pax_global_header00006660000000000000000000000064147506511620014521gustar00rootroot0000000000000052 comment=824e46bfd9fbf240389e626981caf39a51f5e980 futhark-0.25.27/000077500000000000000000000000001475065116200133425ustar00rootroot00000000000000futhark-0.25.27/.github/000077500000000000000000000000001475065116200147025ustar00rootroot00000000000000futhark-0.25.27/.github/actions/000077500000000000000000000000001475065116200163425ustar00rootroot00000000000000futhark-0.25.27/.github/actions/benchmark/000077500000000000000000000000001475065116200202745ustar00rootroot00000000000000futhark-0.25.27/.github/actions/benchmark/action.yml000066400000000000000000000026021475065116200222740ustar00rootroot00000000000000name: 'Benchmark' description: 'Run benchmark suite' inputs: backend: description: 'Backend to use' required: true system: description: 'Name of system (e.g. GPU name)' required: true futhark-options: description: 'Options to pass to futhark bench' required: false default: '' slurm-options: description: 'Options to pass to srun' required: false default: '' runs: using: "composite" steps: - name: Download Benchmarks. uses: ./.github/actions/futhark-slurm with: script: | module load perl cd futhark-benchmarks ./get-data.sh external-data.txt slurm-options: --time=30:00 - uses: ./.github/actions/futhark-slurm with: script: | hostname module unload cuda module load cuda/11.8 futhark bench futhark-benchmarks \ --backend ${{inputs.backend}} \ --exclude no_${{inputs.system}} \ --json futhark-${{inputs.backend}}-${{inputs.system}}-$GITHUB_SHA.json \ --ignore-files /lib/ ${{inputs.futhark-options}} slurm-options: --time=0-02:00:00 ${{inputs.slurm-options}} - uses: actions/upload-artifact@v4 with: name: futhark-${{inputs.backend}}-${{inputs.system}}-${{ github.sha }}.json path: futhark-${{inputs.backend}}-${{inputs.system}}-${{ github.sha }}.json futhark-0.25.27/.github/actions/futhark-slurm/000077500000000000000000000000001475065116200211465ustar00rootroot00000000000000futhark-0.25.27/.github/actions/futhark-slurm/action.yml000066400000000000000000000022331475065116200231460ustar00rootroot00000000000000name: 'Futhark Script' description: 'Run script on slurm if available where Futhark will be available.' inputs: script: description: 'Script to run' required: true slurm-options: description: 'Options to pass to srun' required: false default: '' runs: using: "composite" steps: - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball shell: bash run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - uses: ./.github/actions/is-slurm id: slurm - if: steps.slurm.outputs.is-slurm == 'false' shell: bash run: | ${{inputs.script}} - if: steps.slurm.outputs.is-slurm == 'true' shell: bash run: | printf '#!/bin/bash ${{inputs.script}}' > temp.sh chmod +x temp.sh [ -d /scratch ] && export TMPDIR=/scratch srun ${{inputs.slurm-options}} --exclude=hendrixgpu23fl,hendrixgpu24fl,hendrixgpu25fl,hendrixgpu26fl temp.sh rm temp.sh futhark-0.25.27/.github/actions/is-slurm/000077500000000000000000000000001475065116200201155ustar00rootroot00000000000000futhark-0.25.27/.github/actions/is-slurm/action.yml000066400000000000000000000010761475065116200221210ustar00rootroot00000000000000name: 'Is Slurm Installed' description: 'Checks if slurm is installed by checking if srun can be used.' outputs: is-slurm: description: "If slurm is used." value: ${{ steps.slurm.outputs.is-slurm }} runs: using: "composite" steps: - name: Check if slurm can be run id: slurm shell: bash run: | printf '#!/bin/bash if ! srun --version &> /dev/null || [ $(hostname) = hendrixfut02fl.unicph.domain ]; then echo "is-slurm=false" else echo "is-slurm=true" fi' | bash >> $GITHUB_OUTPUT futhark-0.25.27/.github/workflows/000077500000000000000000000000001475065116200167375ustar00rootroot00000000000000futhark-0.25.27/.github/workflows/benchmark.yml000066400000000000000000000113371475065116200214210ustar00rootroot00000000000000name: Benchmark on: pull_request: types: [ labeled, synchronize ] push: branches: [ master ] workflow_dispatch: jobs: build-linux-nix: if: github.repository == 'diku-dk/futhark' && (!github.event.pull_request.draft && contains(github.event.pull_request.labels.*.name, 'run-benchmarks') || github.ref == 'refs/heads/master') runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 - name: Install Nix uses: cachix/install-nix-action@v27 - uses: cachix/cachix-action@v15 with: name: futhark signingKey: '${{ secrets.CACHIX_SIGNING_KEY }}' - name: Build Futhark run: nix-build --argstr suffix nightly-linux-x86_64 --argstr commit $GITHUB_SHA - uses: actions/upload-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz path: result/futhark-nightly-linux-x86_64.tar.xz benchmark-MI100-opencl: runs-on: MI100 needs: [build-linux-nix] env: TMPDIR: "/scratch" steps: - uses: actions/checkout@v4 with: submodules: recursive fetch-depth: 0 - uses: ./.github/actions/benchmark with: backend: opencl system: MI100 benchmark-MI100-hip: runs-on: MI100 needs: [build-linux-nix] env: TMPDIR: "/scratch" steps: - uses: actions/checkout@v4 with: submodules: recursive fetch-depth: 0 - uses: ./.github/actions/benchmark with: backend: hip system: MI100 benchmark-A100: runs-on: hendrix needs: [build-linux-nix] steps: - uses: actions/checkout@v4 with: submodules: recursive fetch-depth: 0 - uses: ./.github/actions/benchmark with: backend: opencl system: A100 slurm-options: -p gpu --mem=48G --gres=gpu:a100:1 --job-name=fut-opencl-A100 --exclude=hendrixgpu14fl - uses: ./.github/actions/benchmark with: backend: cuda system: A100 slurm-options: -p gpu --mem=48G --gres=gpu:a100:1 --job-name=fut-cuda-A100 --exclude=hendrixgpu14fl # benchmark-titanx-cuda: # runs-on: hendrix # needs: [build-linux-nix] # steps: # - uses: actions/checkout@v4 # with: # submodules: recursive # fetch-depth: 0 # - uses: ./.github/actions/benchmark # with: # backend: cuda # system: titanx # slurm-options: -p gpu --gres=gpu:titanx:1 --job-name=fut-cuda-titanx --exclude=hendrixgpu05fl,hendrixgpu06fl # futhark-options: --exclude=mem_16gb # benchmark-titanx-opencl: # runs-on: hendrix # needs: [build-linux-nix] # steps: # - uses: actions/checkout@v4 # with: # submodules: recursive # fetch-depth: 0 # - uses: ./.github/actions/benchmark # with: # backend: opencl # system: titanx # slurm-options: -p gpu --gres=gpu:titanx:1 --job-name=fut-opencl-titanx --exclude=hendrixgpu05fl,hendrixgpu06fl # futhark-options: --exclude=mem_16gb benchmark-titanrtx-cuda: runs-on: hendrix needs: [build-linux-nix] steps: - uses: actions/checkout@v4 with: submodules: recursive fetch-depth: 0 - uses: ./.github/actions/benchmark with: backend: cuda system: titanrtx slurm-options: -p gpu --mem=48G --gres=gpu:titanrtx:1 --job-name=fut-cuda-titanrtx --exclude=hendrixgpu05fl,hendrixgpu06fl benchmark-titanrtx-opencl: runs-on: hendrix needs: [build-linux-nix] steps: - uses: actions/checkout@v4 with: submodules: recursive fetch-depth: 0 - uses: ./.github/actions/benchmark with: backend: opencl system: titanrtx slurm-options: -p gpu --mem=48G --gres=gpu:titanrtx:1 --job-name=fut-opencl-titanrtx --exclude=hendrixgpu05fl,hendrixgpu06fl benchmark-results: runs-on: ubuntu-22.04 needs: - benchmark-A100 # - benchmark-MI100-opencl # - benchmark-MI100-hip # - benchmark-titanx-cuda # - benchmark-titanx-opencl - benchmark-titanrtx-cuda - benchmark-titanrtx-opencl if: github.ref == 'refs/heads/master' steps: - name: Install SSH key uses: shimataro/ssh-key-action@v2 with: key: ${{ secrets.SSHKEY }} known_hosts: ${{ secrets.KNOWN_HOSTS }} - uses: actions/download-artifact@v4 - run: | ls -R - name: Prepare package run: | mkdir -p package cp */futhark-*.json package/ gzip package/*.json for x in package/*.json.gz; do cp $x $(echo $x | sed "s/$GITHUB_SHA/latest/"); done - name: scp to server run: | scp -o StrictHostKeyChecking=no package/* futhark@futhark-lang.org:/var/www/htdocs/futhark-lang.org/benchmark-results futhark-0.25.27/.github/workflows/main.yml000066400000000000000000000472201475065116200204130ustar00rootroot00000000000000name: CI on: pull_request: push: branches: [ master ] jobs: build-linux-nix: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 - name: Install Nix uses: cachix/install-nix-action@v27 - uses: cachix/cachix-action@v15 with: name: futhark signingKey: '${{ secrets.CACHIX_SIGNING_KEY }}' - name: Build Futhark run: nix-build --argstr suffix nightly-linux-x86_64 --argstr commit $GITHUB_SHA - uses: actions/upload-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz path: result/futhark-nightly-linux-x86_64.tar.xz build-linux-cabal: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 - name: Workaround runner image issue # https://github.com/actions/runner-images/issues/7061 run: sudo chown -R $USER /usr/local/.ghcup - uses: haskell-actions/setup@v2 with: ghc-version: '9.8.1' - uses: actions/cache@v4 name: Cache ~/.cabal/packages, ~/.cabal/store and dist-newstyle with: path: | ~/.cabal/packages ~/.cabal/store dist-newstyle key: ${{ runner.os }}-cabal-${{ hashFiles('futhark.cabal', 'cabal.project') }} - name: cabal check run: cabal check - name: Build Futhark run: | cabal --version ghc --version make configure make build make install build-mac-cabal: runs-on: macos-latest steps: - uses: actions/checkout@v4 - name: Install dependencies run: | brew install cabal-install ghc sphinx-doc echo "/opt/homebrew/opt/sphinx-doc/bin" >> $GITHUB_PATH echo "/opt/homebrew/opt/ghc/bin" >> $GITHUB_PATH - uses: actions/cache@v4 name: Cache ~/.cabal/packages, ~/.cabal/store and dist-newstyle with: path: | ~/.cabal/packages ~/.cabal/store dist-newstyle key: ${{ runner.os }}-cabal-${{ hashFiles('futhark.cabal', 'cabal.project') }} - name: Build run: | cp -r tools/release/skeleton futhark-nightly-macos-x86_64 mkdir -p futhark-nightly-macos-x86_64/bin cabal v2-update cabal install --install-method=copy --overwrite-policy=always --installdir=futhark-nightly-macos-x86_64/bin mkdir -p futhark-nightly-macos-x86_64/share/man/man1/ (cd docs; make man) cp -r docs/_build/man/* futhark-nightly-macos-x86_64/share/man/man1/ mkdir -p futhark-nightly-macos-x86_64/share/futhark cp LICENSE futhark-nightly-macos-x86_64/share/futhark/ echo "${GITHUB_SHA}" > futhark-nightly-macos-x86_64/commit-id tar -Jcf futhark-nightly-macos-x86_64.tar.xz futhark-nightly-macos-x86_64 - uses: actions/upload-artifact@v4 with: name: futhark-nightly-macos-x86_64.tar.xz path: futhark-nightly-macos-x86_64.tar.xz build-windows-cabal: runs-on: windows-latest steps: - uses: actions/checkout@v4 - id: setup-haskell uses: haskell-actions/setup@v2 with: ghc-version: '9.8.1' - uses: actions/cache@v4 name: Cache cabal stuff with: path: | ${{ steps.setup-haskell.outputs.cabal-store }} dist-newstyle key: ${{ runner.os }}-cabal-${{ hashFiles('futhark.cabal', 'cabal.project') }} - name: Build shell: bash run: | cabal update try() { cabal install --install-method=copy --overwrite-policy=always --installdir=.; } try || try || try - uses: vimtor/action-zip@v1.2 with: files: futhark.exe dest: futhark-nightly-windows-x86_64.zip - uses: actions/upload-artifact@v4 with: name: futhark-nightly-windows-x86_64.zip path: futhark-nightly-windows-x86_64.zip haskell-test-style: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 - name: Install Nix uses: cachix/install-nix-action@v27 - uses: actions/cache@v4 name: Cache ~/.cabal/packages, ~/.cabal/store and dist-newstyle with: path: | ~/.cabal/packages ~/.cabal/store dist-newstyle key: ${{ runner.os }}-${{ hashFiles('nix/sources.json') }}-style - name: Style check run: nix-shell --pure --run "make check" python-test-style: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 - name: Install Nix uses: cachix/install-nix-action@v27 - name: Style check run: nix-shell --pure --run "black --check ." - name: Type check run: nix-shell --pure --run "mypy ." # Fails mysteriously right now. # # build-docs: # runs-on: ubuntu-22.04 # steps: # - uses: actions/checkout@v4 # - name: Install Nix # uses: cachix/install-nix-action@v18 # - uses: actions/cache@v4 # name: Cache ~/.cabal/packages, ~/.cabal/store and dist-newstyle # with: # path: | # ~/.cabal/packages # ~/.cabal/store # dist-newstyle # key: ${{ runner.os }}-${{ hashFiles('nix/sources.json', 'futhark.cabal', 'cabal.project') }}-haddock # - name: Run haddock # run: | # nix-shell --pure --run "make configure" # nix-shell --pure --run "make docs" test-interpreter: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - uses: actions/checkout@v4 - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - run: | futhark test -i tests test-structure: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - uses: actions/checkout@v4 - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - run: | futhark test -s tests test-c: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - uses: actions/checkout@v4 - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y python3-jsonschema - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - run: | export CFLAGS="-fsanitize=undefined -fsanitize=address -fno-sanitize-recover -O" futhark test -c --backend=c tests --no-tuning make -C tests_lib/c -j test-multicore: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - uses: actions/checkout@v4 - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y python3-jsonschema - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - run: | export CFLAGS="-fsanitize=undefined -fno-sanitize-recover -O" futhark test -c --backend=multicore tests --no-tuning make -C tests_lib/c FUTHARK_BACKEND=multicore test-ispc: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - uses: actions/checkout@v4 - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y python3-jsonschema - name: Install Nix uses: cachix/install-nix-action@v27 - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - run: | export CFLAGS="-fsanitize=undefined -fno-sanitize-recover -O" nix-shell --run 'futhark test -c --backend=ispc tests --no-tuning' nix-shell --run 'make -C tests_lib/c FUTHARK_BACKEND=ispc' test-python: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - uses: actions/checkout@v4 - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y python3-numpy - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - run: | futhark test -c --no-terminal --no-tuning --backend=python --exclude=no_python --exclude=compiled tests make -C tests_lib/python -j test-oclgrind: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - uses: actions/checkout@v4 - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y opencl-headers oclgrind nvidia-opencl-dev - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - run: | futhark test tests -c --no-terminal --backend=opencl --exclude=compiled --exclude=no_oclgrind --cache-extension=cache --pass-option=--build-option=-O0 --runner=tools/oclgrindrunner.sh test-pyoclgrind: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - uses: actions/checkout@v4 - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y opencl-headers oclgrind - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - run: | set -e python -m venv virtualenv source virtualenv/bin/activate pip install 'numpy<2.0.0' pyopencl jsonschema futhark test tests -c --no-terminal --backend=pyopencl --exclude=compiled --exclude=no_oclgrind --cache-extension=cache --pass-option=--build-option=-O0 --runner=tools/oclgrindrunner.sh test-opencl: runs-on: hendrix needs: [build-linux-nix] if: github.repository == 'diku-dk/futhark' steps: - uses: actions/checkout@v4 - uses: ./.github/actions/futhark-slurm with: script: | set -e python -m venv virtualenv source virtualenv/bin/activate pip install jsonschema module unload cuda module load cuda/11.8 futhark test tests \ --backend=opencl \ --cache-extension=cache FUTHARK_BACKEND=opencl make -C tests_lib/c -j slurm-options: -p gpu --time=0-01:00:00 --gres=gpu:1 --job-name=fut-opencl-test --exclude=hendrixgpu05fl,hendrixgpu06fl test-pyopencl: runs-on: hendrix needs: [build-linux-nix] if: github.repository == 'diku-dk/futhark' steps: - uses: actions/checkout@v4 - uses: ./.github/actions/futhark-slurm with: script: | set -e python -m venv virtualenv source virtualenv/bin/activate pip install numpy pyopencl jsonschema module unload cuda module load cuda/11.8 futhark test tests --no-terminal --backend=pyopencl slurm-options: -p gpu --time=0-02:00:00 --gres=gpu:1 --job-name=fut-pyopencl-test --exclude=hendrixgpu05fl,hendrixgpu06fl test-cuda: runs-on: cuda needs: [build-linux-nix] if: github.repository == 'diku-dk/futhark' steps: - uses: actions/checkout@v4 - uses: ./.github/actions/futhark-slurm with: script: | set -e python -m venv virtualenv source virtualenv/bin/activate pip install jsonschema module unload cuda module load cuda/11.8 nvidia-smi --query-gpu=gpu_name --format=csv,noheader futhark test tests \ --backend=cuda \ --cache-extension=cache \ --concurrency=8 FUTHARK_BACKEND=cuda make -C tests_lib/c -j slurm-options: -p gpu --time=0-01:00:00 --gres=gpu:a100:1 --job-name=fut-cuda-test test-hip: runs-on: MI100 needs: [build-linux-nix] if: github.repository == 'diku-dk/futhark' steps: - uses: actions/checkout@v4 - run: | python -m venv virtualenv source virtualenv/bin/activate pip install jsonschema - uses: ./.github/actions/futhark-slurm with: script: | set -e futhark test tests --no-terminal --backend=hip --concurrency=8 source virtualenv/bin/activate FUTHARK_BACKEND=hip make -C tests_lib/c -j test-wasm: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - uses: actions/checkout@v4 - uses: mymindstorm/setup-emsdk@v12 with: version: 2.0.18 actions-cache-folder: 'emsdk-cache' - uses: actions/setup-node@v3.5.1 with: node-version: '16.x' - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - name: Run tests run: | node --version export EMCFLAGS="-sINITIAL_MEMORY=2147418112 -O1" # 2gb - 64kb... largest value of memory futhark test \ -c \ --backend=wasm \ --runner=./tools/node-simd.sh \ --no-tuning \ --exclude=no_wasm tests test-wasm-multicore: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - uses: actions/checkout@v4 - uses: mymindstorm/setup-emsdk@v12 with: version: 2.0.18 actions-cache-folder: 'emsdk-mc-cache' - uses: actions/setup-node@v3.5.1 with: node-version: '16.x' - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - name: Run tests run: | node --version export EMCFLAGS="-sINITIAL_MEMORY=2147418112 -O1 -s PTHREAD_POOL_SIZE=12" # 2gb - 64kb... largest value of memory futhark test -c --backend=wasm-multicore --runner=./tools/node-threaded.sh --no-tuning --exclude=no_wasm tests test-wasm-lib: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - uses: actions/checkout@v4 - uses: mymindstorm/setup-emsdk@v12 with: version: 2.0.18 actions-cache-folder: 'emsdk-cache' - uses: actions/setup-node@v3.5.1 with: node-version: '16.x' - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - name: Run tests run: | make -C tests_lib/javascript test-ir-parser: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - uses: actions/checkout@v4 - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - run: | tools/testparser.sh tests test-formatter: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - uses: actions/checkout@v4 - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - run: | tools/testfmt.sh tests test-tools: runs-on: ubuntu-22.04 needs: [build-linux-nix] steps: - name: Install OS dependencies run: | sudo apt-get update sudo apt-get install -y ffmpeg oclgrind - uses: actions/checkout@v4 - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - name: Install from nightly tarball run: | tar xvf futhark-nightly-linux-x86_64.tar.xz make -C futhark-nightly-linux-x86_64/ install PREFIX=$HOME/.local echo "$HOME/.local/bin" >> $GITHUB_PATH - run: | cd tests_pkg && sh test.sh - run: | cd tests_literate && sh test.sh - run: | cd tests_repl && sh test.sh - run: | cd tests_bench && sh test.sh - run: | cd tests_adhoc && sh test.sh - run: | cd tests_fmt && sh test.sh - run: | futhark doc -o prelude-docs /dev/null tar -Jcf prelude-docs.tar.xz prelude-docs - uses: actions/upload-artifact@v4 with: name: prelude-docs.tar.xz path: prelude-docs.tar.xz deploy-nightly: runs-on: ubuntu-22.04 needs: [build-linux-nix, build-mac-cabal, build-windows-cabal, test-tools] if: github.ref == 'refs/heads/master' steps: - name: Install SSH key uses: shimataro/ssh-key-action@v2 with: key: ${{ secrets.SSHKEY }} known_hosts: ${{ secrets.KNOWN_HOSTS }} - uses: actions/download-artifact@v4 with: name: futhark-nightly-windows-x86_64.zip - uses: actions/download-artifact@v4 with: name: futhark-nightly-macos-x86_64.tar.xz - uses: actions/download-artifact@v4 with: name: futhark-nightly-linux-x86_64.tar.xz - uses: actions/download-artifact@v4 with: name: prelude-docs.tar.xz - name: scp tarballs to server run: | scp -o StrictHostKeyChecking=no futhark-nightly-*-x86_64.{tar.xz,zip} futhark@futhark-lang.org:/var/www/htdocs/futhark-lang.org/releases - name: copy docs to server run: | tar -xf prelude-docs.tar.xz rsync -rv -e 'ssh -o "StrictHostKeyChecking no"' prelude-docs/* futhark@futhark-lang.org:/var/www/htdocs/futhark-lang.org/docs/prelude/ - name: make nightly release uses: "mathieucarbou/marvinpinto-action-automatic-releases@latest" with: repo_token: "${{ secrets.GITHUB_TOKEN }}" automatic_release_tag: "nightly" prerelease: true title: "nightly" files: | futhark-nightly-linux-x86_64.tar.xz futhark-nightly-macos-x86_64.tar.xz futhark-nightly-windows-x86_64.zip futhark-0.25.27/.github/workflows/release.yml000066400000000000000000000053261475065116200211100ustar00rootroot00000000000000name: Release on: push: tags: - 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10 jobs: build: name: Create and upload release tarballs if: github.event.base_ref == 'refs/heads/master' runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v4 - name: Install Nix uses: cachix/install-nix-action@v20 - name: Prepare metadata id: metadata run: | echo "VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT echo "TARBALL=futhark-${GITHUB_REF#refs/tags/v}-linux-x86_64.tar.xz" >> $GITHUB_OUTPUT - name: Build tarball env: VERSION: ${{ steps.metadata.outputs.VERSION }} run: nix-build --argstr suffix $VERSION-linux-x86_64 --argstr commit $GITHUB_SHA - name: Extract release changes env: VERSION: ${{ steps.metadata.outputs.VERSION }} run: sh tools/changelog.sh $VERSION < CHANGELOG.md > release_changes.md - name: Create release id: create_release uses: actions/create-release@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: tag_name: ${{ github.ref }} release_name: ${{ steps.metadata.outputs.VERSION }} body_path: release_changes.md draft: false prerelease: false - name: Upload tarball id: upload-release-asset uses: actions/upload-release-asset@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: upload_url: ${{ steps.create_release.outputs.upload_url }} asset_path: result/${{ steps.metadata.outputs.TARBALL }} asset_name: ${{ steps.metadata.outputs.TARBALL }} asset_content_type: application/x-xz - name: Install SSH key uses: shimataro/ssh-key-action@v2 with: key: ${{ secrets.SSHKEY }} known_hosts: ${{ secrets.KNOWN_HOSTS }} - name: scp to server env: TARBALL: ${{ steps.metadata.outputs.TARBALL }} run: scp -o StrictHostKeyChecking=no result/$TARBALL futhark@futhark-lang.org:/var/www/htdocs/futhark-lang.org/releases upload-to-hackage: name: Upload to Hackage if: github.event.base_ref == 'refs/heads/master' runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Install Nix uses: cachix/install-nix-action@v20 - uses: cachix/cachix-action@v12 with: name: futhark signingKey: '${{ secrets.CACHIX_SIGNING_KEY }}' - name: Upload to Hackage run: | export HACKAGE_KEY export VERSION nix-shell --run 'cabal update' nix-shell --run tools/release/hackage.sh env: HACKAGE_KEY: ${{ secrets.HACKAGE_KEY }} VERSION: ${{ steps.metadata.outputs.VERSION }} futhark-0.25.27/.gitignore000066400000000000000000000012261475065116200153330ustar00rootroot00000000000000# Ignore non-directories in root, so you can use it as scratch space. /* !/*/ # The previous lines explicitly unignores .git /.git # Ignore Emacs autosave files. \#*\# .#* *~ # Ignore IntelliJ configuration files .idea/ # Ignore cabal build directory. dist dist-newstyle # Ignore IDE info .hie .vscode # Ignore side-effects of compilation. *.aux *.log *.bbl *.blg *.out *.prof *.o *.hi *.snm *.toc *.vrb *.nav # Ignore core dumps. core core.* # Ignore flymake-files created by Emacs. *flymake* # Ignore temporary backup directories. *-backup # Ignore directories often created by testing data/ lib/ testparser/ # Ignore macOS system files. .DS_Store futhark-0.25.27/.gitmodules000066400000000000000000000001651475065116200155210ustar00rootroot00000000000000[submodule "futhark-benchmarks"] path = futhark-benchmarks url = https://github.com/diku-dk/futhark-benchmarks.git futhark-0.25.27/.hlint.yaml000066400000000000000000000016101475065116200154200ustar00rootroot00000000000000- modules: - {name: [Data.Set, Data.HashSet], as: S} - {name: [Data.Map, Data.Map.*], as: M} - {name: [Data.List.NonEmpty], as: NE} - {name: [Data.List], as: L} - error: {lhs: map subExpRes, rhs: subExpsRes} - error: {lhs: map varRes, rhs: varsRes} - error: {lhs: return, rhs: pure} - error: {lhs: nameIn x y, rhs: x `nameIn` y, side: not (isInfixApp original) && not (isParen result), name: Use infix} - error: {lhs: not (x `nameIn` y), rhs: x `notNameIn` y} - error: {lhs: not (any (`nameIn` x) y), rhs: all (`notNameIn` x) y} - error: {lhs: Data.Text.pack (show x), rhs: Futhark.Util.showText x} - error: {lhs: Data.List.nub, rhs: Futhark.Util.nubOrd} - error: {lhs: "copyDWIM p [] y [DimFix i]", rhs: "copyDWIMFix p [] y [i]" } - error: {lhs: "copyDWIM x [DimFix i] y []", rhs: "copyDWIMFix x [i] y []" } - error: {lhs: "copyDWIM x [DimFix i] y [DimFix j]", rhs: "copyDWIMFix x [i] y [j]" } futhark-0.25.27/.mailmap000066400000000000000000000006541475065116200147700ustar00rootroot00000000000000Cosmin Oancea Henrik Urms Martin Elsman Maxwell Orok Mikkel Storgaard Knudsen Oleksandr Shturmov Philip Munksgaard Robert Schenck futhark-0.25.27/.readthedocs.yaml000066400000000000000000000011601475065116200165670ustar00rootroot00000000000000# Read the Docs configuration file for Sphinx projects # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details version: 2 build: os: ubuntu-22.04 tools: python: "3.11" # Build documentation in the "docs/" directory with Sphinx sphinx: configuration: docs/conf.py # Optionally build your docs in additional formats such as PDF and ePub formats: - pdf # Optional but recommended, declare the Python requirements required # to build your documentation # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html python: install: - requirements: docs/requirements.txt futhark-0.25.27/CHANGELOG.md000066400000000000000000003035601475065116200151620ustar00rootroot00000000000000# Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [0.25.27] ### Added * Improved reverse-mode AD of `scan` with complicated operators. Work by Peter Adema and Sophus Valentin Willumsgaard. ### Fixed * `futhark eval`: any errors in the provided .fut file would cause a "file not found" error message. * Handling of module-dependent size expressions in type abbreviations (#2209). * A `let`-bound size would mistakenly be in scope of the bound expression (#2210). * An overzealous floating-point simplification rule. * Corrected AD of `x**y` where `x==0` (#2216). * `futhark fmt`: correct file name in parse errors. * A bug in the "sink" optimisation pass could cause compiler crashes. * Compile errors with newer versions of `ispc`. ## [0.25.26] ### Fixed * Some Windows compatibility quirks (#2200, #2201). * `futhark pkg`: fixed parsing of Git timestamps in Z time zone. * GPU backends did not handle array constants correctly in some cases. * `futhark fmt`: do not throw away doc comments for `local` definitions. * `futhark fmt`: improve formatting of value specs. * `futhark fmt`: add `--check` option. ## [0.25.25] ### Added * Improvements to `futhark fmt`. ### Fixed * Sizes that go out of scope due to use of higher order functions will now work in more cases by adding existentials. (#2193) * Tracing inside AD operators with the interpreter now prints values properly. * Compiled and interpreted code now have same treatment of inclusive ranges with start==end and negative step size, e.g. `1..0...1` produces `[1]` rather than an invalid range error. * Inconsistent handling of types in lambda lifting (#2197). * Invalid primal results from `vjp2` in interpreter (#2199). ## [0.25.24] ### Added * `futhark doc` now produces better (and stable) anchor IDs. * `futhark profile` now supports multiple JSON files. * `futhark fmt`, by William Due and Therese Lyngby. * Lambdas can now be passed as the last argument to a function application. ### Fixed * Negation of floating-point positive zero now produces a negative zero. * Necessary inlining of functions used inside AD constructs. * A compile time regression for programs that used higher order functions very aggressively. * Uniqueness bug related to slice simplification. ## [0.25.23] ### Added * Trailing commas are now allowed for arrays, records, and tuples in the textual value format and in FutharkScript. * Faster floating-point atomics with OpenCL backend on AMD and NVIDIA GPUs. This affects histogram workloads. * AD is now supported by the interpreter (thanks to Marcus Jensen). ### Fixed * Some instances of invalid copy removal. (Again.) * An issue related to entry points with nontrivial sizes in their arguments, where the entry points were also used as normal functions elsewhere. (#2184) ## [0.25.22] ### Added * `futhark script` now supports an `-f` option. * `futhark script` now supports the builtin procedure `$store`. ### Fixed * An error in tuning file validation. * Constant folding for loops that produce floating point results could result in different numerical behaviour. * Compiler crash in memory short circuiting (#2176). ## [0.25.21] ### Added * Logging now prints more GPU information on context initialisation. * GPU cache size can now be configured (tuning param: `default_cache`). * GPU shared memory can now be configured (tuning param: `default_shared_memory`). * GPU register capacity can now be configured. * `futhark script` now accepts a `-b` option for producing binary output. ### Fixed * Type names for element types of array indexing functions in C interface are now often better - although there are still cases where you end up with hashed names. (#2172) * In some cases, GPU failures would not be reported properly if a previous failure was pending. * `auto output` didn't work if the `.fut` file did not have any path components. * Improved detection of malformed tuning files. ## [0.25.20] ### Added * Better error message when in-place updates fail at runtime due to a shape mismatch. ### Fixed * `#[unroll]` on an outer loop now no longer causes unrolling of all loops nested inside the loop body. * Obscure issue related to replications of constants in complex intrablock kernels. * Interpreter no longer crashes on attributes in patterns. * Fixes to array indexing through C API when using GPU backends. ## [0.25.19] ### Added * The compiler now does slightly less aggressive inlining. Use the `#[inline]` attribute if you want to force inlining of some function. * Arrays of opaque types now support indexing through the C API. Arrays of records can also be constructed. (#2082) ### Fixed * The `opencl` backend now always passes `-cl-fp32-correctly-rounded-divide-sqrt` to the kernel compiler, in order to match CUDA and HIP behaviour. ## [0.25.18] ### Added * New prelude function: `rep`, an implicit form of `replicate`. * Improved handling of large monomorphic single-dimensional array literals (#2160). ### Fixed * `futhark repl` no longer asks for confirmation on EOF. * Obscure oversight related to abstract size-lifted types (#2120). * Accidential exponential-time algorithm in layout optimisation for multicore backends (#2151). ## [0.25.17] * Faster device-to-device copies on CUDA. * "More correctly" detect L2 cache size for OpenCL backend on AMD GPUs. ### Fixed * Handling of `..` in `import` paths (again). * Detection of impossible loop parameter sizes (#2144). * Rare case where GPU histograms would use slightly too much shared memory and fail at run-time. * Rare crash in layout optimisation. ## [0.25.16] ### Added * ``futhark test``: `--no-terminal` now prints status messages even when no failures occur. * ``futhark test`` no longer runs ``structure`` tests by default. Pass ``-s`` to run them. * Rewritten array layout optimisation pass by Bjarke Pedersen and Oscar Nelin. Minor speedup for some programs, but is more importantly a principled foundation for further improvements. * Better error message when exceeding shared memory limits. * Better dead code removal for the GPU representation (minor impact on some programs). ### Fixed * Bugs related to deduplication of array payloads in sum types. Unfortunately, fixed by just not deduplicating in those cases. * Frontend bug related to turning size expressions into variables (#2136). * Another exotic monomorphisation bug. ## [0.25.15] ### Added * Incremental Flattening generates fewer redundant code versions. * Better simplification of slices. (#2125) ### Fixed * Ignore type suffixes when unifying expressions (#2124). * In the C API, opaque types that correspond to an array of an opaque type are now once again named `futhark_opaque_arr_...`. * `cuda` backend did not correctly profile CPU-to-GPU scalar copies. ## [0.25.14] ### Added * The prelude definition of `filter` is now more memory efficient, particularly when the output is much smaller than the input. (#2109) * New configuration for GPU backends: `futhark_context_config_set_unified_memory`, also available on executables as ``--unified-memory``. * The "raw" API functions now do something potentially useful, but are still considered experimental. * `futhark --version` now reports GHC version. ### Fixed * Incorrect type checking of let-bound sizes occurring multiple times in pattern. (#2103). * A concatenation simplification would sometimes mess up sizes. (#2104) * Bug related to monomorphisation of polymorphic local functions (#2106). * Rare crash in short circuiting. * Referencing an unbound type parameter could crash the type checker (#2113, #2114). * Futhark now works with GHC 9.8 (#2105). ## [0.25.13] ### Added * Incremental flattening of `map`-`scan` compositions with nested parallelism (similar to the logic for `map`-`reduce` compositions that we have had for years). * `futhark script`, for running FutharkScript expressions from the command line. * `futhark repl` now prints out a message when it ignores a breakpoint during initialisation. (#2098) ### Fixed * Flattening of `scatter` with multi-dimensional elements (#2089). * Some instances of not-actually-irregular allocations were mistakenly interpreted as irregular. Fixing this was a dividend of the memory representation simplifications of 0.25.12. * Obscure issue related to expansion of shared memory allocations (#2092). * A crash in alias checking under some rare circumstances (#2096). * Mishandling of existential sizes for top level constants. (#2099) * Compiler crash when generating code for copying nothing at all. (#2100) ## [0.25.12] ### Added * `f16.copysign`, `f32.copysign`, `f64.copysign`. * Trailing commas are now allowed for all syntactical elements that involve comma-separation. (#2068) * The C API now allows destruction and construction of sum types (with some caveats). (#2074) * An overall reduction in memory copies, through simplifying the internal representation. ### Fixed * C API would define distinct entry point types for Futhark types that differed only in naming of sizes (#2080). * `==` and `!=` on sum types with array payloads. Constructing them is now a bit slower, though. (#2081) * Somewhat obscure simplification error caused by neglecting to update metadata when removing dead scatter outputs. * Compiler crash due to the type checker forgetting to respect the explicitly ascribed non-consuming diet of loop parameters (#2067). * Size inference did incomplete level/scope checking, which could result in circular sizes, which usually manifested as the type checker going into an infinite loop (#2073). * The OpenCL backend now more gracefully handles lack of platform. ## [0.25.11] ### Added * New prelude function: `manifest`. For doing subtle things to memory. * The GPU backends now handle up to 20 operators in a single fused reduction. * CUDA/HIP terminology for GPU concepts (e.g. "thread block") is now used in all public interfaces. The OpenCL names are still supported for backwards compatibility. * More fusion across array slicing. ### Fixed * Compatibility with CUDA versions prior than 12. ## [0.25.10] ### Added * Faster non-commutative reductions in the GPU backends. Work by Anders Holst and Christian Påbøl Jacobsen. ### Fixed * Interpreter crash for certain complicated size expressions involving internal bindings (#2053). * Incorrect type checking of `let` binding with explicit size quantification, where size appears in type of body (#2048). * GPU code generation for non-commutative non-segmented reductions with array operands (#2051). * Histogram with non-vectorised reduction operators (#2056). (But it is probably not a good idea to write such programs.) * Futhark's LSP server should work better with Eglot. * Incorrect copy removal inside histograms could cause compiler error (#2058). * CUDA backend now correctly queries for available shared memory, which affects performance (hopefully positively). * `futhark literate` now switches to the directory containing the `.fut` file before executing its contents. This fixes accessing files through relative paths. ## [0.25.9] ### Added * The `cuda` and `hip` backends now generate faster code for `scan`s that have been fused with `map`s that internally produce arrays. Work by Anders Holst and Christian Påbøl Jacobsen. * `f16.ldexp`, `f32.ldexp`, `f64.ldexp`, corresponding to the functions in the C math library. ### Fixed * Incorrect data dependency information for `scatter` and `vjp` could cause invalid simplification. * Barrier divergence in certain complicated kernels that contain both bounds checks and intragroup scans. ## [0.25.8] ### Added * FutharkScript now has a `$loadbytes` builtin function for reading arbitrary bytes into Futhark programs. * `futhark profile` can now process reports produced by the C API function `futhark_context_report()`. * `futhark profile` now also produces a timeline of events. ### Fixed * `futhark literate` no longer fails if the final line is a directive without a trailing newline. * Parser now allows arbitrary patterns in function parameters and `let` bindings, although the type checker will reject any that are refutable (#2017). * Avoid generating invalid code in cases where deduplicated sum types are exposed through entry points (#1960). * A bug in data dependency analysis for histogram operations would mistakenly classify some loop parameters as redundant, leaving to code being removed. ## [0.25.7] ### Added * `futhark autotune` now supports `hip` backend. * Better parallelisation of `scatter` when the target is multidimensional (#2035). ### Fixed * Very large `iota`s now work. * Lambda lifting in `while` conditions (#2038). * Size expressions in local function parameters had an interesting interaction with defunctionalisation (#2040). * The `store` command in server executables did not properly synchronise when storing opaque values, which would lead to use-after-free errors. ## [0.25.6] ### Added * The various C API functions that accept strings now perform a copy, meaning the caller does not have to keep the strings alive. * Slightly better lexer error messages. * Fusion across slicing is now possible in some cases. * New tool: `futhark profile`. ### Fixed * Inefficient locking for certain segmented histograms (#2024). ## [0.25.5] ### Added * `futhark repl` now has a `:format` command. Work by Dominic Kennedy. ### Fixed * Textual floating-point numbers printed by executables now always print enough digits to not hide information. Binary output is unchanged. * Invalid CSE on constants could crash the compiler (#2021). ## [0.25.4] ### Fixed * Invalid simplification (#2015). * Rarely occurring deadlock for fused map-scan compositions in CUDA backend, when a bounds check failed in the map function. * Compiler and interpreter crash for tricky interactions of abstract types and sizes (#2016). Solved by banning such uses - in principle this could break code. * Incomplete alias tracking could cause removal of necessary copies, leading to compiler crash (#2018). ## [0.25.3] ### Added * pyopencl backend: compatibility with future versions of PyOpenCL. * New backend: hip. ### Fixed * Exotic problems related to intra-group reductions with array operands. (Very rare in real code, although sometimes generated by AD.) * Interpreter issue related to sizes in modules (#1992, #1993, #2002). * Incorrect insertion of size arguments in in complex cases (#1998). * Incorrect handling of `match` in lambda lifting (#2000). * Regression in checking of consumption (#2007). * Error in type checking of horisontally fused `scatter`s could crash the compiler (#2009). * Size-polymorphic value bindings with existential sizes are now rejected by type checker (#1993). * Single pass scans with complicated fused map functions were insufficiently memory-expanded (#2023). * Invalid short circuiting (#2013). ## [0.25.2] ### Added * Flattening/unflattening as the final operation in an entry point no longer forces a copy. * The `opencl` backend no longer *always* fails on platforms that do not support 64-bit integer atomics, although it will still fail if the program needs them. * Various performance improvements to the compiler itself; particularly the frontend. It should be moderately faster. ### Fixed * Code generation for `f16` literals in CUDA backend (#1979). * Branches that return arrays differing in sign of their stride (#1984). ## [0.25.1] ### Added * Arbitrary expressions of type `i64` are now allowed as sizes. Work by Lubin Bailly. * New prelude function `resize`. ### Removed * The prelude functions `concat_to` and `flatten_to`. They are often not necessary now, and otherwise `resize` is available. ### Changed * The prelude functions `flatten` and `unflatten` (and their multidimensional variants), as well as `split`, now have more precise types. * Local and anonymous (lambda) functions that *must* return unique results (because they are passed to a higher order function that requires this) must now have an explicit return type ascription that declares this, using `*`. This is very rare (in practice unobserved) in real programs. ### Fixed * `futhark doc` produced some invalid links. * `flatten` did not properly check for claimed negative array sizes. * Type checker crash on some ill-typed programs (#1926). * Some soundness bugs in memory short circuiting (#1927, #1930). * Another compiler crash in block tiling (#1933, #1940). * Global arrays with size parameters no longer have aliases. * `futhark eval` no longer crashes on ambiguously typed expressions (#1946). * A code motion pass was ignorant of consumption constraints, leading to compiler crash (#1947). * Type checker could get confused and think unknown sizes were available when they really weren't (#1950). * Some index optimisations removed certificates (#1952). * GPU backends can now transpose arrays whose size does not fit in a 32-bit integer (#1953). * Bug in alias checking for the core language type checker (#1949). Actually (finally) a proper fix of #803. * Defunctionalisation duplicates less code (#1968). ## [0.24.3] ### Fixed * Certain cases of noninlined functions in `multicore` backend. * Defunctionalisation of `match` where the constructors carry functions (#1917). * Shape coercions involving sum types (#1918). This required tightening the rules a little bit, so some coercions involving anonymous sizes may now be rejected. Add expected sizes as needed. * Defunctionalisation somtimes forgot about sizes bound at top level (#1920). ## [0.24.2] ### Added * `futhark literate` (and FutharkScript in general) is now able to do a bit of type-coercion of constants. ### Fixed * Accumulators (produced by AD) had defective code generation for intra-group GPU kernel versions. (#1895) * The header file generated by `--library` contained a prototype for an undefined function. (#1896) * Crashing bug in LSP caused by stray `undefined` (#1907). * Missing check for anonymous sizes in type abbreviations (#1903). * Defunctionalisation crashed on projection of holes. * Loop optimisation would sometimes remove holes. * A potential barrier divergence for certain GPU kernels that fail bounds checking. * A potential infinite loop when looking up aliases (#1915). * `futhark literate`: less extraneous whitespace. ## [0.24.1] ### Added * The manifest file now lists which tuning parameters are relevant for each entry point. (#1884) * A command `tuning_params` has been added to the server protocol. ### Changed * If part of a function parameter is marked as consuming ("unique"), the *entire* parameter is now marked as consuming. ### Fixed * A somewhat obscure simplification rule could mess up use of memory. * Corner case optimisation for mapping over `iota` (#1874). * AD for certain combinations of `map` and indexing (#1878). * Barrier divergence in generated code in some exotic cases (#1883). * Handling of higher-order holes (#1885). ## [0.23.1] ### Added * `f16.log1p`/`f32.log1p`/`f64.log1p` by nbos (#1820). * Better syntax errors for invalid use of `!`. * `futhark literate` now supports a `$loadaudio` builtin function for loading audio to Futhark programs (#1829). * You can now mix consumption and higher-order terms in slightly more cases (#1836). * `futhark pkg` now invokes Git directly rather than scraping GitHub/GitLab. This means package paths can now refer to any Git repository, as long as `git clone` works. In particular, you can use private and self-hosted repositories. * Significant reduction in compilation time by doing internal sanity checks in separate thread. * New command: `futhark eval`. Evaluates Futhark expressions provided as command line arguments, optionally allowing a file import (#1408). * `script input` now allows the use of `$loaddata`. * Datasets used in `futhark test` and `futhark bench` can now be named (#1859). * New command `futhark benchcmp` by William Due. ### Changed * The C API function `futhark_context_new_with_command_queue()` for the OpenCL backend has been replaced with a configuration setting `futhark_context_config_set_command_queue()`. ### Fixed * Minor parser regression that mostly affects the REPL (#1822). * Parser did not recognise custom infix operators that did not have a builtin operator as prefix (#1824). * GPU backends: expansion of irregular nested allocations involving consumption (#1837, #1838). * CLI executables now handle entry points with names that are not valid C identifiers (#1841). * Various oversights in the type checking of uniqueness annotations for higher-order functions (#1842). * Invalid short-circuiting could cause compiler crashes (#1843). * Defunctionalisation could mess up sum types, leading to invalid code generation by internalisation, leading to a compiler crash (#1847). * The `#[break]` attribute now provides the right environment to `futhark repl`, allowing local variables to be inspected. * Simplification of concatenations (#1851). * Array payloads in sum types no longer need parens (#1853). * When a file is loaded with `futhark repl`, `local` declarations are now available. * Missing alias propagation when pattern matching incompletely known sum types (#1855). * `reduce_by_index` and `hist` were in some cases unable to handle input sizes that do not fit in a 32-bit integer. * A fusion bug related to fusing across transpositions could result in a compiler crash (#1858). ## [0.22.7] ### Added * `futhark literate` now supports an `:audio` directive for producing audio files from arrays of `i8` (#1810). * `futhark multicore` now parallelises copies (#1799). * `futhark multicore` now uses an allocator that better handles large allocations (#1768). ### Fixed * Some record field names could cause generation of invalid C API names (#1806). * Memory block merging was extremely and unnecessarily slow for programs with many entry points. * Simplification mistake could lead to out-of-bounds reads (#1808). * `futhark lsp` now handles some bookkeeping messages sent by Eglot. * Parser can now handle arbitrarily complex chaining of indexing and projection. * Detect and complain about source files without .fut extension (#1813). * Overly conservative checking of whether a function parameter is allowed by be consumed - it was disallowed if it contained any scalars (#1816). ## [0.22.6] ### Added * Slightly better type errors for sum types (#1792). * Better tracing output in interpreter (#1795). * Improved optimisation of code that uses zero-element arrays (sometimes used for type witnesses). ### Fixed * Mishandling of bounds checks in parallel backends could cause compiler crashes (#1791). * Mis-simplification of certain sequentialised scatters into single-element arrays (#1793). * Invalid `scatter` fusion would cause an internal compiler error (#1794). * The code generator flipped the order of `match` cases. * Simpification of concatenations (#1796). * Defunctionalisation error for fancy size programming (#1798). * Code generation for load-unbalanced reductions in multicore backend (#1800). * Futhark now works on CUDA 12 (#1801). * `mul_hi` and `mad_hi` for signed integer types now actually do signed multiplication (previously it was always unsigned). ## [0.22.5] ### Added * Memory short circuiting now also applied to the `multicore` backend. * Fixes for AD of `scan` with nonscalar operators. Work by Lotte Bruun and Ulrik Larsen. * Generalised histograms now supported in AD. Work by Lotte Bruun and Ulrik Larsen. * OpenCL kernels now embed group size information, which can potentially be used for better register allocation by the OpenCL implementation. ### Fixed * A hole in the type checker that allowed one to sneak functions out of conditionals (#1787). * `futhark repl`: unsigned integers are now printed correctly. * A bug in the type-checking of `match` (#1783). * Missing sequencing in type-checking of in-place `let` expressions (#1786). * Crash in defunctionaliser caused by duplicate parameter names (#1780). * Infelicities in multicore codegen for scans (#1777). ## [0.22.4] ### Added * Memory short circuiting, a major new optimisation by Philip Munksgaard that avoids copies by constructing values in-place. ### Fixed * `assert` was sometimes optimised away by CSE. * `futhark literate` now handles type abbreviations in entry points (#1750). * Handling of non-inlined functions in GPU code. Still very restricted. * Uniqueness checking bug (#1751). * Simplification bug (#1753). * A bug related to optimisation of scalar code migrated to GPU. * Memory optimisation bug for top-level constants (#1755). * Handling of holes in defunctionalisation (again). * A few cases where optimisation (safely but perhaps confusingly) removed bounds checks (#1758). * Futhark now works on Windows again (#1734). This support remains very flaky and not well tested. * Type inference of field projection (#1762). ## [0.22.3] ### Fixed * Non-server executables neglected to synchronise before printing results (#1731). * Fixed handling of holes in defunctionalisation (#1738). * Ascription of higher-order modules (#1741). * Fixed compiler crash when attempting to tile irregular parallelism (#1739). * Incorrect regularity checking in loop interchange (#1744). * Pattern match exhaustiveness of bools (#1748). * Improper consumption checking of nonlocal bindings (#1749). ## [0.22.2] ### Added * `futhark repl` is now delightfully more colourful. * `futhark repl` no longer prints scalar types with type suffixes (#1724). * `futhark pyopencl` executables now accept ``--build-option``. * New functions: `f16.nextafter`, `f32.nextafter`, `f64.nextafter`, matching the ones in the C math library. * `futhark literate` now prints directives in the output exactly as they appear in the source. ### Fixed * Diagnostics will no longer contain control codes when output is a file. * CLI executables now fail when printing an opaque instead of producing garbage. This improves handling of some incorrect uses of `auto output` (#1251). ## [0.22.1] ### Removed * Removed prelude functions `reduce_stream`, `map_stream`, `reduce_stream_per`, and `reduce_map_per`. ### Fixed * Various fixes to scalar migration (#1721). ## [0.21.15] ### Fixed * Corrupted OOM error messages. * Excessive internal fragmentation for some programs (#1715). ## [0.21.14] ### Fixed * `replicate` was broken for sizes that didn't fit in `i32`. * Transposition of empty arrays in interpreter (#1700). * Exotic simplification error (#1309). * Rare race condition could lead to leaking of error message memory in `multicore` and `ispc` backends (#1709). * Overzealous aliasing for built-in overloaded types (#1711). ## [0.21.13] ### Added * New fusion engine by Amar Topalovic and Walter Restelli-Nielsen. Fuses more, which is good for some programs and bad for others. Most programs see no change. This is mainly a robust foundation for further improvements. * New experimental backend: `ispc`. By Louis Marott Normann, Kristoffer August Kortbæk, William Pema Norbu Holmes Malling, and Oliver Bak Kjersgaard Petersen. * New prelude functions: `hist`, `spread`. These function as non-consuming variants of `reduce_by_index` and `scatter`. * Using `==` to compare arrays is now deprecated. * New command: `futhark tokens`. You probably don't care about this one. * In the C API, opaque types that correspond to tuples or records can now have their elements inspected and be created from elements (#1568). * New server protocol commands: `types`, `fields`, `entry_points`. * Tuples and records can now be passed from FutharkScript to Futhark entry points (#1684). ### Fixed * Sometimes missing cases in `match` expressions were not detected. * A defective simplification rule could in very rare cases lead to infinite recursion in the compiler (#1685). * Some broken links in `futhark doc` output (#1686). * Incorrect checking of whether a function parameter is consumable based on its type (#1687). * Missing aliases for functions that return multiple aliased values (#1690). * `new`/`values` functions for GPU backends are now properly asynchronous (#1664). This may uncover bugs in application code. ## [0.21.12] ### Added * Somewhat simplified the handling of "uniqueness types" (which is a term we are moving away from). You should never see `*` in non-function types, and they are better thought of as effect indicators. * `futhark literate`: prints tracing output (and other logging messages that may occur) when run with `-v` (#1678). * Entry points can now be any valid Futhark identifier. ### Fixed * `futhark test -C` was broken. * `futhark_context_free()` for the GPU backends neglected to free some memory used for internal bookkeeping, which could lead to memory leaks for processes that repeatedly create and destroy contexts (#1676). * FutharkScript now allows `'` in names. * `futhark lsp` now handles warnings in programs that also have errors. ## [0.21.11] ### Added * The CUDA backend now supports compute capability 8.6 and 8.7. * Philip Børgesen has implemented a new optimisation for GPU backends that migrates scalar work to the GPU, in order to reduce synchronisation. This results in major speedup for some programs. * String literals are now allowed in `input` blocks. * Experimental and undocumented support for automatic differentiation, available on the secret menu. * Assertions and attributes are now ignored when used as size expressions. E.g. `iota (assert p n) 0` now has size `n`. * `futhark test` only runs the interpreter if passed `-i`. * `futhark literate` now shows progress bars when run with `-v`. ### Fixed * `futhark lsp` is now better at handling multiple files (#1647). * Incorrect handling of local quantification when determining type equivalence in during module type ascription (#1648). * Incorrect checking of liftedness when instantiating polymorphic functions during module type ascription. * Tightened some restrictions on the use of existential sizes that could otherwise lead to compiler crashes (#1650). This restriction is perhaps a bit *too* might and it may be possible to loosen it in the future. * Another defunctorisation bug (#1653). Somehow we find these every time Martin Elsman writes a nontrivial Futhark program. * `futhark bench`: convergence phase now does at least `--runs` runs. * Errors and warnings no longer awkwardly mixed together in console output. * Slightly better type errors for ambiguous sizes (#1661). * Better type errors for module ascription involving nested modules (#1660). * `futhark doc`: some formatting bugs. * `futhark doc` didn't notice all `local` module types (#1666). * Missing consumption check in optimisation could lead to ICE (#1669). ## [0.21.10] ### Added * New math functions: `f16.erf`, `f32.erf`, `f64.erf`. * New math functions: `f16.erfc`, `f32.erfc`, `f64.erfc`. * New math functions: `f16.cbrt`, `f32.cbrt`, `f64.cbrt`. ### Fixed * Variables being indexed now have correct source spans in AST. * `futhark lsp`s hover information now contains proper range information. * `futhark query` and `futhark lsp` incorrectly thought size parameters had type `i32`. * `futhark doc` put documentation for prelude modules in the wrong location (which also led to messed-up style sheets). ## [0.21.9] ### Added * Sun Haoran has implemented unnamed typed holes, with syntax `???`. * Sun Haoran has implemented the beginnings of a language server: `futhark lsp`. A VSCode language extension is available on the marketplace, but the language server should work with any editor. * Crucial new command: `futhark thanks`. * The GPU backends now support a caching mechanism for JIT-compiled code, significantly improving startup times. Use the `futhark_context_config_set_cache_file()` in the C API, the `--cache-file` option on executables, or the `--cache-extension` option on `futhark test` and `futhark bench`. These also work for the non-GPU backends, but currently have no effect. (#1614) * Aleksander Junge has improved `futhark bench` such that it intelligently chooses how many runs to perform (#1335). ### Fixed * Incomplete simplification would cause some instances of nested parallelism to require irregular allocations (#1610). * Missing alias checking for a simplification rule related to in-place updates (#1615, #1628). * Incorrect code generation for certain copies of transposed arrays (#1627). * Fusion would mistakenly try to treat some loops with irregular sizes (#1631). * Memory annotation bug for non-inlined functions (#1634). ## [0.21.8] ### Added * Slightly better parse errors (again). * `futhark literate` now supports a `file:` option in `:img` and `:video` directives (#1491). ### Fixed * Improved hoisting of size computations. This could cause some regular nested parallel programs to run into compiler limitations, as if they were irregular. * Rare code generation bug for histograms (#1609). ## [0.21.7] ### Added * `futhark check-syntax`: check syntactic validity of a program file, without type-checking. * Parsing multi-file programs is now parallelised, making it *slightly* faster. * Reloading a large program in `futhark repl` is now faster, as long as not too many of its files have been modified (#1597). ### Fixed * Mistake in occurs check could cause infinite loop in type checker for programs with type errors (#1599). ## [0.21.6] ### Added * `futhark bench` now explicitly notes when a tuning file is not present. * `futhark bench`, `futhark test` and friends are now better at handling fatally terminating programs (e.g. segmentation faults). * Generated C code is now a lot smaller for large programs, as error recovery has been more centralised (#1584). ### Fixed * Some bugs in checking for local memory capacity for particularly exotic generated code. * Insufficient hoisting of allocation sizes led to problems with memory expansion in rare cases (#1579). * Conversion of floating-point NaNs and infinities to integers now well defined (produces zero) (#1586). * Better handling of OOM for certain short-lived CPU-side allocations (#1585). ## [0.21.5] ### Added * API functions now return more precise error codes in some cases. * Out-of-memory errors contain more information. ### Fixed * Yet another in-place lowering issue (#1569). * Removed unnecessary bounds checks in register tiling, giving about 1.8x speedup on e.g. matrix multiplication on newer NVIDIA GPUs. * A parser bug erroneously demanded whitespace in some type expressions (#1573). * Some memory was not being freed correctly when shutting down OpenCL and CUDA contexts, which could lead to memory leaks in processes that created and freed many contexts. * An incorrect copy-removal in some exotic cases (#1572). * 'restore'-functions might perform undefined pointer arithmetic when passed garbage. ## [0.21.4] ### Fixed * A size inference bug in type checking of `loop`s (#1565). * Exotic flattening bug (#1563). * Segmented `reduce_by_index` with fairly small histogram size would use vastly more memory than needed. ## [0.21.3] ### Added * Parse errors now list possible expected tokens. * Lexer errors now mention the file. ### Fixed * Overloaded number literals cannot be sum types (#1557). * Defective GPU code generation for vectorised non-commutative operatators (#1559). * Excessive memory usage for some programs (#1325). ## [0.21.2] ### Added * New functions: `reduce_by_index_2d`, `reduce_by_index_3d`. * Manifests now contain compiler version information. ### Fixed * Allocation insertion pass bug (#1546). * An exotic bug involving TLS and dynamically loading code generated by the `multicore` backend. * Unconstrained ambiguous types now default to `()` (#1552). This should essentially never have any observable impact, except that more programs will type check. * Double buffering compiler crash (#1553). ## [0.21.1] ### Added * Top-level value definitions can (and should) now be declared with with `def`, although `let` still works. * New tool: `futhark defs`, for printing locations of top-level definitions. ### Changed * `def` is now a reserved word. ### Fixed * Contrived intra-group code versions with no actual parallelism would be given a group size of zero (#1524). ## [0.20.8] ### Added * `futhark repl` now allows Ctrl-c to interrupt execution. ### Fixed * Alias tracking of sum types. * Proper checking that a function declared to return a unique-typed value actually does so. * Faulty uniqueness checking and inference for lambdas (#1535). * Monomorphisation would duplicate functions under rare circumstances (#1537). * Interpreter didn't check that the arguments passed to `unflatten` made sense (#1539). * `futhark literate` now supports a `$loaddata` builtin function for passing datasets to Futhark programs. ## [0.20.7] ### Added * Better exploitation of parallelism in fused nested segmented reductions. * Prelude function `not` for negating booleans. ### Fixed * Some incorrect removal of copies (#1505). * Handling of parametric modules with top-level existentials (#1510). * Module substitution fixes (#1512, #1518). * Invalid in-place lowering (#1523). * Incorrect code generation for some intra-group parallel code versions. * Flattening crash in the presence of irregular parallelism (#1525). * Incorrect substitution of type abbreviations with hidden sizes (#1531). * Proper handling of NaN in `min`/`max` functions for `f16`/`f32`/`f64` in interpreter (#1528). ## [0.20.6] ### Added * Much better code generation for segmented scans with vectorisable operators. ### Fixed * Fixes to extremely exotic GPU scans involving array operators. * Missing alias tracking led to invalid rewrites, causing a compiler crash (#1499). * Top-level bindings with existential sizes were mishandled (#1500, #1501). * A variety of memory leaks in the multicore backend, mostly (or perhaps exclusively) centered around context freeing or failing programs - this should not have affected many people. * Various fixes to `f16` handling in the GPU backends. ## [0.20.5] ### Added * Existential sizes can now be explicitly quantified in type expressions (#1308). * Significantly expanded error index. * Attributes can now be numeric. * Patterns can now have attributes. None have any effect at the moment. * `futhark autotune` and `futhark bench` now take a `--spec-file` option for loading a test specification from another file. ### Fixed * `auto output` reference datasets are now recreated when the program is newer than the data files. * Exotic hoisting bug (#1490). ## [0.20.4] ### Added * Tuning parameters now (officially) exposed in the C API. * `futhark autotune` is now 2-3x faster on many programs, as it now keeps the process running. * Negative numeric literals are now allowed in `case` patterns. ### Fixed * `futhark_context_config_set_profiling` was missing for the `c` backend. * Correct handling of nested entry points (#1478). * Incorrect type information recorded when doing in-place lowering (#1481). ## [0.20.3] ### Added * Executables produced by C backends now take a `--no-print-result` option. * The C backends now generate a manifest when compiling with `--library`. This can be used by FFI generators (#1465). * The beginnings of a Rust-style error index. * `scan` on newer CUDA devices is now much faster. ### Fixed * Unique opaque types are named properly in entry points. * The CUDA backend in library mode no longer `exit()`s the process if NVRTC initialisation fails. ## [0.20.2] ### Fixed * Simplification bug (#1455). * In-place-lowering bug (#1457). * Another in-place-lowering bug (#1460). * Don't try to tile inside loops with parameters with variant sizes (#1462). * Don't consider it an ICE when the user passes invalid command line options (#1464). ## [0.20.1] ### Added * The `#[trace]` and `#[break]` attributes now replace the `trace` and `break` functions (although they are still present in slightly-reduced but compatible form). * The `#[opaque]` attribute replaces the `opaque` function, which is now deprecated. * Tracing now works in compiled code, albeit with several caveats (mainly, it does not work for code running on the GPU). * New `wasm` and `wasm-multicore` backends by Philip Lassen. Still very experimental; do not expect API stability. * New intrinsic type `f16`, along with a prelude module `f16`. Implemented with hardware support where it is available, and with `f32`-based emulation where it is not. * Sometimes slightly more informative error message when input of the wrong type is passed to a test program. ### Changed * The `!` function in the integer modules is now called `not`. * `!` is now builtin syntax. You can no longer define a function called `!`. It is extremely unlikely this affects you. This removes the last special-casing of prefix operators. * A prefix operator section (i.e. `(!)`) is no longer permitted (and it never was according to the grammar). * The offset parameter for the "raw" array creation functions in the C API is now `int64_t` instead of `int`. ### Fixed * `i64.abs` was wrong for arguments that did not fit in an `i32`. * Some `f32` operations (`**`, `abs`, `max`) would be done in double precision on the CUDA backend. * Yet another defunctorisation bug (#1397). * The `clz` function would sometimes exhibit undefined behaviour in CPU code (#1415). * Operator priority of prefix `-` was wrong - it is now the same as `!` (#1419). * `futhark hash` is now invariant to source location as well as stable across OS/compiler/library versions. * `futhark literate` is now much better at avoiding unnecessary recalculation. * Fixed a hole in size type checking that would usually lead to compiler crashes (#1435). * Underscores now allowed in numeric literals in test data (#1440). * The `cuda` backend did not use single-pass segmented scans as intended. Now it does. ## [0.19.7] ### Added * A new memory reuse optimisation has been added. This results in slightly lower footprint for many programs. * The `cuda` backend now uses a fast single-pass implementation for segmented `scan`s, due to Morten Tychsen Clausen (#1375). * `futhark bench` now prints interim results while it is running. ### Fixed * `futhark test` now provides better error message when asked to test an undefined entry point (#1367). * `futhark pkg` now detects some nonsensical package paths (#1364). * FutharkScript now parses `f x y` as applying `f` to `x` and `y`, rather than as `f (x y)`. * Some internal array utility functions would not be generated if entry points exposed both unit arrays and boolean arrays (#1374). * Nested reductions used (much) more memory for intermediate results than strictly needed. * Size propagation bug in defunctionalisation (#1384). * In the C FFI, array types used only internally to implement opaque types are no longer exposed (#1387). * `futhark bench` now copes with test programs that consume their input (#1386). This required an extension of the server protocol as well. ## [0.19.6] ### Added * `f32.hypot` and `f64.hypot` are now much more numerically exact in the interpreter. * Generated code now contains a header with information about the version of Futhark used (and maybe more information in the future). * Testing/benchmarking with large input data (including randomly generated data) is much faster, as each file is now only read once. * Test programs may now use arbitrary FutharkScript expressions to produce test input, in particular expressions that produce opaque values. This affects both testing, benchmarking, and autotuning. * Compilation is about 10% faster, especially for large programs. ### Fixed * `futhark repl` had trouble with declarations that produced unknown sizes (#1347). * Entry points can now have same name as (undocumented!) compiler intrinsics. * FutharkScript now detects too many arguments passed to functions. * Sequentialisation bug (#1350). * Missing causality check for index sections. * `futhark test` now reports mismatches using proper indexes (#1356). * Missing alias checking in fusion could lead to compiler crash (#1358). * The absolute value of NaN is no longer infinity in the interpreter (#1359). * Proper detection of zero strides in compiler (#1360). * Invalid memory accesses related to internal bookkeeping of bounds checking. ## [0.19.5] ### Added * Initial work on granting programmers more control over existential sizes, starting with making type abbreviations function as existential quantifiers (#1301). * FutharkScript now also supports arrays and scientific notation. * Added `f32.epsilon` and `f64.epsilon` for the difference between 1.0 and the next larger representable number. * Added `f32.hypot` and `f64.hypot` for your hypothenuse needs (#1344). * Local size bindings in `let` expressions, e.g: ``` let [n] (xs': [n]i32) = filter (>0) xs in ... ``` ### Fixed * `futhark_context_report()` now internally calls `futhark_context_sync()` before collecting profiling information (if applicable). * `futhark literate`: Parse errors for expression directives now detected properly. * `futhark autotune` now works with the `cuda` backend (#1312). * Devious fusion bug (#1322) causing compiler crashes. * Memory expansion bug for certain complex GPU kernels (#1328). * Complex expressions in index sections (#1332). * Handling of sizes in abstract types in the interpreter (#1333). * Type checking of explicit size requirements in `loop` parameter (#1324). * Various alias checking bugs (#1300, #1340). ## [0.19.4] ### Fixed * Some uniqueness ignorance in fusion (#1291). * An invalid transformation could in rare cases cause race conditions (#1292). * Generated Python and C code should now be warning-free. * Missing check for uses of size-lifted types (#1294). * Error in simplification of concatenations could cause compiler crashes (#1296). ## [0.19.3] ### Added * Better `futhark test`/`futhark bench` errors when test data does not have the expected type. ### Fixed * Mismatch between how thresholds were printed and what the autotuner was looking for (#1269). * `zip` now produces unique arrays (#1271). * `futhark literate` no longer chokes on lines beginning with `--` without a following whitespace. * `futhark literate`: `:loadimg` was broken due to overzealous type checking (#1276). * `futhark literate`: `:loadimg` now handles relative paths properly. * `futhark hash` no longer considers the built-in prelude. * Server executables had broken store/restore commands for opaque types. ## [0.19.2] ### Added * New subcommand: `futhark hash`. * `futhark literate` is now smart about when to regenerate image and animation files. * `futhark literate` now produces better error messages passing expressions of the wrong type to directives. ### Fixed * Type-checking of higher-order functions that take consuming funtional arguments. * Missing cases in causality checking (#1263). * `f32.sgn` was mistakenly defined with double precision arithmetic. * Only include double-precision atomics if actually needed by program (this avoids problems on devices that only support single precision). * A lambda lifting bug due to not handling existential sizes produced by loops correctly (#1267). * Incorrect uniqueness attributes inserted by lambda lifting (#1268). * FutharkScript record expressions were a bit too sensitive to whitespace. ## [0.19.1] ### Added * `futhark literate` now supports a `$loadimg` builtin function for passing images to Futhark programs. * The `futhark literate` directive for generating videos is now `:video`. * Support for 64-bit atomics on CUDA and OpenCL for higher performance with `reduce_by_index` in particular. Double-precision float atomics are used on CUDA. * New functions: `f32.recip` and `f64.recip` for multiplicative inverses. * Executables produced with the `c` and `multicore` backends now also accept `--tuning` and `--size` options (although there are not yet any tunable sizes). * New functions: `scatter_2d` and `scatter_3d` for scattering to multi-dimensional arrays (#1258). ### Removed * The math modules no longer define the name `negate` (use `neg` instead). ### Fixed * Exotic core language alias tracking bug (#1239). * Issue with entry points returning constant arrays (#1240). * Overzealous CSE collided with uniqueness types (#1241). * Defunctionalisation issue (#1242). * Tiling inside multiply nested loops (#1243). * Substitution bug in interpreter (#1250). * `f32.sgn`/`f64.sgn` now correct for NaN arguments. * CPU backends (`c`/`multicore`) are now more careful about staying in single precision for `f32` functions (#1253). * `futhark test` and `futhark bench` now detect program initialisation errors in a saner way (#1246). * Partial application of operators with parameters used in a size-dependent way now works (#1256). * An issue regarding abstract size-lifted sum types (#1260). ## [0.18.6] ### Added * The C API now exposes serialisation functions for opaque values. * The C API now lets you pick which stream (if any) is used for logging prints (#1214). * New compilation mode: `--server`. For now used to support faster benchmarking and testing tools, but can be used to build even fancier things in the future (#1179). * Significantly faster reading/writing of large values. This mainly means that validation of test and benchmark results is much faster (close to an order of magnitude). * The experimental `futhark literate` command allows vaguely a notebook-like programming experience. * All compilers now accept an `--entry` option for treating more functions as entry points. * The `negate` function is now `neg`, but `negate` is kept around for a short while for backwards compatibility. * Generated header-files are now declared `extern "C"` when processed with a C++ compiler. * Parser errors in test blocks used by `futhark bench` and `futhark test` are now reported with much better error messages. ### Fixed * Interaction between slice simplification and in-place updates (#1222). * Problem with user-defined functions with the same name as intrinsics. * Names from transitive imports no longer leak into scope (#1231). * Pattern-matching unit values now works (#1232). ## [0.18.5] ### Fixed * Fix tiling crash (#1203). * `futhark run` now does slightly more type-checking of its inputs (#1208). * Sum type deduplication issue (#1209). * Missing parentheses when printing sum values in interpreter. ## [0.18.4] ### Added * When compiling to binaries in the C-based backends, the compiler now respects the ``CFLAGS`` and ``CC`` environment variables. * GPU backends: avoid some bounds-checks for parallel sections inside intra-kernel loops. * The `cuda` backend now uses a much faster single-pass `scan` implementation, although only for nonsegmented scans where the operator operates on scalars. ### Fixed * `futhark dataset` now correctly detects trailing commas in textual input (#1189). * Fixed local memory capacity check for intra-group-parallel GPU kernels. * Fixed compiler bug on segmented rotates where the rotation amount is variant to the nest (#1192). * `futhark repl` no longer crashes on type errors in given file (#1193). * Fixed a simplification error for certain arithmetic expressions (#1194). * Fixed a small uniqueness-related bug in the compilation of operator section. * Sizes of opaque entry point arguments are now properly checked (related to #1198). ## [0.18.3] ### Fixed * Python backend now disables spurious NumPy overflow warnings for both library and binary code (#1180). * Undid deadlocking over-synchronisation for freeing opaque objects. * `futhark datacmp` now handles bad input files better (#1181). ## [0.18.2] ### Added * The GPU loop tiler can now handle loops where only a subset of the input arrays are tiled. Matrix-vector multiplication is one important program where this helps (#1145). * The number of threads used by the `multicore` backend is now configurable (`--num-threads` and `futhark_context_config_set_num_threads()`). (#1162) ### Fixed * PyOpenCL backend would mistakenly still streat entry point argument sizes as 32 bit. * Warnings are now reported even for programs with type errors. * Multicore backend now works properly for very large iteration spaces. * A few internal generated functions (`init_constants()`, `free_constants()`) were mistakenly declared non-static. * Process exit code is now nonzero when compiler bugs and limitations are encountered. * Multicore backend crashed on `reduce_by_index` with nonempty target and empty input. * Fixed a flattening issue for certain complex `map` nestings (#1168). * Made API function `futhark_context_clear_caches()` thread safe (#1169). * API functions for freeing opaque objects are now thread-safe (#1169). * Tools such as `futhark dataset` no longer crash with an internal error if writing to a broken pipe (but they will return a nonzero exit code). * Defunctionalisation had a name shadowing issue that would crop up for programs making very advanced use of functional representations (#1174). * Type checker erroneously permitted pattern-matching on string literals (this would fail later in the compiler). * New coverage checker for pattern matching, which is more correct. However, it may not provide quite as nice counter-examples (#1134). * Fix rare internalisation error (#1177). ## [0.18.1] ### Added * Experimental multi-threaded CPU backend, `multicore`. ### Changed * All sizes are now of type `i64`. This has wide-ranging implications and most programs will need to be updated (#134). ## [0.17.3] ### Added * Improved parallelisation of `futhark bench` compilation. ### Fixed * Dataset generation for test programs now use the right `futhark` executable (#1133). * Really fix NaN comparisons in interpreter (#1070, again). * Fix entry points with a parameter that is a sum type where multiple constructors contain arrays of the same statically known size. * Fix in monomorphisation of types with constant sizes. * Fix in in-place lowering (#1142). * Fix tiling inside multiple nested loops (#1143). ## [0.17.2] ### Added * Obscure loop optimisation (#1110). * Faster matrix transposition in C backend. * Library code generated with CUDA backend can now be called from multiple threads. * Better optimisation of concatenations of array literals and replicates. * Array creation C API functions now accept `const` pointers. * Arrays can now be indexed (but not sliced) with any signed integer type (#1122). * Added --list-devices command to OpenCL binaries (#1131) * Added --help command to C, CUDA and OpenCL binaries (#1131) ### Removed * The integer modules no longer contain `iota` and `replicate` functions. The top-level ones still exist. * The `size` module type has been removed from the prelude. ### Changed * Range literals may no longer be produced from unsigned integers. ### Fixed * Entry points with names that are not valid C (or Python) identifiers are now pointed out as problematic, rather than generating invalid C code. * Exotic tiling bug (#1112). * Missing synchronisation for in-place updates at group level. * Fixed (in a hacky way) an issue where `reduce_by_index` would use too much local memory on AMD GPUs when using the OpenCL backend. ## [0.16.4] ### Added * `#[unroll]` attribute. * Better error message when writing `a[i][j]` (#1095). * Better error message when missing "in" (#1091). ### Fixed * Fixed compiler crash on certain patterns of nested parallelism (#1068, #1069). * NaN comparisons are now done properly in interpreter (#1070). * Fix incorrect movement of array indexing into branches `if`s (#1073). * Fix defunctorisation bug (#1088). * Fix issue where loop tiling might generate out-of-bounds reads (#1094). * Scans of empty arrays no longer result in out-of-bounds memory reads. * Fix yet another defunctionalisation bug due to missing eta-expansion (#1100). ## [0.16.3] ### Added * `random` input blocks for `futhark test` and `futhark bench` now support floating-point literals, which must always have either an `f32` or `f64` suffix. * The `cuda` backend now supports the `-d` option for executables. * The integer modules now contain a `ctz` function for counting trailing zeroes. ### Fixed * The `pyopencl` backend now works with OpenCL devices that have multiple types (most importantly, oclgrind). * Fix barrier divergence when generating code for group-level colletive copies in GPU backend. * Intra-group flattening now looks properly inside of branches. * Intra-group flattened code versions are no longer used when the resulting workgroups would have less than 32 threads (with default thresholds anyway) (#1064). ## [0.16.2] ### Added * `futhark autotune`: added `--pass-option`. ### Fixed * `futhark bench`: progress bar now correct when number of runs is less than 10 (#1050). * Aliases of arguments passed for consuming parameters are now properly checked (#1053). * When using a GPU backend, errors are now properly cleared. Previously, once e.g. an out-of-bounds error had occurred, all future operations would fail with the same error. * Size-coercing a transposed array no longer leads to invalid code generation (#1054). ## [0.16.1] ### Added * Incremental flattening is now performed by default. Use attributes to constrain and direct the flattening if you have exotic needs. This will likely need further iteration and refinement. * Better code generation for `reverse` (and the equivalent explicit slice). * `futhark bench` now prints progress bars. * The `cuda` backend now supports similar profiling as the `opencl` option, although it is likely slightly less accurate in the presence of concurrent operations. * A preprocessor macro `FUTHARK_BACKEND_foo` is now defined in generated header files, where *foo* is the name of the backend used. * Non-inlined functions (via `#[noinline]`) are now supported in GPU code, but only for functions that *exclusively* operate on scalars. * `futhark repl` now accepts a command line argument to load a program initially. * Attributes are now also permitted on declarations and specs. * `futhark repl` now has a `:nanbreak` command (#839). ### Removed * The C# backend has been removed (#984). * The `unsafe` keyword has been removed. Use `#[unsafe]` instead. ### Changed * Out-of-bounds literals are now an error rather than a warning. * Type ascriptions on entry points now always result in opaque types when the underlying concrete type is a tuple (#1048). ### Fixed * Fix bug in slice simplification (#992). * Fixed a typer checker bug for tracking the aliases of closures (#995). * Fixed handling of dumb terminals in futhark test (#1000). * Fixed exotic monomorphisation case involving lifted type parameters instantiated with functions that take named parameters (#1026). * Further tightening of the causality restriction (#1042). * Fixed alias tracking for right-operand operator sections (#1043). ## [0.15.8] ### Added * Warnings for overflowing literals, such as `1000 : u8`. * Futhark now supports an attribute system, whereby expressions can be tagged with attributes that provide hints or directions to the compiler. This is an expert-level feature, but it is sometimes useful. ## [0.15.7] ### Added * Faster index calculations for very tight GPU kernels (such as the ones corresponding to 2D tiling). * `scan` with vectorised operators (e.g. `map2 (+)`) is now faster in some cases. * The C API has now been documented and stabilized, including obtaining profiling information (although this is still unstructured). ### Fixed * Fixed some cases of missing fusion (#953). * Context deinitialisation is now more complete, and should not leak memory (or at least not nearly as much, if any). This makes it viable to repeatedly create and free Futhark contexts in the same process (although this can still be quite slow). ## [0.15.6] ### Added * Binary operators now act as left-to-right sequence points with respect to size types. * `futhark bench` now has more colourful and hopefully readable output. * The compiler is now about 30% faster for most nontrivial programs. This is due to parallelising the inlining stage, and tweaking the default configuration of the Haskell RTS. * `futhark dataset` is now about 8-10x faster. ### Fixed * Fixed some errors regarding constants (#941). * Fixed a few missing type checker cases for sum types (#938). * Fix OOB write in CUDA backend runtime code (#950). ## [0.15.5] ### Added * `reduce_by_index` with `f32`-addition is now approximately 2x faster in the CUDA backend. ### Fixed * Fixed kernel extractor bug in `if`-interchange (#921). * Fixed some cases of malformed kernel code generation (#922). * Fixed rare memory corruption bug involving branches returning arrays (#923). * Fixed spurious warning about entry points involving opaque return types, where the type annotations are put on a higher-order return type. * Fixed incorrect size type checking for sum types in negative position with unknown constructors (#927). * Fixed loop interchange for permuted sequential loops with more than one outer parallel loop (#928). * Fixed a type checking bug for branches returning incomplete sum types (#931). ## [0.15.4] ### Added * `futhark pkg` now shells out to `curl` for HTTP requests. * `futhark doc` now supports proper GitHub-flavored Markdown, as it uses the `cmark-gfm` library internally. * Top-level constants are now executed only once per program instance. This matters when Futhark is used to generate library code. * `futhark autotune` is better at handling degrees of parallelism that assume multiple magnitudes during a single run. * `futhark pkg` now uses `curl` to retrieve packages. * Type errors are now printed in red for better legibility (thanks to @mxxo!). ### Fixed * Fixed incorrect handling of opaques in entry point return types. * `futhark pkg` now works properly with GitLab (#899). ## [0.15.3] ### Added * `scan` now supports operators whose operands are arrays. They are significantly slower than primitive-typed scans, so avoid them if at all possible. * Precomputed constants are now handled much more efficiently. * Certain large programs that rely heavily on inlining now compile orders of magnitude faster. ### Fixed * Some fixes to complicated module expressions. * `futhark pkg` should no longer crash uncontrollably on network errors (#894). * Fixed local open in interpreter (#887). * Fix error regarding entry points that called other entry points which contained local functions (#895). * Fix loading OpenCL kernels from a binary. ## [0.15.2] ### Fixed * Fix a REPL regression that made it unable to handle overloaded types (such as numeric literals, oops). * The uniqueness of a record is now the minimum of the uniqueness of any of its elements (#870). * Bug in causality checking has been fixed (#872). * Invariant memory allocations in scan/reduce operators are now supported. * `futhark run` now performs more type checking on entry point input (#876). * Compiled Futhark programs now check for EOF after the last input argument has been read (#877). * Fixed a bug in `loop` type checking that prevented the result from ever aliasing the initial parameter values (#879). ## [0.15.1] ### Added * Futhark now type-checks size annotations using a size-dependent type system. * The parallel code generators can now handle bounds checking and other safety checks. * Integer division by zero is now properly safety-checked and produces an error message. * Integer exponentiation with negative exponent is now properly safety-checked and produces an error message. * Serious effort has been put into improving type errors. * `reduce_by_index` may be somewhat faster for complex operators on histograms that barely fit in local memory. * Improved handling of in-place updates of multidimensional arrays nested in `map`. These are now properly parallelised. * Added `concat_to` and `flatten_to` functions to prelude. * Added `indices` function to the prelude. * `futhark check` and all compilers now take a `-w` option for disabling warnings. * `futhark bench` now accepts `--pass-compiler-option`. * The integer modules now have `mad_hi` and `mul_hi` functions for getting the upper part of multiplications. Thanks to @porcuquine for the contribution! * The `f32` and `f64` modules now also define `sinh`, `cosh`, `tanh`, `asinh`, `acosh`, and `atanh` functions. * The `f32` and `f64` modules now also define `fma` and `mad` functions. ### Removed * Removed `update`, `split2`, `intersperse`, `intercalate`, `pick`, `steps`, and `range` from the prelude. ### Changed * `"futlib"` is now called `"prelude"`, and it is now an error to import it explicitly. ### Fixed * Corrected address calculations in `csharp` backend. * The C backends are now more careful about generating overflowing integer operations (since this is undefined behaviour in C, but defined in Futhark). * `futhark dataset` no longer crashes uncontrollably when used incorrectly (#849). ## [0.14.1] ### Added * The optimiser is now somewhat better at removing unnecessary copies of array slices. * `futhark bench` and `futhark test` now take a `--concurrency` option for limiting how many threads are used for housekeeping tasks. Set this to a low value if you run out of memory. * `random` test blocks are now allowed to contain integer literals with type suffixes. * `:frame ` command for `futhark repl` for inspecting the stack. * `e :> t` notation, which means the same as `e : t` for now, but will have looser constraints in the future. * Size-lifted type abbreviations can be declared with `type~` and size-lifted type parameters with `'~`. These currently have no significant difference from fully lifted types. ### Changed * Tuples are now 0-indexed (#821, which also includes a conversion script). * Invalid ranges like `1..<0` now produce a run-time error instead of an empty array. * Record updates (`r with f = e`) now require `r` to have a completely known type up to `f`. This is a restriction that will hopefully be lifted in the future. * The backtrace format has changed to be innermost-first, like pretty much all other languages. * Value specs must now explicitly quantify all sizes of function parameters. Instead of val sum: []t -> t you must write val sum [n]: [n]t -> t * `futhark test` now once again numbers un-named data sets from 0 rather than from 1. This fits a new general principle of always numbering from 0 in Futhark. * Type abbreviations declared with `type` may no longer contain functions or anonymous sizes in their definition. Use `type^` for these cases. Just a warning for now, but will be an error in the future. ### Fixed * Work around (probable) AMD OpenCL compiler bug for `reduce_by_index` operations with complex operators that require locking. * Properly handle another ICE on parse errors in test stanzas (#819). * `futhark_context_new_with_command_queue()` now actually works. Oops. * Different scopes are now properly addressed during type inference (#838). Realistically, there will still be some missing cases. ## [0.13.2] ### Added * New subcommand, `futhark query`, for looking up information about the name at some position in a file. Intended for editor integration. * (Finally) automatic support for compute model 7.5 in the CUDA backend. * Somewhat better performance for very large target arrays for `reduce_by_index.`. ### Fixed * Fixed a slice-iota simplification bug (#813). * Fixed defunctionalisation crash involving intrinsics (#814). ## [0.13.1] ### Added * Stack traces are now multiline for better legibility. ### Changed * The `empty(t)` notation now specifies the type of the *entire value* (not just the element type), and requires dimension sizes when `t` is an array (e.g. `empty(i32)` is no longer allowed, you need for example `empty([0]i32)`). * All input files are now assumed to be in UTF-8. ### Fixed * Fixed exponential-time behaviour for certain kernels with large arithmetic expressions (#805). * `futhark test` and friends no longer crash when reporting some errors (#808). * Fix uniqueness of loop results (#810). ## [0.12.3] ### Added * Character literals can now be any integer type. * The integer modules now have `popc` and `clz` functions. * Tweaked inlining so that larger programs may now compile faster (observed about 20%). * Pattern-matching on large sum typed-values taken from arrays may be a bit faster. ### Fixed * Various small fixes to type errors. * All internal functions used in generated C code are now properly declared `static`. * Fixed bugs when handling dimensions and aliases in type ascriptions. ## [0.12.2] ### Added * New tool: `futhark autotune`, for tuning the threshold parameters used by incremental flattening. Based on work by Svend Lund Breddam, Simon Rotendahl, and Carl Mathias Graae Larsen. * New tool: `futhark dataget`, for extracting test input data. Most will probably never use this. * Programs compiled with the `cuda` backend now take options `--default-group-size`, `--default-num-groups`, and `--default-tile-size`. * Segmented `reduce_by_index` are now substantially fasted for small histograms. * New functions: `f32.lerp` and `f64.lerp`, for linear interpolation. ### Fixed * Fixes to aliasing of record updates. * Fixed unnecessary array duplicates after coalescing optimisations. * `reduce_by_index` nested in `map`s will no longer sometimes require huge amounts of memory. * Source location now correct for unknown infix operators. * Function parameters are no longer in scope of themselves (#798). * Fixed a nasty out-of-bounds error in handling of irregular allocations. * The `floor`/`ceil` functions in `f32`/`f64` now handle infinities correctly (and are also faster). * Using `%` on floats now computes fmod instead of crashing the compiler. ## [0.12.1] ### Added * The internal representation of parallel constructs has been overhauled and many optimisations rewritten. The overall performance impact should be neutral on aggregate, but there may be changes for some programs (please report if so). * Futhark now supports structurally typed sum types and pattern matching! This work was done by Robert Schenck. There remain some problems with arrays of sum types that themselves contain arrays. * Significant reduction in compile time for some large programs. * Manually specified type parameters need no longer be exhaustive. * Mapped `rotate` is now simplified better. This can be particularly helpful for stencils with wraparound. ### Removed * The `~` prefix operator has been removed. `!` has been extended to perform bitwise negation when applied to integers. ### Changed * The `--futhark` option for `futhark bench` and `futhark test` now defaults to the binary being used for the subcommands themselves. * The legacy `futhark -t` option (which did the same as `futhark check`) has been removed. * Lambdas now bind less tightly than type ascription. * `stream_map` is now `map_stream` and `stream_red` is now `reduce_stream`. ### Fixed * `futhark test` now understands `--no-tuning` as it was always supposed to. * `futhark bench` and `futhark test` now interpret `--exclude` in the same way. * The Python and C# backends can now properly read binary boolean input. ## [0.11.2] ### Fixed * Entry points whose types are opaque due to module ascription, yet whose representation is simple (scalars or arrays of scalars) were mistakely made non-opaque when compiled with ``--library``. This has been fixed. * The CUDA backend now supports default sizes in `.tuning` files. * Loop interchange across multiple dimensions was broken in some cases (#767). * The sequential C# backend now generates code that compiles (#772). * The sequential Python backend now generates code that runs (#765). ## [0.11.1] ### Added * Segmented scans are a good bit faster. * `reduce_by_index` has received a new implementation that uses local memory, and is now often a good bit faster when the target array is not too large. * The `f32` and `f64` modules now contain `gamma` and `lgamma` functions. At present these do not work in the C# backend. * Some instances of `reduce` with vectorised operators (e.g. `map2 (+)`) are orders of magnitude faster than before. * Memory usage is now lower on some programs (specifically the ones that have large `map`s with internal intermediate arrays). ### Removed * Size *parameters* (not *annotations*) are no longer permitted directly in `let` and `loop` bindings, nor in lambdas. You are likely not affected (except for the `stream` constructs; see below). Few people used this. ### Changed * The array creation functions exported by generated C code now take `int64_t` arguments for the shape, rather than `int`. This is in line with what the shape functions return. * The types for `stream_map`, `stream_map_per`, `stream_red`, and `stream_red_per` have been changed, such that the chunk function now takes the chunk size as the first argument. ### Fixed * Fixes to reading values under Python 3. * The type of a variable can now be deduced from its use as a size annotation. * The code generated by the C-based backends is now also compilable as C++. * Fix memory corruption bug that would occur on very large segmented reductions (large segments, and many of them). ## [0.10.2] ### Added * `reduce_by_index` is now a good bit faster on operators whose arguments are two 32-bit values. * The type checker warns on size annotations for function parameters and return types that will not be visible from the outside, because they refer to names nested inside tuples or records. For example, the function let f (n: i32, m: i32): [n][m]i32 = ... will cause such a warning. It should instead be written let f (n: i32) (m: i32): [n][m]i32 = ... * A new library function `futhark_context_config_select_device_interactively()` has been added. ### Fixed * Fix reading and writing of binary files for C-compiled executables on Windows. * Fixed a couple of overly strict internal sanity checks related to in-place updates (#735, #736). * Fixed a couple of convoluted defunctorisation bugs (#739). ## [0.10.1] ### Added * Using definitions from the `intrinsic` module outside the prelude now results in a warning. * `reduce_by_index` with vectorised operators (e.g. `map2 (+)`) is orders of magnitude faster than before. * Executables generated with the `pyopencl` backend now support the options `--default-tile-size`, `--default-group-size`, `--default-num-groups`, `--default-threshold`, and `--size`. * Executables generated with `c` and `opencl` now print a help text if run with invalid options. The `py` and `pyopencl` backends already did this. * Generated executables now support a `--tuning` flag for passing many tuned sizes in a file. * Executables generated with the `cuda` backend now take an `--nvrtc-option` option. * Executables generated with the `opencl` backend now take a `--build-option` option. ### Removed * The old `futhark-*` executables have been removed. ### Changed * If an array is passed for a function parameter of a polymorphic type, all arrays passed for parameters of that type must have the same shape. For example, given a function let pair 't (x: t) (y: t) = (x, y) The application `pair [1] [2,3]` will now fail at run-time. * `futhark test` now numbers un-named data sets from 1 rather than 0. This only affects the text output and the generated JSON files, and fits the tuple element ordering in Futhark. * String literals are now of type `[]u8` and contain UTF-8 encoded bytes. ### Fixed * An significant problematic interaction between empty arrays and inner size declarations has been closed (#714). This follows a range of lesser empty-array fixes from 0.9.1. * `futhark datacmp` now prints to stdout, not stderr. * Fixed a major potential out-of-bounds access when sequentialising `reduce_by_index` (in most cases the bug was hidden by subsequent C compiler optimisations). * The result of an anonymous function is now also forbidden from aliasing a global variable, just as with named functions. * Parallel scans now work correctly when using a CPU OpenCL implementation. * `reduce_by_index` was broken on newer NVIDIA GPUs when using fancy operators. This has been fixed. ## [0.9.1] ### Added * `futhark cuda`: a new CUDA backend by Jakob Stokholm Bertelsen. * New command for comparing data files: `futhark datacmp`. * An `:mtype` command for `futhark repl` that shows the type of a module expression. * `futhark run` takes a `-w` option for disabling warnings. ### Changed * Major command reorganisation: all Futhark programs have been combined into a single all-powerful `futhark` program. Instead of e.g. `futhark-foo`, use `futhark foo`. Wrappers will be kept around under the old names for a little while. `futharki` has been split into two commands: `futhark repl` and `futhark run`. Also, `py` has become `python` and `cs` has become `csharp`, but `pyopencl` and `csopencl` have remained as they were. * The result of a function is now forbidden from aliasing a global variable. Surprisingly little code is affected by this. * A global definition may not be ascribed a unique type. This never had any effect in the first place, but now the compiler will explicitly complain. * Source spans are now printed in a slightly different format, with ending the line number omitted when it is the same as the start line number. ### Fixed * `futharki` now reports source locations of `trace` expressions properly. * The type checker now properly complains if you try to define a type abbreviation that has unused size parameters. ## [0.8.1] ### Added * Now warns when `/futlib/...` files are redundantly imported. * `futharki` now prints warnings for files that are ":load"ed. * The compiler now warns when entry points are declared with types that will become unnamed and opaque, and thus impossible to provide from the outside. * Type variables invented by the type checker will now have a unicode subscript to distinguish them from type parameters originating in the source code. * `futhark-test` and `futhark-bench` now support generating random test data. * The library backends now generate proper names for arrays of opaque values. * The parser now permits empty programs. * Most transpositions are now a good bit faster, especially on NVIDIA GPUs. ### Removed * The `<-` symbol can no longer be used for in-place updates and record updates (deprecated in 0.7.3). ### Changed * Entry points that accept a single tuple-typed parameter are no longer silently rewritten to accept multiple parameters. ### Fixed * The `:type` command in `futharki` can now handle polymorphic expressions (#669). * Fixed serious bug related to chaining record updates. * Fixed type inference of record fields (#677). * `futharki` no longer goes in an infinite loop if a ``for`` loop contains a negative upper bound. * Overloaded number types can no longer carry aliases (#682). ## [0.7.4] ### Added * Support type parameters for operator specs defined with `val`. ### Fixed * Fixed nasty defunctionalisation bug (#661). * `cabal sdist` and `stack sdist` works now. ## [0.7.3] ### Added * Significant performance changes: there is now a constant extra compilation overhead (less than 200ms on most machines). However, the rest of the compiler is 30-40% faster (or more in some cases). * A warning when ambiguously typed expressions are assigned a default (`i32` or `f64`). * In-place updates and record updates are now written with `=` instead of `<-`. The latter is deprecated and will be removed in the next major version (#650). ### Fixed * Polymorphic value bindings now work properly with module type ascription. * The type checker no longer requires types used inside local functions to be unambiguous at the point where the local function is defined. They must still be unambiguous by the time the top-level function ends. This is similar to what other ML languages do. * `futhark-bench` now writes "μs" instead of "us". * Type inference for infix operators now works properly. ## [0.7.2] ### Added * `futhark-pkg` now supports GitLab. * `futhark-test`s `--notty` option now has a `--no-terminal` alias. `--notty` is deprecated, but still works. * `futhark-test` now supports multiple entry points per test block. * Functional record updates: `r with f <- x`. ### Fixed * Fix the `-C` option for `futhark-test`. * Fixed incorrect type of `reduce_by_index`. * Segmented `reduce_by_index` now uses much less memory. ## [0.7.1] ### Added * C# backend by Mikkel Storgaard Knudsen (`futhark-cs`/`futhark-csopencl`). * `futhark-test` and `futhark-bench` now take a `--runner` option. * `futharki` now uses a new interpreter that directly interprets the source language, rather than operating on the desugared core language. In practice, this means that the interactive mode is better, but that interpretation is also much slower. * A `trace` function that is semantically `id`, but makes `futharki` print out the value. * A `break` function that is semantically `id`, but makes `futharki` stop and provide the opportunity to inspect variables in scope. * A new SOAC, `reduce_by_index`, for expressing generalised reductions (sometimes called histograms). Designed and implemented by Sune Hellfritzsch. ### Removed * Most of futlib has been removed. Use external packages instead: * `futlib/colour` => https://github.com/athas/matte * `futlib/complex` => https://github.com/diku-dk/complex * `futlib/date` => https://github.com/diku-dk/date * `futlib/fft` => https://github.com/diku-dk/fft * `futlib/linalg` => https://github.com/diku-dk/fft * `futlib/merge_sort`, `futlib/radix_sort` => https://github.com/diku-dk/sorts * `futlib/random` => https://github.com/diku-dk/cpprandom * `futlib/segmented` => https://github.com/diku-dk/segmented * `futlib/sobol` => https://github.com/diku-dk/sobol * `futlib/vector` => https://github.com/athas/vector No replacement: `futlib/mss`, `futlib/lss`. * `zip6`/`zip7`/`zip8` and their `unzip` variants have been removed. If you build gigantic tuples, you're on your own. * The `>>>` operator has been removed. Use an unsigned integer type if you want zero-extended right shifts. ### Changed * The `largest`/`smallest` values for numeric modules have been renamed `highest`/`lowest`. ### Fixed * Many small things. ## [0.6.3] ### Added * Added a package manager: `futhark-pkg`. See also [the documentation](http://futhark.readthedocs.io/en/latest/package-management.html). * Added `log2` and `log10` functions to `f32` and `f64`. * Module type refinement (`with`) now permits refining parametric types. * Better error message when invalid values are passed to generated Python entry points. * `futhark-doc` now ignores files whose doc comment is the word "ignore". * `copy` now works on values of any type, not just arrays. * Better type inference for array indexing. ### Fixed * Floating-point numbers are now correctly rounded to nearest even integer, even in exotic cases (#377). * Fixed a nasty bug in the type checking of calls to consuming functions (#596). ## [0.6.2] ### Added * Bounds checking errors now show the erroneous index and the size of the indexed array. Some other size-related errors also show more information, but it will be a while before they are all converted (and say something useful - it's not entirely straightforward). * Opaque types now have significantly more readable names, especially if you add manual size annotations to the entry point definitions. * Backticked infix operators can now be used in operator sections. ### Fixed * `f64.e` is no longer pi. * Generated C library code will no longer `abort()` on application errors (#584). * Fix file imports on Windows. * `futhark-c` and `futhark-opencl` now generates thread-safe code (#586). * Significantly better behaviour in OOM situations. * Fixed an unsound interaction between in-place updates and parametric polymorphism (#589). ## [0.6.1] ### Added * The `real` module type now specifies `tan`. * `futharki` now supports entering declarations. * `futharki` now supports a `:type` command (or `:t` for short). * `futhark-test` and `futhark-benchmark` now support gzipped data files. They must have a `.gz` extension. * Generated code now frees memory much earlier, which can help reduce the footprint. * Compilers now accept a `--safe` flag to make them ignore `unsafe`. * Module types may now define *lifted* abstract types, using the notation `type ^t`. These may be instantiated with functional types. A lifted abstract type has all the same restrictions as a lifted type parameter. ### Removed * The `rearrange` construct has been removed. Use `transpose` instead. * `futhark-mode.el` has been moved to a [separate repository](https://github.com/diku-dk/futhark-mode). * Removed `|>>` and `<<|`. Use `>->` and `<-<` instead. * The `empty` construct is no longer supported. Just use empty array literals. ### Changed * Imports of the basis library must now use an absolute path (e.g. `/futlib/fft`, not simply `futlib/fft`). * `/futlib/vec2` and `/futlib/vec3` have been replaced by a new `/futlib/vector` file. * Entry points generated by the C code backend are now prefixed with `futhark_entry_` rather than just `futhark_`. * `zip` and `unzip` are no longer language constructs, but library functions, and work only on two arrays and pairs, respectively. Use functions `zipN/unzipN` (for `2<=n<=8`). ### Fixed * Better error message on EOF. * Fixed handling of `..` in `import` paths. * Type errors (and other compiler feedback) will no longer contain internal names. * `futhark-test` and friends can now cope with infinities and NaNs. Such values are printed and read as `f32.nan`, `f32.inf`, `-f32.inf`, and similarly for `f32`. In `futhark-test`, NaNs compare equal. ## [0.5.2] ### Added * Array index section: `(.[i])` is shorthand for `(\x -> x[i])`. Full slice syntax supported. (#559) * New `assert` construct. (#464) * `futhark-mode.el` now contains a definition for flycheck. ### Fixed * The index produced by `futhark-doc` now contains correct links. * Windows linebreaks are now fully supported for test files (#558). ## [0.5.1] ### Added * Entry points need no longer be syntactically first-order. * Added overloaded numeric literals (#532). This means type suffixes are rarely required. * Binary and unary operators may now be bound in patterns by enclosing them in parenthesis. * `futhark-doc` now produces much nicer documentation. Markdown is now supported in documentation comments. * `/futlib/functional` now has operators `>->` and `<-<` for function composition. `<<|` are `|>>` are deprecated. * `/futlib/segmented` now has a `segmented_reduce`. * Scans and reductions can now be horizontally fused. * `futhark-bench` now supports multiple entry points, just like `futhark-test`. * ".." is now supported in `include` paths. ### Removed * The `reshape` construct has been removed. Use the `flatten`/`unflatten` functions instead. * `concat` and `rotate` no longer support the `@` notation. Use `map` nests instead. * Removed `-I`/`--library`. These never worked with `futhark-test`/`futhark-bench` anyway. ### Changed * When defining a module type, a module of the same name is no longer defined (#538). * The `default` keyword is no longer supported. * `/futlib/merge_sort` and `/futlib/radix_sort` now define functions instead of modules. ### Fixed * Better type inference for `rearrange` and `rotate`. * `import` path resolution is now much more robust. ## [0.4.1] ### Added * Unused-result elimination for reductions; particularly useful when computing with dual numbers for automatic differentiation. * Record field projection is now possible for variables of (then) unknown types. A function parameter must still have an unambiguous (complete) type by the time it finishes checking. ### Fixed * Fixed interaction between type ascription and type inference (#529). * Fixed duplication when an entry point was also called as a function. * Futhark now compiles cleanly with GHC 8.4.1 (this is also the new default). ## [0.4.0] ### Added * The constructor for generated PyOpenCL classes now accepts a `command_queue` parameter (#480). * Transposing small arrays is now much faster when using OpenCL backend (#478). * Infix operators can now be defined in prefix notation, e.g.: let (+) (x: i32) (y: i32) = x - y This permits them to have type- and shape parameters. * Comparison operators (<=, <, >, >=) are now valid for boolean operands. * Ordinary functions can be used as infix by enclosing them in backticks, as in Haskell. They are left-associative and have lowest priority. * Numeric modules now have `largest`/`smallest` values. * Numeric modules now have `sum`, `product`, `maximum`, and `minimum` functions. * Added ``--Werror`` command line option to compilers. * Higher-order functions are now supported (#323). * Type inference is now supported, although with some limitations around records, in-place updates, and `unzip`. (#503) * Added a range of higher-order utility functions to the prelude, including (among others): val (|>) '^a '^b: a -> (a -> b) -> b val (<|) '^a '^b: (a -> b) -> a -> b val (|>>) '^a 'b '^c: (a -> b) -> (b -> c) -> a -> c val (<<|) '^a 'b '^c: (b -> c) -> (a -> b) a -> c ### Changed * `FUTHARK_VERSIONED_CODE` is now `FUTHARK_INCREMENTAL_FLATTENING`. * The SOACs `map`, `reduce`, `filter`, `partition`, `scan`, `stream_red,` and `stream_map` have been replaced with library functions. * The futlib/mss and futlib/lss modules have been rewritten to use higher-order functions instead of modules. ### Fixed * Transpositions in generated OpenCL code no longer crashes on large but empty arrays (#483). * Booleans can now be compared with relational operators without crashing the compiler (#499). ## [0.3.1] ### Added * `futhark-bench` now tries to align benchmark results for better legibility. ### Fixed * `futhark-test`: now handles CRLF linebreaks correctly (#471). * A record field can be projected from an array index expression (#473). * Futhark will now never automatically pick Apple's CPU device for OpenCL, as it is rather broken. You can still select it manually (#475). * Fixes to `set_bit` functions in the math module (#476). ## [0.3.0] ### Added * A comprehensible error message is now issued when attempting to run a Futhark program on an OpenCL that does not support the types used by the program. A common case was trying to use double-precision floats on an Intel GPU. * Parallelism inside of a branch can now be exploited if the branch condition and the size of its results is invariant to all enclosing parallel loops. * A new OpenCL memory manager can in some cases dramatically improve performance for repeated invocations of the same entry point. * Experimental support for incremental flattening. Set the environment variable `FUTHARK_VERSIONED_CODE` to any value to try it out. * `futhark-dataset`: Add `-t`/`-type` option. Useful for inspecting data files. * Better error message when ranges written with two dots (`x..y`). * Type errors involving abstract types from modules now use qualified names (less "expected 't', got 't'", more "expected 'foo.t', got 'bar.t'"). * Shorter compile times for most programs. * `futhark-bench`: Add ``--skip-compilation`` flag. * `scatter` expressions nested in `map`s are now parallelised. * futlib: an `fft` module has been added, thanks to David P.H. Jørgensen and Kasper Abildtrup Hansen. ### Removed * `futhark-dataset`: Removed `--binary-no-header` and `--binary-only-header` options. * The `split` language construct has been removed. There is a library function `split` that does approximately the same. ### Changed * futlib: the `complex` module now produces a non-abstract `complex` type. * futlib: the `random` module has been overhauled, with several new engines and adaptors changed, and some of the module types changed. In particular, `rng_distribution` now contains a numeric module instead of an abstract type. * futlib: The `vec2` and `vec3` modules now represent vectors as records rather than tuples. * futlib: The `linalg` module now has distinct convenience functions for multiplying matrices with row and column vectors. * Only entry points defined directly in the file given to the compiler will be visible. * Range literals are now written without brackets: `x...y`. * The syntax `(-x)` can no longer be used for a partial application of subtraction. * `futhark-test` and `futhark-bench` will no longer append `.bin` to executables. * `futhark-test` and `futhark-bench` now replaces actual/expected files from previous runs, rather than increasing the litter. ### Fixed * Fusion would sometimes remove safety checks on e.g. `reshape` (#436). * Variables used as implicit fields in a record construction are now properly recognised as being used. * futlib: the `num_bits` field for the integer modules in `math` now have correct values. ## [0.2.0] ### Added * Run-time errors due to failed assertions now include a stack trace. * Generated OpenCL code now picks more sensible group size and count when running on a CPU. * `scatter` expressions nested in `map`s may now be parallelised ("segmented scatter"). * Add `num_bits`/`get_bit`/`set_bit` functions to numeric module types, including a new `float` module type. * Size annotations may now refer to preceding parameters, e.g: let f (n: i32) (xs: [n]i32) = ... * `futhark-doc`: retain parameter names in generated docs. * `futhark-doc`: now takes `-v`/`--verbose` options. * `futhark-doc`: now generates valid HTML. * `futhark-doc`: now permits files to contain a leading documentation comment. * `futhark-py`/`futhark-pyopencl`: Better dynamic type checking in entry points. * Primitive functions (sqrt etc) can now be constant-folded. * Futlib: /futlib/vec2 added. ### Removed * The built-in `shape` function has been removed. Use `length` or size parameters. ### Changed * The `from_i32`/`from_i64` functions of the `numeric` module type have been replaced with functions named `i32`/`i64`. Similarly functions have been added for all the other primitive types (factored into a new `from_prim` module type). * The overloaded type conversion functions (`i32`, `f32`, `bool`, etc) have been removed. Four functions have been introduced for the special cases of converting between `f32`/`f64` and `i32`: `r32`, `r64`, `t32`, `t64`. * Modules and variables now inhabit the same name space. As a consequence, we now use `x.y` to access field `y` of record `x`. * Record expression syntax has been simplified. Record concatenation and update is no longer directly supported. However, fields can now be implicitly defined: `{x,y}` now creates a record with field `x` and `y`, with values taken from the variables `x` and `y` in scope. ### Fixed * The `!=` operator now works properly on arrays (#426). * Allocations were sometimes hoisted incorrectly (#419). * `f32.e` is no longer pi. * Various other fixes. ## [0.1.0] (This is just a list of highlights of what was included in the first release.) * Code generators: Python and C, both with OpenCL. * Higher-order ML-style module system. * In-place updates. * Tooling: futhark-test, futhark-bench, futhark-dataset, futhark-doc. * Beginnings of a basis library, "futlib". futhark-0.25.27/CITATION.cff000066400000000000000000000002721475065116200152350ustar00rootroot00000000000000cff-version: 1.2.0 message: "If you use this software, please cite it as below." authors: - given-names: "The Futhark Hackers" title: "Futhark" url: "https://github.com/diku-dk/futhark" futhark-0.25.27/CODE_OF_CONDUCT.md000066400000000000000000000062151475065116200161450ustar00rootroot00000000000000# Contributor Covenant Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at athas@sigkill.dk. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] [homepage]: http://contributor-covenant.org [version]: http://contributor-covenant.org/version/1/4/ futhark-0.25.27/CONTRIBUTING.md000066400000000000000000000026741475065116200156040ustar00rootroot00000000000000## How to contribute to the Futhark compiler * See [Get Involved](https://futhark-lang.org/getinvolved.html) on the main website. * Read [HACKING.md](HACKING.md) * Read [STYLE.md](STYLE.md) #### **Did you find a bug?** * **Ensure the bug was not already reported** by searching on GitHub under [Issues](https://github.com/diku-dk/futhark/issues). * If you're unable to find an open issue addressing the problem, [open a new one](https://github.com/diku-dk/futhark/issues/new). Be sure to include a **title and clear description**, as much relevant information as possible, and a **code sample** or an **executable test case** demonstrating the expected behavior that is not occurring. #### **Did you write a patch that fixes a bug?** * Open a new GitHub pull request with the patch. * Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable. * Before submitting, please read the [Style Guide](STYLE.md) to know more about coding conventions and benchmarks. #### **Do you intend to add a new feature or change an existing one?** * Open an issue to start a discussion (or join Gitter or IRC). It's best to talk these things through before spending a lot of time writing code. #### **Do you have questions about the source code?** * Ask any question about the code base on Gitter or IRC (#futhark on Libera.Chat). Please don't open GitHub issues just to ask questions. futhark-0.25.27/HACKING.md000066400000000000000000000313421475065116200147330ustar00rootroot00000000000000# Hacking on the Futhark Compiler The Futhark compiler is a significant body of code with a not entirely straightforward design. The main source of documentation is the Haddock comments in the source code itself, including [this general overview of the compiler architecture](https://hackage.haskell.org/package/futhark-0.24.3/docs/Futhark.html). To build the compiler, you need a recent version of [GHC](http://ghc.haskell.org/), which can be installed via [ghcup](https://www.haskell.org/ghcup/). Alternatively, if you [install Nix](https://nixos.org/download.html#download-nix) then you can run `nix-shell` to get a shell environment in which all necessary tools are installed. After that, run `make docs` to generate internal compiler documentation in HTML format. The last few lines of output will tell you the name of an `index.html` file which you should then open. Go to the documentation for the module named `Futhark`, which contains an introduction to the compiler architecture. For contributing code, see the [Haskell style guide](STYLE.md). If you feel that the documentation is incomplete, or something lacks an explanation, then feel free to [report it as an issue](https://github.com/diku-dk/futhark/issues). Documentation bugs are bugs too. ## Building We include a `Makefile` with the following targets. * `make build` (or just `make`) builds the compiler. * `make install` builds the compiler and copies the resulting binaries to `$HOME/.local/bin`, or `$PREFIX/bin` if the `PREFIX` environment variable is set. * `make docs` builds internal compiler documentation. For the user documentation, see the `docs/` subdirectory. * `make check` style-checks all code. Requires [GNU Parallel](https://www.gnu.org/software/parallel/). * `make check-commit` style-checks all code staged for a commit. Requires [GNU Parallel](https://www.gnu.org/software/parallel/). You can also use `cabal` directly if you are familiar with it. In particular, `cabal run futhark -- args...` is useful for running the Futhark compiler with the provided args. ### Enabling profiling Asking GHC to generate profiling information is useful not just for the obvious purpose of gathering profiling information, but also so that stack traces become more informative. Run make configure-profile to turn on profiling. This setting will be stored in the file `cabal.project.local` and all future builds will be with profiling information. Note that the compiler runs significantly slower this way. To produce a profiling report when running the compiler, add `+RTS -p` to the *end* command line. See also [the chapter on profiling](https://downloads.haskell.org/ghc/latest/docs/users_guide/profiling.html) in the GHC User's Guide. Note that GHCs code generator is sometimes slightly buggy in its handling of profiled code. If you encounter a compiler crash with an error message like "PAP object entered", then this is a GHC bug. ### Debugging compiler crashes By default, Haskell does not produce very good stack traces. If you compile with `make configure-profile` as mentioned above, you can pass `+RTS -xc` to the Futhark compiler in order to get better stack traces. You will see that you actually get *multiple* stack traces, as the Haskell runtime system will print a stack trace for every signal it receives, and several of these occur early, when the program is read from disk. Also, the *final* stack trace is often some diagnostic artifact. Usually the second-to-last stack trace is what you are looking for. ## Testing ### Only internal compilation This command tests compilation *without* compiling the generated C code, which speeds up testing for internal compiler errors: futhark test -C tests --pass-compiler-option=--library ### Running only a single unit test cabal run unit -- -p '/reshape . fix . iota 3d/' The argument to `-p` is the name of the test that fails, as reported by `cabal test`. You may have to scroll through the output a bit to find it. ## Debugging Internal Type Errors The Futhark compiler uses a typed core language, and the type checker is run after every pass. If a given pass produces a program with inconsistent typing, the compiler will report an error and abort. While not every compiler bug will manifest itself as a core language type error (unfortunately), many will. To write the erroneous core program to `filename` in case of type error, pass `-vfilename` to the compiler. This will also enable verbose output, so you can tell which pass fails. The `-v` option is also useful when the compiler itself crashes, as you can at least tell where in the pipeline it got to. ## Checking Generated Code Hacking on the compiler will often involve inspecting the quality of the generated code. The recommended way to do this is to use `futhark c` or `futhark opencl` to compile a Futhark program to an executable. These backends insert various forms of instrumentation that can be enabled by passing run-time options to the generated executable. - As a first resort, use `-t` option to use the built-in runtime measurements. A nice trick is to pass `-t /dev/stderr`, while redirecting standard output to `/dev/null`. This will print the runtime on the screen, but not the execution result. - Optionally use `-r` to ask for several runs, e.g. `-r 10`. If combined with `-t`, this will cause several runtimes to be printed (one per line). - Pass `-D` to have the program print information on allocation and deallocation of memory. - (`futhark opencl` and `futhark cuda` only) Use the `-D` option to enable synchronous execution. `clFinish()` or the CUDA equivalent will be called after most OpenCL operations, and a running log of kernel invocations will be printed. At the end of execution, the program prints a table summarising all kernels and their total runtime and average runtime. ## Using `futhark dev` For debugging specific compiler passes, the `futhark dev` subcommand allows you to tailor your own compilation pipeline using command line options. It is also useful for seeing what the AST looks like after specific passes. ### `FUTHARK_COMPILER_DEBUGGING` environment variable You can set the level of debug verbosity via the environment variable `FUTHARK_COMPILER_DEBUGGING`. It has the following effects: - `FUTHARK_COMPILER_DEBUGGING=1`: + The frontend prints internal names. (This may affect code generation in some cases, so turn it off when actually generating code.) + Tools that talk to server-mode executables will print the messages sent back and forth on the standard error stream. - `FUTHARK_COMPILER_DEBUGGING=2`: + All of the effects of `FUTHARK_COMPILER_DEBUGGING=1`. + The frontend prints explicit type annotations. ## Running compiler pipelines You can run the various compiler passes in whatever order you wish. There are also various shorthands for running entire standard pipelines: - `--gpu`: pipeline used for GPU backends (stopping just before adding memory information). - `--gpu-mem`: pipeline used for GPU backends, with memory information. This will show the IR that is passed to ImpGen. - `--seq`: pipeline used for sequential backends (stopping just before adding memory information). - `--seq-mem`: pipeline used for sequential backends, with memory information. This will show the IR that is passed to ImpGen. - `--mc`: pipeline used for multicore backends (stopping just before adding memory information). - `--mc-mem`: pipeline used for multicore backends, with memory information. This will show the IR that is passed to ImpGen. By default, `futhark dev` will print the resulting IR. You can switch to a different *action* with one of the following options: - `--compile-imp-seq`: generate sequential ImpCode and print it. - `--compile-imp-gpu`: generate GPU ImpCode and print it. - `--compile-imp-multicore`: generate multicore ImpCode and print it. You must use the appropriate pipeline as well (e.g. `--gpu-mem` for `--compile-imp-gpu`). You can also use e.g. `--backend=c` to run the same code generation and compilation as `futhark c`. This is useful for experimenting with other compiler pipelines, but still producing an executable or library. ## When you are about to have a bad day When using the `cuda` backend, you can use the `--dump-ptx` runtime option to dump PTX, a kind of high-level assembly for NVIDIA GPUs, corresponding to the GPU kernels. This can be used to investigate why the generated code isn\'t running as fast as you expect (not fun), or even whether NVIDIAs compiler is miscompiling something (extremely not fun). With the OpenCL backend, `--dump-opencl-binary` does the same thing. On AMD platforms, `--dump-opencl-binary` tends to produce an actual binary of some kind, and it is pretty tricky to obtain a debugger for it (they are available and open source, but the documentation and installation instructions are terrible). Instead, AMDs OpenCL kernel compiler accepts a `-save-temps=foo` build option, which will make it write certain intermediate files, prefixed with `foo`. In particular, it will write an `.s` file that contains what appears to be HSA assembly (at least when using ROCm). If you find yourself having to do do this, then you are definitely going to have a bad day, and probably evening and night as well. ## Minimising programs Sometimes you have a program that produces the wrong results rather than crashing the compiler. These are some of the most difficult bugs to handle. If the result is at least deterministic and you have some way of compiling the program that does work (either an older version or a different backend), then the following procedure is useful for reducing the program as much as possible. Suppose that we are trying to debug a miscompilation for the `opencl` backend where the `c` backend works, the failing program is `prog.fut`, and the input data is `prog.in`. Write the following script `test.sh`: ``` set -x set -e futhark c prog.fut -o prog-c futhark opencl prog.fut -o prog-opencl cat prog.in | ./prog-c -b > output-c cat prog.in | ./prog-opencl -b > output-opencl futhark datacmp output-c output-opencl ``` This compares the results obtained from running the program with the two compilers. You can now (manually) start removing parts of `prog.fut` while regularly rerunning `test.sh` to verify that it still fails. In particular, you can easily remove program return values, which is not the case if you are comparing against a fixed expected output. Eventually you will have a hopefully small program that produces different results with the two compilers, and you can look in detail at the IR to figure out what goes wrong. ## Graphs of internal data structures Some passes can prettyprint internal representations in [GraphViz](https://graphviz.org/) format. For example, to see the fusion graph (prior to fusion), do $ futhark dev -e --inline-aggr -e foo.fut --fusion-graph > foo.dot and then to render `foo.dot` as `foo.dot.pdf` with GraphViz: $ dot foo.dot -Tpdf -O ## Using Oclgrind [Oclgrind](https://github.com/jrprice/oclgrind) is an OpenCL simulator similar to Valgrind that can help find memory and synchronisation errors. It runs code somewhat slowly, but it allows testing of OpenCL code on systems that are not otherwise capable of executing OpenCL. It is very easy to run a program in Oclgrind: oclgrind ./foo For use in `futhark test`, we have [a wrapper script](tools/oclgrindgrunner.sh) that returns with a nonzero exit code if Oclgrind detects a memory error. You use it as follows: futhark test foo.fut --backend=opencl --runner=tools/oclgrindrunner.sh Some versions of Oclgrind have an unfortunate habit of [generating code they don't know how to execute](https://github.com/jrprice/Oclgrind/issues/204). To work around this, disable optimisations in the OpenCL compiler: futhark test foo.fut --backend=opencl --runner=tools/oclgrindrunner.sh --pass-option=--build-option=-O0 ## Using `futhark script` The `futhark script` command is a handy way to run (server-mode) executables with arbitrary input, while also seeing logging output in real time. This is particularly useful for programs whose benchmarking input are complicated FutharkScript expressions. If you have a program `infinite.fut` containing ```Futhark entry main n = iterate 1000000000 (map (+1)) (iota n) ``` then you can run ``` $ futhark script -D infinite.fut 'main 10i64' ``` to run it with debug prints. You can also use `-L` instead of `-D` to just enable logging. The `main 10i64` can be an arbitrary FutharkScript expression. The above will compile `infinite.fut` using the `c` backend before running it. Pass a `--backend` option to `futhark script` to use a different backend, or pass an already compiled program instead of a `.fut` file (e.g., `infinite`). See the manpages for `futhark script` and `futhark literate` for more information. futhark-0.25.27/LICENSE000066400000000000000000000013771475065116200143570ustar00rootroot00000000000000ISC License Copyright (c) 2013-2022. DIKU, University of Copenhagen Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. futhark-0.25.27/Makefile000066400000000000000000000022221475065116200150000ustar00rootroot00000000000000# This Makefile mostly serves to abbreviate build commands that are # unnecessarily obtuse or longwinded. It depends on the underlying # build tool (cabal) to actually do anything incrementally. # Configuration is mostly read from cabal.project. PREFIX?=$(HOME)/.local INSTALLBIN?=$(PREFIX)/bin/futhark UNAME:=$(shell uname) # Disable all implicit rules. .SUFFIXES: .PHONY: all configure build install docs check check-commit clean all: build configure: cabal update cabal configure configure-profile: cabal configure --enable-profiling --profiling-detail=toplevel-functions build: cabal build install: build install -d $(shell dirname $(INSTALLBIN)) install "$$(cabal -v0 list-bin exe:futhark)" $(INSTALLBIN) docs: cabal haddock \ --enable-documentation \ --haddock-html \ --haddock-options=--show-all \ --haddock-options=--quickjump \ --haddock-options=--show-all \ --haddock-options=--hyperlinked-source check: tools/style-check.sh src unittests check-commit: tools/style-check.sh $$(git diff-index --cached --ignore-submodules=all --name-status HEAD | awk '$$1 != "D" { print $$2 }') unittest: cabal run unit clean: cabal clean futhark-0.25.27/README.md000066400000000000000000000035051475065116200146240ustar00rootroot00000000000000 The Futhark Programming Language ========== [![Join the chat at https://gitter.im/futhark-lang/Lobby](https://badges.gitter.im/futhark-lang/Lobby.svg)](https://gitter.im/futhark-lang/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)[![CI](https://github.com/diku-dk/futhark/workflows/CI/badge.svg)](https://github.com/diku-dk/futhark/actions)[![DOI](https://zenodo.org/badge/7960131.svg)](https://zenodo.org/badge/latestdoi/7960131) Futhark is a purely functional data-parallel programming language in the ML family. It can be compiled to typically very efficient parallel code, running on either a CPU or GPU. The language is developed at [DIKU](http://diku.dk) at the University of Copenhagen, originally as part of the [HIPERFIT centre](http://hiperfit.dk). It is quite stable and suitable for practical programming. For more information, see: * [A collection of code examples](https://futhark-lang.org/examples.html) * [Installation instructions](http://futhark.readthedocs.io/en/latest/installation.html) * [The main website](http://futhark-lang.org) * [Parallel Programming in Futhark](https://futhark-book.readthedocs.io/en/latest/), an extensive introduction and guide * [The Futhark User's Guide](http://futhark.readthedocs.io) * [Documentation for the built-in prelude](https://futhark-lang.org/docs/prelude) * [Futhark libraries](https://futhark-lang.org/pkgs/) [![Packaging status](https://repology.org/badge/vertical-allrepos/futhark.svg)](https://repology.org/project/futhark/versions) Hacking ======= Issues tagged with [good first issue](https://github.com/diku-dk/futhark/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) do not require deep knowledge of the code base. For contributing code, see [the hacking instructions](HACKING.md). futhark-0.25.27/STYLE.md000066400000000000000000000232301475065116200145640ustar00rootroot00000000000000Futhark Compiler Style Guide ============================ This document provides guidelines and advice on how to write and format code will fit with the existing code base and ensure some degree of consistency. Some of these rules are enforced by the automatic quality checker that is run after each push, but most are not. Not all of the current code follows these guidelines. If you find yourself working on such code, please reformat it while you are there. When something isn't covered by this guide you should stay consistent with the code in the other modules. There are two sets of rules in this document. *Style Rules*, which low-level details such as the amount of indentation and how to name things, and *Design Rules*, which are less specific high-level guidelines on some of the principles underlying the compiler design. Style Rules =========== Ormolu ------ Futhark uses [Ormolu](https://github.com/tweag/ormolu/) to enforce consistency in formatting style. The style is checked by our CI, so you should make sure any pull requests comply with Ormolus formatting before trying to merge. ### Installation Installing Ormolu can be done through Cabal, Stack, Nix or the Arch Linux package manager. For instance, to install through Cabal, run the following command: ``` cabal install ormolu ``` If you're running Nix or NixOS, you can just use `nix-shell` to enter a development environment with ormolu already installed. ### Basic Usage The following command formats a single file: ``` ormolu -i FILE.hs ``` This command can be used to format all Haskell files in Futhark, while checking that the formatting is idempotent: ``` ./tools/run-formatter.sh src unittests ``` The idempotence check is mostly done to make sure Ormolu (which is still a young tool in active development) doesn't introduce any unnecessary changes. Any idempotence errors should be reported upstream. ### Editor Integration Emacs has [ormolu.el](https://github.com/vyorkin/ormolu.el) to help with formatting. Once installed (as per the instructions on that page), it will automatically format any open Haskell file on save. The Ormolu README lists further integrations for VS Code and vim. ### Limitations Ormolu doesn't handle all aspects of coding style. For instance, it will do no significant rewrites of your code, like if-then-else to pattern matches or enforcing the 80 character line limit, but it will ensure consistency in alignment and basic formatting. Therefore, as a Futhark contributer, use Ormolu to ensure basic style consistency, while still taking care to follow the more general style rules listed below. Formatting ---------- ### Line Length Maximum line length is *80 characters*. Ormolu doesn't enfore the 80 character line limit, so it is up to the user to introduce the necessary line breaks in the code. However, Ormolu will take a hint. Imagine you've got the following line of code: ```haskell onKernels :: (SegOp SegLevel KernelsMem -> ReuseAllocsM (SegOp SegLevel KernelsMem)) -> Stms KernelsMem -> ReuseAllocsM (Stms KernelsMem) ``` If the user introduces a newline before `ReuseAllocsM`, turning the above into the following: ```haskell onKernels :: (SegOp SegLevel KernelsMem -> ReuseAllocsM (SegOp SegLevel KernelsMem)) -> Stms KernelsMem -> ReuseAllocsM (Stms KernelsMem) ``` Ormolu will pick up the hint and reformat the entire declaration to this: ```haskell onKernels :: (SegOp SegLevel KernelsMem -> ReuseAllocsM (SegOp SegLevel KernelsMem)) -> Stms KernelsMem -> ReuseAllocsM (Stms KernelsMem) ``` ### Blank Lines In large `do`-blocks, separate logically separate chunks of code with a single blank line. ### Long Expressions Long expressions should be split over multiple lines. If splitting a definition using the `$` operator, Ormolu will incrementally add more indentation, which may sometimes be undesirable. For instance, the following expression: ```haskell someAtrociouslyLongVariableName <- someFunction $ someOtherFunction withSomeVar $ someThirdFunction something $ map somethingElse ``` Will turn in to: ```haskell someAtrociouslyLongVariableName <- someFunction $ someOtherFunction withSomeVar $ someThirdFunction something $ map somethingElse ``` If you'd rather keep everything equally nested, consider using the `&` operator instead, which is like a reverse `$`. The code above is semantically identical to this: ```haskell someAtrociouslyLongVariableName <- map somethingElse & someThirdFunction something & someOtherFunction withSomeVar & someFunction ``` ### Export Lists Format export lists as follows: ```haskell module Data.Set ( -- * The @Set@ type empty, Set, singleton, -- * Querying member, ) where ``` ### If-then-else clauses Generally, guards and pattern matches should be preferred over if-then-else clauses, where possible. E.g, instead of this: ```haskell foo b = if not b then a else c ``` Prefer; ```haskell foo False = a foo True = c ``` Short cases should usually be put on a single line (when line length allows it). ### Pattern matching Prefer pattern-matching in function clauses to `case`. Consider using [view patterns](https://ghc.haskell.org/trac/ghc/wiki/ViewPatterns), but be careful not to go overboard. Imports ------- Try to use explicit import lists or `qualified` imports for standard and third party libraries. This makes the code more robust against changes in these libraries. Exception: the Prelude. Comments -------- ### Punctuation Write proper sentences; start with a capital letter and use proper punctuation. ### Top-Level Definitions Comment every top level function (particularly exported functions), and provide a type signature; use Haddock syntax in the comments. Comment every exported data type. Function example: ```haskell -- | Send a message on a socket. The socket must be in a connected -- state. Returns the number of bytes sent. Applications are -- responsible for ensuring that all data has been sent. send :: Socket -- ^ Connected socket -> ByteString -- ^ Data to send -> IO Int -- ^ Bytes sent ``` For functions the documentation should give enough information to apply the function without looking at the function's definition. Record example: ```haskell -- | Bla bla bla. data Person = Person { age :: !Int -- ^ Age , name :: !String -- ^ First name } ``` For fields that require longer comments format them like so: ```haskell data Record = Record { -- | This is a very very very long comment that is split over -- multiple lines. field1 :: !Text -- | This is a second very very very long comment that is split -- over multiple lines. , field2 :: !Int } ``` Naming ------ Use camel case (e.g. `functionName`) when naming functions and upper camel case (e.g. `DataType`) when naming data types. Use underscores to separate words in variables and parameters. For compound names consisting of just two words, it is acceptable to not separate them at all, e.g. `flatarrs` instead of `flat_arrs`. If a variable or parameter is also a function, use your judgement as to whether it is most like a function or most like a value. For readability reasons, don't capitalize all letters when using an abbreviation. For example, write `HttpServer` instead of `HTTPServer`. Exception: Two letter abbreviations, e.g. `IO`. Use concise and short names, but do not abbreviate aggresively, especially in complex names. E.g, use `filter_result_size`, not `flt_res_sz`. Misc ---- ### Functions Avoid partial functions like `head` and `!`. They can usually (always in the case of `head`) be replaced with a `case`-expression that provides a meaningful error message in the "impossible" case. Do not use `map`, `mapM`, `zipWithM` or similar with a nontrivial anonymous function. Either give the function a name or use `forM`. ### Literate Haskell Never use Literate Haskell. ### Warnings Code should be compilable with `-Wall -Werror`. There should be no warnings. `hlint` should not complain (except for a few rules that we have disabled - see `tools/style-check.sh`). ### Braces and semicolons Never use braces and semicolons - always use whitespace-based layout instead (except for generated code). ### Prefer `pure` to `return` When writing monadic code, use `pure` instead of `return`. Design Rules ============ * We try not to use too many crazy language extensions. Haskell is merely the implementation language, so we try to keep it simple, and we are not trying to push GHC to its limits. Syntactic language extensions are fine, as are most extensions that are isolated to the module and do not show up in the external module interface. Do not go overboard with type system trickery. The trickery we do have already causes plenty of pain. * Be aware of [boolean blindness](https://existentialtype.wordpress.com/2011/03/15/boolean-blindness/) and try to avoid it. This helps significantly with avoiding partial functions. E.g, if a list must never be empty, represent it as a pair of an element and a list instead of just a list. Notes ===== For long comments, we (try to) use the Notes convention from GHC, [explained here](https://www.stackbuilders.com/news/the-notes-of-ghc). Essentially, instead of writing very long in-line comments that break the flow of the code, we write ```haskell -- See Note [Foo Bar] ``` and then somewhere else in the file (perhaps at the bottom), we put ```haskell -- Note [Foo Bar] -- -- Here is how you foo the bar... ``` There is no automation around this or a hard rule for what a "long comment" is. It's just a convention. Credits ======= Based on [this style guide](https://github.com/tibbe/haskell-style-guide/blob/master/haskell-style.md). futhark-0.25.27/Setup.hs000077500000000000000000000002131475065116200147750ustar00rootroot00000000000000#!/usr/bin/env runhaskell import Distribution.Simple main :: IO () main = defaultMainWithHooks myHooks where myHooks = simpleUserHooks futhark-0.25.27/assets/000077500000000000000000000000001475065116200146445ustar00rootroot00000000000000futhark-0.25.27/assets/logo.svg000066400000000000000000000061151475065116200163300ustar00rootroot00000000000000 image/svg+xml futhark-0.25.27/assets/ohyes.png000066400000000000000000000440361475065116200165100ustar00rootroot00000000000000PNG  IHDRsBIT|d IDATxw|ڀnz%" r(( >zkQ"(6H-tBHo[f|̮lMݐ>jvfNa9-"D!5uBX±~=.Dgābc\W$i;uԒ D^KR&\Ue_Ѝ7!V%NȻ{!.gm屧PmMcu(c+'רʦ@TTa۸J{vЊ9!NȑFjK^wj1Qr | ,wX9>*ub9*(|-jA UK6ضj1P6C gѹɻi"s_-q76H!5?ώ1s| TDn*u']y gk@svQ>V$P.[8*~F, ƌdxSjdK U#aZ 8^*ڏHMҶWE (ET<5?eBrDiV\8vpOE6r'ڼ6; <*:N>SӝW(dڼZM!P?S@+o݊}\=CǤu~%}QrwruoОuiNK*oG>6/_q7(q_c͗Fa=q{I9tg G=ʿ!d/?)`;,Be͸IJԕu^VC/Do#[ +qp(R˩oZ3v-9 Ftm ?^S㞔$G#croi0DĐ"Éq 6 )yo!U}"DZѠzkKEIc;][$G{7{jզMdndU UP6^28'Yӹt1ۚC jv R*OL!~O ΢GMl$$ca9bO&겳!̶; *)wla[=QPv6@[<};/G9וCߞB (x_@UiWf2rr9 wC([*!Gm_5v$AXQm $PG1q8ro+;J84 Llv9D^4id@-B)JT>$am@Ƣ]P2.G98a'E&3B9 ؖd:-0u);$g֑=֕/,SOBLq>9l&n bY## )ap 1cjUZ.ӕnjRŻw R@(!>:~)qF84r܍c;~%BGsJݧ܃}ϗycDZW@X cH2Sl ~A‘`ޏ9Sj w{UلP--P([M}t %)o_sF=\Ό3}1~D7܄*?*lO]%Xzv=0bKEqGozr=V4_D7bԟUjG ;ku@V:CZSw^HSH|7)k)'f .E2k(#8,H*3qwS*G:ao7_], $w$c~) IwIPTd%cŁZ U)Щ`.$+E#[hpRXנ5mSqFq4ʦZ2Oj"!̧(!\뵆C3dHn*',Ø6*Iw:P4DÂ!? Mv#VdT%;ЭPB.W@2f5Q=T-?a]ZLXL,aG`| EhQ[/jZae5GcN9^`L"$rx-Sˮ*k>DɗR4,g= 0,ӯ͸EZVVSML]f"Y7uCIH:XQ`eU80jpd}N/{C 1cVI6bZ@oKnkh}j i&$Čzg Y+mk$ˀV( w_KQ0k!j1_Et._E1D+pa4AWAu;P? (;47Zƾi h4|/n9^^ڝHa'crP9=Iu PR y[ؤ Q=1`>sv!T3źڽ@Ud:Ct̽a>D8SVR.kBʅ9|ۋ8M&$H $OGWHنmvrJu@Hc\ =YkVkW*@e)~*gV.nHD*c`,k&f$R'_sf:u˓@?'x=}pf-P1㸜q Pd~3V$TL*UJ?ZmVC9a]ȝ>8XnW6`} 2`ZطheH?FnkG *ph;@=Ǿ'{֡>ixfa?^~H_έW$c6[)?ƽZr1;6*ֿ+6E:0C@|x΋PN;p^z* BҋAm UP6 4A寫QE-ww :?|tABPd5w 6:|t.-p&&@QPx-Q q *Ђ9ٲCptn:q/uAW9k֡ -o5n#T%7+a=^zR<;J?[m5J]|g Zd!2À7p QgO JC0eŒ>EрO9HBNv/ u*=Ϗ/(jBկ[k۰]WM9D 0wK7w0*2Lvk)8]뷦I{dh+PWVWu Vܬ~O=ʾODE.!o%P󋔈d@NꎱC'mf2cjUcoq<[Us8::]NS Tپ蕽{+: @b]QWk}*@8tFvLĐ0D\\h{Ժ.[p4u먦$#Zu|+]C5ƶWh~yfKuUPJv}S(dv:nĥb S&8Uk+EEoaTN&TWŁ+ag+9|he ~˦nP)sp1{վ.Q M"vaB)(=@=Wr0A=07@ڗTo11e Rg0 j^Sx]0z3kQFth~YU~ӶF^8)[RqO)К?C2=K֢[QTng&9࿧8|:W.9mz$ӉM @Gj5)iji4?<k^F;3II50 GsENK&6r6.נ]) ]v{c+Gc?2w Nrƒŏn`}3i HIČ;x0 ]IqT? D!r(%:(lhCD3SZ!hWHMٔo^zCmG#j8R?&b} ]I vGE(}tZL0²r.*H,̽"{|2#TFɤOǎeZy4"ft iyx6V XW6ȉ|BtI|v2 cTR%,Q ~Q8v}L*C ):4I zzS$,|@ަrK N9vj[$Kė72^seY<6nO|і0l"ݯ?@;躶ǚK5f]~VKH.]klPr NH%s|H&öa?%[S-KD0[]ƾqBcuHNt*܌}4P'ނc$! #TpwkswrԈ p!+P{W)k ve[JM{,B+vb[ 8vmE+\P bDLij{Etk0v= @+!+g%4½BE_9 Tz82 =b苀@cn.L}w@G{M7+h"4Tw"n~=qhRbYG@IQ=M%Pa#@3Qg J8 @$􇶉ʽmF׫v0h1)(5Dk0y9uX TzG9TO'/)Ml,K_çELPfN(1Qkׄ݁cׯ *@hN-b2~ 3j)\F~ASg~P}*DK8ko&Kj]ʾDjgITmK"#i3?D]{x.s+ۦyۺ{a)c@ 1DE1\ԋuˆ5R\6kA-w`pҏa]Еfr2Q>L"? \o÷NZҦ~̶&l (ze_Pֵ@p"!I:$S>K[c־b2~<ڍfYԹs.&veFe²<UOb[ S@#1e"b DXZ׊wR3he{Mu.o_ aU}bFC]ꗐ M Vľ5isQ 9 vqXP!X'ro !j"Τֹnn\kʾBJ?YHI/@O@첅D^lƯ:$Pn0o6c{??ñ}Z6rG|+{|MmLfoܣ!YB+VGB՚1DOi#BO LZIܭ4;C4d{S*Dq٠~SǾi!º#>DjG{$K:Ǿ}U GJh QG$Y"i~2c٘ܬvM'9f] CnBxGF Q&m} o V!jTNC_8yA=1=u>1ŐF-gzjҠB4CԂ_tܔd=`7Wьta'_^lDZ5h!CպȻ6M%.;#؍ʤpiJ$bF{%7XS5Y*8&Tp͡s:"eEnG=IL&SɁ/웍zU&(|h-Oq> m7 P^¶#zQxtjR\nHp`[ǜ0qD(/#h> D5`[sl[M,.NG,}O1'؊y`_ ;@h\$YBN퉲P5`T$^hGQ#L٩{$?nu2=vd*Dc`s\f s [Hh+; fgMn*D/4+5SM,[F x*wA.OMΨH{M &0@@V'Ȅԛo)X]u `jSsWdf†%AQ__ag8z ] 'O`؉HM} Z0?Y]SY ?(g!u>d>@&@l`{,dwQI%̭ls3.Q5_QF(" q eZ^-0!|> .B_aʙpj /KCU?ͧ 9ٗlE3Ӯ3Agz8,q} 8r'|v|s "4 <ħi4$jI܅F8CI1v,Ka:EDGHh>Jym`|9GU[m1r*  !O1ApTx ̚ `xnǦ;:dcv5L: wJz:#,pe۰T-S Xp9#C;=sK+yblO!X" s&) Xõ>8N(Pg޴' =;xx:<AyMa5|}08)>r>}V/Tо'ía266Ln@xHcXI#Ą80@ۮp&  1c9aAu#kds~!75SvA$Mh% Y6\xtLz~ r` vj_[*(BkWÓwzB*D[!LGP: t7 Z>[쩐} Nxx7.T;GP%YR(mTf~>t0yٴuRO`QVx?hڦ!j _]o8rf,arTf:w,?Fb-۠86F˜{@ +#r22*a΃U2ѵiH,A ?r*5vd$KI,*RaRA N87BX=m * a}]n^Hp2ܹЊ Gps83cx#nqHNEXU3:A rp<? ׆k`( 9GaQi9|lE儋-a(| lψVn/*D؟` (^3@Ld|q]yO 4%ڟAؐèf"A!ƩSΤN(yޟ\(rWBT'Tip%*DHmauA / 'Y>$6(~Z!X G&8an1VA[ {T˄MQK}Tq-o 7.%7.Kc!𢫿*h`Rov!  z#vP|/pMŏT.E⭸;ؤm -] aQ, r/]=?׊(#t8O hdqI ?j{E;6 Kx~́cC`|HB=Wi.z᫰2sPo:gD?ٚ[ &TW2FB? GP @M}~wշ ρ4)H٪0|Zzpm>G[5΂04$P!\6ÑCP1^4@[kH8Vk 4C7{g=Xϰ= .L N";6p1  ]o5VPFININ^ۛ Gmv `m#CBJ+_p}fy > ( H h\x>No^xm2Oc4|IטyH=jg|I[O!@Iފ >ԻQ@)@ ЮrA!A~3v)_ymaX[ RB F_#Lwg{Pm2]@a2j3X 7Y%~}_sHF/W(oN(r Xݐ@Aydx82 R+Mf8. r.Ѷ>a̶ r0fňɎrIHu5(CO.ubxN;O$c7|BVݿa|K]M$2}~4@7Bwksw*S?B $5s**r]`1[0kFKS em`u,̚wG뻸U$_nj6Q8]@ܭ ew Gs; 뮅ZԂ+Qau'p]ú@Eo.Sj`eFOűܠsax D,1gjSNmX|K{!Z.;+"}GfM ק:5j(87^vs**)"xť@~w4P *AypX.)p^m߆oo^vm*93'·>ͅt' ȃ޳}ƾ~q7Nc Tc{1ߟ-PK-3*`sg7`^/h sS>Whȷ:lB u*72)̈́mڏaTs[x0nu hN[? ]NjCLFʤUH{tSAk#\W{*h1nZOs%`r X>ʏgRރ `ڕ(t$S&=3ѧZQΐC&`XmZيFdSxUFV; a~w= ni9p Z41Bb7 hm0.x#.xV}W:V/`V?e}O`  M<[/*~L,w-r2lZS+Z8W[@ߌohD 0 ϻ~#:- NP g+_Z]᠀¨\s}X/xO՘e#( u{=Zp/v:B G6p}MB4`o&AkCU@&[΁~t xX  D:֫0Ar7W$V~f Lq݃+.po \Կ#D#`P0 _-8ߋ,&r$}+&8i( D' ~ xaS2\X@N(6AHXԿ!D# CYlfܿGav(vNV&yg> !WsXgpo)<ܧU}x;E W{aK7%i]&-]|=ֶ_ׇ>𤷃jS@[σ$?V`Z>TNeS͕88ۜyPC$ w$_5e!e5Pl Y #򡪺ul-z ~sso#6+= '¿76;KJX<⺰7<@sK=PK(tXk]q,H^WV ;υCC\Vt+ݽ\㝱%J܅JX2d}2d~ k]/+Sˍ0U ~88.N/|25dޅeE`[ݴз~=ȋ%P1(,gc.+[ ʭ>D !nV@ۀ?\kٰ3IW; jf %66:=xM07cXI]Bv gÎ#PFy7AEfuF:Tӌ9.rE';dLptqK~ ܫv5|T[N Bc.7t2 B UM`u~iQe)np\6{_L3[sJ(@<?5EC40l;H^ہ4x&Asa;.eX{?̈mk^)T#[)CG,3B!VUe2 ,`6(ɃjMrTx"/~m-x7 k[!aC`5C7s]@#QwE𚿺SaKh*qoni.NEVeX4uO>2D Nme 9HɩsB?TQB]ux]0BV1X߇ftxKvj\Af-D &@T+HO=`1é,.R`Z[F<W#}2rurO h+ ѩuWw &6t^W@hpaE|,Oc"A췰 .pDX/v{:\UT ?EB".ݑWV uzXKMKҧ_Tm?]A.([dbxjF029 @ \e V,eɰk'A x}3l"?d֊8$<-h8{o#`f) *. - :L.a 6n<8$Wo:=MW;I!j]ȷTb LE`h @S`zvtr-:y´ @U \V@z'L7×!]+$z*U"Ab3r9t|Fn)8ILs.ii3ʞ΀4`΢ar [皫b",$Dqiy :-\OaXSY1V҂n q|=jP9^:MLvq t0 C3(@ ]\G-= n /`S(%` BQ.dnySGa[, 6@lHxQu@MC省F=\=-Q )qme`>XWG@-P`&Is{  .Kp[\@=_;j.ܓ\(jae'^A+kv ̐Dtf*h3`S\mmclb产C^4\ops[@bXkf񗝱Akd[Ne{RajH*`r&4{ &ipu\Q2H?EwnDYW C`mKaf;.Lp4~WpD!X,x0_ϒ BW5/ h`+jk  xa9|8+gmqr^QXݽN I?k煠N{A=M3ܺjzu+3mIDATyPvL~>ݝ& ^}.cb@R -jPؚGI<<īhLe E.(p qCn~@'p:C26GyN)_\ ņ1Ne\t;|Iң2n_ټuL_Vu.ؔz@(Ij3wQ"<z3'W|Q X߅q'\֛?HUS o%5Q籌{s6V'Wqu:ԟ?|3ïesa~J/4&YʞC骰j_Õdm-bQ:Ć$-12a'܂ W^a}؞$}׿\+#lM(GQOVRLDYfVQݸ4 z13?d) )YZbP؊~_‰te<}+/dϓ9w /Ley{fw"Mro(k"R^NaqvP VtҮ͍>\zﰫAsМAQ8>A턯qMCtly>Xv0pU粍b52Ylj#88.go ޞ١{sGCc_}q2%_=krϗ0ҥl:*c3ECjoih$UHAq  t)ޤ5հv*ztYnĴS| 8rkӉ8+rg-+Ҩi]OG$57J4o$rg"0IUcb?e4^lo(Q FB%<4elV}yPQM/yeL/ A;z]Iүumr311-}^|YØ-C)Iҫ¸@&!|$tv}gUGr$On uQN,$]Ŗx7vIKt7sپOgԼî)^g*55i}'K'T әE}\ 'c?Ər ?ci *Iz Rٜ)jCty%1j  ߣ1^ˣDa۩+4IWF,gsԹk{9FQ#&,}x=[(*<ª33# jN ?w/Zzy|_8*Cq ʹlBgLҚV?[ͭ5G*$[1M2TuJ H _3X~ 7$f4=Ĵ%l8]PLy2Ǘ q&ȣleOmTβ,`i/JYDh~vց\LŋlXA˧)9$1vA\Бs:rhfLj>gP4Di0Ue㶋v1ڿѵv*Bb3_)/+m2s%_2 .K‚ji4oa}S2u:ҦKw`=;帜E/{`nLVg~/A渣0 H rAѱj>ɤu,kE#9 )ŅM4ܳ;n+Ȕi8P}ԡAF~ͯ<)=!IENDB`futhark-0.25.27/cabal.project000066400000000000000000000002521475065116200157730ustar00rootroot00000000000000packages: futhark.cabal index-state: 2025-01-10T04:07:23Z package futhark ghc-options: -j -fwrite-ide-info -hiedir=.hie allow-newer: base, template-haskell, ghc-prim futhark-0.25.27/default.nix000066400000000000000000000131021475065116200155030ustar00rootroot00000000000000# This default.nix builds a tarball containing a statically linked # Futhark binary and some manpages. Likely to only work on linux. # # Just run 'nix-build' and fish the tarball out of 'result/'. # # For the Haskell dependencies that diverge from our pinned Nixpkgs, # we use cabal2nix like thus: # # $ cabal2nix cabal://sexp-grammar-2.2.1 > nix/sexp-grammar.nix # # And then import them into the configuration. Although note that # Nixpkgs also tends to contain the newest version of each Hackage # package, even if it is not the default. # # To update the Nixpkgs snapshot (which also includes tooling), use: # # $ niv update nixpkgs -b master # # Also remember this guide: https://github.com/Gabriel439/haskell-nix/blob/master/project1/README.md { suffix ? "nightly", commit ? "" }: let config = { packageOverrides = pkgs: rec { # Very ugly hack to use an older version of elfutils, as the # newest apparently does not work with static linking. elfutils191 = pkgs.callPackage ./nix/elfutils191.nix {}; haskellPackages = pkgs.haskellPackages.override { overrides = haskellPackagesNew: haskellPackagesOld: rec { futhark-data = haskellPackagesNew.callPackage ./nix/futhark-data.nix { }; futhark-server = haskellPackagesNew.callPackage ./nix/futhark-server.nix { }; futhark-manifest = haskellPackagesNew.callPackage ./nix/futhark-manifest.nix { }; zlib = haskellPackagesNew.callPackage ./nix/zlib.nix {zlib=pkgs.zlib;}; futhark = # callCabal2Nix does not do a great job at determining # which files must be included as source, which causes # trouble if you have lots of other large files lying # around (say, data files for testing). As a workaround # we explicitly tell it which files are needed. This must # be _manually_ kept in sync with whatever the cabal file requires. let sources = ["futhark.cabal" "Setup.hs" "LICENSE" "^src.*" "^rts.*" "^docs.*" "^prelude.*" "^assets.*" "^unittests.*" ]; cleanSource = src: pkgs.lib.sourceByRegex src sources; in pkgs.haskell.lib.overrideCabal (pkgs.haskell.lib.addBuildTools (haskellPackagesOld.callCabal2nix "futhark" (cleanSource ./.) { }) [ pkgs.python312Packages.sphinx ]) ( _drv: { isLibrary = false; isExecutable = true; enableSharedExecutables = false; enableSharedLibraries = false; enableLibraryProfiling = false; configureFlags = [ "--ghc-option=-Werror" "--ghc-option=-split-sections" "--ghc-option=-optl=-static" "--extra-lib-dirs=${pkgs.ncurses.override { enableStatic = true; }}/lib" # Static linking crud "--extra-lib-dirs=${pkgs.glibc.static}/lib" "--extra-lib-dirs=${pkgs.gmp6.override { withStatic = true; }}/lib" "--extra-lib-dirs=${pkgs.libffi.overrideAttrs (old: { dontDisableStatic = true; })}/lib" # The ones below are due to GHC's runtime system # depending on libdw (DWARF info), which depends on # a bunch of compression algorithms. "--ghc-option=-optl=-lbz2" "--ghc-option=-optl=-lz" "--ghc-option=-optl=-lelf" "--ghc-option=-optl=-llzma" "--ghc-option=-optl=-lzstd" "--extra-lib-dirs=${pkgs.zlib.static}/lib" "--extra-lib-dirs=${(pkgs.xz.override { enableStatic = true; }).out}/lib" "--extra-lib-dirs=${(pkgs.zstd.override { enableStatic = true; }).out}/lib" "--extra-lib-dirs=${(pkgs.bzip2.override { enableStatic = true; }).out}/lib" "--extra-lib-dirs=${(elfutils191.overrideAttrs (old: { dontDisableStatic= true; })).out}/lib" ]; preBuild = '' if [ "${commit}" ]; then echo "${commit}" > commit-id; fi ''; postBuild = (_drv.postBuild or "") + '' make -C docs man ''; postInstall = (_drv.postInstall or "") + '' mkdir -p $out/share/man/man1 cp docs/_build/man/*.1 $out/share/man/man1/ mkdir -p $out/share/futhark/ cp LICENSE $out/share/futhark/ ''; } ); }; }; }; }; sources = import ./nix/sources.nix; pkgs = import sources.nixpkgs { inherit config; }; futhark = pkgs.haskellPackages.futhark; in pkgs.stdenv.mkDerivation rec { name = "futhark-" + suffix; version = futhark.version; src = tools/release; buildInputs = [ futhark ]; buildPhase = '' cp -r skeleton futhark-${suffix} cp -r ${futhark}/bin futhark-${suffix}/bin mkdir -p futhark-${suffix}/share cp -r ${futhark}/share/man futhark-${suffix}/share/ chmod +w -R futhark-${suffix} cp ${futhark}/share/futhark/LICENSE futhark-${suffix}/ [ "${commit}" ] && echo "${commit}" > futhark-${suffix}/commit-id tar -Jcf futhark-${suffix}.tar.xz futhark-${suffix} ''; installPhase = '' mkdir -p $out cp futhark-${suffix}.tar.xz $out/futhark-${suffix}.tar.xz ''; } futhark-0.25.27/docs/000077500000000000000000000000001475065116200142725ustar00rootroot00000000000000futhark-0.25.27/docs/.gitignore000066400000000000000000000000341475065116200162570ustar00rootroot00000000000000# Sphinx build folder _buildfuthark-0.25.27/docs/Makefile000066400000000000000000000151561475065116200157420ustar00rootroot00000000000000# Makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) endif # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . # the i18n builder cannot share the environment and doctrees with the others I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext help: @echo "Please use \`make ' where is one of" @echo " html to make standalone HTML files" @echo " dirhtml to make HTML files named index.html in directories" @echo " singlehtml to make a single large HTML file" @echo " pickle to make pickle files" @echo " json to make JSON files" @echo " htmlhelp to make HTML files and a HTML help project" @echo " qthelp to make HTML files and a qthelp project" @echo " devhelp to make HTML files and a Devhelp project" @echo " epub to make an epub" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" @echo " text to make text files" @echo " man to make manual pages" @echo " texinfo to make Texinfo files" @echo " info to make Texinfo files and run them through makeinfo" @echo " gettext to make PO message catalogs" @echo " changes to make an overview of all changed/added/deprecated items" @echo " xml to make Docutils-native XML files" @echo " pseudoxml to make pseudoxml-XML files for display purposes" @echo " linkcheck to check all external links for integrity" @echo " doctest to run all doctests embedded in the documentation (if enabled)" clean: rm -rf $(BUILDDIR)/* html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." dirhtml: $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." singlehtml: $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml @echo @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." pickle: $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle @echo @echo "Build finished; now you can process the pickle files." json: $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json @echo @echo "Build finished; now you can process the JSON files." htmlhelp: $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp @echo @echo "Build finished; now you can run HTML Help Workshop with the" \ ".hhp project file in $(BUILDDIR)/htmlhelp." qthelp: $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/Futhark.qhcp" @echo "To view the help file:" @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/Futhark.qhc" devhelp: $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" @echo "# mkdir -p $$HOME/.local/share/devhelp/Futhark" @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/Futhark" @echo "# devhelp" epub: $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub @echo @echo "Build finished. The epub file is in $(BUILDDIR)/epub." latex: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Run \`make' in that directory to run these through (pdf)latex" \ "(use \`make latexpdf' here to do that automatically)." latexpdf: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." latexpdfja: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through platex and dvipdfmx..." $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." text: $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text @echo @echo "Build finished. The text files are in $(BUILDDIR)/text." man: $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man @echo @echo "Build finished. The manual pages are in $(BUILDDIR)/man." texinfo: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." @echo "Run \`make' in that directory to run these through makeinfo" \ "(use \`make info' here to do that automatically)." info: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo "Running Texinfo files through makeinfo..." make -C $(BUILDDIR)/texinfo info @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." gettext: $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale @echo @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." changes: $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes @echo @echo "The overview file is in $(BUILDDIR)/changes." linkcheck: $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck @echo @echo "Link check complete; look for any errors in the above output " \ "or in $(BUILDDIR)/linkcheck/output.txt." doctest: $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." xml: $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml @echo @echo "Build finished. The XML files are in $(BUILDDIR)/xml." pseudoxml: $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml @echo @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." futhark-0.25.27/docs/_theme/000077500000000000000000000000001475065116200155335ustar00rootroot00000000000000futhark-0.25.27/docs/_theme/futhark/000077500000000000000000000000001475065116200171775ustar00rootroot00000000000000futhark-0.25.27/docs/_theme/futhark/static/000077500000000000000000000000001475065116200204665ustar00rootroot00000000000000futhark-0.25.27/docs/_theme/futhark/static/style.css000066400000000000000000000226131475065116200223440ustar00rootroot00000000000000/* * Futhark theme based on bizstyle.css_t * ~~~~~~~~~~~~~~ * * :copyright: Copyright 2011-2014 by Sphinx team, see AUTHORS. * :license: BSD, see LICENSE for details. * */ @import url("basic.css"); /* -- page layout ----------------------------------------------------------- */ body { font-family: 'Lucida Grande', 'Lucida Sans Unicode', 'Geneva', 'Verdana', sans-serif; font-size: 14px; letter-spacing: -0.01em; line-height: 150%; text-align: center; background-color: #fff9e5; color: black; padding: 0; border-right: 1px solid #5f021f; border-left: 1px solid #5f021f; margin: 0px 40px 0px 40px; } div.document { text-align: left; background-repeat: repeat-x; } div.documentwrapper { float: left; width: 100%; } div.bodywrapper { margin: 0 0 0 240px; border-left: 1px solid #ccc; } div.body { margin: 0; padding: 0.5em 20px 20px 20px; } {%- if theme_rightsidebar|tobool %} div.bodywrapper { margin: 0 calc({{ theme_sidebarwidth|todim }} + 30px) 0 0; border-right: 1px solid #ccc; } {%- else %} div.bodywrapper { margin: 0 0 0 calc({{ theme_sidebarwidth|todim }} + 30px); } {%- endif %} div.related { font-size: 1em; } div.related ul { background-color: ##5f021f; height: 100%; overflow: hidden; border-top: 1px solid #ddd; border-bottom: 1px solid #ddd; } div.related ul li { margin: 0; padding: 0; height: 2em; float: left; } div.related ul li.right { float: right; margin-right: 5px; } div.related ul li a { margin: 0; padding: 0 5px 0 5px; line-height: 1.75em; } div.related ul li a:hover { text-decoration: underline; } div.sphinxsidebarwrapper { padding: 0; } div.sphinxsidebar { padding: 0.5em 12px 12px 12px; width: {{ theme_sidebarwidth|todim }}; {%- if theme_rightsidebar|tobool %} float: right; {%- endif %} font-size: 1em; text-align: left; } div.sphinxsidebar h3, div.sphinxsidebar h4 { margin: 1em 0 0.5em 0; font-size: 1em; padding: 0.1em 0 0.1em 0.5em; color: white; border: 1px solid #5f021f; background-color: #5f021f; } div.sphinxsidebar h3 a { color: white; } div.sphinxsidebar ul { padding-left: 1.5em; margin-top: 7px; padding: 0; line-height: 130%; } div.sphinxsidebar ul ul { margin-left: 20px; } div.sphinxsidebar input { border: 1px solid #5f021f; } div.footer { color: #5f021f; padding: 3px 8px 3px 0; clear: both; font-size: 0.8em; text-align: right; border-bottom: 1px solid #5f021f; } div.footer a { color: #5f021f; text-decoration: underline; } /* -- body styles ----------------------------------------------------------- */ p { margin: 0.8em 0 0.5em 0; } a { color: #5f021f; text-decoration: none; } a:hover { color: #5f021f; text-decoration: underline; } div.body a { text-decoration: underline; } h1, h2, h3 { color: {{ theme_maincolor }}; } h1 { margin: 0; padding: 0.7em 0 0.3em 0; font-size: 1.5em; } h2 { margin: 1.3em 0 0.2em 0; font-size: 1.35em; padding-bottom: .5em; border-bottom: 1px solid #5f021f; } h3 { margin: 1em 0 -0.3em 0; font-size: 1.2em; padding-bottom: .3em; border-bottom: 1px solid #CCCCCC; } div.body h1 a, div.body h2 a, div.body h3 a, div.body h4 a, div.body h5 a, div.body h6 a { color: black!important; } h1 a.anchor, h2 a.anchor, h3 a.anchor, h4 a.anchor, h5 a.anchor, h6 a.anchor { display: none; margin: 0 0 0 0.3em; padding: 0 0.2em 0 0.2em; color: #aaa!important; } h1:hover a.anchor, h2:hover a.anchor, h3:hover a.anchor, h4:hover a.anchor, h5:hover a.anchor, h6:hover a.anchor { display: inline; } h1 a.anchor:hover, h2 a.anchor:hover, h3 a.anchor:hover, h4 a.anchor:hover, h5 a.anchor:hover, h6 a.anchor:hover { color: #777; background-color: #eee; } a.headerlink { color: #c60f0f!important; font-size: 1em; margin-left: 6px; padding: 0 4px 0 4px; text-decoration: none!important; } a.headerlink:hover { background-color: #ccc; color: white!important; } cite, code, tt { font-family: 'Consolas', 'Deja Vu Sans Mono', 'Bitstream Vera Sans Mono', monospace; font-size: 0.95em; letter-spacing: 0.01em; } code { background-color: #F2F2F2; border-bottom: 1px solid #ddd; color: #333; } code.descname, code.descclassname, code.xref { border: 0; } hr { border: 1px solid #abc; margin: 2em; } a code { border: 0; color: #CA7900; } a code:hover { color: #2491CF; } pre { background-color: #eeeeee; font-family: 'Consolas', 'Deja Vu Sans Mono', 'Bitstream Vera Sans Mono', monospace; font-size: 0.95em; letter-spacing: 0.015em; line-height: 120%; padding: 0.5em; } pre a { color: inherit; text-decoration: underline; } td.linenos pre { padding: 0.5em 0; } div.quotebar { background-color: #f8f8f8; max-width: 250px; float: right; padding: 2px 7px; border: 1px solid #ccc; } div.topic { background-color: #f8f8f8; } table { border-collapse: collapse; margin: 0 -0.5em 0 -0.5em; } table td, table th { padding: 0.2em 0.5em 0.2em 0.5em; } div.admonition { font-size: 0.9em; margin: 1em 0 1em 0; border: 3px solid #cccccc; background-color: #165e83; padding: 0; } div.admonition p { margin: 0.5em 1em 0.5em 1em; padding: 0; } div.admonition li p { margin-left: 0; } div.admonition pre, div.warning pre { margin: 0; } div.highlight { margin: 0.4em 1em; } div.admonition p.admonition-title { margin: 0; padding: 0.1em 0 0.1em 0.5em; color: white; border-bottom: 3px solid #cccccc; font-weight: bold; background-color: #165e83; } div.danger { border: 3px solid #f0908d; background-color: #f0cfa0; } div.error { border: 3px solid #f0908d; background-color: #ede4cd; } div.warning { border: 3px solid #f8b862; background-color: #f0cfa0; } div.caution { border: 3px solid #f8b862; background-color: #ede4cd; } div.attention { border: 3px solid #f8b862; background-color: #f3f3f3; } div.important { border: 3px solid #f0cfa0; background-color: #ede4cd; } div.note { border: 3px solid #f0cfa0; background-color: #f3f3f3; } div.hint { border: 3px solid #bed2c3; background-color: #f3f3f3; } div.tip { border: 3px solid #bed2c3; background-color: #f3f3f3; } div.danger p.admonition-title, div.error p.admonition-title { background-color: #b7282e; border-bottom: 3px solid #f0908d; } div.caution p.admonition-title, div.warning p.admonition-title, div.attention p.admonition-title { background-color: #f19072; border-bottom: 3px solid #f8b862; } div.note p.admonition-title, div.important p.admonition-title { background-color: #f8b862; border-bottom: 3px solid #f0cfa0; } div.hint p.admonition-title, div.tip p.admonition-title { background-color: #7ebea5; border-bottom: 3px solid #bed2c3; } div.admonition ul, div.admonition ol, div.warning ul, div.warning ol { margin: 0.1em 0.5em 0.5em 3em; padding: 0; } div.versioninfo { margin: 1em 0 0 0; border: 1px solid #ccc; background-color: #DDEAF0; padding: 8px; line-height: 1.3em; font-size: 0.9em; } .viewcode-back { font-family: 'Lucida Grande', 'Lucida Sans Unicode', 'Geneva', 'Verdana', sans-serif; } div.viewcode-block:target { background-color: #f4debf; border-top: 1px solid #ac9; border-bottom: 1px solid #ac9; } p.versionchanged span.versionmodified { font-size: 0.9em; margin-right: 0.2em; padding: 0.1em; background-color: #DCE6A0; } dl.field-list > dt { color: white; background-color: #82A0BE; } dl.field-list > dd { background-color: #f7f7f7; } /* -- table styles ---------------------------------------------------------- */ table.docutils { margin: 1em 0; padding: 0; border: 1px solid white; background-color: #f7f7f7; } table.docutils td, table.docutils th { padding: 1px 8px 1px 5px; border-top: 0; border-left: 0; border-right: 1px solid white; border-bottom: 1px solid white; } table.docutils td p { margin-top: 0; margin-bottom: 0.3em; } table.field-list td, table.field-list th { border: 0 !important; word-break: break-word; } table.footnote td, table.footnote th { border: 0 !important; } th { color: white; text-align: left; padding-right: 5px; background-color: #82A0BE; } div.literal-block-wrapper div.code-block-caption { background-color: #EEE; border-style: solid; border-color: #CCC; border-width: 1px 5px; } /* WIDE DESKTOP STYLE */ @media only screen and (min-width: 1176px) { body { margin: 0 40px 0 40px; } } /* TABLET STYLE */ @media only screen and (min-width: 768px) and (max-width: 991px) { body { margin: 0 40px 0 40px; } } /* MOBILE LAYOUT (PORTRAIT/320px) */ @media only screen and (max-width: 767px) { body { margin: 0; } div.bodywrapper { margin: 0; width: 100%; border: none; } div.sphinxsidebar { display: none; } } /* MOBILE LAYOUT (LANDSCAPE/480px) */ @media only screen and (min-width: 480px) and (max-width: 767px) { body { margin: 0 20px 0 20px; } } /* RETINA OVERRIDES */ @media only screen and (-webkit-min-device-pixel-ratio: 2), only screen and (min-device-pixel-ratio: 2) { } /* -- end ------------------------------------------------------------------- */ futhark-0.25.27/docs/_theme/futhark/theme.conf000066400000000000000000000002751475065116200211540ustar00rootroot00000000000000[theme] inherit = bizstyle stylesheet = style.css pygments_style = sphinx [options] rightsidebar = false stickysidebar = false bodyfont = sans-serif headfont = 'Trebuchet MS', sans-serif futhark-0.25.27/docs/binary-data-format.rst000066400000000000000000000054241475065116200205120ustar00rootroot00000000000000.. _binary-data-format: Binary Data Format ================== Futhark programs compiled to an executable support both textual and binary input. Both are read via standard input, and can be mixed, such that one argument to an entry point may be binary, and another may be textual. The binary input format takes up significantly less space on disk, and can be read much faster than the textual format. This chapter describes the binary input format and its current limitations. The input formats (whether textual or binary) are not used for Futhark programs compiled to libraries, which instead use whichever format is supported by their host language. Currently reading binary input is only supported for compiled programs. It is *not* supported for ``futhark run``. You can generate random data in the binary format with ``futhark dataset`` (:ref:`futhark-dataset(1)`). This tool can also be used to convert between binary and textual data. Futhark-generated executables can be asked to generate binary output with the ``-b`` option. Specification ------------- Elements that are bigger than one byte are always stored using little endian -- we mostly run our code on x86 hardware so this seemed like a reasonable choice. When reading input for an argument to the entry function, we need to be able to differentiate between text and binary input. If the first non-whitespace character of the input is a ``b`` we will parse this argument as binary, otherwise we will parse it in text format. Allowing preceding whitespace characters makes it easy to use binary input for some arguments, and text input for others. The general format has this header:: b Where ``version`` is a byte containing the version of the binary format used for encoding (currently 2), ``num_dims`` is the number of dimensions in the array as a single byte (0 for scalar), and ``type`` is a 4 character string describing the type of the values(s) -- see below for more details. Encoding a scalar value is done by treating it as a 0-dimensional array:: b 0 To encode an array, we encode the number of dimensions ``n`` as a single byte, each dimension ``dim_i`` as an unsigned 64-bit little endian integer, and finally all the values in row-major order in their binary little endian representation:: b ... Type Values ~~~~~~~~~~~ A type is identified by a 4 character ASCII string (four bytes). Valid types are:: " i8" " i16" " i32" " i64" " u8" " u16" " u32" " u64" " f16" " f32" " f64" "bool" Note that unsigned and signed integers have the same byte-level representation. Values of type ``bool`` are encoded with a byte each. The results are undefined if this byte is not either 0 or 1. futhark-0.25.27/docs/c-api.rst000066400000000000000000001042541475065116200160230ustar00rootroot00000000000000.. _c-api: C API Reference =============== A Futhark program ``futlib.fut`` compiled to a C library with the ``--library`` command line option produces two files: ``futlib.c`` and ``futlib.h``. The API provided in the ``.h`` file is documented in the following. The ``.h`` file can be included by a C++ source file to access the functions (``extern "C"`` is added automatically), but the ``.c`` file must be compiled with a proper C compiler and the resulting object file linked with the rest of the program. Using the API requires creating a *configuration object*, which is then used to obtain a *context object*, which is then used to perform most other operations, such as calling Futhark functions. Most functions that can fail return an integer: 0 on success and a non-zero value on error, as documented below. Others return a ``NULL`` pointer. Use :c:func:`futhark_context_get_error` to get a (possibly) more precise error message. Some functions take a C string (``const char*``) as argument. Unless otherwise indicated, the string will be copied if necessary, meaning the argument string can always be modified (or freed) after the function returns. .. c:macro:: FUTHARK_BACKEND_foo A preprocessor macro identifying that the backend *foo* was used to generate the code; e.g. ``c``, ``opencl``, or ``cuda``. This can be used for conditional compilation of code that only works with specific backends. Error codes ----------- Most errors result in a not otherwise specified nonzero return code, but a few classes of errors have distinct error codes. .. c:macro:: FUTHARK_SUCCESS Defined as ``0``. Returned in case of success. .. c:macro:: FUTHARK_PROGRAM_ERROR Defined as ``2``. Returned when the program fails due to out-of-bounds, an invalid size coercion, invalid entry point arguments, or similar misuse. .. c:macro:: FUTHARK_OUT_OF_MEMORY Defined as ``3``. Returned when the program fails to allocate memory. This is (somewhat) reliable only for GPU memory - due to overcommit and other VM tricks, you should not expect running out of main memory to be reported gracefully. Configuration ------------- Context creation is parameterised by a configuration object. Any changes to the configuration must be made *before* calling :c:func:`futhark_context_new`. A configuration object must not be freed before any context objects for which it is used. The same configuration must *not* be used for multiple concurrent contexts. Configuration objects are cheap to create and destroy. .. c:struct:: futhark_context_config An opaque struct representing a Futhark configuration. .. c:function:: struct futhark_context_config *futhark_context_config_new(void) Produce a new configuration object. You must call :c:func:`futhark_context_config_free` when you are done with it. .. c:function:: void futhark_context_config_free(struct futhark_context_config *cfg) Free the configuration object. .. c:function:: void futhark_context_config_set_debugging(struct futhark_context_config *cfg, int flag) With a nonzero flag, enable various debugging information, with the details specific to the backend. This may involve spewing copious amounts of information to the standard error stream. It is also likely to make the program run much slower. .. c:function:: void futhark_context_config_set_profiling(struct futhark_context_config *cfg, int flag) With a nonzero flag, enable the capture of profiling information. This should not significantly impact program performance. Use :c:func:`futhark_context_report` to retrieve captured information, the details of which are backend-specific. .. c:function:: void futhark_context_config_set_logging(struct futhark_context_config *cfg, int flag) With a nonzero flag, print a running log to standard error of what the program is doing. .. c:function:: int futhark_context_config_set_tuning_param(struct futhark_context_config *cfg, const char *param_name, size_t new_value) Set the value of a tuning parameter. Returns zero on success, and non-zero if the parameter cannot be set. This is usually because a parameter of the given name does not exist. See :c:func:`futhark_get_tuning_param_count` and :c:func:`futhark_get_tuning_param_name` for how to query which parameters are available. Most of the tuning parameters are applied only when the context is created, but some may be changed even after the context is active. At the moment, only parameters of class "threshold" may change after the context has been created. Use :c:func:`futhark_get_tuning_param_class` to determine the class of a tuning parameter. .. c:function:: int futhark_get_tuning_param_count(void) Return the number of available tuning parameters. Useful for knowing how to call :c:func:`futhark_get_tuning_param_name` and :c:func:`futhark_get_tuning_param_class`. .. c:function:: const char* futhark_get_tuning_param_name(int i) Return the name of tuning parameter *i*, counting from zero. .. c:function:: const char* futhark_get_tuning_param_class(int i) Return the class of tuning parameter *i*, counting from zero. .. c:function:: void futhark_context_config_set_cache_file(struct futhark_context_config *cfg, const char *fname) Ask the Futhark context to use a file with the designated file as a cross-execution cache. This can result in faster initialisation of the program next time it is run. For example, the GPU backends will store JIT-compiled GPU code in this file. The cache is managed entirely automatically, and if it is invalid or stale, the program performs initialisation from scratch. There is no machine-readable way to get information about whether the cache was hit succesfully, but you can enable logging to see what happens. Pass ``NULL`` to disable caching (this is the default). Context ------- .. c:struct:: futhark_context An opaque struct representing a Futhark context. .. c:function:: struct futhark_context *futhark_context_new(struct futhark_context_config *cfg) Create a new context object. You must call :c:func:`futhark_context_free` when you are done with it. It is fine for multiple contexts to co-exist within the same process, but you must not pass values between them. They have the same C type, so this is an easy mistake to make. After you have created a context object, you must immediately call :c:func:`futhark_context_get_error`, which will return non-``NULL`` if initialisation failed. If initialisation has failed, then you still need to call :c:func:`futhark_context_free` to release resources used for the context object, but you must not use the context object for anything else. .. c:function:: void futhark_context_free(struct futhark_context *ctx) Free the context object. It must not be used again. You must call :c:func:`futhark_context_sync` before calling this function to ensure there are no outstanding asynchronous operations still running. The configuration must be freed separately with :c:func:`futhark_context_config_free`. .. c:function:: int futhark_context_sync(struct futhark_context *ctx) Block until all outstanding operations, including copies, have finished executing. Many API functions are asynchronous on their own. .. c:function:: void futhark_context_pause_profiling(struct futhark_context *ctx) Temporarily suspend the collection of profiling information. Has no effect if profiling was not enabled in the configuration. .. c:function:: void futhark_context_unpause_profiling(struct futhark_context *ctx) Resume the collection of profiling information. Has no effect if profiling was not enabled in the configuration. .. c:function:: char *futhark_context_get_error(struct futhark_context *ctx) A human-readable string describing the last error. Returns ``NULL`` if no error has occurred. It is the caller's responsibility to ``free()`` the returned string. Any subsequent call to the function returns ``NULL``, until a new error occurs. .. c:function:: void futhark_context_set_logging_file(struct futhark_context *ctx, FILE* f) Set the stream used to print diagnostics, debug prints, and logging messages during runtime. This is ``stderr`` by default. Even when this is used to re-route logging messages, fatal errors will still only be printed to ``stderr``. .. c:function:: char *futhark_context_report(struct futhark_context *ctx) Produce a C string encoding a JSON object with debug and profiling information collected during program runtime. It is the caller's responsibility to free the returned string. It is likely to only contain interesting information if :c:func:`futhark_context_config_set_debugging` or :c:func:`futhark_context_config_set_profiling` has been called previously. Returns ``NULL`` on failure. .. c:function:: int futhark_context_clear_caches(struct futhark_context *ctx) Release any context-internal caches and buffers that may otherwise use computer resources. This is useful for freeing up those resources when no Futhark entry points are expected to run for some time. Particularly relevant when using a GPU backend, due to the relative scarcity of GPU memory. Values ------ Primitive types (``i32``, ``bool``, etc) are mapped directly to their corresponding C type. The ``f16`` type is mapped to ``uint16_t``, because C does not have a standard ``half`` type. This integer contains the bitwise representation of the ``f16`` value in the IEEE 754 binary16 format. .. _array-values: Arrays of Primitive Values ~~~~~~~~~~~~~~~~~~~~~~~~~~ For each distinct array type of primitives (ignoring sizes), an opaque C struct is defined. Arrays of ``f16`` are presented as containing ``uint16_t`` elements. For types that do not map cleanly to C, including records, sum types, and arrays of tuples, see :ref:`opaques`. All array values share a similar API, which is illustrated here for the case of the type ``[]i32``. The creation/retrieval functions are all asynchronous, so make sure to call :c:func:`futhark_context_sync` when appropriate. Memory management is entirely manual. All values that are created with a ``new`` function, or returned from an entry point, *must* at some point be freed manually. Values are internally reference counted, so even for entry points that return their input unchanged, you must still free both the input and the output - this will not result in a double free. .. c:struct:: futhark_i32_1d An opaque struct representing a Futhark value of type ``[]i32``. .. c:function:: struct futhark_i32_1d *futhark_new_i32_1d(struct futhark_context *ctx, int32_t *data, int64_t dim0) Asynchronously create a new array based on the given data. The dimensions express the number of elements. The data is copied into the new value. It is the caller's responsibility to eventually call :c:func:`futhark_free_i32_1d`. Multi-dimensional arrays are assumed to be in row-major form. Returns ``NULL`` on failure. .. c:function:: struct futhark_i32_1d *futhark_new_raw_i32_1d(struct futhark_context *ctx, char *data, int64_t dim0) Create an array based on *raw* data, which is used for the representation of the array. The ``data`` pointer must remain valid for the lifetime of the array and will not be freed by Futhark. Returns ``NULL`` on failure. The type of the ``data`` argument depends on the backend, and is for example ``cl_mem`` when using the OpenCL backend. **This is an experimental and unstable interface.** .. c:function:: int futhark_free_i32_1d(struct futhark_context *ctx, struct futhark_i32_1d *arr) Free the value. In practice, this merely decrements the reference count by one. The value (or at least this reference) must not be used again after this function returns. .. c:function:: int futhark_values_i32_1d(struct futhark_context *ctx, struct futhark_i32_1d *arr, int32_t *data) Asynchronously copy data from the value into ``data``, which must point to free memory, allocated by the caller, with sufficient space to store the full array. Multi-dimensional arrays are written in row-major form. .. c:function:: int futhark_index_i32_1d(struct futhark_context *ctx, int32_t *out, struct futhark_i32_1d *arr, int64_t i0); Asynchronously copy a single element from the array and store it in ``*out``. Returns a nonzero value if the index is out of bounds. **Note:** if you need to read many elements, it is much faster to retrieve the entire array with the ``values`` function, particularly when using a GPU backend. .. c:function:: const int64_t *futhark_shape_i32_1d(struct futhark_context *ctx, struct futhark_i32_1d *arr) Return a pointer to the shape of the array, with one element per dimension. The lifetime of the shape is the same as ``arr``, and must *not* be manually freed. Assuming ``arr`` is a valid object, this function cannot fail. .. c:function:: char* futhark_values_raw_i32_1d(struct futhark_context *ctx, struct futhark_i32_1d *arr) Return a pointer to the underlying storage of the array. The return type depends on the backend, and is for example ``cl_mem`` when using the OpenCL backend. If using unified memory with the ``hip`` or ``cuda`` backends, the pointer can be accessed directly from CPU code. **This is an experimental and unstable interface.** .. _opaques: Opaque Values ~~~~~~~~~~~~~ Each instance of a complex type in an entry point (records, nested tuples, etc) is represented by an opaque C struct named ``futhark_opaque_foo``. In the general case, ``foo`` will be a hash of the internal representation. However, if you insert an explicit type annotation in the entry point (and the type name contains only characters valid in C identifiers), that name will be used. Note that arrays contain brackets, which are not valid in identifiers. Defining a type abbreviation is the best way around this. The API for opaque values is similar to that of arrays, and the same rules for memory management apply. You cannot construct them from scratch (unless they correspond to records or tuples, see :ref:`records`), but must obtain them via entry points (or deserialisation, see :c:func:`futhark_restore_opaque_foo`). .. c:struct:: futhark_opaque_foo An opaque struct representing a Futhark value of type ``foo``. .. c:function:: int futhark_free_opaque_foo(struct futhark_context *ctx, struct futhark_opaque_foo *obj) Free the value. In practice, this merely decrements the reference count by one. The value (or at least this reference) must not be used again after this function returns. .. c:function:: int futhark_store_opaque_foo(struct futhark_context *ctx, const struct futhark_opaque_foo *obj, void **p, size_t *n) Serialise an opaque value to a byte sequence, which can later be restored with :c:func:`futhark_restore_opaque_foo`. The byte representation is not otherwise specified, and is not stable between compiler versions or programs. It is stable under change of compiler backend, but not change of compiler version, or modification to the source program (although in most cases the format will not change). The variable pointed to by ``n`` will always be set to the number of bytes needed to represent the value. The ``p`` parameter is more complex: * If ``p`` is ``NULL``, the function will write to ``*n``, but not actually serialise the opaque value. * If ``*p`` is ``NULL``, the function will allocate sufficient storage with ``malloc()``, serialise the value, and write the address of the byte representation to ``*p``. The caller gains ownership of this allocation and is responsible for freeing it. * Otherwise, the serialised representation of the value will be stored at ``*p``, which *must* have room for at least ``*n`` bytes. This is done asynchronously. Returns 0 on success. .. c:function:: struct futhark_opaque_foo* futhark_restore_opaque_foo(struct futhark_context *ctx, const void *p) Asynchronously restore a byte sequence previously written with :c:func:`futhark_store_opaque_foo`. Returns ``NULL`` on failure. The byte sequence does not need to have been generated by the same program *instance*, but it *must* have been generated by the same Futhark program, and compiled with the same version of the Futhark compiler. .. _records: Records ~~~~~~~ A record is an opaque type (see above) that supports additional functions to *project* individual fields (read their values) and to construct a value given values for the fields. An opaque type is a record if its definition is a record at the Futhark level. Note that a tuple is simply a record with numeric fields. The projection and construction functions are equivalent in functionality to writing entry points by hand, and so serve only to cut down on boilerplate. Important things to be aware of: 1. The objects constructed though these functions have their own lifetime (like any objects returned from an entry point) and must be manually freed, independently of the records from which they are projected, or the fields they are constructed from. 2. The objects are however in an *aliasing* relationship with the fields or original record. This means you must be careful when passing them to entry points that consume their arguments. As always, you don't have to worry about this if you never write entry points that consume their arguments. 3. You must synchronise before using any scalar results. The precise functions generated depend on the fields of the record. The following functions assume a record with Futhark-level type ``type t = {foo: t1, bar: t2}`` where ``t1`` and ``t2`` are also opaque types. .. c:function:: int futhark_new_opaque_t(struct futhark_context *ctx, struct futhark_opaque_t **out, const struct futhark_opaque_t2 *bar, const struct futhark_opaque_t1 *foo); Construct a record in ``*out`` which has the given values for the ``bar`` and ``foo`` fields. The parameters are the fields in alphabetic order. As a special case, if the record is a tuple (i.e., has numeric fields), the parameters are ordered numerically. Tuple fields are named ``vX`` where ``X`` is an integer. The resulting record *aliases* the values provided for ``bar`` and ``foo``, but has its own lifetime, and all values must be individually freed when they are no longer needed. .. c:function:: int futhark_project_opaque_t_bar(struct futhark_context *ctx, struct futhark_opaque_t2 **out, const struct futhark_opaque_t *obj); Extract the value of the field ``bar`` from the provided record. The resulting value *aliases* the record, but has its own lifetime, and must eventually be freed. .. c:function:: int futhark_project_opaque_t_foo(struct futhark_context *ctx, struct futhark_opaque_t1 **out, const struct futhark_opaque_t *obj); Extract the value of the field ``bar`` from the provided record. The resulting value *aliases* the record, but has its own lifetime, and must eventually be freed. .. _sums: Sums ~~~~ A sum type is an opaque type (see above) that supports construction and destruction functions. An opaque type is a sum type if its definition is a sum type at the Futhark level. Similarly to records (see :ref:`Records`), this functionality is equivalent to writing entry points by hand, and have the same properties regarding lifetimes. A sum type consists of one or more variants. A value of this type is always an instance of one of these variants. In the C API, these variants are numbered from zero. The numbering is given by the order in which they are represented in the manifest (see :ref:`manifest`), which is also the order in which their associated functions are defined in the header file. For an opaque sum type ``t``, the following function is always generated. .. c:function:: int futhark_variant_opaque_t(struct futhark_context *ctx, const struct futhark_opaque_t *v); Return the identifying number of the variant of which this sum type is an instance (see above). Cannot fail. For each variant ``foo``, construction and destruction functions are defined. The following assume ``t`` is defined as ``type t = #foo ([]i32) bool``. .. c:function:: int futhark_new_opaque_t_foo(struct futhark_context *ctx, struct futhark_opaque_contrived **out, const struct futhark_i32_1d *v0, const bool v1); Construct a value of type ``t`` that is an instance of the variant ``foo``. Arguments are provided in the same order as in the Futhark-level ``foo`` constructr. **Beware:** if ``t`` has size parameters that are only used for *other* variants than the one that is being instantiated, those size parameters will be set to 0. If this is a problem for your application, define your own entry point for constructing a value with the proper sizes. .. c:function:: int futhark_destruct_opaque_contrived_foo(struct futhark_context *ctx, struct futhark_i32_1d **v0, bool *v1, const struct futhark_opaque_contrived *obj); Extract the payload of variant ``foo`` from the sum value. Despite the name, "destruction" does not free the sum type value. The extracted values alias the sum value, but has their own lifetime, and must eventually be freed. **Precondition:** ``t`` must be an instance of the ``foo`` variant, which can be determined with :c:func:`futhark_variant_opaque_t`. .. _arrays_of_opaques: Arrays of Non-Primitive Values ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ An array that contains a non-primitive type is considered an opaque value. However, it also supports a subset of the API documented in :ref:`array-values`. For an opaque array type ``[]t``, the following functions are always generated (assuming the generated C type is ``arr_t``): .. c:function:: int futhark_index_opaque_arr_t(struct futhark_context *ctx, struct futhark_opaque_t **out, struct futhark_opaque_arr_t *arr, int64_t i0); Asynchronously copy a single element from the array and store it in ``*out``. Returns a nonzero value if the index is out of bounds. .. c:function:: const int64_t *futhark_shape_opaque_arr_t(struct futhark_context *ctx, struct futhark_opaque_arr_t *arr); Return a pointer to the shape of the array, with one element per dimension. The lifetime of the shape is the same as ``arr``, and must *not* be manually freed. Assuming ``arr`` is a valid object, this function cannot fail. Additionally, if the element type is a record (or equivalently a tuple), for example if the array type is ``[](f32,f32)``, the following functions are also available: .. c:function:: int futhark_zip_opaque_arr1d_tup2_f32_f32(struct futhark_context *ctx, struct futhark_opaque_arr1d_tup2_f32_f32 **out, const struct futhark_f32_1d *f_0, const struct futhark_f32_1d *f_1); Construct an array of records from arrays of the component values. This is analogous to ``zip`` in the source language. The provided arrays must have compatible shapes, and the function returns nonzero if they do not. **Note:** This is a cheap operation, as it does not copy array elements. **Note:** The resulting array aliases the original arrays. .. c:function:: int futhark_project_opaque_arr1d_tup2_f32_f32_0(struct futhark_context *ctx, struct futhark_f32_1d **out, const struct futhark_opaque_arr1d_tup2_f32_f32 *obj); Retrieve an array of all the ``.0`` fields of the array elements. A similar function is provided for each field. **Note:** This is a cheap operation, as it does not copy array elements. **Note:** The resulting array aliases the original array. Entry points ------------ Entry points are mapped 1:1 to C functions. Return values are handled with *out*-parameters. For example, this Futhark entry point:: entry sum = i32.sum Results in the following C function: .. c:function:: int futhark_entry_sum(struct futhark_context *ctx, int32_t *out0, const struct futhark_i32_1d *in0) Asynchronously call the entry point with the given arguments. Make sure to call :c:func:`futhark_context_sync` before using the value of ``out0``. Errors are indicated by a nonzero return value. On error, the *out*-parameters are not touched. The precise semantics of the return value depends on the backend. For the sequential C backend, errors will always be available when the entry point returns, and :c:func:`futhark_context_sync` will always return zero. When using a GPU backend such as ``cuda`` or ``opencl``, the entry point may still be running asynchronous operations when it returns, in which case the entry point may return zero successfully, even though execution has already (or will) fail. These problems will be reported when :c:func:`futhark_context_sync` is called. Therefore, be careful to check the return code of *both* the entry point itself, and :c:func:`futhark_context_sync`. For the rules on entry points that consume their input, see :ref:`api-consumption`. Note that even if a value has been consumed, you must still manually free it. This is the only operation that is permitted on a consumed value. GPU --- The following API functions are available when using the ``opencl``, ``cuda``, or ``hip`` backends. .. c:function:: void futhark_context_config_set_device(struct futhark_context_config *cfg, const char *s) Use the first device whose name contains the given string. The special string ``#k``, where ``k`` is an integer, can be used to pick the *k*-th device, numbered from zero. If used in conjunction with :c:func:`futhark_context_config_set_platform`, only the devices from matching platforms are considered. .. c:function:: void futhark_context_config_set_unified_memory(struct futhark_context_config* cfg, int flag); Use "unified" memory for GPU arrays. This means arrays are located in memory that is also accessible from the CPU. The details depends on the backend and hardware in use. The following values are supported: * 0: never use unified memory (the default on ``hip``). * 1: always use unified memory. * 2: use managed memory if the device claims to support it (the default on ``cuda``). Exotic ~~~~~~ The following functions are not interesting to most users. .. c:function:: void futhark_context_config_set_default_thread_block_size(struct futhark_context_config *cfg, int size) Set the default number of work-items in a thread block. .. c:function:: void futhark_context_config_set_default_group_size(struct futhark_context_config *cfg, int size) Identical to :c:func:`futhark_context_config_set_default_thread_block_size`; provided for backwards compatibility. .. c:function:: void futhark_context_config_set_default_grid_size(struct futhark_context_config *cfg, int num) Set the default number of thread blocks used for kernels. .. c:function:: void futhark_context_config_set_default_num_groups(struct futhark_context_config *cfg, int num) Identical to :c:func:`futhark_context_config_set_default_grid_size`; provided for backwards compatibility. .. c:function:: void futhark_context_config_set_default_tile_size(struct futhark_context_config *cfg, int num) Set the default tile size used when executing kernels that have been block tiled. .. c:function:: const char* futhark_context_config_get_program(struct futhark_context_config *cfg) Retrieve the embedded GPU program. The context configuration keeps ownership, so don't free the string. .. c:function:: void futhark_context_config_set_program(struct futhark_context_config *cfg, const char *program) Instead of using the embedded GPU program, use the provided string, which is copied by this function. OpenCL ------ The following API functions are available only when using the ``opencl`` backend. .. c:function:: void futhark_context_config_set_platform(struct futhark_context_config *cfg, const char *s) Use the first OpenCL platform whose name contains the given string. The special string ``#k``, where ``k`` is an integer, can be used to pick the *k*-th platform, numbered from zero. .. c:function:: void futhark_context_config_select_device_interactively(struct futhark_context_config *cfg) Immediately conduct an interactive dialogue on standard output to select the platform and device from a list. .. c:function:: void futhark_context_config_set_command_queue(struct futhark_context_config *cfg, cl_command_queue queue) Use exactly this command queue for the context. If this is set, all other device/platform configuration options are ignored. Once the context is active, the command queue belongs to Futhark and must not be used by anything else. This is useful for implementing custom device selection logic in application code. .. c:function:: cl_command_queue futhark_context_get_command_queue(struct futhark_context *ctx) Retrieve the command queue used by the Futhark context. Be very careful with it - enqueueing your own work is unlikely to go well. Exotic ~~~~~~ The following functions are used for debugging generated code or advanced usage. .. c:function:: void futhark_context_config_add_build_option(struct futhark_context_config *cfg, const char *opt) Add a build option to the OpenCL kernel compiler. See the OpenCL specification for `clBuildProgram` for available options. .. c:function:: cl_program futhark_context_get_program(struct futhark_context_config *cfg) Retrieve the compiled OpenCL program. .. c:function:: void futhark_context_config_load_binary_from(struct futhark_context_config *cfg, const char *path) During :c:func:`futhark_context_new`, read a compiled OpenCL binary from the given file instead of using the embedded program. CUDA ---- The following API functions are available when using the ``cuda`` backend. Exotic ~~~~~~ The following functions are used for debugging generated code or advanced usage. .. c:function:: void futhark_context_config_add_nvrtc_option(struct futhark_context_config *cfg, const char *opt) Add a build option to the NVRTC compiler. See the CUDA documentation for ``nvrtcCompileProgram`` for available options. .. c:function:: void futhark_context_dump_ptx_to(struct futhark_context_config *cfg, const char *path) During :c:func:`futhark_context_new`, dump the generated PTX code to the given file. .. c:function:: void futhark_context_config_load_ptx_from(struct futhark_context_config *cfg, const char *path) During :c:func:`futhark_context_new`, read PTX code from the given file instead of using the embedded program. Multicore --------- The following API functions are available when using the ``multicore`` backend. .. c:function:: void futhark_context_config_set_num_threads(struct futhark_context_config *cfg, int n) The number of threads used to run parallel operations. If set to a value less than ``1``, then the runtime system will use one thread per detected core. General guarantees ------------------ Calling an entry point, or interacting with Futhark values through the functions listed above, has no system-wide side effects, such as writing to the file system, launching processes, or performing network connections. Defects in the program or Futhark compiler itself can with high probability result only in the consumption of CPU or GPU resources, or a process crash. Using the ``#[unsafe]`` attribute with in-place updates can result in writes to arbitrary memory locations. A malicious program can likely exploit this to obtain arbitrary code execution, just as with any insecure C program. If you must run untrusted code, consider using the ``--safe`` command line option to instruct the compiler to disable ``#[unsafe]``. Initialising a Futhark context likewise has no side effects, except if explicitly configured differently, such as by using :c:func:`futhark_context_config_dump_program_to`. In its default configuration, Futhark will not access the file system. Note that for the GPU backends, the underlying API (such as CUDA or OpenCL) may perform file system operations during startup, and perhaps for caching GPU kernels in some cases. This is beyond Futhark's control. Violation the restrictions of consumption (see :ref:`api-consumption`) can result in undefined behaviour. This does not matter for programs whose entry points do not have unique parameter types (:ref:`in-place-updates`). .. _manifest: Manifest -------- When compiling with ``--library``, the C backends generate a machine-readable *manifest* in JSON format that describes the API of the compiled Futhark program. Specifically, the manifest contains: * A mapping from the name of each entry point to: * The C function name of the entry point. * A list of all *inputs*, including their type (as a name) and *whether they are unique* (consuming). * A list of all *outputs*, including their type (as a name) and *whether they are unique*. * A list of all *tuning parameters* that can influence the execution of this entry point. These are not necessarily unique to the entry point. * A mapping from the name of each non-scalar type to: * The C type used to represent this type (which is in practice always a pointer of some kind). * What *kind* of type this is - either an *array* or an *opaque*. * For arrays, the element type and rank. * A mapping from *operations* to the names of the C functions that implement the operations for the type. The types of the C functions are as documented above. The following operations are listed: * For arrays: ``free``, ``shape``, ``values``, ``new``, ``index``. * For opaques: ``free``, ``store``, ``restore``. * For opaques that are actually records (including tuples): * The list of fields, including their type and a projection function. The field ordering here is the one expected by the *new* function. * The name of the C *new* function for creating a record from field values. * For opaques that are actually arrays of records: * The element type and rank. * The operations ``index``, ``shape``, ``zip``. * The fields, which will be the fields of the element type, but with the dimensions preprended. These are the types of the arrays that should be passed to the ``zip`` function. * For other opaques that are actually arrays: * The element type and rank. * The operations ``index`` and ``shape``. Manifests are defined by the following JSON Schema: .. include:: manifest.schema.json :code: json It is likely that we will add more fields in the future, but it is unlikely that we will remove any. futhark-0.25.27/docs/c-porting-guide.rst000066400000000000000000000141171475065116200200250ustar00rootroot00000000000000.. _c-porting-guide: C Porting Guide =============== This short document contains a collection of tips and tricks for porting simple numerical C code to Futhark. Futhark's sequential fragment is powerful enough to permit a rather straightforward translation of sequential C code that does not rely on pointer mutation. Additionally, we provide hints on how to recognise C coding patterns that are symptoms of C's weak type system, and how better to organise it in Futhark. One intended audience of this document is a programmer who needs to translate a benchmark application written in C, or needs to use a simple numerical algorithm that is already available in the form of C source code. Where This Guide Falls Short ---------------------------- Some C code makes use of unstructured returns and nonlocal exits (``return`` inside loops, for example). These are not easy to express in Futhark, and will require massaging the control flow a bit. C code that uses ``goto`` is likewise not easy to port. Types ----- Futhark provides scalar types that match the ones commonly used in C: ``u8``/``u16``/``u32``/``u64`` for the unsigned integers, ``i8``/``i16``/``i32``/``i64`` for the signed, and ``f32``/``f64`` for ``float`` and ``double`` respectively. In contrast to C, Futhark does not automatically promote types in expressions - you will have to manually make sure that both operands to e.g. a multiplication are of the exact same type. This means that you will need to understand exactly which types a given expression in original C program operates on, which generally boils down to converting the type of the (type-wise) smaller operand to that of the larger. Note that the Futhark ``bool`` type is not considered a number. Operators --------- Most of the C operators can be found in Futhark with their usual names. Note however that the Futhark ``/`` and ``%`` operators for integers round towards negative infinity, whereas their counterparts in C round towards zero. You can write ``//`` and ``%%`` if you want the C behaviour. There is no difference if both operands are non-zero, but ``//`` and ``%%`` may be slightly faster. For unsigned numbers, they are exactly the same. Variable Mutation ----------------- As a sequential language, most C programs quite obviously rely heavily on mutating variables. However, in many programs, this is done in a static manner without indirection through pointers (except for arrays; see below), which is conceptually similar to just declaring a new variable of the same name that shadows the old one. If this is the case, a C assignment can generally be translated to just a ``let``-binding. As an example, let us consider the following function for computing the modular multiplicative inverse of a 16-bit unsigned integer (part of the IDEA encryption algorithm): .. code-block:: c static uint16_t ideaInv(uint16_t a) { uint32_t b; uint32_t q; uint32_t r; int32_t t; int32_t u; int32_t v; b = 0x10001; u = 0; v = 1; while(a > 0) { q = b / a; r = b % a; b = a; a = r; t = v; v = u - q * v; u = t; } if(u < 0) u += 0x10001; return u; } Each iteration of the loop mutates the variables ``a``, ``b``, ``v``, and ``u`` in ways that are visible to the following iteration. Conversely, the "mutations" of ``q``, ``r``, and ``t`` are not truly mutations, and the variable declarations could be moved inside the loop if we wished. Presumably, the C programmer left them outside for aesthetic reasons. When translating such code, it is important to determine exactly how much *true* mutation is going on, and how much is just reuse of variable space. This can usually be done by checking whether a variable is read before it is written in any given iteration - if not, then it is not true mutation. The variables that change value from one iteration of the loop to the next will need to be maintained as *merge parameters* of the Futhark ``do``-loop. The Futhark program resulting from a straightforward port looks as follows: .. code-block:: futhark let main(a: u16): u32 = let b = 0x10001u32 let u = 0i32 let v = 1i32 in let (_,_,u,_) = loop ((a,b,u,v)) while a > 0u16 do let q = b / u32.u16(a) let r = b % u32.u16(a) let b = u32.u16(a) let a = u16.u32(r) let t = v let v = u - i32.u32 (q) * v let u = t in (a,b,u,v) in u32.i32(if u < 0 then u + 0x10001 else u) Note the heavy use of type conversion and type suffixes for constants. This is necessary due to Futhark's lack of implicit conversions. Note also the conspicuous way in which the ``do``-loop is written - the result of a loop iteration consists of variables whose names are identical to those of the merge parameters. This program can still be massaged to make it more idiomatic Futhark - for example, the variable ``t`` only serves to store the old value of ``v`` that is otherwise clobbered. This can be written more elegantly by simply inlining the expressions in the result part of the loop body. Arrays ------ Dynamically sized multidimensional arrays are somewhat awkward in C, and are often simulated via single-dimensional arrays with explicitly calculated indices: .. code:: c a[i * M + j] = foo; This indicates a two-dimensional array ``a`` whose *inner* dimension is of size ``M``. We can usually look at where ``a`` is allocated to figure out what the size of the outer dimension must be: .. code:: c a = malloc(N * M * sizeof(int)); We see clearly that ``a`` is a two-dimensional integer array of size ``N`` times ``M`` - or of type ``[N][M]i32`` in Futhark. Thus, the update expression above would be translated as:: let a[i,j] = foo in ... C programs usually first allocate an array, then enter a loop to provide its initial values. This is not possible in Futhark - consider whether you can write it as a ``replicate``, an ``iota``, or a ``map``. In the worst case, use ``replicate`` to obtain an array of the desired size, then use a ``do``-loop with in-place updates to initialise it (but note that this will run stricly sequentially). futhark-0.25.27/docs/conf.py000077500000000000000000000366271475065116200156120ustar00rootroot00000000000000#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # Futhark documentation build configuration file, created by # sphinx-quickstart on Tue Mar 24 14:21:12 2015. # # This file is execfile()d with the current directory set to its # containing dir. # # Note that not all possible configuration values are present in this # autogenerated file. # # All configuration values have a default; values that are commented out # serve to show the default. import sys import os import re from pygments.lexer import RegexLexer, bygroups from pygments import token from pygments import unistring as uni from pygments.token import ( Text, Comment, Operator, Keyword, Name, String, Number, Punctuation, Whitespace, ) from sphinx.highlighting import lexers from typing import Dict, Any, List, Tuple # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # sys.path.insert(0, os.path.abspath('.')) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = ["sphinx.ext.todo", "sphinx.ext.mathjax"] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # The suffix of source filenames. source_suffix = ".rst" # The encoding of source files. # source_encoding = 'utf-8-sig' # The master toctree document. master_doc = "index" # General information about the project. project = "Futhark" copyright = "2013-2020, DIKU, University of Copenhagen" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. # No reason for a cabal file parser; let's just hack it. def get_version(): # Get cabal file cabal_file = open("../futhark.cabal", "r").read() # Extract version return re.search( r"^version:[ ]*([^ ]*)$", cabal_file, flags=re.MULTILINE ).group(1) version = get_version() # The full version, including alpha/beta/rc tags. release = version # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: # today = '' # Else, today_fmt is used as the format for a strftime call. # today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = ["_build", "lib"] # The reST default role (used for this markup: `text`) to use for all # documents. # default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. # add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). # add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. # show_authors = false # The name of the Pygments (syntax highlighting) style to use. pygments_style = "sphinx" class FutharkLexer(RegexLexer): """ A Futhark lexer .. versionadded:: 2.8 """ name = "Futhark" url = "https://futhark-lang.org/" aliases = ["futhark"] filenames = ["*.fut"] mimetypes = ["text/x-futhark"] num_types = ( "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f32", "f64", ) other_types = ("bool",) reserved = ( "if", "then", "else", "def", "let", "loop", "in", "with", "type", "val", "entry", "for", "while", "do", "case", "match", "include", "import", "module", "open", "local", "assert", "_", ) ascii = ( "NUL", "SOH", "[SE]TX", "EOT", "ENQ", "ACK", "BEL", "BS", "HT", "LF", "VT", "FF", "CR", "S[OI]", "DLE", "DC[1-4]", "NAK", "SYN", "ETB", "CAN", "EM", "SUB", "ESC", "[FGRU]S", "SP", "DEL", ) num_postfix = r"(%s)?" % "|".join(num_types) identifier_re = "[a-zA-Z_][a-zA-Z_0-9']*" # opstart_re = '+\-\*/%=\!><\|&\^' tokens = { "root": [ (r"--(.*?)$", Comment.Single), (r"\s+", Whitespace), (r"\(\)", Punctuation), (r"\b(%s)(?!\')\b" % "|".join(reserved), Keyword.Reserved), ( r"\b(%s)(?!\')\b" % "|".join(num_types + other_types), Keyword.Type, ), # Identifiers (r"#\[([a-zA-Z_\(\) ]*)\]", Comment.Preproc), (r"[#!]?(%s\.)*%s" % (identifier_re, identifier_re), Name), (r"\\", Operator), (r"[-+/%=!><|&*^][-+/%=!><|&*^.]*", Operator), (r"[][(),:;`{}?.\']", Punctuation), # Numbers ( r"0[xX]_*[\da-fA-F](_*[\da-fA-F])*_*[pP][+-]?\d(_*\d)*" + num_postfix, Number.Float, ), ( r"0[xX]_*[\da-fA-F](_*[\da-fA-F])*\.[\da-fA-F](_*[\da-fA-F])*" r"(_*[pP][+-]?\d(_*\d)*)?" + num_postfix, Number.Float, ), (r"\d(_*\d)*_*[eE][+-]?\d(_*\d)*" + num_postfix, Number.Float), ( r"\d(_*\d)*\.\d(_*\d)*(_*[eE][+-]?\d(_*\d)*)?" + num_postfix, Number.Float, ), (r"0[bB]_*[01](_*[01])*" + num_postfix, Number.Bin), (r"0[xX]_*[\da-fA-F](_*[\da-fA-F])*" + num_postfix, Number.Hex), (r"\d(_*\d)*" + num_postfix, Number.Integer), # Character/String Literals (r"'", String.Char, "character"), (r'"', String, "string"), # Special (r"\[[a-zA-Z_\d]*\]", Keyword.Type), (r"\(\)", Name.Builtin), ], "character": [ # Allows multi-chars, incorrectly. (r"[^\\']'", String.Char, "#pop"), (r"\\", String.Escape, "escape"), ("'", String.Char, "#pop"), ], "string": [ (r'[^\\"]+', String), (r"\\", String.Escape, "escape"), ('"', String, "#pop"), ], "escape": [ (r'[abfnrtv"\'&\\]', String.Escape, "#pop"), (r"\^[][" + uni.Lu + r"@^_]", String.Escape, "#pop"), ("|".join(ascii), String.Escape, "#pop"), (r"o[0-7]+", String.Escape, "#pop"), (r"x[\da-fA-F]+", String.Escape, "#pop"), (r"\d+", String.Escape, "#pop"), (r"(\s+)(\\)", bygroups(Whitespace, String.Escape), "#pop"), ], } lexers["futhark"] = FutharkLexer() highlight_language = "text" # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. # keep_warnings = false # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = "futhark" html_theme_path = ["_theme"] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. html_theme_options: Dict[Any, Any] = {} # Add any paths that contain custom themes here, relative to this directory. # html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". # html_title = None # A shorter title for the navigation bar. Default is the same as html_title. # html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. # html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. # html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ['_static'] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. # html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. # html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. # html_use_smartypants = True # Custom sidebar templates, maps document names to template names. html_sidebars = { "**": [ "globaltoc.html", "relations.html", "sourcelink.html", "searchbox.html", ] } # Additional templates that should be rendered to pages, maps page names to # template names. # html_additional_pages = {} # If false, no module index is generated. # html_domain_indices = True # If false, no index is generated. # html_use_index = True # If true, the index is split into individual pages for each letter. # html_split_index = false # If true, links to the reST sources are added to the pages. # html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. # html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. # html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. # html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). # html_file_suffix = None # Output file base name for HTML help builder. htmlhelp_basename = "Futharkdoc" # -- Options for LaTeX output --------------------------------------------- latex_elements: Dict[str, str] = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', # Additional stuff for the LaTeX preamble. #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ ("index", "Futhark.tex", "Futhark User's Guide", "DIKU", "manual"), ] # The name of an image file (relative to this directory) to place at the top of # the title page. # latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. # latex_use_parts = false # If true, show page references after internal links. # latex_show_pagerefs = false # If true, show URL addresses after external links. # latex_show_urls = false # Documents to append as an appendix to all manuals. # latex_appendices = [] # If false, no module index is generated. # latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages: List[Tuple[str, str, str, List[Any], int]] = [ ("man/futhark", "futhark", "a parallel functional array language", [], 1), ( "man/futhark-autotune", "futhark-autotune", "calibrate run-time parameters", [], 1, ), ("man/futhark-c", "futhark-c", "compile Futhark to sequential C", [], 1), ( "man/futhark-multicore", "futhark-multicore", "compile Futhark to multithreaded C", [], 1, ), ( "man/futhark-ispc", "futhark-ispc", "compile Futhark to multithreaded ISPC", [], 1, ), ( "man/futhark-opencl", "futhark-opencl", "compile Futhark to OpenCL", [], 1, ), ("man/futhark-cuda", "futhark-cuda", "compile Futhark to CUDA", [], 1), ("man/futhark-hip", "futhark-hip", "compile Futhark to HIP", [], 1), ( "man/futhark-python", "futhark-python", "compile Futhark to sequential Python", [], 1, ), ( "man/futhark-pyopencl", "futhark-pyopencl", "compile Futhark to Python and OpenCL", [], 1, ), ( "man/futhark-wasm", "futhark-wasm", "compile Futhark to WebAssembly", [], 1, ), ( "man/futhark-wasm-multicore", "futhark-wasm-multicore", "compile Futhark to parallel WebAssembly", [], 1, ), ("man/futhark-run", "futhark-run", "interpret Futhark program", [], 1), ( "man/futhark-repl", "futhark-repl", "interactive Futhark read-eval-print-loop", [], 1, ), ("man/futhark-test", "futhark-test", "test Futhark programs", [], 1), ( "man/futhark-bench", "futhark-bench", "benchmark Futhark programs", [], 1, ), ( "man/futhark-doc", "futhark-doc", "generate documentation for Futhark code", [], 1, ), ( "man/futhark-dataset", "futhark-dataset", "generate random data sets", [], 1, ), ( "man/futhark-fmt", "futhark-fmt", "format Futhark programs", [], 1, ), ("man/futhark-pkg", "futhark-pkg", "manage Futhark packages", [], 1), ( "man/futhark-literate", "futhark-literate", "execute literate Futhark program", [], 1, ), ( "man/futhark-script", "futhark-script", "execute FutharkScript expression", [], 1, ), ( "man/futhark-profile", "futhark-profile", "profile Futhark programs", [], 1, ), ] # If true, show URL addresses after external links. # man_show_urls = false # -- Options for Texinfo output ------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ( "index", "Futhark", "Futhark Documentation", "DIKU", "Futhark", "One line description of project.", "Miscellaneous", ), ] # Documents to append as an appendix to all manuals. # texinfo_appendices = [] # If false, no module index is generated. # texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. # texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. # texinfo_no_detailmenu = false futhark-0.25.27/docs/default.nix000066400000000000000000000001741475065116200164400ustar00rootroot00000000000000with import {}; stdenv.mkDerivation { name = "futhark-docs"; buildInputs = [ python37Packages.sphinx ]; } futhark-0.25.27/docs/error-index.rst000066400000000000000000000704661475065116200172770ustar00rootroot00000000000000.. _error-index: Compiler Error Index ==================== Elaboration on type errors produced by the compiler. Many error messages contain links to the sections below. Uniqueness errors ----------------- .. _use-after-consume: "Using *x*, but this was consumed at *y*." ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ A core principle of uniqueness typing (see :ref:`in-place-updates`) is that after a variable is "consumed", it must not be used again. For example, this is invalid, and will result in the error above: .. code-block:: futhark let y = x with [0] = 0 in x Several operations can *consume* a variable: array update expressions, calling a function with unique-typed parameters, or passing it as the initial value of a unique-typed loop parameter. When a variable is consumed, its *aliases* are also considered consumed. Aliasing is the possibility of two variables occupying the same memory at run-time. For example, this will fail as above, because ``y`` and ``x`` are aliased: .. code-block:: futhark let y = x let z = y with [0] = 0 in x We can always break aliasing by using a ``copy`` expression: .. code-block:: futhark let y = copy x let z = y with [0] = 0 in x .. _not-consumable: "Would consume *x*, which is not consumable" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This error message occurs for programs that try to perform a consumption (such as an in-place update) on variables that are not consumable. For example, it would occur for the following program: .. code-block:: futhark def f (a: []i32) = let a[0] = a[0]+1 in a Only arrays with a a *unique array type* can be consumed. Such a type is written by prefixing the array type with an asterisk. The program could be fixed by writing it like this: .. code-block:: futhark def f (a: *[]i32) = let a[0] = a[0]+1 in a Note that this places extra obligations on the caller of the ``f`` function, since it now *consumes* its argument. See :ref:`in-place-updates` for the full details. You can always obtain a unique copy of an array by using ``copy``: .. code-block:: futhark def f (a: []i32) = let a = copy a let a[0] = a[0]+1 in a But note that in most cases (although not all), this subverts the purpose of using in-place updates in the first place. .. _return-aliased: "Unique-typed return value of *x* is aliased to *y*, which is not consumable" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This can be caused by a function like this: .. code-block:: futhark def f (xs: []i32) : *[]i32 = xs We are saying that ``f`` returns a *unique* array - meaning it has no aliases - but at the same time, it aliases the parameter *xs*, which is not marked as being unique (see :ref:`in-place-updates`). This violates one of the core guarantees provided by uniqueness types, namely that a unique return value does not alias any value that might be used in the future. Imagine if this was permitted, and we had a program that used ``f``: .. code-block:: futhark let b = f a let b[0] = x ... The update of ``b`` is fine, but if ``b`` was allowed to alias ``a`` (hence occupying the same memory), then we would be modifying ``a`` as well, which is a violation of referential transparency. As with most uniqueness errors, it can be fixed by using ``copy xs`` to break the aliasing. We can also change the type of ``f`` to take a unique array as input: .. code-block:: futhark def f (xs: *[]i32) : *[]i32 = xs This makes ``xs`` "consumable", in the sense used by the error message. .. _unique-return-aliased: "A unique-typed component of the return value of *x* is aliased to some other component" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Caused by programs like the following: .. code-block:: futhark def main (xs: *[]i32) : (*[]i32, *[]i32) = (xs, xs) While we are allowed to "consume" ``xs``, as it is a unique parameter, this function is trying to return two unique values that alias each other. This violates one of the core guarantees provided by uniqueness types, namely that a unique return value does not alias any value that might be used in the future (see :ref:`in-place-updates`) - and in this case, the two values alias each other. We can fix this by inserting copies to break the aliasing: .. code-block:: futhark def main (xs: *[]i32) : (*[]i32, *[]i32) = (xs, copy xs) .. _self-aliasing-arg: "Argument passed for consuming parameter is self-aliased." ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Caused by programs like the following: .. code-block:: futhark def g (t: *([]i64, []i64)) = 0 def f n = let x = iota n in g (x,x) The function ``g`` expects to consume two separate ``[]i64`` arrays, but ``f`` passes it a tuple containing two references to the same physical array. This is not allowed, as ``g`` must be allowed to assume that components of consuming record- or tuple parameters have no internal aliases. We can fix this by inserting copies to break the aliasing: .. code-block:: futhark def f n = let x = iota n in g (copy (x,x)) Alternative, we could duplicate the expression producing the array: .. code-block:: futhark def f n = g (iota n, iota n)) .. _consuming-parameter: "Consuming parameter passed non-unique argument" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Caused by programs like the following: .. code-block:: futhark def update (xs: *[]i32) = xs with [0] = 0 def f (ys: []i32) = update ys The update ``function`` *consumes* its ``xs`` argument to perform an :ref:`in-place update `, as denoted by the asterisk before the type. However, the ``f`` function tries to pass an array that it is not allowed to consume (no asterisk before the type). One solution is to change the type of ``f`` so that it also consumes its input, which allows it to pass it on to ``update``: .. code-block:: futhark def f (ys: *[]i32) = update ys Another solution to ``copy`` the array that we pass to ``update``: .. code-block:: futhark def f (ys: []i32) = update (copy ys) .. _consuming-argument: "Non-consuming higher-order parameter passed consuming argument." ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This error occurs when we have a higher-order function that expects a function that does *not* consume its arguments, and we pass it one that does: .. code-block:: futhark def apply 'a 'b (f: a -> b) (x: a) = f x def consume (xs: *[]i32) = xs with [0] = 0 def f (arr: *[]i32) = apply consume arr We can fix this by changing ``consume`` so that it does not have to consume its argument, by adding a ``copy``: .. code-block:: futhark def consume (xs: []i32) = copy xs with [0] = 0 Or we can create a variant of ``apply`` that accepts a consuming function: .. code-block:: futhark def apply 'a 'b (f: *a -> b) (x: *a) = f x .. _alias-free-variable: "Function result aliases the free variable *x*" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Caused by definitions such as the following: .. code-block:: futhark def x = [1,2,3] def f () = x To simplify the tracking of aliases, the Futhark type system requires that the result of a function may only alias the function parameters, not any free variables. Use ``copy`` to fix this: .. code-block:: futhark def f () = copy x .. _size-expression-bind: "Size expression with binding is replaced by unknown size." ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To illustrate this error, consider the following program .. code-block:: futhark def main (xs: *[]i64) = let a = iota (let n = 10 in n+n) in ... Intuitively, the type of ``a`` should be ``[let n = 10 in n+n]i32``, but this puts a binding into a size expression, which is invalid. Therefore, the type checker invents an :term:`unknown size` variable, say ``l``, and assigns ``a`` the type ``[l]i32``. .. _size-expression-consume: "Size expression with consumption is replaced by unknown size." ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To illustrate this error, consider the following program .. code-block:: futhark def consume (xs: *[]i64): i64 = xs[0] def main (xs: *[]i64) = let a = iota (consume xs) in ... Intuitively, the type of ``a`` should be ``[consume ys]i32``, but this puts a consumption of the array ``ys`` into a size expression, which is invalid. Therefore, the type checker invents an :term:`unknown size` variable, say ``l``, and assigns ``a`` the type ``[l]i32``. .. _inaccessible-size: "Parameter *x* refers to size *y* which will not be accessible to the caller ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This happens when the size of an array parameter depends on a name that cannot be expressed in the function type: .. code-block:: futhark def f (x: i64, y: i64) (A: [x]bool) = true Intuitively, this function might have the following type: .. code-block:: futhark val f : (x: i64, y: i64) -> [x]bool -> bool But this is not currently a valid Futhark type. In a function type, each parameter can be named *as a whole*, but it cannot be taken apart in a pattern. In this case, we could fix it by splitting the tuple parameter into two separate parameters: .. code-block:: futhark def f (x: i64) (y: i64) (A: [x]bool) = true This gives the following type: .. code-block:: futhark val f : (x: i64) -> (y: i64) -> [x]bool -> bool Another workaround is to loosen the static safety, and use a size coercion to give A its expected size: .. code-block:: futhark def f (x: i64, y: i64) (A_unsized: []bool) = let A = A_unsized :> [x]bool in true This will produce a function with the following type: .. code-block:: futhark val f [d] : (i64, i64) -> [d]bool -> bool This does however lose the constraint that the size of the array must match one of the elements of the tuple, which means the program may fail at run-time. The error is not always due to an explicit type annotation. It might also be due to size inference: .. code-block:: futhark def f (x: i64, y: i64) (A: []bool) = zip A (iota x) Here the type rules force ``A`` to have size ``x``, leading to a problematic type. It can be fixed using the techniques above. Size errors ----------- .. _unused-size: "Size *x* unused in pattern." ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Caused by expressions like this: .. code-block:: futhark def [n] (y: i32) = x And functions like this: .. code-block:: futhark def f [n] (x: i32) = x Since ``n`` is not the size of anything, it cannot be assigned a value at runtime. Hence this program is rejected. .. _causality-check: "Causality check" ~~~~~~~~~~~~~~~~~ Causality check errors occur when the program is written in such a way that a size is needed before it is actually computed. See :ref:`causality` for the full rules. Contrived example: .. code-block:: futhark def f (b: bool) (xs: []i32) = let a = [] : [][]i32 let b = [filter (>0) xs] in a[0] == b[0] Here the inner size of the array ``a`` must be the same as the inner size of ``b``, but the inner size of ``b`` depends on a ``filter`` operation that is executed after ``a`` is constructed. There are various ways to fix causality errors. In the above case, we could merely change the order of statements, such that ``b`` is bound first, meaning that the size is available by the time ``a`` is bound. In many other cases, we can lift out the "size-producing" expressions into a separate ``let``-binding preceding the problematic expressions. .. _unknown-param-def: "Unknown size *x* in parameter of *y*" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This error occurs when you define a function that can never be applied, as it requires an input of a specific size, and that size is an :term:`unknown size`. Somewhat contrived example: .. code-block:: futhark def f (x: bool) = let n = if x then 10 else 20 in \(y: [n]bool) -> ... The above constructs a function that accepts an array of size 10 or 20, based on the value of ``x`` argument. But the type of ``f true`` by itself would be ``?[n].[n]bool -> bool``, where the ``n`` is unknown. There is no way to construct an array of the right size, so the type checker rejects this program. (In a fully dependently typed language, the type would have been ``[10]bool -> bool``, but Futhark does not do any type-level computation.) In most cases, this error means you have done something you didn't actually mean to. However, in the case that that the above really is what you intend, the workaround is to make the function fully polymorphic, and then perform a size coercion to the desired size inside the function body itself: .. code-block:: futhark def f (x: bool) = let n = if x then 10 else 20 in \(y_any: []bool) -> let y = y_any :> [n]bool in true This requires a check at run-time, but it is the only way to accomplish this in Futhark. .. _existential-param-ret: "Existential size would appear in function parameter of return type" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This occurs most commonly when we use function composition with one or more functions that return an *existential size*. Example: .. code-block:: futhark filter (>0) >-> length The ``filter`` function has this type: .. code-block:: futhark val filter [n] 't : (t -> bool) -> [n]t -> ?[m].[m]t That is, ``filter`` returns an array whose size is not known until the function actually returns. The ``length`` function has this type: .. code-block:: futhark val length [n] 't : [n]t -> i64 Whenever ``length`` occurs (as in the composition above), the type checker must *instantiate* the ``[n]`` with the concrete symbolic size of its input array. But in the composition, that size does not actually exist until ``filter`` has been fully applied. For that matter, the type checker does not know what ``>->`` does, and for all it knows it may actually apply ``filter`` many times to different arrays, yielding different sizes. This makes it impossible to uniquely instantiate the type of ``length``, and therefore the program is rejected. The common workaround is to use *pipelining* instead of composition whenever we use functions with existential return types: .. code-block:: futhark xs |> filter (>0) |> length This works because ``|>`` is left-associative, and hence the ``xs |> filter (>0)`` part will be fully evaluated to a concrete array before ``length`` is reached. We can of course also write it as ``length (filter (>0) xs)``, with no use of either pipelining or composition. .. _unused-existential: "Existential size *n* not used as array size" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This error occurs for type expressions that bind an existential size for which there is no :term:`constructive use`, such as in the following examples: .. code-block:: futhark ?[n].bool ?[n].bool -> [n]bool When we use existential quantification, we are required to use the size constructively within its scope, *in particular* it must not be exclusively as the parameter or return type of a function. To understand the motivation behind this rule, consider that when we use an existential quantifier we are saying that there is *some size*. The size is not known statically, but must be read from some value (i.e. array) at runtime. In the first example above, the existential size ``n`` is not used at all, so the actual value cannot be determined at runtime. In the second example, while an array ``[n]bool`` does exist, it is part of a function type, and at runtime functions are black boxes and don't "carry" the size of their parameter or result types. The workaround is to actually use the existential size. This can be as simple as adding a *witness array* of type ``[n]()``: .. code-block:: futhark ?[n].([n](),bool) ?[n].([n](), bool -> [n]bool) Such an array will take up no space at runtime. .. _anonymous-nonconstructive: "Type abbreviation contains an anonymous size not used constructively as an array size." ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This error occurs for type abbreviations that use anonymous sizes, such as the following: .. code-block:: futhark type^ t = []bool -> bool Such an abbreviation is actually shorthand for .. code-block:: futhark type^ t = ?[n].[n]bool -> bool which is erroneous, but with workarounds, as explained in :ref:`unused-existential`. .. _unify-consuming-param: "Parameter types *x* and *y* are incompatible regarding consuming their arguments ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This error occurs when you provide a function that *does* consume its argument in a context that expects a function that *does not* allow a function that consumes its argument. As a simple example, consider the following contrived function that does consume its argument: .. code-block:: futhark def f (xs: *[]f32) : f32 = 0f32 Now we define another function that is merely ``f``, but with a type annotation that tries to hide the consumption: .. code-block:: futhark def g : []f32 -> f32 = f Allowing this would permit us to hide the fact that ``f`` consumes its argument, which would not be sound, so the type checker complains. .. _ambiguous-size: "Ambiguous size *x*" ~~~~~~~~~~~~~~~~~~~~ There are various sources for this error, but they all have the same ultimate cause: the type checker cannot figure out how some symbolic size name should be resolved to a concrete size. The simplest example, although contrived, is probably this: .. code-block:: futhark let [n][m] (xss: [n][m]i64) = [] The type checker can infer that ``n`` should be zero, but how can it possibly figure out the shape of the (non-existent) rows of the two-dimensional array? This can be fixed in many ways, but adding a type ascription to the array is one of them: ``[] : [0][2]i64``. Another common case arises when using holes. For an expression ``length ???``, how would the type checker figure out the intended size of the array that the hole represents? Again, this can be solved with a type ascription: ``length (??? : [10]bool)``. Finally, ambiguous sizes can also occur for functions that use size parameters only in "non-witnessing" position, meaning sizes that are not actually uses as sizes of real arrays. An example: .. code-block:: futhark def f [n] (g: [n]i64 -> i64) : i64 = n def main = f (\xs -> xs[0]) Note that ``f`` is a higher order function, and that the size parameter ``n`` is only used in the type of the ``g`` function. Futhark's value model is such that given a value of type ``[n]i64 -> i64``, we cannot extract an ``n`` from it. Using a function such as ``f`` is only valid when ``n`` can be inferred from the usage, which is not the case here. Again, we can fix it by adding a type ascription to disambiguate: .. code-block:: futhark def main = f (\(xs:[1]i64) -> xs[0]) Module errors ------------- .. _module-is-parametric: "Module *x* is a parametric module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ A parametric module is a module-level function: .. code-block:: futhark module PM (P: {val x : i64}) = { def y = x + 2 } If we directly try to access the component of ``PM``, as ``PM.y``, we will get an error. To use ``PM`` we must first apply it to a module of the expected type: .. code-block:: futhark module M = PM { val x = 2 : i64 } Now we can say ``M.y``. See :ref:`module-system` for more. Other errors ------------ .. _literal-out-of-bounds: "Literal out of bounds" ~~~~~~~~~~~~~~~~~~~~~~~ This occurs for overloaded constants such as ``1234`` that are inferred by context to have a type that is too narrow for their value. Example: .. code-block:: 257 : u8 It is not an error to have a *non-overloaded* numeric constant whose value is too large for its type. The following is perfectly cromulent: .. code-block:: 257u8 In such cases, the behaviour is overflow (so this is equivalent to ``1u8``). .. _ambiguous-type: "Type is ambiguous" ~~~~~~~~~~~~~~~~~~~ There are various cases where the type checker is unable to infer the full type of something. For example: .. code-block:: futhark def f r = r.x We know that ``r`` must be a record with a field called ``x``, but maybe the record could also have other fields as well. Instead of assuming a perhaps too narrow type, the type checker signals an error. The solution is always to add a type annotation in one or more places to disambiguate the type: .. code-block:: futhark def f (r: {x:bool, y:i32}) = r.x Usually the best spot to add such an annotation is on a function parameter, as above. But for ambiguous sum types, we often have to put it on the return type. Consider: .. code-block:: futhark def f (x: bool) = #some x The type of this function is ambiguous, because the type checker must know what other possible contructors (apart from ``#some``) are possible. We fix it with a type annotation on the return type: .. code-block:: futhark def f (x: bool) : (#some bool | #none) = #just x See :ref:`typeabbrevs` for how to avoid typing long types in several places. .. _may-not-be-redefined: "The *x* operator may not be redefined" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The ``&&`` and ``||`` operators have magical short-circuiting behaviour, and therefore may not be redefined. There is no way to define your own short-circuiting operators. .. _unmatched-cases: "Unmatched cases in match expression" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Futhark requires ``match`` expressions to be *exhaustive* - that is, cover all possible forms of the value being matched. Example: .. code-block:: futhark def f (x: i32) = match x case 0 -> false case 1 -> true Usually this is an actual bug, and you fix it by adding the missing cases. But sometimes you *know* that the missing cases will never actually occur at run-time. To satisfy the type checker, you can turn the final case into a wildcard that matches anything: .. code-block:: futhark def f (x: i32) = match x case 0 -> false case _ -> true Alternatively, you can add a wildcard case that explicitly asserts that it should never happen: .. code-block:: futhark def f (x: i32) = match x case 0 -> false case 1 -> true case _ -> assert false false :ref:`See here ` for details on how to use ``assert``. .. _refutable-pattern: "Refutable pattern not allowed here" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This occurs when you try to use a :term:`refutable pattern` in a ``let`` binding or function parameter. A refutable pattern is a pattern that is not guaranteed to match a well-typed value. For example, this expression tries to bind an arbitrary tuple value ``x`` a pattern that requires the first element is ``2``: .. code-block:: futhark let (2, y) = x in 0 What should happen at run-time if ``x`` is not 2? Refutable patterns are only allowed in ``match`` expressions, where the failure to match can be handled. For example: .. code-block:: futhark match x case (2, y) -> 0 case _ -> ... -- do something else .. _record-type-not-known: "Full type of *x* is not known at this point" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ When performing a :ref:`record update `, the type of the field we are updating must be known. This restriction is based on a limitation in the type type checker, so the notion of "known" is a bit subtle: .. code-block:: futhark def f r : {x:i32} = r with x = 0 Even though the return type annotation disambiguates the type, this program still fails to type check. This is because the return type is not consulted until *after* the body of the function has been checked. The solution is to put a type annotation on the parameter instead: .. code-block:: futhark def f (r : {x:i32}) = r with x = 0 Entry points ------------ .. _nested-entry: "Entry points may not be declared inside modules." ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This occurs when the program uses the ``entry`` keyword inside a module: .. code-block:: futhark module m = { entry f x = x + 1 } Entry points can only be declared at the top level of a file. When we wish to make a function from inside a module available as an entry point, we must define a wrapper function: .. code-block:: futhark module m = { def f x = x + 1 } entry f = m.f .. _polymorphic-entry: "Entry point functions may not be polymorphic." ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Entry points are Futhark functions that can be called from other languages, and are therefore limited how advanced their types can be. In this case, the problem is that an entry point may not have a polymorphic type, for example: .. code-block:: futhark entry dup 't (x: t) : (t,t) = x This is an invalid entry point because it uses a type parameter ``'t``. This error occurs frequently when we want to test a polymorphic function. In such cases, the solution is to define one or more *monomorphic* entry points, each for a distinct type. For example, to we can define a variety of monomorphic entry points that call the built-in function ``scan``: .. code-block:: futhark entry scan_i32 (xs: []i32) = scan (+) 0 xs entry scan_f32 (xs: []i32) = scan (*) 1 xs .. _higher-order-entry: "Entry point functions may not be higher-order." ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Entry points are Futhark functions that can be called from other languages, and are therefore limited how advanced their types can be. In this case, the problem is that an entry point may use functions as input or output. For example: .. code-block:: futhark entry apply (f: i32 -> i32) (x: i32) = f x There is no simple workaround for such cases. One option is to manually `defunctionalise `_ to use a non-functional encoding of the functional values, but this can quickly get very elaborate. Following up on the example above, if we know that the only functions that would ever be passed are ``(+y)`` or ``(*y)`` for some ``y``, we could do something like the following: .. code-block:: futhark type function = #add i32 | #mul i32 entry apply (f: function) (x: i32) = match f case #add y -> x + y case #mul y -> x + y Although in many cases, the best solution is simply to define a handful of simpler entry points instead of a single complicated one. .. _size-polymorphic-entry: "Entry point functions must not be size-polymorphic in their return type." ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This somewhat rare error occurs when an entry point returns an array that can have an arbitrary size chosen by its caller. Contrived example: .. code-block:: futhark -- Entry point taking no parameters. entry f [n] : [0][n]i32 = [] The size ``n`` is chosen by the caller. Note that the ``n`` might be inferred and invisible, as in this example: .. code-block:: futhark entry g : [0][]i32 = [] When calling functions within a Futhark program, size parameters are handled by type inference, but entry points are called from the outside world, which is not subject to type inference. If you really must have entry points like this, turn the size parameter into an ordinary parameter: .. code-block:: futhark entry f (n: i64) : [0][n]i32 = [] .. _nonconstructive-entry: "Entry point size parameter [n] only used non-constructively." ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This error occurs for programs such as the following:: .. code-block:: futhark entry main [x] (A: [x+1]i32) = ... The size parameter ``[x]`` is only used in an size expression ``x+1``, rather than directly as an array size. This is allowed for ordinary functions, but not for entry points. The reason is that entry points are not subject to ordinary type inference, as they are called from the external world, meaning that the value of the size parameter ``[x]`` will have to be determined from the size of the array ``A``. This is in principle not a problem for simple sizes like ``x+1``, as it is obvious that ``x == length A - 1``, but in the general case it would require computing function inverses that might not exist. For this reason, entry points require that all size parameters are used :term:`constructively`. As a workaround, you can rewrite the entry point as follows: .. code-block:: futhark entry main [n] (A: [n]i32) = let x = n-1 let A = A :> [x+1]i32 ... Or by passing the ``x`` explicitly: .. code-block:: futhark entry main (x: i64) (A: [x+1]i32) = ... futhark-0.25.27/docs/glossary.rst000066400000000000000000000366411475065116200167010ustar00rootroot00000000000000.. _glossary: Glossary ======== The following defines various Futhark-specific terms used in the documentation and in compiler output. .. glossary:: :sorted: Abstract type A type whose definition has been hidden through a :term:`module ascription`. Aliases The *aliases* of a variable is a set of those other variables with which it might be :term:`aliased`. Aliasing Whether two values might potentially share the same memory at run-time. Clearly, after ``let y = x``, ``x`` and ``y`` are aliased. Also, the slice ``x[i:j]`` is also aliased with ``x``. Aliasing is used to type-check :ref:`in-place-updates`. Anonymous size In a type expression, a size of the form `[]`. Will be :term:`elaborated` to some name (possibly :term:`existentially bound`) by the type checker. Attribute Auxiliary information attached to an expression or declaration, which the compiler or other tool might use for various purposes. See :ref:`attributes`. Coercion Shorthand for a :ref:`size-coercion`. Compiler backend A Futhark compiler backend is technically only responsible for the final compilation result, but from the user's perspective is also coupled with a :term:`compiler pipeline`. The backend corresponds to compiler subcommand, such as ``futhark c``, ``futhark cuda``, ``futhark multicore``, etc. Compiler frontend The part of the compiler responsible for reading code from files, parsing it, and type checking it. Compiler pipeline The series of compiler passes that lie between the :term:`compiler frontend` and the :term:`compiler backend`. Responsible for the majority of program optimisations. In principle the pipeline could be configurable, but in practice each backend is coupled with a specific pipeline. Constructive use A variable ``n`` is used *constructively* in a type if it is used as the size of an array at least once outside of any function arrows. For example, the following types use ``n`` constructively: * ``[n]bool`` * ``([n]bool, bool -> [n]bool)`` * ``([n]bool, [n+1]bool)`` The following do not: * ``[n+1]bool`` * ``bool -> [n]bool`` * ``[n]bool -> bool`` Consumption If a value is passed for a *consuming* function parameter, that value may no longer be used. We say that a an expression is *with consumption* if any values are consumed in the expression. This is banned in some cases where that expression might otherwise be evaluated multiple times. See :ref:`in-place-updates`. Data parallelism Performing the same operation on multiple elements of a collection, such as an array. The ``map`` :term:`SOAC` is the simplest example. This is the form of parallelism supported by Futhark. `See also Wikipedia `_. Defunctionalisation A program transformation always performed by the Futhark compiler, that replaces function values with non-function values. The goal is to avoid having indirect calls through function pointers at run-time. To permit zero-overhead defunctionalisation, the Futhark type rules impose restrictions on :term:`lifted types `. Defunctorisation A program transformation always performed by the Futhark compiler, that compiles away modules using an approach similar to :term:`defunctionalisation`. This makes using e.g. a :term:`parametric module` completely free at run-time. Elaboration The process conducted out by the type checker, where it infers and inserts information not explicitly provided in the program. The most important part of this is type inference, but also includes various other things. Existential size An existential size is a size that is bound by the existential quantifier ``?`` in the same type. For example, in a type ``[n]bool -> ?[m].[m]bool``, the size ``m`` is existential. When such a function is applied, each existential size is instantiated as an :term:`unknown size`. Functor The Standard ML term for what Futhark calls a :term:`parametric module`. GPU backend A :term:`compiler backend` that ultimately produces GPU code. The backends ``opencl`` and ``gpu`` are GPU backends. These have more restrictions than some other backends, particularly with respect to :term:`irregular nested data parallelism`. Higher-ranked type A type that does not describe :term:`values `. Can be seen as a partially applied :term:`type constructor`. Not directly supported by Futhark, but a similar effect can be achieved through the :ref:`module-system`. In-place updates A somewhat misleading term for the syntactic forms ``x with [i] = v`` and ``let x[i] = v``. These are not semantic in-place updates, but can be operationally understood as thus. See :ref:`in-place-updates`. Invariant Not :term:`variant`. Irrefutable pattern A :term:`pattern` that will always match a value of its type. For example, ``(x,y)`` is a pattern that will match any tuple. See also :term:`refutable pattern`. Irregular Something that is not regular. Usually used as shorthand for :term:`irregular nested data parallelism` or :term:`irregular array`. Irregular array An array where the elements do not have the same size. For example, ``[[1], [2,3]`` is irregular. These are not supported in Futhark. Irregular nested data parallelism An instance of :term:`nested data parallelism`, where the :term:`parallel width` of inner parallelism is :term:`variant` to the outer parallelism. For example, the following expression exhibits irregular nested data parallelism:: map (\n -> reduce (+) 0 (iota n)) ns Because the width of the inner ``reduce`` is ``n``, and every iteration of the outer ``map`` has a (potentially) different ``n``. The Futhark :term:`GPU backends` *currently* do not support irregular nested data parallelism well, and will usually sequentialise the irregular loops. In cases that require an :term:`irregular memory allocation`, the compiler may entirely fail to generate code. Irregular memory allocation A situation that occurs when the generated code has to allocate memory inside of an instance of :term:`nested data parallelism`, where the amount to allocate is variant to the outer parallel levels. As a contrived example (that the actual compiler would just optimise away), consider:: map (\n -> let A = iota n in A[10]) ns To construct the array ``A`` in memory, we require ``8n`` bytes, but ``n`` is not known until we start executing the body of the ``map``. While such simple cases are handled, more complicated ones that involve nested sequential loops are not supported by the :term:`GPU backends`. Parametric module A function from :term:`modules` to modules. The most powerful form of abstraction provided by Futhark. Polymorphic Usually means a :term:`polymorphic function`, but sometimes a :term:`parametric modules `. Should not be used to describe a :term:`type constructor `. Polymorphism The concept of being :term:`polymorphic`. Polymorphic function A function with :term:`type parameters `, such that the function can be applied to arguments of various types. Compiled using :term:`monomorphisation`. Lifted type A type that may contain functions, including function types themselves. These have various restrictions on their use in order to support :term:`defunctionalisation`. See :ref:`hofs`. Module A mapping from names to definitions of types, values, or nested modules. See :ref:`module-system`. Module ascription A feature of the module system through which the contents of a module can be hidden. Written as ``m : mt`` where ``m`` is a :term:`module expression` and ``mt`` is a :term:`module type expression`. See :ref:`module-system`. Module expression An expression that is evaluated at compile time, through :term:`defunctorisation` to a :term:`module`. Most commonly just the name of a module. Module type A description of the interface of a :term:`module`. Most commonly used to hide contents in a :term:`module ascription` or to require implementation of an interface in a :term:`parametric module`. Module type expression An expression that is evaluated during type-checking to a :term:`module type`. Monomorphisation A program transformation that instantiates a copy of each :term:`polymorphic` functions for each type it is used with. Performed by the Futhark compiler. Name A lexical token consisting of alphanumeric characters and underscores, for example ``map`` and ``do_it``. Most variables are names. See also :term:`symbol`. Nested data parallelism Nested :term:`data parallelism` occurs when a parallel construct is used inside of another parallel construct. For example, a ``reduce`` might be used inside a function passed to ``map``. Parallel width A somewhat informal term used to describe the size of an array on which we apply a :term:`SOAC`. For example, if ``x`` has type ``[1000]i32``, then ``map f x`` has a parallel width of 1000. Intuitively, the "amount of processors" that would be needed to fully exploit the parallelism of the program, although :term:`nested data parallelism` muddles the picture. Pattern A syntactical construct for decomposing a value into its consituent parts. Patterns are used in function parameters, ``let``-bindings, and ``match``. See :ref:`patterns`. Recursion A function that calls itself. Currently not supported in Futhark. Refutable pattern A :term:`pattern` that does does not match all possible values. For example, the pattern ``(1,x)`` matches only tuples where the first element is ``1``. These may not be used in ``let`` expressions or in function parameters. See also :term:`irrefutable pattern`. Regular nested data parallelism An instance of :term:`nested data parallelism` that is not :term:`irregular`. Fully supports by any :term:`GPU backend`. Size The symbolic size of an array dimension or :term:`abstract type`. Size expression An expression that occurs as the size of an array or size argument. For example, in the type ``[x+2]i32``, ``x+2`` is a size expression. Size expressions can occur syntactically in source code, or due to parameter substitution when applying a :term:`size-dependent function`. Size-dependent function A function where the size of the result depends on the values of the parameters. The function ``iota`` is perhaps the simplest example. Size types Size-dependent types An umbrella term for the part of Futhark's type system that tracks array sizes. See :ref:`size-types`. Size-lifted type A type that may contain internal hidden sizes. These cannot be array elements, as that might potentially result in an :term:`irregular array`. See :ref:`typeabbrevs`. Size argument An argument to a :term:`type constructor` in a :term:`type expression` of the form ``[n]`` or ``[]``. The latter is called an :term:`anonymous size`. Must match a corresponding :term:`size parameter`. Size parameter A parameter of a :term:`polymorphic function` or :term:`type constructor` that ranges over :term:`sizes `. These are written as `[n]` for some `n`, after which `n` is in scope as a term of type ``i64`` within the rest of the definition. Do not confuse them with :term:`type parameters `. SOAC Second Order Array Combinator A term covering the main parallel building blocks provided by Futhark: functions such as ``map``, ``reduce``, ``scan``, and so on. They are *second order* because they accept a functional argument, and so permit :term:`nested data parallelism`. Symbol A lexical token that consts of symbolic (non-alphabetic characters), and can be bound to a value. Infix operators such as ``+`` and ``/`` are symbols. See also :term:`name`. Type A classification of values. ``i32`` and ``[10]i32`` are examples of types. Type abbreviation A shorthand for a longer type, e.g. ``type t = [100]i32``. Can accept :term:`type parameters ` and :term:`size parameters `. The definition is visible to users, unless hidden with a :term:`module ascription`. See :ref:`typeabbrevs`. Type argument An argument to a :term:`type constructor` that is itself a :term:`type`. Must match a corresponding :term:`type parameter`. Type constructor A :term:`type abbreviation` or :term:`abstract type` that has at least one :term:`type parameter` or :term:`size parameter`. Futhark does not support :term:`higher-ranked types `, so when referencing a type constructor in a :term:`type expression`, you must provide corresponding :term:`type arguments ` and :term:`size arguments ` in an appopriate order. Type expression A syntactic construct that is evaluated to a :term:`type` in the type checker, but may contain uses of :term:`type abbreviations ` and :term:`anonymous sizes `. Type parameter A parameter of a :term:`polymorphic function` or :term:`type constructor` that ranges over types. These are written as `'t` for some `t`, after which `t` is in scope as a type within the rest of the definition. Do not confuse them with :term:`size parameters `. Uniqueness types A somewhat misleading term that describes Futhark's system of allowing :term:`consumption` of values, in the interest of allowing :term:`in-place updates`. The only place where *uniqueness* truly occurs is in return types, where e.g. the return type of ``copy`` is *unique* to indicate that the result does not :term:`alias` the argument. Unknown size A size produced by invoking a function whose result type contains an existentially quantified size, such as ``filter``, or because the original :term:`size expression` involves variables that have gone out of scope. Value An object such as the integer ``123`` or the array ``[1,2,3]``. Expressions variables are bound to values and all valid expressions have a :term:`type` describing the form of values they can return. Variant When some value ``v`` computed inside a loop takes a different value for each iteration inside the loop, we say that ``v`` is *variant* to the loop (and otherwise :term:`invariant`). Often used to talk about :term:`irregularity `. When something is nested inside multiple loops, it may be variant to just one of them. futhark-0.25.27/docs/index.rst000066400000000000000000000037761475065116200161500ustar00rootroot00000000000000.. Futhark documentation master file, created by sphinx-quickstart on Tue Mar 24 14:21:12 2015. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. Futhark User's Guide ==================== Welcome to the documentation for the Futhark compiler and language. For a basic introduction, please see `the Futhark website `_. To get started, read the page on :ref:`installation`. Once the compiler has been installed, you might want to take a look at :ref:`usage`. This User's Guide contains a :ref:`language-reference`, but new Futhark programmers are probably better served by reading `Parallel Programming in Futhark `_ first. Documentation for the built-in prelude is also `available online `_. The particularly interested reader may also want to peruse the `publications `_, or the `development blog `_. .. toctree:: :caption: Table of Contents :maxdepth: 2 :numbered: installation.rst usage.rst language-reference.rst c-api.rst js-api.rst package-management.rst performance.rst error-index.rst server-protocol.rst c-porting-guide.rst versus-other-languages.rst binary-data-format.rst glossary.rst .. toctree:: :caption: Manual Pages :maxdepth: 1 man/futhark-autotune.rst man/futhark-bench.rst man/futhark-c.rst man/futhark-cuda.rst man/futhark-dataset.rst man/futhark-doc.rst man/futhark-fmt.rst man/futhark-hip.rst man/futhark-ispc.rst man/futhark-literate.rst man/futhark-script.rst man/futhark-multicore.rst man/futhark-opencl.rst man/futhark-pkg.rst man/futhark-profile.rst man/futhark-pyopencl.rst man/futhark-python.rst man/futhark-repl.rst man/futhark-run.rst man/futhark-test.rst man/futhark-wasm-multicore.rst man/futhark-wasm.rst man/futhark.rst futhark-0.25.27/docs/installation.rst000066400000000000000000000222151475065116200175270ustar00rootroot00000000000000.. _installation: Installation ============ There are two main ways to install the Futhark compiler: using a precompiled tarball or compiling from source. Both methods are discussed below. If you are using Linux, see :ref:`linux-installation`. If you are using Windows, see read :ref:`windows-installation`. If you are using macOS, see :ref:`macos-installation`. Futhark is also available via `Nix `_. If you are using Nix, simply install the ``futhark`` derivation from Nixpkgs. Dependencies ------------ The Linux binaries we distribute are statically linked and should not require any special libraries installed system-wide. When building from source on Linux and macOS, you will need to have the ``gmp``, ``tinfo``, and ``zlib`` libraries installed. These are pretty common, so you may already have them. On Debian-like systems (e.g. Ubuntu), use:: sudo apt install libtinfo-dev libgmp-dev zlib1g-dev If you install Futhark via a package manager (e.g. Homebrew, Nix, or AUR), you shouldn't need to worry about any of this. Actually *running* the output of the Futhark compiler may require additional dependencies, for example an OpenCL library and GPU driver. See the documentation for the respective compiler backends. Compiling from source --------------------- To compile Futhark you must first install an appropriate version of GHC, either with `ghcup `_ or a package manager. Any version since GHC 9.0 should work. You also need the ``cabal`` command line program, which ghcup will install for you as well. You then either retrieve a `source release tarball `_ or perform a checkout of our Git repository:: $ git clone https://github.com/diku-dk/futhark.git This will create a directory ``futhark``, which you must enter:: $ cd futhark First you must run the following command to download metadata about Futhark's dependencies:: $ make configure To build the Futhark compiler and all of its dependencies, run:: $ make build This step typically requires at least 8GiB of memory. This will create files in your ``~/.cabal`` and ``~/.ghc`` directories. After building, you can copy the binaries to your ``$HOME/.local/bin`` directory by running:: $ make install You can set the ``PREFIX`` environment variable to indicate a different installation path. Note that this does not install the Futhark manual pages. You can delete ``~/.cabal`` and ``~/.ghc`` after this if you wish - the ``futhark`` binary will still work. Installing from a precompiled snapshot -------------------------------------- Tarballs of binary releases can be `found online `_, but are available only for very few platforms (as of this writing, only GNU/Linux on x86_64). See the enclosed ``README.md`` for installation instructions. Furthermore, every day a program automatically clones the Git repository, builds the compiler, and packages a simple tarball containing the resulting binaries, built manpages, and a simple ``Makefile`` for installing. The implication is that these tarballs are not vetted in any way, nor more stable than Git HEAD at any particular moment in time. They are provided for users who wish to use the most recent code, but are unable to compile Futhark themselves. We build such binary snapshots for the following operating systems: **Linux (x86_64)** `futhark-nightly-linux-x86_64.tar.xz `_ **macOS (x86_64)** `futhark-nightly-macos-x86_64.zip `_ **Windows (x86_64)** `futhark-nightly-windows-x86_64.zip `_ You will still likely need to make a C compiler (such as GCC) available on your own. .. _linux-installation: Installing Futhark on Linux --------------------------- * `Homebrew`_ is a distribution-agnostic package manager for macOS and Linux that contains a formula for Futhark. If Homebrew is already installed (which does not require ``root`` access), installation is as easy as:: $ brew install futhark * Arch Linux users can use a `futhark-nightly package `_ or a `regular futhark package `_. * NixOS users can install the ``futhark`` derivation. Otherwise (or if the version in the package system is too old), your best bet is to install from source or use a tarball, as described above. .. _`Linuxbrew`: http://linuxbrew.sh/ .. _macos-installation: Using OpenCL or CUDA ~~~~~~~~~~~~~~~~~~~~ If you wish to use ``futhark opencl`` or ``futhark cuda``, you must have the OpenCL or CUDA libraries installed, respectively. Consult your favourite search engine for instructions on how to do this on your distribution. It is usually not terribly difficult if you already have working GPU drivers. For OpenCL, note that there is a distinction between the general OpenCL host library (``OpenCL.so``) that Futhark links against, and the *Installable Client Driver* (ICD) that OpenCL uses to actually talk to the hardware. You will need both. Working display drivers for the GPU does not imply that an ICD has been installed - they are usually in a separate package. Consult your favourite search engine for details. Installing Futhark on macOS --------------------------- Futhark is available on `Homebrew`_, and the latest release can be installed via:: $ brew install futhark Or you can install the unreleased development version with:: $ brew install --HEAD futhark This has to compile from source, so it takes a little while (20-30 minutes is common). macOS ships with one OpenCL platform and various devices. One of these devices is always the CPU, which is not fully functional, and is never picked by Futhark by default. You can still select it manually with the usual mechanisms (see :ref:`executable-options`), but it is unlikely to be able to run most Futhark programs. Depending on the system, there may also be one or more GPU devices, and Futhark will simply pick the first one as always. On multi-GPU MacBooks, this is is the low-power integrated GPU. It should work just fine, but you might have better performance if you use the dedicated GPU instead. On a Mac with an AMD GPU, this is done by passing ``-dAMD`` to the generated Futhark executable. .. _`Homebrew`: https://brew.sh/ .. _windows-installation: Setting up Futhark on Windows ----------------------------- Due to limited maintenance and testing resources, Futhark is only partially supported on Windows. A precompiled nightly snapshot is available above. In most cases, it is better to install `WSL `_ and follow the Linux instructions above. The C code generated by the Futhark compiler should work on Windows, except for the ``multicore`` backend. Alternatively, you can use the C compiler that is installed with `w64devkit`_ Using HIP ~~~~~~~~~~~~~~~~~~~~ *Note*: dependencies can sometimes move faster than this documentation. This Windows/HIP HowTo was written in Jan 2025 on a setup without WSL, using PowerShell, and after the installation of `w64devkit`_ via `scoop`_. If you wish to use ``futhark hip`` on windows, you must have the `ROCm/HIP SDK`_ installed on your system. The SDK installation will create a ``HIP_PATH`` environment variable on your system pointing to the installed SDK. It is advised to check that this variable has indeed been created. In order for ```futhark hip``` to work you need to setup 3 environment variables in your PowerShell:: # CPATH creation so that the compiler can find the HIP headers $env:CPATH = $env:HIP_PATH + "include" # LIBRARY_PATH creation so that the linker can find the HIP libraries $env:LIBRARY_PATH = $env:HIP_PATH + "lib" # PATH modification so that the compiled app can find the HIP DLLs at runtime $env:PATH = $env:PATH + ";" + $env:HIP_PATH + "bin" .. _`ROCm/HIP SDK`: https://www.amd.com/fr/developer/resources/rocm-hub/hip-sdk.html .. _`w64devkit`: https://github.com/skeeto/w64devkit .. _`scoop` : https://scoop.sh/ Futhark with Nix ---------------- Futhark mostly works fine with Nix and `NixOS `_, but when using OpenCL you may need to make more packages available in your environment. This is regardless of whether you are using the ``futhark`` package from Nixpkgs or one you have installed otherwise. * On NixOS, for OpenCL, you should import ``opencl-headers`` and ``ocl-icd``. You also need some form of OpenCL backend. If you have an AMD GPU and use ROCm, you may also need ``rocm-opencl-runtime``. * On NixOS, for CUDA (and probably also OpenCL on NVIDIA GPUs), you need ``cudatoolkit``. However, ``cudatoolkit`` does not appear to provide ``libcuda.so`` and similar libraries. These are instead provided in an ``nvidia_x11`` package that is specific to some kernel version, e.g. ``linuxPackages_5_4.nvidia_x11``. You will need this as well. * On macOS, for OpenCL, you need ``darwin.apple_sdk.frameworks.OpenCL``. These can be easily made available with e.g:: nix-shell -p opencl-headers -p ocl-icd futhark-0.25.27/docs/js-api.rst000066400000000000000000000113371475065116200162140ustar00rootroot00000000000000.. _js-api: JavaScript API Reference ======================== The :ref:`futhark-wasm(1)` and :ref:`futhark-wasm-multicore(1)` compilers produce JavaScript wrapper code to allow JavaScript programs to invoke the generated WebAssembly code. This chapter describes the API exposed by the wrapper. First a warning: **the JavaScript API is experimental**. It may change incompatibly even in minor versions of the compiler. A Futhark program ``futlib.fut`` compiled with a WASM backend as a library with the ``--library`` command line option produces four files: * ``futlib.c``, ``futlib.h``: Implementation and header C files generated by the compiler, similar to ``futhark c``. You can delete these - they are not needed at run-time. * ``futlib.class.js``: An intermediate build artifact. Feel free to delete it. * ``futlib.wasm``: A compiled WebAssembly module, which must be present at runtime. * ``futlib.mjs``: An ES6 module that can can be imported by other JavaScript code, and implements the API given in the following. The module exports a function, ``newFutharkContext``, which is a factory function that returns a Promise producing a ``FutharkContext`` instance (see below). A simple usage example: .. code-block:: javascript import { newFutharkContext } from './futlib.mjs'; var fc; newFutharkContext().then(x => fc = x); General concerns ---------------- Memory management is completely manual, as JavaScript does not support finalizers that could let Futhark hook into the garbage collector. You are responsible for eventually freeing all objects produced by the API, using the appropriate methods. FutharkContext -------------- FutharkContext is a class that contains information about the context and configuration from the C API. It has methods for invoking the Futhark entry points and creating FutharkArrays on the WebAssembly heap. .. js:function:: newFutharkContext() Asynchronously create a new ``FutharkContext`` object. .. js:class:: FutharkContext() A bookkeeping class representing an instance of a Futhark program. Do *not* directly invoke its constructor - always use the ``newFutharkContext()`` factory function. .. js:function:: FutharkContext.free() Frees all memory created by the ``FutharkContext`` object. Should be called when the ``FutharkContext`` is done being used. It is an error use a ``FutharkArray`` or ``FutharkOpaque`` after the ``FutharkContext`` on which they were defined has been freed. Values ------ Numeric types ``u8``, ``u16``, ``u32``, ``i8``, ``i16``, ``i32``, ``f32``, and ``f64`` are mapped to JavaScript's standard number type. 64-bit integers ``u64``, and ``i64`` are mapped to ``BigInt``. ``bool`` is mapped to JavaScript's ``boolean`` type. Arrays are represented by the ``FutharkArray``. complex types (records, nested tuples, etc) are represented by the ``FutharkOpaque`` class. FutharkArray ------------ ``FutharkArray`` has the following API .. js:function:: FutharkArray.toArray() Returns a nested JavaScript array .. js:function:: FutharkArray.toTypedArray() Returns a flat typed array of the underlying data. .. js:function:: FutharkArray.shape() Returns the shape of the FutharkArray as an array of BigInts. .. js:function:: FutharkArray.free() Frees the memory used by the FutharkArray class ``FutharkContext`` also contains two functions for creating ``FutharkArrays`` from JavaScript arrays, and typed arrays for each array type that appears in an entry point. All array types share similar API methods on the ``FutharkContext``, which is illustrated here for the case of the type ``[]i32``. .. js:function:: FutharkContext.new_i32_1d_from_jsarray(jsarray) Creates and returns a one-dimensional ``i32`` ``FutharkArray`` representing the JavaScript array jsarray .. js:function:: FutharkContext.new_i32_1d(array, dim1) Creates and returns a one-dimensional ``i32`` ``FutharkArray`` representing the typed array of array, with the size given by dim1. FutharkOpaque ------------- Complex types (records, nested tuples, etc) are represented by ``FutharkOpaque``. It has no use outside of being accepted and returned by entry point functions. For this reason the method only has one function for freeing the memory when ``FutharkOpaque`` is no longer used. .. js:function:: FutharkOpaque.free() Frees memory used by FutharkOpaque. Should be called when Futhark Opaque is no longer used. Entry Points ------------ Each entry point in the compiled futhark program has an entry point method on the FutharkContext .. js:function:: FutharkContext.(in1, ..., inN) The entry point function taking the N arguments of the Futhark entry point function, and returns the result. If the result is a tuple the return value is an array. futhark-0.25.27/docs/language-reference.rst000066400000000000000000001774171475065116200205640ustar00rootroot00000000000000.. _language-reference: Language Reference ================== This reference seeks to describe every construct in the Futhark language. It is not presented in a tutorial fashion, but rather intended for quick lookup and documentation of subtleties. For this reason, it is not written in a bottom-up manner, and some concepts may be used before they are fully defined. It is a good idea to have a basic grasp of Futhark (or some other functional programming language) before reading this reference. An ambiguous grammar is given for the full language. The text describes how ambiguities are resolved in practice (for example by applying rules of operator precedence). This reference describes only the language itself. Documentation for the built-in prelude is `available elsewhere `_. Comments -------- Line comments are indicated with ``--`` and continue until end of line. A contiguous block of line comments beginning with ``-- |`` is a *documentation comment* and has special meaning to documentation tools. Documentation comments are only allowed immediately before declarations. Trailing commas --------------- All syntactical elements that involve comma-separated sequencing permit an optional trailing comma. Identifiers and Keywords ------------------------ .. productionlist:: name: `letter` `constituent`* | "_" `constituent`* constituent: `letter` | `digit` | "_" | "'" quals: (`name` ".")+ qualname: `name` | `quals` `name` symbol: `symstartchar` `symchar`* qualsymbol: `symbol` | `quals` `symbol` | "`" `qualname` "`" fieldid: `decimal` | `name` symstartchar: "+" | "-" | "*" | "/" | "%" | "=" | "!" | ">" | "<" | "|" | "&" | "^" symchar: `symstartchar` | "." constructor: "#" `name` Many things in Futhark are named. When we are defining something, we give it an unqualified name (`name`). When referencing something inside a module, we use a qualified name (`qualname`). We can also use symbols (`symbol`, `qualsymbol`), which are treated as infix by the grammar. The constructor names of a sum type are identifiers prefixed with ``#``, with no space afterwards. The fields of a record are named with `fieldid`. Note that a `fieldid` can be a decimal number. Futhark has three distinct name spaces: terms, module types, and types. Modules (including parametric modules) and values both share the term namespace. .. _reserved: Reserved names and symbols ~~~~~~~~~~~~~~~~~~~~~~~~~~ A reserved name or symbol may be used only when explicitly present in the grammar. In particular, they cannot be bound in definitions. The following identifier are reserved: ``true``, ``false``, ``if``, ``then``, ``else``, ``def``, ``let``, ``loop``, ``in``, ``val``, ``for``, ``do``, ``with``, ``local``, ``open``, ``include``, ``import``, ``type``, ``entry``, ``module``, ``while``, ``assert``, ``match``, ``case``. The following symbols are reserved: ``=``. .. _primitives: Primitive Types and Values -------------------------- .. productionlist:: literal: `intnumber` | `floatnumber` | "true" | "false" Boolean literals are written ``true`` and ``false``. The primitive types in Futhark are the signed integer types ``i8``, ``i16``, ``i32``, ``i64``, the unsigned integer types ``u8``, ``u16``, ``u32``, ``u64``, the floating-point types ``f16``, ``f32``, ``f64``, as well as ``bool``. .. productionlist:: int_type: "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" float_type: "f16" | "f32" | "f64" Numeric literals can be suffixed with their intended type. For example ``42i8`` is of type ``i8``, and ``1337e2f64`` is of type ``f64``. If no suffix is given, the type of the literal will be inferred based on its use. If the use is not constrained, integral literals will be assigned type ``i32``, and decimal literals type ``f64``. Hexadecimal literals are supported by prefixing with ``0x``, and binary literals by prefixing with ``0b``. Floats can also be written in hexadecimal format such as ``0x1.fp3``, instead of the usual decimal notation. Here, ``0x1.f`` evaluates to ``1 15/16`` and the ``p3`` multiplies it by ``2^3 = 8``. .. productionlist:: intnumber: (`decimal` | `hexadecimal` | `binary`) [`int_type`] decimal: `decdigit` (`decdigit` |"_")* hexadecimal: 0 ("x" | "X") `hexdigit` (`hexdigit` |"_")* binary: 0 ("b" | "B") `bindigit` (`bindigit` | "_")* .. productionlist:: floatnumber: (`pointfloat` | `exponentfloat` | `hexadecimalfloat`) [`float_type`] pointfloat: [`intpart`] `fraction` exponentfloat: (`intpart` | `pointfloat`) `exponent` hexadecimalfloat: 0 ("x" | "X") `hexintpart` `hexfraction` ("p"|"P") ["+" | "-"] `decdigit`+ intpart: `decdigit` (`decdigit` |"_")* fraction: "." `decdigit` (`decdigit` |"_")* hexintpart: `hexdigit` (`hexdigit` | "_")* hexfraction: "." `hexdigit` (`hexdigit` |"_")* exponent: ("e" | "E") ["+" | "-"] `decdigit`+ .. productionlist:: decdigit: "0"..."9" hexdigit: `decdigit` | "a"..."f" | "A"..."F" bindigit: "0" | "1" Compound Types and Values ~~~~~~~~~~~~~~~~~~~~~~~~~ .. productionlist:: type: `qualname` : | `array_type` : | `tuple_type` : | `record_type` : | `sum_type` : | `function_type` : | `type_application` : | `existential_size` Compound types can be constructed based on the primitive types. The Futhark type system is entirely structural, and type abbreviations are merely shorthands. The only exception is abstract types whose definition has been hidden via the module system (see :ref:`module-system`). .. productionlist:: tuple_type: "(" ")" | "(" `type` ("," `type`)+ [","] ")" A tuple value or type is written as a sequence of comma-separated values or types enclosed in parentheses. For example, ``(0, 1)`` is a tuple value of type ``(i32,i32)``. The elements of a tuple need not have the same type -- the value ``(false, 1, 2.0)`` is of type ``(bool, i32, f64)``. A tuple element can also be another tuple, as in ``((1,2),(3,4))``, which is of type ``((i32,i32),(i32,i32))``. A tuple cannot have just one element, but empty tuples are permitted, although they are not very useful. Empty tuples are written ``()`` and are of type ``()``. .. productionlist:: array_type: "[" [`exp`] "]" `type` An array value is written as a sequence of zero or more comma-separated values enclosed in square brackets: ``[1,2,3]``. An array type is written as ``[d]t``, where ``t`` is the element type of the array, and ``d`` is an expression of type ``i64`` indicating the number of elements in the array. We can elide ``d`` and write just ``[]`` (an :term:`anonymous size`), in which case the size will be inferred. An anonymous size is a syntactic shorthand, and is always replaced by an actual size by the type checker (either via inference or by inventing a new name, depending on context). As an example, an array of three integers could be written as ``[1,2,3]``, and has type ``[3]i32``. An empty array is written as ``[]``, and its type is inferred from its use. When writing Futhark values for such uses as ``futhark test`` (but not when writing programs), empty arrays are written ``empty([0]t)`` for an empty array of type ``[0]t``. When using ``empty``, all dimensions must be given a size, and at least one must be zero, e.g. ``empty([2][0]i32)``. Multi-dimensional arrays are supported in Futhark, but they must be *regular*, meaning that all inner arrays must have the same shape. For example, ``[[1,2], [3,4], [5,6]]`` is a valid array of type ``[3][2]i32``, but ``[[1,2], [3,4,5], [6,7]]`` is not, because there we cannot come up with integers ``m`` and ``n`` such that ``[m][n]i32`` describes the array. The restriction to regular arrays is rooted in low-level concerns about efficient compilation. However, we can understand it in language terms by the inability to write a type with consistent dimension sizes for an irregular array value. In a Futhark program, all array values, including intermediate (unnamed) arrays, must be typeable. .. productionlist:: sum_type: `constructor` `type`* ("|" `constructor` `type`*)* Sum types are anonymous in Futhark, and are written as the constructors separated by vertical bars. Each constructor consists of a ``#``-prefixed *name*, followed by zero or more types, called its *payload*. **Note:** The current implementation of sum types is fairly inefficient, in that all possible constructors of a sum-typed value will be resident in memory. Avoid using sum types where multiple constructors have large payloads. .. productionlist:: record_type: "{" "}" | "{" `fieldid` ":" `type` ("," `fieldid` ":" `type`)* [","] "}" Records are mappings from field names to values, with the field names known statically. A tuple behaves in all respects like a record with numeric field names starting from zero, and vice versa. It is an error for a record type to name the same field twice. A trailing comma is permitted. .. productionlist:: type_application: `type` `type_arg` | "*" `type` type_arg: "[" [`dim`] "]" | `type` A parametric type abbreviation can be applied by juxtaposing its name and its arguments. The application must provide as many arguments as the type abbreviation has parameters - partial application is presently not allowed. See `Type Abbreviations`_ for further details. .. productionlist:: function_type: `param_type` "->" `type` param_type: `type` | "(" `name` ":" `type` ")" Functions are classified via function types, but they are not fully first class. See :ref:`hofs` for the details. .. productionlist:: stringlit: '"' `stringchar`* '"' stringchar: charlit: "'" `char` "'" char: String literals are supported, but only as syntactic sugar for UTF-8 encoded arrays of ``u8`` values. There is no character type in Futhark, but character literals are interpreted as integers of the corresponding Unicode code point. .. productionlist:: existential_size: "?" ("[" `name` "]")+ "." `type` An existential size quantifier brings an unknown size into scope within a type. This can be used to encode constraints for statically unknown array sizes. Declarations ------------ A Futhark module consists of a sequence of declarations. Files are also modules. Each declaration is processed in order, and a declaration can only refer to names bound by preceding declarations. .. productionlist:: dec: `val_bind` | `type_bind` | `mod_bind` | `mod_type_bind` : | "open" `mod_exp` : | "import" `stringlit` : | "local" `dec` : | "#[" `attr` "]" `dec` Any names defined by a declaration inside a module are by default visible to users of that module (see :ref:`module-system`). * ``open mod_exp`` brings names bound in ``mod_exp`` into the current scope. These names will also be visible to users of the module. * ``local dec`` has the meaning of ``dec``, but any names bound by ``dec`` will not be visible outside the module. * ``import "foo"`` is a shorthand for ``local open import "foo"``, where the ``import`` is interpreted as a module expression (see :ref:`module-system`). * ``#[attr] dec`` adds an attribute to a declaration (see :ref:`attributes`). Declaring Functions and Values ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. productionlist:: val_bind: ("def" | "entry" | "let") (`name` | "(" `symbol` ")") `type_param`* `pat`* [":" `type`] "=" `exp` : | ("def" | "entry" | "let") `pat` `symbol` `pat` [":" `type`] "=" `exp` **Note:** using ``let`` to define top-level bindings is deprecated. Functions and constants must be defined before they are used. A function declaration must specify the name, parameters, and body of the function:: def name params...: rettype = body Hindley-Milner-style type inference is supported. A parameter may be given a type with the notation ``(name: type)``. Functions may not be recursive. The sizes of the arguments can be constrained - see `Size Types`_. A function can be *polymorphic* by using type parameters, in the same way as for `Type Abbreviations`_:: def reverse [n] 't (xs: [n]t): [n]t = xs[::-1] Type parameters for a function do not need to cover the types of all parameters. The type checker will add more if necessary. For example, the following is well typed:: def pair 'a (x: a) y = (x, y) A new type variable will be invented for the parameter ``y``. Shape and type parameters are not passed explicitly when calling function, but are automatically derived. If an array value *v* is passed for a type parameter *t*, all other arguments passed of type *t* must have the same shape as *v*. For example, consider the following definition:: def pair 't (x: t) (y: t) = (x, y) The application ``pair [1] [2,3]`` is ill-typed. To simplify the handling of in-place updates (see :ref:`in-place-updates`), the value returned by a function may not alias any global variables. User-Defined Operators ~~~~~~~~~~~~~~~~~~~~~~ Infix operators are defined much like functions:: def (p1: t1) op (p2: t2): rt = ... For example:: def (a:i32,b:i32) +^ (c:i32,d:i32) = (a+c, b+d) We can also define operators by enclosing the operator name in parentheses and suffixing the parameters, as an ordinary function:: def (+^) (a:i32,b:i32) (c:i32,d:i32) = (a+c, b+d) This is necessary when defining a polymorphic operator. A valid operator name is a non-empty sequence of characters chosen from the string ``"+-*/%=!><&^"``. The fixity of an operator is determined by its first characters, which must correspond to a built-in operator. Thus, ``+^`` binds like ``+``, whilst ``*^`` binds like ``*``. The longest such prefix is used to determine fixity, so ``>>=`` binds like ``>>``, not like ``>``. It is not permitted to define operators with the names ``&&`` or ``||`` (although these as prefixes are accepted). This is because a user-defined version of these operators would not be short-circuiting. User-defined operators behave exactly like ordinary functions, except for being infix. A built-in operator can be shadowed (i.e. a new ``+`` can be defined). This will result in the built-in polymorphic operator becoming inaccessible, except through the ``intrinsics`` module. An infix operator can also be defined with prefix notation, like an ordinary function, by enclosing it in parentheses:: def (+) (x: i32) (y: i32) = x - y This is necessary when defining operators that take type or shape parameters. .. _entry-points: Entry Points ~~~~~~~~~~~~ Apart from declaring a function with the keyword ``def``, it can also be declared with ``entry``. When the Futhark program is compiled any top-level function declared with ``entry`` will be exposed as an entry point. If the Futhark program has been compiled as a library, these are the functions that will be exposed. If compiled as an executable, you can use the ``--entry-point`` command line option of the generated executable to select the entry point you wish to run. Any top-level function named ``main`` will always be considered an entry point, whether it is declared with ``entry`` or not. The name of an entry point must not contain an apostrophe (``'``), even though that is normally permitted in Futhark identifiers. Value Declarations ~~~~~~~~~~~~~~~~~~ A named value/constant can be declared as follows:: def name: type = definition The definition can be an arbitrary expression, including function calls and other values, although they must be in scope before the value is defined. If the return type contains any anonymous sizes (see `Size types`_), new existential sizes will be constructed for them. .. _typeabbrevs: Type Abbreviations ~~~~~~~~~~~~~~~~~~ .. productionlist:: type_bind: ("type" | "type^" | "type~") `name` `type_param`* "=" `type` type_param: "[" `name` "]" | "'" `name` | "'~" `name` | "'^" `name` Type abbreviations function as shorthands for the purpose of documentation or brevity. After a type binding ``type t1 = t2``, the name ``t1`` can be used as a shorthand for the type ``t2``. Type abbreviations do not create distinct types: the types ``t1`` and ``t2`` are entirely interchangeable. If the right-hand side of a type contains existential sizes, it must be declared "size-lifted" with ``type~``. If it (potentially) contains a function, it must be declared "fully lifted" with ``type^``. A lifted type can also contain existential sizes. Lifted types cannot be put in arrays. Fully lifted types cannot be returned from conditional or loop expressions. A type abbreviation can have zero or more parameters. A type parameter enclosed with square brackets is a *size parameter*, and can be used in the definition as an array size, or as a size argument to other type abbreviations. When passing an argument for a shape parameter, it must be enclosed in square brackets. Example:: type two_intvecs [n] = ([n]i32, [n]i32) def x: two_intvecs [2] = (iota 2, replicate 2 0) When referencing a type abbreviation, size parameters work much like array sizes. Like sizes, they can be passed an anonymous size (``[]``). All size parameters must be used in the definition of the type abbreviation. A type parameter prefixed with a single quote is a *type parameter*. It is in scope as a type in the definition of the type abbreviation. Whenever the type abbreviation is used in a type expression, a type argument must be passed for the parameter. Type arguments need not be prefixed with single quotes:: type two_vecs [n] 't = ([n]t, [n]t) type two_intvecs [n] = two_vecs [n] i32 def x: two_vecs [2] i32 = (iota 2, replicate 2 0) A *size-lifted type parameter* is prefixed with ``'~``, and a *fully lifted type parameter* with ``'^``. These have the same rules and restrictions as lifted type abbreviations. Expressions ----------- Expressions are the basic construct of any Futhark program. An expression has a statically determined *type*, and produces a *value* at runtime. Futhark is an eager/strict language ("call by value"). The basic elements of expressions are called *atoms*, for example literals and variables, but also more complicated forms. .. productionlist:: atom: `literal` : | `qualname` ("." `fieldid`)* : | `stringlit` : | `charlit` : | "(" ")" : | "(" `exp` ")" ("." `fieldid`)* : | "(" `exp` ("," `exp`)+ [","] ")" : | "{" "}" : | "{" `field` ("," `field`)* [","] "}" : | `qualname` `slice` : | "(" `exp` ")" `slice` : | `quals` "." "(" `exp` ")" : | "[" `exp` ("," `exp`)* [","] "]" : | "(" `qualsymbol` ")" : | "(" `exp` `qualsymbol` ")" : | "(" `qualsymbol` `exp` ")" : | "(" ( "." `field` )+ ")" : | "(" "." `slice` ")" : | "???" exp: `atom` : | `exp` `qualsymbol` `exp` : | `exp` `exp` : | "!" `exp` : | "-" `exp` : | `constructor` `exp`* : | `exp` ":" `type` : | `exp` ":>" `type` : | `exp` [ ".." `exp` ] "..." `exp` : | `exp` [ ".." `exp` ] "..<" `exp` : | `exp` [ ".." `exp` ] "..>" `exp` : | "if" `exp` "then" `exp` "else" `exp` : | "let" `size`* `pat` "=" `exp` "in" `exp` : | "let" `name` `slice` "=" `exp` "in" `exp` : | "let" `name` `type_param`* `pat`+ [":" `type`] "=" `exp` "in" `exp` : | "(" "\" `pat`+ [":" `type`] "->" `exp` ")" : | "loop" `pat` ["=" `exp`] `loopform` "do" `exp` : | "#[" `attr` "]" `exp` : | "unsafe" `exp` : | "assert" `atom` `atom` : | `exp` "with" `slice` "=" `exp` : | `exp` "with" `fieldid` ("." `fieldid`)* "=" `exp` : | "match" `exp` ("case" `pat` "->" `exp`)+ slice: "[" `index` ("," `index`)* [","] "]" field: `fieldid` "=" `exp` : | `name` size : "[" `name` "]" pat: `name` : | `pat_literal` : | "_" : | "(" ")" : | "(" `pat` ")" : | "(" `pat` ("," `pat`)+ [","] ")" : | "{" "}" : | "{" `fieldid` ["=" `pat`] ("," `fieldid` ["=" `pat`])* [","] "}" : | `constructor` `pat`* : | `pat` ":" `type` : | "#[" `attr` "]" `pat` pat_literal: [ "-" ] `intnumber` : | [ "-" ] `floatnumber` : | `charlit` : | "true" : | "false" loopform : "for" `name` "<" `exp` : | "for" `pat` "in" `exp` : | "while" `exp` index: `exp` [":" [`exp`]] [":" [`exp`]] : | [`exp`] ":" `exp` [":" [`exp`]] : | [`exp`] [":" `exp`] ":" [`exp`] Some of the built-in expression forms have parallel semantics, but it is not guaranteed that the the parallel constructs in Futhark are evaluated in parallel, especially if they are nested in complicated ways. Their purpose is to give the compiler as much freedom and information is possible, in order to enable it to maximise the efficiency of the generated code. Resolving Ambiguities ~~~~~~~~~~~~~~~~~~~~~ The above grammar contains some ambiguities, which in the concrete implementation is resolved via a combination of lexer and grammar transformations. For ease of understanding, they are presented here in natural text. * An expression ``x.y`` may either be a reference to the name ``y`` in the module ``x``, or the field ``y`` in the record ``x``. Modules and values occupy the same name space, so this is disambiguated by whether ``x`` is a value or module. * A type ascription (``exp : type``) cannot appear as an array index, as it conflicts with the syntax for slicing. * In ``f [x]``, there is an ambiguity between indexing the array ``f`` at position ``x``, or calling the function ``f`` with the singleton array ``x``. We resolve this the following way: * If there is a space between ``f`` and the opening bracket, it is treated as a function application. * Otherwise, it is an array index operation. * An expression ``(-x)`` is parsed as the variable ``x`` negated and enclosed in parentheses, rather than an operator section partially applying the infix operator ``-``. * Prefix operators bind more tighly than infix operators. Note that the only prefix operators are the builtin ``!`` and ``-``, and more cannot be defined. In particular, a user-defined operator beginning with ``!`` binds as ``!=``, as on the table below, not as the prefix operator ``!`` * Function and type application binds more tightly than infix operators. * ``#foo #bar`` is interpreted as a constructor with a ``#bar`` payload, not as applying ``#foo`` to ``#bar`` (the latter would be semantically invalid anyway). * `Attributes`_ bind less tightly than any other syntactic construct. * A type application ``pt [n]t`` is parsed as an application of the type constructor ``pt`` to the size argument ``[n]`` and the type ``t``. To pass a single array-typed parameter, enclose it in parens. * The bodies of ``let``, ``if``, and ``loop`` extend as far to the right as possible. * The following table describes the precedence and associativity of infix operators in both expressions and type expressions. All operators in the same row have the same precedence. The rows are listed in increasing order of precedence. Note that not all operators listed here are used in expressions; nevertheless, they are still used for resolving ambiguities. ================= ============= **Associativity** **Operators** ================= ============= left ``,`` left ``:``, ``:>`` left ```symbol``` left ``||`` left ``&&`` left ``<=`` ``>=`` ``>`` ``<`` ``==`` ``!=`` ``!`` ``=`` left ``&`` ``^`` ``|`` left ``<<`` ``>>`` left ``+`` ``-`` left ``*`` ``/`` ``%`` ``//`` ``%%`` left ``|>`` right ``<|`` right ``->`` left ``**`` left juxtaposition ================= ============= .. _patterns: Patterns ~~~~~~~~ We say that a pattern is *irrefutable* if it can never fail to match a value of the appropriate type. Concretely, this means that it does not require any specific sum type constructor (unless the type in question has only a single constructor), or any specific numeric or boolean literal. Patterns used in function parameters and ``let`` bindings must be irrefutable. Patterns used in ``case`` need not be irrefutable. A pattern ``_`` matches any value. A pattern consisting of a literal value (e.g. a numeric constant) matches exactly that value. Semantics of Simple Expressions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ `literal` ......... Evaluates to itself. `qualname` .......... A variable name; evaluates to its value in the current environment. `stringlit` ........... Evaluates to an array of type ``[]u8`` that contains the characters encoded as UTF-8. ``()`` ...... Evaluates to an empty tuple. ``( e )`` ......... Evaluates to the result of ``e``. ``???`` ....... A *typed hole*, usable as a placeholder expression. The type checker will infer any necessary type for this expression. This can sometimes result in an ambiguous type, which can be resolved using a type ascription. Evaluating a typed hole results in a run-time error. ``(e1, e2, ..., eN)`` ..................... Evaluates to a tuple containing ``N`` values. Equivalent to the record literal ``{0=e1, 1=e2, ..., N-1=eN}``. ``{f1, f2, ..., fN}`` ..................... A record expression consists of a comma-separated sequence of *field expressions*. Each field expression defines the value of a field in the record. A field expression can take one of two forms: ``f = e``: defines a field with the name ``f`` and the value resulting from evaluating ``e``. ``f``: defines a field with the name ``f`` and the value of the variable ``f`` in scope. Each field may only be defined once. ``a[i]`` ........ Return the element at the given position in the array. The index may be a comma-separated list of indexes instead of just a single index. If the number of indices given is less than the rank of the array, an array is returned. The index may be of any unsigned integer type. The array ``a`` must be a variable name or a parenthesised expression. Furthermore, there *may not* be a space between ``a`` and the opening bracket. This disambiguates the array indexing ``a[i]``, from ``a [i]``, which is a function call with a literal array. .. _slices: ``a[i:j:s]`` ............ Return a slice of the array ``a`` from index ``i`` to ``j``, the former inclusive and the latter exclusive, taking every ``s``-th element. The ``s`` parameter may not be zero. If ``s`` is negative, it means to start at ``i`` and descend by steps of size ``s`` to ``j`` (not inclusive). Slicing can be done only with expressions of type ``i64``. It is generally a bad idea for ``s`` to be non-constant. Slicing of multiple dimensions can be done by separating with commas, and may be intermixed freely with indexing. If ``s`` is elided it defaults to ``1``. If ``i`` or ``j`` is elided, their value depends on the sign of ``s``. If ``s`` is positive, ``i`` become ``0`` and ``j`` become the length of the array. If ``s`` is negative, ``i`` becomes the length of the array minus one, and ``j`` becomes minus one. This means that ``a[::-1]`` is the reverse of the array ``a``. In the general case, the size of the array produced by a slice is unknown (see `Size types`_). In a few cases, the size is known statically: * ``a[0:n]`` has size ``n`` * ``a[:n]`` has size ``n`` * ``a[0:n:1]`` has size ``n`` * ``a[:n:1]`` has size ``n`` This holds only if ``n`` is a variable or constant. ``[x, y, z]`` ............. Create an array containing the indicated elements. Each element must have the same type and shape. **Large array optimisation**: as a special case, large one-dimensional array literal consisting *entirely* of monomorphic constants (i.e., numbers must have a type suffix) are handled with specialised fast-path code by the compiler. To keep compile times manageable, make sure that all very large array literals (more than about ten thousand elements) are of this form. This is likely relevant only for generated code. .. _range: ``x..y...z`` ............ Construct a signed integer array whose first element is ``x`` and which proceeds with a stride of ``y-x`` until reaching ``z`` (inclusive). The ``..y`` part can be elided in which case a stride of ``1`` is used. All components must be of the same unsigned integer type. A run-time error occurs if ``z`` is less than ``x`` or ``y``, or if ``x`` and ``y`` are the same value. In the general case, the size of the array produced by a range is unknown (see `Size types`_). In a few cases, the size is known statically: * ``0..z`` ............... Construct a signed integer array whose first elements is ``x``, and which proceeds downwards with a stride of ``y-x`` until reaching ``z`` (exclusive). The ``..y`` part can be elided in which case a stride of -1 is used. A run-time error occurs if ``z`` is greater than ``x`` or ``y``, or if ``x`` and ``y`` are the same value. ``e.f`` ........ Access field ``f`` of the expression ``e``, which must be a record or tuple. ``m.(e)`` ......... Evaluate the expression ``e`` with the module ``m`` locally opened, as if by ``open``. This can make some expressions easier to read and write, without polluting the global scope with a declaration-level ``open``. ``x`` *binop* ``y`` ................... Apply an operator to ``x`` and ``y``. Operators are functions like any other, and can be user-defined. Futhark pre-defines certain "magical" *overloaded* operators that work on several types. Overloaded operators cannot be defined by the user. Both operands must have the same type. The predefined operators and their semantics are: ``**`` Power operator, defined for all numeric types. ``//``, ``%%`` Division and remainder on integers, with rounding towards zero. ``*``, ``/``, ``%``, ``+``, ``-`` The usual arithmetic operators, defined for all numeric types. Note that ``/`` and ``%`` rounds towards negative infinity when used on integers - this is different from in C. ``^``, ``&``, ``|``, ``>>``, ``<<`` Bitwise operators, respectively bitwise xor, and, or, arithmetic shift right and left, and logical shift right. **Shifting is undefined if the right operand is negative, or greater than or equal to the length in bits of the left operand.** Note that, unlike in C, bitwise operators have *higher* priority than arithmetic operators. This means that ``x & y == z`` is understood as ``(x & y) == z``, rather than ``x & (y == z)`` as it would in C. Note that the latter is a type error in Futhark anyhow. ``==``, ``!=`` Compare any two values of builtin or compound type for equality. ``<``, ``<=``. ``>``, ``>=`` Company any two values of numeric type for equality. ```qualname``` Use ``qualname``, which may be any non-operator function name, as an infix operator. ``x && y`` .......... Short-circuiting logical conjunction; both operands must be of type ``bool``. ``x || y`` .......... Short-circuiting logical disjunction; both operands must be of type ``bool``. ``f x`` ....... Apply the function ``f`` to the argument ``x``. ``#c x y z`` ............ Apply the sum type constructor ``#c`` to the payload ``x``, ``y``, and ``z``. A constructor application is always assumed to be saturated, i.e. its entire payload provided. This means that constructors may not be partially applied. ``e : t`` ......... Annotate that ``e`` is expected to be of type ``t``, failing with a type error if it is not. If ``t`` is an array with shape declarations, the correctness of the shape declarations is checked at run-time. Due to ambiguities, this syntactic form cannot appear as an array index expression unless it is first enclosed in parentheses. However, as an array index must always be of type ``i64``, there is never a reason to put an explicit type ascription there. ``e :> t`` .......... Coerce the size of ``e`` to ``t``. The type of ``t`` must match the type of ``e``, except that the sizes may be statically different. At run-time, it will be verified that the sizes are the same. ``! x`` ....... Logical negation if ``x`` is of type ``bool``. Bitwise negation if ``x`` is of integral type. ``- x`` ....... Numerical negation of ``x``, which must be of numeric type. ``#[attr] e`` ............. Apply the given attribute to the expression. Attributes are an ad-hoc and optional mechanism for providing extra information, directives, or hints to the compiler. See :ref:`attributes` for more information. ``unsafe e`` ............ Elide safety checks and assertions (such as bounds checking) that occur during execution of ``e``. This is useful if the compiler is otherwise unable to avoid bounds checks (e.g. when using indirect indexes), but you really do not want them there. Make very sure that the code is correct; eliding such checks can lead to memory corruption. This construct is deprecated. Use the ``#[unsafe]`` attribute instead. .. _assert: ``assert cond e`` ................. Terminate execution with an error if ``cond`` evaluates to false, otherwise produce the result of evaluating ``e``. Unless ``e`` produces a value that is used subsequently (it can just be a variable), dead code elimination may remove the assertion. ``a with [i] = e`` ................... Return ``a``, but with the element at position ``i`` changed to contain the result of evaluating ``e``. Consumes ``a``. .. _record_update: ``r with f = e`` ................. Return the record ``r``, but with field ``f`` changed to have value ``e``. The type of the field must remain unchanged. Type inference is limited: ``r`` must have a *completely known type* up to ``f``. This sometimes requires extra type annotations to make the type of ``r`` known. ``if c then a else b`` ...................... If ``c`` evaluates to ``true``, evaluate ``a``, else evaluate ``b``. Binding Expressions ~~~~~~~~~~~~~~~~~~~ ``let pat = e in body`` ....................... Evaluate ``e`` and bind the result to the irrefutable pattern ``pat`` (see :ref:`patterns`) while evaluating ``body``. The ``in`` keyword is optional if ``body`` is a ``let`` expression. The binding is not let-generalised, meaning it has a monomorphic type. This can be significant if ``e`` is of functional type. If ``e`` is of type ``i64`` and ``pat`` binds only a single name ``v``, then the type of the overall expression is the type of ``body``, but with any occurence of ``v`` replaced by ``e``. ``let [n] pat = e in body`` ........................... As above, but bind sizes (here ``n``) used in the pattern (here to the size of the array being bound). All sizes must be used in the pattern. Roughly Equivalent to ``let f [n] pat = body in f e``. ``let a[i] = v in body`` ........................ Write ``v`` to ``a[i]`` and evaluate ``body``. The given index need not be complete and can also be a slice, but in these cases, the value of ``v`` must be an array of the proper size. This notation is Syntactic sugar for ``let a = a with [i] = v in a``. ``let f params... = e in body`` ............................... Bind ``f`` to a function with the given parameters and definition (``e``) and evaluate ``body``. The function will be treated as aliasing any free variables in ``e``. The function is not in scope of itself, and hence cannot be recursive. ``loop pat = initial for x in a do loopbody`` ............................................. 1. Bind ``pat`` to the initial values given in ``initial``. 2. For each element ``x`` in ``a``, evaluate ``loopbody`` and rebind ``pat`` to the result of the evaluation. 3. Return the final value of ``pat``. The ``= initial`` can be left out, in which case initial values for the pattern are taken from equivalently named variables in the environment. I.e., ``loop (x) = ...`` is equivalent to ``loop (x = x) = ...``. ``loop pat = initial for x < n do loopbody`` ............................................ Equivalent to ``loop (pat = initial) for x in [0..1.. e1 case p2 -> e2`` ....................................... Match the value produced by ``x`` to each of the patterns in turn, picking the first one that succeeds. The result of the corresponding expression is the value of the entire ``match`` expression. All the expressions associated with a ``case`` must have the same type (but not necessarily match the type of ``x``). It is a type error if there is not a ``case`` for every possible value of ``x`` - inexhaustive pattern matching is not allowed. Function Expressions ~~~~~~~~~~~~~~~~~~~~ ``\x y z: t -> e`` .................. Produces an anonymous function taking parameters ``x``, ``y``, and ``z``, returns type ``t``, and whose body is ``e``. Lambdas do not permit type parameters; use a named function if you want a polymorphic function. ``(binop)`` ........... An *operator section* that is equivalent to ``\x y -> x *binop* y``. ``(x binop)`` ............. An *operator section* that is equivalent to ``\y -> x *binop* y``. ``(binop y)`` ............. An *operator section* that is equivalent to ``\x -> x *binop* y``. ``(.a.b.c)`` ............ An *operator section* that is equivalent to ``\x -> x.a.b.c``. ``(.[i,j])`` ............ An *operator section* that is equivalent to ``\x -> x[i,j]``. .. _hofs: Higher-order functions ---------------------- At a high level, Futhark functions are values, and can be used as any other value. However, to ensure that the compiler is able to compile the higher-order functions efficiently via *defunctionalisation*, certain type-driven restrictions exist on how functions can be used. These also apply to any record or tuple containing a function (a *functional type*): * Arrays of functions are not permitted. * A function cannot be returned from an ``if`` expression. * A ``loop`` parameter cannot be a function. Further, *type parameters* are divided into *non-lifted* (bound with an apostrophe, e.g. ``'t``), *size-lifted* (``'~t``), and *fully lifted* (``'^t``). Only fully lifted type parameters may be instantiated with a functional type. Within a function, a lifted type parameter is treated as a functional type. See also `In-place updates`_ for details on how consumption interacts with higher-order functions. Type Inference -------------- Futhark supports Hindley-Milner-style type inference, so in many cases explicit type annotations can be left off. Record field projection cannot in isolation be fully inferred, and may need type annotations where their inputs are bound. The same goes when constructing sum types, as Futhark cannot assume that a given constructor only belongs to a single type. Further, consumed parameters (see `In-place updates`_) must be explicitly annotated. Type inference processes top-level declared in top-down order, and the type of a top-level function must be completely inferred at its definition site. Specifically, if a top-level function uses overloaded arithmetic operators, the resolution of those overloads cannot be influenced by later uses of the function. Local bindings made with ``let`` are not made polymorphic through let-generalisation *unless* they are syntactically functions, meaning they have at least one named parameter. .. _size-types: Size Types ---------- Futhark supports a system of size-dependent types that statically checks that the sizes of arrays passed to a function are compatible. Whenever a pattern occurs (in ``let``, ``loop``, and function parameters), as well as in return types, the types of the bindings express invariants about the shapes of arrays that are accepted or produced by the function. For example:: def f [n] (a: [n]i32) (b: [n]i32): [n]i32 = map2 (+) a b We use a *size parameter*, ``[n]``, to explicitly quantify a size. The ``[n]`` parameter is not explicitly passed when calling ``f``. Rather, its value is implicitly deduced from the arguments passed for the value parameters. An array type can contain *anonymous sizes*, e.g. ``[]i32``, for which the type checker will invent fresh size parameters, which ensures that all arrays have a size. On the right-hand side of a function arrow ("return types"), this results in an *existential size* that is not known until the function is fully applied, e.g:: val filter [n] 'a : (p: a -> bool) -> (as: [n]a) -> ?[k].[k]a Sizes can be any expression of type ``i64`` that does not consume any free variables. Size parameters can be used as ordinary variables of type ``i64`` within the scope of the parameters. The type checker verifies that the program obeys any constraints imposed by size annotations. *Size-dependent types* are supported, as the names of parameters can be used in the return type of a function:: def replicate 't (n: i64) (x: t): [n]t = ... An application ``replicate 10 0`` will have type ``[10]i32``. Whenever we write a type ``[e]t``, ``e`` must be a well-typed expression of type ``i64`` in scope (possibly by referencing names bound as a size parameter). .. _unknown-sizes: Unknown sizes ~~~~~~~~~~~~~ There are cases where the type checker cannot assign a precise size to the result of some operation. For example, the type of ``filter`` is:: val filter [n] 'a : (a -> bool) -> [n]t -> ?[m].[m]t The function returns of an array of *some existential size* ``m``, but it cannot be known in advance. When an application ``filter p xs`` is found, the result will be of type ``[k]t``, where ``k`` is a fresh *unknown size* that is considered distinct from every other size in the program. It is sometimes necessary to perform a size coercion (see `Size coercion`_) to convert an unknown size to a known size. Generally, unknown sizes are constructed whenever the true size cannot be expressed. The following lists all possible sources of unknown sizes. Size going out of scope ....................... An unknown size is created in some cases when the a type references a name that has gone out of scope:: match ... case #some c -> replicate c 0 The type of ``replicate c 0`` is ``[c]i32``, but since ``c`` is locally bound, the type of the entire expression is ``[k]i32`` for some fresh ``k``. Consuming expression passed as function argument ................................................ The type of ``replicate e 0`` should be ``[e]i32``, but if ``e`` is an expression that is not valid as a size, this is not expressible. Therefore an unknown size ``k`` is created and the size of the expression becomes ``[k]i32``. Compound expression used as range bound ....................................... While a simple range expression such as ``0..`. Complex ranges .............. Most complex ranges, such as ``a..` and :ref:`"upto" ranges `. Existential size in function return type ........................................ Whenever the result of a function application has an existential size, that size is replaced with a fresh unknown size variable. For example, ``filter`` has the following type:: val filter [n] 'a : (p: a -> bool) -> (as: [n]a) -> ?[k].[k]a For an application ``filter f xs``, the type checker invents a fresh unknown size ``k'``, and the actual type for this specific application will be ``[k']a``. Branches of ``if`` return arrays of different sizes ................................................... When an ``if`` (or ``match``) expression has branches that returns array of different sizes, the differing sizes will be replaced with fresh unknown sizes. For example:: if b then [[1,2], [3,4]] else [[5,6]] This expression will have type ``[k][2]i32``, for some fresh ``k``. **Important:** The check whether the sizes differ is done when first encountering the ``if`` or ``match`` during type checking. At this point, the type checker may not realise that the two sizes are actually equal, even though constraints later in the function force them to be. This can always be resolved by adding type annotations. An array produced by a loop does not have a known size ...................................................... If the size of some loop parameter is not maintained across a loop iteration, the final result of the loop will contain unknown sizes. For example:: loop xs = [1] for i < n do xs ++ xs Similar to conditionals, the type checker may sometimes be too cautious in assuming that some size may change during the loop. Adding type annotations to the loop parameter can be used to resolve this. .. _size-coercion: Size coercion ~~~~~~~~~~~~~ Size coercion, written with ``:>``, can be used to perform a runtime-checked coercion of one size to another. This can be useful as an escape hatch in the size type system:: def concat_to 'a (m: i32) (a: []a) (b: []a) : [m]a = a ++ b :> [m]a .. _causality: Causality restriction ~~~~~~~~~~~~~~~~~~~~~ Conceptually, size parameters are assigned their value by reading the sizes of concrete values passed along as parameters. This means that any size parameter must be used as the size of some parameter. This is an error:: def f [n] (x: i32) = n The following is not an error:: def f [n] (g: [n]i32 -> [n]i32) = ... However, using this function comes with a constraint: whenever an application ``f x`` occurs, the value of the size parameter must be inferable. Specifically, this value must have been used as the size of an array *before* the ``f x`` application is encountered. The notion of "before" is subtle, as there is no evaluation ordering of a Futhark expression, *except* that a ``let``-binding is always evaluated before its body, the argument to a function is always evaluated before the function itself, and the left operand to an operator is evaluated before the right. The causality restriction only occurs when a function has size parameters whose first use is *not* as a concrete array size. For example, it does not apply to uses of the following function:: def f [n] (arr: [n]i32) (g: [n]i32 -> [n]i32) = ... This is because the proper value of ``n`` can be read directly from the actual size of the array. Empty array literals ~~~~~~~~~~~~~~~~~~~~ Just as with size-polymorphic functions, when constructing an empty array, we must know the exact size of the (missing) elements. For example, in the following program we are forcing the elements of ``a`` to be the same as the elements of ``b``, but the size of the elements of ``b`` are not known at the time ``a`` is constructed:: def main (b: bool) (xs: []i32) = let a = [] : [][]i32 let b = [filter (>0) xs] in a[0] == b[0] The result is a type error. Sum types ~~~~~~~~~ When constructing a value of a sum type, the compiler must still be able to determine the size of the constructors that are *not* used. This is illegal:: type sum = #foo ([]i32) | #bar ([]i32) def main (xs: *[]i32) = let v : sum = #foo xs in xs Modules ~~~~~~~ When matching a module with a module type (see :ref:`module-system`), a non-lifted abstract type (i.e. one that is declared with ``type`` rather than ``type^``) may not be implemented by a type abbreviation that contains any existential sizes. This is to ensure that if we have the following:: module m : { type t } = ... Then we can construct an array of values of type ``m.t`` without worrying about constructing an irregular array. Higher-order functions ~~~~~~~~~~~~~~~~~~~~~~ When a higher-order function takes a functional argument whose return type is a non-lifted type parameter, any instantiation of that type parameter must have a non-existential size. If the return type is a lifted type parameter, then the instantiation may contain existential sizes. This is why the type of ``map`` guarantees regular arrays:: val map [n] 'a 'b : (a -> b) -> [n]a -> [n]b The type parameter ``b`` can only be replaced with a type that has non-existential sizes, which means they must be the same for every application of the function. In contrast, this is the type of the pipeline operator:: val (|>) '^a -> '^b : a -> (a -> b) -> b The provided function can return something with an existential size (such as ``filter``). A function whose return type has an unknown size ................................................ If a function (named or anonymous) is inferred to have a return type that contains an unknown size variable created *within* the function body, that size variable will be replaced with an existential size. In most cases this is not important, but it means that an expression like the following is ill-typed:: map (\xs -> iota (length xs)) (xss : [n][m]i32) This is because the ``(length xs)`` expression gives rise to some fresh size ``k``. The lambda is then assigned the type ``[n]t -> [k]i32``, which is immediately turned into ``[n]t -> ?[k].[k]i32`` because ``k`` was generated inside its body. A function of this type cannot be passed to ``map``, as explained before. The solution is to bind ``length`` to a name *before* the lambda. .. _in-place-updates: In-place Updates ---------------- In-place updates do not provide observable side effects, but they do provide a way to efficiently update an array in-place, with the guarantee that the cost is proportional to the size of the value(s) being written, not the size of the full array. The ``a with [i] = v`` language construct, and derived forms, performs an in-place update. The compiler verifies that the original array (``a``) is not used on any execution path following the in-place update. This involves also checking that no *alias* of ``a`` is used. Generally, most language constructs produce new arrays, but some (slicing) create arrays that alias their input arrays. When defining a function parameter we can mark it as *consuming* by prefixing it with an asterisk. For a return type, we can mark it as *alias-free* by prefixing it with an asterisk. For example:: def modify (a: *[]i32) (i: i32) (x: i32): *[]i32 = a with [i] = a[i] + x A parameter that is not consuming is called *observing*. In the parameter declaration ``a: *[i32]``, the asterisk means that the function ``modify`` has been given "ownership" of the array ``a``, meaning that any caller of ``modify`` will never reference array ``a`` after the call again. This allows the ``with`` expression to perform an in-place update. After a call ``modify a i x``, neither ``a`` or any variable that *aliases* ``a`` may be used on any following execution path. If an asterisk is present at *any point* inside a tuple parameter type, the parameter as a whole is considered consuming. For example:: def consumes_both ((a,b): (*[]i32,[]i32)) = ... This is usually not desirable behaviour. Use multiple parameters instead:: def consumes_first_arg (a: *[]i32) (b: []i32) = ... For bulk in-place updates with multiple values, use the ``scatter`` function from the `prelude `_. Alias Analysis ~~~~~~~~~~~~~~ The rules used by the Futhark compiler to determine aliasing are intuitive in the intra-procedural case. Aliases are associated with entire arrays. Aliases of a record are tuple are tracked for each element, not for the record or tuple itself. Most constructs produce fresh arrays, with no aliases. The main exceptions are ``if``, ``loop``, function calls, and variable literals. * After a binding ``let a = b``, that simply assigns a new name to an existing variable, the variable ``a`` aliases ``b``. Similarly for record projections and patterns. * The result of an ``if`` aliases the union of the aliases of the components. * The result of a ``loop`` aliases the initial values, as well as any aliases that the merge parameters may assume at the end of an iteration, computed to a fixed point. * The aliases of a value returned from a function is the most interesting case, and depends on whether the return value is declared *alias-free* (with an asterisk ``*``) or not. If it is declared alias-free, then it has no aliases. Otherwise, it aliases all arguments passed for *non-consumed* parameters. In-place Updates and Higher-Order Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Consumption generally interacts inflexibly with higher-order functions. The issue is that we cannot control how many times a function argument is applied, or to what, so it is not safe to pass a function that consumes its argument. The following two conservative rules govern the interaction between consumption and higher-order functions: 1. In the expression ``let p = e1 in ...``, if *any* in-place update takes place in the expression ``e1``, the value bound by ``p`` must not be or contain a function. 2. A function that consumes one of its arguments may not be passed as a higher-order argument to another function. .. _module-system: Modules ------- .. productionlist:: mod_bind: "module" `name` `mod_param`* "=" [":" `mod_type_exp`] "=" `mod_exp` mod_param: "(" `name` ":" `mod_type_exp` ")" mod_type_bind: "module" "type" `name` "=" `mod_type_exp` Futhark supports an ML-style higher-order module system. *Modules* can contain types, functions, and other modules and module types. *Module types* are used to classify the contents of modules, and *parametric modules* are used to abstract over modules (essentially module-level functions). In Standard ML, modules, module types and parametric modules are called *structs*, *signatures*, and *functors*, respectively. Module names exist in the same name space as values, but module types are their own name space. Module bindings ~~~~~~~~~~~~~~~ ``module m = mod_exp`` ...................... Binds *m* to the module produced by the module expression ``mod_exp``. Any name x in the module produced by ``mod_exp`` can then be accessed with ``m.x``. ``module m : mod_type_exp = mod_exp`` ..................................... Shorthand for ``module m = mod_exp : mod_type_exp``. ``module m mod_params... = mod_exp`` .................................... Shorthand for ``module m = \mod_params... -> mod_exp``. This produces a parametric module. ``module type mt = mod_type_exp`` ................................. Binds *mt* to the module type produced by the module type expression ``mod_type_exp``. Module Expressions ~~~~~~~~~~~~~~~~~~ .. productionlist:: mod_exp: `qualname` : | `mod_exp` ":" `mod_type_exp` : | "\" "(" `mod_param`* ")" [":" `mod_type_exp`] "->" `mod_exp` : | `mod_exp` `mod_exp` : | "(" `mod_exp` ")" : | "{" `dec`* "}" : | "import" `stringlit` A module expression produces a module. Modules are collections of bindings produced by declarations (`dec`). In particular, a module may contain other modules or module types. ``qualname`` ............ Evaluates to the module of the given name. ``(mod_exp)`` ............. Evaluates to ``mod_exp``. ``mod_exp : mod_type_exp`` .......................... *Module ascription* evaluates the module expression and the module type expression, verifies that the module implements the module type, then returns a module that exposes only the functionality described by the module type. This is how internal details of a module can be hidden. As a slightly ad-hoc limitation, ascription is forbidden when a type substitution of size-lifted types occurs in a size appearing at the top level. ``\(p: mt1): mt2 -> e`` ....................... Constructs a *parametric module* (a function at the module level) that accepts a parameter of module type ``mt1`` and returns a module of type ``mt2``. The latter is optional, but the parameter type is not. ``e1 e2`` ......... Apply the parametric module ``m1`` to the module ``m2``. ``{ decs }`` ............ Returns a module that contains the given definitions. The resulting module defines any name defined by any declaration that is not ``local``, *in particular* including names made available via ``open``. ``import "foo"`` ................ Returns a module that contains the definitions of the file ``"foo"`` relative to the current file. Module Type Expressions ~~~~~~~~~~~~~~~~~~~~~~~ .. productionlist:: mod_type_exp: `qualname` : | "{" `spec`* "}" : | `mod_type_exp` "with" `qualname` `type_param`* "=" `type` : | "(" `mod_type_exp` ")" : | "(" `name` ":" `mod_type_exp` ")" "->" `mod_type_exp` : | `mod_type_exp` "->" `mod_type_exp` .. productionlist:: spec: "val" `name` `type_param`* ":" `type` : | "val" `symbol` `type_param`* ":" `type` : | ("type" | "type^" | "type~") `name` `type_param`* "=" `type` : | ("type" | "type^" | "type~") `name` `type_param`* : | "module" `name` ":" `mod_type_exp` : | "include" `mod_type_exp` : | "#[" `attr` "]" `spec` Module types classify modules, with the only (unimportant) difference in expressivity being that modules can contain module types, but module types cannot specify that a module must contain a specific module type. They can specify of course that a module contains a *submodule* of a specific module type. A module type expression can be the name of another module type, or a sequence of *specifications*, or *specs*, enclosed in curly braces. A spec can be a *value spec*, indicating the presence of a function or value, an *abstract type spec*, or a *type abbreviation spec*. In a value spec, sizes in types on the left-hand side of a function arrow must not be anonymous. For example, this is forbidden:: val sum: []t -> t Instead write:: val sum [n]: [n]t -> t But this is allowed, because the empty size is not to the left of a function arrow:: val evens [n]: [n]i32 -> []i32 .. _other-files: Referencing Other Files ----------------------- You can refer to external files in a Futhark file like this:: import "file" The above will include all non-``local`` top-level definitions from ``file.fut`` is and make them available in the current file (but will not export them). The ``.fut`` extension is implied. You can also include files from subdirectories:: import "path/to/a/file" The above will include the file ``path/to/a/file.fut`` relative to the including file. Qualified imports are also possible, where a module is created for the file:: module M = import "file" In fact, a plain ``import "file"`` is equivalent to:: local open import "file" To re-export names from another file in the current module, use:: open import "file" .. _attributes: Attributes ---------- .. productionlist:: attr: `name` : | `decimal` : | `name` "(" [`attr` ("," `attr`)* [","]] ")" An expression, declaration, pattern, or module type spec can be prefixed with an attribute, written as ``#[attr]``. This may affect how it is treated by the compiler or other tools. In no case will attributes affect or change the *semantics* of a program, but it may affect how well it compiles and runs (or in some cases, whether it compiles or runs at all). Unknown attributes are silently ignored. Most have no effect in the interpreter. An attribute can be either an *atom*, written as an identifier or number, or *compound*, consisting of an identifier and a comma-separated sequence of attributes. The latter is used for grouping and encoding of more complex information. Expression attributes ~~~~~~~~~~~~~~~~~~~~~ Many expression attributes affect second-order array combinators (*SOACS*). These must be applied to a fully saturated function application or they will have no effect. If two SOACs with contradictory attributes are combined through fusion, it is unspecified which attributes take precedence. The following expression attributes are supported. ``trace`` ......... Print the value produced by the attributed expression. Used for debugging. Somewhat unreliable outside of the interpreter, and in particular does not work for GPU device code. ``trace(tag)`` .............. Like ``trace``, but prefix output with *tag*, which must lexically be an identifier. ``break`` ......... In the interpreter, pause execution *before* evaluating the expression. No effect for compiled code. ``opaque`` .......... The compiler will treat the attributed expression as a black box. This is used to work around optimisation deficiencies (or bugs), although it should hopefully rarely be necessary. ``incremental_flattening(no_outer)`` .................................... When using incremental flattening, do not generate the "only outer parallelism" version for the attributed SOACs. ``incremental_flattening(no_intra)`` .................................... When using incremental flattening, do not generate the "intra-block parallelism" version for the attributed SOACs. ``incremental_flattening(only_intra)`` ...................................... When using incremental flattening, *only* generate the "intra-block parallelism" version of the attributed SOACs. **Beware**: the resulting program will fail to run if the inner parallelism does not fit on the device. ``incremental_flattening(only_inner)`` ...................................... When using incremental flattening, do not generate multiple versions for this SOAC, but do exploit inner parallelism (which may give rise to multiple versions at deeper levels). ``noinline`` ............ Do not inline the attributed function application. If used within a parallel construct (e.g. ``map``), this will likely prevent the GPU backends from generating working code. ``sequential`` .............. *Fully* sequentialise the attributed SOAC. ``sequential_outer`` .................... Turn the outer parallelism in the attributed SOAC sequential, but preserve any inner parallelism. ``sequential_inner`` .................... Exploit only outer parallelism in the attributed SOAC. ``unroll`` .......... Fully unroll the attributed ``loop``. If the compiler cannot determine the exact number of iterations (possibly after other optimisations and simplifications have taken place), then this attribute has no code generation effect, but instead results in a warning. Be very careful with this attribute: it can massively increase program size (possibly crashing the compiler) if the loop has a huge number of iterations. ``unsafe`` .......... Do not perform any dynamic safety checks (such as bound checks) during execution of the attributed expression. ``warn(safety_checks)`` ....................... Make the compiler issue a warning if the attributed expression (or its subexpressions) requires safety checks (such as bounds checking) at run-time. This is used for performance-critical code where you want to be told when the compiler is unable to statically verify the safety of all operations. Declaration attributes ~~~~~~~~~~~~~~~~~~~~~~ The following declaration attributes are supported. ``noinline`` ............ Do not inline any calls to this function. If the function is then used within a parallel construct (e.g. ``map``), this will likely prevent the GPU backends from generating working code. ``inline`` .......... Always inline calls to this function. Pattern attributes ~~~~~~~~~~~~~~~~~~ No pattern attributes are currently supported by the compiler itself, although they are syntactically permitted and may be used by other tools. Spec attributes ~~~~~~~~~~~~~~~ No spec attributes are currently supported by the compiler itself, although they are syntactically permitted and may be used by other tools. futhark-0.25.27/docs/man/000077500000000000000000000000001475065116200150455ustar00rootroot00000000000000futhark-0.25.27/docs/man/futhark-autotune.rst000066400000000000000000000041461475065116200211120ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-autotune(1): ================ futhark-autotune ================ SYNOPSIS ======== futhark autotune [options...] DESCRIPTION =========== ``futhark autotune`` attemps to find optimal values for threshold parameters given representative datasets. This is done by repeatedly running running the program through :ref:`futhark-bench(1)` with different values for the threshold parameters. When ``futhark autotune`` finishes tuning a program ``foo.fut``, the results are written to ``foo.fut.tuning``, which will then automatically be picked up by subsequent uses of :ref:`futhark-bench(1)` and :ref:`futhark-test(1)`. OPTIONS ======= --backend=name The backend used when compiling Futhark programs (without leading ``futhark``, e.g. just ``opencl``). --futhark=program The program used to perform operations (eg. compilation). Defaults to the binary running ``futhark autotune`` itself. --pass-option=opt Pass an option to programs that are being run. For example, we might want to run OpenCL programs on a specific device:: futhark autotune prog.fut --backend=opencl --pass-option=-dHawaii --runs=count The number of runs per data set. -v, --verbose Print verbose information about what the tuner is doing. Pass multiple times to increase the amount of information printed. --skip-compilation Do not run the compiler, and instead assume that the program has already been compiled. Use with caution. --spec-file=FILE Ignore the test specification in the program file(s), and instead load them from this other file. These external test specifications use the same syntax as normal, but *without* line comment prefixes A ``==`` is still expected. --tuning=EXTENSION Change the extension used for tuning files (``.tuning`` by default). --timeout=seconds Initial tuning timeout for each dataset in seconds. After running the intitial tuning run on each dataset, the timeout is based on the run time of that initial tuning. Defaults to 60. A negative timeout means to wait indefinitely. SEE ALSO ======== :ref:`futhark-bench(1)` futhark-0.25.27/docs/man/futhark-bench.rst000066400000000000000000000137661475065116200203350ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-bench(1): ============= futhark-bench ============= SYNOPSIS ======== futhark bench [options...] programs... DESCRIPTION =========== This tool is the recommended way to benchmark Futhark programs. Programs are compiled using the specified backend (``c`` by default), then run a number of times for each test case, and the arithmetic mean runtime and 95% confidence interval printed on standard output. Refer to :ref:`futhark-test(1)` for information on how to format test data. A program will be ignored if it contains no data sets - it will not even be compiled. If compilation of a program fails, then ``futhark bench`` will abort immediately. If execution of a test set fails, an error message will be printed and benchmarking will continue (and ``--json`` will write the file), but a non-zero exit code will be returned at the end. METHODOLOGY =========== For each program and dataset, ``futhark bench`` first does a single "warmup" run that is discarded. After that it uses a two-phase technique. 1. The *initial phase* performs ten runs (change with ``-r``), or perform runs for at least half a second, whichever takes longer. If the resulting measurements are sufficiently statistically robust (determined using standard deviation and autocorrelation metrics), the results are produced and the second phase is not entered. Otherwise, the results are discarded and the second phase entered. 2. The *convergence phase* keeps performing runs until a measurement of sufficient statistical quality is reached. The notion of "sufficient statistical quality" is based on heuristics. The intent is that ``futhark bench`` will in most cases do *the right thing* by default, both when benchmarking both long-running programs and short-running programs. If you want complete control, disable the convergence phase with ``--no-convergence-phase`` and set the number of runs you want with ``-r``. OPTIONS ======= --backend=name The backend used when compiling Futhark programs (without leading ``futhark``, e.g. just ``opencl``). --cache-extension=EXTENSION For a program ``foo.fut``, pass ``--cache-file foo.fut.EXTENSION``. By default, ``--cache-file`` is not passed. --concurrency=NUM The number of benchmark programs to prepare concurrently. Defaults to the number of cores available. *Prepare* means to compile the benchmark, as well as generate any needed datasets. In some cases, this generation can take too much memory, in which case lowering ``--concurrency`` may help. --convergence-max-seconds=NUM Don't run the convergence phase for longer than this. This does not mean that the measurements have converged. Defaults to 300 seconds (five minutes). --entry-point=name Only run entry points with this name. --exclude-case=TAG Do not run test cases that contain the given tag. Cases marked with "nobench", "disable", or "no_foo" (where *foo* is the backend used) are ignored by default. --futhark=program The program used to perform operations (eg. compilation). Defaults to the binary running ``futhark bench`` itself. --ignore-files=REGEX Ignore files whose path match the given regular expression. --json=file Write raw results in JSON format to the specified file. --no-tuning Do not look for tuning files. --no-convergence-phase Do not run the convergence phase. --pass-option=opt Pass an option to benchmark programs that are being run. For example, we might want to run OpenCL programs on a specific device:: futhark bench prog.fut --backend=opencl --pass-option=-dHawaii --pass-compiler-option=opt Pass an extra option to the compiler when compiling the programs. --profile Enable profiling for the binary (by passing ``--profiling`` and ``--logging``) and store the recorded information in the file indicated by ``--json`` (which is required), along with the other benchmarking results. --runner=program If set to a non-empty string, compiled programs are not run directly, but instead the indicated *program* is run with its first argument being the path to the compiled Futhark program. This is useful for compilation targets that cannot be executed directly (as with :ref:`futhark-pyopencl(1)` on some platforms), or when you wish to run the program on a remote machine. --runs=count The number of runs per data set. --skip-compilation Do not run the compiler, and instead assume that each benchmark program has already been compiled into a server-mode executable. Use with caution. --spec-file=FILE Ignore the test specification in the program file(s), and instead load them from this other file. These external test specifications use the same syntax as normal, but *without* line comment prefixes. A ``==`` is still expected. --timeout=seconds If the runtime for a dataset exceeds this integral number of seconds, it is aborted. Note that the time is allotted not *per run*, but for *all runs* for a dataset. A twenty second limit for ten runs thus means each run has only two seconds (minus initialisation overhead). A negative timeout means to wait indefinitely. -v, --verbose Print verbose information about what the benchmark is doing. Pass multiple times to increase the amount of information printed. --tuning=EXTENSION For each program being run, look for a tuning file with this extension, which is suffixed to the name of the program. For example, given ``--tuning=tuning`` (the default), the program ``foo.fut`` will be passed the tuning file ``foo.fut.tuning`` if it exists. EXAMPLES ======== The following program benchmarks how quickly we can sum arrays of different sizes:: -- How quickly can we reduce arrays? -- -- == -- nobench input { 0i64 } -- output { 0i64 } -- input { 100i64 } -- output { 4950i64 } -- compiled input { 10000i64 } -- output { 49995000i64 } -- compiled input { 1000000i64 } -- output { 499999500000i64 } let main(n: i64): i64 = reduce (+) 0 (iota n) SEE ALSO ======== :ref:`futhark-c(1)`, :ref:`futhark-test(1)` futhark-0.25.27/docs/man/futhark-c.rst000066400000000000000000000072161475065116200174710ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-c(1): ========= futhark-c ========= SYNOPSIS ======== futhark c [options...] DESCRIPTION =========== ``futhark c`` translates a Futhark program to sequential C code, and either compiles that C code with a C compiler (see below) to an executable binary program, or produces a ``.h`` and ``.c`` file that can be linked with other code.. The standard Futhark optimisation pipeline is used, and The resulting program will read the arguments to the entry point (``main`` by default) from standard input and print its return value on standard output. The arguments are read and printed in Futhark syntax. OPTIONS ======= -h Print help text to standard output and exit. --entry-point NAME Treat this top-level function as an entry point. --library Generate a library instead of an executable. Appends ``.c``/``.h`` to the name indicated by the ``-o`` option to determine output file names. -o outfile Where to write the result. If the source program is named ``foo.fut``, this defaults to ``foo``. --safe Ignore ``unsafe`` in program and perform safety checks unconditionally. --server Generate a server-mode executable that reads commands from stdin. -v verbose Enable debugging output. If compilation fails due to a compiler error, the result of the last successful compiler step will be printed to standard error. -V Print version information on standard output and exit. -W Do not print any warnings. --Werror Treat warnings as errors. ENVIRONMENT VARIABLES ===================== ``CC`` The C compiler used to compile the program. Defaults to ``cc`` if unset. ``CFLAGS`` Space-separated list of options passed to the C compiler. Defaults to ``-O3 -std=c99`` if unset. EXECUTABLE OPTIONS ================== The following options are accepted by executables generated by ``futhark c``. -h, --help Print help text to standard output and exit. -b, --binary-output Print the program result in the binary output format. The default is human-readable text, which is very slow. Not accepted by server-mode executables. --cache-file=FILE Store any reusable initialisation data in this file, possibly speeding up subsequent launches. -D, --debugging Perform possibly expensive internal correctness checks and verbose logging. Implies ``-L``. -e, --entry-point=FUN The entry point to run. Defaults to ``main``. Not accepted by server-mode executables. -L, --log Print various low-overhead logging information to stderr while running. -n, --no-print-result Do not print the program result. Not accepted by server-mode executables. -P, --profile Gather profiling data during execution. Mostly interesting in ``--server`` mode. Implied by ``-D``. --param=ASSIGNMENT Set a tuning parameter to the given value. ``ASSIGNMENT`` must be of the form ``NAME=INT`` Use ``--print-params`` to see which names are available. --print-params Print all tuning parameters that can be set with ``--param`` or ``--tuning``. -r, --runs=NUM Perform NUM runs of the program. With ``-t``, the runtime for each individual run will be printed. Additionally, a single leading warmup run will be performed (not counted). Only the final run will have its result written to stdout. Not accepted by server-mode executables. -t, --write-runtime-to=FILE Print the time taken to execute the program to the indicated file, an integral number of microseconds. Not accepted by server-mode executables. --tuning=FILE Read size=value assignments from the given file. SEE ALSO ======== :ref:`futhark-opencl(1)`, :ref:`futhark-cuda(1)`, :ref:`futhark-test(1)` futhark-0.25.27/docs/man/futhark-cuda.rst000066400000000000000000000067351475065116200201700ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-cuda(1): ============== futhark-cuda ============== SYNOPSIS ======== futhark cuda [options...] DESCRIPTION =========== ``futhark cuda`` translates a Futhark program to C code invoking CUDA kernels, and either compiles that C code with a C compiler to an executable binary program, or produces a ``.h`` and ``.c`` file that can be linked with other code. The standard Futhark optimisation pipeline is used. ``futhark cuda`` uses ``-lcuda -lcudart -lnvrtc`` to link. If using ``--library``, you will need to do the same when linking the final binary. The generated CUDA code can be called from multiple CPU threads, as it brackets every API operation with ``cuCtxPushCurrent()`` and ``cuCtxPopCurrent()``. OPTIONS ======= Accepts the same options as :ref:`futhark-c(1)`. ENVIRONMENT VARIABLES ===================== ``CC`` The C compiler used to compile the program. Defaults to ``cc`` if unset. ``CFLAGS`` Space-separated list of options passed to the C compiler. Defaults to ``-O -std=c99`` if unset. EXECUTABLE OPTIONS ================== Generated executables accept the same options as those generated by :ref:`futhark-c(1)`. The ``-t`` option behaves as with :ref:`futhark-opencl(1)`. The following additional options are accepted. -h, --help Print help text to standard output and exit. --default-thread-block-size=INT The default size of thread blocks that are launched. Capped to the hardware limit if necessary. --default-num-thread-blocks=INT The default number of thread blocks that are launched. --default-threshold=INT The default parallelism threshold used for comparisons when selecting between code versions generated by incremental flattening. Intuitively, the amount of parallelism needed to saturate the GPU. --default-tile-size=INT The default tile size used when performing two-dimensional tiling (the workgroup size will be the square of the tile size). --dump-cuda=FILE Don't run the program, but instead dump the embedded CUDA kernels to the indicated file. Useful if you want to see what is actually being executed. --dump-ptx=FILE Don't run the program, but instead dump the PTX-compiled version of the embedded kernels to the indicated file. --load-cuda=FILE Instead of using the embedded CUDA kernels, load them from the indicated file. --load-ptx=FILE Load PTX code from the indicated file. --nvrtc-option=OPT Add an additional build option to the string passed to NVRTC. Refer to the CUDA documentation for which options are supported. Be careful - some options can easily result in invalid results. ENVIRONMENT =========== If run without ``--library``, ``futhark cuda`` will invoke a C compiler to compile the generated C program into a binary. This only works if the C compiler can find the necessary CUDA libraries. On most systems, CUDA is installed in ``/usr/local/cuda``, which is usually not part of the default compiler search path. You may need to set the following environment variables before running ``futhark cuda``:: LIBRARY_PATH=/usr/local/cuda/lib64 LD_LIBRARY_PATH=/usr/local/cuda/lib64/ CPATH=/usr/local/cuda/include At runtime the generated program must be able to find the CUDA installation directory, which is normally located at ``/usr/local/cuda``. If you have CUDA installed elsewhere, set any of the ``CUDA_HOME``, ``CUDA_ROOT``, or ``CUDA_PATH`` environment variables to the proper directory. SEE ALSO ======== :ref:`futhark-opencl(1)` futhark-0.25.27/docs/man/futhark-dataset.rst000066400000000000000000000044761475065116200207010ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-dataset(1): =============== futhark-dataset =============== SYNOPSIS ======== futhark dataset [options...] DESCRIPTION =========== Generate random values in Futhark syntax, which can be useful when generating input datasets for program testing. All Futhark primitive types are supported. Tuples are not supported. Arrays of specific (non-random) sizes can be generated. You can specify maximum and minimum bounds for values, as well as the random seed used when generating the data. The generated values are written to standard output. If no ``-g``/``--generate`` options are passed, values are read from standard input, and printed to standard output in the indicated format. The input format (whether textual or binary) is automatically detected. Returns a nonzero exit code if it fails to write the full output. OPTIONS ======= -b, --binary Output data in binary Futhark format (must precede --generate). -g type, --generate type Generate a value of the indicated type, e.g. ``-g i32`` or ``-g [10]f32``. The type may also be a value, in which case that literal value is generated. -s int Set the seed used for the RNG. 1 by default. --T-bounds= Set inclusive lower and upper bounds on generated values of type ``T``. ``T`` is any primitive type, e.g. ``i32`` or ``f32``. The bounds apply to any following uses of the ``-g`` option. You can alter the output format using the following flags. To use them, add them before data generation (--generate): --text Output data in text format (must precede --generate). Default. -t, --type Output the types of values (textually) instead of the values themselves. Mostly useful when reading values on stdin. EXAMPLES ======== Generate a 4 by 2 integer matrix:: futhark dataset -g [4][2]i32 Generate an array of floating-point numbers and an array of indices into that array:: futhark dataset -g [10]f32 --i64-bounds=0:9 -g [100]i64 To generate binary data, the ``--binary`` must come before the ``--generate``:: futhark dataset --binary --generate=[42]i32 Create a binary data file from a data file:: futhark dataset --binary < any_data > binary_data Determine the types of values contained in a data file:: futhark dataset -t < any_data SEE ALSO ======== :ref:`futhark-test(1)`, :ref:`futhark-bench(1)` futhark-0.25.27/docs/man/futhark-doc.rst000066400000000000000000000032431475065116200200100ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-doc(1): =========== futhark-doc =========== SYNOPSIS ======== futhark doc [options...] dir DESCRIPTION =========== ``futhark doc`` generates HTML-formatted documentation from Futhark code. One HTML file will be created for each ``.fut`` file in the given directory, as well as any file reachable through ``import`` expressions. The given Futhark code will be considered as one cohesive whole, and must be type-correct. Futhark definitions may be documented by prefixing them with a block of line comments starting with :literal:`-- |` (see example below). Simple Markdown syntax is supported within these comments. A link to another identifier is possible with the notation :literal:`\`name\`@namespace`, where ``namespace`` must be either ``term``, ``type``, or ``mtype`` (module names are in the ``term`` namespace). A file may contain a leading documentation comment, which will be considered the file *abstract*. ``futhark doc`` will ignore any file whose documentation comment consists solely of the word "ignore". This is useful for files that contain tests, or are otherwise not relevant to the reader of the documentation. OPTIONS ======= -h Print help text to standard output and exit. -o outdir The name of the directory that will contain the generated documentation. This option is mandatory. -v, --verbose Print status messages to stderr while running. -V Print version information on standard output and exit. EXAMPLES ======== .. code-block:: futhark -- | Gratuitous re-implementation of `map`@term. -- -- Does exactly the same. let mymap = ... SEE ALSO ======== :ref:`futhark-test(1)`, :ref:`futhark-bench(1)` futhark-0.25.27/docs/man/futhark-fmt.rst000066400000000000000000000017231475065116200200320ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-fmt(1): =========== futhark-fmt =========== SYNOPSIS ======== futhark fmt [options...] [FILES] DESCRIPTION =========== Reformat the given Futhark programs. If no files are provided, read Futhark program on stdin and produce formatted output on stdout. If stdout is a terminal, the output will be syntax highlighted. In contrast to many other automatic formatters, the formatting is somewhat sensitive to the formatting of the input program. In particular, single-line expressions will usually be kept on a single line, even if they are very long. To force ``futhark fmt`` to break these, insert a linebreak at an arbitrary location. OPTIONS ======= --check Check if the given files are correctly formatted, and if not, terminate with an error message and a nonzero exit code. -h Print help text to standard output and exit. -V Print version information on standard output and exit. SEE ALSO ======== :ref:`futhark-doc(1)` futhark-0.25.27/docs/man/futhark-hip.rst000066400000000000000000000051751475065116200200310ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-hip(1): ============== futhark-hip ============== SYNOPSIS ======== futhark hip [options...] DESCRIPTION =========== ``futhark hip`` translates a Futhark program to C code invoking HIP kernels, and either compiles that C code with a C compiler to an executable binary program, or produces a ``.h`` and ``.c`` file that can be linked with other code. The standard Futhark optimisation pipeline is used. ``futhark hip`` uses ``-lhiprtc -lamdhip64`` to link. If using ``--library``, you will need to do the same when linking the final binary. Although the HIP backend can be made to work on NVIDIA GPUs, you are probably better off using the very similar :ref:`futhark-cuda(1)`. OPTIONS ======= Accepts the same options as :ref:`futhark-c(1)`. ENVIRONMENT VARIABLES ===================== ``CC`` The C compiler used to compile the program. Defaults to ``cc`` if unset. ``CFLAGS`` Space-separated list of options passed to the C compiler. Defaults to ``-O -std=c99`` if unset. EXECUTABLE OPTIONS ================== Generated executables accept the same options as those generated by :ref:`futhark-c(1)`. For commonality, the options use OpenCL nomenclature ("group" instead of "thread block"). The following additional options are accepted. --default-group-size=INT The default size of thread blocks that are launched. Capped to the hardware limit if necessary. --default-num-groups=INT The default number of thread blocks that are launched. --default-threshold=INT The default parallelism threshold used for comparisons when selecting between code versions generated by incremental flattening. Intuitively, the amount of parallelism needed to saturate the GPU. --default-tile-size=INT The default tile size used when performing two-dimensional tiling (the workgroup size will be the square of the tile size). --dump-hip=FILE Don't run the program, but instead dump the embedded HIP kernels to the indicated file. Useful if you want to see what is actually being executed. --load-hip=FILE Instead of using the embedded HIP kernels, load them from the indicated file. --build-option=OPT Add an additional build option to the string passed to the kernel compiler (HIPRTC). Refer to the HIP documentation for which options are supported. Be careful - some options can easily result in invalid results. ENVIRONMENT =========== If run without ``--library``, ``futhark hip`` will invoke a C compiler to compile the generated C program into a binary. This only works if the C compiler can find the necessary HIP libraries. SEE ALSO ======== :ref:`futhark(1)` futhark-0.25.27/docs/man/futhark-ispc.rst000066400000000000000000000027371475065116200202100ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-ispc(1): ============ futhark-ispc ============ SYNOPSIS ======== futhark ispc [options...] DESCRIPTION =========== ``futhark ispc`` translates a Futhark program to a combination of C and ISPC code, with ISPC used for parallel loops. It otherwise operates similarly to :ref:`futhark-multicore(1)`. You need to have ``ispc`` on your ``PATH``. OPTIONS ======= Accepts the same options as :ref:`futhark-multicore(1)`. ENVIRONMENT VARIABLES ===================== ``CC`` The C compiler used to compile the program. Defaults to ``cc`` if unset. ``CFLAGS`` Space-separated list of options passed to the C compiler. Defaults to ``-O3 -std=c99 -pthread`` if unset. ``ISPCFLAGS`` Space-separated list of options passed to ``ispc``. Defaults to ``-O3 --woff`` if unset. EXECUTABLE OPTIONS ================== Generated executables accept the same options as those generated by :ref:`futhark-multicore(1)`. LIBRARY USAGE ============= When compiling a program ``foo.fut`` with ``futhark ispc --library``, a ``foo.kernels.ispc`` file is produced that must be compiled with ``ispc`` and linked with the final program. For example:: $ ispc -o foo.kernels.o foo.kernels.ispc --addressing=64 --pic --woff -O3 BUGS ==== Currently works only on Unix-like systems because of a dependency on pthreads. Adding support for Windows would likely not be difficult. SEE ALSO ======== :ref:`futhark-multicore(1)`, :ref:`futhark-test(1)` futhark-0.25.27/docs/man/futhark-literate.rst000066400000000000000000000233401475065116200210540ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-literate(1): ================ futhark-literate ================ SYNOPSIS ======== futhark literate [options...] program DESCRIPTION =========== The command ``futhark literate foo.fut`` will compile the given program and then generate a Markdown file ``foo.md`` that contains a prettyprinted form of the program. This is useful for demonstrating programming techniques. * Top-level comments that start with a line comment marker (``--``) and a space in the next column will be turned into ordinary text in the Markdown file. * Ordinary top-level definitions will be enclosed in Markdown code blocks. * Any *directives* will be executed and replaced with their output. See below. **Warning:** Do not run untrusted programs. See SAFETY below. Image directives and builtin functions shell out to ``convert`` (from ImageMagick). Video and audio generation uses ``ffmpeg``. For an input file ``foo.fut``, all generated files will be in a directory named ``foo-img``. A ``file`` parameter passed to a directive may not contain a directory component or spaces. OPTIONS ======= --backend=name The backend used when compiling Futhark programs (without leading ``futhark``, e.g. just ``opencl``). Defaults to ``c``. --futhark=program The program used to perform operations (eg. compilation). Defaults to the binary running ``futhark literate`` itself. --output=FILE Override the default output file. The image directory will be set to the provided ``FILE`` with its extension stripped and ``-img/`` appended. --pass-option=opt Pass an option to benchmark programs that are being run. For example, we might want to run OpenCL programs on a specific device:: futhark literate prog.fut --backend=opencl --pass-option=-dHawaii --pass-compiler-option=opt Pass an extra option to the compiler when compiling the programs. --skip-compilation Do not run the compiler, and instead assume that the program has already been compiled. Use with caution. --stop-on-error Terminate immediately without producing an output file if a directive fails. Otherwise a file will still be produced, and failing directives will be followed by an error message. -v, --verbose Print verbose information on stderr about directives as they are executing. This is also needed to see ``#[trace]`` output. DIRECTIVES ========== A directive is a way to show the result of running a function. Depending on the directive, this can be as simple as printing the textual representation of the result, or as complex as running an external plotting program and referencing a generated image. Any directives that produce images for a program ``foo.fut`` will place them in the directory ``foo-img/``. If this directory already exists, it will be deleted. A directive is a line starting with ``-- >``, which must follow an empty line. Arguments to the directive follow on the remainder of the line. Any expression arguments are given in a very restricted subset of Futhark called *FutharkScript* (see below). Some directives take mandatory or optional parameters. These are entered after a semicolon *and a linebreak*. The following directives are supported: * ``> e`` Shows the result of executing the FutharkScript expression ``e``, which can have any (transparent) type. * ``> :video e[; parameters...]`` Creates a video from ``e``. The optional parameters are lines of the form *key: value*: * ``repeat: `` * ``fps: `` * ``format: `` * ``file: ``. Make sure to provide a proper extension. ``e`` must be one of the following: * A 3D array where the 2D elements is of a type acceptable to ``:img``, and the outermost dimension is the number of frames. * A triple ``(s -> (img,s), s, i64)``, for some types ``s`` and ``img``, where ``img`` is an array acceptable to ``:img``. This means not all frames have to be held in memory at once. * ``> :brief `` The same as the given *directive* (which must not start with another ``>``), but suppress parameters when printing it. * ``> :covert `` The same as the given *directive* (which must not start with another ``>``), but do not show the directive itself in the output, only its result. * ``> :img e[; parameters...]`` Visualises ``e``. The optional parameters are lines of the form *key: value*: * ``file: NAME``. Make sure to use a proper extension. The expression ``e`` must have one of the following types: * ``[][]i32`` and ``[][]u32`` Interpreted as ARGB pixel values. * ``[][]f32`` and ``[][]f64`` Interpreted as greyscale. Values should be between 0 and 1, with 0 being black and 1 being white. * ``[][]u8`` Interpreted as greyscale. 0 is black and 255 is white. * ``[][]bool`` Interpreted as black and white. ``false`` is black and ``true`` is white. * ``> :plot2d e[; size=(height,width)]`` Shows a plot generated with ``gnuplot`` of ``e``, which must be an expression of type ``([]t, []t)``, where ``t`` is some numeric type. The two arrays must have the same length and are interpreted as ``x`` and ``y`` values, respectively. The expression may also be a record expression (*not* merely the name of a Futhark variable of record type), where each field will be plotted separately and must have the type mentioned above. * ``> :gnuplot e; script...`` Similar to ``plot2d``, except that it uses the provided Gnuplot script. The ``e`` argument must be a record whose fields are tuples of one-dimensional arrays, and the data will be available in temporary files whose names are in variables named after the record fields. Each file will contain a column of data for each array in the corresponding tuple. Use ``set term png size width,height`` to change the size to ``width`` by ``height`` pixels. * ``> :audio e[; parameters...]`` Creates a sound-file from ``e``. The optional parameters are lines of the form *key:value*: * ``sampling_frequency: `` The sampling frequency (in Hz) of the input. Defaults to ``44100``. * ``codec: `` The codec of the output. Defaults to ``wav``. Other common options include ``mp3``, ``flac``, ``ogg`` and ``opus``. The expression ``e`` must have one of the following types: * ``[]i8`` and ``[]u8`` Interpreted as PCM signed/unsigned 8-bit audio. * ``[]i16`` and ``[]u16`` Interpreted as PCM signed/unsigned 16-bit audio. * ``[]i32`` and ``[]u32`` Interpreted as PCM signed/unsigned 32-bit audio. * ``[]f32`` and ``[]f64`` Interpreted as PCM signed/unsigned 32/64 bit floating-point audio. Should only contain values between ``-1.0`` and ``1.0``. For each type of input, it is also possible to give expressions with a two-dimensional type instead, e.g. ``[][]f32``. These expressions are interpreted as an array of channels, making it possible to do stereo audio by returning e.g. ``[2][]f32``. For stereo output, the first row is the left channel and the second row is the right channel. This functionality uses the amerge filter from ffmpeg, so consult the documentation there for additional information. FUTHARKSCRIPT ============= Only an extremely limited subset of Futhark is supported: .. productionlist:: script_exp: `fun` `script_exp`* : | "(" `script_exp` ")" : | "(" `script_exp` ( "," `script_exp` )+ ")" : | "[" `script_exp` ( "," `script_exp` )+ "]" : | "empty" "(" ("[" `decimal` "]" )+ `script_type` ")" : | "{" "}" : | "{" (`id` = `script_exp`) ("," `id` = `script_exp`)* "}" : | "let" `script_pat` "=" `script_exp` "in" `script_exp` : | `literal` script_pat: `id` | "(" `id` ("," `id`) ")" script_fun: `id` | "$" `id` script_type: `int_type` | `float_type` | "bool" Note that empty arrays must be written using the ``empty(t)`` notation, e.g. ``empty([0]i32)``. Function applications are either of Futhark functions or *builtin functions*. The latter are prefixed with ``$`` and are magical (usually impure) functions that could not possibly be implemented in Futhark. The following builtins are supported: * ``$loadimg "file"`` reads an image from the given file and returns it as a row-major ``[][]u32`` array with each pixel encoded as ARGB. * ``$loaddata "file"`` reads a dataset from the given file. When the file contains a singular value, it is returned as value. Otherwise, a tuple of values is returned, which should be destructured before use. For example: ``let (a, b) = $loaddata "foo.in" in bar a b``. * ``$loadbytes "file"`` reads the contents of the given file as an array of type ``[]u8``. * ``$loadaudio "file"`` reads audio from the given file and returns it as a ``[][]f64``, where each row corresponds to a channel of the original soundfile. Most common audio-formats are supported, including mp3, ogg, wav, flac and opus. FutharkScript supports a form of automatic uncurrying. If a function taking *n* parameters is applied to a single argument that is an *n*-element tuple, the function is applied to the elements of the tuple as individual arguments. SAFETY ====== Some directives (e.g. ``:gnuplot``) can run arbitrary shell commands. Other directives or builtin functions can read or write arbitrary files. Running an untrusted literate Futhark program is as dangerous as running a shell script you downloaded off the Internet. Before running a program from an unknown source, you should always give it a quick read to see if anything looks fishy. BUGS ==== FutharkScript expressions can only refer to names defined in the file passed to ``futhark literate``, not any names in imported files. SEE ALSO ======== :ref:`futhark-script(1)`, :ref:`futhark-test(1)`, :ref:`futhark-bench(1)` futhark-0.25.27/docs/man/futhark-multicore.rst000066400000000000000000000026641475065116200212540ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-multicore(1): ================= futhark-multicore ================= SYNOPSIS ======== futhark multicore [options...] DESCRIPTION =========== ``futhark multicore`` translates a Futhark program to multithreaded C code, and either compiles that C code with a C compiler to an executable binary program, or produces a ``.h`` and ``.c`` file that can be linked with other code. The standard Futhark optimisation pipeline is used. The resulting program will read the arguments to the entry point (``main`` by default) from standard input and print its return value on standard output. The arguments are read and printed in Futhark syntax. OPTIONS ======= Accepts the same options as :ref:`futhark-c(1)`. ENVIRONMENT VARIABLES ===================== ``CC`` The C compiler used to compile the program. Defaults to ``cc`` if unset. ``CFLAGS`` Space-separated list of options passed to the C compiler. Defaults to ``-O3 -std=c99 -pthread`` if unset. EXECUTABLE OPTIONS ================== Generated executables accept the same options as those generated by :ref:`futhark-c(1)`. The following additional options are accepted. --num-threads=INT Use this many physical threads. BUGS ==== Currently works only on Unix-like systems because of a dependency on pthreads. Adding support for Windows would likely not be difficult. SEE ALSO ======== :ref:`futhark-c(1)`, :ref:`futhark-test(1)` futhark-0.25.27/docs/man/futhark-opencl.rst000066400000000000000000000074631475065116200205330ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-opencl(1): ============== futhark-opencl ============== SYNOPSIS ======== futhark opencl [options...] DESCRIPTION =========== ``futhark opencl`` translates a Futhark program to C code invoking OpenCL kernels, and either compiles that C code with a C compiler to an executable binary program, or produces a ``.h`` and ``.c`` file that can be linked with other code. The standard Futhark optimisation pipeline is used. ``futhark opencl`` uses ``-lOpenCL`` to link (``-framework OpenCL`` on macOS). If using ``--library``, you will need to do the same when linking the final binary. The GPU terminology used is derived from CUDA nomenclature (e.g. "thread block" instead of "workgroup"), but OpenCL nomenclature is also supported for compatibility. OPTIONS ======= Accepts the same options as :ref:`futhark-c(1)`. ENVIRONMENT VARIABLES ===================== ``CC`` The C compiler used to compile the program. Defaults to ``cc`` if unset. ``CFLAGS`` Space-separated list of options passed to the C compiler. Defaults to ``-O -std=c99`` if unset. EXECUTABLE OPTIONS ================== Generated executables accept the same options as those generated by :ref:`futhark-c(1)`. For the ``-t`` option, The time taken to perform device setup or teardown, including writing the input or reading the result, is not included in the measurement. In particular, this means that timing starts after all kernels have been compiled and data has been copied to the device buffers but before setting any kernel arguments. Timing stops after the kernels are done running, but before data has been read from the buffers or the buffers have been released. The following additional options are accepted. --build-option=OPT Add an additional build option to the string passed to ``clBuildProgram()``. Refer to the OpenCL documentation for which options are supported. Be careful - some options can easily result in invalid results. --default-thread-block-size=INT, --default-group-size=INT The default size of thread blocks that are launched. Capped to the hardware limit if necessary. --default-num-thread-blocks, --default-num-groups=INT The default number of thread blocks that are launched. --default-threshold=INT The default parallelism threshold used for comparisons when selecting between code versions generated by incremental flattening. Intuitively, the amount of parallelism needed to saturate the GPU. --default-tile-size=INT The default tile size used when performing two-dimensional tiling (the workgroup size will be the square of the tile size). -d, --device=NAME Use the first OpenCL device whose name contains the given string. The special string ``#k``, where ``k`` is an integer, can be used to pick the *k*-th device, numbered from zero. If used in conjunction with ``-p``, only the devices from matching platforms are considered. --dump-opencl=FILE Don't run the program, but instead dump the embedded OpenCL program to the indicated file. Useful if you want to see what is actually being executed. --dump-opencl-binary=FILE Don't run the program, but instead dump the compiled version of the embedded OpenCL program to the indicated file. On NVIDIA platforms, this will be PTX code. --load-opencl=FILE Instead of using the embedded OpenCL program, load it from the indicated file. --load-opencl-binary=FILE Load an OpenCL binary from the indicated file. -p, --platform=NAME Use the first OpenCL platform whose name contains the given string. The special string ``#k``, where ``k`` is an integer, can be used to pick the *k*-th platform, numbered from zero. --list-devices List all OpenCL devices and platforms available on the system. SEE ALSO ======== :ref:`futhark-test(1)`, :ref:`futhark-cuda(1)`, :ref:`futhark-c(1)` futhark-0.25.27/docs/man/futhark-pkg.rst000066400000000000000000000116311475065116200200240ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-pkg(1): =========== futhark-pkg =========== SYNOPSIS ======== futhark pkg add PKGPATH [X.Y.Z] futhark pkg check futhark pkg init PKGPATH futhark pkg fmt futhark pkg remove PKGPATH futhark pkg sync futhark pkg upgrade futhark pkg versions DESCRIPTION =========== This tool is used to modify the package manifest (``futhark.pkg``) and download the required packages it describes. ``futhark pkg`` is not a build system; you will still need to compile your Futhark code with the usual compilers. The only purpose of ``futhark pkg`` is to download code (and perform other package management utility tasks). This manpage is not a general introduction to package management in Futhark; see the User's Guide for that. The ``futhark pkg`` subcommands will modify only two locations in the file system (relative to the current working directory): the ``futhark.pkg`` file, and the contents of ``lib/``. When modifying ``lib/``, ``futhark pkg`` constructs the new version in ``lib~new/`` and backs up the old version in ``lib~old``. If ``futhark pkg`` should fail for any reason, you can recover the old state by moving ``lib~old`` back. These temporary directories are erased if ``futhark pkg`` finishes without errors. The ``futhark pkg sync`` and ``futhark pkg init`` subcommands are the only ones that actually modifies ``lib/``; the others modify only ``futhark.pkg`` and require you to manually run ``futhark pkg sync`` afterwards. Most commands take a ``-v``/``--verbose`` option that makes ``futhark pkg`` write running diagnostics to stderr. Packages must correspond to Git repositories, and all interactions are done by invoking ``git``. COMMANDS ======== futhark pkg add PKGPATH [X.Y.Z] ------------------------------- Add the specified package of the given minimum version as a requirement to ``futhark.pkg``. If no version is provided, the newest one is used. If the package is already required in ``futhark.pkg``, the new version requirement will replace the old one. Note that adding a package does not automatically download it. Run ``futhark pkg sync`` to do that. futhark pkg check ----------------- Verify that the ``futhark.pkg`` is valid, that all required packages are available in the indicated versions. This command does not check that these versions contain well-formed code. If a package path is defined in ``futhark.pkg``, also checks that ``.fut`` files are located at the expected location in the file system. futhark pkg init PKGPATH ------------------------ Create a new ``futhark.pkg`` defining a package with the given package path, and initially no requirements. futhark pkg fmt --------------- Reformat the ``futhark.pkg`` file, while retaining any comments. futhark pkg remove PKGPATH -------------------------- Remove a package from ``futhark.pkg``. Does *not* remove it from the ``lib/`` directory. futhark pkg sync ---------------- Populate the ``lib/`` directory with the packages listed in ``futhark.pkg``. **Warning**: this will delete everything in ``lib/`` that does not relate to a file listed in ``futhark.pkg``, as well as any local modifications. futhark pkg upgrade ------------------- Upgrade all package requirements in ``futhark.pkg`` to the newest available versions. futhark pkg versions PKGPATH ---------------------------- Print all available versions for the given package path. COMMIT VERSIONS =============== It is possible to use ``futhark pkg`` with packages that have not yet made proper releases. This is done via pseudoversions of the form ``0.0.0-yyyymmddhhmmss+commitid``. The timestamp is not verified against the actual commit. The timestamp ensures that newer commits take precedence if multiple packages depend on a commit version for the same package. If ``futhark pkg add`` is given a package with no releases, the most recent commit will be used. In this case, the timestamp is merely set to the current time. Commit versions are awkward and fragile, and should not be relied upon. Issue proper releases (even experimental 0.x version) as soon as feasible. Released versions also always take precedence over commit versions, since any version number will be greater than 0.0.0. EXAMPLES ======== Create a new package that will be hosted at ``https://github.com/sturluson/edda``:: futhark pkg init github.com/sturluson/edda Add a package dependency:: futhark pkg add github.com/sturluson/hattatal Download the dependencies:: futhark pkg sync And then you're ready to start hacking! (Except that these packages do not actually exist.) BUGS ==== Since the ``lib/`` directory is populated with transitive dependencies as well, it is possible for a package to depend unwittingly on one of the dependencies of its dependencies, without the ``futhark.pkg`` file reflecting this. There is no caching of package metadata between invocations, so the network traffic can be rather heavy. SEE ALSO ======== :ref:`futhark-test(1)`, :ref:`futhark-doc(1)` futhark-0.25.27/docs/man/futhark-profile.rst000066400000000000000000000070721475065116200207070ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-profile(1): =============== futhark-profile =============== SYNOPSIS ======== futhark profile JSONFILES DESCRIPTION =========== This tool produces human-readable profiling information based on information collected with :ref:`futhark bench`. Futhark has only rudimentary support for profiling. While the system can collect information about the run-time behaviour of the program, there is currently no automatic way to connect the information to the program source code. However, the collected information can still be useful for estimating the source of inefficiencies. USAGE ===== The first step is to run :ref:`futhark bench` on your program, while passing ``--profile`` and ``--json``. This will produce a JSON file containing runtime measurements, as well as collected profiling information. If you neglect to pass ``--profile``, the profiling information will be missing. The information in the JSON file is complete, but it is difficult for humans to read. The next step is to run ``futhark profile`` on the JSON file. For a JSON file ``prog.json``, this will create a *top level directory* ``prog.prof`` that contains files with human-readable profiling information. A set of files will be created for each benchmark dataset. If the original invocation of ``futhark bench`` included multiple programs, then ``futhark profile`` will create subdirectories for each program (although all inside the same top level directory). You can pass multiple JSON files to ``futhark profile``. Each will produce a distinct top level directory. Files produced -------------- Supposing a dataset ``foo``, ``futhark profile`` will produce the following files in the top level directory. * ``foo.log``: the running log produced during execution. Contains many details on dynamic behaviour, depending on the exact backend. * ``foo.summary``: a summary of memory usage and cost centres. For the GPU backends, the cost centres are kernel executions and memory copies. * ``foo.timeline``: a list of all recorded profiling events, in the order in which they occurred, along with their runtime and other available information Technicalities -------------- The profiling information, including the log, is collected from a *final* run performed after all the measured runs. Profiling information is not collected during the runs that contribute to the runtime measurement reported by ``futhark bench``. However, enabling profiling may still affect performance, as it changes the behaviour of the run time system. Raw reports ----------- Alternatively, the JSON can also contain a raw profiling report as produced by the C API function ``futhark_context_report()``. A directory is still created, but it will only contain a single set of files, and it will not contain a log. EXAMPLES ======== This shows the sequence of commands one might use to profile the program ``LocVolCalib.fut``, which has three datasets associated with it, using the ``hip`` backend:: $ futhark bench --backend=hip --profile --json result.json LocVolCalib.fut $ futhark profile result.json $ tree result.prof/ result.prof/ ├── LocVolCalib-data_large.in.log ├── LocVolCalib-data_large.in.summary ├── LocVolCalib-data_medium.in.log ├── LocVolCalib-data_medium.in.summary ├── LocVolCalib-data_small.in.log └── LocVolCalib-data_small.in.summary BUGS ==== Only the C-based backends currently support profiling. The ``c`` backend does not actually record useful profiling information. SEE ALSO ======== :ref:`futhark-bench(1)` futhark-0.25.27/docs/man/futhark-pyopencl.rst000066400000000000000000000022421475065116200210720ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-pyopencl(1): ================ futhark-pyopencl ================ SYNOPSIS ======== futhark pyopencl [options...] infile DESCRIPTION =========== ``futhark pyopencl`` translates a Futhark program to Python code invoking OpenCL kernels, which depends on Numpy and PyOpenCL. By default, the program uses the first device of the first OpenCL platform - this can be changed by passing ``-p`` and ``-d`` options to the generated program (not to ``futhark pyopencl`` itself). The resulting program will otherwise behave exactly as one compiled with ``futhark py``. While the sequential host-level code is pure Python and just as slow as in ``futhark py``, parallel sections will have been compiled to OpenCL, and runs just as fast as when using ``futhark opencl``. The kernel launch overhead is significantly higher, however, so a good rule of thumb when using ``futhark pyopencl`` is to aim for having fewer but longer-lasting parallel sections. The generated code requires at least PyOpenCL version 2015.2. OPTIONS ======= Accepts the same options as :ref:`futhark-opencl(1)`. SEE ALSO ======== :ref:`futhark-py(1)`, :ref:`futhark-opencl(1)` futhark-0.25.27/docs/man/futhark-python.rst000066400000000000000000000013301475065116200205570ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-py(1): ============== futhark-python ============== SYNOPSIS ======== futhark python [options...] infile DESCRIPTION =========== ``futhark python`` translates a Futhark program to sequential Python code, which depends on Numpy. The resulting program will read the arguments to the ``main`` function from standard input and print its return value on standard output. The arguments are read and printed in Futhark syntax. The generated code is very slow, likely too slow to be useful. It is more interesting to use this command's big brother, :ref:`futhark-pyopencl(1)`. OPTIONS ======= Accepts the same options as :ref:`futhark-c(1)`. SEE ALSO ======== :ref:`futhark-pyopencl(1)` futhark-0.25.27/docs/man/futhark-repl.rst000066400000000000000000000017731475065116200202130ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-repl(1): ============ futhark-repl ============ SYNOPSIS ======== futhark repl [program.fut] DESCRIPTION =========== Start an interactive Futhark session. This will let you interactively enter expressions and declarations which are then immediately interpreted. If the entered line can be either a declaration or an expression, it is assumed to be a declaration. The input msut fit on a single line. Futhark source files can be loaded using the ``:load`` command. This will erase any interactively entered definitions. Use the ``:help`` command to see a list of commands. All commands are prefixed with a colon. ``futhark repl`` uses the Futhark interpreter, which grants access to the ``#[trace]`` and ``#[break]`` attributes. See :ref:`futhark-run(1)` for a description. OPTIONS ======= -h Print help text to standard output and exit. -V Print version information on standard output and exit. SEE ALSO ======== :ref:`futhark-run(1)`, :ref:`futhark-test(1)` futhark-0.25.27/docs/man/futhark-run.rst000066400000000000000000000016331475065116200200500ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-run(1): =========== futhark-run =========== SYNOPSIS ======== futhark run [options...] DESCRIPTION =========== Execute the given program by evaluating an entry point (``main`` by default) with arguments read from standard input, and write the results on standard output. ``futhark run`` is very slow, and in practice only useful for testing, teaching, and experimenting with the language. The ``#[trace]`` and ``#[break]`` attributes are fully supported in the interpreter. Tracing prints values to stdout in contrast to compiled code, which prints to stderr. OPTIONS ======= -e NAME Run the given entry point instead of ``main``. -h Print help text to standard output and exit. -V Print version information on standard output and exit. -w, --no-warnings Disable interpreter warnings. SEE ALSO ======== :ref:`futhark-repl(1)`, :ref:`futhark-test(1)` futhark-0.25.27/docs/man/futhark-script.rst000066400000000000000000000053241475065116200205510ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-script(1): ================ futhark-script ================ SYNOPSIS ======== futhark script [options...] program [expression] DESCRIPTION =========== The command ``futhark script foo.fut expr`` will compile ``foo.fut``, run the provided FutharkScript expression ``expr``, and finally print the result to stdout. It is essentially a simpler way to access the evaluation facilities of :ref:`futhark-literate(1)`, and provides the same FutharkScript facilities, with a few additional built-in procedures documented below. If the provided program does not have a ``.fut`` extension, it is assumed to be a previously compiled server-mode program, and simply run directly. When ``-e`` and ``-f`` are used, the expressions are run in the order provided, and only the value of the last expression is printed. This implies multiple uses of these options is only useful when they invoke procedures with side effects. OPTIONS ======= --backend=name The backend used when compiling Futhark programs (without leading ``futhark``, e.g. just ``opencl``). Defaults to ``c``. -b, --binary Produce output in the binary data format. Fails if the value is not a primitive or array of primitives. -D, --debug Pass ``-D`` to the executable and show debug prints. -e, --expression=EXP Evaluate this FutharkScript expression. Expressions are run in the order provided. --futhark=program The program used to perform operations (eg. compilation). Defaults to the binary running ``futhark script`` itself. -f, --file=FILe Read and evaluate FutharkScript expression from this file. Expressions are run in the order provided. -L, --log Pass ``-L`` to the executable and show debug prints. --pass-option=opt Pass an option to benchmark programs that are being run. --pass-compiler-option=opt Pass an extra option to the compiler when compiling the programs. --skip-compilation Do not run the compiler, and instead assume that the program has already been compiled. Use with caution. -v, --verbose Print verbose information on stderr about directives as they are executing. This is also needed to see ``#[trace]`` output. ADDITIONAL BUILTINS =================== * ``$store "file" v`` store the value *v* (which must be a primitive or an array) as a binary value in the given file. BUGS ==== FutharkScript expressions can only refer to names defined in the file passed to ``futhark script``, not any names in imported files. If the result of the expression does not have an external representation (e.g. is an array of tuples), the value that is printed is misleading and somewhat nonsensical. SEE ALSO ======== :ref:`futhark-test(1)`, :ref:`futhark-bench(1)`, :ref:`futhark-literate(1)` futhark-0.25.27/docs/man/futhark-test.rst000066400000000000000000000213011475065116200202150ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-test(1): ============ futhark-test ============ SYNOPSIS ======== futhark test [options...] infiles... DESCRIPTION =========== Test Futhark programs based on input/output datasets. All contained ``.fut`` files within a given directory are considered. By default, tests are carried out with compiled code. This can be changed with the ``-i`` option. A Futhark test program is an ordinary Futhark program, with at least one test block describing input/output test cases and possibly other options. The last line must end in a newline. A test block consists of commented-out text with the following overall format:: description == cases... The ``description`` is an arbitrary (and possibly multiline) human-readable explanation of the test program. It is separated from the test cases by a line containing just ``==``. Any comment starting at the beginning of the line, and containing a line consisting of just ``==``, will be considered a test block. The format of a test case is as follows:: [tags { tags... }] [entry: names...] ["name..."] [compiled|nobench|random|script] input ({ values... } | @ filename) output { values... } | auto output | error: regex If a test case begins with a quoted string, that string is reported as the dataset name, including in the JSON file produced by :ref:`futhark-bench(1)`. If no name is provided, one is automatically generated. The name must be unique across all test cases. If ``compiled`` is present before the ``input`` keyword, this test case will never be passed to the interpreter. This is useful for test cases that are annoyingly slow to interpret. The ``nobench`` keyword is for data sets that are too small to be worth benchmarking, and only has meaning to :ref:`futhark-bench(1)`. If ``input`` is preceded by ``random``, the text between the curly braces must consist of a sequence of Futhark types, including sizes in the case of arrays. When ``futhark test`` is run, a file located in a ``data/`` subdirectory, containing values of the indicated types and shapes is, automatically constructed with :ref:`futhark-dataset(1)`. Apart from sizes, integer constants (with or without type suffix), and floating-point constants (always with type suffix) are also permitted. If ``input`` is preceded by ``script``, the text between the curly braces is interpreted as a FutharkScript expression (see :ref:`futhark-literate(1)`), which is executed to generate the input. It must use only functions explicitly declared as entry points. If the expression produces an *n*-element tuple, it will be unpacked and its components passed as *n* distinct arguments to the test function. The only builtin functions supported are ``$loaddata`` and ``$loadbytes``. If ``input`` is followed by an ``@`` and a file name (which must not contain any whitespace) instead of curly braces, values will be read from the indicated file. This is recommended for large data sets. This notation cannot be used with ``random`` input. With ``script input``, the file contents will be interpreted as a FutharkScript expression. After the ``input`` block, the expected result of the test case is written as either ``output`` followed by another block of values, or ``error:`` followed by a regex indicating an expected run-time error. If neither ``output`` nor ``error`` is given, the program will be expected to execute succesfully, but its output will not be validated. If ``output`` is preceded by ``auto`` (as in ``auto output``), the expected values are automatically generated by compiling the program with ``futhark c`` and recording its result for the given input (which must not fail). This is usually only useful for testing or benchmarking alternative compilers, and not for testing the correctness of Futhark programs. This currently does not work for ``script`` inputs. Alternatively, instead of input-output pairs, the test cases can simply be a description of an expected compile time type error:: error: regex This is used to test the type checker. Tuple syntax is not supported when specifying input and output values. Instead, you can write an N-tuple as its constituent N values. Beware of syntax errors in the values - the errors reported by ``futhark test`` are very poor. An optional tags specification is permitted in the first test block. This section can contain arbitrary tags that classify the benchmark:: tags { names... } Tag are sequences of alphanumeric characters, dashes, and underscores, with each tag seperated by whitespace. Any program with the ``disable`` tag is ignored by ``futhark test``. Another optional directive is ``entry``, which specifies the entry point to be used for testing. This is useful for writing programs that test libraries with multiple entry points. Multiple entry points can be specified on the same line by separating them with space, and they will all be tested with the same input/output pairs. The ``entry`` directive affects subsequent input-output pairs in the same comment block, and may only be present immediately preceding these input-output pairs. If no ``entry`` is given, ``main`` is assumed. See below for an example. For many usage examples, see the ``tests`` directory in the Futhark source directory. A simple example can be found in ``EXAMPLES`` below. OPTIONS ======= --backend=program The backend used when compiling Futhark programs (without leading ``futhark``, e.g. just ``opencl``). --cache-extension=EXTENSION For a program ``foo.fut``, pass ``--cache-file foo.fut.EXTENSION``. By default, ``--cache-file`` is not passed. -c Only run compiled code - do not run the interpreter. This is the default. -C Compile the programs, but do not run them. --concurrency=NUM The number of tests to run concurrently. Defaults to the number of (hyper-)cores available. --exclude=tag Do not run test cases that contain the given tag. Cases marked with "disable" are ignored by default, as are cases marked "no_foo", where *foo* is the backend used. -i Test with the interpreter. -I Pass the program through the compiler frontend, but do not run them. This is only useful for testing the Futhark compiler itself. -t Type-check the programs, but do not run them. -s Run ``structure`` tests. These are not run by default. When this option is passed, no other testing is done. --futhark=program The program used to perform operations (eg. compilation). Defaults to the binary running ``futhark test`` itself. --no-terminal Change the output format to be suitable for noninteractive terminals. Prints a status message roughly every minute. --no-tuning Do not look for tuning files. --pass-option=opt Pass an option to benchmark programs that are being run. For example, we might want to run OpenCL programs on a specific device:: futhark test prog.fut --backend=opencl --pass-option=-dHawaii --pass-compiler-option=opt Pass an extra option to the compiler when compiling the programs. --runner=program If set to a non-empty string, compiled programs are not run directly, but instead the indicated *program* is run with its first argument being the path to the compiled Futhark program. This is useful for compilation targets that cannot be executed directly (as with :ref:`futhark-pyopencl(1)` on some platforms), or when you wish to run the program on a remote machine. --tuning=EXTENSION For each program being run, look for a tuning file with this extension, which is suffixed to the name of the program. For example, given ``--tuning=tuning`` (the default), the program ``foo.fut`` will be passed the tuning file ``foo.fut.tuning`` if it exists. ENVIRONMENT VARIABLES ===================== ``TMPDIR`` Directory used for temporary files such as gunzipped datasets and log files. EXAMPLES ======== The following program tests simple indexing and bounds checking:: -- Test simple indexing of an array. -- == -- tags { firsttag secondtag } -- input { [4,3,2,1] 1i64 } -- output { 3 } -- input { [4,3,2,1] 5i64 } -- error: Error* let main (a: []i32) (i: i64): i32 = a[i] The following program contains two entry points, both of which are tested:: let add (x: i32) (y: i32): i32 = x + y -- Test the add1 function. -- == -- entry: add1 -- input { 1 } output { 2 } entry add1 (x: i32): i32 = add x 1 -- Test the sub1 function. -- == -- entry: sub1 -- input { 1 } output { 0 } entry sub1 (x: i32): i32 = add x (-1) The following program contains an entry point that is tested with randomly generated data:: -- == -- random input { [100]i32 [100]i32 } auto output -- random input { [1000]i32 [1000]i32 } auto output let main xs ys = i32.product (map2 (*) xs ys) SEE ALSO ======== :ref:`futhark-bench(1)`, :ref:`futhark-repl(1)` futhark-0.25.27/docs/man/futhark-wasm-multicore.rst000066400000000000000000000016511475065116200222140ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-wasm-multicore(1): ====================== futhark-wasm-multicore ====================== SYNOPSIS ======== futhark wasm-multicore [options...] DESCRIPTION =========== ``futhark wasm-multicore`` translates a Futhark program to multi-threaded WebAssembly code by first generating C as ``futhark c``, and then using Emscripten (``emcc``). This produces a ``.js`` file that allows the compiled code to be invoked from JavaScript. Executables implement the Futhark server protocol and can be run with Node.js. OPTIONS ======= Accepts the same options as :ref:`futhark-c(1)`. ENVIRONMENT VARIABLES ===================== Respects the same environment variables as :ref:`futhark-wasm(1)`. EXECUTABLE OPTIONS ================== Generated executables accept the same options as those generated by :ref:`futhark-wasm(1)`. SEE ALSO ======== :ref:`futhark-c(1)`, :ref:`futhark-wasm(1)` futhark-0.25.27/docs/man/futhark-wasm.rst000066400000000000000000000023371475065116200202150ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark-wasm(1): ============ futhark-wasm ============ SYNOPSIS ======== futhark wasm [options...] DESCRIPTION =========== ``futhark wasm`` translates a Futhark program to sequential WebAssembly code by first generating C as ``futhark c``, and then using Emscripten (``emcc``). This produces a ``.js`` file that allows the compiled code to be invoked from JavaScript. Executables implement the Futhark server protocol and can be run with Node.js. OPTIONS ======= Accepts the same options as :ref:`futhark-c(1)`. ENVIRONMENT VARIABLES ===================== ``CFLAGS`` Space-separated list of options passed to ``emcc``. Defaults to ``-O3 -std=c99`` if unset. ``EMCFLAGS`` Space-separated list of options passed to ``emcc``. EXECUTABLE OPTIONS ================== The following options are accepted by executables generated by ``futhark wasm``. -h, --help Print help text to standard output and exit. -D, --debugging Perform possibly expensive internal correctness checks and verbose logging. Implies ``-L``. -L, --log Print various low-overhead logging information to stderr while running. SEE ALSO ======== :ref:`futhark-c(1)`, :ref:`futhark-wasm-multicore(1)` futhark-0.25.27/docs/man/futhark.rst000066400000000000000000000067211475065116200172510ustar00rootroot00000000000000.. role:: ref(emphasis) .. _futhark(1): ======= futhark ======= SYNOPSIS ======== futhark options... DESCRIPTION =========== Futhark is a data-parallel functional array language. Through various subcommands, the ``futhark`` tool provides facilities for compiling, developing, or analysing Futhark programs. Most subcommands are documented in their own manpage. For example, ``futhark opencl`` is documented as :ref:`futhark-opencl(1)`. The remaining subcommands are documented below. COMMANDS ======== futhark benchcmp FILE_A FILE_B ------------------------------ Compare two JSON files produced by the ``--json`` option of :ref:`futhark-bench(1)`. The results show speedup of the latter file compared to the former. futhark check [-w] [-Werror] PROGRAM ------------------------------------ Check whether a Futhark program type checks. With ``-w``, no warnings are printed. With ``--Werror``, warnings are treated as errors. futhark check-syntax PROGRAM ---------------------------- Check whether a Futhark program is syntactically correct. futhark datacmp FILE_A FILE_B ----------------------------- Check whether the two files contain the same Futhark values. The files must be formatted using the general Futhark data format that is used by all other executable and tools (such as :ref:`futhark-dataset(1)`). All discrepancies will be reported. This is in contrast to :ref:`futhark-test(1)`, which only reports the first one. futhark dataget PROGRAM DATASET ------------------------------- Find the test dataset whose description contains ``DATASET`` (e.g. ``#1``) and print it in binary representation to standard output. This does not work for ``script`` datasets. futhark defs PROGRAM -------------------- Print names and locations of every top-level definition in the program (including top levels of modules), one per line. The program need not be type-correct, but it must not contain syntax errors. futhark dev options... PROGRAM ------------------------------ A Futhark compiler development command, intentionally undocumented and intended for use in developing the Futhark compiler, not for programmers writing in Futhark. futhark eval [-f FILE] [-w] -------------------------------------- Evaluates expressions given as command-line arguments. Optionally allows a file import using ``-f``. futhark hash PROGRAM -------------------- Print a hexadecimal hash of the program AST, including all non-builtin imports. Supposed to be invariant to whitespace changes. futhark imports PROGRAM ----------------------- Print all non-builtin imported Futhark files to stdout, one per line. futhark lsp ----------- Run an LSP (Language Server Protocol) server for Futhark that communicates on standard input. There is no reason to run this by hand. It is used by LSP clients to provide editor features. futhark query PROGRAM LINE COL ------------------------------ Print information about the variable at the given position in the program. futhark thanks -------------- Expresses gratitude. futhark tokens FILE ------------------- Print the tokens the given Futhark source file; one per line. SEE ALSO ======== :ref:`futhark-opencl(1)`, :ref:`futhark-c(1)`, :ref:`futhark-py(1)`, :ref:`futhark-pyopencl(1)`, :ref:`futhark-wasm(1)`, :ref:`futhark-wasm-multicore(1)`, :ref:`futhark-ispc(1)`, :ref:`futhark-dataset(1)`, :ref:`futhark-doc(1)`, :ref:`futhark-test(1)`, :ref:`futhark-bench(1)`, :ref:`futhark-run(1)`, :ref:`futhark-repl(1)`, :ref:`futhark-literate(1)` futhark-0.25.27/docs/manifest.schema.json000066400000000000000000000173231475065116200202400ustar00rootroot00000000000000{ "$schema": "https://json-schema.org/draft/2020-12/schema", "$id": "https://futhark-lang.org/manifest.schema.json", "title": "Futhark C Manifest", "description": "The C API presented by a compiled Futhark program", "type": "object", "properties": { "backend": {"type": "string"}, "version": {"type": "string"}, "entry_points": { "type": "object", "additionalProperties": { "type": "object", "properties": { "cfun": {"type": "string"}, "tuning_params": { "type": "array", "items": { "type": "string" } }, "outputs": { "type": "array", "items": { "type": "object", "properties": { "type": {"type": "string"}, "unique": {"type": "boolean"} }, "additionalProperties": false } }, "inputs": { "type": "array", "items": { "type": "object", "properties": { "name": {"type": "string"}, "type": {"type": "string"}, "unique": {"type": "boolean"} }, "additionalProperties": false } } } } }, "types": { "type": "object", "additionalProperties": { "oneOf": [ { "type": "object", "properties": { "kind": {"const": "opaque"}, "ctype": {"type": "string"}, "ops": { "type": "object", "properties": { "free": {"type": "string"}, "store": {"type": "string"}, "restore": {"type": "string"} }, "additionalProperties": false }, "record": { "type": "object", "properties": { "new": {"type": "string"}, "fields": { "type": "array", "items": { "type": "object", "properties": { "name": {"type": "string"}, "type": {"type": "string"}, "project": {"type": "string"} } } } }, "additionalProperties": false }, "sum": { "type": "object", "properties": { "variant": {"type": "string"}, "variants": { "type": "array", "items": { "type": "object", "properties": { "construct": {"type": "string"}, "destruct": {"type": "string"}, "payload": {"type": "array", "items": { "type": "string" } } } } } }, "additionalProperties": false }, "record_array": { "type": "object", "properties": { "rank": {"type": "integer"}, "elemtype": {"type": "string"}, "zip": {"type": "string"}, "index": {"type": "string"}, "shape": {"type": "string"}, "fields": { "type": "array", "items": { "type": "object", "properties": { "name": {"type": "string"}, "type": {"type": "string"}, "project": {"type": "string"} } } } }, "additionalProperties": false }, "opaque_array": { "type": "object", "properties": { "rank": {"type": "integer"}, "elemtype": {"type": "string"}, "index": {"type": "string"}, "shape": {"type": "string"} }, "additionalProperties": false } }, "required": [ "kind", "ctype", "ops" ] }, { "type": "object", "properties": { "kind": {"const": "array"}, "ctype": {"type": "string"}, "rank": {"type": "integer"}, "elemtype": { "enum": ["i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f16", "f32", "f64", "bool"] }, "ops": { "type": "object", "properties": { "free": {"type": "string"}, "shape": {"type": "string"}, "values": {"type": "string"}, "values_raw": {"type": "string"}, "new": {"type": "string"}, "new_raw": {"type": "string"}, "index": {"type": "string"} }, "additionalProperties": false } } }] } } }, "required": ["backend", "entry_points", "types"], "additionalProperties": false } futhark-0.25.27/docs/package-management.rst000066400000000000000000000362041475065116200205360ustar00rootroot00000000000000.. _package-management: Package Management ================== This document describes ``futhark pkg``, a minimalistic package manager inspired by `vgo `_. A Futhark package is a downloadable collection of ``.fut`` files and little more. There is a (not necessarily comprehensive) `list of known packages `_. Basic Concepts -------------- A package is uniquely identified with a *package path*, which is similar to a URL, except without a protocol. At the moment, package paths must be something that can be passed to ``git clone``. In particular, this includes paths to repositories on major code hosting sites such as GitLab and GitHub. In the future, this will become more flexible. As an example, a package path may be ``github.com/athas/fut-foo``. Packages are versioned with `semantic version numbers `_ of the form ``X.Y.Z``. Whenever versions are indicated, all three digits must always be given (that is, ``1.0`` is not a valid shorthand for ``1.0.0``). Most ``futhark pkg`` operations involve reading and writing a *package manifest*, which is always stored in a file called ``futhark.pkg``. The ``futhark.pkg`` file is human-editable, but is in day-to-day use mainly modified by ``futhark pkg`` automatically. Using Packages -------------- Required packages can be added by using ``futhark pkg add``, for example:: $ futhark pkg add github.com/athas/fut-foo 0.1.0 This will create a new file ``futhark.pkg`` with the following contents: .. code-block:: text require { github.com/athas/fut-foo 0.1.0 #d285563c25c5152b1ae80fc64de64ff2775fa733 } This lists one required package, with its package path, minimum version (see :ref:`version-selection`), and the expected commit hash. The latter is used for verification, to ensure that the contents of a package version cannot be changed silently. ``futhark pkg`` will perform network requests to determine whether a package of the given name and with the given version exists and fail otherwise (but it will not check whether the package is otherwise well-formed). The version number can be elided, in which case ``futhark pkg`` will use the newest available version. If the package is already present in ``futhark.pkg``, it will simply have its version requirement changed to the one specified in the command. Any dependencies of the package will *not* be added to ``futhark.pkg``, but will still be downloaded by ``futhark pkg sync`` (see below). Adding a package with ``futhark pkg add`` modifies ``futhark.pkg``, but does not download the package files. This is done with ``futhark pkg sync`` (without further options). The contents of each required dependency and any transitive dependencies will be stored in a subdirectory of ``lib/`` corresponding to their package path. As an example:: $ futhark pkg sync $ tree lib lib └── github.com └── athas └── fut-foo └── foo.fut 3 directories, 1 file **Warning:** ``futhark pkg sync`` will remove any unrecognized files or local modifications to files in ``lib/`` (except of course the package directory of the package path listed in ``futhark.pkg``; see :ref:`creating-packages`). Packages can be removed from ``futhark.pkg`` with:: $ futhark pkg remove pkgpath You will need to run ``futhark pkg sync`` to actually remove the files in ``lib/``. The intended usage is that ``futhark.pkg`` is added to version control, but ``lib/`` is not, as the contents of ``lib/`` can always be reproduced from ``futhark.pkg``. However, adding ``lib/`` works just fine as well. Importing Files from Dependencies ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ``futhark pkg sync`` will populate the ``lib/`` directory, but does not interact with the compiler in any way. The downloaded files can be imported using the usual ``import`` mechanism (:ref:`other-files`); for example, assuming the package contains a file ``foo.fut``:: import "lib/github.com/athas/fut-foo/foo" Ultimately, everything boils down to ordinary file system semantics. This has the downside of relatively long and clumsy import paths, but the upside of predictability. Upgrading Dependencies ~~~~~~~~~~~~~~~~~~~~~~ The ``futhark pkg upgrade`` command will update every version requirement in ``futhark.pkg`` to be the most recent available version. You still need to run ``futhark pkg sync`` to actually retrieve the new versions. Be careful - while upgrades are safe if semantic versioning is followed correctly, this is not yet properly machine-checked, so human mistakes may occur. As an example: .. code-block:: text $ cat futhark.pkg require { github.com/athas/fut-foo 0.1.0 #d285563c25c5152b1ae80fc64de64ff2775fa733 } $ futhark pkg upgrade Upgraded github.com/athas/fut-foo 0.1.0 => 0.2.1. $ cat futhark.pkg require { github.com/athas/fut-foo 0.2.1 #3ddc9fc93c1d8ce560a3961e55547e5c78bd0f3e } $ futhark pkg sync $ tree lib lib └── github.com └── athas ├── fut-bar │   └── bar.fut └── fut-foo └── foo.fut 4 directories, 2 files Note that ``fut-foo 0.2.1`` depends on ``github.com/athas/fut-bar``, so it was fetched by ``futhark pkg sync``. ``futhark pkg upgrade`` will *never* upgrade across a major version number. Due to the principle of `Semantic Import Versioning `_, a new major version is a completely different package from the point of view of the package manager. Thus, to upgrade to a new major version, you will need to use ``futhark pkg add`` to add the new version and ``futhark pkg remove`` to remove the old version. Or you can keep it around - it is perfectly acceptable to depend on multiple major versions of the same package, because they are really different packages. .. _creating-packages: Creating Packages ----------------- A package is a directory tree (which at the moment must correspond to a Git repository). It *must* contain two things: * A file ``futhark.pkg`` at the root defining the package path and any required packages. * A *package directory* ``lib/pkg-path``, where ``pkg-path`` is the full package path. The contents of the package directory is what will be made available to users of the package. The repository may contain other things (tests, data files, examples, docs, other programs, etc), but these are ignored by ``futhark pkg``. This structure can be created automatically by running for example:: $ futhark pkg init github.com/sturluson/edda Note again, no ``https://``. The result is this ``futhark.pkg``:: package github.com/sturluson/edda require { } And this file hierarchy: .. code-block:: text $ tree lib lib └── github.com └── sturluson └── edda 3 directories, 0 files Note that ``futhark pkg init`` is not necessary simply to *use* packages, only when *creating* packages. When creating a package, the ``.fut`` files we are writing will be located inside the ``lib/`` directory. If the package has its own dependencies, whose files we would like to access, we can use *relative imports*. For example, assume we are creating a package ``github.com/sturluson/edda`` and we are writing a Futhark file located at ``lib/github.com/sturluson/edda/saga.fut``. Further, we have a dependency on the package ``github.com/athas/foo-fut``, which is stored in the directory ``lib/github.com/athas/foo-fut``. We can import a file ``lib/github.com/athas/foo-fut/foo.fut`` from ``lib/github.com/sturluson/edda/saga.fut`` with:: import "../foo-fut/foo" Releasing a Package ~~~~~~~~~~~~~~~~~~~ Currently, a package corresponds exactly to a GitHub repository mirroring the package path. A release is done by tagging an appropriate commit with ``git tag vX.Y.Z`` and then pushing the tag to GitHub with ``git push --tags``. In the future, this will be generalised to other code hosting sites and version control systems (and possibly self-hosted tarballs). Remember to take semantic versioning into account - unless you bump the major version number (or the major version is 0), the new version must be *fully compatible* with the old. When releasing a new package, consider getting it added to the `central package list `_. See `this page `_ for details. Incrementing the Major Version Number ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ While backwards-incompatible modifications to a package are sometimes unavoidable, it is wise to avoid them as much as possible, as they significantly inconvenience users. To discourage breaking compatibility, ``futhark pkg`` tries to ensure that the package developer feels this inconvenience as well. In many cases, an incompatible change can be avoided simply by adding new files to the package rather than incompatibly changing the existing ones. In the general case, the package path also encodes the major version of the package, separated with a ``@``. For example, version 5.2.1 of a package might have the package path ``github.com/user/repo@5``. For major versions 0 and 1, this can be elided. This means that multiple (major) versions of a package are completely distinct from the point of view of the package manager - this principle is called `Semantic Import Versioning `_, and is intended to facilitate backwards compatibility of packages when new versions are released. If you really must increment the major version, then you will need to change the package path in ``futhark.pkg`` to contain the new major version preceded by ``@``. For example, ``lib/github.com/sturluson/edda`` becomes ``lib/github.com/sturluson/edda@2``. As a special case, this is not necessary when moving from major version 0 to 1. Since the package path has changed, you will also need to rename the package directory in ``lib/``. This is painful and awkward, but it is less painful and awkward than what users feel when their dependencies break compatibility. Renaming a Package ~~~~~~~~~~~~~~~~~~ It is likely that the hosting location for a very long-lived package will change from time to time. Since the hosting location is embedded into the package path itself, this causes some issues for ``futhark pkg``. In simple cases, there is no problem. Consider a package ``github.com/asgard/loki`` which is moved to ``github.com/utgard/loki``. If no GitHub-level redirect is set up, all users must update the path by which they import the package. This is unavoidable, unfortunately. However, the old tagged versions, which contain a ``futhark.pkg`` that uses the old package path, will continue to work. This is because the package path indicated in ``package.pkg`` merely defines the subdirectory of ``lib/`` where the package files are to be found, while the package path used by dependents in the ``require`` section defines where the package files are located after ``futhark pkg sync``. Thus, when we import an old version of ``github.com/utgard/loki`` whose ``futhark.pkg`` defines the package as ``github.com/asgard/loki``, the package files will be retrieved from the ``lib/github.com/asgard/loki`` directory in the repository, but stored at ``lib/github.com/utgard/loki`` in the local directory. The above means that package management remains operational in simple cases of renaming, but it is awkward when a transitive dependency is renamed (or deleted). The Futhark package ecosystem is sufficiently embryonic that we have not yet developed more robust solutions. When such solutions are developed, they will likely involve some form of ``replace`` directive that allows transparent local renaming of packages, as well as perhaps a central registry of package paths that does not depend on specific source code hosts. .. _version-selection: Version Selection ----------------- The package manifest ``futhark.pkg`` declares which packages the program depends on. Dependencies are specified as the *oldest acceptable version* within the given major version. Upper version bounds are not supported, as strict adherence to semantic versioning is assumed, so any later version with the same major version number should work. When ``futhark pkg sync`` calculates which version of a given package to download, it will pick the oldest version that still satisfies the minimum version requirements of that package in all transitive dependencies. This means that a version may be used that is newer than the one indicated in ``futhark.pkg``, but only if a dependency requires a more recent version. Tests and Documentation for Dependencies ---------------------------------------- Package management has been designed to ensure that the normal development tools work as expected with the contents of the ``lib/`` directory. For example, to ensure that all dependencies do in fact work well (or at least compile) together, run: .. code-block:: text futhark test lib Also, you can generate hyperlinked documentation for all dependencies with: .. code-block:: text futhark doc lib -o docs The file ``docs/index.html`` can be opened in a web browser to browse the documentation. Prebuilt documentation is also available via the `online package list `_. Safety ------ In contrast to some other package managers, ``futhark pkg`` does not run any package-supplied code on installation, upgrade, or removal. This means that all ``futhark pkg`` operations are in principle completely safe (barring exploitable bugs in ``futhark pkg`` itself, which is unlikely but not impossible). Further, Futhark code itself is also completely pure, so executing it cannot have any unfortunate effects, such as `infecting all of your own packages with a worm `_. The worst it can do is loop infinitely, consume arbitrarily large amounts of memory, or produce wrong results. The exception is packages that uses ``unsafe``. With some cleverness, ``unsafe`` can be combined with in-place updates to perform arbitrary memory reads and writes, which can trivially lead to exploitable behaviour. You should not use untrusted code that employs ``unsafe`` (but the ``--safe`` compiler option may help). However, this is not any worse than calling external code in a conventional impure language, which generally can perform any conceivable harmful action. Private repositories -------------------- The Futhark package manager is intentionally very simple - perhaps even simplistic. The key philosophy is that if you can ``git clone`` a repository from the command line, then ``futhark pkg`` can also access it. However, ``futhark pkg`` always uses the ``https://`` protocol when converting package paths to the URLs that are passed to ``git``, which is sometimes inconvenient for self-hosted or private repositories. As a workaround, you can modify your Git configuration file to transparently replace ``https://`` with ``ssh://`` for certain repositories. For example, you can add the following entry ``$HOME/.gitconfig``:: [url "ssh://git@github.com/sturluson"] insteadOf = https://github.com/sturluson This will make all interactions with repositories owned by the ``sturluson`` user on GitHub use SSH instead of HTTPS. futhark-0.25.27/docs/performance.rst000066400000000000000000000404411475065116200173300ustar00rootroot00000000000000.. _performance: Writing Fast Futhark Programs ============================= This document contains tips, tricks, and hints for writing efficient Futhark code. Ideally you'd need to know nothing more than an abstract cost model, but sometimes it is useful to have an idea of how the compiler will transform your program, what values look like in memory, and what kind of code the compiler will generate for you. These details are documented below. Don't be discouraged by the complexities mentioned here - most Futhark programs are written without worrying about any of these details, and they still manage to run with good performance. This document focuses on corner cases and pitfalls, which easily makes for depressing reading. Parallelism ----------- The Futhark compiler only generates parallel code for explicitly parallel constructs such as ``map`` and ``reduce``. A plain ``loop`` will *not* result in parallel code (unless the loop body itself contains parallel operations). The most important parallel constructs are the *second-order array combinators* (SOACs) such as ``map`` and ``reduce``, but functions such as ``copy`` are also parallel. When describing the asymptotic cost of a Futhark function, it is not enough to give a traditional big-O measure of the total amount of work. Both ``foldl`` and ``reduce`` involve *O(n)* work, where *n* is the size of the input array, but ``foldl`` is sequential while ``reduce`` is parallel, and this is an important distinction. To make this distinction, each function is described by *two* costs: the *work*, which is the total amount of operations, and the *span* (sometimes called *depth*) which is intuitively the "longest chain of sequential dependencies". We say that ``foldl`` has span *O(n)*, while ``reduce`` has span *O(log(n))*. This explains that ``reduce`` is more parallel than ``foldl``. The documentation for a Futhark function should mention both its work and span. `See this `_ for more details on parallel cost models and pointers to literature. Scans and reductions ~~~~~~~~~~~~~~~~~~~~ The ``scan`` and ``reduce`` SOACs are rather inefficient when their operators are on arrays. If possible, use tuples instead (see :ref:`performance-small-arrays`). The one exception is when the operator is a ``map2`` or equivalent. Example: .. code-block:: futhark reduce (map2 (+)) (replicate n 0) xss Such "vectorised" operators are typically handled quite efficiently. Although to be on the safe side, you can rewrite the above by interchanging the ``reduce`` and ``map``: .. code-block:: futhark map (reduce (+) 0) (transpose xss) Avoid reductions over tiny arrays, e.g. ``reduce (+) 0 [x,y,z]``. In such cases the compiler will generate complex code to exploit a negligible amount of parallelism. Instead, just unroll the loop manually (``x+y+z``) or perhaps use ``foldl (+) 0 [x,z,y]``, which produces a sequential loop. Histograms ~~~~~~~~~~ The ``reduce_by_index`` construct ("generalised histogram") has a clever and adaptive implementation that handles multiple updates of the same bin efficiently. Its main weakness is when computing a very large histogram (many millions of bins) where only a tiny fraction of the bins are updated. This is because the main mechanism for optimising conflicts is by duplicating the histogram in memory, but this is not efficient when it is very large. If you know your program involves such a situation, it may be better to implement the histogram operation by sorting and then performing an irregular segmented reduction. Particularly with the GPU backends, ``reduce_by_index`` is much faster when the operator involves a single 32-bit or 64-bit value. Even if you really want an 8-bit or 16-bit result, it may be faster to compute it with a 32-bit or 64-bit type and manually mask off the excess bits. Nested parallelism ~~~~~~~~~~~~~~~~~~ Futhark allows nested parallelism, understood as a parallel construct used inside some other parallel construct. The simplest example is nested SOACs. Example: .. code-block:: futhark map (\xs -> reduce (+) 0 xs) xss Nested parallelism is allowed and encouraged, but its compilation to efficient code is rather complicated, depending on the compiler backend that is used. The problem is that sometimes exploiting all levels of parallelism is not optimal, yet how much to exploit depends on run-time information that is not available to the compiler. Sequential backends !!!!!!!!!!!!!!!!!!! The sequential backends are straightforward: all parallel operations are compiled into sequential loops. Due to Futhark's low-overhead data representation (see below), this is often surprisingly efficient. Multicore backend !!!!!!!!!!!!!!!!! Whenever the multicore backend encounters nested parallelism, it generates two code versions: one where the nested parallel constructs are also parallelised (possibly recursively involving further nested parallelism), and one where they are turned into sequential loops. At runtime, based on the amount of work available and self-tuning heuristics, the scheduler picks the version that it believes best balances overhead with exploitation of parallelism. GPU backends !!!!!!!!!!!! The GPU backends handle parallelism through an elaborate program transformation called *incremental flattening*. The full details are beyond the scope of this document, but some properties are useful to know of. `See this paper `_ for more details. The main restriction is that the GPU backends can only handle *regular* nested parallelism, meaning that the sizes of inner parallel dimensions are invariant to the outer parallel dimensions. For example, this expression contains *irregular* nested parallelism: .. code-block:: futhark map (\i -> reduce (+) 0 (iota i)) is This is because the size of the nested parallel construct is ``i``, and ``i`` has a different value for every iteration of the outer ``map``. The Futhark compiler will currently turn the irregular constructs (here, the ``reduce``) into a sequential loop. Depending on how complicated the irregularity is, it may even refuse to generate code entirely. Incremental flattening is based on generating multiple code versions to cater to different classes of datasets. At run-time, one of these versions will be picked for execution by comparing properties of the input (its size) with a *threshold parameter*. These threshold parameters have sensible defaults, but for optimal performance, they can be tuned with :ref:`futhark-autotune(1)`. Value Representation -------------------- The compiler discards all type abstraction when compiling. Using the module system to make a type abstract causes no run-time overhead. Scalars ~~~~~~~ Scalar values (``i32``, ``f64``, ``bool``, etc) are represented as themselves. The internal representation does not distinguish signs, so ``i32`` and ``u32`` have the same representation, and converting between types that differ only in sign is free. Tuples ~~~~~~ Tuples are flattened and then represented directly by their individual components - there are no *tuple objects* at runtime. A function that takes an argument of type ``(f64,f64)`` corresponds to a C function that takes two arguments of type ``double``. This has one performance implication: whenever you pass a tuple to a function, the *entire* tuple is copied (except any embedded arrays, which are always passed by reference, see below). Due to the compiler's heavy use of inlining, this is rarely a problem in practice, but it can be a concern when using the ``loop`` construct with a large tuple as the loop variant parameter. Records ~~~~~~~ Records are turned into tuples by simply sorting their fields and discarding the labels. This means there is no overhead to using a record compared to using a tuple. Sum Types ~~~~~~~~~ A sum type value is represented as a tuple containing all the payload components in order, prefixed with an `i8` tag to identify the constructor. For example, .. code-block:: futhark #foo i32 bool | #bar i32 would be represented as a tuple of type .. code-block:: futhark (i8, i32, bool, i32) where the value .. code-block:: futhark #foo 42 false is represented as .. code-block:: futhark (1, 42, false, 0) where ``#foo`` is assigned the tag ``1`` because it is alphabetically after ``#bar``. To shrink the tuples, if multiple constructors have payload elements of the *same* type, the compiler assigns them to the same elements in the result tuple. The representation of the above sum type is actually the following: .. code-block:: futhark (i8, i32, bool) The types must be the *same* for deduplication to take place - despite `i32` and `f32` being of the same size, they cannot be assigned the same tuple element. This means that the type .. code-block:: futhark #foo [n]i32 | #bar [n]i32 is efficiently represented as .. code-block:: futhark (u8, [n]i32) However the type .. code-block:: futhark #foo [n]i32 | #bar [n]f32 is represented as .. code-block:: futhark (u8, [n]i32, [n]f32) which is not great. Take caution when you use sum types with large arrays in their payloads. Functions ~~~~~~~~~ Higher-order functions are implemented via defunctionalisation. At run-time, they are represented by a record containing their lexical closure. Since the type system forbids putting functions in arrays, this is essentially a constant cost, and not worth worrying about. Arrays ~~~~~~ Arrays are the only Futhark values that are boxed - that is, are stored on the heap. The elements of an array are unboxed, stored adjacent to each other in memory. There is zero memory overhead except for the minuscule amount needed to track the shape of the array. Multidimensional arrays !!!!!!!!!!!!!!!!!!!!!!! At the surface language level, Futhark may appear to support "arrays of arrays", and this is indeed a convenient aspect of its programming model, but at runtime multi-dimensional arrays are stored in flattened form. A value of type ``[x][y]i32`` is laid out in memory simply as one array containing *x\*y* integers. This means that constructing an array ``[x,y,x]`` can be (relatively) expensive if ``x``, ``y``, ``z`` are themselves large arrays, as they must be copied in their entirety. Since arrays cannot contain other arrays, memory management only has to be concerned with one level of indirection. In practice, it means that Futhark can use straightforward reference counting to keep track of when to free the memory backing an array, as circular references are not possible. Further, since arrays tend to be large and relatively few in number, the usual performance impact of naive reference counting is not present. Arrays of tuples !!!!!!!!!!!!!!!! For arrays of tuples, Futhark uses the so-called `structure of arrays `_ representation. In Futhark terms, an array ``[n](a,b,c)`` is at runtime represented as the tuple ``([n]a,[n]b,[n]c)``. This means that the final memory representation always consists of arrays of scalars. This has some significant implications. For example, ``zip`` and ``unzip`` are very cheap, as the actual runtime representation is in always "unzipped", so these functions don't actually have to do anything. Since records and sum types are represented as tuples, this also explains how arrays of these are represented. Element order !!!!!!!!!!!!! The exact in-memory element ordering is up to the compiler, and depends on how the array is constructed and how it is used. Absent any other information, Futhark represents multidimensional arrays in row-major order. However, depending on how the array is traversed, the compiler may insert code to represent it in some other order. For particularly tricky programs, an array may even be duplicated in memory, represented in different ways, to ensure efficient traversal. This means you should normally *not* worry about how to represent your arrays to ensure coalesced access on GPUs or similar. That is the compiler's job. Crucial Optimisations --------------------- Some of the optimisations done by the Futhark compiler are important, complex, or subtle enough that it may be useful to know how they work, and how to write code that caters to their quirks. Fusion ~~~~~~ Futhark performs fusion of SOACs. For an expression ``map f (map g A)``, then the compiler will optimise this into a single ``map`` with the composition of ``f`` and ``g``, which prevents us from storing an intermediate array in memory. This is called *vertical fusion* or *producer-consumer fusion*. In this case the *producer* is ``map g`` and the *consumer* is ``map f``. Fusion does not depend on the expressions being adjacent as in this example, as the optimisation is performed on a data dependency graph representing the program. This means that you can decompose your programs into many small parallel operations without worrying about the overhead, as the compiler will fuse them together automatically. Not all producer-consumer relationships between SOACs can be fused. Generally, ``map`` can always be fused as a producer, but ``scan``, ``reduce``, and similar SOACs can only act as consumers. *Horizontal fusion* occurs when two SOACs take as input the same array, but are not themselves in a producer-consumer relationship. Example: .. code-block:: futhark (map f xs, map g xs) Such cases are fused into a single operation that traverses ``xs`` just once. More than two SOACs can be involved in horizontal fusion, and they need not be of the same kind (e.g. one could be a ``map`` and the other a ``reduce``). Free Operations --------------- Some operations such as array slicing, ``take``, ``drop``, ``transpose`` and ``reverse`` are "free" in the sense that they merely return a different view of some underlying array. In most cases they have constant cost, no matter the size of the array they operate on. This is because they are index space transformations that simply result in different code being generated when the arrays are eventually used. However, there are some cases where the compiler is forced to manifest such a "view" as an actual array in memory, which involves a full copy. An incomplete list follows: * Any array returned by an entry point is converted to row-major order. * An array returned by an ``if`` branch may be copied if its representation is substantially different from that of the other branch. * An array returned by a ``loop`` body may be copied if its representation is substantially different from that of the initial loop values. * An array is copied whenever it becomes the element of another multidimensional array. This is most obviously the case for array literals (``[x,y,z]``), but also for ``map`` expressions where the mapped function returns an array. Note that this notion of "views" is not part of the Futhark type system - it is merely an implementation detail. Strictly speaking, all functions that return an array (e.g. ``reverse``) should be considered to have a cost proportional to the size of the array, even if that cost will almost never actually be paid at run-time. If you want to be sure no copy takes place, it may be better to explicitly maintain tuples of indexes into some other array. .. _performance-small-arrays: Small Arrays ------------ The compiler assumes arrays are "large", which for example means that operations across them are worth parallelising. It also means they are boxed and heap-allocated, even when the size is a small constant. This can cause unexpectedly bad performance when using small constant-size arrays (say, five elements or less). Consider using tuples or records instead. `This post `_ contains more information on how and why. If in doubt, try both and measure which is faster. Inlining -------- The compiler currently inlines all functions at their call site, unless they have been marked with the ``noinline`` attribute (see :ref:`attributes`). This can lead to code explosion, which mostly results in slow compile times, but can also affect run-time performance. In many cases this is currently unavoidable, but sometimes the program can be rewritten such that instead of calling the same function in multiple places, it is called in a single place, in a loop. E.g. we might rewrite ``f x (f y (f z v))`` as: .. code-block:: futhark loop acc = v for a in [z,y,x] do f a acc futhark-0.25.27/docs/report.schema.json000066400000000000000000000012761475065116200177450ustar00rootroot00000000000000{ "$schema": "https://json-schema.org/draft/2019-09/schema", "title":"Futhark runtime report", "type": "object", "properties": { "memory": { "type": "object", "patternProperties":{ "": {"type": "integer"} } }, "events": { "type": "array", "items": { "type": "object", "properties": { "Name": {"type":"string"}, "Start": {"type":"integer"}, "End": {"type":"integer"} }, "required": ["name", "start", "end"] }}}, "required": ["memory", "events"] } futhark-0.25.27/docs/requirements.txt000066400000000000000000000000341475065116200175530ustar00rootroot00000000000000pyyaml>=4.2b1 sphinx>=4.2.0 futhark-0.25.27/docs/server-protocol.rst000066400000000000000000000142431475065116200201750ustar00rootroot00000000000000.. _server-protocol: Server Protocol =============== A Futhark program can be compiled to a *server executable*. Such a server maintains a Futhark context and presents a line-oriented interface (over stdin/stdout) for loading and dumping values, as well as calling the entry points in the program. The main advantage over the plain executable interface is that program initialisation is done only *once*, and we can work with opaque values. The server interface is not intended for human consumption, but is useful for writing tools on top of Futhark programs, without having to use the C API. Futhark's built-in benchmarking and testing tools use server executables. A server executable is started like any other executable, and supports most of the same command line options (:ref:`executable-options`). Basics ------ Each command is sent as a *single line* on standard input. A command consists of space-separated *words*. A word is either a sequence of non-space characters (``foo``), *or* double quotes surrounding a sequence of non-newline and non-quote characters (``"foo bar"``). The response is sent on standard output. The server will print ``%%% OK`` on a line by itself to indicate that a command has finished. It will also print ``%%% OK`` at startup once initialisation has finished. If initialisation fails, the process will terminate. If a command fails, the server will print ``%%% FAILURE`` followed by the error message, and then ``%%% OK`` when it is ready for more input. Some output may also precede ``%%% FAILURE``, e.g. logging statements that occured before failure was detected. Fatal errors that lead to server shutdown may be printed to stderr. Variables --------- Some commands produce or read variables. A variable is a mapping from a name to a Futhark value. Values can be both transparent (arrays and primitives), but they can also be *opaque* values. These can be produced by entry points and passed to other entry points, but cannot be directly inspected. Types ----- All variables have types, and all entry points accept inputs and produce outputs of defined types. The notion of transparent and opaque types are the same as in the C API: primitives and array of primitives are directly supported, and everything else is treated as opaque. See also :ref:`valuemapping`. When printed, types follow basic Futhark type syntax *without* sizes (e.g. ``[][]i32``). Uniqueness is not part of the types, but is indicated with an asterisk in the ``inputs`` and ``outputs`` commands (see below). Consumption and aliasing ------------------------ Since the server protocol closely models the C API, the same rules apply to entry points that consume their arguments (see :ref:`api-consumption`). In particular, consumed variables must still be freed with the ``free`` command - but this is the only operation that may be used on consumed variables. Commands -------- The following commands are supported. General Commands ~~~~~~~~~~~~~~~~ ``types`` ......... Print the names of available types, one per line. ``entry_points`` ................ Print the names of available entry points. ``call`` *entry* *o1* ... *oN* *i1* ... *iM* ............................................ Call the given entry point with input from the variables *i1* to *iM*. The results are stored in *o1* to *oN*, which must not already exist. ``restore`` *file* *v1* *t1* ... *vN* *tN* .......................................... Load *N* values from *file* and store them in the variables *v1* to *vN* of types *t1* to *tN*, which must not already exist. ``store`` *file* *v1* ... *vN* .............................. Store the *N* values in variables *v1* to *vN* in *file*. ``free`` *v1* ... *vN* ...................... Delete the given variables. ``rename`` *oldname* *newname* .............................. Rename the variable *oldname* to *newname*, which must not already exist. ``inputs`` *entry* .................. Print the types of inputs accepted by the given entry point, one per line. If the given input is consumed, the type is prefixed by `*`. ``outputs`` *entry* ................... Print the types of outputs produced by the given entry point, one per line. If the given output is guaranteed to be unique (does not alias any inputs), the type is prefixed by `*`. ``clear`` ......... Clear all internal caches and counters maintained by the Futhark context. Corresponds to :c:func:`futhark_context_clear_caches`. ``pause_profiling`` ................... Corresponds to :c:func:`futhark_context_pause_profiling`. ``unpause_profiling`` ..................... Corresponds to :c:func:`futhark_context_unpause_profiling`. ``report`` .......... Corresponds to :c:func:`futhark_context_report`. ``set_tuning_param`` *param* *value* .................................... Corresponds to :c:func:`futhark_context_config_set_tuning_param`. ``tuning_params`` *entry* ......................... For each tuning parameters relevant to the given entry point, print its name, then a space, then its class. This is similar to on :c:func:`futhark_tuning_params_for_sum`, but note that this command prints *names* and not *integers*. ``tuning_param_class`` *param* .............................. Corresponds to :c:func:`futhark_get_tuning_param_class`. Record Commands ~~~~~~~~~~~~~~~ ``fields`` *type* ................. If the given type is a record, print a line for each field of the record. The line will contain the name of the field, followed by a space, followed by the type of the field. Note that the type name can contain spaces. The order of fields is significant, as it is the one expected by the ``new_record`` command. ``new`` *v0* *type* *v1* ... *vN* ................................. Create a new variable *v0* of type *type*, which must be a record type with *N* fields, where *v1* to *vN* are variables with the corresponding field types (the expected order is given by the ``fields`` command). ``project`` *to* *from* *field* ............................... Create a new variable *to* whose value is the field *field* of the record-typed variable *from*. Environment Variables --------------------- ``FUTHARK_COMPILER_DEBUGGING`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Turns on debugging output for the server when set to 1. futhark-0.25.27/docs/usage.rst000066400000000000000000000546661475065116200161510ustar00rootroot00000000000000.. _usage: Basic Usage =========== Futhark contains several code generation backends. Each is provided as subcommand of the ``futhark`` binary. For example, ``futhark c`` compiles a Futhark program by translating it to sequential C code, while ``futhark pyopencl`` generates Python code with calls to the PyOpenCL library. The different compilers all contain the same frontend and optimisation pipeline - only the code generator is different. They all provide roughly the same command line interface, but there may be minor differences and quirks due to characteristics of the specific backends. There are three main ways of compiling a Futhark program: to an ordinary executable (by using ``--executable``, which is the default), to a *server executable* (``--server``), and to a library (``--library``). Plain executables can be run immediately, but are useful mostly for testing and benchmarking. Server executables are discussed in :ref:`server-protocol`. Libraries can be called from non-Futhark code. .. _executable: Compiling to Executable ----------------------- A Futhark program is stored in a file with the extension ``.fut``. It can be compiled to an executable program as follows:: $ futhark c prog.fut This makes use of the ``futhark c`` compiler, but any other will work as well. The compiler will automatically invoke ``cc`` to produce an executable binary called ``prog``. If we had used ``futhark python`` instead of ``futhark c``, the ``prog`` file would instead have contained Python code, along with a `shebang`_ for easy execution. In general, when compiling file ``foo.fut``, the result will be written to a file ``foo`` (i.e. the extension will be stripped off). This can be overridden using the ``-o`` option. For more details on specific compilers, see their individual manual pages. .. _shebang: https://en.wikipedia.org/wiki/Shebang_%28Unix%29 Executables generated by the various Futhark compilers share a common command-line interface, but may also individually support more options. When a Futhark program is run, execution starts at one of its *entry points*. By default, the entry point named ``main`` is run. An alternative entry point can be indicated by using the ``-e`` option. All entry point functions must be declared appropriately in the program (see :ref:`entry-points`). If the entry point takes any parameters, these will be read from standard input in a subset of the Futhark syntax. A binary input format is also supported; see :ref:`binary-data-format`. The result of the entry point is printed to standard output. Only a subset of all Futhark values can be passed to an executable. Specifically, only primitives and arrays of primitive types are supported. In particular, nested tuples and arrays of tuples are not permitted. Non-nested tuples are supported are supported as simply flat values. This restriction is not present for Futhark programs compiled to libraries. If an entry point *returns* any such value, its printed representation is unspecified. As a special case, an entry point is allowed to return a flat tuple. Instead of compiling, there is also an interpreter, accessible as ``futhark run`` and ``futhark repl``. The latter is an interactive prompt, useful for experimenting with Futhark expressions. Be aware that the interpreter runs code very slowly. .. _executable-options: Executable Options ^^^^^^^^^^^^^^^^^^ All generated executables support the following options. ``-h/--help`` Print help text to standard output and exit. ``-D/--debugging`` Print debugging information on standard error. Exactly what is printed, and how it looks, depends on which Futhark compiler is used. This option may also enable more conservative (and slower) execution, such as frequently synchronising to check for errors. This implies ``--log``. ``-L/--log`` Print low-overhead logging information during initialisation and during execution of entry points. Enabling this option should not affect program performance. ``--cache-file FILE`` Create (if necessary) and use data in the provided cache file to speed up subsequent launches of the same program. The cache file is automatically updated by the running program as necessary. It is safe to delete at any time, and will be recreated as necessary. ``--print-params`` Print a list of tuning parameters followed by their *parameter class* in parentheses, which indicates what they are used for. ``--param SIZE=VALUE`` Set one of the tunable sizes to the given value. Using the ``--tuning`` option is more convenient. ``--tuning FILE`` Load tuning options from the indicated *tuning file*. The file must contain lines of the form ``SIZE=VALUE``, where each *SIZE* must be one of the sizes listed by the ``--print-params`` option (without size class), and the *VALUE* must be a non-negative integer. Extraneous spaces or blank lines are not allowed. A zero means to use the default size, whatever it may be. In case of duplicate assignments to the same size, the last one takes predecence. This is equivalent to passing each size setting on the command line using the ``--params`` option, but more convenient. Non-Server Executable Options ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The following options are only supported on non-server executables, because they make no sense in a server context. ``-t/--write-runtime-to FILE`` Print the time taken to execute the program to the indicated file, an integral number of microseconds. The time taken to perform setup or teardown, including reading the input or writing the result, is not included in the measurement. See the documentation for specific compilers to see exactly what is measured. ``-r/--runs RUNS`` Run the specified entry point the given number of times (plus a warmup run). The program result is only printed once, after the last run. If combined with ``-t``, one measurement is printed per run. This is a good way to perform benchmarking. ``-b/--binary-output`` Print the result using the binary data format (:ref:`binary-data-format`). For large outputs, this is significantly faster and takes up less space. ``-n/--no-print-result`` Do not print the result of running the program. GPU Options ~~~~~~~~~~~ The following options are supported by executables generated with the GPU backends (``opencl``, ``pyopencl``, ``hip``, and ``cuda``). ``-d/--device DEVICE`` Pick the first device whose name contains the given string. The special string ``#k``, where ``k`` is an integer, can be used to pick the *k*-th device, numbered from zero. ``--default-thread-block-size INT`` The default size of GPU thread blocks that are launched. Capped to the hardware limit if necessary. ``--default-num-thread-blocks INT`` The default number of GPU thread blocks that are launched. ``-P/--profile`` Measure the time taken by various GPU operations (such as kernels) and print a summary at the end. Unfortunately, it is currently nontrivial (and manual) to relate these operations back to source Futhark code. ``--unified-memory INT`` Corresponds to :c:func:`futhark_context_config_set_unified_memory`. OpenCL-specific Options ~~~~~~~~~~~~~~~~~~~~~~~ The following options are supported by executables generated with the OpenCL backends (``opencl``, ``pyopencl``): ``-p/--platform PLATFORM`` Pick the first OpenCL platform whose name contains the given string. The special string ``#k``, where ``k`` is an integer, can be used to pick the *k*-th platform, numbered from zero. If used in conjunction with ``-d``, only the devices from matching platforms are considered. ``--default-group-size INT`` The default size of OpenCL workgroups that are launched. Capped to the hardware limit if necessary. ``--default-num-groups INT`` The default number of OpenCL workgroups that are launched. ``--dump-opencl FILE`` Don't run the program, but instead dump the embedded OpenCL program to the indicated file. Useful if you want to see what is actually being executed. ``--load-opencl FILE`` Instead of using the embedded OpenCL program, load it from the indicated file. This is extremely unlikely to result in succesful execution unless this file is the result of a previous call to ``--dump-opencl`` (perhaps lightly modified). ``--dump-opencl-binary FILE`` Don't run the program, but instead dump the compiled version of the embedded OpenCL program to the indicated file. On NVIDIA platforms, this will be PTX code. If this option is set, no entry point will be run. ``--load-opencl-binary FILE`` Load an OpenCL binary from the indicated file. ``--build-option OPT`` Add an additional build option to the string passed to ``clBuildProgram()``. Refer to the OpenCL documentation for which options are supported. Be careful - some options can easily result in invalid results. ``--list-devices`` List all OpenCL devices and platforms available on the system. There is rarely a need to use both ``-p`` and ``-d``. For example, to run on the first available NVIDIA GPU, ``-p NVIDIA`` is sufficient, as there is likely only a single device associated with this platform. On \*nix (including macOS), the `clinfo `_ tool (available in many package managers) can be used to determine which OpenCL platforms and devices are available on a given system. CUDA-specific Options ~~~~~~~~~~~~~~~~~~~~~ The following options are supported by executables generated by the ``cuda`` backend: ``--dump-cuda FILE`` Don't run the program, but instead dump the embedded CUDA program to the indicated file. Useful if you want to see what is actually being executed. ``--load-cuda FILE`` Instead of using the embedded CUDA program, load it from the indicated file. This is extremely unlikely to result in succesful execution unless this file is the result of a previous call to ``--dump-cuda`` (perhaps lightly modified). ``--dump-ptx FILE`` As ``--dump-cuda``, but dumps the compiled PTX code instead. ``--load-ptx FILE`` Instead of using the embedded CUDA program, load compiled PTX code from the indicated file. ``--nvrtc-option OPT`` Add the given option to the command line used to compile CUDA kernels with NVRTC. The list of supported options varies with the CUDA version but can be `found in the NVRTC documentation `_. For convenience, CUDA executables also accept the same ``--default-num-groups`` and ``--default-group-size`` options that the OpenCL backend uses. These then refer to grid size and thread block size, respectively. Multicore options ~~~~~~~~~~~~~~~~~ The following options are supported by executables generated by the ``multicore`` backend: ``--num-threads INT`` The number of threads used to run parallel operations. If set to a value less than ``1``, then the runtime system will use one thread per detected core. ``-P/--profile`` Measure the time taken by various parallel sections and print a summary at the end. Unfortunately, it is currently nontrivial (and manual) to relate these operations back to source Futhark code. Compiling to Library -------------------- While compiling a Futhark program to an executable is useful for testing, it is not suitable for production use. Instead, a Futhark program should be compiled into a reusable library in some target language, enabling integration into a larger program. General Concerns ^^^^^^^^^^^^^^^^ Futhark entry points are mapped to some form of function or method in the target language. Generally, an entry point taking *n* parameters will result in a function taking *n* parameters. If the entry point returns an *m*-element tuple, then the function will return *m* values (although the tuple can be replaced with a single opaque value, see below). Extra parameters may be added to pass in context data, or *out*-parameters for writing the result, for target languages that do not support multiple return values from functions. The entry point should have a name that is also a valid identifier in the target language (usually C). Not all Futhark types can be mapped cleanly to the target language. Arrays of tuples, for example, are a common issue. In such cases, *opaque types* are used in the generated code. Values of these types cannot be directly inspected, but can be passed back to Futhark entry points. In the general case, these types will be named with a random hash. However, if you insert an explicit type annotation (and the type name contains only characters valid for identifiers for the used backend), the indicated name will be used. Note that arrays contain brackets, which are usually not valid in identifiers. Defining and using a type abbreviation is the best way around this. .. _valuemapping: Value Mapping ~~~~~~~~~~~~~ The rules for how Futhark values are mapped to target language values are as follows: * Primitive types or arrays of primitive types are mapped transparently (although for the C backends, this still involves a distinct type for arrays). * All other types are mapped to an opaque type. Use a type ascription with a type abbreviation to give it a specific name, otherwise one will be generated. Return types follow these rules, with one addition: * If the return type is an *m*-element tuple, then the function returns *m* values, mapped according to the rules above (but not including this one - nested tuples are not mapped directly). This rule does not apply when the entry point has been given a return type ascription that is not syntactically a tuple type. .. _api-consumption: Consumption and Aliasing ~~~~~~~~~~~~~~~~~~~~~~~~ Futhark's support for :ref:`in-place-updates` has implications for the generated API. Unfortunately, The type system of most languages (e.g. C) is not rich enough to express the rules, so they are not statically (or currently even dynamically checked). Since Futhark will never infer a unique/consuming type for an entry point parameter, this section can be ignored unless uniqueness annotations have been manually added to the entry points parameter types. The rules are essentially the same as in the language itself: 1. Each entry point input parameter is either *consuming* or *nonconsuming* (the default). This corresponds to unique and nonunique types in the original Futhark program. A value passed for a consuming parameter is considered *consumed*, now has an unspecified value, and may never be used again. It must still be manually freed, if applicable. Further, any *aliases* of that value are also considered consumed and may not be used. 2. Each entry point output is either *unique* or *nonunique*. A unique output has no aliases. A nonunique output aliases *every* nonconsuming input parameter. Note that these distinctions are currently usually not visible in the generated API, and so correct usage requires knowledge of the original types in the Futhark function. The safest strategy is to not expose unique types in entry points. Generating C ^^^^^^^^^^^^ A Futhark program ``futlib.fut`` can be compiled to reusable C code using either:: $ futhark c --library futlib.fut Or:: $ futhark opencl --library futlib.fut This produces three files in the current directory: ``futlib.c``, ``futlib.h``, and ``futlib.json`` ( see :ref:`manifest` for more on the latter). If we wish (and are on a Unix system), we can then compile ``futlib.c`` to an object file like this:: $ gcc futlib.c -c This produces a file ``futlib.o`` that can then be linked with the main application. Details of how to link the generated code with other C code is highly system-dependent, and outside the scope of this manual. On Unix, we can simply add ``futlib.o`` to the final compiler or linker command line:: $ gcc main.c -o main futlib.o Depending on the Futhark backend you are using, you may need to add some linker flags. For example, ``futhark opencl`` requires ``-lOpenCL`` (``-framework OpenCL`` on macOS). See the manual page for each compiler for details. It is also possible to simply add the generated ``.c`` file to the C compiler command line used for compiling our whole program (here ``main.c``):: $ gcc main.c -o main futlib.c The downside of this approach is that the generated ``.c`` file may contain code that causes the C compiler to warn (for example, unused support code that is not needed by the Futhark program). The generated header file (here, ``futlib.h``) specifies the API, and is intended to be human-readable. See :ref:`c-api` for more information. The basic usage revolves around creating a *configuration object*, which can then be used to obtain a *context object*, which must be passed whenever entry points are called. The configuration object is created using the following function:: struct futhark_context_config *futhark_context_config_new(); Depending on the backend, various functions are generated to modify the configuration. The following is always available:: void futhark_context_config_set_debugging(struct futhark_context_config *cfg, int flag); A configuration object can be used to create a context with the following function:: struct futhark_context *futhark_context_new(struct futhark_context_config *cfg); Context creation may fail. Immediately after ``futhark_context_new()``, call ``futhark_context_get_error()`` (see below), which will return a non-NULL error string if context creation failed. The API functions are all thread safe. Memory management is entirely manual. Deallocation functions are provided for all types defined in the header file. Everything returned by an entry point must be manually deallocated. For now, many internal errors, such as failure to allocate memory, will cause the function to ``abort()`` rather than return an error code. However, all application errors (such as bounds and array size checks) will produce an error code. C with OpenCL ~~~~~~~~~~~~~ When generating C code with ``futhark opencl``, you will need to link against the OpenCL library when linking the final binary:: $ gcc main.c -o main futlib.o -lOpenCL When using the OpenCL backend, extra API functions are provided for directly accessing or providing the OpenCL objects used by Futhark. Take care when using these functions. In particular, a Futhark context can now be configured with the command queue to use:: void futhark_context_config_set_command_queue(struct futhark_context_config *cfg, cl_command_queue queue); As a ``cl_command_queue`` specifies an OpenCL device, this is also how manual platform and device selection is possible. A function is also provided for retrieving the command queue used by some Futhark context:: cl_command_queue futhark_context_get_command_queue(struct futhark_context *ctx); This can be used to connect two separate Futhark contexts that have been loaded dynamically. The raw ``cl_mem`` object underlying a Futhark array can be accessed with the function named ``futhark_values_raw_type``, where ``type`` depends on the array in question. For example:: cl_mem futhark_values_raw_i32_1d(struct futhark_context *ctx, struct futhark_i32_1d *arr); The array will be stored in row-major form in the returned memory object. The function performs no copying, so the ``cl_mem`` still belongs to Futhark, and may be reused for other purposes when the corresponding array is freed. A dual function can be used to construct a Futhark array from a ``cl_mem``:: struct futhark_i32_1d *futhark_new_raw_i32_1d(struct futhark_context *ctx, cl_mem data, int offset, int dim0); This function *does* copy the provided memory into fresh internally allocated memory. The array is assumed to be stored in row-major form ``offset`` bytes into the memory region. See also :ref:`futhark-opencl(1)`. Generating Python ^^^^^^^^^^^^^^^^^ The ``futhark python`` and ``futhark pyopencl`` compilers both support generating reusable Python code, although the latter of these generates code of sufficient performance to be worthwhile. The following mentions options and parameters only available for ``futhark pyopencl``. You will need at least PyOpenCL version 2015.2. We can use ``futhark pyopencl`` to translate the program ``futlib.fut`` into a Python module ``futlib.py`` with the following command:: $ futhark pyopencl --library futlib.fut This will create a file ``futlib.py``, which contains Python code that defines a class named ``futlib``. This class defines one method for each entry point function (see :ref:`entry-points`) in the Futhark program. The methods take one parameter for each parameter in the corresponding entry point, and return a tuple containing a value for every value returned by the entry point. For entry points returning a single (non-tuple) value, just that value is returned (that is, single-element tuples are not returned). After the class has been instantiated, these methods can be invoked to run the corresponding Futhark function. The constructor for the class takes various keyword parameters: ``interactive=BOOL`` If ``True`` (the default is ``False``), show a menu of available OpenCL platforms and devices, and use the one chosen by the user. ``platform_pref=STR`` Use the first platform that contains the given string. Similar to the ``-p`` option for executables. ``device_pref=STR`` Use the first device that contains the given string. Similar to the ``-d`` option for executables. Futhark arrays are mapped to either the Numpy ``ndarray`` type or the `pyopencl.array `_ type. Scalars are mapped to Numpy scalar types. Reproducibility --------------- The Futhark compiler is deterministic by design, meaning that repeatedly compiling the *same program* with the *same compilation flags* and using the *same version* of the compiler will produce identical output every time. Note that this only applies to the code generated by the Futhark compiler itself. When compiling to an executable with one of the C backends (see :ref:`executable`), Futhark will invoke a C compiler that may not be perfectly reproducible. In such cases the generated ``.c`` and ``.h`` files will be reproducible, but the final executable may not. futhark-0.25.27/docs/versus-other-languages.rst000066400000000000000000000200071475065116200214350ustar00rootroot00000000000000.. _versus-other-languages: Futhark Compared to Other Functional Languages ============================================== This guide is intended for programmers who are familiar with other functional languages and want to start working with Futhark. Futhark is a simple language with a complex compiler. Functional programming is fundamentally well suited to data parallelism, so Futhark's syntax and underlying concepts are taken directly from established functional languages such as Haskell and the ML family. While Futhark does add a few small conveniences (built-in array types) and one complicated and unusual feature (in-place updates via uniqueness types, see :ref:`in-place-updates`), a programmer familiar with a common functional language should be able to understand the meaning of a Futhark program and quickly begin writing their own programs. To speed up this process, we describe here some of the various quirks and unexpected limitations imposed by Futhark. We also recommended reading some of the `example programs`_ along with this guide. The guide does *not* cover all Futhark features worth knowing, so do also skim :ref:`language-reference` and the :ref:`glossary`. .. _`example programs`: https://futhark-lang.org/examples.html Basic Syntax ------------ Futhark uses a keyword-based structure, with optional indentation *solely* for human readability. This aspect differs from Haskell and F#. Names are lexically divided into *identifiers* and *symbols*: * *Identifiers* begin with a letter or underscore and contain letters, numbers, underscores, and apostrophes. * *Symbols* contain the characters found in the default operators (``+-*/%=!><|&^``). All function and variable names must be identifiers, and all infix operators are symbols. An identifier can be used as an infix operator by enclosing it in backticks, as in Haskell. Identifiers are case-sensitive, and there is no restriction on the case of the first letter (unlike Haskell and OCaml, but like Standard ML and Flix). User-defined operators are possible, but the fixity of the operator depends on its name. Specifically, the fixity of a user-defined operator *op* is equal to the fixity of the built-in operator that is the longest prefix of *op*. For example, ``<<=`` would have the same fixity as ``<<``, and ``=<<`` the same as ``=``. This rule is the same as the rule found in OCaml and F#. Top-level functions and values are defined with ``def`` as in Flix. Local variables are bound with ``let``. Evaluation ---------- Futhark is a completely pure language, with no cheating through monads, effect systems, or anything of the sort. Evaluation is *eager* or *call-by-value*, like most non-Haskell languages. However, there is no defined evaluation order. Furthermore, the Futhark compiler is permitted to turn non-terminating programs into terminating programs, for example by removing dead code that might cause an error. Moreover, there is no way to handle errors within a Futhark program (no exceptions or similar); although errors are gracefully reported to whatever invokes the Futhark program. The evaluation semantics are entirely sequential, with parallelism being solely an operational detail. Hence, race conditions are impossible. The Futhark compiler does not automatically go looking for parallelism. Only certain special constructs and built-in library functions (such as ``map``, ``reduce``, ``scan``, and ``filter``) may be executed in parallel. Currying and partial application work as usual (although functions are not fully first class; see `Types`_). Although the ``assert`` construct looks like a function, it is not, and it cannot be partially applied. Lambda terms are written as ``\x -> x + 2``, as in Haskell. A Futhark program is read top-down, and all functions must be declared in the order they are used, like Standard ML. Unlike just about all functional languages, recursive functions are *not* supported. Most of the time, you will use bulk array operations instead, but there is also a dedicated ``loop`` language construct, which is essentially syntactic sugar for tail recursive functions. Types ----- Futhark supports a range of integer types, floating point types, and booleans (see :ref:`primitives`). A numeric literal can be suffixed with its desired type, such as ``1i8`` for an eight-bit signed integer. Un-adorned numerals have their type inferred based on use. This only works for built-in numeric types. Arrays are a built-in type. The type of an array containing elements of type ``t`` is written ``[]t`` (not ``[t]`` as in Haskell), and we may optionally annotate it with a size as ``[n]t`` (see `Shape Declarations`). Array values are written as ``[1,2,3]``. Array indexing is written ``a[i]`` with *no* space allowed between the array name and the brace. Indexing of multi-dimensional arrays is written ``a[i,j]``. Arrays are 0-indexed. All types can be combined in tuples as usual, as well as in *structurally typed records*, as in Standard ML and Flix. Non-recursive sum types are supported, and are also structurally typed. Abstract types are possible via the module system; see :ref:`module-system`. If a variable ``foo`` is a record of type ``{a: i32, b: bool}``, then we access field ``a`` with dot notation: ``foo.a``. Tuples are a special case of records, where all the fields have a 0-indexed numeric label. For example, ``(i32, bool)`` is the same as ``{0: i32, 1: bool}``, and can be indexed as ``foo.1``. Sum types are defined as constructors separated by a vertical bar (``|``). Constructor names always start with a ``#``. For example, ``#red | #blue i32`` is a sum type with the constructors ``#red`` and ``#blue``, where the latter has an ``i32`` as payload. The terms ``#red`` and ``#blue 2`` produce values of this type. Constructor applications must always be fully saturated. Due to the structural type system, type annotations are sometimes necessary to resolve ambiguities. For example, the term ``#blue 2`` can produce a value of *any type* that has an appropriate constructor. Function types are written with the usual ``a -> b`` notation, and functions can be passed as arguments to other functions. However, there are some restrictions: * A function cannot be put in an array (but a record or tuple is fine). * A function cannot be returned from a branch. * A function cannot be used as a ``loop`` parameter. Function types interact with type parameters in a subtle way:: def id 't (x: t) = x This declaration defines a function ``id`` that has a type parameter ``t``. Here, ``t`` is an *unlifted* type parameter, which is guaranteed never to be a function type, and so in the body of the function we could choose to put parameter values of type ``t`` in an array. However, it means that this identity function cannot be called on a functional value. Instead, we probably want a *lifted* type parameter:: def id '^t (x: t) = x Such *lifted* type parameters are not restricted from being instantiated with function types. On the other hand, in the function definition they are subject to the same restrictions as functional types. Futhark supports Hindley-Milner type inference (with some restrictions), so we could also just write it as:: def id x = x Type abbreviations are possible:: type foo = (i32, i32) Type parameters are supported as well:: type pair 'a 'b = (a, b) As with everything else, they are structurally typed, so the types ``pair i32 bool`` and ``(i32, bool)`` are entirely interchangeable. Most unusually, this is also the case for sum types. The following two types are entirely interchangeable:: type maybe 'a = #just a | #nothing type option 'a = #nothing | #just a Only for abstract types, where the definition has been hidden via the module system, do type names have any significance. Size parameters can also be passed:: type vector [n] t = [n]t type i32matrix [n][m] = [n] (vector [m] i32) Note that for an actual array type, the dimensions come *before* the element type, but with a type abbreviation, a size is just another parameter. This easily becomes hard to read if you are not careful. futhark-0.25.27/futhark-benchmarks/000077500000000000000000000000001475065116200171215ustar00rootroot00000000000000futhark-0.25.27/futhark.cabal000066400000000000000000000422171475065116200160000ustar00rootroot00000000000000cabal-version: 2.4 name: futhark version: 0.25.27 synopsis: An optimising compiler for a functional, array-oriented language. description: Futhark is a small programming language designed to be compiled to efficient parallel code. It is a statically typed, data-parallel, and purely functional array language in the ML family, and comes with a heavily optimising ahead-of-time compiler that presently generates GPU code via CUDA and OpenCL, although the language itself is hardware-agnostic. . For more information, see the website at https://futhark-lang.org . For introductionary information about hacking on the Futhark compiler, see . Regarding the internal design of the compiler, the following modules make good starting points: . * "Futhark" contains a basic architectural overview of the compiler. * "Futhark.IR.Syntax" explains the basic design of the intermediate representation (IR). * "Futhark.Construct" explains how to write code that manipulates and creates AST fragments. . <> category: Futhark homepage: https://futhark-lang.org bug-reports: https://github.com/diku-dk/futhark/issues maintainer: Troels Henriksen athas@sigkill.dk license: ISC license-file: LICENSE build-type: Simple extra-source-files: -- Cabal's recompilation tracking doesn't work when we use wildcards -- here, so for now we spell out every single file. rts/c/atomics.h rts/c/context.h rts/c/context_prototypes.h rts/c/backends/c.h rts/c/backends/cuda.h rts/c/backends/hip.h rts/c/backends/multicore.h rts/c/backends/opencl.h rts/c/lock.h rts/c/copy.h rts/c/timing.h rts/c/errors.h rts/c/free_list.h rts/c/event_list.h rts/c/gpu.h rts/c/gpu_prototypes.h rts/c/tuning.h rts/c/values.h rts/c/half.h rts/c/cache.h rts/c/ispc_util.h rts/c/scalar.h rts/c/scalar_f16.h rts/c/scheduler.h rts/c/uniform.h rts/c/util.h rts/c/server.h rts/cuda/prelude.cu rts/futhark-doc/style.css rts/javascript/server.js rts/javascript/values.js rts/javascript/wrapperclasses.js rts/opencl/copy.cl rts/opencl/prelude.cl rts/opencl/transpose.cl rts/python/tuning.py rts/python/panic.py rts/python/memory.py rts/python/server.py rts/python/values.py rts/python/opencl.py rts/python/scalar.py prelude/functional.fut prelude/math.fut prelude/soacs.fut prelude/zip.fut prelude/ad.fut prelude/array.fut prelude/prelude.fut -- Just enough of the docs to build the manpages. docs/**/*.rst docs/Makefile docs/conf.py docs/requirements.txt extra-doc-files: assets/*.png CHANGELOG.md README.md source-repository head type: git location: https://github.com/diku-dk/futhark common common ghc-options: -Wall -Wcompat -Wno-incomplete-uni-patterns -Wno-x-partial -Wno-unrecognised-warning-flags -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -Wunused-packages default-language: GHC2021 default-extensions: OverloadedStrings library import: common hs-source-dirs: src exposed-modules: Futhark Futhark.Actions Futhark.AD.Derivatives Futhark.AD.Fwd Futhark.AD.Rev Futhark.AD.Rev.Loop Futhark.AD.Rev.Hist Futhark.AD.Rev.Map Futhark.AD.Rev.Monad Futhark.AD.Rev.Reduce Futhark.AD.Rev.Scan Futhark.AD.Rev.Scatter Futhark.AD.Rev.SOAC Futhark.Analysis.AccessPattern Futhark.Analysis.AlgSimplify Futhark.Analysis.Alias Futhark.Analysis.CallGraph Futhark.Analysis.DataDependencies Futhark.Analysis.HORep.MapNest Futhark.Analysis.HORep.SOAC Futhark.Analysis.Interference Futhark.Analysis.LastUse Futhark.Analysis.MemAlias Futhark.Analysis.Metrics Futhark.Analysis.Metrics.Type Futhark.Analysis.PrimExp Futhark.Analysis.PrimExp.Convert Futhark.Analysis.PrimExp.Parse Futhark.Analysis.PrimExp.Simplify Futhark.Analysis.PrimExp.Table Futhark.Analysis.SymbolTable Futhark.Analysis.UsageTable Futhark.Bench Futhark.Builder Futhark.Builder.Class Futhark.CLI.Autotune Futhark.CLI.Bench Futhark.CLI.C Futhark.CLI.CUDA Futhark.CLI.Check Futhark.CLI.Benchcmp Futhark.CLI.Datacmp Futhark.CLI.Dataset Futhark.CLI.Defs Futhark.CLI.Dev Futhark.CLI.Doc Futhark.CLI.Eval Futhark.CLI.Fmt Futhark.CLI.HIP Futhark.CLI.Literate Futhark.CLI.LSP Futhark.CLI.Main Futhark.CLI.Misc Futhark.CLI.Multicore Futhark.CLI.MulticoreISPC Futhark.CLI.MulticoreWASM Futhark.CLI.OpenCL Futhark.CLI.Pkg Futhark.CLI.Profile Futhark.CLI.PyOpenCL Futhark.CLI.Python Futhark.CLI.Query Futhark.CLI.REPL Futhark.CLI.Run Futhark.CLI.Script Futhark.CLI.Test Futhark.CLI.WASM Futhark.CodeGen.Backends.CCUDA Futhark.CodeGen.Backends.COpenCL Futhark.CodeGen.Backends.HIP Futhark.CodeGen.Backends.GenericC Futhark.CodeGen.Backends.GenericC.CLI Futhark.CodeGen.Backends.GenericC.Code Futhark.CodeGen.Backends.GenericC.EntryPoints Futhark.CodeGen.Backends.GenericC.Fun Futhark.CodeGen.Backends.GenericC.Monad Futhark.CodeGen.Backends.GenericC.Options Futhark.CodeGen.Backends.GenericC.Pretty Futhark.CodeGen.Backends.GenericC.Server Futhark.CodeGen.Backends.GenericC.Types Futhark.CodeGen.Backends.GenericPython Futhark.CodeGen.Backends.GenericPython.AST Futhark.CodeGen.Backends.GenericPython.Options Futhark.CodeGen.Backends.GenericWASM Futhark.CodeGen.Backends.GPU Futhark.CodeGen.Backends.MulticoreC Futhark.CodeGen.Backends.MulticoreC.Boilerplate Futhark.CodeGen.Backends.MulticoreISPC Futhark.CodeGen.Backends.MulticoreWASM Futhark.CodeGen.Backends.PyOpenCL Futhark.CodeGen.Backends.PyOpenCL.Boilerplate Futhark.CodeGen.Backends.SequentialC Futhark.CodeGen.Backends.SequentialC.Boilerplate Futhark.CodeGen.Backends.SequentialPython Futhark.CodeGen.Backends.SequentialWASM Futhark.CodeGen.Backends.SimpleRep Futhark.CodeGen.RTS.C Futhark.CodeGen.RTS.CUDA Futhark.CodeGen.RTS.OpenCL Futhark.CodeGen.RTS.Python Futhark.CodeGen.RTS.JavaScript Futhark.CodeGen.ImpCode Futhark.CodeGen.ImpCode.GPU Futhark.CodeGen.ImpCode.Multicore Futhark.CodeGen.ImpCode.OpenCL Futhark.CodeGen.ImpCode.Sequential Futhark.CodeGen.ImpGen Futhark.CodeGen.ImpGen.CUDA Futhark.CodeGen.ImpGen.GPU Futhark.CodeGen.ImpGen.GPU.Base Futhark.CodeGen.ImpGen.GPU.Block Futhark.CodeGen.ImpGen.GPU.SegHist Futhark.CodeGen.ImpGen.GPU.SegMap Futhark.CodeGen.ImpGen.GPU.SegRed Futhark.CodeGen.ImpGen.GPU.SegScan Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass Futhark.CodeGen.ImpGen.GPU.ToOpenCL Futhark.CodeGen.ImpGen.HIP Futhark.CodeGen.ImpGen.Multicore Futhark.CodeGen.ImpGen.Multicore.Base Futhark.CodeGen.ImpGen.Multicore.SegHist Futhark.CodeGen.ImpGen.Multicore.SegMap Futhark.CodeGen.ImpGen.Multicore.SegRed Futhark.CodeGen.ImpGen.Multicore.SegScan Futhark.CodeGen.ImpGen.OpenCL Futhark.CodeGen.ImpGen.Sequential Futhark.CodeGen.OpenCL.Heuristics Futhark.Compiler Futhark.Compiler.CLI Futhark.Compiler.Config Futhark.Compiler.Program Futhark.Construct Futhark.Doc.Generator Futhark.Error Futhark.Fmt.Printer Futhark.Fmt.Monad Futhark.FreshNames Futhark.Format Futhark.IR Futhark.IR.Aliases Futhark.IR.GPU Futhark.IR.GPU.Op Futhark.IR.GPU.Simplify Futhark.IR.GPU.Sizes Futhark.IR.GPUMem Futhark.IR.MC Futhark.IR.MC.Op Futhark.IR.MCMem Futhark.IR.Mem Futhark.IR.Mem.Interval Futhark.IR.Mem.LMAD Futhark.IR.Mem.Simplify Futhark.IR.Parse Futhark.IR.Pretty Futhark.IR.Prop Futhark.IR.Prop.Aliases Futhark.IR.Prop.Constants Futhark.IR.Prop.Names Futhark.IR.Prop.Pat Futhark.IR.Prop.Rearrange Futhark.IR.Prop.Reshape Futhark.IR.Prop.Scope Futhark.IR.Prop.TypeOf Futhark.IR.Prop.Types Futhark.IR.Rep Futhark.IR.Rephrase Futhark.IR.RetType Futhark.IR.SOACS Futhark.IR.SOACS.SOAC Futhark.IR.SOACS.Simplify Futhark.IR.SegOp Futhark.IR.Seq Futhark.IR.SeqMem Futhark.IR.Syntax Futhark.IR.Syntax.Core Futhark.IR.Traversals Futhark.IR.TypeCheck Futhark.Internalise Futhark.Internalise.AccurateSizes Futhark.Internalise.ApplyTypeAbbrs Futhark.Internalise.Bindings Futhark.Internalise.Defunctionalise Futhark.Internalise.Defunctorise Futhark.Internalise.Entry Futhark.Internalise.Exps Futhark.Internalise.FullNormalise Futhark.Internalise.Lambdas Futhark.Internalise.LiftLambdas Futhark.Internalise.Monad Futhark.Internalise.Monomorphise Futhark.Internalise.ReplaceRecords Futhark.Internalise.TypesValues Futhark.LSP.Compile Futhark.LSP.Diagnostic Futhark.LSP.Handlers Futhark.LSP.Tool Futhark.LSP.State Futhark.LSP.PositionMapping Futhark.MonadFreshNames Futhark.Optimise.BlkRegTiling Futhark.Optimise.CSE Futhark.Optimise.DoubleBuffer Futhark.Optimise.EntryPointMem Futhark.Optimise.Fusion Futhark.Optimise.Fusion.Composing Futhark.Optimise.Fusion.GraphRep Futhark.Optimise.Fusion.RulesWithAccs Futhark.Optimise.Fusion.TryFusion Futhark.Optimise.GenRedOpt Futhark.Optimise.HistAccs Futhark.Optimise.InliningDeadFun Futhark.Optimise.MemoryBlockMerging Futhark.Optimise.MemoryBlockMerging.GreedyColoring Futhark.Optimise.ArrayShortCircuiting Futhark.Optimise.ArrayShortCircuiting.ArrayCoalescing Futhark.Optimise.ArrayShortCircuiting.DataStructs Futhark.Optimise.ArrayShortCircuiting.MemRefAggreg Futhark.Optimise.ArrayShortCircuiting.TopdownAnalysis Futhark.Optimise.MergeGPUBodies Futhark.Optimise.ArrayLayout Futhark.Optimise.ArrayLayout.Transform Futhark.Optimise.ArrayLayout.Layout Futhark.Optimise.ReduceDeviceSyncs Futhark.Optimise.ReduceDeviceSyncs.MigrationTable Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph Futhark.Optimise.Simplify Futhark.Optimise.Simplify.Engine Futhark.Optimise.Simplify.Rep Futhark.Optimise.Simplify.Rule Futhark.Optimise.Simplify.Rules Futhark.Optimise.Simplify.Rules.BasicOp Futhark.Optimise.Simplify.Rules.ClosedForm Futhark.Optimise.Simplify.Rules.Index Futhark.Optimise.Simplify.Rules.Loop Futhark.Optimise.Simplify.Rules.Match Futhark.Optimise.Simplify.Rules.Simple Futhark.Optimise.Sink Futhark.Optimise.TileLoops Futhark.Optimise.TileLoops.Shared Futhark.Optimise.Unstream Futhark.Pass Futhark.Pass.AD Futhark.Pass.ExpandAllocations Futhark.Pass.ExplicitAllocations Futhark.Pass.ExplicitAllocations.GPU Futhark.Pass.ExplicitAllocations.MC Futhark.Pass.ExplicitAllocations.SegOp Futhark.Pass.ExplicitAllocations.Seq Futhark.Pass.ExtractKernels Futhark.Pass.ExtractKernels.BlockedKernel Futhark.Pass.ExtractKernels.DistributeNests Futhark.Pass.ExtractKernels.Distribution Futhark.Pass.ExtractKernels.ISRWIM Futhark.Pass.ExtractKernels.Interchange Futhark.Pass.ExtractKernels.Intrablock Futhark.Pass.ExtractKernels.StreamKernel Futhark.Pass.ExtractKernels.ToGPU Futhark.Pass.ExtractMulticore Futhark.Pass.FirstOrderTransform Futhark.Pass.LiftAllocations Futhark.Pass.LowerAllocations Futhark.Pass.Simplify Futhark.Passes Futhark.Pipeline Futhark.Pkg.Info Futhark.Pkg.Solve Futhark.Pkg.Types Futhark.Profile Futhark.Script Futhark.Test Futhark.Test.Spec Futhark.Test.Values Futhark.Tools Futhark.Transform.CopyPropagate Futhark.Transform.FirstOrderTransform Futhark.Transform.Rename Futhark.Transform.Substitute Futhark.Util Futhark.Util.CMath Futhark.Util.IntegralExp Futhark.Util.Loc Futhark.Util.Log Futhark.Util.Options Futhark.Util.Pretty Futhark.Util.ProgressBar Futhark.Util.Table Futhark.Version Language.Futhark Language.Futhark.Core Language.Futhark.Interpreter Language.Futhark.Interpreter.AD Language.Futhark.Interpreter.Values Language.Futhark.FreeVars Language.Futhark.Parser Language.Futhark.Parser.Monad Language.Futhark.Parser.Lexer.Tokens Language.Futhark.Parser.Lexer.Wrapper Language.Futhark.Prelude Language.Futhark.Pretty Language.Futhark.Primitive Language.Futhark.Primitive.Parse Language.Futhark.Prop Language.Futhark.Query Language.Futhark.Semantic Language.Futhark.Syntax Language.Futhark.Traversals Language.Futhark.Tuple Language.Futhark.TypeChecker Language.Futhark.TypeChecker.Consumption Language.Futhark.TypeChecker.Names Language.Futhark.TypeChecker.Match Language.Futhark.TypeChecker.Modules Language.Futhark.TypeChecker.Monad Language.Futhark.TypeChecker.Terms Language.Futhark.TypeChecker.Terms.Loop Language.Futhark.TypeChecker.Terms.Monad Language.Futhark.TypeChecker.Terms.Pat Language.Futhark.TypeChecker.Types Language.Futhark.TypeChecker.Unify Language.Futhark.Warnings other-modules: Language.Futhark.Parser.Parser Language.Futhark.Parser.Lexer Paths_futhark autogen-modules: Paths_futhark build-tool-depends: alex:alex , happy:happy build-depends: aeson >=2.0.0.0 , ansi-terminal >=0.6.3.1 , array >=0.4 , async >=2.0 , base >=4.15 && <5 , base16-bytestring , binary >=0.8.3 , blaze-html >=0.9.0.1 , bytestring >=0.11.2 , bytestring-to-vector >=0.3.0.1 , bmp >=1.2.6.3 , co-log-core , containers >=0.6.2.1 , cryptohash-md5 , Diff >=0.4.1 , directory >=1.3.0.0 , directory-tree >=0.12.1 , dlist >=0.6.0.1 , fgl , fgl-visualize , file-embed >=0.0.14.0 , filepath >=1.4.1.1 , free >=5.1.10 , futhark-data >= 1.1.1.0 , futhark-server >= 1.2.3.0 , futhark-manifest >= 1.5.0.0 , githash >=0.1.6.1 , half >= 0.3 , haskeline , language-c-quote >= 0.12 , lens , lsp >= 2.2.0.0 , lsp-types >= 2.0.1.0 , mainland-pretty >=0.7.1 , cmark-gfm >=0.2.1 , megaparsec >=9.0.0 , mtl >=2.2.1 , neat-interpolation >=0.3 , parallel >=3.2.1.0 , random >= 1.2.0 , process-extras >=0.7.2 , regex-tdfa >=1.2 , srcloc >=0.4 , template-haskell >=2.11.1 , temporary , terminal-size >=0.3 , text >=1.2.2.2 , time >=1.6.0.1 , transformers >=0.3 , vector >=0.12 , versions >=6.0.0 , zlib >=0.7.0.0 , statistics , mwc-random , prettyprinter >= 1.7 , prettyprinter-ansi-terminal >= 1.1 executable futhark import: common main-is: src/main.hs ghc-options: -threaded -rtsopts "-with-rtsopts=-maxN16 -qg1 -A16M" build-depends: base, futhark test-suite unit import: common type: exitcode-stdio-1.0 main-is: futhark_tests.hs hs-source-dirs: unittests other-modules: Futhark.AD.DerivativesTests Futhark.Analysis.AlgSimplifyTests Futhark.Analysis.PrimExp.TableTests Futhark.BenchTests Futhark.IR.GPUTests Futhark.IR.MCTests Futhark.IR.Mem.IntervalTests Futhark.IR.Mem.IxFun.Alg Futhark.IR.Mem.IxFunTests Futhark.IR.Mem.IxFunWrapper Futhark.IR.Prop.RearrangeTests Futhark.IR.Prop.ReshapeTests Futhark.IR.PropTests Futhark.IR.Syntax.CoreTests Futhark.IR.SyntaxTests Futhark.Internalise.TypesValuesTests Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests Futhark.Optimise.ArrayLayout.AnalyseTests Futhark.Optimise.ArrayLayout.LayoutTests Futhark.Optimise.ArrayLayoutTests Futhark.Pkg.SolveTests Futhark.ProfileTests Language.Futhark.CoreTests Language.Futhark.PrimitiveTests Language.Futhark.SemanticTests Language.Futhark.SyntaxTests Language.Futhark.TypeChecker.TypesTests Language.Futhark.TypeCheckerTests build-depends: QuickCheck >=2.8 , mtl >=2.2.1 , base , containers , free , futhark , megaparsec , tasty , tasty-hunit , tasty-quickcheck , text futhark-0.25.27/nix/000077500000000000000000000000001475065116200141405ustar00rootroot00000000000000futhark-0.25.27/nix/elfutils191.nix000066400000000000000000000065721475065116200167540ustar00rootroot00000000000000{ lib, stdenv, fetchurl, fetchpatch, pkg-config, musl-fts , musl-obstack, m4, zlib, zstd, bzip2, bison, flex, gettext, xz, setupDebugInfoDirs , argp-standalone , enableDebuginfod ? lib.meta.availableOn stdenv.hostPlatform libarchive, sqlite, curl, libmicrohttpd, libarchive , gitUpdater, autoreconfHook }: # TODO: Look at the hardcoded paths to kernel, modules etc. stdenv.mkDerivation rec { pname = "elfutils"; version = "0.191"; src = fetchurl { url = "https://sourceware.org/elfutils/ftp/${version}/${pname}-${version}.tar.bz2"; hash = "sha256-33bbcTZtHXCDZfx6bGDKSDmPFDZ+sriVTvyIlxR62HE="; }; postPatch = '' patchShebangs tests/*.sh '' + lib.optionalString stdenv.hostPlatform.isRiscV '' # disable failing test: # # > dwfl_thread_getframes: No DWARF information found sed -i s/run-backtrace-dwarf.sh//g tests/Makefile.in ''; outputs = [ "bin" "dev" "out" "man" ]; # We need bzip2 in NativeInputs because otherwise we can't unpack the src, # as the host-bzip2 will be in the path. nativeBuildInputs = [ m4 bison flex gettext bzip2 ] ++ lib.optional enableDebuginfod pkg-config ++ lib.optional (stdenv.targetPlatform.useLLVM or false) autoreconfHook; buildInputs = [ zlib zstd bzip2 xz ] ++ lib.optionals stdenv.hostPlatform.isMusl [ argp-standalone musl-fts musl-obstack ] ++ lib.optionals enableDebuginfod [ sqlite curl libmicrohttpd libarchive ]; propagatedNativeBuildInputs = [ setupDebugInfoDirs ]; configureFlags = [ "--program-prefix=eu-" # prevent collisions with binutils "--enable-deterministic-archives" (lib.enableFeature enableDebuginfod "libdebuginfod") (lib.enableFeature enableDebuginfod "debuginfod") # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=101766 # Versioned symbols are nice to have, but we can do without. (lib.enableFeature (!stdenv.hostPlatform.isMicroBlaze) "symbol-versioning") ] ++ lib.optional (stdenv.targetPlatform.useLLVM or false) "--disable-demangler" ++ lib.optionals stdenv.cc.isClang [ "CFLAGS=-Wno-unused-private-field" "CXXFLAGS=-Wno-unused-private-field" ]; enableParallelBuilding = true; doCheck = # Backtrace unwinding tests rely on glibc-internal symbol names. # Musl provides slightly different forms and fails. # Let's disable tests there until musl support is fully upstreamed. !stdenv.hostPlatform.isMusl # Test suite tries using `uname` to determine whether certain tests # can be executed, so we need to match build and host platform exactly. && (stdenv.hostPlatform == stdenv.buildPlatform); doInstallCheck = !stdenv.hostPlatform.isMusl && (stdenv.hostPlatform == stdenv.buildPlatform); passthru.updateScript = gitUpdater { url = "https://sourceware.org/git/elfutils.git"; rev-prefix = "elfutils-"; }; meta = with lib; { homepage = "https://sourceware.org/elfutils/"; description = "Set of utilities to handle ELF objects"; platforms = platforms.linux; # https://lists.fedorahosted.org/pipermail/elfutils-devel/2014-November/004223.html badPlatforms = [ lib.systems.inspect.platformPatterns.isStatic ]; # licenses are GPL2 or LGPL3+ for libraries, GPL3+ for bins, # but since this package isn't split that way, all three are listed. license = with licenses; [ gpl2Only lgpl3Plus gpl3Plus ]; maintainers = with maintainers; [ r-burns ]; }; } futhark-0.25.27/nix/futhark-data.nix000066400000000000000000000013461475065116200172370ustar00rootroot00000000000000{ mkDerivation, base, binary, bytestring, bytestring-to-vector , containers, half, lib, megaparsec, mtl, QuickCheck, scientific , tasty, tasty-hunit, tasty-quickcheck, text, vector , vector-binary-instances }: mkDerivation { pname = "futhark-data"; version = "1.1.1.0"; sha256 = "0ef011fb779f269208c0a6b57a62e1a5ec265bfd0cde820edf400cef57451804"; libraryHaskellDepends = [ base binary bytestring bytestring-to-vector containers half megaparsec mtl scientific text vector vector-binary-instances ]; testHaskellDepends = [ base binary bytestring megaparsec QuickCheck tasty tasty-hunit tasty-quickcheck text vector ]; description = "An implementation of the Futhark data format"; license = lib.licenses.isc; } futhark-0.25.27/nix/futhark-manifest.nix000066400000000000000000000011071475065116200201270ustar00rootroot00000000000000{ mkDerivation, aeson, base, bytestring, containers, lib , QuickCheck, quickcheck-instances, tasty, tasty-hunit , tasty-quickcheck, text }: mkDerivation { pname = "futhark-manifest"; version = "1.5.0.0"; sha256 = "c4d076761f293f2f6251993b73e7e7de69cc15ac474e60770103f97558a3fcb1"; libraryHaskellDepends = [ aeson base bytestring containers text ]; testHaskellDepends = [ base QuickCheck quickcheck-instances tasty tasty-hunit tasty-quickcheck text ]; description = "Definition and serialisation instances for Futhark manifests"; license = lib.licenses.isc; } futhark-0.25.27/nix/futhark-server.nix000066400000000000000000000007251475065116200176340ustar00rootroot00000000000000{ mkDerivation, base, binary, bytestring, directory, futhark-data , lib, mtl, process, temporary, text }: mkDerivation { pname = "futhark-server"; version = "1.2.3.0"; sha256 = "4bd26a908ae3c41b4eb18343a8fedb193a06c802c9e8a31d99a4f87dc781f189"; libraryHaskellDepends = [ base binary bytestring directory futhark-data mtl process temporary text ]; description = "Client implementation of the Futhark server protocol"; license = lib.licenses.isc; } futhark-0.25.27/nix/sources.json000066400000000000000000000021131475065116200165130ustar00rootroot00000000000000{ "niv": { "branch": "master", "description": "Easy dependency management for Nix projects", "homepage": "https://github.com/nmattia/niv", "owner": "nmattia", "repo": "niv", "rev": "df49d53b71ad5b6b5847b32e5254924d60703c46", "sha256": "1j5p8mi1wi3pdcq0lfb881p97i232si07nb605dl92cjwnira88c", "type": "tarball", "url": "https://github.com/nmattia/niv/archive/df49d53b71ad5b6b5847b32e5254924d60703c46.tar.gz", "url_template": "https://github.com///archive/.tar.gz" }, "nixpkgs": { "branch": "master", "description": "Nix Packages collection", "homepage": "", "owner": "NixOS", "repo": "nixpkgs", "rev": "cf0e7c1ab2634e89fc2c4ec479609b8250dc0ace", "sha256": "14g5j556bkckplprzb907d06mkxzxs45dzkglacz1jrx955xx86w", "type": "tarball", "url": "https://github.com/NixOS/nixpkgs/archive/cf0e7c1ab2634e89fc2c4ec479609b8250dc0ace.tar.gz", "url_template": "https://github.com///archive/.tar.gz" } } futhark-0.25.27/nix/sources.nix000066400000000000000000000162241475065116200163500ustar00rootroot00000000000000# This file has been generated by Niv. let # # The fetchers. fetch_ fetches specs of type . # fetch_file = pkgs: name: spec: let name' = sanitizeName name + "-src"; in if spec.builtin or true then builtins_fetchurl { inherit (spec) url sha256; name = name'; } else pkgs.fetchurl { inherit (spec) url sha256; name = name'; }; fetch_tarball = pkgs: name: spec: let name' = sanitizeName name + "-src"; in if spec.builtin or true then builtins_fetchTarball { name = name'; inherit (spec) url sha256; } else pkgs.fetchzip { name = name'; inherit (spec) url sha256; }; fetch_git = name: spec: let ref = spec.ref or ( if spec ? branch then "refs/heads/${spec.branch}" else if spec ? tag then "refs/tags/${spec.tag}" else abort "In git source '${name}': Please specify `ref`, `tag` or `branch`!" ); submodules = spec.submodules or false; submoduleArg = let nixSupportsSubmodules = builtins.compareVersions builtins.nixVersion "2.4" >= 0; emptyArgWithWarning = if submodules then builtins.trace ( "The niv input \"${name}\" uses submodules " + "but your nix's (${builtins.nixVersion}) builtins.fetchGit " + "does not support them" ) { } else { }; in if nixSupportsSubmodules then { inherit submodules; } else emptyArgWithWarning; in builtins.fetchGit ({ url = spec.repo; inherit (spec) rev; inherit ref; } // submoduleArg); fetch_local = spec: spec.path; fetch_builtin-tarball = name: throw ''[${name}] The niv type "builtin-tarball" is deprecated. You should instead use `builtin = true`. $ niv modify ${name} -a type=tarball -a builtin=true''; fetch_builtin-url = name: throw ''[${name}] The niv type "builtin-url" will soon be deprecated. You should instead use `builtin = true`. $ niv modify ${name} -a type=file -a builtin=true''; # # Various helpers # # https://github.com/NixOS/nixpkgs/pull/83241/files#diff-c6f540a4f3bfa4b0e8b6bafd4cd54e8bR695 sanitizeName = name: ( concatMapStrings (s: if builtins.isList s then "-" else s) ( builtins.split "[^[:alnum:]+._?=-]+" ((x: builtins.elemAt (builtins.match "\\.*(.*)" x) 0) name) ) ); # The set of packages used when specs are fetched using non-builtins. mkPkgs = sources: system: let sourcesNixpkgs = import (builtins_fetchTarball { inherit (sources.nixpkgs) url sha256; }) { inherit system; }; hasNixpkgsPath = builtins.any (x: x.prefix == "nixpkgs") builtins.nixPath; hasThisAsNixpkgsPath = == ./.; in if builtins.hasAttr "nixpkgs" sources then sourcesNixpkgs else if hasNixpkgsPath && ! hasThisAsNixpkgsPath then import { } else abort '' Please specify either (through -I or NIX_PATH=nixpkgs=...) or add a package called "nixpkgs" to your sources.json. ''; # The actual fetching function. fetch = pkgs: name: spec: if ! builtins.hasAttr "type" spec then abort "ERROR: niv spec ${name} does not have a 'type' attribute" else if spec.type == "file" then fetch_file pkgs name spec else if spec.type == "tarball" then fetch_tarball pkgs name spec else if spec.type == "git" then fetch_git name spec else if spec.type == "local" then fetch_local spec else if spec.type == "builtin-tarball" then fetch_builtin-tarball name else if spec.type == "builtin-url" then fetch_builtin-url name else abort "ERROR: niv spec ${name} has unknown type ${builtins.toJSON spec.type}"; # If the environment variable NIV_OVERRIDE_${name} is set, then use # the path directly as opposed to the fetched source. replace = name: drv: let saneName = stringAsChars (c: if (builtins.match "[a-zA-Z0-9]" c) == null then "_" else c) name; ersatz = builtins.getEnv "NIV_OVERRIDE_${saneName}"; in if ersatz == "" then drv else # this turns the string into an actual Nix path (for both absolute and # relative paths) if builtins.substring 0 1 ersatz == "/" then /. + ersatz else /. + builtins.getEnv "PWD" + "/${ersatz}"; # Ports of functions for older nix versions # a Nix version of mapAttrs if the built-in doesn't exist mapAttrs = builtins.mapAttrs or ( f: set: with builtins; listToAttrs (map (attr: { name = attr; value = f attr set.${attr}; }) (attrNames set)) ); # https://github.com/NixOS/nixpkgs/blob/0258808f5744ca980b9a1f24fe0b1e6f0fecee9c/lib/lists.nix#L295 range = first: last: if first > last then [ ] else builtins.genList (n: first + n) (last - first + 1); # https://github.com/NixOS/nixpkgs/blob/0258808f5744ca980b9a1f24fe0b1e6f0fecee9c/lib/strings.nix#L257 stringToCharacters = s: map (p: builtins.substring p 1 s) (range 0 (builtins.stringLength s - 1)); # https://github.com/NixOS/nixpkgs/blob/0258808f5744ca980b9a1f24fe0b1e6f0fecee9c/lib/strings.nix#L269 stringAsChars = f: s: concatStrings (map f (stringToCharacters s)); concatMapStrings = f: list: concatStrings (map f list); concatStrings = builtins.concatStringsSep ""; # https://github.com/NixOS/nixpkgs/blob/8a9f58a375c401b96da862d969f66429def1d118/lib/attrsets.nix#L331 optionalAttrs = cond: as: if cond then as else { }; # fetchTarball version that is compatible between all the versions of Nix builtins_fetchTarball = { url, name ? null, sha256 }@attrs: let inherit (builtins) lessThan nixVersion fetchTarball; in if lessThan nixVersion "1.12" then fetchTarball ({ inherit url; } // (optionalAttrs (name != null) { inherit name; })) else fetchTarball attrs; # fetchurl version that is compatible between all the versions of Nix builtins_fetchurl = { url, name ? null, sha256 }@attrs: let inherit (builtins) lessThan nixVersion fetchurl; in if lessThan nixVersion "1.12" then fetchurl ({ inherit url; } // (optionalAttrs (name != null) { inherit name; })) else fetchurl attrs; # Create the final "sources" from the config mkSources = config: mapAttrs ( name: spec: if builtins.hasAttr "outPath" spec then abort "The values in sources.json should not have an 'outPath' attribute" else spec // { outPath = replace name (fetch config.pkgs name spec); } ) config.sources; # The "config" used by the fetchers mkConfig = { sourcesFile ? if builtins.pathExists ./sources.json then ./sources.json else null , sources ? if sourcesFile == null then { } else builtins.fromJSON (builtins.readFile sourcesFile) , system ? builtins.currentSystem , pkgs ? mkPkgs sources system }: rec { # The sources, i.e. the attribute set of spec name to spec inherit sources; # The "pkgs" (evaluated nixpkgs) to use for e.g. non-builtin fetchers inherit pkgs; }; in mkSources (mkConfig { }) // { __functor = _: settings: mkSources (mkConfig settings); } futhark-0.25.27/nix/zlib.nix000066400000000000000000000007641475065116200156270ustar00rootroot00000000000000{ mkDerivation, base, bytestring, lib, QuickCheck, tasty , tasty-quickcheck, zlib }: mkDerivation { pname = "zlib"; version = "0.7.0.0"; sha256 = "7e43c205e1e1ff5a4b033086ec8cce82ab658879e977c8ba02a6701946ff7a47"; libraryHaskellDepends = [ base bytestring ]; libraryPkgconfigDepends = [ zlib ]; testHaskellDepends = [ base bytestring QuickCheck tasty tasty-quickcheck ]; description = "Compression and decompression in the gzip and zlib formats"; license = lib.licenses.bsd3; } futhark-0.25.27/prelude/000077500000000000000000000000001475065116200150025ustar00rootroot00000000000000futhark-0.25.27/prelude/STYLE.md000066400000000000000000000032371475065116200162310ustar00rootroot00000000000000Futhark Prelude Library Style Guide =================================== This short document provides instructions on the coding style used in the Futhark Prelude Library. When you add new code, please try to adhere to the style outlined here. To ease porting and integration, Futhark is designed to permit a variety in naming and indentation style (so it can adapt to conventions already in use), but we try to be more consistent in the Prelude itself. The Prelude is generally kept very minimalistic. It is mostly intended for functions that directly wrap intrinsics, or are extremely widely used and have only one sensible implementation and name. Hence, we do not expect a great deal of development to happen here. If you disagree with any of these instructions, feel free to open a GitHub issue for discussion. Style Rules =========== The style is generally aimed at terseness. When in doubt, make it short and simple. ### Line Length The maximum line length is *80 characters*. ### Indentation Tabs are illegal. Use spaces for indenting. Indent your code blocks with *2 spaces*. ### Capitalisation Almost all names are lowercase, with `snake_case` used for long compound names. Parametric module parameters may be single capital letters. ### Comments Write proper sentences; start with a capital letter and use proper punctuation. Give non-trivial top level functions an explanatory comment. This may be skipped when the function corresponds to one in a module type. ### Modules Every module should be matched with some signature. ### Type Annotations Avoid optional type annotations unless they express some interesting property (i.e. contain shape declarations). futhark-0.25.27/prelude/ad.fut000066400000000000000000000122521475065116200161100ustar00rootroot00000000000000-- | Definitions related to automatic differentiation. -- -- Futhark supports a fairly complete facility for computing -- derivatives of functions through automatic differentiation (AD). -- The following does not constitute a full introduction to AD, which -- is a substantial topic by itself, but rather introduces enough -- concepts to describe Futhark's functional AD interface. AD is a -- program transformation that is implemented in the compiler itself, -- but the user interface is a handful of fairly simple functions. -- -- AD is useful when optimising cost functions, and for any other -- purpose where you might need derivatives, such as for example -- computing surface normals for signed distance functions. -- -- ## Jacobians -- -- For a differentiable function *f* whose input comprise *n* scalars -- and whose output comprises *m* scalars, the -- [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) -- for a given input point is an *m* by *n* matrix of scalars that -- each represent a [partial -- derivatives](https://en.wikipedia.org/wiki/Partial_derivative). -- Intuitively, position *(i,j)* of the Jacobian describes how -- sensitive output *i* is to input *j*. The notion of Jacobian -- generalises to functions that accept or produce compound structures -- such as arrays, records, sums, and so on, simply by "flattening -- out" the values and considering only their constituent scalars. -- -- Computing the full Jacobian is usually costly and sometimes not -- necessary, and it is not part of the AD facility provided by -- Futhark. Instead it is possible to parts of the Jacobian. -- -- We can take the product of an an *m* by *n* Jacobian with an -- *n*-element *tangent vector* to produce an *m*-element vector -- (*Jacobian-vector product*). Such a product can be computed in a -- single (augmented) execution of the function *f*, and by choosing -- the tangent vector appropriately we can use this to compute the -- full Jacobian. This is provided by the function `jvp`. -- -- We can also take the product of an *m*-element vector *cotangent -- vector* with the *m* by *n* Jacobian to produce an *n*-element -- vector (*Vector-Jacobian product*). This too can be computed in a -- single execution of *f*, with `vjp`. -- -- We can use the `jvp` function to produce a *column* of the full -- Jacobian, and `vjp` to produce a *row*. Which is superior for a -- given situation depends on whether the function has more inputs or -- outputs. -- -- You can freely nest `vjp` and `jvp` to compute higher-order -- derivatives. -- -- ## Efficiency -- -- Both `jvp` and `vjp` work by transforming the program to carry -- along extra information associated with each scalar value. -- -- In the case of `vjp`, this extra information takes the form of an -- additional scalar representing the tangent, which is then -- propagated in each scalar computation using essentially the [chain -- rule](https://en.wikipedia.org/wiki/Chain_rule). Therefore, `jvp` -- has a memory overhead of approximately *2x*, and a computational -- overhead of slightly more, but usually less than *4x*. -- -- In the case of `jvp`, since our starting point is a *cotangent*, -- the function is essentially first run forward, then backwards (the -- *return sweep*) to propagate the cotangent. During the return -- sweep, all intermediate results computed during the forward sweep -- must still be available, and must therefore be stored in memory -- during the forward sweep. This means that the memory usage of `jvp` -- is essentially proportional to the number of sequential steps of -- the original function (essentially turning *time* into *space*). -- The compiler does a nontrivial amount of optimisation to ameliorate -- this overhead (see [AD for an Array Language with Nested -- Parallelism](https://futhark-lang.org/publications/sc22-ad.pdf)), -- but it can still be substantial for programs with deep sequential -- loops. -- -- ## Differentiable functions -- -- AD only gives meaningful results for differentiable functions. The -- Futhark type system does not distinguish differentiable or -- non-differentiable operations. As a rule of thumb, a function is -- differentiable if its results are computed using a composition of -- primitive floating-point operations, without ever converting to or -- from integers. -- -- ## Limitations -- -- `jvp` is expected to work in all cases. `vjp` has limitations when -- using the GPU backends similar to those for irregular flattening. -- Specifically, you should avoid structures with variant sizes, such -- as loops that carry an array that changes size through the -- execution of the loop. -- | Jacobian-Vector Product ("forward mode"), producing also the -- primal result as the first element of the result tuple. def jvp2 'a 'b (f: a -> b) (x: a) (x': a): (b, b) = intrinsics.jvp2 f x x' -- | Vector-Jacobian Product ("reverse mode"), producing also the -- primal result as the first element of the result tuple. def vjp2 'a 'b (f: a -> b) (x: a) (y': b): (b, a) = intrinsics.vjp2 f x y' -- | Jacobian-Vector Product ("forward mode"). def jvp 'a 'b (f: a -> b) (x: a) (x': a): b = (jvp2 f x x').1 -- | Vector-Jacobian Product ("reverse mode"). def vjp 'a 'b (f: a -> b) (x: a) (y': b): a = (vjp2 f x y').1 futhark-0.25.27/prelude/array.fut000066400000000000000000000143571475065116200166520ustar00rootroot00000000000000-- | Utility functions for arrays. import "math" import "soacs" import "functional" open import "zip" -- Rexport. -- | The size of the outer dimension of an array. -- -- **Complexity:** O(1). def length [n] 't (_: [n]t) = n -- | Is the array empty? -- -- **Complexity:** O(1). def null [n] 't (_: [n]t) = n == 0 -- | The first element of the array. -- -- **Complexity:** O(1). #[inline] def head [n] 't (x: [n]t) = x[0] -- | The last element of the array. -- -- **Complexity:** O(1). #[inline] def last [n] 't (x: [n]t) = x[n - 1] -- | Everything but the first element of the array. -- -- **Complexity:** O(1). #[inline] def tail [n] 't (x: [n]t) : [n - 1]t = x[1:] -- | Everything but the last element of the array. -- -- **Complexity:** O(1). #[inline] def init [n] 't (x: [n]t) : [n - 1]t = x[0:n - 1] -- | Take some number of elements from the head of the array. -- -- **Complexity:** O(1). #[inline] def take [n] 't (i: i64) (x: [n]t) : [i]t = x[0:i] -- | Remove some number of elements from the head of the array. -- -- **Complexity:** O(1). #[inline] def drop [n] 't (i: i64) (x: [n]t) : [n - i]t = x[i:] -- | Statically change the size of an array. Fail at runtime if the -- imposed size does not match the actual size. Essentially syntactic -- sugar for a size coercion. #[inline] def sized [m] 't (n: i64) (xs: [m]t) : [n]t = xs :> [n]t -- | Split an array at a given position. -- -- **Complexity:** O(1). #[inline] def split [n] [m] 't (xs: [n + m]t) : ([n]t, [m]t) = (xs[0:n], xs[n:n + m] :> [m]t) -- | Return the elements of the array in reverse order. -- -- **Complexity:** O(1). #[inline] def reverse [n] 't (x: [n]t) : [n]t = x[::-1] -- | Concatenate two arrays. Warning: never try to perform a reduction -- with this operator; it will not work. -- -- **Work:** O(n). -- -- **Span:** O(1). #[inline] def (++) [n] [m] 't (xs: [n]t) (ys: [m]t) : *[n + m]t = intrinsics.concat xs ys -- | An old-fashioned way of saying `++`. #[inline] def concat [n] [m] 't (xs: [n]t) (ys: [m]t) : *[n + m]t = xs ++ ys -- | Construct an array of consecutive integers of the given length, -- starting at 0. -- -- **Work:** O(n). -- -- **Span:** O(1). #[inline] def iota (n: i64) : *[n]i64 = 0..1.. #[unsafe] a[(i + r) % n]) (iota n) -- | Construct an array of the given length containing the given -- value. -- -- **Work:** O(n). -- -- **Span:** O(1). #[inline] def replicate 't (n: i64) (x: t) : *[n]t = map (const x) (iota n) -- | Construct an array of an inferred length containing the given -- value. -- -- **Work:** O(n). -- -- **Span:** O(1). #[inline] def rep 't [n] (x: t) : *[n]t = replicate n x -- | Copy a value. The result will not alias anything. -- -- **Work:** O(n). -- -- **Span:** O(1). #[inline] def copy 't (a: t) : *t = ([a])[0] -- | Copy a value. The result will not alias anything. Additionally, -- there is a guarantee that the result will be laid out in row-major -- order in memory. This can be used for locality optimisations in -- cases where the compiler does not otherwise do the right thing. -- -- **Work:** O(n). -- -- **Span:** O(1). #[inline] def manifest 't (a: t) : *t = intrinsics.manifest a -- | Combines the outer two dimensions of an array. -- -- **Complexity:** O(1). #[inline] def flatten [n] [m] 't (xs: [n][m]t) : [n * m]t = intrinsics.flatten xs -- | Like `flatten`, but on the outer three dimensions of an array. #[inline] def flatten_3d [n] [m] [l] 't (xs: [n][m][l]t) : [n * m * l]t = flatten (flatten xs) -- | Like `flatten`, but on the outer four dimensions of an array. #[inline] def flatten_4d [n] [m] [l] [k] 't (xs: [n][m][l][k]t) : [n * m * l * k]t = flatten (flatten_3d xs) -- | Splits the outer dimension of an array in two. -- -- **Complexity:** O(1). #[inline] def unflatten 't [n] [m] (xs: [n * m]t) : [n][m]t = intrinsics.unflatten n m xs -- | Like `unflatten`, but produces three dimensions. #[inline] def unflatten_3d 't [n] [m] [l] (xs: [n * m * l]t) : [n][m][l]t = unflatten (unflatten xs) -- | Like `unflatten`, but produces four dimensions. #[inline] def unflatten_4d 't [n] [m] [l] [k] (xs: [n * m * l * k]t) : [n][m][l][k]t = unflatten (unflatten_3d xs) -- | Transpose an array. -- -- **Complexity:** O(1). #[inline] def transpose [n] [m] 't (a: [n][m]t) : [m][n]t = intrinsics.transpose a -- | True if all of the input elements are true. Produces true on an -- empty array. -- -- **Work:** O(n). -- -- **Span:** O(log(n)). def and [n] (xs: [n]bool) = all id xs -- | True if any of the input elements are true. Produces false on an -- empty array. -- -- **Work:** O(n). -- -- **Span:** O(log(n)). def or [n] (xs: [n]bool) = any id xs -- | Perform a *sequential* left-fold of an array. -- -- **Work:** O(n ✕ W(f))). -- -- **Span:** O(n ✕ S(f)). def foldl [n] 'a 'b (f: a -> b -> a) (acc: a) (bs: [n]b) : a = loop acc for b in bs do f acc b -- | Perform a *sequential* right-fold of an array. -- -- **Work:** O(n ✕ W(f))). -- -- **Span:** O(n ✕ S(f)). def foldr [n] 'a 'b (f: b -> a -> a) (acc: a) (bs: [n]b) : a = foldl (flip f) acc (reverse bs) -- | Create a value for each point in a one-dimensional index space. -- -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* def tabulate 'a (n: i64) (f: i64 -> a) : *[n]a = map1 f (iota n) -- | Create a value for each point in a two-dimensional index space. -- -- **Work:** *O(n ✕ m ✕ W(f))* -- -- **Span:** *O(S(f))* def tabulate_2d 'a (n: i64) (m: i64) (f: i64 -> i64 -> a) : *[n][m]a = map1 (f >-> tabulate m) (iota n) -- | Create a value for each point in a three-dimensional index space. -- -- **Work:** *O(n ✕ m ✕ o ✕ W(f))* -- -- **Span:** *O(S(f))* def tabulate_3d 'a (n: i64) (m: i64) (o: i64) (f: i64 -> i64 -> i64 -> a) : *[n][m][o]a = map1 (f >-> tabulate_2d m o) (iota n) futhark-0.25.27/prelude/functional.fut000066400000000000000000000044111475065116200176640ustar00rootroot00000000000000-- | Simple functional combinators. -- | Left-to-right application. Particularly useful for describing -- computation pipelines: -- -- ``` -- x |> f |> g |> h -- ``` def (|>) '^a '^b (x: a) (f: a -> b): b = f x -- | Right to left application. -- -- Due to the causality restriction (see the language reference) this -- is less useful than `|>`@term. For example, the following is -- a type error: -- -- ``` -- length <| filter (>0) [-1,0,1] -- ``` -- -- But this works: -- -- ``` -- filter (>0) [-1,0,1] |> length -- ``` def (<|) '^a '^b (f: a -> b) (x: a) = f x -- | Function composition, with values flowing from left to right. -- -- Note that functions with anonymous return sizes cannot be composed. -- For example, the following is a type error: -- -- ``` -- filter (>0) >-> length -- ``` -- -- In such cases you can use the pipe operator `|>`@term instead. def (>->) '^a '^b '^c (f: a -> b) (g: b -> c) (x: a): c = g (f x) -- | Function composition, with values flowing from right to left. -- This is the same as the `∘` operator known from mathematics. -- -- Has the same restrictions with respect to anonymous sizes as -- `>->`@term. def (<-<) '^a '^b '^c (g: b -> c) (f: a -> b) (x: a): c = g (f x) -- | Flip the arguments passed to a function. -- -- ``` -- f x y == flip f y x -- ``` def flip '^a '^b '^c (f: a -> b -> c) (b: b) (a: a): c = f a b -- | Transform a function taking a pair into a function taking two -- arguments. def curry '^a '^b '^c (f: (a, b) -> c) (a: a) (b: b): c = f (a, b) -- | Transform a function taking two arguments in a function taking a -- pair. def uncurry '^a '^b '^c (f: a -> b -> c) (a: a, b: b): c = f a b -- | The constant function. def const '^a '^b (x: a) (_: b): a = x -- | The identity function. def id '^a (x: a) = x -- | Apply a function some number of times. def iterate 'a (n: i32) (f: a -> a) (x: a) = loop x for _i < n do f x -- | Keep applying `f` until `p` returns true for the input value. -- May apply zero times. *Note*: may not terminate. def iterate_until 'a (p: a -> bool) (f: a -> a) (x: a) = loop x while !(p x) do f x -- | Keep applying `f` while `p` returns true for the input value. -- May apply zero times. *Note*: may not terminate. def iterate_while 'a (p: a -> bool) (f: a -> a) (x: a) = loop x while p x do f x futhark-0.25.27/prelude/math.fut000066400000000000000000001257411475065116200164650ustar00rootroot00000000000000-- | Basic mathematical modules and functions. import "soacs" -- | Describes types of values that can be created from the primitive -- numeric types (and bool). module type from_prim = { type t val i8: i8 -> t val i16: i16 -> t val i32: i32 -> t val i64: i64 -> t val u8: u8 -> t val u16: u16 -> t val u32: u32 -> t val u64: u64 -> t val f16: f16 -> t val f32: f32 -> t val f64: f64 -> t val bool: bool -> t } -- | A basic numeric module type that can be implemented for both -- integers and rational numbers. module type numeric = { include from_prim val +: t -> t -> t val -: t -> t -> t val *: t -> t -> t val /: t -> t -> t val %: t -> t -> t val **: t -> t -> t val to_i64: t -> i64 val ==: t -> t -> bool val <: t -> t -> bool val >: t -> t -> bool val <=: t -> t -> bool val >=: t -> t -> bool val !=: t -> t -> bool -- | Arithmetic negation (use `!` for bitwise negation). val neg: t -> t val max: t -> t -> t val min: t -> t -> t val abs: t -> t -- | Sign function. Produces -1, 0, or 1 if the argument is -- respectively less than, equal to, or greater than zero. val sgn: t -> t -- | The most positive representable number. val highest: t -- | The least positive representable number (most negative for -- signed types). val lowest: t -- | Returns zero on empty input. val sum [n]: [n]t -> t -- | Returns one on empty input. val product [n]: [n]t -> t -- | Returns `lowest` on empty input. val maximum [n]: [n]t -> t -- | Returns `highest` on empty input. val minimum [n]: [n]t -> t } -- | An extension of `numeric`@mtype that provides facilities that are -- only meaningful for integral types. module type integral = { include numeric -- | Like `/`@term, but rounds towards zero. This only matters when -- one of the operands is negative. May be more efficient. val //: t -> t -> t -- | Like `%`@term, but rounds towards zero. This only matters when -- one of the operands is negative. May be more efficient. val %%: t -> t -> t -- | Bitwise and. val &: t -> t -> t -- | Bitwise or. val |: t -> t -> t -- | Bitwise xor. val ^: t -> t -> t -- | Bitwise negation. val not: t -> t -- | Left shift; inserting zeroes. val <<: t -> t -> t -- | Arithmetic right shift, using sign extension for the leftmost bits. val >>: t -> t -> t -- | Logical right shift, inserting zeroes for the leftmost bits. val >>>: t -> t -> t val num_bits: i32 val get_bit: i32 -> t -> i32 val set_bit: i32 -> t -> i32 -> t -- | Count number of one bits. val popc: t -> i32 -- | Computes `x * y` and returns the high half of the product of x -- and y. val mul_hi: (x: t) -> (y: t) -> t -- | Computes `mul_hi a b + c`, but perhaps in a more efficient way, -- depending on the target platform. val mad_hi: (a: t) -> (b: t) -> (c: t) -> t -- | Count number of zero bits preceding the most significant set -- bit. Returns the number of bits in the type if the argument is -- zero. val clz: t -> i32 -- | Count number of trailing zero bits following the least -- significant set bit. Returns the number of bits in the type if -- the argument is zero. val ctz: t -> i32 } -- | Numbers that model real numbers to some degree. module type real = { include numeric -- | Multiplicative inverse. val recip: t -> t val from_fraction: i64 -> i64 -> t val to_i64: t -> i64 val to_f64: t -> f64 -- | Square root. val sqrt: t -> t -- | Cube root. val cbrt: t -> t val exp: t -> t val sin: t -> t val cos: t -> t val tan: t -> t val asin: t -> t val acos: t -> t val atan: t -> t val sinh: t -> t val cosh: t -> t val tanh: t -> t val asinh: t -> t val acosh: t -> t val atanh: t -> t val atan2: t -> t -> t -- | Compute the length of the hypotenuse of a right-angled -- triangle. That is, `hypot x y` computes *√(x²+y²)*. Put another -- way, the distance of *(x,y)* from origin in an Euclidean space. -- The calculation is performed without undue overflow or underflow -- during intermediate steps (specific accuracy depends on the -- backend). val hypot: t -> t -> t -- | The true Gamma function. val gamma: t -> t -- | The natural logarithm of the absolute value of `gamma`@term. val lgamma: t -> t -- | The error function. val erf: t -> t -- | The complementary error function. val erfc: t -> t -- | Linear interpolation. The third argument must be in the range -- `[0,1]` or the results are unspecified. val lerp: t -> t -> t -> t -- | Natural logarithm. val log: t -> t -- | Base-2 logarithm. val log2: t -> t -- | Base-10 logarithm. val log10: t -> t -- | Compute `log (1 + x)` accurately even when `x` is very small. val log1p: t -> t -- | Round towards infinity. val ceil: t -> t -- | Round towards negative infinity. val floor: t -> t -- | Round towards zero. val trunc: t -> t -- | Round to the nearest integer, with halfway cases rounded to the -- nearest even integer. Note that this differs from `round()` in -- C, but matches more modern languages. val round: t -> t -- | Computes `a*b+c`. Depending on the compiler backend, this may -- be fused into a single operation that is faster but less -- accurate. Do not confuse it with `fma`@term. val mad: (a: t) -> (b: t) -> (c: t) -> t -- | Computes `a*b+c`, with `a*b` being rounded with infinite -- precision. Rounding of intermediate products shall not -- occur. Edge case behavior is per the IEEE 754-2008 standard. val fma: (a: t) -> (b: t) -> (c: t) -> t val isinf: t -> bool val isnan: t -> bool val inf: t val nan: t val pi: t val e: t } -- | An extension of `real`@mtype that further gives access to the -- bitwise representation of the underlying number. It is presumed -- that this will be some form of IEEE float. -- -- Conversion of floats to integers is by truncation. If an infinity -- or NaN is converted to an integer, the result is zero. module type float = { include real -- | An unsigned integer type containing the same number of bits as -- 't'. type int_t val from_bits: int_t -> t val to_bits: t -> int_t val num_bits: i32 val get_bit: i32 -> t -> i32 val set_bit: i32 -> t -> i32 -> t -- | The difference between 1.0 and the next larger representable -- number. val epsilon: t -- | Produces the next representable number from `x` in the -- direction of `y`. val nextafter: (x: t) -> (y: t) -> t -- | Multiplies floating-point value by 2 raised to an integer power. val ldexp: t -> i32 -> t -- | Compose a floating-point value with the magnitude of `x` and the sign of `y`. val copysign: (x: t) -> (y: t) -> t } -- | Boolean numbers. When converting from a number to `bool`, 0 is -- considered `false` and any other value is `true`. module bool: from_prim with t = bool = { type t = bool def i8 = intrinsics.itob_i8_bool def i16 = intrinsics.itob_i16_bool def i32 = intrinsics.itob_i32_bool def i64 = intrinsics.itob_i64_bool def u8 (x: u8) = intrinsics.itob_i8_bool (intrinsics.sign_i8 x) def u16 (x: u16) = intrinsics.itob_i16_bool (intrinsics.sign_i16 x) def u32 (x: u32) = intrinsics.itob_i32_bool (intrinsics.sign_i32 x) def u64 (x: u64) = intrinsics.itob_i64_bool (intrinsics.sign_i64 x) def f16 (x: f16) = intrinsics.ftob_f16_bool x def f32 (x: f32) = intrinsics.ftob_f32_bool x def f64 (x: f64) = intrinsics.ftob_f64_bool x def bool (x: bool) = x } module i8: (integral with t = i8) = { type t = i8 def (+) (x: i8) (y: i8) = intrinsics.add8 (x, y) def (-) (x: i8) (y: i8) = intrinsics.sub8 (x, y) def (*) (x: i8) (y: i8) = intrinsics.mul8 (x, y) def (/) (x: i8) (y: i8) = intrinsics.sdiv8 (x, y) def (**) (x: i8) (y: i8) = intrinsics.pow8 (x, y) def (%) (x: i8) (y: i8) = intrinsics.smod8 (x, y) def (//) (x: i8) (y: i8) = intrinsics.squot8 (x, y) def (%%) (x: i8) (y: i8) = intrinsics.srem8 (x, y) def (&) (x: i8) (y: i8) = intrinsics.and8 (x, y) def (|) (x: i8) (y: i8) = intrinsics.or8 (x, y) def (^) (x: i8) (y: i8) = intrinsics.xor8 (x, y) def not (x: i8) = intrinsics.complement8 x def (<<) (x: i8) (y: i8) = intrinsics.shl8 (x, y) def (>>) (x: i8) (y: i8) = intrinsics.ashr8 (x, y) def (>>>) (x: i8) (y: i8) = intrinsics.lshr8 (x, y) def i8 (x: i8) = intrinsics.sext_i8_i8 x def i16 (x: i16) = intrinsics.sext_i16_i8 x def i32 (x: i32) = intrinsics.sext_i32_i8 x def i64 (x: i64) = intrinsics.sext_i64_i8 x def u8 (x: u8) = intrinsics.zext_i8_i8 (intrinsics.sign_i8 x) def u16 (x: u16) = intrinsics.zext_i16_i8 (intrinsics.sign_i16 x) def u32 (x: u32) = intrinsics.zext_i32_i8 (intrinsics.sign_i32 x) def u64 (x: u64) = intrinsics.zext_i64_i8 (intrinsics.sign_i64 x) def f16 (x: f16) = intrinsics.fptosi_f16_i8 x def f32 (x: f32) = intrinsics.fptosi_f32_i8 x def f64 (x: f64) = intrinsics.fptosi_f64_i8 x def bool = intrinsics.btoi_bool_i8 def to_i32 (x: i8) = intrinsics.sext_i8_i32 x def to_i64 (x: i8) = intrinsics.sext_i8_i64 x def (==) (x: i8) (y: i8) = intrinsics.eq_i8 (x, y) def (<) (x: i8) (y: i8) = intrinsics.slt8 (x, y) def (>) (x: i8) (y: i8) = intrinsics.slt8 (y, x) def (<=) (x: i8) (y: i8) = intrinsics.sle8 (x, y) def (>=) (x: i8) (y: i8) = intrinsics.sle8 (y, x) def (!=) (x: i8) (y: i8) = !(x == y) def sgn (x: i8) = intrinsics.ssignum8 x def abs (x: i8) = intrinsics.abs8 x def neg (x: t) = -x def max (x: t) (y: t) = intrinsics.smax8 (x, y) def min (x: t) (y: t) = intrinsics.smin8 (x, y) def highest = 127i8 def lowest = highest + 1i8 def num_bits = 8i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) def popc = intrinsics.popc8 def mul_hi a b = intrinsics.smul_hi8 (i8 a, i8 b) def mad_hi a b c = intrinsics.smad_hi8 (i8 a, i8 b, i8 c) def clz = intrinsics.clz8 def ctz = intrinsics.ctz8 def sum = reduce (+) (i32 0) def product = reduce (*) (i32 1) def maximum = reduce max lowest def minimum = reduce min highest } module i16: (integral with t = i16) = { type t = i16 def (+) (x: i16) (y: i16) = intrinsics.add16 (x, y) def (-) (x: i16) (y: i16) = intrinsics.sub16 (x, y) def (*) (x: i16) (y: i16) = intrinsics.mul16 (x, y) def (/) (x: i16) (y: i16) = intrinsics.sdiv16 (x, y) def (**) (x: i16) (y: i16) = intrinsics.pow16 (x, y) def (%) (x: i16) (y: i16) = intrinsics.smod16 (x, y) def (//) (x: i16) (y: i16) = intrinsics.squot16 (x, y) def (%%) (x: i16) (y: i16) = intrinsics.srem16 (x, y) def (&) (x: i16) (y: i16) = intrinsics.and16 (x, y) def (|) (x: i16) (y: i16) = intrinsics.or16 (x, y) def (^) (x: i16) (y: i16) = intrinsics.xor16 (x, y) def not (x: i16) = intrinsics.complement16 x def (<<) (x: i16) (y: i16) = intrinsics.shl16 (x, y) def (>>) (x: i16) (y: i16) = intrinsics.ashr16 (x, y) def (>>>) (x: i16) (y: i16) = intrinsics.lshr16 (x, y) def i8 (x: i8) = intrinsics.sext_i8_i16 x def i16 (x: i16) = intrinsics.sext_i16_i16 x def i32 (x: i32) = intrinsics.sext_i32_i16 x def i64 (x: i64) = intrinsics.sext_i64_i16 x def u8 (x: u8) = intrinsics.zext_i8_i16 (intrinsics.sign_i8 x) def u16 (x: u16) = intrinsics.zext_i16_i16 (intrinsics.sign_i16 x) def u32 (x: u32) = intrinsics.zext_i32_i16 (intrinsics.sign_i32 x) def u64 (x: u64) = intrinsics.zext_i64_i16 (intrinsics.sign_i64 x) def f16 (x: f16) = intrinsics.fptosi_f16_i16 x def f32 (x: f32) = intrinsics.fptosi_f32_i16 x def f64 (x: f64) = intrinsics.fptosi_f64_i16 x def bool = intrinsics.btoi_bool_i16 def to_i32 (x: i16) = intrinsics.sext_i16_i32 x def to_i64 (x: i16) = intrinsics.sext_i16_i64 x def (==) (x: i16) (y: i16) = intrinsics.eq_i16 (x, y) def (<) (x: i16) (y: i16) = intrinsics.slt16 (x, y) def (>) (x: i16) (y: i16) = intrinsics.slt16 (y, x) def (<=) (x: i16) (y: i16) = intrinsics.sle16 (x, y) def (>=) (x: i16) (y: i16) = intrinsics.sle16 (y, x) def (!=) (x: i16) (y: i16) = !(x == y) def sgn (x: i16) = intrinsics.ssignum16 x def abs (x: i16) = intrinsics.abs16 x def neg (x: t) = -x def max (x: t) (y: t) = intrinsics.smax16 (x, y) def min (x: t) (y: t) = intrinsics.smin16 (x, y) def highest = 32767i16 def lowest = highest + 1i16 def num_bits = 16i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) def popc = intrinsics.popc16 def mul_hi a b = intrinsics.smul_hi16 (i16 a, i16 b) def mad_hi a b c = intrinsics.smad_hi16 (i16 a, i16 b, i16 c) def clz = intrinsics.clz16 def ctz = intrinsics.ctz16 def sum = reduce (+) (i32 0) def product = reduce (*) (i32 1) def maximum = reduce max lowest def minimum = reduce min highest } module i32: (integral with t = i32) = { type t = i32 def sign (x: u32) = intrinsics.sign_i32 x def unsign (x: i32) = intrinsics.unsign_i32 x def (+) (x: i32) (y: i32) = intrinsics.add32 (x, y) def (-) (x: i32) (y: i32) = intrinsics.sub32 (x, y) def (*) (x: i32) (y: i32) = intrinsics.mul32 (x, y) def (/) (x: i32) (y: i32) = intrinsics.sdiv32 (x, y) def (**) (x: i32) (y: i32) = intrinsics.pow32 (x, y) def (%) (x: i32) (y: i32) = intrinsics.smod32 (x, y) def (//) (x: i32) (y: i32) = intrinsics.squot32 (x, y) def (%%) (x: i32) (y: i32) = intrinsics.srem32 (x, y) def (&) (x: i32) (y: i32) = intrinsics.and32 (x, y) def (|) (x: i32) (y: i32) = intrinsics.or32 (x, y) def (^) (x: i32) (y: i32) = intrinsics.xor32 (x, y) def not (x: i32) = intrinsics.complement32 x def (<<) (x: i32) (y: i32) = intrinsics.shl32 (x, y) def (>>) (x: i32) (y: i32) = intrinsics.ashr32 (x, y) def (>>>) (x: i32) (y: i32) = intrinsics.lshr32 (x, y) def i8 (x: i8) = intrinsics.sext_i8_i32 x def i16 (x: i16) = intrinsics.sext_i16_i32 x def i32 (x: i32) = intrinsics.sext_i32_i32 x def i64 (x: i64) = intrinsics.sext_i64_i32 x def u8 (x: u8) = intrinsics.zext_i8_i32 (intrinsics.sign_i8 x) def u16 (x: u16) = intrinsics.zext_i16_i32 (intrinsics.sign_i16 x) def u32 (x: u32) = intrinsics.zext_i32_i32 (intrinsics.sign_i32 x) def u64 (x: u64) = intrinsics.zext_i64_i32 (intrinsics.sign_i64 x) def f16 (x: f16) = intrinsics.fptosi_f16_i32 x def f32 (x: f32) = intrinsics.fptosi_f32_i32 x def f64 (x: f64) = intrinsics.fptosi_f64_i32 x def bool = intrinsics.btoi_bool_i32 def to_i32 (x: i32) = intrinsics.sext_i32_i32 x def to_i64 (x: i32) = intrinsics.sext_i32_i64 x def (==) (x: i32) (y: i32) = intrinsics.eq_i32 (x, y) def (<) (x: i32) (y: i32) = intrinsics.slt32 (x, y) def (>) (x: i32) (y: i32) = intrinsics.slt32 (y, x) def (<=) (x: i32) (y: i32) = intrinsics.sle32 (x, y) def (>=) (x: i32) (y: i32) = intrinsics.sle32 (y, x) def (!=) (x: i32) (y: i32) = !(x == y) def sgn (x: i32) = intrinsics.ssignum32 x def abs (x: i32) = intrinsics.abs32 x def neg (x: t) = -x def max (x: t) (y: t) = intrinsics.smax32 (x, y) def min (x: t) (y: t) = intrinsics.smin32 (x, y) def highest = 2147483647i32 def lowest = highest + 1 def num_bits = 32i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) def popc = intrinsics.popc32 def mul_hi a b = intrinsics.smul_hi32 (i32 a, i32 b) def mad_hi a b c = intrinsics.smad_hi32 (i32 a, i32 b, i32 c) def clz = intrinsics.clz32 def ctz = intrinsics.ctz32 def sum = reduce (+) (i32 0) def product = reduce (*) (i32 1) def maximum = reduce max lowest def minimum = reduce min highest } module i64: (integral with t = i64) = { type t = i64 def sign (x: u64) = intrinsics.sign_i64 x def unsign (x: i64) = intrinsics.unsign_i64 x def (+) (x: i64) (y: i64) = intrinsics.add64 (x, y) def (-) (x: i64) (y: i64) = intrinsics.sub64 (x, y) def (*) (x: i64) (y: i64) = intrinsics.mul64 (x, y) def (/) (x: i64) (y: i64) = intrinsics.sdiv64 (x, y) def (**) (x: i64) (y: i64) = intrinsics.pow64 (x, y) def (%) (x: i64) (y: i64) = intrinsics.smod64 (x, y) def (//) (x: i64) (y: i64) = intrinsics.squot64 (x, y) def (%%) (x: i64) (y: i64) = intrinsics.srem64 (x, y) def (&) (x: i64) (y: i64) = intrinsics.and64 (x, y) def (|) (x: i64) (y: i64) = intrinsics.or64 (x, y) def (^) (x: i64) (y: i64) = intrinsics.xor64 (x, y) def not (x: i64) = intrinsics.complement64 x def (<<) (x: i64) (y: i64) = intrinsics.shl64 (x, y) def (>>) (x: i64) (y: i64) = intrinsics.ashr64 (x, y) def (>>>) (x: i64) (y: i64) = intrinsics.lshr64 (x, y) def i8 (x: i8) = intrinsics.sext_i8_i64 x def i16 (x: i16) = intrinsics.sext_i16_i64 x def i32 (x: i32) = intrinsics.sext_i32_i64 x def i64 (x: i64) = intrinsics.sext_i64_i64 x def u8 (x: u8) = intrinsics.zext_i8_i64 (intrinsics.sign_i8 x) def u16 (x: u16) = intrinsics.zext_i16_i64 (intrinsics.sign_i16 x) def u32 (x: u32) = intrinsics.zext_i32_i64 (intrinsics.sign_i32 x) def u64 (x: u64) = intrinsics.zext_i64_i64 (intrinsics.sign_i64 x) def f16 (x: f16) = intrinsics.fptosi_f16_i64 x def f32 (x: f32) = intrinsics.fptosi_f32_i64 x def f64 (x: f64) = intrinsics.fptosi_f64_i64 x def bool = intrinsics.btoi_bool_i64 def to_i32 (x: i64) = intrinsics.sext_i64_i32 x def to_i64 (x: i64) = intrinsics.sext_i64_i64 x def (==) (x: i64) (y: i64) = intrinsics.eq_i64 (x, y) def (<) (x: i64) (y: i64) = intrinsics.slt64 (x, y) def (>) (x: i64) (y: i64) = intrinsics.slt64 (y, x) def (<=) (x: i64) (y: i64) = intrinsics.sle64 (x, y) def (>=) (x: i64) (y: i64) = intrinsics.sle64 (y, x) def (!=) (x: i64) (y: i64) = !(x == y) def sgn (x: i64) = intrinsics.ssignum64 x def abs (x: i64) = intrinsics.abs64 x def neg (x: t) = -x def max (x: t) (y: t) = intrinsics.smax64 (x, y) def min (x: t) (y: t) = intrinsics.smin64 (x, y) def highest = 9223372036854775807i64 def lowest = highest + 1i64 def num_bits = 64i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | intrinsics.zext_i32_i64 (b intrinsics.<< bit)) def popc = intrinsics.popc64 def mul_hi a b = intrinsics.smul_hi64 (i64 a, i64 b) def mad_hi a b c = intrinsics.smad_hi64 (i64 a, i64 b, i64 c) def clz = intrinsics.clz64 def ctz = intrinsics.ctz64 def sum = reduce (+) (i32 0) def product = reduce (*) (i32 1) def maximum = reduce max lowest def minimum = reduce min highest } module u8: (integral with t = u8) = { type t = u8 def sign (x: u8) = intrinsics.sign_i8 x def unsign (x: i8) = intrinsics.unsign_i8 x def (+) (x: u8) (y: u8) = unsign (intrinsics.add8 (sign x, sign y)) def (-) (x: u8) (y: u8) = unsign (intrinsics.sub8 (sign x, sign y)) def (*) (x: u8) (y: u8) = unsign (intrinsics.mul8 (sign x, sign y)) def (/) (x: u8) (y: u8) = unsign (intrinsics.udiv8 (sign x, sign y)) def (**) (x: u8) (y: u8) = unsign (intrinsics.pow8 (sign x, sign y)) def (%) (x: u8) (y: u8) = unsign (intrinsics.umod8 (sign x, sign y)) def (//) (x: u8) (y: u8) = unsign (intrinsics.udiv8 (sign x, sign y)) def (%%) (x: u8) (y: u8) = unsign (intrinsics.umod8 (sign x, sign y)) def (&) (x: u8) (y: u8) = unsign (intrinsics.and8 (sign x, sign y)) def (|) (x: u8) (y: u8) = unsign (intrinsics.or8 (sign x, sign y)) def (^) (x: u8) (y: u8) = unsign (intrinsics.xor8 (sign x, sign y)) def not (x: u8) = unsign (intrinsics.complement8 (sign x)) def (<<) (x: u8) (y: u8) = unsign (intrinsics.shl8 (sign x, sign y)) def (>>) (x: u8) (y: u8) = unsign (intrinsics.ashr8 (sign x, sign y)) def (>>>) (x: u8) (y: u8) = unsign (intrinsics.lshr8 (sign x, sign y)) def u8 (x: u8) = unsign (i8.u8 x) def u16 (x: u16) = unsign (i8.u16 x) def u32 (x: u32) = unsign (i8.u32 x) def u64 (x: u64) = unsign (i8.u64 x) def i8 (x: i8) = unsign (intrinsics.zext_i8_i8 x) def i16 (x: i16) = unsign (intrinsics.zext_i16_i8 x) def i32 (x: i32) = unsign (intrinsics.zext_i32_i8 x) def i64 (x: i64) = unsign (intrinsics.zext_i64_i8 x) def f16 (x: f16) = unsign (intrinsics.fptoui_f16_i8 x) def f32 (x: f32) = unsign (intrinsics.fptoui_f32_i8 x) def f64 (x: f64) = unsign (intrinsics.fptoui_f64_i8 x) def bool x = unsign (intrinsics.btoi_bool_i8 x) def to_i32 (x: u8) = intrinsics.zext_i8_i32 (sign x) def to_i64 (x: u8) = intrinsics.zext_i8_i64 (sign x) def (==) (x: u8) (y: u8) = intrinsics.eq_i8 (sign x, sign y) def (<) (x: u8) (y: u8) = intrinsics.ult8 (sign x, sign y) def (>) (x: u8) (y: u8) = intrinsics.ult8 (sign y, sign x) def (<=) (x: u8) (y: u8) = intrinsics.ule8 (sign x, sign y) def (>=) (x: u8) (y: u8) = intrinsics.ule8 (sign y, sign x) def (!=) (x: u8) (y: u8) = !(x == y) def sgn (x: u8) = unsign (intrinsics.usignum8 (sign x)) def abs (x: u8) = x def neg (x: t) = -x def max (x: t) (y: t) = unsign (intrinsics.umax8 (sign x, sign y)) def min (x: t) (y: t) = unsign (intrinsics.umin8 (sign x, sign y)) def highest = 255u8 def lowest = 0u8 def num_bits = 8i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) def popc x = intrinsics.popc8 (sign x) def mul_hi a b = unsign (intrinsics.umul_hi8 (sign a, sign b)) def mad_hi a b c = unsign (intrinsics.umad_hi8 (sign a, sign b, sign c)) def clz x = intrinsics.clz8 (sign x) def ctz x = intrinsics.ctz8 (sign x) def sum = reduce (+) (i32 0) def product = reduce (*) (i32 1) def maximum = reduce max lowest def minimum = reduce min highest } module u16: (integral with t = u16) = { type t = u16 def sign (x: u16) = intrinsics.sign_i16 x def unsign (x: i16) = intrinsics.unsign_i16 x def (+) (x: u16) (y: u16) = unsign (intrinsics.add16 (sign x, sign y)) def (-) (x: u16) (y: u16) = unsign (intrinsics.sub16 (sign x, sign y)) def (*) (x: u16) (y: u16) = unsign (intrinsics.mul16 (sign x, sign y)) def (/) (x: u16) (y: u16) = unsign (intrinsics.udiv16 (sign x, sign y)) def (**) (x: u16) (y: u16) = unsign (intrinsics.pow16 (sign x, sign y)) def (%) (x: u16) (y: u16) = unsign (intrinsics.umod16 (sign x, sign y)) def (//) (x: u16) (y: u16) = unsign (intrinsics.udiv16 (sign x, sign y)) def (%%) (x: u16) (y: u16) = unsign (intrinsics.umod16 (sign x, sign y)) def (&) (x: u16) (y: u16) = unsign (intrinsics.and16 (sign x, sign y)) def (|) (x: u16) (y: u16) = unsign (intrinsics.or16 (sign x, sign y)) def (^) (x: u16) (y: u16) = unsign (intrinsics.xor16 (sign x, sign y)) def not (x: u16) = unsign (intrinsics.complement16 (sign x)) def (<<) (x: u16) (y: u16) = unsign (intrinsics.shl16 (sign x, sign y)) def (>>) (x: u16) (y: u16) = unsign (intrinsics.ashr16 (sign x, sign y)) def (>>>) (x: u16) (y: u16) = unsign (intrinsics.lshr16 (sign x, sign y)) def u8 (x: u8) = unsign (i16.u8 x) def u16 (x: u16) = unsign (i16.u16 x) def u32 (x: u32) = unsign (i16.u32 x) def u64 (x: u64) = unsign (i16.u64 x) def i8 (x: i8) = unsign (intrinsics.zext_i8_i16 x) def i16 (x: i16) = unsign (intrinsics.zext_i16_i16 x) def i32 (x: i32) = unsign (intrinsics.zext_i32_i16 x) def i64 (x: i64) = unsign (intrinsics.zext_i64_i16 x) def f16 (x: f16) = unsign (intrinsics.fptoui_f16_i16 x) def f32 (x: f32) = unsign (intrinsics.fptoui_f32_i16 x) def f64 (x: f64) = unsign (intrinsics.fptoui_f64_i16 x) def bool x = unsign (intrinsics.btoi_bool_i16 x) def to_i32 (x: u16) = intrinsics.zext_i16_i32 (sign x) def to_i64 (x: u16) = intrinsics.zext_i16_i64 (sign x) def (==) (x: u16) (y: u16) = intrinsics.eq_i16 (sign x, sign y) def (<) (x: u16) (y: u16) = intrinsics.ult16 (sign x, sign y) def (>) (x: u16) (y: u16) = intrinsics.ult16 (sign y, sign x) def (<=) (x: u16) (y: u16) = intrinsics.ule16 (sign x, sign y) def (>=) (x: u16) (y: u16) = intrinsics.ule16 (sign y, sign x) def (!=) (x: u16) (y: u16) = !(x == y) def sgn (x: u16) = unsign (intrinsics.usignum16 (sign x)) def abs (x: u16) = x def neg (x: t) = -x def max (x: t) (y: t) = unsign (intrinsics.umax16 (sign x, sign y)) def min (x: t) (y: t) = unsign (intrinsics.umin16 (sign x, sign y)) def highest = 65535u16 def lowest = 0u16 def num_bits = 16i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) def popc x = intrinsics.popc16 (sign x) def mul_hi a b = unsign (intrinsics.umul_hi16 (sign a, sign b)) def mad_hi a b c = unsign (intrinsics.umad_hi16 (sign a, sign b, sign c)) def clz x = intrinsics.clz16 (sign x) def ctz x = intrinsics.ctz16 (sign x) def sum = reduce (+) (i32 0) def product = reduce (*) (i32 1) def maximum = reduce max lowest def minimum = reduce min highest } module u32: (integral with t = u32) = { type t = u32 def sign (x: u32) = intrinsics.sign_i32 x def unsign (x: i32) = intrinsics.unsign_i32 x def (+) (x: u32) (y: u32) = unsign (intrinsics.add32 (sign x, sign y)) def (-) (x: u32) (y: u32) = unsign (intrinsics.sub32 (sign x, sign y)) def (*) (x: u32) (y: u32) = unsign (intrinsics.mul32 (sign x, sign y)) def (/) (x: u32) (y: u32) = unsign (intrinsics.udiv32 (sign x, sign y)) def (**) (x: u32) (y: u32) = unsign (intrinsics.pow32 (sign x, sign y)) def (%) (x: u32) (y: u32) = unsign (intrinsics.umod32 (sign x, sign y)) def (//) (x: u32) (y: u32) = unsign (intrinsics.udiv32 (sign x, sign y)) def (%%) (x: u32) (y: u32) = unsign (intrinsics.umod32 (sign x, sign y)) def (&) (x: u32) (y: u32) = unsign (intrinsics.and32 (sign x, sign y)) def (|) (x: u32) (y: u32) = unsign (intrinsics.or32 (sign x, sign y)) def (^) (x: u32) (y: u32) = unsign (intrinsics.xor32 (sign x, sign y)) def not (x: u32) = unsign (intrinsics.complement32 (sign x)) def (<<) (x: u32) (y: u32) = unsign (intrinsics.shl32 (sign x, sign y)) def (>>) (x: u32) (y: u32) = unsign (intrinsics.ashr32 (sign x, sign y)) def (>>>) (x: u32) (y: u32) = unsign (intrinsics.lshr32 (sign x, sign y)) def u8 (x: u8) = unsign (i32.u8 x) def u16 (x: u16) = unsign (i32.u16 x) def u32 (x: u32) = unsign (i32.u32 x) def u64 (x: u64) = unsign (i32.u64 x) def i8 (x: i8) = unsign (intrinsics.zext_i8_i32 x) def i16 (x: i16) = unsign (intrinsics.zext_i16_i32 x) def i32 (x: i32) = unsign (intrinsics.zext_i32_i32 x) def i64 (x: i64) = unsign (intrinsics.zext_i64_i32 x) def f16 (x: f16) = unsign (intrinsics.fptoui_f16_i32 x) def f32 (x: f32) = unsign (intrinsics.fptoui_f32_i32 x) def f64 (x: f64) = unsign (intrinsics.fptoui_f64_i32 x) def bool x = unsign (intrinsics.btoi_bool_i32 x) def to_i32 (x: u32) = intrinsics.zext_i32_i32 (sign x) def to_i64 (x: u32) = intrinsics.zext_i32_i64 (sign x) def (==) (x: u32) (y: u32) = intrinsics.eq_i32 (sign x, sign y) def (<) (x: u32) (y: u32) = intrinsics.ult32 (sign x, sign y) def (>) (x: u32) (y: u32) = intrinsics.ult32 (sign y, sign x) def (<=) (x: u32) (y: u32) = intrinsics.ule32 (sign x, sign y) def (>=) (x: u32) (y: u32) = intrinsics.ule32 (sign y, sign x) def (!=) (x: u32) (y: u32) = !(x == y) def sgn (x: u32) = unsign (intrinsics.usignum32 (sign x)) def abs (x: u32) = x def highest = 4294967295u32 def lowest = highest + 1u32 def neg (x: t) = -x def max (x: t) (y: t) = unsign (intrinsics.umax32 (sign x, sign y)) def min (x: t) (y: t) = unsign (intrinsics.umin32 (sign x, sign y)) def num_bits = 32i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) def popc x = intrinsics.popc32 (sign x) def mul_hi a b = unsign (intrinsics.umul_hi32 (sign a, sign b)) def mad_hi a b c = unsign (intrinsics.umad_hi32 (sign a, sign b, sign c)) def clz x = intrinsics.clz32 (sign x) def ctz x = intrinsics.ctz32 (sign x) def sum = reduce (+) (i32 0) def product = reduce (*) (i32 1) def maximum = reduce max lowest def minimum = reduce min highest } module u64: (integral with t = u64) = { type t = u64 def sign (x: u64) = intrinsics.sign_i64 x def unsign (x: i64) = intrinsics.unsign_i64 x def (+) (x: u64) (y: u64) = unsign (intrinsics.add64 (sign x, sign y)) def (-) (x: u64) (y: u64) = unsign (intrinsics.sub64 (sign x, sign y)) def (*) (x: u64) (y: u64) = unsign (intrinsics.mul64 (sign x, sign y)) def (/) (x: u64) (y: u64) = unsign (intrinsics.udiv64 (sign x, sign y)) def (**) (x: u64) (y: u64) = unsign (intrinsics.pow64 (sign x, sign y)) def (%) (x: u64) (y: u64) = unsign (intrinsics.umod64 (sign x, sign y)) def (//) (x: u64) (y: u64) = unsign (intrinsics.udiv64 (sign x, sign y)) def (%%) (x: u64) (y: u64) = unsign (intrinsics.umod64 (sign x, sign y)) def (&) (x: u64) (y: u64) = unsign (intrinsics.and64 (sign x, sign y)) def (|) (x: u64) (y: u64) = unsign (intrinsics.or64 (sign x, sign y)) def (^) (x: u64) (y: u64) = unsign (intrinsics.xor64 (sign x, sign y)) def not (x: u64) = unsign (intrinsics.complement64 (sign x)) def (<<) (x: u64) (y: u64) = unsign (intrinsics.shl64 (sign x, sign y)) def (>>) (x: u64) (y: u64) = unsign (intrinsics.ashr64 (sign x, sign y)) def (>>>) (x: u64) (y: u64) = unsign (intrinsics.lshr64 (sign x, sign y)) def u8 (x: u8) = unsign (i64.u8 x) def u16 (x: u16) = unsign (i64.u16 x) def u32 (x: u32) = unsign (i64.u32 x) def u64 (x: u64) = unsign (i64.u64 x) def i8 (x: i8) = unsign (intrinsics.zext_i8_i64 x) def i16 (x: i16) = unsign (intrinsics.zext_i16_i64 x) def i32 (x: i32) = unsign (intrinsics.zext_i32_i64 x) def i64 (x: i64) = unsign (intrinsics.zext_i64_i64 x) def f16 (x: f16) = unsign (intrinsics.fptoui_f16_i64 x) def f32 (x: f32) = unsign (intrinsics.fptoui_f32_i64 x) def f64 (x: f64) = unsign (intrinsics.fptoui_f64_i64 x) def bool x = unsign (intrinsics.btoi_bool_i64 x) def to_i32 (x: u64) = intrinsics.zext_i64_i32 (sign x) def to_i64 (x: u64) = intrinsics.zext_i64_i64 (sign x) def (==) (x: u64) (y: u64) = intrinsics.eq_i64 (sign x, sign y) def (<) (x: u64) (y: u64) = intrinsics.ult64 (sign x, sign y) def (>) (x: u64) (y: u64) = intrinsics.ult64 (sign y, sign x) def (<=) (x: u64) (y: u64) = intrinsics.ule64 (sign x, sign y) def (>=) (x: u64) (y: u64) = intrinsics.ule64 (sign y, sign x) def (!=) (x: u64) (y: u64) = !(x == y) def sgn (x: u64) = unsign (intrinsics.usignum64 (sign x)) def abs (x: u64) = x def neg (x: t) = -x def max (x: t) (y: t) = unsign (intrinsics.umax64 (sign x, sign y)) def min (x: t) (y: t) = unsign (intrinsics.umin64 (sign x, sign y)) def highest = 18446744073709551615u64 def lowest = highest + 1u64 def num_bits = 64i32 def get_bit (bit: i32) (x: t) = to_i32 ((x >> i32 bit) & i32 1) def set_bit (bit: i32) (x: t) (b: i32) = ((x & i32 (!(1 intrinsics.<< bit))) | i32 (b intrinsics.<< bit)) def popc x = intrinsics.popc64 (sign x) def mul_hi a b = unsign (intrinsics.umul_hi64 (sign a, sign b)) def mad_hi a b c = unsign (intrinsics.umad_hi64 (sign a, sign b, sign c)) def clz x = intrinsics.clz64 (sign x) def ctz x = intrinsics.ctz64 (sign x) def sum = reduce (+) (i32 0) def product = reduce (*) (i32 1) def maximum = reduce max lowest def minimum = reduce min highest } module f64: (float with t = f64 with int_t = u64) = { type t = f64 type int_t = u64 module i64m = i64 module u64m = u64 def (+) (x: f64) (y: f64) = intrinsics.fadd64 (x, y) def (-) (x: f64) (y: f64) = intrinsics.fsub64 (x, y) def (*) (x: f64) (y: f64) = intrinsics.fmul64 (x, y) def (/) (x: f64) (y: f64) = intrinsics.fdiv64 (x, y) def (%) (x: f64) (y: f64) = intrinsics.fmod64 (x, y) def (**) (x: f64) (y: f64) = intrinsics.fpow64 (x, y) def u8 (x: u8) = intrinsics.uitofp_i8_f64 (i8.u8 x) def u16 (x: u16) = intrinsics.uitofp_i16_f64 (i16.u16 x) def u32 (x: u32) = intrinsics.uitofp_i32_f64 (i32.u32 x) def u64 (x: u64) = intrinsics.uitofp_i64_f64 (i64.u64 x) def i8 (x: i8) = intrinsics.sitofp_i8_f64 x def i16 (x: i16) = intrinsics.sitofp_i16_f64 x def i32 (x: i32) = intrinsics.sitofp_i32_f64 x def i64 (x: i64) = intrinsics.sitofp_i64_f64 x def f16 (x: f16) = intrinsics.fpconv_f16_f64 x def f32 (x: f32) = intrinsics.fpconv_f32_f64 x def f64 (x: f64) = intrinsics.fpconv_f64_f64 x def bool (x: bool) = intrinsics.btof_bool_f64 x def from_fraction (x: i64) (y: i64) = i64 x / i64 y def to_i64 (x: f64) = intrinsics.fptosi_f64_i64 x def to_f64 (x: f64) = x def (==) (x: f64) (y: f64) = intrinsics.eq_f64 (x, y) def (<) (x: f64) (y: f64) = intrinsics.lt64 (x, y) def (>) (x: f64) (y: f64) = intrinsics.lt64 (y, x) def (<=) (x: f64) (y: f64) = intrinsics.le64 (x, y) def (>=) (x: f64) (y: f64) = intrinsics.le64 (y, x) def (!=) (x: f64) (y: f64) = !(x == y) def neg (x: t) = -x def recip (x: t) = 1 / x def max (x: t) (y: t) = intrinsics.fmax64 (x, y) def min (x: t) (y: t) = intrinsics.fmin64 (x, y) def sgn (x: f64) = intrinsics.fsignum64 x def abs (x: f64) = intrinsics.fabs64 x def sqrt (x: f64) = intrinsics.sqrt64 x def cbrt (x: f64) = intrinsics.cbrt64 x def log (x: f64) = intrinsics.log64 x def log2 (x: f64) = intrinsics.log2_64 x def log10 (x: f64) = intrinsics.log10_64 x def log1p (x: f64) = intrinsics.log1p_64 x def exp (x: f64) = intrinsics.exp64 x def sin (x: f64) = intrinsics.sin64 x def cos (x: f64) = intrinsics.cos64 x def tan (x: f64) = intrinsics.tan64 x def acos (x: f64) = intrinsics.acos64 x def asin (x: f64) = intrinsics.asin64 x def atan (x: f64) = intrinsics.atan64 x def sinh (x: f64) = intrinsics.sinh64 x def cosh (x: f64) = intrinsics.cosh64 x def tanh (x: f64) = intrinsics.tanh64 x def acosh (x: f64) = intrinsics.acosh64 x def asinh (x: f64) = intrinsics.asinh64 x def atanh (x: f64) = intrinsics.atanh64 x def atan2 (x: f64) (y: f64) = intrinsics.atan2_64 (x, y) def hypot (x: f64) (y: f64) = intrinsics.hypot64 (x, y) def gamma = intrinsics.gamma64 def lgamma = intrinsics.lgamma64 def erf = intrinsics.erf64 def erfc = intrinsics.erfc64 def lerp v0 v1 t = intrinsics.lerp64 (v0, v1, t) def fma a b c = intrinsics.fma64 (a, b, c) def mad a b c = intrinsics.mad64 (a, b, c) def ceil = intrinsics.ceil64 def floor = intrinsics.floor64 def trunc (x: f64): f64 = i64 (i64m.f64 x) def round = intrinsics.round64 def nextafter x y = intrinsics.nextafter64 (x, y) def ldexp x y = intrinsics.ldexp64 (x, y) def copysign x y = intrinsics.copysign64 (x, y) def to_bits (x: f64): u64 = u64m.i64 (intrinsics.to_bits64 x) def from_bits (x: u64): f64 = intrinsics.from_bits64 (intrinsics.sign_i64 x) def num_bits = 64i32 def get_bit (bit: i32) (x: t) = u64m.get_bit bit (to_bits x) def set_bit (bit: i32) (x: t) (b: i32) = from_bits (u64m.set_bit bit (to_bits x) b) def isinf (x: f64) = intrinsics.isinf64 x def isnan (x: f64) = intrinsics.isnan64 x def inf = 1f64 / 0f64 def nan = 0f64 / 0f64 def highest = inf def lowest = -inf def epsilon = 2.220446049250313e-16f64 def pi = 3.1415926535897932384626433832795028841971693993751058209749445923078164062f64 def e = 2.718281828459045235360287471352662497757247093699959574966967627724076630353f64 def sum = reduce (+) (i32 0) def product = reduce (*) (i32 1) def maximum = reduce max lowest def minimum = reduce min highest } module f32: (float with t = f32 with int_t = u32) = { type t = f32 type int_t = u32 module i32m = i32 module u32m = u32 module f64m = f64 def (+) (x: f32) (y: f32) = intrinsics.fadd32 (x, y) def (-) (x: f32) (y: f32) = intrinsics.fsub32 (x, y) def (*) (x: f32) (y: f32) = intrinsics.fmul32 (x, y) def (/) (x: f32) (y: f32) = intrinsics.fdiv32 (x, y) def (%) (x: f32) (y: f32) = intrinsics.fmod32 (x, y) def (**) (x: f32) (y: f32) = intrinsics.fpow32 (x, y) def u8 (x: u8) = intrinsics.uitofp_i8_f32 (i8.u8 x) def u16 (x: u16) = intrinsics.uitofp_i16_f32 (i16.u16 x) def u32 (x: u32) = intrinsics.uitofp_i32_f32 (i32.u32 x) def u64 (x: u64) = intrinsics.uitofp_i64_f32 (i64.u64 x) def i8 (x: i8) = intrinsics.sitofp_i8_f32 x def i16 (x: i16) = intrinsics.sitofp_i16_f32 x def i32 (x: i32) = intrinsics.sitofp_i32_f32 x def i64 (x: i64) = intrinsics.sitofp_i64_f32 x def f16 (x: f16) = intrinsics.fpconv_f16_f32 x def f32 (x: f32) = intrinsics.fpconv_f32_f32 x def f64 (x: f64) = intrinsics.fpconv_f64_f32 x def bool (x: bool) = intrinsics.btof_bool_f32 x def from_fraction (x: i64) (y: i64) = i64 x / i64 y def to_i64 (x: f32) = intrinsics.fptosi_f32_i64 x def to_f64 (x: f32) = intrinsics.fpconv_f32_f64 x def (==) (x: f32) (y: f32) = intrinsics.eq_f32 (x, y) def (<) (x: f32) (y: f32) = intrinsics.lt32 (x, y) def (>) (x: f32) (y: f32) = intrinsics.lt32 (y, x) def (<=) (x: f32) (y: f32) = intrinsics.le32 (x, y) def (>=) (x: f32) (y: f32) = intrinsics.le32 (y, x) def (!=) (x: f32) (y: f32) = !(x == y) def neg (x: t) = -x def recip (x: t) = 1 / x def max (x: t) (y: t) = intrinsics.fmax32 (x, y) def min (x: t) (y: t) = intrinsics.fmin32 (x, y) def sgn (x: f32) = intrinsics.fsignum32 x def abs (x: f32) = intrinsics.fabs32 x def sqrt (x: f32) = intrinsics.sqrt32 x def cbrt (x: f32) = intrinsics.cbrt32 x def log (x: f32) = intrinsics.log32 x def log2 (x: f32) = intrinsics.log2_32 x def log10 (x: f32) = intrinsics.log10_32 x def log1p (x: f32) = intrinsics.log1p_32 x def exp (x: f32) = intrinsics.exp32 x def sin (x: f32) = intrinsics.sin32 x def cos (x: f32) = intrinsics.cos32 x def tan (x: f32) = intrinsics.tan32 x def acos (x: f32) = intrinsics.acos32 x def asin (x: f32) = intrinsics.asin32 x def atan (x: f32) = intrinsics.atan32 x def sinh (x: f32) = intrinsics.sinh32 x def cosh (x: f32) = intrinsics.cosh32 x def tanh (x: f32) = intrinsics.tanh32 x def acosh (x: f32) = intrinsics.acosh32 x def asinh (x: f32) = intrinsics.asinh32 x def atanh (x: f32) = intrinsics.atanh32 x def atan2 (x: f32) (y: f32) = intrinsics.atan2_32 (x, y) def hypot (x: f32) (y: f32) = intrinsics.hypot32 (x, y) def gamma = intrinsics.gamma32 def lgamma = intrinsics.lgamma32 def erf = intrinsics.erf32 def erfc = intrinsics.erfc32 def lerp v0 v1 t = intrinsics.lerp32 (v0, v1, t) def fma a b c = intrinsics.fma32 (a, b, c) def mad a b c = intrinsics.mad32 (a, b, c) def ceil = intrinsics.ceil32 def floor = intrinsics.floor32 def trunc (x: f32): f32 = i32 (i32m.f32 x) def round = intrinsics.round32 def nextafter x y = intrinsics.nextafter32 (x, y) def ldexp x y = intrinsics.ldexp32 (x, y) def copysign x y = intrinsics.copysign32 (x, y) def to_bits (x: f32): u32 = u32m.i32 (intrinsics.to_bits32 x) def from_bits (x: u32): f32 = intrinsics.from_bits32 (intrinsics.sign_i32 x) def num_bits = 32i32 def get_bit (bit: i32) (x: t) = u32m.get_bit bit (to_bits x) def set_bit (bit: i32) (x: t) (b: i32) = from_bits (u32m.set_bit bit (to_bits x) b) def isinf (x: f32) = intrinsics.isinf32 x def isnan (x: f32) = intrinsics.isnan32 x def inf = 1f32 / 0f32 def nan = 0f32 / 0f32 def highest = inf def lowest = -inf def epsilon = 1.1920929e-7f32 def pi = f64 f64m.pi def e = f64 f64m.e def sum = reduce (+) (i32 0) def product = reduce (*) (i32 1) def maximum = reduce max lowest def minimum = reduce min highest } -- | Emulated with single precision on systems that do not natively -- support half precision. This means you might get more accurate -- results than on real systems, but it is also likely to be -- significantly slower than just using `f32` in the first place. module f16: (float with t = f16 with int_t = u16) = { type t = f16 type int_t = u16 module i16m = i16 module u16m = u16 module f64m = f64 def (+) (x: f16) (y: f16) = intrinsics.fadd16 (x, y) def (-) (x: f16) (y: f16) = intrinsics.fsub16 (x, y) def (*) (x: f16) (y: f16) = intrinsics.fmul16 (x, y) def (/) (x: f16) (y: f16) = intrinsics.fdiv16 (x, y) def (%) (x: f16) (y: f16) = intrinsics.fmod16 (x, y) def (**) (x: f16) (y: f16) = intrinsics.fpow16 (x, y) def u8 (x: u8) = intrinsics.uitofp_i8_f16 (i8.u8 x) def u16 (x: u16) = intrinsics.uitofp_i16_f16 (i16.u16 x) def u32 (x: u32) = intrinsics.uitofp_i32_f16 (i32.u32 x) def u64 (x: u64) = intrinsics.uitofp_i64_f16 (i64.u64 x) def i8 (x: i8) = intrinsics.sitofp_i8_f16 x def i16 (x: i16) = intrinsics.sitofp_i16_f16 x def i32 (x: i32) = intrinsics.sitofp_i32_f16 x def i64 (x: i64) = intrinsics.sitofp_i64_f16 x def f16 (x: f16) = intrinsics.fpconv_f16_f16 x def f32 (x: f32) = intrinsics.fpconv_f32_f16 x def f64 (x: f64) = intrinsics.fpconv_f64_f16 x def bool (x: bool) = intrinsics.btof_bool_f16 x def from_fraction (x: i64) (y: i64) = i64 x / i64 y def to_i64 (x: f16) = intrinsics.fptosi_f16_i64 x def to_f64 (x: f16) = intrinsics.fpconv_f16_f64 x def (==) (x: f16) (y: f16) = intrinsics.eq_f16 (x, y) def (<) (x: f16) (y: f16) = intrinsics.lt16 (x, y) def (>) (x: f16) (y: f16) = intrinsics.lt16 (y, x) def (<=) (x: f16) (y: f16) = intrinsics.le16 (x, y) def (>=) (x: f16) (y: f16) = intrinsics.le16 (y, x) def (!=) (x: f16) (y: f16) = !(x == y) def neg (x: t) = -x def recip (x: t) = 1 / x def max (x: t) (y: t) = intrinsics.fmax16 (x, y) def min (x: t) (y: t) = intrinsics.fmin16 (x, y) def sgn (x: f16) = intrinsics.fsignum16 x def abs (x: f16) = intrinsics.fabs16 x def sqrt (x: f16) = intrinsics.sqrt16 x def cbrt (x: f16) = intrinsics.cbrt16 x def log (x: f16) = intrinsics.log16 x def log2 (x: f16) = intrinsics.log2_16 x def log10 (x: f16) = intrinsics.log10_16 x def log1p (x: f16) = intrinsics.log1p_16 x def exp (x: f16) = intrinsics.exp16 x def sin (x: f16) = intrinsics.sin16 x def cos (x: f16) = intrinsics.cos16 x def tan (x: f16) = intrinsics.tan16 x def acos (x: f16) = intrinsics.acos16 x def asin (x: f16) = intrinsics.asin16 x def atan (x: f16) = intrinsics.atan16 x def sinh (x: f16) = intrinsics.sinh16 x def cosh (x: f16) = intrinsics.cosh16 x def tanh (x: f16) = intrinsics.tanh16 x def acosh (x: f16) = intrinsics.acosh16 x def asinh (x: f16) = intrinsics.asinh16 x def atanh (x: f16) = intrinsics.atanh16 x def atan2 (x: f16) (y: f16) = intrinsics.atan2_16 (x, y) def hypot (x: f16) (y: f16) = intrinsics.hypot16 (x, y) def gamma = intrinsics.gamma16 def lgamma = intrinsics.lgamma16 def erf = intrinsics.erf16 def erfc = intrinsics.erfc16 def lerp v0 v1 t = intrinsics.lerp16 (v0, v1, t) def fma a b c = intrinsics.fma16 (a, b, c) def mad a b c = intrinsics.mad16 (a, b, c) def ceil = intrinsics.ceil16 def floor = intrinsics.floor16 def trunc (x: f16): f16 = i16 (i16m.f16 x) def round = intrinsics.round16 def nextafter x y = intrinsics.nextafter16 (x, y) def ldexp x y = intrinsics.ldexp16 (x, y) def copysign x y = intrinsics.copysign16 (x, y) def to_bits (x: f16): u16 = u16m.i16 (intrinsics.to_bits16 x) def from_bits (x: u16): f16 = intrinsics.from_bits16 (intrinsics.sign_i16 x) def num_bits = 16i32 def get_bit (bit: i32) (x: t) = u16m.get_bit bit (to_bits x) def set_bit (bit: i32) (x: t) (b: i32) = from_bits (u16m.set_bit bit (to_bits x) b) def isinf (x: f16) = intrinsics.isinf16 x def isnan (x: f16) = intrinsics.isnan16 x def inf = 1f16 / 0f16 def nan = 0f16 / 0f16 def highest = inf def lowest = -inf def epsilon = 1.1920929e-7f16 def pi = f64 f64m.pi def e = f64 f64m.e def sum = reduce (+) (i32 0) def product = reduce (*) (i32 1) def maximum = reduce max lowest def minimum = reduce min highest } futhark-0.25.27/prelude/prelude.fut000066400000000000000000000025071475065116200171660ustar00rootroot00000000000000-- | The default prelude that is implicitly available in all Futhark -- files. open import "soacs" open import "array" open import "math" open import "functional" open import "ad" -- | Create single-precision float from integer. def r32 (x: i32): f32 = f32.i32 x -- | Create integer from single-precision float. def t32 (x: f32): i32 = i32.f32 x -- | Create double-precision float from integer. def r64 (x: i32): f64 = f64.i32 x -- | Create integer from double-precision float. def t64 (x: f64): i32 = i32.f64 x -- | Negate a boolean. `not x` is the same as `!x`. This function is -- mostly useful for passing to `map`. def not (x: bool): bool = !x -- | Semantically just identity, but serves as an optimisation -- inhibitor. The compiler will treat this function as a black box. -- You can use this to work around optimisation deficiencies (or -- bugs), although it should hopefully rarely be necessary. -- Deprecated: use `#[opaque]` attribute instead. def opaque 't (x: t): t = #[opaque] x -- | Semantically just identity, but at runtime, the argument value -- will be printed. Deprecated: use `#[trace]` attribute instead. def trace 't (x: t): t = #[trace(trace)] x -- | Semantically just identity, but acts as a break point in -- `futhark repl`. Deprecated: use `#[break]` attribute instead. def break 't (x: t): t = #[break] x futhark-0.25.27/prelude/soacs.fut000066400000000000000000000220561475065116200166370ustar00rootroot00000000000000-- | Various Second-Order Array Combinators that are operationally -- parallel in a way that can be exploited by the compiler. -- -- The functions here are recognised specially by the compiler (or -- built on those that are). The asymptotic [work and -- span](https://en.wikipedia.org/wiki/Analysis_of_parallel_algorithms) -- is provided for each function, but note that this easily hides very -- substantial constant factors. For example, `scan`@term is *much* -- slower than `reduce`@term, although they have the same asymptotic -- complexity. -- -- **Higher-order complexity** -- -- Specifying the time complexity of higher-order functions is tricky -- because it depends on the functional argument. We use the informal -- convention that *W(f)* denotes the largest (asymptotic) *work* of -- function *f*, for the values it may be applied to. Similarly, -- *S(f)* denotes the largest span. See [this Wikipedia -- article](https://en.wikipedia.org/wiki/Analysis_of_parallel_algorithms) -- for a general introduction to these constructs. -- -- **Reminder on terminology** -- -- A function `op` is said to be *associative* if -- -- (x `op` y) `op` z == x `op` (y `op` z) -- -- for all `x`, `y`, `z`. Similarly, it is *commutative* if -- -- x `op` y == y `op` x -- -- The value `o` is a *neutral element* if -- -- x `op` o == o `op` x == x -- -- for any `x`. -- Implementation note: many of these definitions contain dynamically -- checked size casts. These will be removed by the compiler, but are -- necessary for the type checker, as the 'intrinsics' functions are -- not size-dependently typed. import "zip" -- | Apply the given function to each element of an array. -- -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* def map 'a [n] 'x (f: a -> x) (as: [n]a) : *[n]x = intrinsics.map f as -- | Apply the given function to each element of a single array. -- -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* def map1 'a [n] 'x (f: a -> x) (as: [n]a) : *[n]x = map f as -- | As `map1`@term, but with one more array. -- -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* def map2 'a 'b [n] 'x (f: a -> b -> x) (as: [n]a) (bs: [n]b) : *[n]x = map (\(a, b) -> f a b) (zip2 as bs) -- | As `map2`@term, but with one more array. -- -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* def map3 'a 'b 'c [n] 'x (f: a -> b -> c -> x) (as: [n]a) (bs: [n]b) (cs: [n]c) : *[n]x = map (\(a, b, c) -> f a b c) (zip3 as bs cs) -- | As `map3`@term, but with one more array. -- -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* def map4 'a 'b 'c 'd [n] 'x (f: a -> b -> c -> d -> x) (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) : *[n]x = map (\(a, b, c, d) -> f a b c d) (zip4 as bs cs ds) -- | As `map3`@term, but with one more array. -- -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(S(f))* def map5 'a 'b 'c 'd 'e [n] 'x (f: a -> b -> c -> d -> e -> x) (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) (es: [n]e) : *[n]x = map (\(a, b, c, d, e) -> f a b c d e) (zip5 as bs cs ds es) -- | Reduce the array `as` with `op`, with `ne` as the neutral -- element for `op`. The function `op` must be associative. If -- it is not, the return value is unspecified. If the value returned -- by the operator is an array, it must have the exact same size as -- the neutral element, and that must again have the same size as the -- elements of the input array. -- -- **Work:** *O(n ✕ W(op))* -- -- **Span:** *O(log(n) ✕ W(op))* -- -- Note that the complexity implies that parallelism in the combining -- operator will *not* be exploited. def reduce [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a) : a = intrinsics.reduce op ne as -- | As `reduce`, but the operator must also be commutative. This is -- potentially faster than `reduce`. For simple built-in operators, -- like addition, the compiler already knows that the operator is -- commutative, so plain `reduce`@term will work just as well. -- -- **Work:** *O(n ✕ W(op))* -- -- **Span:** *O(log(n) ✕ W(op))* def reduce_comm [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a) : a = intrinsics.reduce_comm op ne as -- | `h = hist op ne k is as` computes a generalised `k`-bin histogram -- `h`, such that `h[i]` is the sum of those values `as[j]` for which -- `is[j]==i`. The summation is done with `op`, which must be a -- commutative and associative function with neutral element `ne`. If -- a bin has no elements, its value will be `ne`. -- -- **Work:** *O(k + n ✕ W(op))* -- -- **Span:** *O(n ✕ W(op))* in the worst case (all updates to same -- position), but *O(W(op))* in the best case. -- -- In practice, linear span only occurs if *k* is also very large. def hist 'a [n] (op: a -> a -> a) (ne: a) (k: i64) (is: [n]i64) (as: [n]a) : *[k]a = intrinsics.hist_1d 1 (map (\_ -> ne) (0..1.. a -> a) (ne: a) (is: [n]i64) (as: [n]a) : *[k]a = intrinsics.hist_1d 1 dest f ne is as -- | As `reduce_by_index`, but with two-dimensional indexes. def reduce_by_index_2d 'a [k] [n] [m] (dest: *[k][m]a) (f: a -> a -> a) (ne: a) (is: [n](i64, i64)) (as: [n]a) : *[k][m]a = intrinsics.hist_2d 1 dest f ne is as -- | As `reduce_by_index`, but with three-dimensional indexes. def reduce_by_index_3d 'a [k] [n] [m] [l] (dest: *[k][m][l]a) (f: a -> a -> a) (ne: a) (is: [n](i64, i64, i64)) (as: [n]a) : *[k][m][l]a = intrinsics.hist_3d 1 dest f ne is as -- | Inclusive prefix scan. Has the same caveats with respect to -- associativity and complexity as `reduce`. -- -- **Work:** *O(n ✕ W(op))* -- -- **Span:** *O(log(n) ✕ W(op))* def scan [n] 'a (op: a -> a -> a) (ne: a) (as: [n]a) : *[n]a = intrinsics.scan op ne as -- | Split an array into those elements that satisfy the given -- predicate, and those that do not. -- -- **Work:** *O(n ✕ W(p))* -- -- **Span:** *O(log(n) ✕ W(p))* def partition [n] 'a (p: a -> bool) (as: [n]a) : ?[k].([k]a, [n - k]a) = let p' x = if p x then 0 else 1 let (as', is) = intrinsics.partition 2 p' as in (as'[0:is[0]], as'[is[0]:n]) -- | Split an array by two predicates, producing three arrays. -- -- **Work:** *O(n ✕ (W(p1) + W(p2)))* -- -- **Span:** *O(log(n) ✕ (W(p1) + W(p2)))* def partition2 [n] 'a (p1: a -> bool) (p2: a -> bool) (as: [n]a) : ?[k][l].([k]a, [l]a, [n - k - l]a) = let p' x = if p1 x then 0 else if p2 x then 1 else 2 let (as', is) = intrinsics.partition 3 p' as in ( as'[0:is[0]] , as'[is[0]:is[0] + is[1]] :> [is[1]]a , as'[is[0] + is[1]:n] :> [n - is[0] - is[1]]a ) -- | Return `true` if the given function returns `true` for all -- elements in the array. -- -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(log(n) + S(f))* def all [n] 'a (f: a -> bool) (as: [n]a) : bool = reduce (&&) true (map f as) -- | Return `true` if the given function returns `true` for any -- elements in the array. -- -- **Work:** *O(n ✕ W(f))* -- -- **Span:** *O(log(n) + S(f))* def any [n] 'a (f: a -> bool) (as: [n]a) : bool = reduce (||) false (map f as) -- | `r = spread k x is vs` produces an array `r` such that `r[i] = -- vs[j]` where `is[j] == i`, or `x` if no such `j` exists. -- Intuitively, `is` is an array indicating where the corresponding -- elements of `vs` should be located in the result. Out-of-bounds -- elements of `is` are ignored. In-bounds duplicates in `is` result -- in unspecified behaviour - see `hist`@term for a function that can -- handle this. -- -- **Work:** *O(k + n)* -- -- **Span:** *O(1)* def spread 't [n] (k: i64) (x: t) (is: [n]i64) (vs: [n]t) : *[k]t = intrinsics.scatter (map (\_ -> x) (0..1.. bool) (as: [n]a) : *[]a = let flags = map (\x -> if p x then 1 else 0) as let offsets = scan (+) 0 flags let m = if n == 0 then 0 else offsets[n - 1] in scatter (map (\x -> x) as[:m]) (map2 (\f o -> if f == 1 then o - 1 else -1) flags offsets) as futhark-0.25.27/prelude/zip.fut000066400000000000000000000046211475065116200163270ustar00rootroot00000000000000-- | Transforming arrays of tuples into tuples of arrays and back -- again. -- -- These are generally very cheap operations, as the internal compiler -- representation is always tuples of arrays. -- The main reason this module exists is that we need it to define -- SOACs like `map2`. -- We need a map to define some of the zip variants, but this file is -- depended upon by soacs.fut. So we just define a quick-and-dirty -- internal one here that uses the intrinsic version. local def internal_map 'a [n] 'x (f: a -> x) (as: [n]a) : *[n]x = intrinsics.map f as -- | Construct an array of pairs from two arrays. def zip [n] 'a 'b (as: [n]a) (bs: [n]b) : *[n](a, b) = intrinsics.zip as bs -- | Construct an array of pairs from two arrays. def zip2 [n] 'a 'b (as: [n]a) (bs: [n]b) : *[n](a, b) = zip as bs -- | As `zip2`@term, but with one more array. def zip3 [n] 'a 'b 'c (as: [n]a) (bs: [n]b) (cs: [n]c) : *[n](a, b, c) = internal_map (\(a, (b, c)) -> (a, b, c)) (zip as (zip2 bs cs)) -- | As `zip3`@term, but with one more array. def zip4 [n] 'a 'b 'c 'd (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) : *[n](a, b, c, d) = internal_map (\(a, (b, c, d)) -> (a, b, c, d)) (zip as (zip3 bs cs ds)) -- | As `zip4`@term, but with one more array. def zip5 [n] 'a 'b 'c 'd 'e (as: [n]a) (bs: [n]b) (cs: [n]c) (ds: [n]d) (es: [n]e) : *[n](a, b, c, d, e) = internal_map (\(a, (b, c, d, e)) -> (a, b, c, d, e)) (zip as (zip4 bs cs ds es)) -- | Turn an array of pairs into two arrays. def unzip [n] 'a 'b (xs: [n](a, b)) : ([n]a, [n]b) = intrinsics.unzip xs -- | Turn an array of pairs into two arrays. def unzip2 [n] 'a 'b (xs: [n](a, b)) : ([n]a, [n]b) = unzip xs -- | As `unzip2`@term, but with one more array. def unzip3 [n] 'a 'b 'c (xs: [n](a, b, c)) : ([n]a, [n]b, [n]c) = let (as, bcs) = unzip (internal_map (\(a, b, c) -> (a, (b, c))) xs) let (bs, cs) = unzip bcs in (as, bs, cs) -- | As `unzip3`@term, but with one more array. def unzip4 [n] 'a 'b 'c 'd (xs: [n](a, b, c, d)) : ([n]a, [n]b, [n]c, [n]d) = let (as, bs, cds) = unzip3 (internal_map (\(a, b, c, d) -> (a, b, (c, d))) xs) let (cs, ds) = unzip cds in (as, bs, cs, ds) -- | As `unzip4`@term, but with one more array. def unzip5 [n] 'a 'b 'c 'd 'e (xs: [n](a, b, c, d, e)) : ([n]a, [n]b, [n]c, [n]d, [n]e) = let (as, bs, cs, des) = unzip4 (internal_map (\(a, b, c, d, e) -> (a, b, c, (d, e))) xs) let (ds, es) = unzip des in (as, bs, cs, ds, es) futhark-0.25.27/pyproject.toml000066400000000000000000000001101475065116200162460ustar00rootroot00000000000000[tool.black] line-length = 79 exclude = "benchmark-performance-plot.py" futhark-0.25.27/rts/000077500000000000000000000000001475065116200141525ustar00rootroot00000000000000futhark-0.25.27/rts/README.md000066400000000000000000000004031475065116200154260ustar00rootroot00000000000000Futhark runtime system directory ================================ This directory contains bits and pieces used in generated code. It is put here instead of embedded in the compiler source code to ease modification and, hopefully, enable standalone testing. futhark-0.25.27/rts/c/000077500000000000000000000000001475065116200143745ustar00rootroot00000000000000futhark-0.25.27/rts/c/STYLE.md000066400000000000000000000015151475065116200156200ustar00rootroot00000000000000Coding Style for Futhark C Runtime Component == * Use two spaces for indentation, no tabs. * Try to stay below 80 characters per line. * Braces are mandatory for control flow structures. * Use only line comments. * Use snake_case for naming, except preprocessor macros, which are uppercase. * Check all return values. * Do not use header guards, and do not include one RTS header from another. The header files here are not intended to be used as normal C header files, but are instead copied into the generated program in a specific order. An argument could be made that perhaps they ought be `.c` files instead. * Start all files with the comment `// Start of foo.h.` and end with `// End of foo.h.`. This makes the concatenated code easier to navigate. * Ensure, as far as possible, that the code is also valid C++. futhark-0.25.27/rts/c/atomics.h000066400000000000000000000432751475065116200162170ustar00rootroot00000000000000// Start of atomics.h SCALAR_FUN_ATTR int32_t atomic_xchg_i32_global(volatile __global int32_t *p, int32_t x); SCALAR_FUN_ATTR int32_t atomic_xchg_i32_shared(volatile __local int32_t *p, int32_t x); SCALAR_FUN_ATTR int32_t atomic_cmpxchg_i32_global(volatile __global int32_t *p, int32_t cmp, int32_t val); SCALAR_FUN_ATTR int32_t atomic_cmpxchg_i32_shared(volatile __local int32_t *p, int32_t cmp, int32_t val); SCALAR_FUN_ATTR int32_t atomic_add_i32_global(volatile __global int32_t *p, int32_t x); SCALAR_FUN_ATTR int32_t atomic_add_i32_shared(volatile __local int32_t *p, int32_t x); SCALAR_FUN_ATTR float atomic_fadd_f32_global(volatile __global float *p, float x); SCALAR_FUN_ATTR float atomic_fadd_f32_shared(volatile __local float *p, float x); SCALAR_FUN_ATTR int32_t atomic_smax_i32_global(volatile __global int32_t *p, int32_t x); SCALAR_FUN_ATTR int32_t atomic_smax_i32_shared(volatile __local int32_t *p, int32_t x); SCALAR_FUN_ATTR int32_t atomic_smin_i32_global(volatile __global int32_t *p, int32_t x); SCALAR_FUN_ATTR int32_t atomic_smin_i32_shared(volatile __local int32_t *p, int32_t x); SCALAR_FUN_ATTR uint32_t atomic_umax_i32_global(volatile __global uint32_t *p, uint32_t x); SCALAR_FUN_ATTR uint32_t atomic_umax_i32_shared(volatile __local uint32_t *p, uint32_t x); SCALAR_FUN_ATTR uint32_t atomic_umin_i32_global(volatile __global uint32_t *p, uint32_t x); SCALAR_FUN_ATTR uint32_t atomic_umin_i32_shared(volatile __local uint32_t *p, uint32_t x); SCALAR_FUN_ATTR int32_t atomic_and_i32_global(volatile __global int32_t *p, int32_t x); SCALAR_FUN_ATTR int32_t atomic_and_i32_shared(volatile __local int32_t *p, int32_t x); SCALAR_FUN_ATTR int32_t atomic_or_i32_global(volatile __global int32_t *p, int32_t x); SCALAR_FUN_ATTR int32_t atomic_or_i32_shared(volatile __local int32_t *p, int32_t x); SCALAR_FUN_ATTR int32_t atomic_xor_i32_global(volatile __global int32_t *p, int32_t x); SCALAR_FUN_ATTR int32_t atomic_xor_i32_shared(volatile __local int32_t *p, int32_t x); SCALAR_FUN_ATTR int32_t atomic_xchg_i32_global(volatile __global int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicExch((int32_t*)p, x); #else return atomic_xor(p, x); #endif } SCALAR_FUN_ATTR int32_t atomic_xchg_i32_shared(volatile __local int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicExch((int32_t*)p, x); #else return atomic_xor(p, x); #endif } SCALAR_FUN_ATTR int32_t atomic_cmpxchg_i32_global(volatile __global int32_t *p, int32_t cmp, int32_t val) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicCAS((int32_t*)p, cmp, val); #else return atomic_cmpxchg(p, cmp, val); #endif } SCALAR_FUN_ATTR int32_t atomic_cmpxchg_i32_shared(volatile __local int32_t *p, int32_t cmp, int32_t val) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicCAS((int32_t*)p, cmp, val); #else return atomic_cmpxchg(p, cmp, val); #endif } SCALAR_FUN_ATTR int32_t atomic_add_i32_global(volatile __global int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicAdd((int32_t*)p, x); #else return atomic_add(p, x); #endif } SCALAR_FUN_ATTR int32_t atomic_add_i32_shared(volatile __local int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicAdd((int32_t*)p, x); #else return atomic_add(p, x); #endif } SCALAR_FUN_ATTR float atomic_fadd_f32_global(volatile __global float *p, float x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicAdd((float*)p, x); // On OpenCL, use technique from // https://pipinspace.github.io/blog/atomic-float-addition-in-opencl.html #elif defined(cl_nv_pragma_unroll) // use hardware-supported atomic addition on Nvidia GPUs with inline // PTX assembly float ret; asm volatile("atom.global.add.f32 %0,[%1],%2;":"=f"(ret):"l"(p),"f"(x):"memory"); return ret; #elif defined(__opencl_c_ext_fp32_global_atomic_add) // use hardware-supported atomic addition on some Intel GPUs return atomic_fetch_add_explicit((volatile __global atomic_float*)p, x, memory_order_relaxed); #elif __has_builtin(__builtin_amdgcn_global_atomic_fadd_f32) // use hardware-supported atomic addition on some AMD GPUs return __builtin_amdgcn_global_atomic_fadd_f32(p, x); #else // fallback emulation: // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639/5 float old = x; float ret; while ((old=atomic_xchg(p, ret=atomic_xchg(p, 0.0f)+old))!=0.0f); return ret; #endif } SCALAR_FUN_ATTR float atomic_fadd_f32_shared(volatile __local float *p, float x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicAdd((float*)p, x); #else union { int32_t i; float f; } old; union { int32_t i; float f; } assumed; old.f = *p; do { assumed.f = old.f; old.f = old.f + x; old.i = atomic_cmpxchg_i32_shared((volatile __local int32_t*)p, assumed.i, old.i); } while (assumed.i != old.i); return old.f; #endif } SCALAR_FUN_ATTR int32_t atomic_smax_i32_global(volatile __global int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicMax((int32_t*)p, x); #else return atomic_max(p, x); #endif } SCALAR_FUN_ATTR int32_t atomic_smax_i32_shared(volatile __local int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicMax((int32_t*)p, x); #else return atomic_max(p, x); #endif } SCALAR_FUN_ATTR int32_t atomic_smin_i32_global(volatile __global int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicMin((int32_t*)p, x); #else return atomic_min(p, x); #endif } SCALAR_FUN_ATTR int32_t atomic_smin_i32_shared(volatile __local int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicMin((int32_t*)p, x); #else return atomic_min(p, x); #endif } SCALAR_FUN_ATTR uint32_t atomic_umax_i32_global(volatile __global uint32_t *p, uint32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicMax((uint32_t*)p, x); #else return atomic_max(p, x); #endif } SCALAR_FUN_ATTR uint32_t atomic_umax_i32_shared(volatile __local uint32_t *p, uint32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicMax((uint32_t*)p, x); #else return atomic_max(p, x); #endif } SCALAR_FUN_ATTR uint32_t atomic_umin_i32_global(volatile __global uint32_t *p, uint32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicMin((uint32_t*)p, x); #else return atomic_min(p, x); #endif } SCALAR_FUN_ATTR uint32_t atomic_umin_i32_shared(volatile __local uint32_t *p, uint32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicMin((uint32_t*)p, x); #else return atomic_min(p, x); #endif } SCALAR_FUN_ATTR int32_t atomic_and_i32_global(volatile __global int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicAnd((int32_t*)p, x); #else return atomic_and(p, x); #endif } SCALAR_FUN_ATTR int32_t atomic_and_i32_shared(volatile __local int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicAnd((int32_t*)p, x); #else return atomic_and(p, x); #endif } SCALAR_FUN_ATTR int32_t atomic_or_i32_global(volatile __global int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicOr((int32_t*)p, x); #else return atomic_or(p, x); #endif } SCALAR_FUN_ATTR int32_t atomic_or_i32_shared(volatile __local int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicOr((int32_t*)p, x); #else return atomic_or(p, x); #endif } SCALAR_FUN_ATTR int32_t atomic_xor_i32_global(volatile __global int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicXor((int32_t*)p, x); #else return atomic_xor(p, x); #endif } SCALAR_FUN_ATTR int32_t atomic_xor_i32_shared(volatile __local int32_t *p, int32_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicXor((int32_t*)p, x); #else return atomic_xor(p, x); #endif } // Start of 64 bit atomics #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) || defined(cl_khr_int64_base_atomics) && defined(cl_khr_int64_extended_atomics) SCALAR_FUN_ATTR int64_t atomic_xchg_i64_global(volatile __global int64_t *p, int64_t x); SCALAR_FUN_ATTR int64_t atomic_xchg_i64_shared(volatile __local int64_t *p, int64_t x); SCALAR_FUN_ATTR int64_t atomic_cmpxchg_i64_global(volatile __global int64_t *p, int64_t cmp, int64_t val); SCALAR_FUN_ATTR int64_t atomic_cmpxchg_i64_shared(volatile __local int64_t *p, int64_t cmp, int64_t val); SCALAR_FUN_ATTR int64_t atomic_add_i64_global(volatile __global int64_t *p, int64_t x); SCALAR_FUN_ATTR int64_t atomic_add_i64_shared(volatile __local int64_t *p, int64_t x); SCALAR_FUN_ATTR int64_t atomic_smax_i64_global(volatile __global int64_t *p, int64_t x); SCALAR_FUN_ATTR int64_t atomic_smax_i64_shared(volatile __local int64_t *p, int64_t x); SCALAR_FUN_ATTR int64_t atomic_smin_i64_global(volatile __global int64_t *p, int64_t x); SCALAR_FUN_ATTR int64_t atomic_smin_i64_shared(volatile __local int64_t *p, int64_t x); SCALAR_FUN_ATTR uint64_t atomic_umax_i64_global(volatile __global uint64_t *p, uint64_t x); SCALAR_FUN_ATTR uint64_t atomic_umax_i64_shared(volatile __local uint64_t *p, uint64_t x); SCALAR_FUN_ATTR uint64_t atomic_umin_i64_global(volatile __global uint64_t *p, uint64_t x); SCALAR_FUN_ATTR uint64_t atomic_umin_i64_shared(volatile __local uint64_t *p, uint64_t x); SCALAR_FUN_ATTR int64_t atomic_and_i64_global(volatile __global int64_t *p, int64_t x); SCALAR_FUN_ATTR int64_t atomic_and_i64_shared(volatile __local int64_t *p, int64_t x); SCALAR_FUN_ATTR int64_t atomic_or_i64_global(volatile __global int64_t *p, int64_t x); SCALAR_FUN_ATTR int64_t atomic_or_i64_shared(volatile __local int64_t *p, int64_t x); SCALAR_FUN_ATTR int64_t atomic_xor_i64_global(volatile __global int64_t *p, int64_t x); SCALAR_FUN_ATTR int64_t atomic_xor_i64_shared(volatile __local int64_t *p, int64_t x); #ifdef FUTHARK_F64_ENABLED SCALAR_FUN_ATTR double atomic_fadd_f64_global(volatile __global double *p, double x); SCALAR_FUN_ATTR double atomic_fadd_f64_shared(volatile __local double *p, double x); #endif SCALAR_FUN_ATTR int64_t atomic_xchg_i64_global(volatile __global int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicExch((uint64_t*)p, x); #else return atom_xor(p, x); #endif } SCALAR_FUN_ATTR int64_t atomic_xchg_i64_shared(volatile __local int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicExch((uint64_t*)p, x); #else return atom_xor(p, x); #endif } SCALAR_FUN_ATTR int64_t atomic_cmpxchg_i64_global(volatile __global int64_t *p, int64_t cmp, int64_t val) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicCAS((uint64_t*)p, cmp, val); #else return atom_cmpxchg(p, cmp, val); #endif } SCALAR_FUN_ATTR int64_t atomic_cmpxchg_i64_shared(volatile __local int64_t *p, int64_t cmp, int64_t val) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicCAS((uint64_t*)p, cmp, val); #else return atom_cmpxchg(p, cmp, val); #endif } SCALAR_FUN_ATTR int64_t atomic_add_i64_global(volatile __global int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicAdd((uint64_t*)p, x); #else return atom_add(p, x); #endif } SCALAR_FUN_ATTR int64_t atomic_add_i64_shared(volatile __local int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicAdd((uint64_t*)p, x); #else return atom_add(p, x); #endif } #ifdef FUTHARK_F64_ENABLED SCALAR_FUN_ATTR double atomic_fadd_f64_global(volatile __global double *p, double x) { #if defined(FUTHARK_CUDA) && __CUDA_ARCH__ >= 600 || defined(FUTHARK_HIP) return atomicAdd((double*)p, x); // On OpenCL, use technique from // https://pipinspace.github.io/blog/atomic-float-addition-in-opencl.html #elif defined(cl_nv_pragma_unroll) // use hardware-supported atomic addition on Nvidia GPUs with inline // PTX assembly double ret; asm volatile("atom.global.add.f64 %0,[%1],%2;":"=d"(ret):"l"(p),"d"(x):"memory"); return ret; #elif __has_builtin(__builtin_amdgcn_global_atomic_fadd_f64) // use hardware-supported atomic addition on some AMD GPUs return __builtin_amdgcn_global_atomic_fadd_f64(p, x); #else // fallback emulation: // https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639/5 union {int64_t i; double f;} old; union {int64_t i; double f;} ret; old.f = x; while (1) { ret.i = atom_xchg((volatile __global int64_t*)p, (int64_t)0); ret.f += old.f; old.i = atom_xchg((volatile __global int64_t*)p, ret.i); if (old.i == 0) { break; } } return ret.f; #endif } SCALAR_FUN_ATTR double atomic_fadd_f64_shared(volatile __local double *p, double x) { #if defined(FUTHARK_CUDA) && __CUDA_ARCH__ >= 600 || defined(FUTHARK_HIP) return atomicAdd((double*)p, x); #else union { int64_t i; double f; } old; union { int64_t i; double f; } assumed; old.f = *p; do { assumed.f = old.f; old.f = old.f + x; old.i = atomic_cmpxchg_i64_shared((volatile __local int64_t*)p, assumed.i, old.i); } while (assumed.i != old.i); return old.f; #endif } #endif SCALAR_FUN_ATTR int64_t atomic_smax_i64_global(volatile __global int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) return atomicMax((int64_t*)p, x); #elif defined(FUTHARK_HIP) // Currentely missing in HIP; probably a temporary oversight. int64_t old = *p, assumed; do { assumed = old; old = smax64(old, x); old = atomic_cmpxchg_i64_global((volatile __global int64_t*)p, assumed, old); } while (assumed != old); return old; #else return atom_max(p, x); #endif } SCALAR_FUN_ATTR int64_t atomic_smax_i64_shared(volatile __local int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) return atomicMax((int64_t*)p, x); #elif defined(FUTHARK_HIP) // Currentely missing in HIP; probably a temporary oversight. int64_t old = *p, assumed; do { assumed = old; old = smax64(old, x); old = atomic_cmpxchg_i64_shared((volatile __local int64_t*)p, assumed, old); } while (assumed != old); return old; #else return atom_max(p, x); #endif } SCALAR_FUN_ATTR int64_t atomic_smin_i64_global(volatile __global int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) return atomicMin((int64_t*)p, x); #elif defined(FUTHARK_HIP) // Currentely missing in HIP; probably a temporary oversight. int64_t old = *p, assumed; do { assumed = old; old = smin64(old, x); old = atomic_cmpxchg_i64_global((volatile __global int64_t*)p, assumed, old); } while (assumed != old); return old; #else return atom_min(p, x); #endif } SCALAR_FUN_ATTR int64_t atomic_smin_i64_shared(volatile __local int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) return atomicMin((int64_t*)p, x); #elif defined(FUTHARK_HIP) // Currentely missing in HIP; probably a temporary oversight. int64_t old = *p, assumed; do { assumed = old; old = smin64(old, x); old = atomic_cmpxchg_i64_shared((volatile __local int64_t*)p, assumed, old); } while (assumed != old); return old; #else return atom_min(p, x); #endif } SCALAR_FUN_ATTR uint64_t atomic_umax_i64_global(volatile __global uint64_t *p, uint64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicMax((uint64_t*)p, x); #else return atom_max(p, x); #endif } SCALAR_FUN_ATTR uint64_t atomic_umax_i64_shared(volatile __local uint64_t *p, uint64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicMax((uint64_t*)p, x); #else return atom_max(p, x); #endif } SCALAR_FUN_ATTR uint64_t atomic_umin_i64_global(volatile __global uint64_t *p, uint64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicMin((uint64_t*)p, x); #else return atom_min(p, x); #endif } SCALAR_FUN_ATTR uint64_t atomic_umin_i64_shared(volatile __local uint64_t *p, uint64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicMin((uint64_t*)p, x); #else return atom_min(p, x); #endif } SCALAR_FUN_ATTR int64_t atomic_and_i64_global(volatile __global int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicAnd((uint64_t*)p, x); #else return atom_and(p, x); #endif } SCALAR_FUN_ATTR int64_t atomic_and_i64_shared(volatile __local int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicAnd((uint64_t*)p, x); #else return atom_and(p, x); #endif } SCALAR_FUN_ATTR int64_t atomic_or_i64_global(volatile __global int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicOr((uint64_t*)p, x); #else return atom_or(p, x); #endif } SCALAR_FUN_ATTR int64_t atomic_or_i64_shared(volatile __local int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicOr((uint64_t*)p, x); #else return atom_or(p, x); #endif } SCALAR_FUN_ATTR int64_t atomic_xor_i64_global(volatile __global int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicXor((uint64_t*)p, x); #else return atom_xor(p, x); #endif } SCALAR_FUN_ATTR int64_t atomic_xor_i64_shared(volatile __local int64_t *p, int64_t x) { #if defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) return atomicXor((uint64_t*)p, x); #else return atom_xor(p, x); #endif } #endif // defined(FUTHARK_CUDA) || defined(FUTHARK_HIP) || defined(cl_khr_int64_base_atomics) && defined(cl_khr_int64_extended_atomics) // End of atomics.h futhark-0.25.27/rts/c/backends/000077500000000000000000000000001475065116200161465ustar00rootroot00000000000000futhark-0.25.27/rts/c/backends/c.h000066400000000000000000000025641475065116200165500ustar00rootroot00000000000000// Start of backends/c.h struct futhark_context_config { int in_use; int debugging; int profiling; int logging; char *cache_fname; int num_tuning_params; int64_t *tuning_params; const char** tuning_param_names; const char** tuning_param_vars; const char** tuning_param_classes; }; static void backend_context_config_setup(struct futhark_context_config* cfg) { (void)cfg; } static void backend_context_config_teardown(struct futhark_context_config* cfg) { (void)cfg; } int futhark_context_config_set_tuning_param(struct futhark_context_config* cfg, const char *param_name, size_t param_value) { (void)cfg; (void)param_name; (void)param_value; return 1; } struct futhark_context { struct futhark_context_config* cfg; int detail_memory; int debugging; int profiling; int profiling_paused; int logging; lock_t lock; char *error; lock_t error_lock; FILE *log; struct constants *constants; struct free_list free_list; struct event_list event_list; int64_t peak_mem_usage_default; int64_t cur_mem_usage_default; struct program* program; bool program_initialised; }; int backend_context_setup(struct futhark_context* ctx) { (void)ctx; return 0; } void backend_context_teardown(struct futhark_context* ctx) { (void)ctx; } int futhark_context_sync(struct futhark_context* ctx) { (void)ctx; return 0; } // End of backends/c.h futhark-0.25.27/rts/c/backends/cuda.h000066400000000000000000001032701475065116200172360ustar00rootroot00000000000000// Start of backends/cuda.h. // Forward declarations. // Invoked by setup_opencl() after the platform and device has been // found, but before the program is loaded. Its intended use is to // tune constants based on the selected platform and device. static void set_tuning_params(struct futhark_context* ctx); static char* get_failure_msg(int failure_idx, int64_t args[]); #define CUDA_SUCCEED_FATAL(x) cuda_api_succeed_fatal(x, #x, __FILE__, __LINE__) #define CUDA_SUCCEED_NONFATAL(x) cuda_api_succeed_nonfatal(x, #x, __FILE__, __LINE__) #define NVRTC_SUCCEED_FATAL(x) nvrtc_api_succeed_fatal(x, #x, __FILE__, __LINE__) #define NVRTC_SUCCEED_NONFATAL(x) nvrtc_api_succeed_nonfatal(x, #x, __FILE__, __LINE__) // Take care not to override an existing error. #define CUDA_SUCCEED_OR_RETURN(e) { \ char *serror = CUDA_SUCCEED_NONFATAL(e); \ if (serror) { \ if (!ctx->error) { \ ctx->error = serror; \ } else { \ free(serror); \ } \ return bad; \ } \ } // CUDA_SUCCEED_OR_RETURN returns the value of the variable 'bad' in // scope. By default, it will be this one. Create a local variable // of some other type if needed. This is a bit of a hack, but it // saves effort in the code generator. static const int bad = 1; static inline void cuda_api_succeed_fatal(CUresult res, const char *call, const char *file, int line) { if (res != CUDA_SUCCESS) { const char *err_str; cuGetErrorString(res, &err_str); if (err_str == NULL) { err_str = "Unknown"; } futhark_panic(-1, "%s:%d: CUDA call\n %s\nfailed with error code %d (%s)\n", file, line, call, res, err_str); } } static char* cuda_api_succeed_nonfatal(CUresult res, const char *call, const char *file, int line) { if (res != CUDA_SUCCESS) { const char *err_str; cuGetErrorString(res, &err_str); if (err_str == NULL) { err_str = "Unknown"; } return msgprintf("%s:%d: CUDA call\n %s\nfailed with error code %d (%s)\n", file, line, call, res, err_str); } else { return NULL; } } static inline void nvrtc_api_succeed_fatal(nvrtcResult res, const char *call, const char *file, int line) { if (res != NVRTC_SUCCESS) { const char *err_str = nvrtcGetErrorString(res); futhark_panic(-1, "%s:%d: NVRTC call\n %s\nfailed with error code %d (%s)\n", file, line, call, res, err_str); } } static char* nvrtc_api_succeed_nonfatal(nvrtcResult res, const char *call, const char *file, int line) { if (res != NVRTC_SUCCESS) { const char *err_str = nvrtcGetErrorString(res); return msgprintf("%s:%d: NVRTC call\n %s\nfailed with error code %d (%s)\n", file, line, call, res, err_str); } else { return NULL; } } struct futhark_context_config { int in_use; int debugging; int profiling; int logging; char* cache_fname; int num_tuning_params; int64_t *tuning_params; const char** tuning_param_names; const char** tuning_param_vars; const char** tuning_param_classes; // Uniform fields above. char* program; int num_nvrtc_opts; char* *nvrtc_opts; char* preferred_device; int preferred_device_num; int unified_memory; char* dump_ptx_to; char* load_ptx_from; struct gpu_config gpu; }; static void backend_context_config_setup(struct futhark_context_config *cfg) { cfg->num_nvrtc_opts = 0; cfg->nvrtc_opts = (char**) malloc(sizeof(char*)); cfg->nvrtc_opts[0] = NULL; cfg->program = strconcat(gpu_program); cfg->preferred_device_num = 0; cfg->preferred_device = strdup(""); cfg->dump_ptx_to = NULL; cfg->load_ptx_from = NULL; cfg->unified_memory = 2; cfg->gpu = gpu_config_initial; cfg->gpu.default_block_size = 256; cfg->gpu.default_tile_size = 32; cfg->gpu.default_reg_tile_size = 2; cfg->gpu.default_threshold = 32*1024; } static void backend_context_config_teardown(struct futhark_context_config* cfg) { for (int i = 0; i < cfg->num_nvrtc_opts; i++) { free(cfg->nvrtc_opts[i]); } free(cfg->nvrtc_opts); free(cfg->dump_ptx_to); free(cfg->load_ptx_from); free(cfg->preferred_device); free(cfg->program); } void futhark_context_config_add_nvrtc_option(struct futhark_context_config *cfg, const char *opt) { cfg->nvrtc_opts[cfg->num_nvrtc_opts] = strdup(opt); cfg->num_nvrtc_opts++; cfg->nvrtc_opts = (char **) realloc(cfg->nvrtc_opts, (cfg->num_nvrtc_opts + 1) * sizeof(char *)); cfg->nvrtc_opts[cfg->num_nvrtc_opts] = NULL; } void futhark_context_config_set_device(struct futhark_context_config *cfg, const char *s) { int x = 0; if (*s == '#') { s++; while (isdigit(*s)) { x = x * 10 + (*s++)-'0'; } // Skip trailing spaces. while (isspace(*s)) { s++; } } free(cfg->preferred_device); cfg->preferred_device = strdup(s); cfg->preferred_device_num = x; } const char* futhark_context_config_get_program(struct futhark_context_config *cfg) { return cfg->program; } void futhark_context_config_set_program(struct futhark_context_config *cfg, const char *s) { free(cfg->program); cfg->program = strdup(s); } void futhark_context_config_dump_ptx_to(struct futhark_context_config *cfg, const char *path) { free(cfg->dump_ptx_to); cfg->dump_ptx_to = strdup(path); } void futhark_context_config_load_ptx_from(struct futhark_context_config *cfg, const char *path) { free(cfg->load_ptx_from); cfg->load_ptx_from = strdup(path); } void futhark_context_config_set_unified_memory(struct futhark_context_config* cfg, int flag) { cfg->unified_memory = flag; } // A record of something that happened. struct profiling_record { cudaEvent_t *events; // Points to two events. const char *name; }; struct futhark_context { struct futhark_context_config* cfg; int detail_memory; int debugging; int profiling; int profiling_paused; int logging; lock_t lock; char *error; lock_t error_lock; FILE *log; struct constants *constants; struct free_list free_list; struct event_list event_list; int64_t peak_mem_usage_default; int64_t cur_mem_usage_default; struct program* program; bool program_initialised; // Uniform fields above. CUdeviceptr global_failure; CUdeviceptr global_failure_args; struct tuning_params tuning_params; // True if a potentially failing kernel has been enqueued. int32_t failure_is_an_option; int total_runs; long int total_runtime; int64_t peak_mem_usage_device; int64_t cur_mem_usage_device; CUdevice dev; CUcontext cu_ctx; CUmodule module; CUstream stream; struct free_list gpu_free_list; size_t max_thread_block_size; size_t max_grid_size; size_t max_tile_size; size_t max_threshold; size_t max_shared_memory; size_t max_bespoke; size_t max_registers; size_t max_cache; size_t lockstep_width; struct builtin_kernels* kernels; }; #define CU_DEV_ATTR(x) (CU_DEVICE_ATTRIBUTE_##x) #define device_query(dev,attrib) _device_query(dev, CU_DEV_ATTR(attrib)) static int _device_query(CUdevice dev, CUdevice_attribute attrib) { int val; CUDA_SUCCEED_FATAL(cuDeviceGetAttribute(&val, attrib, dev)); return val; } #define CU_FUN_ATTR(x) (CU_FUNC_ATTRIBUTE_##x) #define function_query(fn,attrib) _function_query(dev, CU_FUN_ATTR(attrib)) static int _function_query(CUfunction dev, CUfunction_attribute attrib) { int val; CUDA_SUCCEED_FATAL(cuFuncGetAttribute(&val, attrib, dev)); return val; } static int cuda_device_setup(struct futhark_context *ctx) { struct futhark_context_config *cfg = ctx->cfg; char name[256]; int count, chosen = -1, best_cc = -1; int cc_major_best = 0, cc_minor_best = 0; int cc_major = 0, cc_minor = 0; CUdevice dev; CUDA_SUCCEED_FATAL(cuDeviceGetCount(&count)); if (count == 0) { return 1; } int num_device_matches = 0; // XXX: Current device selection policy is to choose the device with the // highest compute capability (if no preferred device is set). // This should maybe be changed, since greater compute capability is not // necessarily an indicator of better performance. for (int i = 0; i < count; i++) { CUDA_SUCCEED_FATAL(cuDeviceGet(&dev, i)); cc_major = device_query(dev, COMPUTE_CAPABILITY_MAJOR); cc_minor = device_query(dev, COMPUTE_CAPABILITY_MINOR); CUDA_SUCCEED_FATAL(cuDeviceGetName(name, sizeof(name) - 1, dev)); name[sizeof(name) - 1] = 0; if (cfg->logging) { fprintf(ctx->log, "Device #%d: name=\"%s\", compute capability=%d.%d\n", i, name, cc_major, cc_minor); } if (device_query(dev, COMPUTE_MODE) == CU_COMPUTEMODE_PROHIBITED) { if (cfg->logging) { fprintf(ctx->log, "Device #%d is compute-prohibited, ignoring\n", i); } continue; } if (best_cc == -1 || cc_major > cc_major_best || (cc_major == cc_major_best && cc_minor > cc_minor_best)) { best_cc = i; cc_major_best = cc_major; cc_minor_best = cc_minor; } if (strstr(name, cfg->preferred_device) != NULL && num_device_matches++ == cfg->preferred_device_num) { chosen = i; break; } } if (chosen == -1) { chosen = best_cc; } if (chosen == -1) { return 1; } if (cfg->logging) { fprintf(ctx->log, "Using device #%d\n", chosen); } CUDA_SUCCEED_FATAL(cuDeviceGet(&ctx->dev, chosen)); return 0; } static const char *cuda_nvrtc_get_arch(CUdevice dev) { static struct { int major; int minor; const char *arch_str; } const x[] = { { 3, 0, "compute_30" }, { 3, 2, "compute_32" }, { 3, 5, "compute_35" }, { 3, 7, "compute_37" }, { 5, 0, "compute_50" }, { 5, 2, "compute_52" }, { 5, 3, "compute_53" }, { 6, 0, "compute_60" }, { 6, 1, "compute_61" }, { 6, 2, "compute_62" }, { 7, 0, "compute_70" }, { 7, 2, "compute_72" }, { 7, 5, "compute_75" }, { 8, 0, "compute_80" }, { 8, 6, "compute_80" }, { 8, 7, "compute_80" } }; int major = device_query(dev, COMPUTE_CAPABILITY_MAJOR); int minor = device_query(dev, COMPUTE_CAPABILITY_MINOR); int chosen = -1; int num_archs = sizeof(x)/sizeof(x[0]); for (int i = 0; i < num_archs; i++) { if (x[i].major < major || (x[i].major == major && x[i].minor <= minor)) { chosen = i; } else { break; } } if (chosen == -1) { futhark_panic(-1, "Unsupported compute capability %d.%d\n", major, minor); } if (x[chosen].major != major || x[chosen].minor != minor) { fprintf(stderr, "Warning: device compute capability is %d.%d, but newest supported by Futhark is %d.%d.\n", major, minor, x[chosen].major, x[chosen].minor); } return x[chosen].arch_str; } static void cuda_nvrtc_mk_build_options(struct futhark_context *ctx, const char *extra_opts[], char*** opts_out, size_t *n_opts) { int arch_set = 0, num_extra_opts; struct futhark_context_config *cfg = ctx->cfg; char** macro_names; int64_t* macro_vals; int num_macros = gpu_macros(ctx, ¯o_names, ¯o_vals); // nvrtc cannot handle multiple -arch options. Hence, if one of the // extra_opts is -arch, we have to be careful not to do our usual // automatic generation. for (num_extra_opts = 0; extra_opts[num_extra_opts] != NULL; num_extra_opts++) { if (strstr(extra_opts[num_extra_opts], "-arch") == extra_opts[num_extra_opts] || strstr(extra_opts[num_extra_opts], "--gpu-architecture") == extra_opts[num_extra_opts]) { arch_set = 1; } } size_t i = 0, n_opts_alloc = 20 + num_macros + num_extra_opts + cfg->num_tuning_params; char **opts = (char**) malloc(n_opts_alloc * sizeof(char *)); if (!arch_set) { opts[i++] = strdup("-arch"); opts[i++] = strdup(cuda_nvrtc_get_arch(ctx->dev)); } opts[i++] = strdup("-default-device"); if (cfg->debugging) { opts[i++] = strdup("-G"); opts[i++] = strdup("-lineinfo"); } else { opts[i++] = strdup("--disable-warnings"); } opts[i++] = msgprintf("-D%s=%d", "max_thread_block_size", (int)ctx->max_thread_block_size); opts[i++] = msgprintf("-D%s=%d", "max_shared_memory", (int)ctx->max_shared_memory); opts[i++] = msgprintf("-D%s=%d", "max_registers", (int)ctx->max_registers); for (int j = 0; j < num_macros; j++) { opts[i++] = msgprintf("-D%s=%zu", macro_names[j], macro_vals[j]); } for (int j = 0; j < cfg->num_tuning_params; j++) { opts[i++] = msgprintf("-D%s=%zu", cfg->tuning_param_vars[j], cfg->tuning_params[j]); } opts[i++] = msgprintf("-DLOCKSTEP_WIDTH=%zu", ctx->lockstep_width); opts[i++] = msgprintf("-DMAX_THREADS_PER_BLOCK=%zu", ctx->max_thread_block_size); // Time for the best lines of the code in the entire compiler. if (getenv("CUDA_HOME") != NULL) { opts[i++] = msgprintf("-I%s/include", getenv("CUDA_HOME")); } if (getenv("CUDA_ROOT") != NULL) { opts[i++] = msgprintf("-I%s/include", getenv("CUDA_ROOT")); } if (getenv("CUDA_PATH") != NULL) { opts[i++] = msgprintf("-I%s/include", getenv("CUDA_PATH")); } opts[i++] = msgprintf("-I/usr/local/cuda/include"); opts[i++] = msgprintf("-I/usr/include"); for (int j = 0; extra_opts[j] != NULL; j++) { opts[i++] = strdup(extra_opts[j]); } opts[i++] = msgprintf("-DTR_BLOCK_DIM=%d", TR_BLOCK_DIM); opts[i++] = msgprintf("-DTR_TILE_DIM=%d", TR_TILE_DIM); opts[i++] = msgprintf("-DTR_ELEMS_PER_THREAD=%d", TR_ELEMS_PER_THREAD); free(macro_names); free(macro_vals); *n_opts = i; *opts_out = opts; } static char* cuda_nvrtc_build(const char *src, const char *opts[], size_t n_opts, char **ptx) { nvrtcProgram prog; char *problem = NULL; problem = NVRTC_SUCCEED_NONFATAL(nvrtcCreateProgram(&prog, src, "futhark-cuda", 0, NULL, NULL)); if (problem) { return problem; } nvrtcResult res = nvrtcCompileProgram(prog, n_opts, opts); if (res != NVRTC_SUCCESS) { size_t log_size; if (nvrtcGetProgramLogSize(prog, &log_size) == NVRTC_SUCCESS) { char *log = (char*) malloc(log_size); if (nvrtcGetProgramLog(prog, log) == NVRTC_SUCCESS) { problem = msgprintf("NVRTC compilation failed.\n\n%s\n", log); } else { problem = msgprintf("Could not retrieve compilation log\n"); } free(log); } return problem; } size_t ptx_size; NVRTC_SUCCEED_FATAL(nvrtcGetPTXSize(prog, &ptx_size)); *ptx = (char*) malloc(ptx_size); NVRTC_SUCCEED_FATAL(nvrtcGetPTX(prog, *ptx)); NVRTC_SUCCEED_FATAL(nvrtcDestroyProgram(&prog)); return NULL; } static void cuda_load_ptx_from_cache(struct futhark_context_config *cfg, const char *src, const char *opts[], size_t n_opts, struct cache_hash *h, const char *cache_fname, char **ptx) { if (cfg->logging) { fprintf(stderr, "Restoring cache from from %s...\n", cache_fname); } cache_hash_init(h); for (size_t i = 0; i < n_opts; i++) { cache_hash(h, opts[i], strlen(opts[i])); } cache_hash(h, src, strlen(src)); size_t ptxsize; errno = 0; if (cache_restore(cache_fname, h, (unsigned char**)ptx, &ptxsize) != 0) { if (cfg->logging) { fprintf(stderr, "Failed to restore cache (errno: %s)\n", strerror(errno)); } } } static void cuda_size_setup(struct futhark_context *ctx) { struct futhark_context_config *cfg = ctx->cfg; if (cfg->gpu.default_block_size > ctx->max_thread_block_size) { if (cfg->gpu.default_block_size_changed) { fprintf(stderr, "Note: Device limits default block size to %zu (down from %zu).\n", ctx->max_thread_block_size, cfg->gpu.default_block_size); } cfg->gpu.default_block_size = ctx->max_thread_block_size; } if (cfg->gpu.default_grid_size > ctx->max_grid_size) { if (cfg->gpu.default_grid_size_changed) { fprintf(stderr, "Note: Device limits default grid size to %zu (down from %zu).\n", ctx->max_grid_size, cfg->gpu.default_grid_size); } cfg->gpu.default_grid_size = ctx->max_grid_size; } if (cfg->gpu.default_tile_size > ctx->max_tile_size) { if (cfg->gpu.default_tile_size_changed) { fprintf(stderr, "Note: Device limits default tile size to %zu (down from %zu).\n", ctx->max_tile_size, cfg->gpu.default_tile_size); } cfg->gpu.default_tile_size = ctx->max_tile_size; } if (!cfg->gpu.default_grid_size_changed) { cfg->gpu.default_grid_size = (device_query(ctx->dev, MULTIPROCESSOR_COUNT) * device_query(ctx->dev, MAX_THREADS_PER_MULTIPROCESSOR)) / cfg->gpu.default_block_size; } for (int i = 0; i < cfg->num_tuning_params; i++) { const char *size_class = cfg->tuning_param_classes[i]; int64_t *size_value = &cfg->tuning_params[i]; const char* size_name = cfg->tuning_param_names[i]; int64_t max_value = 0, default_value = 0; if (strstr(size_class, "thread_block_size") == size_class) { max_value = ctx->max_thread_block_size; default_value = cfg->gpu.default_block_size; } else if (strstr(size_class, "grid_size") == size_class) { max_value = ctx->max_grid_size; default_value = cfg->gpu.default_grid_size; // XXX: as a quick and dirty hack, use twice as many threads for // histograms by default. We really should just be smarter // about sizes somehow. if (strstr(size_name, ".seghist_") != NULL) { default_value *= 2; } } else if (strstr(size_class, "tile_size") == size_class) { max_value = ctx->max_tile_size; default_value = cfg->gpu.default_tile_size; } else if (strstr(size_class, "reg_tile_size") == size_class) { max_value = 0; // No limit. default_value = cfg->gpu.default_reg_tile_size; } else if (strstr(size_class, "shared_memory") == size_class) { max_value = ctx->max_shared_memory; default_value = ctx->max_shared_memory; } else if (strstr(size_class, "cache") == size_class) { max_value = ctx->max_cache; default_value = ctx->max_cache; } else if (strstr(size_class, "threshold") == size_class) { // Threshold can be as large as it takes. default_value = cfg->gpu.default_threshold; } else { // Bespoke sizes have no limit or default. } if (*size_value == 0) { *size_value = default_value; } else if (max_value > 0 && *size_value > max_value) { fprintf(stderr, "Note: Device limits %s to %zu (down from %zu)\n", size_name, max_value, *size_value); *size_value = max_value; } } } static char* cuda_module_setup(struct futhark_context *ctx, const char *src, const char *extra_opts[], const char* cache_fname) { char *ptx = NULL; struct futhark_context_config *cfg = ctx->cfg; if (cfg->load_ptx_from) { ptx = slurp_file(cfg->load_ptx_from, NULL); } char **opts; size_t n_opts; cuda_nvrtc_mk_build_options(ctx, extra_opts, &opts, &n_opts); if (cfg->logging) { fprintf(stderr, "NVRTC compile options:\n"); for (size_t j = 0; j < n_opts; j++) { fprintf(stderr, "\t%s\n", opts[j]); } fprintf(stderr, "\n"); } struct cache_hash h; int loaded_ptx_from_cache = 0; if (cache_fname != NULL) { cuda_load_ptx_from_cache(cfg, src, (const char**)opts, n_opts, &h, cache_fname, &ptx); if (ptx != NULL) { if (cfg->logging) { fprintf(stderr, "Restored PTX from cache; now loading module...\n"); } if (cuModuleLoadData(&ctx->module, ptx) == CUDA_SUCCESS) { if (cfg->logging) { fprintf(stderr, "Success!\n"); } loaded_ptx_from_cache = 1; } else { if (cfg->logging) { fprintf(stderr, "Failed!\n"); } free(ptx); ptx = NULL; } } } if (ptx == NULL) { char* problem = cuda_nvrtc_build(src, (const char**)opts, n_opts, &ptx); if (problem != NULL) { return problem; } } if (cfg->dump_ptx_to != NULL) { dump_file(cfg->dump_ptx_to, ptx, strlen(ptx)); } if (!loaded_ptx_from_cache) { CUDA_SUCCEED_FATAL(cuModuleLoadData(&ctx->module, ptx)); } if (cache_fname != NULL && !loaded_ptx_from_cache) { if (cfg->logging) { fprintf(stderr, "Caching PTX in %s...\n", cache_fname); } errno = 0; if (cache_store(cache_fname, &h, (const unsigned char*)ptx, strlen(ptx)) != 0) { fprintf(stderr, "Failed to cache PTX: %s\n", strerror(errno)); } } for (size_t i = 0; i < n_opts; i++) { free((char *)opts[i]); } free(opts); free(ptx); return NULL; } struct cuda_event { cudaEvent_t start; cudaEvent_t end; }; static struct cuda_event* cuda_event_new(struct futhark_context* ctx) { if (ctx->profiling && !ctx->profiling_paused) { struct cuda_event* e = malloc(sizeof(struct cuda_event)); cudaEventCreate(&e->start); cudaEventCreate(&e->end); return e; } else { return NULL; } } static int cuda_event_report(struct str_builder* sb, struct cuda_event* e) { float ms; CUresult err; if ((err = cuEventElapsedTime(&ms, e->start, e->end)) != CUDA_SUCCESS) { return err; } // CUDA provides milisecond resolution, but we want microseconds. str_builder(sb, ",\"duration\":%f", ms*1000); if ((err = cuEventDestroy(e->start)) != CUDA_SUCCESS) { return 1; } if ((err = cuEventDestroy(e->end)) != CUDA_SUCCESS) { return 1; } free(e); return 0; } int futhark_context_sync(struct futhark_context* ctx) { CUDA_SUCCEED_OR_RETURN(cuCtxPushCurrent(ctx->cu_ctx)); CUDA_SUCCEED_OR_RETURN(cuCtxSynchronize()); if (ctx->failure_is_an_option) { // Check for any delayed error. int32_t failure_idx; CUDA_SUCCEED_OR_RETURN( cuMemcpyDtoH(&failure_idx, ctx->global_failure, sizeof(int32_t))); ctx->failure_is_an_option = 0; if (failure_idx >= 0) { // We have to clear global_failure so that the next entry point // is not considered a failure from the start. int32_t no_failure = -1; CUDA_SUCCEED_OR_RETURN( cuMemcpyHtoD(ctx->global_failure, &no_failure, sizeof(int32_t))); int64_t args[max_failure_args+1]; CUDA_SUCCEED_OR_RETURN( cuMemcpyDtoH(&args, ctx->global_failure_args, sizeof(args))); ctx->error = get_failure_msg(failure_idx, args); return FUTHARK_PROGRAM_ERROR; } } CUDA_SUCCEED_OR_RETURN(cuCtxPopCurrent(&ctx->cu_ctx)); return 0; } struct builtin_kernels* init_builtin_kernels(struct futhark_context* ctx); void free_builtin_kernels(struct futhark_context* ctx, struct builtin_kernels* kernels); int backend_context_setup(struct futhark_context* ctx) { ctx->failure_is_an_option = 0; ctx->total_runs = 0; ctx->total_runtime = 0; ctx->peak_mem_usage_device = 0; ctx->cur_mem_usage_device = 0; ctx->kernels = NULL; CUDA_SUCCEED_FATAL(cuInit(0)); if (cuda_device_setup(ctx) != 0) { futhark_panic(-1, "No suitable CUDA device found.\n"); } CUDA_SUCCEED_FATAL(cuCtxCreate(&ctx->cu_ctx, 0, ctx->dev)); free_list_init(&ctx->gpu_free_list); if (ctx->cfg->unified_memory == 2) { ctx->cfg->unified_memory = device_query(ctx->dev, MANAGED_MEMORY); } if (ctx->cfg->logging) { if (ctx->cfg->unified_memory) { fprintf(ctx->log, "Using managed memory\n"); } else { fprintf(ctx->log, "Using unmanaged memory\n"); } } ctx->max_thread_block_size = device_query(ctx->dev, MAX_THREADS_PER_BLOCK); ctx->max_grid_size = device_query(ctx->dev, MAX_GRID_DIM_X); ctx->max_tile_size = sqrt(ctx->max_thread_block_size); ctx->max_threshold = 1U<<31; // No limit. ctx->max_bespoke = 1U<<31; // No limit. if (ctx->cfg->gpu.default_registers != 0) { ctx->max_registers = ctx->cfg->gpu.default_registers; } else { ctx->max_registers = device_query(ctx->dev, MAX_REGISTERS_PER_BLOCK); } if (ctx->cfg->gpu.default_shared_memory != 0) { ctx->max_shared_memory = ctx->cfg->gpu.default_shared_memory; } else { // MAX_SHARED_MEMORY_PER_BLOCK gives bogus numbers (48KiB); probably // for backwards compatibility. Add _OPTIN and you seem to get the // right number. ctx->max_shared_memory = device_query(ctx->dev, MAX_SHARED_MEMORY_PER_BLOCK_OPTIN); #if CUDART_VERSION >= 12000 ctx->max_shared_memory -= device_query(ctx->dev, RESERVED_SHARED_MEMORY_PER_BLOCK); #endif } if (ctx->cfg->gpu.default_cache != 0) { ctx->max_cache = ctx->cfg->gpu.default_cache; } else { ctx->max_cache = device_query(ctx->dev, L2_CACHE_SIZE); } ctx->lockstep_width = device_query(ctx->dev, WARP_SIZE); CUDA_SUCCEED_FATAL(cuStreamCreate(&ctx->stream, CU_STREAM_DEFAULT)); cuda_size_setup(ctx); gpu_init_log(ctx); ctx->error = cuda_module_setup(ctx, ctx->cfg->program, (const char**)ctx->cfg->nvrtc_opts, ctx->cfg->cache_fname); if (ctx->error != NULL) { futhark_panic(1, "During CUDA initialisation:\n%s\n", ctx->error); } int32_t no_error = -1; CUDA_SUCCEED_FATAL(cuMemAlloc(&ctx->global_failure, sizeof(no_error))); CUDA_SUCCEED_FATAL(cuMemcpyHtoD(ctx->global_failure, &no_error, sizeof(no_error))); // The +1 is to avoid zero-byte allocations. CUDA_SUCCEED_FATAL(cuMemAlloc(&ctx->global_failure_args, sizeof(int64_t)*(max_failure_args+1))); if ((ctx->kernels = init_builtin_kernels(ctx)) == NULL) { return 1; } return 0; } void backend_context_teardown(struct futhark_context* ctx) { if (ctx->kernels != NULL) { free_builtin_kernels(ctx, ctx->kernels); cuMemFree(ctx->global_failure); cuMemFree(ctx->global_failure_args); CUDA_SUCCEED_FATAL(gpu_free_all(ctx)); CUDA_SUCCEED_FATAL(cuStreamDestroy(ctx->stream)); CUDA_SUCCEED_FATAL(cuModuleUnload(ctx->module)); CUDA_SUCCEED_FATAL(cuCtxDestroy(ctx->cu_ctx)); } free_list_destroy(&ctx->gpu_free_list); } // GPU ABSTRACTION LAYER // Types. typedef CUfunction gpu_kernel; typedef CUdeviceptr gpu_mem; static void gpu_create_kernel(struct futhark_context *ctx, gpu_kernel* kernel, const char* name) { if (ctx->debugging) { fprintf(ctx->log, "Creating kernel %s.\n", name); } CUDA_SUCCEED_FATAL(cuModuleGetFunction(kernel, ctx->module, name)); // Unless the below is set, the kernel is limited to 48KiB of memory. CUDA_SUCCEED_FATAL(cuFuncSetAttribute(*kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, ctx->max_shared_memory)); } static void gpu_free_kernel(struct futhark_context *ctx, gpu_kernel kernel) { (void)ctx; (void)kernel; } static int gpu_scalar_to_device(struct futhark_context* ctx, gpu_mem dst, size_t offset, size_t size, void *src) { struct cuda_event *event = cuda_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_scalar_to_dev", strdup(""), event, (event_report_fn)cuda_event_report); CUDA_SUCCEED_FATAL(cuEventRecord(event->start, ctx->stream)); } CUDA_SUCCEED_OR_RETURN(cuMemcpyHtoD(dst + offset, src, size)); if (event != NULL) { CUDA_SUCCEED_FATAL(cuEventRecord(event->end, ctx->stream)); } return FUTHARK_SUCCESS; } static int gpu_scalar_from_device(struct futhark_context* ctx, void *dst, gpu_mem src, size_t offset, size_t size) { struct cuda_event *event = cuda_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_scalar_from_dev", strdup(""), event, (event_report_fn)cuda_event_report); CUDA_SUCCEED_FATAL(cuEventRecord(event->start, ctx->stream)); } CUDA_SUCCEED_OR_RETURN(cuMemcpyDtoH(dst, src + offset, size)); if (event != NULL) { CUDA_SUCCEED_FATAL(cuEventRecord(event->end, ctx->stream)); } return FUTHARK_SUCCESS; } static int gpu_memcpy(struct futhark_context* ctx, gpu_mem dst, int64_t dst_offset, gpu_mem src, int64_t src_offset, int64_t nbytes) { struct cuda_event *event = cuda_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_dev_to_dev", strdup(""), event, (event_report_fn)cuda_event_report); CUDA_SUCCEED_FATAL(cuEventRecord(event->start, ctx->stream)); } CUDA_SUCCEED_OR_RETURN(cuMemcpyAsync(dst+dst_offset, src+src_offset, nbytes, ctx->stream)); if (event != NULL) { CUDA_SUCCEED_FATAL(cuEventRecord(event->end, ctx->stream)); } return FUTHARK_SUCCESS; } static int memcpy_host2gpu(struct futhark_context* ctx, bool sync, gpu_mem dst, int64_t dst_offset, const unsigned char* src, int64_t src_offset, int64_t nbytes) { if (nbytes > 0) { struct cuda_event *event = cuda_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_host_to_dev", strdup(""), event, (event_report_fn)cuda_event_report); CUDA_SUCCEED_FATAL(cuEventRecord(event->start, ctx->stream)); } if (sync) { CUDA_SUCCEED_OR_RETURN (cuMemcpyHtoD(dst + dst_offset, src + src_offset, nbytes)); } else { CUDA_SUCCEED_OR_RETURN (cuMemcpyHtoDAsync(dst + dst_offset, src + src_offset, nbytes, ctx->stream)); } if (event != NULL) { CUDA_SUCCEED_FATAL(cuEventRecord(event->end, ctx->stream)); } } return FUTHARK_SUCCESS; } static int memcpy_gpu2host(struct futhark_context* ctx, bool sync, unsigned char* dst, int64_t dst_offset, gpu_mem src, int64_t src_offset, int64_t nbytes) { if (nbytes > 0) { struct cuda_event *event = cuda_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_dev_to_host", strdup(""), event, (event_report_fn)cuda_event_report); CUDA_SUCCEED_FATAL(cuEventRecord(event->start, ctx->stream)); } if (sync) { CUDA_SUCCEED_OR_RETURN (cuMemcpyDtoH(dst + dst_offset, src + src_offset, nbytes)); } else { CUDA_SUCCEED_OR_RETURN (cuMemcpyDtoHAsync(dst + dst_offset, src + src_offset, nbytes, ctx->stream)); } if (event != NULL) { CUDA_SUCCEED_FATAL(cuEventRecord(event->end, ctx->stream)); } if (sync && ctx->failure_is_an_option && futhark_context_sync(ctx) != 0) { return 1; } } return FUTHARK_SUCCESS; } static int gpu_launch_kernel(struct futhark_context* ctx, gpu_kernel kernel, const char *name, const int32_t grid[3], const int32_t block[3], unsigned int shared_mem_bytes, int num_args, void* args[num_args], size_t args_sizes[num_args]) { (void) args_sizes; if (shared_mem_bytes > ctx->max_shared_memory) { set_error(ctx, msgprintf("Kernel %s with %d bytes of memory exceeds device limit of %d\n", name, shared_mem_bytes, (int)ctx->max_shared_memory)); return 1; } int64_t time_start = 0, time_end = 0; if (ctx->debugging) { time_start = get_wall_time(); } struct cuda_event *event = cuda_event_new(ctx); if (event != NULL) { CUDA_SUCCEED_FATAL(cuEventRecord(event->start, ctx->stream)); add_event(ctx, name, msgprintf("Kernel %s with\n" " grid=(%d,%d,%d)\n" " block=(%d,%d,%d)\n" " shared memory=%d", name, grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem_bytes), event, (event_report_fn)cuda_event_report); } CUDA_SUCCEED_OR_RETURN (cuLaunchKernel(kernel, grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem_bytes, ctx->stream, args, NULL)); if (event != NULL) { CUDA_SUCCEED_FATAL(cuEventRecord(event->end, ctx->stream)); } if (ctx->debugging) { CUDA_SUCCEED_FATAL(cuCtxSynchronize()); time_end = get_wall_time(); long int time_diff = time_end - time_start; fprintf(ctx->log, " runtime: %ldus\n", time_diff); } if (ctx->logging) { fprintf(ctx->log, "\n"); } return FUTHARK_SUCCESS; } static int gpu_alloc_actual(struct futhark_context *ctx, size_t size, gpu_mem *mem_out) { CUresult res; if (ctx->cfg->unified_memory) { res = cuMemAllocManaged(mem_out, size, CU_MEM_ATTACH_GLOBAL); } else { res = cuMemAlloc(mem_out, size); } if (res == CUDA_ERROR_OUT_OF_MEMORY) { return FUTHARK_OUT_OF_MEMORY; } CUDA_SUCCEED_OR_RETURN(res); return FUTHARK_SUCCESS; } static int gpu_free_actual(struct futhark_context *ctx, gpu_mem mem) { (void)ctx; CUDA_SUCCEED_OR_RETURN(cuMemFree(mem)); return FUTHARK_SUCCESS; } // End of backends/cuda.h. futhark-0.25.27/rts/c/backends/hip.h000066400000000000000000000730621475065116200171070ustar00rootroot00000000000000// Start of backends/hip.h. // Forward declarations. // Invoked by setup_opencl() after the platform and device has been // found, but before the program is loaded. Its intended use is to // tune constants based on the selected platform and device. static void set_tuning_params(struct futhark_context* ctx); static char* get_failure_msg(int failure_idx, int64_t args[]); #define HIP_SUCCEED_FATAL(x) hip_api_succeed_fatal(x, #x, __FILE__, __LINE__) #define HIP_SUCCEED_NONFATAL(x) hip_api_succeed_nonfatal(x, #x, __FILE__, __LINE__) #define HIPRTC_SUCCEED_FATAL(x) hiprtc_api_succeed_fatal(x, #x, __FILE__, __LINE__) #define HIPRTC_SUCCEED_NONFATAL(x) hiprtc_api_succeed_nonfatal(x, #x, __FILE__, __LINE__) // Take care not to override an existing error. #define HIP_SUCCEED_OR_RETURN(e) { \ char *serror = HIP_SUCCEED_NONFATAL(e); \ if (serror) { \ if (!ctx->error) { \ ctx->error = serror; \ } else { \ free(serror); \ } \ return bad; \ } \ } // HIP_SUCCEED_OR_RETURN returns the value of the variable 'bad' in // scope. By default, it will be this one. Create a local variable // of some other type if needed. This is a bit of a hack, but it // saves effort in the code generator. static const int bad = 1; static inline void hip_api_succeed_fatal(hipError_t res, const char *call, const char *file, int line) { if (res != hipSuccess) { const char *err_str = hipGetErrorString(res); if (err_str == NULL) { err_str = "Unknown"; } futhark_panic(-1, "%s:%d: HIP call\n %s\nfailed with error code %d (%s)\n", file, line, call, res, err_str); } } static char* hip_api_succeed_nonfatal(hipError_t res, const char *call, const char *file, int line) { if (res != hipSuccess) { const char *err_str = hipGetErrorString(res); if (err_str == NULL) { err_str = "Unknown"; } return msgprintf("%s:%d: HIP call\n %s\nfailed with error code %d (%s)\n", file, line, call, res, err_str); } else { return NULL; } } static inline void hiprtc_api_succeed_fatal(hiprtcResult res, const char *call, const char *file, int line) { if (res != HIPRTC_SUCCESS) { const char *err_str = hiprtcGetErrorString(res); futhark_panic(-1, "%s:%d: HIPRTC call\n %s\nfailed with error code %d (%s)\n", file, line, call, res, err_str); } } static char* hiprtc_api_succeed_nonfatal(hiprtcResult res, const char *call, const char *file, int line) { if (res != HIPRTC_SUCCESS) { const char *err_str = hiprtcGetErrorString(res); return msgprintf("%s:%d: HIPRTC call\n %s\nfailed with error code %d (%s)\n", file, line, call, res, err_str); } else { return NULL; } } struct futhark_context_config { int in_use; int debugging; int profiling; int logging; char* cache_fname; int num_tuning_params; int64_t *tuning_params; const char** tuning_param_names; const char** tuning_param_vars; const char** tuning_param_classes; // Uniform fields above. char* program; int num_build_opts; char* *build_opts; int unified_memory; char* preferred_device; int preferred_device_num; struct gpu_config gpu; }; static void backend_context_config_setup(struct futhark_context_config *cfg) { cfg->num_build_opts = 0; cfg->build_opts = (char**) malloc(sizeof(char*)); cfg->build_opts[0] = NULL; cfg->preferred_device_num = 0; cfg->preferred_device = strdup(""); cfg->program = strconcat(gpu_program); cfg->unified_memory = 0; cfg->gpu = gpu_config_initial; cfg->gpu.default_block_size = 256; cfg->gpu.default_tile_size = 32; cfg->gpu.default_reg_tile_size = 2; cfg->gpu.default_threshold = 32*1024; } static void backend_context_config_teardown(struct futhark_context_config* cfg) { for (int i = 0; i < cfg->num_build_opts; i++) { free(cfg->build_opts[i]); } free(cfg->build_opts); free(cfg->preferred_device); free(cfg->program); } void futhark_context_config_add_build_option(struct futhark_context_config *cfg, const char *opt) { cfg->build_opts[cfg->num_build_opts] = strdup(opt); cfg->num_build_opts++; cfg->build_opts = (char **) realloc(cfg->build_opts, (cfg->num_build_opts + 1) * sizeof(char *)); cfg->build_opts[cfg->num_build_opts] = NULL; } void futhark_context_config_set_device(struct futhark_context_config *cfg, const char *s) { int x = 0; if (*s == '#') { s++; while (isdigit(*s)) { x = x * 10 + (*s++)-'0'; } // Skip trailing spaces. while (isspace(*s)) { s++; } } free(cfg->preferred_device); cfg->preferred_device = strdup(s); cfg->preferred_device_num = x; } const char* futhark_context_config_get_program(struct futhark_context_config *cfg) { return cfg->program; } void futhark_context_config_set_program(struct futhark_context_config *cfg, const char *s) { free(cfg->program); cfg->program = strdup(s); } void futhark_context_config_set_unified_memory(struct futhark_context_config* cfg, int flag) { cfg->unified_memory = flag; } struct futhark_context { struct futhark_context_config* cfg; int detail_memory; int debugging; int profiling; int profiling_paused; int logging; lock_t lock; char *error; lock_t error_lock; FILE *log; struct constants *constants; struct free_list free_list; struct event_list event_list; int64_t peak_mem_usage_default; int64_t cur_mem_usage_default; bool program_initialised; // Uniform fields above. void* global_failure; void* global_failure_args; struct tuning_params tuning_params; // True if a potentially failing kernel has been enqueued. int32_t failure_is_an_option; int total_runs; long int total_runtime; int64_t peak_mem_usage_device; int64_t cur_mem_usage_device; struct program* program; hipDevice_t dev; int dev_id; hipModule_t module; hipStream_t stream; struct free_list gpu_free_list; size_t max_thread_block_size; size_t max_grid_size; size_t max_tile_size; size_t max_threshold; size_t max_shared_memory; size_t max_bespoke; size_t max_registers; size_t max_cache; size_t lockstep_width; struct builtin_kernels* kernels; }; static int device_query(int dev_id, hipDeviceAttribute_t attr) { int val; HIP_SUCCEED_FATAL(hipDeviceGetAttribute(&val, attr, dev_id)); return val; } static int function_query(hipFunction_t f, hipFunction_attribute attr) { int val; HIP_SUCCEED_FATAL(hipFuncGetAttribute(&val, attr, f)); return val; } static int hip_device_setup(struct futhark_context *ctx) { struct futhark_context_config *cfg = ctx->cfg; int count, chosen = -1; hipDevice_t dev; HIP_SUCCEED_FATAL(hipGetDeviceCount(&count)); if (count == 0) { return 1; } int num_device_matches = 0; for (int i = 0; i < count; i++) { hipDeviceProp_t prop; hipGetDeviceProperties(&prop, i); if (cfg->logging) { fprintf(ctx->log, "Device #%d: name=\"%s\"\n", i, prop.name); } if (strstr(prop.name, cfg->preferred_device) != NULL && num_device_matches++ == cfg->preferred_device_num) { chosen = i; break; } } if (chosen == -1) { return 1; } if (cfg->logging) { fprintf(ctx->log, "Using device #%d\n", chosen); } ctx->dev_id = chosen; HIP_SUCCEED_FATAL(hipDeviceGet(&ctx->dev, ctx->dev_id)); return 0; } static void hip_load_code_from_cache(struct futhark_context_config *cfg, const char *src, const char *opts[], size_t n_opts, struct cache_hash *h, const char *cache_fname, char **code, size_t *code_size) { if (cfg->logging) { fprintf(stderr, "Restoring cache from from %s...\n", cache_fname); } cache_hash_init(h); for (size_t i = 0; i < n_opts; i++) { cache_hash(h, opts[i], strlen(opts[i])); } cache_hash(h, src, strlen(src)); errno = 0; if (cache_restore(cache_fname, h, (unsigned char**)code, code_size) != 0) { if (cfg->logging) { fprintf(stderr, "Failed to restore cache (errno: %s)\n", strerror(errno)); } } } static void hip_size_setup(struct futhark_context *ctx) { struct futhark_context_config *cfg = ctx->cfg; if (cfg->gpu.default_block_size > ctx->max_thread_block_size) { if (cfg->gpu.default_block_size_changed) { fprintf(stderr, "Note: Device limits default block size to %zu (down from %zu).\n", ctx->max_thread_block_size, cfg->gpu.default_block_size); } cfg->gpu.default_block_size = ctx->max_thread_block_size; } if (cfg->gpu.default_grid_size > ctx->max_grid_size) { if (cfg->gpu.default_grid_size_changed) { fprintf(stderr, "Note: Device limits default grid size to %zu (down from %zu).\n", ctx->max_grid_size, cfg->gpu.default_grid_size); } cfg->gpu.default_grid_size = ctx->max_grid_size; } if (cfg->gpu.default_tile_size > ctx->max_tile_size) { if (cfg->gpu.default_tile_size_changed) { fprintf(stderr, "Note: Device limits default tile size to %zu (down from %zu).\n", ctx->max_tile_size, cfg->gpu.default_tile_size); } cfg->gpu.default_tile_size = ctx->max_tile_size; } if (!cfg->gpu.default_grid_size_changed) { cfg->gpu.default_grid_size = (device_query(ctx->dev, hipDeviceAttributePhysicalMultiProcessorCount) * device_query(ctx->dev, hipDeviceAttributeMaxThreadsPerMultiProcessor)) / cfg->gpu.default_block_size; } for (int i = 0; i < cfg->num_tuning_params; i++) { const char *size_class = cfg->tuning_param_classes[i]; int64_t *size_value = &cfg->tuning_params[i]; const char* size_name = cfg->tuning_param_names[i]; int64_t max_value = 0, default_value = 0; if (strstr(size_class, "thread_block_size") == size_class) { max_value = ctx->max_thread_block_size; default_value = cfg->gpu.default_block_size; } else if (strstr(size_class, "grid_size") == size_class) { max_value = ctx->max_grid_size; default_value = cfg->gpu.default_grid_size; // XXX: as a quick and dirty hack, use twice as many threads for // histograms by default. We really should just be smarter // about sizes somehow. if (strstr(size_name, ".seghist_") != NULL) { default_value *= 2; } } else if (strstr(size_class, "tile_size") == size_class) { max_value = ctx->max_tile_size; default_value = cfg->gpu.default_tile_size; } else if (strstr(size_class, "reg_tile_size") == size_class) { max_value = 0; // No limit. default_value = cfg->gpu.default_reg_tile_size; } else if (strstr(size_class, "shared_memory") == size_class) { max_value = ctx->max_shared_memory; default_value = ctx->max_shared_memory; } else if (strstr(size_class, "cache") == size_class) { max_value = ctx->max_cache; default_value = ctx->max_cache; } else if (strstr(size_class, "threshold") == size_class) { // Threshold can be as large as it takes. default_value = cfg->gpu.default_threshold; } else { // Bespoke sizes have no limit or default. } if (*size_value == 0) { *size_value = default_value; } else if (max_value > 0 && *size_value > max_value) { fprintf(stderr, "Note: Device limits %s to %zu (down from %zu)\n", size_name, max_value, *size_value); *size_value = max_value; } } } static char* hiprtc_build(const char *src, const char *opts[], size_t n_opts, char **code, size_t *code_size) { hiprtcProgram prog; char *problem = NULL; problem = HIPRTC_SUCCEED_NONFATAL(hiprtcCreateProgram(&prog, src, "futhark-hip", 0, NULL, NULL)); if (problem) { return problem; } hiprtcResult res = hiprtcCompileProgram(prog, n_opts, opts); if (res != HIPRTC_SUCCESS) { size_t log_size; if (hiprtcGetProgramLogSize(prog, &log_size) == HIPRTC_SUCCESS) { char *log = (char*) malloc(log_size+1); log[log_size] = 0; // HIPRTC does not zero-terminate. if (hiprtcGetProgramLog(prog, log) == HIPRTC_SUCCESS) { problem = msgprintf("HIPRTC compilation failed.\n\n%s\n", log); } else { problem = msgprintf("Could not retrieve compilation log\n"); } free(log); } return problem; } HIPRTC_SUCCEED_FATAL(hiprtcGetCodeSize(prog, code_size)); *code = (char*) malloc(*code_size); HIPRTC_SUCCEED_FATAL(hiprtcGetCode(prog, *code)); HIPRTC_SUCCEED_FATAL(hiprtcDestroyProgram(&prog)); return NULL; } static void hiprtc_mk_build_options(struct futhark_context *ctx, const char *extra_opts[], char*** opts_out, size_t *n_opts) { int arch_set = 0, num_extra_opts; struct futhark_context_config *cfg = ctx->cfg; char** macro_names; int64_t* macro_vals; int num_macros = gpu_macros(ctx, ¯o_names, ¯o_vals); for (num_extra_opts = 0; extra_opts[num_extra_opts] != NULL; num_extra_opts++) { if (strstr(extra_opts[num_extra_opts], "--gpu-architecture") == extra_opts[num_extra_opts]) { arch_set = 1; } } size_t i = 0, n_opts_alloc = 20 + num_macros + num_extra_opts + cfg->num_tuning_params; char **opts = (char**) malloc(n_opts_alloc * sizeof(char *)); if (!arch_set) { hipDeviceProp_t props; HIP_SUCCEED_FATAL(hipGetDeviceProperties(&props, ctx->dev_id)); opts[i++] = msgprintf("--gpu-architecture=%s", props.gcnArchName); } if (cfg->debugging) { opts[i++] = strdup("-G"); opts[i++] = strdup("-lineinfo"); } opts[i++] = msgprintf("-D%s=%d", "max_thread_block_size", (int)ctx->max_thread_block_size); opts[i++] = msgprintf("-D%s=%d", "max_shared_memory", (int)ctx->max_shared_memory); opts[i++] = msgprintf("-D%s=%d", "max_registers", (int)ctx->max_registers); for (int j = 0; j < num_macros; j++) { opts[i++] = msgprintf("-D%s=%zu", macro_names[j], macro_vals[j]); } for (int j = 0; j < cfg->num_tuning_params; j++) { opts[i++] = msgprintf("-D%s=%zu", cfg->tuning_param_vars[j], cfg->tuning_params[j]); } opts[i++] = msgprintf("-DLOCKSTEP_WIDTH=%zu", ctx->lockstep_width); opts[i++] = msgprintf("-DMAX_THREADS_PER_BLOCK=%zu", ctx->max_thread_block_size); for (int j = 0; extra_opts[j] != NULL; j++) { opts[i++] = strdup(extra_opts[j]); } opts[i++] = msgprintf("-DTR_BLOCK_DIM=%d", TR_BLOCK_DIM); opts[i++] = msgprintf("-DTR_TILE_DIM=%d", TR_TILE_DIM); opts[i++] = msgprintf("-DTR_ELEMS_PER_THREAD=%d", TR_ELEMS_PER_THREAD); free(macro_names); free(macro_vals); *n_opts = i; *opts_out = opts; } static char* hip_module_setup(struct futhark_context *ctx, const char *src, const char *extra_opts[], const char* cache_fname) { char *code = NULL; size_t code_size = 0; struct futhark_context_config *cfg = ctx->cfg; char **opts; size_t n_opts; hiprtc_mk_build_options(ctx, extra_opts, &opts, &n_opts); if (cfg->logging) { fprintf(stderr, "HIPRTC build options:\n"); for (size_t j = 0; j < n_opts; j++) { fprintf(stderr, "\t%s\n", opts[j]); } fprintf(stderr, "\n"); } struct cache_hash h; int loaded_code_from_cache = 0; if (cache_fname != NULL) { hip_load_code_from_cache(cfg, src, (const char**)opts, n_opts, &h, cache_fname, &code, &code_size); if (code != NULL) { if (cfg->logging) { fprintf(stderr, "Restored compiled code from cache; now loading module...\n"); } if (hipModuleLoadData(&ctx->module, code) == hipSuccess) { if (cfg->logging) { fprintf(stderr, "Success!\n"); } loaded_code_from_cache = 1; } else { if (cfg->logging) { fprintf(stderr, "Failed!\n"); } free(code); code = NULL; } } } if (code == NULL) { char* problem = hiprtc_build(src, (const char**)opts, n_opts, &code, &code_size); if (problem != NULL) { return problem; } } if (!loaded_code_from_cache) { HIP_SUCCEED_FATAL(hipModuleLoadData(&ctx->module, code)); } if (cache_fname != NULL && !loaded_code_from_cache) { if (cfg->logging) { fprintf(stderr, "Caching compiled code in %s...\n", cache_fname); } errno = 0; if (cache_store(cache_fname, &h, (const unsigned char*)code, code_size) != 0) { fprintf(stderr, "Failed to cache compiled code: %s\n", strerror(errno)); } } for (size_t i = 0; i < n_opts; i++) { free((char *)opts[i]); } free(opts); free(code); return NULL; } struct hip_event { hipEvent_t start; hipEvent_t end; }; static struct hip_event* hip_event_new(struct futhark_context* ctx) { if (ctx->profiling && !ctx->profiling_paused) { struct hip_event* e = malloc(sizeof(struct hip_event)); hipEventCreate(&e->start); hipEventCreate(&e->end); return e; } else { return NULL; } } static int hip_event_report(struct str_builder* sb, struct hip_event* e) { float ms; hipError_t err; if ((err = hipEventElapsedTime(&ms, e->start, e->end)) != hipSuccess) { return err; } // HIP provides milisecond resolution, but we want microseconds. str_builder(sb, ",\"duration\":%f", ms*1000); if ((err = hipEventDestroy(e->start)) != hipSuccess) { return 1; } if ((err = hipEventDestroy(e->end)) != hipSuccess) { return 1; } free(e); return 0; } int futhark_context_sync(struct futhark_context* ctx) { HIP_SUCCEED_OR_RETURN(hipStreamSynchronize(ctx->stream)); if (ctx->failure_is_an_option) { // Check for any delayed error. int32_t failure_idx; HIP_SUCCEED_OR_RETURN(hipMemcpyDtoH(&failure_idx, ctx->global_failure, sizeof(int32_t))); ctx->failure_is_an_option = 0; if (failure_idx >= 0) { // We have to clear global_failure so that the next entry point // is not considered a failure from the start. int32_t no_failure = -1; HIP_SUCCEED_OR_RETURN(hipMemcpyHtoD(ctx->global_failure, &no_failure, sizeof(int32_t))); int64_t args[max_failure_args+1]; HIP_SUCCEED_OR_RETURN(hipMemcpyDtoH(&args, ctx->global_failure_args, sizeof(args))); ctx->error = get_failure_msg(failure_idx, args); return FUTHARK_PROGRAM_ERROR; } } return 0; } struct builtin_kernels* init_builtin_kernels(struct futhark_context* ctx); void free_builtin_kernels(struct futhark_context* ctx, struct builtin_kernels* kernels); int backend_context_setup(struct futhark_context* ctx) { ctx->failure_is_an_option = 0; ctx->total_runs = 0; ctx->total_runtime = 0; ctx->peak_mem_usage_device = 0; ctx->cur_mem_usage_device = 0; ctx->kernels = NULL; HIP_SUCCEED_FATAL(hipInit(0)); if (hip_device_setup(ctx) != 0) { futhark_panic(-1, "No suitable HIP device found.\n"); } free_list_init(&ctx->gpu_free_list); if (ctx->cfg->unified_memory == 2) { ctx->cfg->unified_memory = device_query(ctx->dev, hipDeviceAttributeManagedMemory); } if (ctx->cfg->logging) { if (ctx->cfg->unified_memory) { fprintf(ctx->log, "Using managed memory\n"); } else { fprintf(ctx->log, "Using unmanaged memory\n"); } } ctx->max_thread_block_size = device_query(ctx->dev, hipDeviceAttributeMaxThreadsPerBlock); ctx->max_grid_size = device_query(ctx->dev, hipDeviceAttributeMaxGridDimX); ctx->max_tile_size = sqrt(ctx->max_thread_block_size); ctx->max_threshold = 1U<<31; // No limit. ctx->max_bespoke = 0; ctx->max_registers = device_query(ctx->dev, hipDeviceAttributeMaxRegistersPerBlock); if (ctx->cfg->gpu.default_shared_memory != 0) { ctx->max_shared_memory = ctx->cfg->gpu.default_shared_memory; } else { ctx->max_shared_memory = device_query(ctx->dev, hipDeviceAttributeMaxSharedMemoryPerBlock); } if (ctx->cfg->gpu.default_cache != 0) { ctx->max_cache = ctx->cfg->gpu.default_cache; } else { ctx->max_cache = device_query(ctx->dev, hipDeviceAttributeL2CacheSize); } // FIXME: in principle we should query hipDeviceAttributeWarpSize // from the device, which will provide 64 on AMD GPUs. // Unfortunately, we currently do nasty implicit intra-warp // synchronisation in codegen, which does not work when this is 64. // Once our codegen properly synchronises intra-warp operations, we // can use the actual hardware lockstep width instead. ctx->lockstep_width = 32; HIP_SUCCEED_FATAL(hipStreamCreate(&ctx->stream)); hip_size_setup(ctx); gpu_init_log(ctx); ctx->error = hip_module_setup(ctx, ctx->cfg->program, (const char**)ctx->cfg->build_opts, ctx->cfg->cache_fname); if (ctx->error != NULL) { futhark_panic(1, "During HIP initialisation:\n%s\n", ctx->error); } int32_t no_error = -1; HIP_SUCCEED_FATAL(hipMalloc(&ctx->global_failure, sizeof(no_error))); HIP_SUCCEED_FATAL(hipMemcpyHtoD(ctx->global_failure, &no_error, sizeof(no_error))); // The +1 is to avoid zero-byte allocations. HIP_SUCCEED_FATAL(hipMalloc(&ctx->global_failure_args, sizeof(int64_t)*(max_failure_args+1))); if ((ctx->kernels = init_builtin_kernels(ctx)) == NULL) { return 1; } return 0; } void backend_context_teardown(struct futhark_context* ctx) { if (ctx->kernels != NULL) { free_builtin_kernels(ctx, ctx->kernels); hipFree(ctx->global_failure); hipFree(ctx->global_failure_args); HIP_SUCCEED_FATAL(gpu_free_all(ctx)); HIP_SUCCEED_FATAL(hipStreamDestroy(ctx->stream)); HIP_SUCCEED_FATAL(hipModuleUnload(ctx->module)); } free_list_destroy(&ctx->gpu_free_list); } // GPU ABSTRACTION LAYER typedef hipFunction_t gpu_kernel; typedef hipDeviceptr_t gpu_mem; static void gpu_create_kernel(struct futhark_context *ctx, gpu_kernel* kernel, const char* name) { if (ctx->debugging) { fprintf(ctx->log, "Creating kernel %s.\n", name); } HIP_SUCCEED_FATAL(hipModuleGetFunction(kernel, ctx->module, name)); } static void gpu_free_kernel(struct futhark_context *ctx, gpu_kernel kernel) { (void)ctx; (void)kernel; } static int gpu_scalar_to_device(struct futhark_context* ctx, gpu_mem dst, size_t offset, size_t size, void *src) { struct hip_event *event = hip_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_scalar_to_dev", strdup(""), event, (event_report_fn)hip_event_report); HIP_SUCCEED_FATAL(hipEventRecord(event->start, ctx->stream)); } HIP_SUCCEED_OR_RETURN(hipMemcpyHtoD((unsigned char*)dst + offset, src, size)); if (event != NULL) { HIP_SUCCEED_FATAL(hipEventRecord(event->end, ctx->stream)); } return FUTHARK_SUCCESS; } static int gpu_scalar_from_device(struct futhark_context* ctx, void *dst, gpu_mem src, size_t offset, size_t size) { struct hip_event *event = hip_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_scalar_from_dev", strdup(""), event, (event_report_fn)hip_event_report); HIP_SUCCEED_FATAL(hipEventRecord(event->start, ctx->stream)); } HIP_SUCCEED_OR_RETURN(hipMemcpyDtoH(dst, (unsigned char*)src + offset, size)); if (event != NULL) { HIP_SUCCEED_FATAL(hipEventRecord(event->end, ctx->stream)); } return FUTHARK_SUCCESS; } static int gpu_memcpy(struct futhark_context* ctx, gpu_mem dst, int64_t dst_offset, gpu_mem src, int64_t src_offset, int64_t nbytes) { struct hip_event *event = hip_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_dev_to_dev", strdup(""), event, (event_report_fn)hip_event_report); HIP_SUCCEED_FATAL(hipEventRecord(event->start, ctx->stream)); } HIP_SUCCEED_OR_RETURN(hipMemcpyWithStream((unsigned char*)dst+dst_offset, (unsigned char*)src+src_offset, nbytes, hipMemcpyDeviceToDevice, ctx->stream)); if (event != NULL) { HIP_SUCCEED_FATAL(hipEventRecord(event->end, ctx->stream)); } return FUTHARK_SUCCESS; } static int memcpy_host2gpu(struct futhark_context* ctx, bool sync, gpu_mem dst, int64_t dst_offset, const unsigned char* src, int64_t src_offset, int64_t nbytes) { if (nbytes > 0) { struct hip_event *event = hip_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_host_to_dev", strdup(""), event, (event_report_fn)hip_event_report); HIP_SUCCEED_FATAL(hipEventRecord(event->start, ctx->stream)); } if (sync) { HIP_SUCCEED_OR_RETURN (hipMemcpyHtoD((unsigned char*)dst + dst_offset, (unsigned char*)src + src_offset, nbytes)); } else { HIP_SUCCEED_OR_RETURN (hipMemcpyHtoDAsync((unsigned char*)dst + dst_offset, (unsigned char*)src + src_offset, nbytes, ctx->stream)); } if (event != NULL) { HIP_SUCCEED_FATAL(hipEventRecord(event->end, ctx->stream)); } } return FUTHARK_SUCCESS; } static int memcpy_gpu2host(struct futhark_context* ctx, bool sync, unsigned char* dst, int64_t dst_offset, gpu_mem src, int64_t src_offset, int64_t nbytes) { if (nbytes > 0) { struct hip_event *event = hip_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_dev_to_host", strdup(""), event, (event_report_fn)hip_event_report); HIP_SUCCEED_FATAL(hipEventRecord(event->start, ctx->stream)); } if (sync) { HIP_SUCCEED_OR_RETURN (hipMemcpyDtoH(dst + dst_offset, (unsigned char*)src + src_offset, nbytes)); } else { HIP_SUCCEED_OR_RETURN (hipMemcpyDtoHAsync(dst + dst_offset, (unsigned char*)src + src_offset, nbytes, ctx->stream)); } if (event != NULL) { HIP_SUCCEED_FATAL(hipEventRecord(event->end, ctx->stream)); } if (sync && ctx->failure_is_an_option && futhark_context_sync(ctx) != 0) { return 1; } } return FUTHARK_SUCCESS; } static int gpu_launch_kernel(struct futhark_context* ctx, gpu_kernel kernel, const char *name, const int32_t grid[3], const int32_t block[3], unsigned int shared_mem_bytes, int num_args, void* args[num_args], size_t args_sizes[num_args]) { (void) args_sizes; if (shared_mem_bytes > ctx->max_shared_memory) { set_error(ctx, msgprintf("Kernel %s with %d bytes of memory exceeds device limit of %d\n", name, shared_mem_bytes, (int)ctx->max_shared_memory)); return 1; } int64_t time_start = 0, time_end = 0; if (ctx->debugging) { time_start = get_wall_time(); } struct hip_event *event = hip_event_new(ctx); if (event != NULL) { HIP_SUCCEED_FATAL(hipEventRecord(event->start, ctx->stream)); add_event(ctx, name, msgprintf("Kernel %s with\n" " grid=(%d,%d,%d)\n" " block=(%d,%d,%d)\n" " shared memory=%d", name, grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem_bytes), event, (event_report_fn)hip_event_report); } HIP_SUCCEED_OR_RETURN (hipModuleLaunchKernel(kernel, grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem_bytes, ctx->stream, args, NULL)); if (event != NULL) { HIP_SUCCEED_FATAL(hipEventRecord(event->end, ctx->stream)); } if (ctx->debugging) { HIP_SUCCEED_FATAL(hipStreamSynchronize(ctx->stream)); time_end = get_wall_time(); long int time_diff = time_end - time_start; fprintf(ctx->log, " runtime: %ldus\n", time_diff); } if (ctx->logging) { fprintf(ctx->log, "\n"); } return FUTHARK_SUCCESS; } static int gpu_alloc_actual(struct futhark_context *ctx, size_t size, gpu_mem *mem_out) { hipError_t res; if (ctx->cfg->unified_memory) { res = hipMallocManaged(mem_out, size, hipMemAttachGlobal); } else { res = hipMalloc(mem_out, size); } if (res == hipErrorOutOfMemory) { return FUTHARK_OUT_OF_MEMORY; } HIP_SUCCEED_OR_RETURN(res); return FUTHARK_SUCCESS; } static int gpu_free_actual(struct futhark_context *ctx, gpu_mem mem) { (void)ctx; HIP_SUCCEED_OR_RETURN(hipFree(mem)); return FUTHARK_SUCCESS; } // End of backends/hip.h. futhark-0.25.27/rts/c/backends/multicore.h000066400000000000000000000054201475065116200203230ustar00rootroot00000000000000// Start of backends/multicore.h struct futhark_context_config { int in_use; int debugging; int profiling; int logging; char *cache_fname; int num_tuning_params; int64_t *tuning_params; const char** tuning_param_names; const char** tuning_param_vars; const char** tuning_param_classes; // Uniform fields above. int num_threads; }; static void backend_context_config_setup(struct futhark_context_config* cfg) { cfg->num_threads = 0; } static void backend_context_config_teardown(struct futhark_context_config* cfg) { (void)cfg; } void futhark_context_config_set_num_threads(struct futhark_context_config *cfg, int n) { cfg->num_threads = n; } int futhark_context_config_set_tuning_param(struct futhark_context_config* cfg, const char *param_name, size_t param_value) { (void)cfg; (void)param_name; (void)param_value; return 1; } struct futhark_context { struct futhark_context_config* cfg; int detail_memory; int debugging; int profiling; int profiling_paused; int logging; lock_t lock; char *error; lock_t error_lock; FILE *log; struct constants *constants; struct free_list free_list; struct event_list event_list; int64_t peak_mem_usage_default; int64_t cur_mem_usage_default; struct program* program; bool program_initialised; // Uniform fields above. lock_t event_list_lock; struct scheduler scheduler; int total_runs; long int total_runtime; int64_t tuning_timing; int64_t tuning_iter; }; int backend_context_setup(struct futhark_context* ctx) { // Initialize rand() fast_srand(time(0)); int tune_kappa = 0; double kappa = 5.1f * 1000; if (tune_kappa) { if (determine_kappa(&kappa) != 0) { ctx->error = strdup("Failed to determine kappa."); return 1; } } if (scheduler_init(&ctx->scheduler, ctx->cfg->num_threads > 0 ? ctx->cfg->num_threads : num_processors(), kappa) != 0) { ctx->error = strdup("Failed to initialise scheduler."); return 1; } create_lock(&ctx->event_list_lock); return 0; } void backend_context_teardown(struct futhark_context* ctx) { (void)scheduler_destroy(&ctx->scheduler); free_lock(&ctx->event_list_lock); } int futhark_context_sync(struct futhark_context* ctx) { (void)ctx; return 0; } struct mc_event { // Time in microseconds. uint64_t bef, aft; }; static struct mc_event* mc_event_new(struct futhark_context* ctx) { if (ctx->profiling && !ctx->profiling_paused) { struct mc_event* e = malloc(sizeof(struct mc_event)); return e; } else { return NULL; } } static int mc_event_report(struct str_builder* sb, struct mc_event* e) { float ms = e->aft - e->bef; str_builder(sb, ",\"duration\":%f", ms); free(e); return 0; } // End of backends/multicore.h futhark-0.25.27/rts/c/backends/opencl.h000066400000000000000000001412211475065116200176000ustar00rootroot00000000000000// Start of backends/opencl.h // Note [32-bit transpositions] // // Transposition kernels are much slower when they have to use 64-bit // arithmetic. I observed about 0.67x slowdown on an A100 GPU when // transposing four-byte elements (much less when transposing 8-byte // elements). Unfortunately, 64-bit arithmetic is a requirement for // large arrays (see #1953 for what happens otherwise). We generate // both 32- and 64-bit index arithmetic versions of transpositions, // and dynamically pick between them at runtime. This is an // unfortunate code bloat, and it would be preferable if we could // simply optimise the 64-bit version to make this distinction // unnecessary. Fortunately these kernels are quite small. // Forward declarations. struct opencl_device_option; // Invoked by setup_opencl() after the platform and device has been // found, but before the program is loaded. Its intended use is to // tune constants based on the selected platform and device. static void post_opencl_setup(struct futhark_context*, struct opencl_device_option*); static void set_tuning_params(struct futhark_context* ctx); static char* get_failure_msg(int failure_idx, int64_t args[]); #define OPENCL_SUCCEED_FATAL(e) opencl_succeed_fatal(e, #e, __FILE__, __LINE__) #define OPENCL_SUCCEED_NONFATAL(e) opencl_succeed_nonfatal(e, #e, __FILE__, __LINE__) // Take care not to override an existing error. #define OPENCL_SUCCEED_OR_RETURN(e) { \ char *serror = OPENCL_SUCCEED_NONFATAL(e); \ if (serror) { \ if (!ctx->error) { \ ctx->error = serror; \ } else { \ free(serror); \ } \ return bad; \ } \ } // OPENCL_SUCCEED_OR_RETURN returns the value of the variable 'bad' in // scope. By default, it will be this one. Create a local variable // of some other type if needed. This is a bit of a hack, but it // saves effort in the code generator. static const int bad = 1; static const char* opencl_error_string(cl_int err) { switch (err) { case CL_SUCCESS: return "Success!"; case CL_DEVICE_NOT_FOUND: return "Device not found."; case CL_DEVICE_NOT_AVAILABLE: return "Device not available"; case CL_COMPILER_NOT_AVAILABLE: return "Compiler not available"; case CL_MEM_OBJECT_ALLOCATION_FAILURE: return "Memory object allocation failure"; case CL_OUT_OF_RESOURCES: return "Out of resources"; case CL_OUT_OF_HOST_MEMORY: return "Out of host memory"; case CL_PROFILING_INFO_NOT_AVAILABLE: return "Profiling information not available"; case CL_MEM_COPY_OVERLAP: return "Memory copy overlap"; case CL_IMAGE_FORMAT_MISMATCH: return "Image format mismatch"; case CL_IMAGE_FORMAT_NOT_SUPPORTED: return "Image format not supported"; case CL_BUILD_PROGRAM_FAILURE: return "Program build failure"; case CL_MAP_FAILURE: return "Map failure"; case CL_INVALID_VALUE: return "Invalid value"; case CL_INVALID_DEVICE_TYPE: return "Invalid device type"; case CL_INVALID_PLATFORM: return "Invalid platform"; case CL_INVALID_DEVICE: return "Invalid device"; case CL_INVALID_CONTEXT: return "Invalid context"; case CL_INVALID_QUEUE_PROPERTIES: return "Invalid queue properties"; case CL_INVALID_COMMAND_QUEUE: return "Invalid command queue"; case CL_INVALID_HOST_PTR: return "Invalid host pointer"; case CL_INVALID_MEM_OBJECT: return "Invalid memory object"; case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: return "Invalid image format descriptor"; case CL_INVALID_IMAGE_SIZE: return "Invalid image size"; case CL_INVALID_SAMPLER: return "Invalid sampler"; case CL_INVALID_BINARY: return "Invalid binary"; case CL_INVALID_BUILD_OPTIONS: return "Invalid build options"; case CL_INVALID_PROGRAM: return "Invalid program"; case CL_INVALID_PROGRAM_EXECUTABLE: return "Invalid program executable"; case CL_INVALID_KERNEL_NAME: return "Invalid kernel name"; case CL_INVALID_KERNEL_DEFINITION: return "Invalid kernel definition"; case CL_INVALID_KERNEL: return "Invalid kernel"; case CL_INVALID_ARG_INDEX: return "Invalid argument index"; case CL_INVALID_ARG_VALUE: return "Invalid argument value"; case CL_INVALID_ARG_SIZE: return "Invalid argument size"; case CL_INVALID_KERNEL_ARGS: return "Invalid kernel arguments"; case CL_INVALID_WORK_DIMENSION: return "Invalid work dimension"; case CL_INVALID_WORK_GROUP_SIZE: return "Invalid work group size"; case CL_INVALID_WORK_ITEM_SIZE: return "Invalid work item size"; case CL_INVALID_GLOBAL_OFFSET: return "Invalid global offset"; case CL_INVALID_EVENT_WAIT_LIST: return "Invalid event wait list"; case CL_INVALID_EVENT: return "Invalid event"; case CL_INVALID_OPERATION: return "Invalid operation"; case CL_INVALID_GL_OBJECT: return "Invalid OpenGL object"; case CL_INVALID_BUFFER_SIZE: return "Invalid buffer size"; case CL_INVALID_MIP_LEVEL: return "Invalid mip-map level"; default: return "Unknown"; } } static void opencl_succeed_fatal(cl_int ret, const char *call, const char *file, int line) { if (ret != CL_SUCCESS) { futhark_panic(-1, "%s:%d: OpenCL call\n %s\nfailed with error code %d (%s)\n", file, line, call, ret, opencl_error_string(ret)); } } static char* opencl_succeed_nonfatal(cl_int ret, const char *call, const char *file, int line) { if (ret != CL_SUCCESS) { return msgprintf("%s:%d: OpenCL call\n %s\nfailed with error code %d (%s)\n", file, line, call, ret, opencl_error_string(ret)); } else { return NULL; } } struct futhark_context_config { int in_use; int debugging; int profiling; int logging; char *cache_fname; int num_tuning_params; int64_t *tuning_params; const char** tuning_param_names; const char** tuning_param_vars; const char** tuning_param_classes; // Uniform fields above. char* program; int preferred_device_num; char* preferred_platform; char* preferred_device; int ignore_blacklist; int unified_memory; char* dump_binary_to; char* load_binary_from; int num_build_opts; char* *build_opts; cl_command_queue queue; int queue_set; struct gpu_config gpu; }; static void backend_context_config_setup(struct futhark_context_config* cfg) { cfg->num_build_opts = 0; cfg->build_opts = (char**) malloc(sizeof(const char*)); cfg->build_opts[0] = NULL; cfg->preferred_device_num = 0; cfg->preferred_platform = strdup(""); cfg->preferred_device = strdup(""); cfg->ignore_blacklist = 0; cfg->dump_binary_to = NULL; cfg->load_binary_from = NULL; cfg->program = strconcat(gpu_program); cfg->unified_memory = 2; cfg->gpu = gpu_config_initial; cfg->queue_set = 0; } static void backend_context_config_teardown(struct futhark_context_config* cfg) { for (int i = 0; i < cfg->num_build_opts; i++) { free(cfg->build_opts[i]); } free(cfg->build_opts); free(cfg->dump_binary_to); free(cfg->load_binary_from); free(cfg->preferred_device); free(cfg->preferred_platform); free(cfg->program); } void futhark_context_config_add_build_option(struct futhark_context_config* cfg, const char *opt) { cfg->build_opts[cfg->num_build_opts] = strdup(opt); cfg->num_build_opts++; cfg->build_opts = (char**) realloc(cfg->build_opts, (cfg->num_build_opts+1) * sizeof(char*)); cfg->build_opts[cfg->num_build_opts] = NULL; } void futhark_context_config_set_device(struct futhark_context_config *cfg, const char* s) { int x = 0; if (*s == '#') { s++; while (isdigit(*s)) { x = x * 10 + (*s++)-'0'; } // Skip trailing spaces. while (isspace(*s)) { s++; } } free(cfg->preferred_device); cfg->preferred_device = strdup(s); cfg->preferred_device_num = x; cfg->ignore_blacklist = 1; } void futhark_context_config_set_platform(struct futhark_context_config *cfg, const char *s) { free(cfg->preferred_platform); cfg->preferred_platform = strdup(s); cfg->ignore_blacklist = 1; } void futhark_context_config_set_command_queue(struct futhark_context_config *cfg, cl_command_queue q) { cfg->queue = q; cfg->queue_set = 1; } struct opencl_device_option { cl_platform_id platform; cl_device_id device; cl_device_type device_type; char *platform_name; char *device_name; }; static char* opencl_platform_info(cl_platform_id platform, cl_platform_info param) { size_t req_bytes; char *info; OPENCL_SUCCEED_FATAL(clGetPlatformInfo(platform, param, 0, NULL, &req_bytes)); info = (char*) malloc(req_bytes); OPENCL_SUCCEED_FATAL(clGetPlatformInfo(platform, param, req_bytes, info, NULL)); return info; } static char* opencl_device_info(cl_device_id device, cl_device_info param) { size_t req_bytes; char *info; OPENCL_SUCCEED_FATAL(clGetDeviceInfo(device, param, 0, NULL, &req_bytes)); info = (char*) malloc(req_bytes); OPENCL_SUCCEED_FATAL(clGetDeviceInfo(device, param, req_bytes, info, NULL)); return info; } static int is_blacklisted(const char *platform_name, const char *device_name, const struct futhark_context_config *cfg) { if (strcmp(cfg->preferred_platform, "") != 0 || strcmp(cfg->preferred_device, "") != 0) { return 0; } else if (strstr(platform_name, "Apple") != NULL && strstr(device_name, "Intel(R) Core(TM)") != NULL) { return 1; } else { return 0; } } static void opencl_all_device_options(struct opencl_device_option **devices_out, size_t *num_devices_out) { size_t num_devices = 0, num_devices_added = 0; cl_platform_id *all_platforms; cl_uint *platform_num_devices; cl_uint num_platforms; // Find the number of platforms. OPENCL_SUCCEED_FATAL(clGetPlatformIDs(0, NULL, &num_platforms)); // Make room for them. all_platforms = calloc(num_platforms, sizeof(cl_platform_id)); platform_num_devices = calloc(num_platforms, sizeof(cl_uint)); // Fetch all the platforms. OPENCL_SUCCEED_FATAL(clGetPlatformIDs(num_platforms, all_platforms, NULL)); // Count the number of devices for each platform, as well as the // total number of devices. for (cl_uint i = 0; i < num_platforms; i++) { if (clGetDeviceIDs(all_platforms[i], CL_DEVICE_TYPE_ALL, 0, NULL, &platform_num_devices[i]) == CL_SUCCESS) { num_devices += platform_num_devices[i]; } else { platform_num_devices[i] = 0; } } // Make room for all the device options. struct opencl_device_option *devices = calloc(num_devices, sizeof(struct opencl_device_option)); // Loop through the platforms, getting information about their devices. for (cl_uint i = 0; i < num_platforms; i++) { cl_platform_id platform = all_platforms[i]; cl_uint num_platform_devices = platform_num_devices[i]; if (num_platform_devices == 0) { continue; } char *platform_name = opencl_platform_info(platform, CL_PLATFORM_NAME); cl_device_id *platform_devices = calloc(num_platform_devices, sizeof(cl_device_id)); // Fetch all the devices. OPENCL_SUCCEED_FATAL(clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_platform_devices, platform_devices, NULL)); // Loop through the devices, adding them to the devices array. for (cl_uint i = 0; i < num_platform_devices; i++) { char *device_name = opencl_device_info(platform_devices[i], CL_DEVICE_NAME); devices[num_devices_added].platform = platform; devices[num_devices_added].device = platform_devices[i]; OPENCL_SUCCEED_FATAL(clGetDeviceInfo(platform_devices[i], CL_DEVICE_TYPE, sizeof(cl_device_type), &devices[num_devices_added].device_type, NULL)); // We don't want the structs to share memory, so copy the platform name. // Each device name is already unique. devices[num_devices_added].platform_name = strclone(platform_name); devices[num_devices_added].device_name = device_name; num_devices_added++; } free(platform_devices); free(platform_name); } free(all_platforms); free(platform_num_devices); *devices_out = devices; *num_devices_out = num_devices; } void futhark_context_config_select_device_interactively(struct futhark_context_config *cfg) { struct opencl_device_option *devices; size_t num_devices; opencl_all_device_options(&devices, &num_devices); printf("Choose OpenCL device:\n"); const char *cur_platform = ""; for (size_t i = 0; i < num_devices; i++) { struct opencl_device_option device = devices[i]; if (strcmp(cur_platform, device.platform_name) != 0) { printf("Platform: %s\n", device.platform_name); cur_platform = device.platform_name; } printf("[%d] %s\n", (int)i, device.device_name); } int selection; printf("Choice: "); if (scanf("%d", &selection) == 1) { cfg->preferred_platform = ""; cfg->preferred_device = ""; cfg->preferred_device_num = selection; cfg->ignore_blacklist = 1; } // Free all the platform and device names. for (size_t j = 0; j < num_devices; j++) { free(devices[j].platform_name); free(devices[j].device_name); } free(devices); } void futhark_context_config_list_devices(struct futhark_context_config *cfg) { (void)cfg; struct opencl_device_option *devices; size_t num_devices; opencl_all_device_options(&devices, &num_devices); const char *cur_platform = ""; for (size_t i = 0; i < num_devices; i++) { struct opencl_device_option device = devices[i]; if (strcmp(cur_platform, device.platform_name) != 0) { printf("Platform: %s\n", device.platform_name); cur_platform = device.platform_name; } printf("[%d]: %s\n", (int)i, device.device_name); } // Free all the platform and device names. for (size_t j = 0; j < num_devices; j++) { free(devices[j].platform_name); free(devices[j].device_name); } free(devices); } const char* futhark_context_config_get_program(struct futhark_context_config *cfg) { return cfg->program; } void futhark_context_config_set_program(struct futhark_context_config *cfg, const char *s) { free(cfg->program); cfg->program = strdup(s); } void futhark_context_config_dump_binary_to(struct futhark_context_config *cfg, const char *path) { free(cfg->dump_binary_to); cfg->dump_binary_to = strdup(path); } void futhark_context_config_load_binary_from(struct futhark_context_config *cfg, const char *path) { free(cfg->load_binary_from); cfg->load_binary_from = strdup(path); } void futhark_context_config_set_unified_memory(struct futhark_context_config* cfg, int flag) { cfg->unified_memory = flag; } struct futhark_context { struct futhark_context_config* cfg; int detail_memory; int debugging; int profiling; int profiling_paused; int logging; lock_t lock; char *error; lock_t error_lock; FILE *log; struct constants *constants; struct free_list free_list; struct event_list event_list; int64_t peak_mem_usage_default; int64_t cur_mem_usage_default; struct program* program; bool program_initialised; // Uniform fields above. cl_mem global_failure; cl_mem global_failure_args; struct tuning_params tuning_params; // True if a potentially failing kernel has been enqueued. cl_int failure_is_an_option; int total_runs; long int total_runtime; int64_t peak_mem_usage_device; int64_t cur_mem_usage_device; cl_device_id device; cl_context ctx; cl_command_queue queue; cl_program clprogram; struct free_list gpu_free_list; size_t max_thread_block_size; size_t max_grid_size; size_t max_tile_size; size_t max_threshold; size_t max_shared_memory; size_t max_bespoke; size_t max_registers; size_t max_cache; size_t lockstep_width; struct builtin_kernels* kernels; }; static cl_build_status build_gpu_program(cl_program program, cl_device_id device, const char* options, char** log) { cl_int clBuildProgram_error = clBuildProgram(program, 1, &device, options, NULL, NULL); // Avoid termination due to CL_BUILD_PROGRAM_FAILURE if (clBuildProgram_error != CL_SUCCESS && clBuildProgram_error != CL_BUILD_PROGRAM_FAILURE) { OPENCL_SUCCEED_FATAL(clBuildProgram_error); } cl_build_status build_status; OPENCL_SUCCEED_FATAL(clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_STATUS, sizeof(cl_build_status), &build_status, NULL)); if (build_status != CL_BUILD_SUCCESS) { char *build_log; size_t ret_val_size; OPENCL_SUCCEED_FATAL(clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, 0, NULL, &ret_val_size)); build_log = (char*) malloc(ret_val_size+1); OPENCL_SUCCEED_FATAL(clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, ret_val_size, build_log, NULL)); // The spec technically does not say whether the build log is // zero-terminated, so let's be careful. build_log[ret_val_size] = '\0'; *log = build_log; } return build_status; } static char* mk_compile_opts(struct futhark_context *ctx, const char *extra_build_opts[], struct opencl_device_option device_option) { int compile_opts_size = 1024; for (int i = 0; i < ctx->cfg->num_tuning_params; i++) { compile_opts_size += strlen(ctx->cfg->tuning_param_names[i]) + 20; } char** macro_names; int64_t* macro_vals; int num_macros = gpu_macros(ctx, ¯o_names, ¯o_vals); for (int i = 0; extra_build_opts[i] != NULL; i++) { compile_opts_size += strlen(extra_build_opts[i] + 1); } for (int i = 0; i < num_macros; i++) { compile_opts_size += strlen(macro_names[i]) + 1 + 20; } char *compile_opts = (char*) malloc(compile_opts_size); int w = snprintf(compile_opts, compile_opts_size, "-DLOCKSTEP_WIDTH=%d ", (int)ctx->lockstep_width); w += snprintf(compile_opts+w, compile_opts_size-w, "-D%s=%d ", "max_thread_block_size", (int)ctx->max_thread_block_size); w += snprintf(compile_opts+w, compile_opts_size-w, "-D%s=%d ", "max_shared_memory", (int)ctx->max_shared_memory); w += snprintf(compile_opts+w, compile_opts_size-w, "-D%s=%d ", "max_registers", (int)ctx->max_registers); for (int i = 0; i < ctx->cfg->num_tuning_params; i++) { w += snprintf(compile_opts+w, compile_opts_size-w, "-D%s=%d ", ctx->cfg->tuning_param_vars[i], (int)ctx->cfg->tuning_params[i]); } for (int i = 0; extra_build_opts[i] != NULL; i++) { w += snprintf(compile_opts+w, compile_opts_size-w, "%s ", extra_build_opts[i]); } for (int i = 0; i < num_macros; i++) { w += snprintf(compile_opts+w, compile_opts_size-w, "-D%s=%zu ", macro_names[i], macro_vals[i]); } w += snprintf(compile_opts+w, compile_opts_size-w, "-DTR_BLOCK_DIM=%d -DTR_TILE_DIM=%d -DTR_ELEMS_PER_THREAD=%d ", TR_BLOCK_DIM, TR_TILE_DIM, TR_ELEMS_PER_THREAD); // Oclgrind claims to support cl_khr_fp16, but this is not actually // the case. if (strcmp(device_option.platform_name, "Oclgrind") == 0) { w += snprintf(compile_opts+w, compile_opts_size-w, "-DEMULATE_F16 "); } // By default, OpenCL allows imprecise (but faster) division and // square root operations. For equivalence with other backends, ask // for correctly rounded ones here. w += snprintf(compile_opts+w, compile_opts_size-w, "-cl-fp32-correctly-rounded-divide-sqrt"); free(macro_names); free(macro_vals); return compile_opts; } static cl_event* opencl_event_new(struct futhark_context* ctx) { if (ctx->profiling && !ctx->profiling_paused) { return malloc(sizeof(cl_event)); } else { return NULL; } } static int opencl_event_report(struct str_builder* sb, cl_event* e) { cl_int err; cl_ulong start_t, end_t; assert(e != NULL); OPENCL_SUCCEED_FATAL(clGetEventProfilingInfo(*e, CL_PROFILING_COMMAND_START, sizeof(start_t), &start_t, NULL)); OPENCL_SUCCEED_FATAL(clGetEventProfilingInfo(*e, CL_PROFILING_COMMAND_END, sizeof(end_t), &end_t, NULL)); // OpenCL provides nanosecond resolution, but we want microseconds. str_builder(sb, ",\"duration\":%f", (end_t - start_t)/1000.0); OPENCL_SUCCEED_FATAL(clReleaseEvent(*e)); free(e); return 0; } int futhark_context_sync(struct futhark_context* ctx) { // Check for any delayed error. cl_int failure_idx = -1; if (ctx->failure_is_an_option) { OPENCL_SUCCEED_OR_RETURN( clEnqueueReadBuffer(ctx->queue, ctx->global_failure, CL_FALSE, 0, sizeof(cl_int), &failure_idx, 0, NULL, NULL)); ctx->failure_is_an_option = 0; } OPENCL_SUCCEED_OR_RETURN(clFinish(ctx->queue)); if (failure_idx >= 0) { // We have to clear global_failure so that the next entry point // is not considered a failure from the start. cl_int no_failure = -1; OPENCL_SUCCEED_OR_RETURN( clEnqueueWriteBuffer(ctx->queue, ctx->global_failure, CL_TRUE, 0, sizeof(cl_int), &no_failure, 0, NULL, NULL)); int64_t args[max_failure_args+1]; OPENCL_SUCCEED_OR_RETURN( clEnqueueReadBuffer(ctx->queue, ctx->global_failure_args, CL_TRUE, 0, sizeof(args), &args, 0, NULL, NULL)); ctx->error = get_failure_msg(failure_idx, args); return FUTHARK_PROGRAM_ERROR; } return 0; } // We take as input several strings representing the program, because // C does not guarantee that the compiler supports particularly large // literals. Notably, Visual C has a limit of 2048 characters. The // array must be NULL-terminated. static void setup_opencl_with_command_queue(struct futhark_context *ctx, cl_command_queue queue, const char* extra_build_opts[], const char* cache_fname) { int error; free_list_init(&ctx->gpu_free_list); ctx->queue = queue; OPENCL_SUCCEED_FATAL(clGetCommandQueueInfo(ctx->queue, CL_QUEUE_CONTEXT, sizeof(cl_context), &ctx->ctx, NULL)); // Fill out the device info. This is redundant work if we are // called from setup_opencl() (which is the common case), but I // doubt it matters much. struct opencl_device_option device_option; OPENCL_SUCCEED_FATAL(clGetCommandQueueInfo(ctx->queue, CL_QUEUE_DEVICE, sizeof(cl_device_id), &device_option.device, NULL)); OPENCL_SUCCEED_FATAL(clGetDeviceInfo(device_option.device, CL_DEVICE_PLATFORM, sizeof(cl_platform_id), &device_option.platform, NULL)); OPENCL_SUCCEED_FATAL(clGetDeviceInfo(device_option.device, CL_DEVICE_TYPE, sizeof(cl_device_type), &device_option.device_type, NULL)); device_option.platform_name = opencl_platform_info(device_option.platform, CL_PLATFORM_NAME); device_option.device_name = opencl_device_info(device_option.device, CL_DEVICE_NAME); ctx->device = device_option.device; if (f64_required) { cl_uint supported; OPENCL_SUCCEED_FATAL(clGetDeviceInfo(device_option.device, CL_DEVICE_PREFERRED_VECTOR_WIDTH_DOUBLE, sizeof(cl_uint), &supported, NULL)); if (!supported) { futhark_panic(1, "Program uses double-precision floats, but this is not supported on the chosen device: %s\n", device_option.device_name); } } bool is_amd = strstr(device_option.platform_name, "AMD") != NULL; bool is_nvidia = strstr(device_option.platform_name, "NVIDIA CUDA") != NULL; size_t max_thread_block_size; OPENCL_SUCCEED_FATAL(clGetDeviceInfo(device_option.device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &max_thread_block_size, NULL)); size_t max_tile_size = sqrt(max_thread_block_size); cl_ulong max_shared_memory; OPENCL_SUCCEED_FATAL(clGetDeviceInfo(device_option.device, CL_DEVICE_LOCAL_MEM_SIZE, sizeof(size_t), &max_shared_memory, NULL)); // Futhark reserves 4 bytes for bookkeeping information. max_shared_memory -= 4; // The OpenCL implementation may reserve some local memory bytes for // various purposes. In principle, we should use // clGetKernelWorkGroupInfo() to figure out for each kernel how much // is actually available, but our current code generator design // makes this infeasible. Instead, we have this nasty hack where we // arbitrarily subtract some bytes, based on empirical measurements // (but which might be arbitrarily wrong). Fortunately, we rarely // try to really push the local memory usage. if (is_nvidia) { max_shared_memory -= 12; } else if (is_amd) { max_shared_memory -= 16; } // Make sure this function is defined. post_opencl_setup(ctx, &device_option); if (max_thread_block_size < ctx->cfg->gpu.default_block_size) { if (ctx->cfg->gpu.default_block_size_changed) { fprintf(stderr, "Note: Device limits default group size to %zu (down from %zu).\n", max_thread_block_size, ctx->cfg->gpu.default_block_size); } ctx->cfg->gpu.default_block_size = max_thread_block_size; } if (max_tile_size < ctx->cfg->gpu.default_tile_size) { if (ctx->cfg->gpu.default_tile_size_changed) { fprintf(stderr, "Note: Device limits default tile size to %zu (down from %zu).\n", max_tile_size, ctx->cfg->gpu.default_tile_size); } ctx->cfg->gpu.default_tile_size = max_tile_size; } // Some of the code generated by Futhark will use the L2 cache size // to make very precise decisions about execution. OpenCL does not // specify whether CL_DEVICE_GLOBAL_MEM_CACHE_SIZE is L1 or L2 cache // (or maybe something else entirely). NVIDIA's implementation // reports L2, but AMDs reports L1 (and provides no way to query for // the L2 size). That means it is time to hack. cl_ulong l2_cache_size; if (ctx->cfg->gpu.default_cache != 0) { l2_cache_size = ctx->cfg->gpu.default_cache; } else { cl_ulong opencl_cache_size; OPENCL_SUCCEED_FATAL(clGetDeviceInfo(device_option.device, CL_DEVICE_GLOBAL_MEM_CACHE_SIZE, sizeof(opencl_cache_size), &opencl_cache_size, NULL)); if (is_amd) { // We multiply the L1 cache size with the number of compute units // times 4 (number of SIMD units with GCN). Empirically this // doesn't get us the right result, but it gets us fairly close. cl_ulong compute_units; OPENCL_SUCCEED_FATAL(clGetDeviceInfo(device_option.device, CL_DEVICE_MAX_COMPUTE_UNITS, sizeof(compute_units), &compute_units, NULL)); l2_cache_size = opencl_cache_size * compute_units * 4; } else { l2_cache_size = opencl_cache_size; } if (l2_cache_size == 0) { // Some code assumes nonzero cache. l2_cache_size = 1024*1024; } } ctx->max_thread_block_size = max_thread_block_size; ctx->max_tile_size = max_tile_size; // No limit. ctx->max_threshold = ctx->max_grid_size = 1U<<31; // No limit. if (ctx->cfg->gpu.default_cache != 0) { ctx->max_cache = ctx->cfg->gpu.default_cache; } else { ctx->max_cache = l2_cache_size; } if (ctx->cfg->gpu.default_registers != 0) { ctx->max_registers = ctx->cfg->gpu.default_registers; } else { ctx->max_registers = 1<<16; // I cannot find a way to query for this. } if (ctx->cfg->gpu.default_shared_memory != 0) { ctx->max_shared_memory = ctx->cfg->gpu.default_shared_memory; } else { ctx->max_shared_memory = max_shared_memory; } // Now we go through all the sizes, clamp them to the valid range, // or set them to the default. for (int i = 0; i < ctx->cfg->num_tuning_params; i++) { const char *size_class = ctx->cfg->tuning_param_classes[i]; int64_t *size_value = &ctx->cfg->tuning_params[i]; const char* size_name = ctx->cfg->tuning_param_names[i]; int64_t max_value = 0, default_value = 0; if (strstr(size_class, "thread_block_size") == size_class) { max_value = max_thread_block_size; default_value = ctx->cfg->gpu.default_block_size; } else if (strstr(size_class, "grid_size") == size_class) { max_value = max_thread_block_size; // Futhark assumes this constraint. default_value = ctx->cfg->gpu.default_grid_size; // XXX: as a quick and dirty hack, use twice as many threads for // histograms by default. We really should just be smarter // about sizes somehow. if (strstr(size_name, ".seghist_") != NULL) { default_value *= 2; } } else if (strstr(size_class, "tile_size") == size_class) { max_value = sqrt(max_thread_block_size); default_value = ctx->cfg->gpu.default_tile_size; } else if (strstr(size_class, "reg_tile_size") == size_class) { max_value = 0; // No limit. default_value = ctx->cfg->gpu.default_reg_tile_size; } else if (strstr(size_class, "shared_memory") == size_class) { max_value = ctx->max_shared_memory; default_value = ctx->max_shared_memory; } else if (strstr(size_class, "cache") == size_class) { max_value = ctx->max_cache; default_value = ctx->max_cache; } else if (strstr(size_class, "threshold") == size_class) { // Threshold can be as large as it takes. default_value = ctx->cfg->gpu.default_threshold; } else { // Bespoke sizes have no limit or default. } if (*size_value == 0) { *size_value = default_value; } else if (max_value > 0 && *size_value > max_value) { fprintf(stderr, "Note: Device limits %s to %d (down from %d)\n", size_name, (int)max_value, (int)*size_value); *size_value = max_value; } } if (ctx->lockstep_width == 0) { ctx->lockstep_width = 1; } gpu_init_log(ctx); char *compile_opts = mk_compile_opts(ctx, extra_build_opts, device_option); if (ctx->cfg->logging) { fprintf(stderr, "OpenCL compiler options: %s\n", compile_opts); } const char* opencl_src = ctx->cfg->program; cl_program prog; error = CL_SUCCESS; struct cache_hash h; int loaded_from_cache = 0; if (ctx->cfg->load_binary_from == NULL) { size_t src_size = 0; if (cache_fname != NULL) { if (ctx->cfg->logging) { fprintf(stderr, "Restoring cache from from %s...\n", cache_fname); } cache_hash_init(&h); cache_hash(&h, opencl_src, strlen(opencl_src)); cache_hash(&h, compile_opts, strlen(compile_opts)); unsigned char *buf; size_t bufsize; errno = 0; if (cache_restore(cache_fname, &h, &buf, &bufsize) != 0) { if (ctx->cfg->logging) { fprintf(stderr, "Failed to restore cache (errno: %s)\n", strerror(errno)); } } else { if (ctx->cfg->logging) { fprintf(stderr, "Cache restored; loading OpenCL binary...\n"); } cl_int status = 0; prog = clCreateProgramWithBinary(ctx->ctx, 1, &device_option.device, &bufsize, (const unsigned char**)&buf, &status, &error); if (status == CL_SUCCESS) { loaded_from_cache = 1; if (ctx->cfg->logging) { fprintf(stderr, "Loading succeeded.\n"); } } else { if (ctx->cfg->logging) { fprintf(stderr, "Loading failed.\n"); } } } } if (!loaded_from_cache) { if (ctx->cfg->logging) { fprintf(stderr, "Creating OpenCL program...\n"); } const char* src_ptr[] = {opencl_src}; prog = clCreateProgramWithSource(ctx->ctx, 1, src_ptr, &src_size, &error); OPENCL_SUCCEED_FATAL(error); } } else { if (ctx->cfg->logging) { fprintf(stderr, "Loading OpenCL binary from %s...\n", ctx->cfg->load_binary_from); } size_t binary_size; unsigned char *fut_opencl_bin = (unsigned char*) slurp_file(ctx->cfg->load_binary_from, &binary_size); assert(fut_opencl_bin != NULL); const unsigned char *binaries[1] = { fut_opencl_bin }; cl_int status = 0; prog = clCreateProgramWithBinary(ctx->ctx, 1, &device_option.device, &binary_size, binaries, &status, &error); OPENCL_SUCCEED_FATAL(status); OPENCL_SUCCEED_FATAL(error); } if (ctx->cfg->logging) { fprintf(stderr, "Building OpenCL program...\n"); } char* build_log; cl_build_status status = build_gpu_program(prog, device_option.device, compile_opts, &build_log); free(compile_opts); if (status != CL_BUILD_SUCCESS) { ctx->error = msgprintf("Compilation of OpenCL program failed.\nBuild log:\n%s", build_log); // We are giving up on initialising this OpenCL context. That also // means we need to free all the OpenCL bits we have managed to // allocate thus far, as futhark_context_free() will not touch // these unless initialisation was completely successful. (void)clReleaseProgram(prog); (void)clReleaseCommandQueue(ctx->queue); (void)clReleaseContext(ctx->ctx); free(build_log); return; } size_t binary_size = 0; unsigned char *binary = NULL; int store_in_cache = cache_fname != NULL && !loaded_from_cache; if (store_in_cache || ctx->cfg->dump_binary_to != NULL) { OPENCL_SUCCEED_FATAL(clGetProgramInfo(prog, CL_PROGRAM_BINARY_SIZES, sizeof(size_t), &binary_size, NULL)); binary = (unsigned char*) malloc(binary_size); OPENCL_SUCCEED_FATAL(clGetProgramInfo(prog, CL_PROGRAM_BINARIES, sizeof(unsigned char*), &binary, NULL)); } if (store_in_cache) { if (ctx->cfg->logging) { fprintf(stderr, "Caching OpenCL binary in %s...\n", cache_fname); } if (cache_store(cache_fname, &h, binary, binary_size) != 0) { printf("Failed to cache binary: %s\n", strerror(errno)); } } if (ctx->cfg->dump_binary_to != NULL) { if (ctx->cfg->logging) { fprintf(stderr, "Dumping OpenCL binary to %s...\n", ctx->cfg->dump_binary_to); } dump_file(ctx->cfg->dump_binary_to, binary, binary_size); } ctx->clprogram = prog; } static struct opencl_device_option get_preferred_device(struct futhark_context *ctx, const struct futhark_context_config *cfg) { struct opencl_device_option *devices; size_t num_devices; opencl_all_device_options(&devices, &num_devices); int num_device_matches = 0; for (size_t i = 0; i < num_devices; i++) { struct opencl_device_option device = devices[i]; if (strstr(device.platform_name, cfg->preferred_platform) != NULL && strstr(device.device_name, cfg->preferred_device) != NULL && (cfg->ignore_blacklist || !is_blacklisted(device.platform_name, device.device_name, cfg)) && num_device_matches++ == cfg->preferred_device_num) { // Free all the platform and device names, except the ones we have chosen. for (size_t j = 0; j < num_devices; j++) { if (j != i) { free(devices[j].platform_name); free(devices[j].device_name); } } free(devices); return device; } } ctx->error = strdup("Could not find acceptable OpenCL device.\n"); struct opencl_device_option device; return device; } static void setup_opencl(struct futhark_context *ctx, const char *extra_build_opts[], const char* cache_fname) { struct opencl_device_option device_option = get_preferred_device(ctx, ctx->cfg); if (ctx->error != NULL) { return; } if (ctx->cfg->logging) { fprintf(stderr, "Using platform: %s\n", device_option.platform_name); fprintf(stderr, "Using device: %s\n", device_option.device_name); } // Note that NVIDIA's OpenCL requires the platform property cl_context_properties properties[] = { CL_CONTEXT_PLATFORM, (cl_context_properties)device_option.platform, 0 }; cl_int clCreateContext_error; ctx->ctx = clCreateContext(properties, 1, &device_option.device, NULL, NULL, &clCreateContext_error); OPENCL_SUCCEED_FATAL(clCreateContext_error); cl_int clCreateCommandQueue_error; cl_command_queue queue = clCreateCommandQueue(ctx->ctx, device_option.device, ctx->cfg->profiling ? CL_QUEUE_PROFILING_ENABLE : 0, &clCreateCommandQueue_error); OPENCL_SUCCEED_FATAL(clCreateCommandQueue_error); setup_opencl_with_command_queue(ctx, queue, extra_build_opts, cache_fname); } struct builtin_kernels* init_builtin_kernels(struct futhark_context* ctx); void free_builtin_kernels(struct futhark_context* ctx, struct builtin_kernels* kernels); int backend_context_setup(struct futhark_context* ctx) { ctx->lockstep_width = 0; // Real value set later. ctx->failure_is_an_option = 0; ctx->total_runs = 0; ctx->total_runtime = 0; ctx->peak_mem_usage_device = 0; ctx->cur_mem_usage_device = 0; ctx->kernels = NULL; if (ctx->cfg->queue_set) { setup_opencl_with_command_queue(ctx, ctx->cfg->queue, (const char**)ctx->cfg->build_opts, ctx->cfg->cache_fname); } else { setup_opencl(ctx, (const char**)ctx->cfg->build_opts, ctx->cfg->cache_fname); } if (ctx->error != NULL) { return 1; } cl_int error; cl_int no_error = -1; ctx->global_failure = clCreateBuffer(ctx->ctx, CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR, sizeof(cl_int), &no_error, &error); OPENCL_SUCCEED_OR_RETURN(error); // The +1 is to avoid zero-byte allocations. ctx->global_failure_args = clCreateBuffer(ctx->ctx, CL_MEM_READ_WRITE, sizeof(int64_t)*(max_failure_args+1), NULL, &error); OPENCL_SUCCEED_OR_RETURN(error); if ((ctx->kernels = init_builtin_kernels(ctx)) == NULL) { return 1; } return FUTHARK_SUCCESS; } static int gpu_free_all(struct futhark_context *ctx); void backend_context_teardown(struct futhark_context* ctx) { if (ctx->kernels != NULL) { free_builtin_kernels(ctx, ctx->kernels); OPENCL_SUCCEED_FATAL(clReleaseMemObject(ctx->global_failure)); OPENCL_SUCCEED_FATAL(clReleaseMemObject(ctx->global_failure_args)); (void)gpu_free_all(ctx); (void)clReleaseProgram(ctx->clprogram); (void)clReleaseCommandQueue(ctx->queue); (void)clReleaseContext(ctx->ctx); } free_list_destroy(&ctx->gpu_free_list); } cl_command_queue futhark_context_get_command_queue(struct futhark_context* ctx) { return ctx->queue; } // GPU ABSTRACTION LAYER // Types. typedef cl_kernel gpu_kernel; typedef cl_mem gpu_mem; static void gpu_create_kernel(struct futhark_context *ctx, gpu_kernel* kernel, const char* name) { if (ctx->debugging) { fprintf(ctx->log, "Creating kernel %s.\n", name); } cl_int error; *kernel = clCreateKernel(ctx->clprogram, name, &error); OPENCL_SUCCEED_FATAL(error); } static void gpu_free_kernel(struct futhark_context *ctx, gpu_kernel kernel) { (void)ctx; clReleaseKernel(kernel); } static int gpu_scalar_to_device(struct futhark_context* ctx, gpu_mem dst, size_t offset, size_t size, void *src) { cl_event* event = opencl_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_scalar_to_dev", strdup(""), event, (event_report_fn)opencl_event_report); } OPENCL_SUCCEED_OR_RETURN (clEnqueueWriteBuffer (ctx->queue, dst, CL_TRUE, offset, size, src, 0, NULL, event)); return 0; } static int gpu_scalar_from_device(struct futhark_context* ctx, void *dst, gpu_mem src, size_t offset, size_t size) { cl_event* event = opencl_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_scalar_from_dev", strdup(""), event, (event_report_fn)opencl_event_report); } OPENCL_SUCCEED_OR_RETURN (clEnqueueReadBuffer (ctx->queue, src, ctx->failure_is_an_option ? CL_FALSE : CL_TRUE, offset, size, dst, 0, NULL, event)); return 0; } static int gpu_memcpy(struct futhark_context* ctx, gpu_mem dst, int64_t dst_offset, gpu_mem src, int64_t src_offset, int64_t nbytes) { if (nbytes > 0) { cl_event* event = opencl_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_dev_to_dev", strdup(""), event, (event_report_fn)opencl_event_report); } // OpenCL swaps the usual order of operands for memcpy()-like // functions. The order below is not a typo. OPENCL_SUCCEED_OR_RETURN (clEnqueueCopyBuffer (ctx->queue, src, dst, src_offset, dst_offset, nbytes, 0, NULL, event)); if (ctx->debugging) { OPENCL_SUCCEED_FATAL(clFinish(ctx->queue)); } } return FUTHARK_SUCCESS; } static int memcpy_host2gpu(struct futhark_context* ctx, bool sync, gpu_mem dst, int64_t dst_offset, const unsigned char* src, int64_t src_offset, int64_t nbytes) { if (nbytes > 0) { cl_event* event = opencl_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_host_to_dev", strdup(""), event, (event_report_fn)opencl_event_report); } OPENCL_SUCCEED_OR_RETURN (clEnqueueWriteBuffer(ctx->queue, dst, sync ? CL_TRUE : CL_FALSE, (size_t)dst_offset, (size_t)nbytes, src + src_offset, 0, NULL, event)); if (ctx->debugging) { OPENCL_SUCCEED_FATAL(clFinish(ctx->queue)); } } return FUTHARK_SUCCESS; } static int memcpy_gpu2host(struct futhark_context* ctx, bool sync, unsigned char* dst, int64_t dst_offset, gpu_mem src, int64_t src_offset, int64_t nbytes) { if (nbytes > 0) { cl_event* event = opencl_event_new(ctx); if (event != NULL) { add_event(ctx, "copy_dev_to_host", strdup(""), event, (event_report_fn)opencl_event_report); } OPENCL_SUCCEED_OR_RETURN (clEnqueueReadBuffer(ctx->queue, src, ctx->failure_is_an_option ? CL_FALSE : sync ? CL_TRUE : CL_FALSE, src_offset, nbytes, dst + dst_offset, 0, NULL, event)); if (sync && ctx->failure_is_an_option && futhark_context_sync(ctx) != 0) { return 1; } } return FUTHARK_SUCCESS; } static int gpu_launch_kernel(struct futhark_context* ctx, gpu_kernel kernel, const char *name, const int32_t grid[3], const int32_t block[3], unsigned int shared_mem_bytes, int num_args, void* args[num_args], size_t args_sizes[num_args]) { if (shared_mem_bytes > ctx->max_shared_memory) { set_error(ctx, msgprintf("Kernel %s with %d bytes of memory exceeds device limit of %d\n", name, shared_mem_bytes, (int)ctx->max_shared_memory)); return 1; } int64_t time_start = 0, time_end = 0; cl_event* event = opencl_event_new(ctx); if (event != NULL) { add_event(ctx, name, msgprintf("Kernel %s with\n" " grid=(%d,%d,%d)\n" " block=(%d,%d,%d)\n" " shared memory=%d", name, grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem_bytes), event, (event_report_fn)opencl_event_report); } if (ctx->debugging) { time_start = get_wall_time(); } // Some implementations do not work with 0-byte shared memory. if (shared_mem_bytes == 0) { shared_mem_bytes = 4; } OPENCL_SUCCEED_OR_RETURN (clSetKernelArg(kernel, 0, shared_mem_bytes, NULL)); for (int i = 0; i < num_args; i++) { OPENCL_SUCCEED_OR_RETURN (clSetKernelArg(kernel, i+1, args_sizes[i], args[i])); } const size_t global_work_size[3] = {(size_t)grid[0]*block[0], (size_t)grid[1]*block[1], (size_t)grid[2]*block[2]}; const size_t local_work_size[3] = {block[0], block[1], block[2]}; OPENCL_SUCCEED_OR_RETURN (clEnqueueNDRangeKernel(ctx->queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, event)); if (ctx->debugging) { OPENCL_SUCCEED_FATAL(clFinish(ctx->queue)); time_end = get_wall_time(); long int time_diff = time_end - time_start; fprintf(ctx->log, " runtime: %ldus\n", time_diff); } if (ctx->logging) { fprintf(ctx->log, "\n"); } return FUTHARK_SUCCESS; } // Allocate memory from driver. The problem is that OpenCL may perform // lazy allocation, so we cannot know whether an allocation succeeded // until the first time we try to use it. Hence we immediately // perform a write to see if the allocation succeeded. This is slow, // but the assumption is that this operation will be rare (most things // will go through the free list). static int gpu_alloc_actual(struct futhark_context *ctx, size_t size, gpu_mem *mem_out) { int error; *mem_out = clCreateBuffer(ctx->ctx, CL_MEM_READ_WRITE, size, NULL, &error); OPENCL_SUCCEED_OR_RETURN(error); int x = 2; error = clEnqueueWriteBuffer(ctx->queue, *mem_out, CL_TRUE, 0, sizeof(x), &x, 0, NULL, NULL); // No need to wait for completion here. clWaitForEvents() cannot // return mem object allocation failures. This implies that the // buffer is faulted onto the device on enqueue. (Observation by // Andreas Kloeckner.) if (error == CL_MEM_OBJECT_ALLOCATION_FAILURE) { return FUTHARK_OUT_OF_MEMORY; } OPENCL_SUCCEED_OR_RETURN(error); return FUTHARK_SUCCESS; } static int gpu_free_actual(struct futhark_context *ctx, gpu_mem mem) { (void)ctx; OPENCL_SUCCEED_OR_RETURN(clReleaseMemObject(mem)); return FUTHARK_SUCCESS; } // End of backends/opencl.h futhark-0.25.27/rts/c/cache.h000066400000000000000000000073121475065116200156130ustar00rootroot00000000000000// Start of cache.h #define CACHE_HASH_SIZE 8 // In 32-bit words. struct cache_hash { uint32_t hash[CACHE_HASH_SIZE]; }; // Initialise a blank cache. static void cache_hash_init(struct cache_hash *c); // Hash some bytes and add them to the accumulated hash. static void cache_hash(struct cache_hash *out, const char *in, size_t n); // Try to restore cache contents from a file with the given name. // Assumes the cache is invalid if it contains the given hash. // Allocates memory and reads the cache conents, which is returned in // *buf with size *buflen. If the cache is successfully loaded, this // function returns 0. Otherwise it returns nonzero. Errno is set if // the failure to load the cache is due to anything except invalid // cache conents. Note that failing to restore the cache is not // necessarily a problem: it might just be invalid or not created yet. static int cache_restore(const char *fname, const struct cache_hash *hash, unsigned char **buf, size_t *buflen); // Store cache contents in the given file, with the given hash. static int cache_store(const char *fname, const struct cache_hash *hash, const unsigned char *buf, size_t buflen); // Now for the implementation. static void cache_hash_init(struct cache_hash *c) { memset(c->hash, 0, CACHE_HASH_SIZE * sizeof(uint32_t)); } static void cache_hash(struct cache_hash *out, const char *in, size_t n) { // Adaptation of djb2 for larger output size by storing intermediate // states. uint32_t hash = 5381; for (size_t i = 0; i < n; i++) { hash = ((hash << 5) + hash) + in[i]; out->hash[i % CACHE_HASH_SIZE] ^= hash; } } #define CACHE_HEADER_SIZE 8 static const char cache_header[CACHE_HEADER_SIZE] = "FUTHARK\0"; static int cache_restore(const char *fname, const struct cache_hash *hash, unsigned char **buf, size_t *buflen) { FILE *f = fopen(fname, "rb"); if (f == NULL) { return 1; } char f_header[CACHE_HEADER_SIZE]; if (fread(f_header, sizeof(char), CACHE_HEADER_SIZE, f) != CACHE_HEADER_SIZE) { goto error; } if (memcmp(f_header, cache_header, CACHE_HEADER_SIZE) != 0) { goto error; } if (fseek(f, 0, SEEK_END) != 0) { goto error; } int64_t f_size = (int64_t)ftell(f); if (fseek(f, CACHE_HEADER_SIZE, SEEK_SET) != 0) { goto error; } int64_t expected_size; if (fread(&expected_size, sizeof(int64_t), 1, f) != 1) { goto error; } if (f_size != expected_size) { errno = 0; goto error; } int32_t f_hash[CACHE_HASH_SIZE]; if (fread(f_hash, sizeof(int32_t), CACHE_HASH_SIZE, f) != CACHE_HASH_SIZE) { goto error; } if (memcmp(f_hash, hash->hash, CACHE_HASH_SIZE) != 0) { errno = 0; goto error; } *buflen = f_size - CACHE_HEADER_SIZE - sizeof(int64_t) - CACHE_HASH_SIZE*sizeof(int32_t); *buf = malloc(*buflen); if (fread(*buf, sizeof(char), *buflen, f) != *buflen) { free(*buf); goto error; } fclose(f); return 0; error: fclose(f); return 1; } static int cache_store(const char *fname, const struct cache_hash *hash, const unsigned char *buf, size_t buflen) { FILE *f = fopen(fname, "wb"); if (f == NULL) { return 1; } if (fwrite(cache_header, CACHE_HEADER_SIZE, 1, f) != 1) { goto error; } int64_t size = CACHE_HEADER_SIZE + sizeof(int64_t) + CACHE_HASH_SIZE*sizeof(int32_t) + buflen; if (fwrite(&size, sizeof(size), 1, f) != 1) { goto error; } if (fwrite(hash->hash, sizeof(int32_t), CACHE_HASH_SIZE, f) != CACHE_HASH_SIZE) { goto error; } if (fwrite(buf, sizeof(unsigned char), buflen, f) != buflen) { goto error; } fclose(f); return 0; error: fclose(f); return 1; } // End of cache.h futhark-0.25.27/rts/c/context.h000066400000000000000000000136511475065116200162370ustar00rootroot00000000000000// Start of context.h // Internal functions. static void set_error(struct futhark_context* ctx, char *error) { lock_lock(&ctx->error_lock); if (ctx->error == NULL) { ctx->error = error; } else { free(error); } lock_unlock(&ctx->error_lock); } // XXX: should be static, but used in ispc_util.h void lexical_realloc_error(struct futhark_context* ctx, size_t new_size) { set_error(ctx, msgprintf("Failed to allocate memory.\nAttempted allocation: %12lld bytes\n", (long long) new_size)); } static int lexical_realloc(struct futhark_context *ctx, unsigned char **ptr, int64_t *old_size, int64_t new_size) { unsigned char *new = realloc(*ptr, (size_t)new_size); if (new == NULL) { lexical_realloc_error(ctx, new_size); return FUTHARK_OUT_OF_MEMORY; } else { *ptr = new; *old_size = new_size; return FUTHARK_SUCCESS; } } static void free_all_in_free_list(struct futhark_context* ctx) { fl_mem mem; free_list_pack(&ctx->free_list); while (free_list_first(&ctx->free_list, (fl_mem*)&mem) == 0) { free((void*)mem); } } static int is_small_alloc(size_t size) { return size < 1024*1024; } static void host_alloc(struct futhark_context* ctx, size_t size, const char* tag, size_t* size_out, void** mem_out) { if (is_small_alloc(size) || free_list_find(&ctx->free_list, size, tag, size_out, (fl_mem*)mem_out) != 0) { *size_out = size; *mem_out = malloc(size); } } static void host_free(struct futhark_context* ctx, size_t size, const char* tag, void* mem) { // Small allocations are handled by malloc()s own free list. The // threshold here is kind of arbitrary, but seems to work OK. // Larger allocations are mmap()ed/munmapped() every time, which is // very slow, and Futhark programs tend to use a few very large // allocations. if (is_small_alloc(size)) { free(mem); } else { free_list_insert(&ctx->free_list, size, (fl_mem)mem, tag); } } static void add_event(struct futhark_context* ctx, const char* name, char* description, void* data, event_report_fn f) { if (ctx->logging) { fprintf(ctx->log, "Event: %s\n%s\n", name, description); } add_event_to_list(&ctx->event_list, name, description, data, f); } char *futhark_context_get_error(struct futhark_context *ctx) { char *error = ctx->error; ctx->error = NULL; return error; } void futhark_context_config_set_debugging(struct futhark_context_config *cfg, int flag) { cfg->profiling = cfg->logging = cfg->debugging = flag; } void futhark_context_config_set_profiling(struct futhark_context_config *cfg, int flag) { cfg->profiling = flag; } void futhark_context_config_set_logging(struct futhark_context_config *cfg, int flag) { cfg->logging = flag; } void futhark_context_config_set_cache_file(struct futhark_context_config *cfg, const char *f) { cfg->cache_fname = strdup(f); } int futhark_get_tuning_param_count(void) { return num_tuning_params; } const char *futhark_get_tuning_param_name(int i) { return tuning_param_names[i]; } const char *futhark_get_tuning_param_class(int i) { return tuning_param_classes[i]; } void futhark_context_set_logging_file(struct futhark_context *ctx, FILE *f){ ctx->log = f; } void futhark_context_pause_profiling(struct futhark_context *ctx) { ctx->profiling_paused = 1; } void futhark_context_unpause_profiling(struct futhark_context *ctx) { ctx->profiling_paused = 0; } struct futhark_context_config* futhark_context_config_new(void) { struct futhark_context_config* cfg = malloc(sizeof(struct futhark_context_config)); if (cfg == NULL) { return NULL; } cfg->in_use = 0; cfg->debugging = 0; cfg->profiling = 0; cfg->logging = 0; cfg->cache_fname = NULL; cfg->num_tuning_params = num_tuning_params; cfg->tuning_params = malloc(cfg->num_tuning_params * sizeof(int64_t)); memcpy(cfg->tuning_params, tuning_param_defaults, cfg->num_tuning_params * sizeof(int64_t)); cfg->tuning_param_names = tuning_param_names; cfg->tuning_param_vars = tuning_param_vars; cfg->tuning_param_classes = tuning_param_classes; backend_context_config_setup(cfg); return cfg; } void futhark_context_config_free(struct futhark_context_config* cfg) { assert(!cfg->in_use); backend_context_config_teardown(cfg); free(cfg->cache_fname); free(cfg->tuning_params); free(cfg); } struct futhark_context* futhark_context_new(struct futhark_context_config* cfg) { struct futhark_context* ctx = malloc(sizeof(struct futhark_context)); if (ctx == NULL) { return NULL; } assert(!cfg->in_use); ctx->cfg = cfg; ctx->cfg->in_use = 1; ctx->program_initialised = false; create_lock(&ctx->error_lock); create_lock(&ctx->lock); free_list_init(&ctx->free_list); event_list_init(&ctx->event_list); ctx->peak_mem_usage_default = 0; ctx->cur_mem_usage_default = 0; ctx->constants = malloc(sizeof(struct constants)); ctx->debugging = cfg->debugging; ctx->logging = cfg->logging; ctx->detail_memory = cfg->logging; ctx->profiling = cfg->profiling; ctx->profiling_paused = 0; ctx->error = NULL; ctx->log = stderr; set_tuning_params(ctx); if (backend_context_setup(ctx) == 0) { setup_program(ctx); init_constants(ctx); ctx->program_initialised = true; (void)futhark_context_clear_caches(ctx); (void)futhark_context_sync(ctx); } return ctx; } void futhark_context_free(struct futhark_context* ctx) { if (ctx->program_initialised) { free_constants(ctx); teardown_program(ctx); } backend_context_teardown(ctx); free_all_in_free_list(ctx); free_list_destroy(&ctx->free_list); event_list_free(&ctx->event_list); free(ctx->constants); free(ctx->error); free_lock(&ctx->lock); free_lock(&ctx->error_lock); ctx->cfg->in_use = 0; free(ctx); } // End of context.h futhark-0.25.27/rts/c/context_prototypes.h000066400000000000000000000046011475065116200205420ustar00rootroot00000000000000// Start of context_prototypes.h // // Prototypes for the functions in context.h, or that will be called // from those functions, that need to be available very early. struct futhark_context_config; struct futhark_context; static void set_error(struct futhark_context* ctx, char *error); // These are called in context/config new/free functions and contain // shared setup. They are generated by the compiler itself. static int init_constants(struct futhark_context*); static int free_constants(struct futhark_context*); static void setup_program(struct futhark_context* ctx); static void teardown_program(struct futhark_context *ctx); // Allocate host memory. Must be freed with host_free(). static void host_alloc(struct futhark_context* ctx, size_t size, const char* tag, size_t* size_out, void** mem_out); // Allocate memory allocated with host_alloc(). static void host_free(struct futhark_context* ctx, size_t size, const char* tag, void* mem); // Log that a copy has occurred. static void log_copy(struct futhark_context* ctx, const char *kind, int r, int64_t dst_offset, int64_t dst_strides[r], int64_t src_offset, int64_t src_strides[r], int64_t shape[r]); static void log_transpose(struct futhark_context* ctx, int64_t k, int64_t m, int64_t n); static bool lmad_map_tr(int64_t *num_arrays_out, int64_t *n_out, int64_t *m_out, int r, const int64_t dst_strides[r], const int64_t src_strides[r], const int64_t shape[r]); static bool lmad_contiguous(int r, int64_t strides[r], int64_t shape[r]); static bool lmad_memcpyable(int r, int64_t dst_strides[r], int64_t src_strides[r], int64_t shape[r]); static void add_event(struct futhark_context* ctx, const char* name, char* description, void* data, event_report_fn f); // Functions that must be defined by the backend. static void backend_context_config_setup(struct futhark_context_config* cfg); static void backend_context_config_teardown(struct futhark_context_config* cfg); static int backend_context_setup(struct futhark_context *ctx); static void backend_context_teardown(struct futhark_context *ctx); // End of of context_prototypes.h futhark-0.25.27/rts/c/copy.h000066400000000000000000000265451475065116200155330ustar00rootroot00000000000000// Start of copy.h // Cache-oblivious map-transpose function. #define GEN_MAP_TRANSPOSE(NAME, ELEM_TYPE) \ static void map_transpose_##NAME \ (ELEM_TYPE* dst, ELEM_TYPE* src, \ int64_t k, int64_t m, int64_t n, \ int64_t cb, int64_t ce, int64_t rb, int64_t re) \ { \ int32_t r = re - rb; \ int32_t c = ce - cb; \ if (k == 1) { \ if (r <= 64 && c <= 64) { \ for (int64_t j = 0; j < c; j++) { \ for (int64_t i = 0; i < r; i++) { \ dst[(j + cb) * n + (i + rb)] = src[(i + rb) * m + (j + cb)]; \ } \ } \ } else if (c <= r) { \ map_transpose_##NAME(dst, src, k, m, n, cb, ce, rb, rb + r/2); \ map_transpose_##NAME(dst, src, k, m, n, cb, ce, rb + r/2, re); \ } else { \ map_transpose_##NAME(dst, src, k, m, n, cb, cb + c/2, rb, re); \ map_transpose_##NAME(dst, src, k, m, n, cb + c/2, ce, rb, re); \ } \ } else { \ for (int64_t i = 0; i < k; i++) { \ map_transpose_##NAME(dst + i * m * n, src + i * m * n, 1, m, n, cb, ce, rb, re); \ }\ } \ } // Straightforward LMAD copy function. #define GEN_LMAD_COPY_ELEMENTS(NAME, ELEM_TYPE) \ static void lmad_copy_elements_##NAME(int r, \ ELEM_TYPE* dst, int64_t dst_strides[r], \ ELEM_TYPE *src, int64_t src_strides[r], \ int64_t shape[r]) { \ if (r == 1) { \ for (int i = 0; i < shape[0]; i++) { \ dst[i*dst_strides[0]] = src[i*src_strides[0]]; \ } \ } else if (r > 1) { \ for (int i = 0; i < shape[0]; i++) { \ lmad_copy_elements_##NAME(r-1, \ dst+i*dst_strides[0], dst_strides+1, \ src+i*src_strides[0], src_strides+1, \ shape+1); \ } \ } \ } \ // Check whether this LMAD can be seen as a transposed 2D array. This // is done by checking every possible splitting point. static bool lmad_is_tr(int64_t *n_out, int64_t *m_out, int r, const int64_t strides[r], const int64_t shape[r]) { for (int i = 1; i < r; i++) { int n = 1, m = 1; bool ok = true; int64_t expected = 1; // Check strides before 'i'. for (int j = i-1; j >= 0; j--) { ok = ok && strides[j] == expected; expected *= shape[j]; n *= shape[j]; } // Check strides after 'i'. for (int j = r-1; j >= i; j--) { ok = ok && strides[j] == expected; expected *= shape[j]; m *= shape[j]; } if (ok) { *n_out = n; *m_out = m; return true; } } return false; } // This function determines whether the a 'dst' LMAD is row-major and // 'src' LMAD is column-major. Both LMADs are for arrays of the same // shape. Both LMADs are allowed to have additional dimensions "on // top". Essentially, this function determines whether a copy from // 'src' to 'dst' is a "map(transpose)" that we know how to implement // efficiently. The LMADs can have arbitrary rank, and the main // challenge here is checking whether the src LMAD actually // corresponds to a 2D column-major layout by morally collapsing // dimensions. There is a lot of looping here, but the actual trip // count is going to be very low in practice. // // Returns true if this is indeed a map(transpose), and writes the // number of arrays, and moral array size to appropriate output // parameters. static bool lmad_map_tr(int64_t *num_arrays_out, int64_t *n_out, int64_t *m_out, int r, const int64_t dst_strides[r], const int64_t src_strides[r], const int64_t shape[r]) { int64_t rowmajor_strides[r]; rowmajor_strides[r-1] = 1; for (int i = r-2; i >= 0; i--) { rowmajor_strides[i] = rowmajor_strides[i+1] * shape[i+1]; } // map_r will be the number of mapped dimensions on top. int map_r = 0; int64_t num_arrays = 1; for (int i = 0; i < r; i++) { if (dst_strides[i] != rowmajor_strides[i] || src_strides[i] != rowmajor_strides[i]) { break; } else { num_arrays *= shape[i]; map_r++; } } *num_arrays_out = num_arrays; if (r==map_r) { return false; } if (memcmp(&rowmajor_strides[map_r], &dst_strides[map_r], sizeof(int64_t)*(r-map_r)) == 0) { return lmad_is_tr(n_out, m_out, r-map_r, src_strides+map_r, shape+map_r); } else if (memcmp(&rowmajor_strides[map_r], &src_strides[map_r], sizeof(int64_t)*(r-map_r)) == 0) { return lmad_is_tr(m_out, n_out, r-map_r, dst_strides+map_r, shape+map_r); } return false; } // Check if the strides correspond to row-major strides of *any* // permutation of the shape. This is done by recursive search with // backtracking. This is worst-case exponential, but hopefully the // arrays we encounter do not have that many dimensions. static bool lmad_contiguous_search(int checked, int64_t expected, int r, int64_t strides[r], int64_t shape[r], bool used[r]) { for (int i = 0; i < r; i++) { for (int j = 0; j < r; j++) { if (!used[j] && strides[j] == expected && strides[j] >= 0) { used[j] = true; if (checked+1 == r || lmad_contiguous_search(checked+1, expected * shape[j], r, strides, shape, used)) { return true; } used[j] = false; } } } return false; } // Does this LMAD correspond to an array with positive strides and no // holes? static bool lmad_contiguous(int r, int64_t strides[r], int64_t shape[r]) { bool used[r]; for (int i = 0; i < r; i++) { used[i] = false; } return lmad_contiguous_search(0, 1, r, strides, shape, used); } // Does this copy correspond to something that could be done with a // memcpy()-like operation? I.e. do the LMADs actually represent the // same in-memory layout and are they contiguous? static bool lmad_memcpyable(int r, int64_t dst_strides[r], int64_t src_strides[r], int64_t shape[r]) { if (!lmad_contiguous(r, dst_strides, shape)) { return false; } for (int i = 0; i < r; i++) { if (dst_strides[i] != src_strides[i] && shape[i] != 1) { return false; } } return true; } static void log_copy(struct futhark_context* ctx, const char *kind, int r, int64_t dst_offset, int64_t dst_strides[r], int64_t src_offset, int64_t src_strides[r], int64_t shape[r]) { if (ctx->logging) { fprintf(ctx->log, "\n# Copy %s\n", kind); fprintf(ctx->log, "Shape: "); for (int i = 0; i < r; i++) { fprintf(ctx->log, "[%ld]", (long int)shape[i]); } fprintf(ctx->log, "\n"); fprintf(ctx->log, "Dst offset: %ld\n", (long int)dst_offset); fprintf(ctx->log, "Dst strides:"); for (int i = 0; i < r; i++) { fprintf(ctx->log, " %ld", (long int)dst_strides[i]); } fprintf(ctx->log, "\n"); fprintf(ctx->log, "Src offset: %ld\n", (long int)src_offset); fprintf(ctx->log, "Src strides:"); for (int i = 0; i < r; i++) { fprintf(ctx->log, " %ld", (long int)src_strides[i]); } fprintf(ctx->log, "\n"); } } static void log_transpose(struct futhark_context* ctx, int64_t k, int64_t n, int64_t m) { if (ctx->logging) { fprintf(ctx->log, "## Transpose\n"); fprintf(ctx->log, "Arrays : %ld\n", (long int)k); fprintf(ctx->log, "X elements : %ld\n", (long int)m); fprintf(ctx->log, "Y elements : %ld\n", (long int)n); fprintf(ctx->log, "\n"); } } #define GEN_LMAD_COPY(NAME, ELEM_TYPE) \ static void lmad_copy_##NAME \ (struct futhark_context *ctx, int r, \ ELEM_TYPE* dst, int64_t dst_offset, int64_t dst_strides[r], \ ELEM_TYPE *src, int64_t src_offset, int64_t src_strides[r], \ int64_t shape[r]) { \ log_copy(ctx, "CPU to CPU", r, dst_offset, dst_strides, \ src_offset, src_strides, shape); \ int64_t size = 1; \ for (int i = 0; i < r; i++) { size *= shape[i]; } \ if (size == 0) { return; } \ int64_t k, n, m; \ if (lmad_map_tr(&k, &n, &m, \ r, dst_strides, src_strides, shape)) { \ log_transpose(ctx, k, n, m); \ map_transpose_##NAME \ (dst+dst_offset, src+src_offset, k, n, m, 0, n, 0, m); \ } else if (lmad_memcpyable(r, dst_strides, src_strides, shape)) { \ if (ctx->logging) {fprintf(ctx->log, "## Flat copy\n\n");} \ memcpy(dst+dst_offset, src+src_offset, size*sizeof(*dst)); \ } else { \ if (ctx->logging) {fprintf(ctx->log, "## General copy\n\n");} \ lmad_copy_elements_##NAME \ (r, \ dst+dst_offset, dst_strides, \ src+src_offset, src_strides, shape); \ } \ } GEN_MAP_TRANSPOSE(1b, uint8_t) GEN_MAP_TRANSPOSE(2b, uint16_t) GEN_MAP_TRANSPOSE(4b, uint32_t) GEN_MAP_TRANSPOSE(8b, uint64_t) GEN_LMAD_COPY_ELEMENTS(1b, uint8_t) GEN_LMAD_COPY_ELEMENTS(2b, uint16_t) GEN_LMAD_COPY_ELEMENTS(4b, uint32_t) GEN_LMAD_COPY_ELEMENTS(8b, uint64_t) GEN_LMAD_COPY(1b, uint8_t) GEN_LMAD_COPY(2b, uint16_t) GEN_LMAD_COPY(4b, uint32_t) GEN_LMAD_COPY(8b, uint64_t) // End of copy.h futhark-0.25.27/rts/c/errors.h000066400000000000000000000001321475065116200160550ustar00rootroot00000000000000#define FUTHARK_SUCCESS 0 #define FUTHARK_PROGRAM_ERROR 2 #define FUTHARK_OUT_OF_MEMORY 3 futhark-0.25.27/rts/c/event_list.h000066400000000000000000000032651475065116200167270ustar00rootroot00000000000000// Start of event_list.h typedef int (*event_report_fn)(struct str_builder*, void*); struct event { void* data; event_report_fn f; const char* name; char *description; }; struct event_list { struct event *events; int num_events; int capacity; }; static void event_list_init(struct event_list *l) { l->capacity = 100; l->num_events = 0; l->events = calloc(l->capacity, sizeof(struct event)); } static void event_list_free(struct event_list *l) { free(l->events); } static void add_event_to_list(struct event_list *l, const char* name, char* description, void* data, event_report_fn f) { if (l->num_events == l->capacity) { l->capacity *= 2; l->events = realloc(l->events, l->capacity * sizeof(struct event)); } l->events[l->num_events].name = name; l->events[l->num_events].description = description; l->events[l->num_events].data = data; l->events[l->num_events].f = f; l->num_events++; } static int report_events_in_list(struct event_list *l, struct str_builder* sb) { int ret = 0; for (int i = 0; i < l->num_events; i++) { if (i != 0) { str_builder_str(sb, ","); } str_builder_str(sb, "{\"name\":"); str_builder_json_str(sb, l->events[i].name); str_builder_str(sb, ",\"description\":"); str_builder_json_str(sb, l->events[i].description); free(l->events[i].description); if (l->events[i].f(sb, l->events[i].data) != 0) { ret = 1; break; } str_builder(sb, "}"); } event_list_free(l); event_list_init(l); return ret; } // End of event_list.h futhark-0.25.27/rts/c/free_list.h000066400000000000000000000116461475065116200165310ustar00rootroot00000000000000// Start of free_list.h. typedef uintptr_t fl_mem; // An entry in the free list. May be invalid, to avoid having to // deallocate entries as soon as they are removed. There is also a // tag, to help with memory reuse. struct free_list_entry { size_t size; fl_mem mem; const char *tag; unsigned char valid; }; struct free_list { struct free_list_entry *entries; // Pointer to entries. int capacity; // Number of entries. int used; // Number of valid entries. lock_t lock; // Thread safety. }; static void free_list_init(struct free_list *l) { l->capacity = 30; // Picked arbitrarily. l->used = 0; l->entries = (struct free_list_entry*) malloc(sizeof(struct free_list_entry) * l->capacity); for (int i = 0; i < l->capacity; i++) { l->entries[i].valid = 0; } create_lock(&l->lock); } // Remove invalid entries from the free list. static void free_list_pack(struct free_list *l) { lock_lock(&l->lock); int p = 0; for (int i = 0; i < l->capacity; i++) { if (l->entries[i].valid) { l->entries[p] = l->entries[i]; if (i > p) { l->entries[i].valid = 0; } p++; } } // Now p is the number of used elements. We don't want it to go // less than the default capacity (although in practice it's OK as // long as it doesn't become 1). if (p < 30) { p = 30; } l->entries = realloc(l->entries, p * sizeof(struct free_list_entry)); l->capacity = p; lock_unlock(&l->lock); } static void free_list_destroy(struct free_list *l) { assert(l->used == 0); free(l->entries); free_lock(&l->lock); } // Not part of the interface, so no locking. static int free_list_find_invalid(struct free_list *l) { int i; for (i = 0; i < l->capacity; i++) { if (!l->entries[i].valid) { break; } } return i; } static void free_list_insert(struct free_list *l, size_t size, fl_mem mem, const char *tag) { lock_lock(&l->lock); int i = free_list_find_invalid(l); if (i == l->capacity) { // List is full; so we have to grow it. int new_capacity = l->capacity * 2 * sizeof(struct free_list_entry); l->entries = realloc(l->entries, new_capacity); for (int j = 0; j < l->capacity; j++) { l->entries[j+l->capacity].valid = 0; } l->capacity *= 2; } // Now 'i' points to the first invalid entry. l->entries[i].valid = 1; l->entries[i].size = size; l->entries[i].mem = mem; l->entries[i].tag = tag; l->used++; lock_unlock(&l->lock); } // Determine whether this entry in the free list is acceptable for // satisfying the request. Not public, so no locking. static bool free_list_acceptable(size_t size, const char* tag, struct free_list_entry *entry) { // We check not just the hard requirement (is the entry acceptable // and big enough?) but also put a cap on how much wasted space // (internal fragmentation) we allow. This is necessarily a // heuristic, and a crude one. if (!entry->valid) { return false; } if (size > entry->size) { return false; } // We know the block fits. Now the question is whether it is too // big. Our policy is as follows: // // 1) We don't care about wasted space below 4096 bytes (to avoid // churn in tiny allocations). // // 2) If the tag matches, we allow _any_ amount of wasted space. // // 3) Otherwise we allow up to 50% wasted space. if (entry->size < 4096) { return true; } if (entry->tag == tag) { return true; } if (entry->size < size * 2) { return true; } return false; } // Find and remove a memory block of the indicated tag, or if that // does not exist, another memory block with exactly the desired size. // Returns 0 on success. static int free_list_find(struct free_list *l, size_t size, const char *tag, size_t *size_out, fl_mem *mem_out) { lock_lock(&l->lock); int size_match = -1; int i; int ret = 1; for (i = 0; i < l->capacity; i++) { if (free_list_acceptable(size, tag, &l->entries[i]) && (size_match < 0 || l->entries[i].size < l->entries[size_match].size)) { // If this entry is valid, has sufficient size, and is smaller than the // best entry found so far, use this entry. size_match = i; } } if (size_match >= 0) { l->entries[size_match].valid = 0; *size_out = l->entries[size_match].size; *mem_out = l->entries[size_match].mem; l->used--; ret = 0; } lock_unlock(&l->lock); return ret; } // Remove the first block in the free list. Returns 0 if a block was // removed, and nonzero if the free list was already empty. static int free_list_first(struct free_list *l, fl_mem *mem_out) { lock_lock(&l->lock); int ret = 1; for (int i = 0; i < l->capacity; i++) { if (l->entries[i].valid) { l->entries[i].valid = 0; *mem_out = l->entries[i].mem; l->used--; ret = 0; break; } } lock_unlock(&l->lock); return ret; } // End of free_list.h. futhark-0.25.27/rts/c/gpu.h000066400000000000000000000701421475065116200153440ustar00rootroot00000000000000// Start of gpu.h // Generic functions that use our tiny GPU abstraction layer. The // entire context must be defined before this header is included. In // particular we expect the following functions to be available: static int gpu_free_actual(struct futhark_context *ctx, gpu_mem mem); static int gpu_alloc_actual(struct futhark_context *ctx, size_t size, gpu_mem *mem_out); int gpu_launch_kernel(struct futhark_context* ctx, gpu_kernel kernel, const char *name, const int32_t grid[3], const int32_t block[3], unsigned int shared_mem_bytes, int num_args, void* args[num_args], size_t args_sizes[num_args]); int gpu_memcpy(struct futhark_context* ctx, gpu_mem dst, int64_t dst_offset, gpu_mem src, int64_t src_offset, int64_t nbytes); int gpu_scalar_from_device(struct futhark_context* ctx, void *dst, gpu_mem src, size_t offset, size_t size); int gpu_scalar_to_device(struct futhark_context* ctx, gpu_mem dst, size_t offset, size_t size, void *src); void gpu_create_kernel(struct futhark_context *ctx, gpu_kernel* kernel, const char* name); static void gpu_init_log(struct futhark_context *ctx) { if (ctx->cfg->logging) { fprintf(ctx->log, "Default block size: %ld\n", (long)ctx->cfg->gpu.default_block_size); fprintf(ctx->log, "Default grid size: %ld\n", (long)ctx->cfg->gpu.default_grid_size); fprintf(ctx->log, "Default tile size: %ld\n", (long)ctx->cfg->gpu.default_tile_size); fprintf(ctx->log, "Default register tile size: %ld\n", (long)ctx->cfg->gpu.default_reg_tile_size); fprintf(ctx->log, "Default cache: %ld\n", (long)ctx->cfg->gpu.default_cache); fprintf(ctx->log, "Default registers: %ld\n", (long)ctx->cfg->gpu.default_registers); fprintf(ctx->log, "Default threshold: %ld\n", (long)ctx->cfg->gpu.default_threshold); fprintf(ctx->log, "Max thread block size: %ld\n", (long)ctx->max_thread_block_size); fprintf(ctx->log, "Max grid size: %ld\n", (long)ctx->max_grid_size); fprintf(ctx->log, "Max tile size: %ld\n", (long)ctx->max_tile_size); fprintf(ctx->log, "Max threshold: %ld\n", (long)ctx->max_threshold); fprintf(ctx->log, "Max shared memory: %ld\n", (long)ctx->max_shared_memory); fprintf(ctx->log, "Max registers: %ld\n", (long)ctx->max_registers); fprintf(ctx->log, "Max cache: %ld\n", (long)ctx->max_cache); fprintf(ctx->log, "Lockstep width: %ld\n", (long)ctx->lockstep_width); } } // Generic GPU command line options. void futhark_context_config_set_default_thread_block_size(struct futhark_context_config *cfg, int size) { cfg->gpu.default_block_size = size; cfg->gpu.default_block_size_changed = 1; } void futhark_context_config_set_default_group_size(struct futhark_context_config *cfg, int size) { futhark_context_config_set_default_thread_block_size(cfg, size); } void futhark_context_config_set_default_grid_size(struct futhark_context_config *cfg, int num) { cfg->gpu.default_grid_size = num; cfg->gpu.default_grid_size_changed = 1; } void futhark_context_config_set_default_num_groups(struct futhark_context_config *cfg, int num) { futhark_context_config_set_default_grid_size(cfg, num); } void futhark_context_config_set_default_tile_size(struct futhark_context_config *cfg, int size) { cfg->gpu.default_tile_size = size; cfg->gpu.default_tile_size_changed = 1; } void futhark_context_config_set_default_reg_tile_size(struct futhark_context_config *cfg, int size) { cfg->gpu.default_reg_tile_size = size; } void futhark_context_config_set_default_cache(struct futhark_context_config *cfg, int size) { cfg->gpu.default_cache = size; } void futhark_context_config_set_default_registers(struct futhark_context_config *cfg, int size) { cfg->gpu.default_registers = size; } void futhark_context_config_set_default_threshold(struct futhark_context_config *cfg, int size) { cfg->gpu.default_threshold = size; } int futhark_context_config_set_tuning_param(struct futhark_context_config *cfg, const char *param_name, size_t new_value) { for (int i = 0; i < cfg->num_tuning_params; i++) { if (strcmp(param_name, cfg->tuning_param_names[i]) == 0) { cfg->tuning_params[i] = new_value; return 0; } } if (strcmp(param_name, "default_thread_block_size") == 0 || strcmp(param_name, "default_group_size") == 0) { cfg->gpu.default_block_size = new_value; return 0; } if (strcmp(param_name, "default_grid_size") == 0 || strcmp(param_name, "default_num_groups") == 0) { cfg->gpu.default_grid_size = new_value; return 0; } if (strcmp(param_name, "default_threshold") == 0) { cfg->gpu.default_threshold = new_value; return 0; } if (strcmp(param_name, "default_tile_size") == 0) { cfg->gpu.default_tile_size = new_value; return 0; } if (strcmp(param_name, "default_reg_tile_size") == 0) { cfg->gpu.default_reg_tile_size = new_value; return 0; } if (strcmp(param_name, "default_cache") == 0) { cfg->gpu.default_cache = new_value; return 0; } if (strcmp(param_name, "default_shared_memory") == 0) { cfg->gpu.default_shared_memory = new_value; return 0; } return 1; } // End of GPU command line optiopns. // Max number of thead blocks we allow along the second or third // dimension for transpositions. #define MAX_TR_THREAD_BLOCKS 65535 struct builtin_kernels { // We have a lot of ways to transpose arrays. gpu_kernel map_transpose_1b; gpu_kernel map_transpose_1b_low_height; gpu_kernel map_transpose_1b_low_width; gpu_kernel map_transpose_1b_small; gpu_kernel map_transpose_1b_large; gpu_kernel map_transpose_2b; gpu_kernel map_transpose_2b_low_height; gpu_kernel map_transpose_2b_low_width; gpu_kernel map_transpose_2b_small; gpu_kernel map_transpose_2b_large; gpu_kernel map_transpose_4b; gpu_kernel map_transpose_4b_low_height; gpu_kernel map_transpose_4b_low_width; gpu_kernel map_transpose_4b_small; gpu_kernel map_transpose_4b_large; gpu_kernel map_transpose_8b; gpu_kernel map_transpose_8b_low_height; gpu_kernel map_transpose_8b_low_width; gpu_kernel map_transpose_8b_small; gpu_kernel map_transpose_8b_large; // And a few ways of copying. gpu_kernel lmad_copy_1b; gpu_kernel lmad_copy_2b; gpu_kernel lmad_copy_4b; gpu_kernel lmad_copy_8b; }; struct builtin_kernels* init_builtin_kernels(struct futhark_context* ctx) { struct builtin_kernels *kernels = malloc(sizeof(struct builtin_kernels)); gpu_create_kernel(ctx, &kernels->map_transpose_1b, "map_transpose_1b"); gpu_create_kernel(ctx, &kernels->map_transpose_1b_large, "map_transpose_1b_large"); gpu_create_kernel(ctx, &kernels->map_transpose_1b_low_height, "map_transpose_1b_low_height"); gpu_create_kernel(ctx, &kernels->map_transpose_1b_low_width, "map_transpose_1b_low_width"); gpu_create_kernel(ctx, &kernels->map_transpose_1b_small, "map_transpose_1b_small"); gpu_create_kernel(ctx, &kernels->map_transpose_2b, "map_transpose_2b"); gpu_create_kernel(ctx, &kernels->map_transpose_2b_large, "map_transpose_2b_large"); gpu_create_kernel(ctx, &kernels->map_transpose_2b_low_height, "map_transpose_2b_low_height"); gpu_create_kernel(ctx, &kernels->map_transpose_2b_low_width, "map_transpose_2b_low_width"); gpu_create_kernel(ctx, &kernels->map_transpose_2b_small, "map_transpose_2b_small"); gpu_create_kernel(ctx, &kernels->map_transpose_4b, "map_transpose_4b"); gpu_create_kernel(ctx, &kernels->map_transpose_4b_large, "map_transpose_4b_large"); gpu_create_kernel(ctx, &kernels->map_transpose_4b_low_height, "map_transpose_4b_low_height"); gpu_create_kernel(ctx, &kernels->map_transpose_4b_low_width, "map_transpose_4b_low_width"); gpu_create_kernel(ctx, &kernels->map_transpose_4b_small, "map_transpose_4b_small"); gpu_create_kernel(ctx, &kernels->map_transpose_8b, "map_transpose_8b"); gpu_create_kernel(ctx, &kernels->map_transpose_8b_large, "map_transpose_8b_large"); gpu_create_kernel(ctx, &kernels->map_transpose_8b_low_height, "map_transpose_8b_low_height"); gpu_create_kernel(ctx, &kernels->map_transpose_8b_low_width, "map_transpose_8b_low_width"); gpu_create_kernel(ctx, &kernels->map_transpose_8b_small, "map_transpose_8b_small"); gpu_create_kernel(ctx, &kernels->lmad_copy_1b, "lmad_copy_1b"); gpu_create_kernel(ctx, &kernels->lmad_copy_2b, "lmad_copy_2b"); gpu_create_kernel(ctx, &kernels->lmad_copy_4b, "lmad_copy_4b"); gpu_create_kernel(ctx, &kernels->lmad_copy_8b, "lmad_copy_8b"); return kernels; } void free_builtin_kernels(struct futhark_context* ctx, struct builtin_kernels* kernels) { gpu_free_kernel(ctx, kernels->map_transpose_1b); gpu_free_kernel(ctx, kernels->map_transpose_1b_large); gpu_free_kernel(ctx, kernels->map_transpose_1b_low_height); gpu_free_kernel(ctx, kernels->map_transpose_1b_low_width); gpu_free_kernel(ctx, kernels->map_transpose_1b_small); gpu_free_kernel(ctx, kernels->map_transpose_2b); gpu_free_kernel(ctx, kernels->map_transpose_2b_large); gpu_free_kernel(ctx, kernels->map_transpose_2b_low_height); gpu_free_kernel(ctx, kernels->map_transpose_2b_low_width); gpu_free_kernel(ctx, kernels->map_transpose_2b_small); gpu_free_kernel(ctx, kernels->map_transpose_4b); gpu_free_kernel(ctx, kernels->map_transpose_4b_large); gpu_free_kernel(ctx, kernels->map_transpose_4b_low_height); gpu_free_kernel(ctx, kernels->map_transpose_4b_low_width); gpu_free_kernel(ctx, kernels->map_transpose_4b_small); gpu_free_kernel(ctx, kernels->map_transpose_8b); gpu_free_kernel(ctx, kernels->map_transpose_8b_large); gpu_free_kernel(ctx, kernels->map_transpose_8b_low_height); gpu_free_kernel(ctx, kernels->map_transpose_8b_low_width); gpu_free_kernel(ctx, kernels->map_transpose_8b_small); gpu_free_kernel(ctx, kernels->lmad_copy_1b); gpu_free_kernel(ctx, kernels->lmad_copy_2b); gpu_free_kernel(ctx, kernels->lmad_copy_4b); gpu_free_kernel(ctx, kernels->lmad_copy_8b); free(kernels); } static int gpu_alloc(struct futhark_context *ctx, FILE *log, size_t min_size, const char *tag, gpu_mem *mem_out, size_t *size_out) { if (min_size < sizeof(int)) { min_size = sizeof(int); } gpu_mem* memptr; if (free_list_find(&ctx->gpu_free_list, min_size, tag, size_out, (fl_mem*)&memptr) == 0) { // Successfully found a free block. Is it big enough? if (*size_out >= min_size) { if (ctx->cfg->debugging) { fprintf(log, "No need to allocate: Found a block in the free list.\n"); } *mem_out = *memptr; free(memptr); return FUTHARK_SUCCESS; } else { if (ctx->cfg->debugging) { fprintf(log, "Found a free block, but it was too small.\n"); } int error = gpu_free_actual(ctx, *memptr); free(memptr); if (error != FUTHARK_SUCCESS) { return error; } } } *size_out = min_size; // We have to allocate a new block from the driver. If the // allocation does not succeed, then we might be in an out-of-memory // situation. We now start freeing things from the free list until // we think we have freed enough that the allocation will succeed. // Since we don't know how far the allocation is from fitting, we // have to check after every deallocation. This might be pretty // expensive. Let's hope that this case is hit rarely. if (ctx->cfg->debugging) { fprintf(log, "Actually allocating the desired block.\n"); } int error = gpu_alloc_actual(ctx, min_size, mem_out); while (error == FUTHARK_OUT_OF_MEMORY) { if (ctx->cfg->debugging) { fprintf(log, "Out of GPU memory: releasing entry from the free list...\n"); } gpu_mem* memptr; if (free_list_first(&ctx->gpu_free_list, (fl_mem*)&memptr) == 0) { gpu_mem mem = *memptr; free(memptr); error = gpu_free_actual(ctx, mem); if (error != FUTHARK_SUCCESS) { return error; } } else { break; } error = gpu_alloc_actual(ctx, min_size, mem_out); } return error; } static int gpu_free(struct futhark_context *ctx, gpu_mem mem, size_t size, const char *tag) { gpu_mem* memptr = malloc(sizeof(gpu_mem)); *memptr = mem; free_list_insert(&ctx->gpu_free_list, size, (fl_mem)memptr, tag); return FUTHARK_SUCCESS; } static int gpu_free_all(struct futhark_context *ctx) { free_list_pack(&ctx->gpu_free_list); gpu_mem* memptr; while (free_list_first(&ctx->gpu_free_list, (fl_mem*)&memptr) == 0) { gpu_mem mem = *memptr; free(memptr); int error = gpu_free_actual(ctx, mem); if (error != FUTHARK_SUCCESS) { return error; } } return FUTHARK_SUCCESS; } static int gpu_map_transpose(struct futhark_context* ctx, gpu_kernel kernel_default, gpu_kernel kernel_low_height, gpu_kernel kernel_low_width, gpu_kernel kernel_small, gpu_kernel kernel_large, const char *name, size_t elem_size, gpu_mem dst, int64_t dst_offset, gpu_mem src, int64_t src_offset, int64_t k, int64_t n, int64_t m) { int64_t mulx = TR_BLOCK_DIM / n; int64_t muly = TR_BLOCK_DIM / m; int32_t mulx32 = mulx; int32_t muly32 = muly; int32_t k32 = k; int32_t n32 = n; int32_t m32 = m; gpu_kernel kernel = kernel_default; int32_t grid[3]; int32_t block[3]; void* args[11]; size_t args_sizes[11] = { sizeof(gpu_mem), sizeof(int64_t), sizeof(gpu_mem), sizeof(int64_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) }; args[0] = &dst; args[1] = &dst_offset; args[2] = &src; args[3] = &src_offset; args[7] = &mulx; args[8] = &muly; if (dst_offset + k * n * m <= 2147483647L && src_offset + k * n * m <= 2147483647L) { if (m <= TR_BLOCK_DIM/2 && n <= TR_BLOCK_DIM/2) { if (ctx->logging) { fprintf(ctx->log, "Using small kernel\n"); } kernel = kernel_small; grid[0] = ((k * n * m) + (TR_BLOCK_DIM*TR_BLOCK_DIM) - 1) / (TR_BLOCK_DIM*TR_BLOCK_DIM); grid[1] = 1; grid[2] = 1; block[0] = TR_BLOCK_DIM*TR_BLOCK_DIM; block[1] = 1; block[2] = 1; } else if (m <= TR_BLOCK_DIM/2 && TR_BLOCK_DIM < n) { if (ctx->logging) { fprintf(ctx->log, "Using low-width kernel\n"); } kernel = kernel_low_width; int64_t x_elems = m; int64_t y_elems = (n + muly - 1) / muly; grid[0] = (x_elems + TR_BLOCK_DIM - 1) / TR_BLOCK_DIM; grid[1] = (y_elems + TR_BLOCK_DIM - 1) / TR_BLOCK_DIM; grid[2] = k; block[0] = TR_BLOCK_DIM; block[1] = TR_BLOCK_DIM; block[2] = 1; } else if (n <= TR_BLOCK_DIM/2 && TR_BLOCK_DIM < m) { if (ctx->logging) { fprintf(ctx->log, "Using low-height kernel\n"); } kernel = kernel_low_height; int64_t x_elems = (m + mulx - 1) / mulx; int64_t y_elems = n; grid[0] = (x_elems + TR_BLOCK_DIM - 1) / TR_BLOCK_DIM; grid[1] = (y_elems + TR_BLOCK_DIM - 1) / TR_BLOCK_DIM; grid[2] = k; block[0] = TR_BLOCK_DIM; block[1] = TR_BLOCK_DIM; block[2] = 1; } else { if (ctx->logging) { fprintf(ctx->log, "Using default kernel\n"); } kernel = kernel_default; grid[0] = (m+TR_TILE_DIM-1)/TR_TILE_DIM; grid[1] = (n+TR_TILE_DIM-1)/TR_TILE_DIM; grid[2] = k; block[0] = TR_TILE_DIM; block[1] = TR_TILE_DIM/TR_ELEMS_PER_THREAD; block[2] = 1; } args[4] = &k32; args[5] = &m32; args[6] = &n32; args[7] = &mulx32; args[8] = &muly32; } else { if (ctx->logging) { fprintf(ctx->log, "Using large kernel\n"); } kernel = kernel_large; grid[0] = (m+TR_TILE_DIM-1)/TR_TILE_DIM; grid[1] = (n+TR_TILE_DIM-1)/TR_TILE_DIM; grid[2] = k; block[0] = TR_TILE_DIM; block[1] = TR_TILE_DIM/TR_ELEMS_PER_THREAD; block[2] = 1; args[4] = &k; args[5] = &m; args[6] = &n; args[7] = &mulx; args[8] = &muly; args_sizes[4] = sizeof(int64_t); args_sizes[5] = sizeof(int64_t); args_sizes[6] = sizeof(int64_t); args_sizes[7] = sizeof(int64_t); args_sizes[8] = sizeof(int64_t); } // Cap the number of thead blocks we launch and figure out how many // repeats we need alongside each dimension. int32_t repeat_1 = grid[1] / MAX_TR_THREAD_BLOCKS; int32_t repeat_2 = grid[2] / MAX_TR_THREAD_BLOCKS; grid[1] = repeat_1 > 0 ? MAX_TR_THREAD_BLOCKS : grid[1]; grid[2] = repeat_2 > 0 ? MAX_TR_THREAD_BLOCKS : grid[2]; args[9] = &repeat_1; args[10] = &repeat_2; args_sizes[9] = sizeof(repeat_1); args_sizes[10] = sizeof(repeat_2); if (ctx->logging) { fprintf(ctx->log, "\n"); } return gpu_launch_kernel(ctx, kernel, name, grid, block, TR_TILE_DIM*(TR_TILE_DIM+1)*elem_size, sizeof(args)/sizeof(args[0]), args, args_sizes); } #define GEN_MAP_TRANSPOSE_GPU2GPU(NAME, ELEM_TYPE) \ static int map_transpose_gpu2gpu_##NAME \ (struct futhark_context* ctx, \ gpu_mem dst, int64_t dst_offset, \ gpu_mem src, int64_t src_offset, \ int64_t k, int64_t m, int64_t n) \ { \ return \ gpu_map_transpose \ (ctx, \ ctx->kernels->map_transpose_##NAME, \ ctx->kernels->map_transpose_##NAME##_low_height, \ ctx->kernels->map_transpose_##NAME##_low_width, \ ctx->kernels->map_transpose_##NAME##_small, \ ctx->kernels->map_transpose_##NAME##_large, \ "map_transpose_" #NAME, sizeof(ELEM_TYPE), \ dst, dst_offset, src, src_offset, \ k, n, m); \ } static int gpu_lmad_copy(struct futhark_context* ctx, gpu_kernel kernel, int r, gpu_mem dst, int64_t dst_offset, int64_t dst_strides[r], gpu_mem src, int64_t src_offset, int64_t src_strides[r], int64_t shape[r]) { if (r > 8) { set_error(ctx, strdup("Futhark runtime limitation:\nCannot copy array of greater than rank 8.\n")); return 1; } int64_t n = 1; for (int i = 0; i < r; i++) { n *= shape[i]; } void* args[6+(8*3)]; size_t args_sizes[6+(8*3)]; args[0] = &dst; args_sizes[0] = sizeof(gpu_mem); args[1] = &dst_offset; args_sizes[1] = sizeof(dst_offset); args[2] = &src; args_sizes[2] = sizeof(gpu_mem); args[3] = &src_offset; args_sizes[3] = sizeof(src_offset); args[4] = &n; args_sizes[4] = sizeof(n); args[5] = &r; args_sizes[5] = sizeof(r); int64_t zero = 0; for (int i = 0; i < 8; i++) { args_sizes[6+i*3] = sizeof(int64_t); args_sizes[6+i*3+1] = sizeof(int64_t); args_sizes[6+i*3+2] = sizeof(int64_t); if (i < r) { args[6+i*3] = &shape[i]; args[6+i*3+1] = &dst_strides[i]; args[6+i*3+2] = &src_strides[i]; } else { args[6+i*3] = &zero; args[6+i*3+1] = &zero; args[6+i*3+2] = &zero; } } const size_t w = 256; // XXX: hardcoded thread block size. return gpu_launch_kernel(ctx, kernel, "copy_lmad_dev_to_dev", (const int32_t[3]) {(n+w-1)/w,1,1}, (const int32_t[3]) {w,1,1}, 0, 6+(8*3), args, args_sizes); } #define GEN_LMAD_COPY_ELEMENTS_GPU2GPU(NAME, ELEM_TYPE) \ static int lmad_copy_elements_gpu2gpu_##NAME \ (struct futhark_context* ctx, \ int r, \ gpu_mem dst, int64_t dst_offset, int64_t dst_strides[r], \ gpu_mem src, int64_t src_offset, int64_t src_strides[r], \ int64_t shape[r]) { \ return gpu_lmad_copy(ctx, ctx->kernels->lmad_copy_##NAME, r, \ dst, dst_offset, dst_strides, \ src, src_offset, src_strides, \ shape); \ } \ #define GEN_LMAD_COPY_GPU2GPU(NAME, ELEM_TYPE) \ static int lmad_copy_gpu2gpu_##NAME \ (struct futhark_context* ctx, \ int r, \ gpu_mem dst, int64_t dst_offset, int64_t dst_strides[r], \ gpu_mem src, int64_t src_offset, int64_t src_strides[r], \ int64_t shape[r]) { \ log_copy(ctx, "GPU to GPU", r, dst_offset, dst_strides, \ src_offset, src_strides, shape); \ int64_t size = 1; \ for (int i = 0; i < r; i++) { size *= shape[i]; } \ if (size == 0) { return FUTHARK_SUCCESS; } \ int64_t k, n, m; \ if (lmad_map_tr(&k, &n, &m, \ r, dst_strides, src_strides, shape)) { \ log_transpose(ctx, k, n, m); \ return map_transpose_gpu2gpu_##NAME \ (ctx, dst, dst_offset, src, src_offset, k, n, m); \ } else if (lmad_memcpyable(r, dst_strides, src_strides, shape)) { \ if (ctx->logging) {fprintf(ctx->log, "## Flat copy\n\n");} \ return gpu_memcpy(ctx, \ dst, dst_offset*sizeof(ELEM_TYPE), \ src, src_offset*sizeof(ELEM_TYPE), \ size * sizeof(ELEM_TYPE)); \ } else { \ if (ctx->logging) {fprintf(ctx->log, "## General copy\n\n");} \ return lmad_copy_elements_gpu2gpu_##NAME \ (ctx, r, \ dst, dst_offset, dst_strides, \ src, src_offset, src_strides, \ shape); \ } \ } static int lmad_copy_elements_host2gpu(struct futhark_context *ctx, size_t elem_size, int r, gpu_mem dst, int64_t dst_offset, int64_t dst_strides[r], unsigned char* src, int64_t src_offset, int64_t src_strides[r], int64_t shape[r]) { (void)ctx; (void)elem_size; (void)r; (void)dst; (void)dst_offset; (void)dst_strides; (void)src; (void)src_offset; (void)src_strides; (void)shape; set_error(ctx, strdup("Futhark runtime limitation:\nCannot copy unstructured array from host to GPU.\n")); return 1; } static int lmad_copy_elements_gpu2host (struct futhark_context *ctx, size_t elem_size, int r, unsigned char* dst, int64_t dst_offset, int64_t dst_strides[r], gpu_mem src, int64_t src_offset, int64_t src_strides[r], int64_t shape[r]) { (void)ctx; (void)elem_size; (void)r; (void)dst; (void)dst_offset; (void)dst_strides; (void)src; (void)src_offset; (void)src_strides; (void)shape; set_error(ctx, strdup("Futhark runtime limitation:\nCannot copy unstructured array from GPU to host.\n")); return 1; } #define GEN_LMAD_COPY_ELEMENTS_HOSTGPU(NAME, ELEM_TYPE) \ static int lmad_copy_elements_gpu2gpu_##NAME \ (struct futhark_context* ctx, \ int r, \ gpu_mem dst, int64_t dst_offset, int64_t dst_strides[r], \ gpu_mem src, int64_t src_offset, int64_t src_strides[r], \ int64_t shape[r]) { \ return (ctx, ctx->kernels->lmad_copy_##NAME, r, \ dst, dst_offset, dst_strides, \ src, src_offset, src_strides, \ shape); \ } \ static int lmad_copy_host2gpu(struct futhark_context* ctx, size_t elem_size, bool sync, int r, gpu_mem dst, int64_t dst_offset, int64_t dst_strides[r], unsigned char* src, int64_t src_offset, int64_t src_strides[r], int64_t shape[r]) { log_copy(ctx, "Host to GPU", r, dst_offset, dst_strides, src_offset, src_strides, shape); int64_t size = elem_size; for (int i = 0; i < r; i++) { size *= shape[i]; } if (size == 0) { return FUTHARK_SUCCESS; } int64_t k, n, m; if (lmad_memcpyable(r, dst_strides, src_strides, shape)) { if (ctx->logging) {fprintf(ctx->log, "## Flat copy\n\n");} return memcpy_host2gpu(ctx, sync, dst, dst_offset*elem_size, src, src_offset*elem_size, size); } else { if (ctx->logging) {fprintf(ctx->log, "## General copy\n\n");} int error; error = lmad_copy_elements_host2gpu (ctx, elem_size, r, dst, dst_offset, dst_strides, src, src_offset, src_strides, shape); if (error == 0 && sync) { error = futhark_context_sync(ctx); } return error; } } static int lmad_copy_gpu2host(struct futhark_context* ctx, size_t elem_size, bool sync, int r, unsigned char* dst, int64_t dst_offset, int64_t dst_strides[r], gpu_mem src, int64_t src_offset, int64_t src_strides[r], int64_t shape[r]) { log_copy(ctx, "Host to GPU", r, dst_offset, dst_strides, src_offset, src_strides, shape); int64_t size = elem_size; for (int i = 0; i < r; i++) { size *= shape[i]; } if (size == 0) { return FUTHARK_SUCCESS; } int64_t k, n, m; if (lmad_memcpyable(r, dst_strides, src_strides, shape)) { if (ctx->logging) {fprintf(ctx->log, "## Flat copy\n\n");} return memcpy_gpu2host(ctx, sync, dst, dst_offset*elem_size, src, src_offset*elem_size, size); } else { if (ctx->logging) {fprintf(ctx->log, "## General copy\n\n");} int error; error = lmad_copy_elements_gpu2host (ctx, elem_size, r, dst, dst_offset, dst_strides, src, src_offset, src_strides, shape); if (error == 0 && sync) { error = futhark_context_sync(ctx); } return error; } } GEN_MAP_TRANSPOSE_GPU2GPU(1b, uint8_t) GEN_MAP_TRANSPOSE_GPU2GPU(2b, uint16_t) GEN_MAP_TRANSPOSE_GPU2GPU(4b, uint32_t) GEN_MAP_TRANSPOSE_GPU2GPU(8b, uint64_t) GEN_LMAD_COPY_ELEMENTS_GPU2GPU(1b, uint8_t) GEN_LMAD_COPY_ELEMENTS_GPU2GPU(2b, uint16_t) GEN_LMAD_COPY_ELEMENTS_GPU2GPU(4b, uint32_t) GEN_LMAD_COPY_ELEMENTS_GPU2GPU(8b, uint64_t) GEN_LMAD_COPY_GPU2GPU(1b, uint8_t) GEN_LMAD_COPY_GPU2GPU(2b, uint16_t) GEN_LMAD_COPY_GPU2GPU(4b, uint32_t) GEN_LMAD_COPY_GPU2GPU(8b, uint64_t) // End of gpu.h futhark-0.25.27/rts/c/gpu_prototypes.h000066400000000000000000000023041475065116200176470ustar00rootroot00000000000000// Start of gpu_prototypes.h // Constants used for transpositions. In principle these should be configurable. #define TR_BLOCK_DIM 16 #define TR_TILE_DIM (TR_BLOCK_DIM*2) #define TR_ELEMS_PER_THREAD 8 // Config stuff included in every GPU backend. struct gpu_config { size_t default_block_size; size_t default_grid_size; size_t default_tile_size; size_t default_reg_tile_size; size_t default_cache; size_t default_shared_memory; size_t default_registers; size_t default_threshold; int default_block_size_changed; int default_grid_size_changed; int default_tile_size_changed; }; // The following are dummy sizes that mean the concrete defaults // will be set during initialisation via hardware-inspection-based // heuristics. struct gpu_config gpu_config_initial = { 0 }; // Must be defined by the user. static int gpu_macros(struct futhark_context *ctx, char*** names, int64_t** values); static void gpu_init_log(struct futhark_context *ctx); struct builtin_kernels* init_builtin_kernels(struct futhark_context* ctx); void free_builtin_kernels(struct futhark_context* ctx, struct builtin_kernels* kernels); static int gpu_free_all(struct futhark_context *ctx); // End of gpu_prototypes.h futhark-0.25.27/rts/c/half.h000066400000000000000000001037431475065116200154670ustar00rootroot00000000000000// Start of half.h. // Conversion functions are from http://half.sourceforge.net/, but // translated to C. // // Copyright (c) 2012-2021 Christian Rau // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. #ifndef __OPENCL_VERSION__ #define __constant #endif __constant static const uint16_t base_table[512] = { 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800, 0x1C00, 0x2000, 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, 0x3800, 0x3C00, 0x4000, 0x4400, 0x4800, 0x4C00, 0x5000, 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x7C00, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, 0x8200, 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, 0x9800, 0x9C00, 0xA000, 0xA400, 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, 0xC000, 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, 0xF000, 0xF400, 0xF800, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00, 0xFC00 }; __constant static const unsigned char shift_table[512] = { 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13 }; __constant static const uint32_t mantissa_table[2048] = { 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, 0x34C00000, 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, 0x35600000, 0x35700000, 0x35800000, 0x35880000, 0x35900000, 0x35980000, 0x35A00000, 0x35A80000, 0x35B00000, 0x35B80000, 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000, 0x36000000, 0x36040000, 0x36080000, 0x360C0000, 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, 0x36400000, 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, 0x36580000, 0x365C0000, 0x36600000, 0x36640000, 0x36680000, 0x366C0000, 0x36700000, 0x36740000, 0x36780000, 0x367C0000, 0x36800000, 0x36820000, 0x36840000, 0x36860000, 0x36880000, 0x368A0000, 0x368C0000, 0x368E0000, 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, 0x369A0000, 0x369C0000, 0x369E0000, 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, 0x36B40000, 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, 0x36C00000, 0x36C20000, 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000, 0x36D00000, 0x36D20000, 0x36D40000, 0x36D60000, 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, 0x36EE0000, 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, 0x36FC0000, 0x36FE0000, 0x37000000, 0x37010000, 0x37020000, 0x37030000, 0x37040000, 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, 0x370A0000, 0x370B0000, 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, 0x37100000, 0x37110000, 0x37120000, 0x37130000, 0x37140000, 0x37150000, 0x37160000, 0x37170000, 0x37180000, 0x37190000, 0x371A0000, 0x371B0000, 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, 0x37200000, 0x37210000, 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, 0x372E0000, 0x372F0000, 0x37300000, 0x37310000, 0x37320000, 0x37330000, 0x37340000, 0x37350000, 0x37360000, 0x37370000, 0x37380000, 0x37390000, 0x373A0000, 0x373B0000, 0x373C0000, 0x373D0000, 0x373E0000, 0x373F0000, 0x37400000, 0x37410000, 0x37420000, 0x37430000, 0x37440000, 0x37450000, 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, 0x374B0000, 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, 0x37500000, 0x37510000, 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, 0x37580000, 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000, 0x37600000, 0x37610000, 0x37620000, 0x37630000, 0x37640000, 0x37650000, 0x37660000, 0x37670000, 0x37680000, 0x37690000, 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, 0x376E0000, 0x376F0000, 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, 0x37750000, 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, 0x37800000, 0x37808000, 0x37810000, 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, 0x37840000, 0x37848000, 0x37850000, 0x37858000, 0x37860000, 0x37868000, 0x37870000, 0x37878000, 0x37880000, 0x37888000, 0x37890000, 0x37898000, 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, 0x378F8000, 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, 0x37960000, 0x37968000, 0x37970000, 0x37978000, 0x37980000, 0x37988000, 0x37990000, 0x37998000, 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, 0x379C0000, 0x379C8000, 0x379D0000, 0x379D8000, 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000, 0x37A00000, 0x37A08000, 0x37A10000, 0x37A18000, 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, 0x37A48000, 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, 0x37AB0000, 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, 0x37AE0000, 0x37AE8000, 0x37AF0000, 0x37AF8000, 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000, 0x37B20000, 0x37B28000, 0x37B30000, 0x37B38000, 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000, 0x37B80000, 0x37B88000, 0x37B90000, 0x37B98000, 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, 0x37C00000, 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000, 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, 0x37C60000, 0x37C68000, 0x37C70000, 0x37C78000, 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, 0x37CB0000, 0x37CB8000, 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, 0x37CE8000, 0x37CF0000, 0x37CF8000, 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, 0x37D50000, 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, 0x37D80000, 0x37D88000, 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000, 0x37DC0000, 0x37DC8000, 0x37DD0000, 0x37DD8000, 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, 0x37E38000, 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, 0x37E70000, 0x37E78000, 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, 0x37EA0000, 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000, 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, 0x37F00000, 0x37F08000, 0x37F10000, 0x37F18000, 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, 0x37F50000, 0x37F58000, 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, 0x37F80000, 0x37F88000, 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, 0x37FF0000, 0x37FF8000, 0x38000000, 0x38004000, 0x38008000, 0x3800C000, 0x38010000, 0x38014000, 0x38018000, 0x3801C000, 0x38020000, 0x38024000, 0x38028000, 0x3802C000, 0x38030000, 0x38034000, 0x38038000, 0x3803C000, 0x38040000, 0x38044000, 0x38048000, 0x3804C000, 0x38050000, 0x38054000, 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, 0x3806C000, 0x38070000, 0x38074000, 0x38078000, 0x3807C000, 0x38080000, 0x38084000, 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, 0x380A0000, 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000, 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, 0x380D0000, 0x380D4000, 0x380D8000, 0x380DC000, 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, 0x380F8000, 0x380FC000, 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, 0x38114000, 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, 0x38130000, 0x38134000, 0x38138000, 0x3813C000, 0x38140000, 0x38144000, 0x38148000, 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, 0x38160000, 0x38164000, 0x38168000, 0x3816C000, 0x38170000, 0x38174000, 0x38178000, 0x3817C000, 0x38180000, 0x38184000, 0x38188000, 0x3818C000, 0x38190000, 0x38194000, 0x38198000, 0x3819C000, 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, 0x381BC000, 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, 0x381F0000, 0x381F4000, 0x381F8000, 0x381FC000, 0x38200000, 0x38204000, 0x38208000, 0x3820C000, 0x38210000, 0x38214000, 0x38218000, 0x3821C000, 0x38220000, 0x38224000, 0x38228000, 0x3822C000, 0x38230000, 0x38234000, 0x38238000, 0x3823C000, 0x38240000, 0x38244000, 0x38248000, 0x3824C000, 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, 0x38264000, 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, 0x38298000, 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, 0x382B0000, 0x382B4000, 0x382B8000, 0x382BC000, 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000, 0x382D0000, 0x382D4000, 0x382D8000, 0x382DC000, 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000, 0x38300000, 0x38304000, 0x38308000, 0x3830C000, 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, 0x38340000, 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, 0x38358000, 0x3835C000, 0x38360000, 0x38364000, 0x38368000, 0x3836C000, 0x38370000, 0x38374000, 0x38378000, 0x3837C000, 0x38380000, 0x38384000, 0x38388000, 0x3838C000, 0x38390000, 0x38394000, 0x38398000, 0x3839C000, 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, 0x383B4000, 0x383B8000, 0x383BC000, 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, 0x383E8000, 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, 0x38400000, 0x38404000, 0x38408000, 0x3840C000, 0x38410000, 0x38414000, 0x38418000, 0x3841C000, 0x38420000, 0x38424000, 0x38428000, 0x3842C000, 0x38430000, 0x38434000, 0x38438000, 0x3843C000, 0x38440000, 0x38444000, 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, 0x3845C000, 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, 0x38478000, 0x3847C000, 0x38480000, 0x38484000, 0x38488000, 0x3848C000, 0x38490000, 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000, 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, 0x384C0000, 0x384C4000, 0x384C8000, 0x384CC000, 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, 0x384E8000, 0x384EC000, 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, 0x38500000, 0x38504000, 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, 0x38538000, 0x3853C000, 0x38540000, 0x38544000, 0x38548000, 0x3854C000, 0x38550000, 0x38554000, 0x38558000, 0x3855C000, 0x38560000, 0x38564000, 0x38568000, 0x3856C000, 0x38570000, 0x38574000, 0x38578000, 0x3857C000, 0x38580000, 0x38584000, 0x38588000, 0x3858C000, 0x38590000, 0x38594000, 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, 0x385AC000, 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, 0x385C0000, 0x385C4000, 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, 0x385E0000, 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000, 0x38600000, 0x38604000, 0x38608000, 0x3860C000, 0x38610000, 0x38614000, 0x38618000, 0x3861C000, 0x38620000, 0x38624000, 0x38628000, 0x3862C000, 0x38630000, 0x38634000, 0x38638000, 0x3863C000, 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, 0x38654000, 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, 0x38670000, 0x38674000, 0x38678000, 0x3867C000, 0x38680000, 0x38684000, 0x38688000, 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, 0x386A0000, 0x386A4000, 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000, 0x386C0000, 0x386C4000, 0x386C8000, 0x386CC000, 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, 0x386FC000, 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, 0x38730000, 0x38734000, 0x38738000, 0x3873C000, 0x38740000, 0x38744000, 0x38748000, 0x3874C000, 0x38750000, 0x38754000, 0x38758000, 0x3875C000, 0x38760000, 0x38764000, 0x38768000, 0x3876C000, 0x38770000, 0x38774000, 0x38778000, 0x3877C000, 0x38780000, 0x38784000, 0x38788000, 0x3878C000, 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, 0x387A4000, 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, 0x387D8000, 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, 0x387F0000, 0x387F4000, 0x387F8000, 0x387FC000, 0x38000000, 0x38002000, 0x38004000, 0x38006000, 0x38008000, 0x3800A000, 0x3800C000, 0x3800E000, 0x38010000, 0x38012000, 0x38014000, 0x38016000, 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000, 0x38020000, 0x38022000, 0x38024000, 0x38026000, 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, 0x38040000, 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000, 0x38050000, 0x38052000, 0x38054000, 0x38056000, 0x38058000, 0x3805A000, 0x3805C000, 0x3805E000, 0x38060000, 0x38062000, 0x38064000, 0x38066000, 0x38068000, 0x3806A000, 0x3806C000, 0x3806E000, 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, 0x3807A000, 0x3807C000, 0x3807E000, 0x38080000, 0x38082000, 0x38084000, 0x38086000, 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, 0x38094000, 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, 0x380A0000, 0x380A2000, 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000, 0x380B0000, 0x380B2000, 0x380B4000, 0x380B6000, 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, 0x380CE000, 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, 0x380DC000, 0x380DE000, 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, 0x380E8000, 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000, 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, 0x38100000, 0x38102000, 0x38104000, 0x38106000, 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, 0x38114000, 0x38116000, 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, 0x38120000, 0x38122000, 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, 0x3813C000, 0x3813E000, 0x38140000, 0x38142000, 0x38144000, 0x38146000, 0x38148000, 0x3814A000, 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, 0x38154000, 0x38156000, 0x38158000, 0x3815A000, 0x3815C000, 0x3815E000, 0x38160000, 0x38162000, 0x38164000, 0x38166000, 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, 0x38176000, 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, 0x38180000, 0x38182000, 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, 0x38190000, 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000, 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, 0x381A8000, 0x381AA000, 0x381AC000, 0x381AE000, 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, 0x381BC000, 0x381BE000, 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, 0x381CA000, 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, 0x381E0000, 0x381E2000, 0x381E4000, 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, 0x381F0000, 0x381F2000, 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000, 0x38200000, 0x38202000, 0x38204000, 0x38206000, 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, 0x38210000, 0x38212000, 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, 0x3821E000, 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, 0x38238000, 0x3823A000, 0x3823C000, 0x3823E000, 0x38240000, 0x38242000, 0x38244000, 0x38246000, 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, 0x38250000, 0x38252000, 0x38254000, 0x38256000, 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000, 0x38260000, 0x38262000, 0x38264000, 0x38266000, 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, 0x38272000, 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, 0x3828C000, 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, 0x38298000, 0x3829A000, 0x3829C000, 0x3829E000, 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000, 0x382A8000, 0x382AA000, 0x382AC000, 0x382AE000, 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000, 0x382C0000, 0x382C2000, 0x382C4000, 0x382C6000, 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, 0x382E0000, 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000, 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, 0x382F8000, 0x382FA000, 0x382FC000, 0x382FE000, 0x38300000, 0x38302000, 0x38304000, 0x38306000, 0x38308000, 0x3830A000, 0x3830C000, 0x3830E000, 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, 0x3831A000, 0x3831C000, 0x3831E000, 0x38320000, 0x38322000, 0x38324000, 0x38326000, 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, 0x38334000, 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, 0x38340000, 0x38342000, 0x38344000, 0x38346000, 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000, 0x38350000, 0x38352000, 0x38354000, 0x38356000, 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, 0x38360000, 0x38362000, 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, 0x3836E000, 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, 0x3837C000, 0x3837E000, 0x38380000, 0x38382000, 0x38384000, 0x38386000, 0x38388000, 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, 0x38394000, 0x38396000, 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, 0x383A0000, 0x383A2000, 0x383A4000, 0x383A6000, 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, 0x383B4000, 0x383B6000, 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, 0x383C0000, 0x383C2000, 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, 0x383DC000, 0x383DE000, 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, 0x383E8000, 0x383EA000, 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000, 0x383F8000, 0x383FA000, 0x383FC000, 0x383FE000, 0x38400000, 0x38402000, 0x38404000, 0x38406000, 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, 0x38416000, 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, 0x38420000, 0x38422000, 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, 0x38430000, 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000, 0x38440000, 0x38442000, 0x38444000, 0x38446000, 0x38448000, 0x3844A000, 0x3844C000, 0x3844E000, 0x38450000, 0x38452000, 0x38454000, 0x38456000, 0x38458000, 0x3845A000, 0x3845C000, 0x3845E000, 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, 0x3846A000, 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, 0x38480000, 0x38482000, 0x38484000, 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, 0x38490000, 0x38492000, 0x38494000, 0x38496000, 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000, 0x384A0000, 0x384A2000, 0x384A4000, 0x384A6000, 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, 0x384BE000, 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, 0x384D8000, 0x384DA000, 0x384DC000, 0x384DE000, 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000, 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, 0x384F0000, 0x384F2000, 0x384F4000, 0x384F6000, 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000, 0x38500000, 0x38502000, 0x38504000, 0x38506000, 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, 0x38512000, 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, 0x3852C000, 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, 0x38538000, 0x3853A000, 0x3853C000, 0x3853E000, 0x38540000, 0x38542000, 0x38544000, 0x38546000, 0x38548000, 0x3854A000, 0x3854C000, 0x3854E000, 0x38550000, 0x38552000, 0x38554000, 0x38556000, 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000, 0x38560000, 0x38562000, 0x38564000, 0x38566000, 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, 0x38580000, 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000, 0x38590000, 0x38592000, 0x38594000, 0x38596000, 0x38598000, 0x3859A000, 0x3859C000, 0x3859E000, 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, 0x385AC000, 0x385AE000, 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, 0x385BA000, 0x385BC000, 0x385BE000, 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, 0x385D4000, 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, 0x385E0000, 0x385E2000, 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000, 0x385F0000, 0x385F2000, 0x385F4000, 0x385F6000, 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, 0x38600000, 0x38602000, 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, 0x3860E000, 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, 0x3861C000, 0x3861E000, 0x38620000, 0x38622000, 0x38624000, 0x38626000, 0x38628000, 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, 0x38634000, 0x38636000, 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, 0x38640000, 0x38642000, 0x38644000, 0x38646000, 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, 0x38654000, 0x38656000, 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, 0x38660000, 0x38662000, 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, 0x3867C000, 0x3867E000, 0x38680000, 0x38682000, 0x38684000, 0x38686000, 0x38688000, 0x3868A000, 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, 0x38694000, 0x38696000, 0x38698000, 0x3869A000, 0x3869C000, 0x3869E000, 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, 0x386B6000, 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, 0x386C0000, 0x386C2000, 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, 0x386D0000, 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000, 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, 0x386E8000, 0x386EA000, 0x386EC000, 0x386EE000, 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, 0x386FC000, 0x386FE000, 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, 0x3870A000, 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, 0x38720000, 0x38722000, 0x38724000, 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, 0x38730000, 0x38732000, 0x38734000, 0x38736000, 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000, 0x38740000, 0x38742000, 0x38744000, 0x38746000, 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, 0x38750000, 0x38752000, 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, 0x3875E000, 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, 0x38778000, 0x3877A000, 0x3877C000, 0x3877E000, 0x38780000, 0x38782000, 0x38784000, 0x38786000, 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, 0x38790000, 0x38792000, 0x38794000, 0x38796000, 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000, 0x387A0000, 0x387A2000, 0x387A4000, 0x387A6000, 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, 0x387B2000, 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, 0x387CC000, 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, 0x387D8000, 0x387DA000, 0x387DC000, 0x387DE000, 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000, 0x387E8000, 0x387EA000, 0x387EC000, 0x387EE000, 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000 }; __constant static const uint32_t exponent_table[64] = { 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, 0x03000000, 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, 0x06000000, 0x06800000, 0x07000000, 0x07800000, 0x08000000, 0x08800000, 0x09000000, 0x09800000, 0x0A000000, 0x0A800000, 0x0B000000, 0x0B800000, 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000, 0x80000000, 0x80800000, 0x81000000, 0x81800000, 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, 0x88000000, 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000, 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, 0x8E000000, 0x8E800000, 0x8F000000, 0xC7800000 }; __constant static const unsigned short offset_table[64] = { 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024 }; SCALAR_FUN_ATTR uint16_t float2halfbits(float value) { union { float x; uint32_t y; } u; u.x = value; uint32_t bits = u.y; uint16_t hbits = base_table[bits>>23] + (uint16_t)((bits&0x7FFFFF)>>shift_table[bits>>23]);; return hbits; } SCALAR_FUN_ATTR float halfbits2float(uint16_t value) { uint32_t bits = mantissa_table[offset_table[value>>10]+(value&0x3FF)] + exponent_table[value>>10]; union { uint32_t x; float y; } u; u.x = bits; return u.y; } SCALAR_FUN_ATTR uint16_t halfbitsnextafter(uint16_t from, uint16_t to) { int fabs = from & 0x7FFF, tabs = to & 0x7FFF; if(fabs > 0x7C00 || tabs > 0x7C00) { return ((from&0x7FFF)>0x7C00) ? (from|0x200) : (to|0x200); } if(from == to || !(fabs|tabs)) { return to; } if(!fabs) { return (to&0x8000)+1; } unsigned int out = from + (((from>>15)^(unsigned int)((from^(0x8000|(0x8000-(from>>15))))<(to^(0x8000|(0x8000-(to>>15))))))<<1) - 1; return out; } // End of half.h. futhark-0.25.27/rts/c/ispc_util.h000066400000000000000000000423261475065116200165470ustar00rootroot00000000000000// Start of ispc_util.h. // This header file implements various operations that are useful only when // generating ISPC code. This includes wrappers for parts of Futhark's C runtime. // Expose gang size export uniform int64_t get_gang_size() { return programCount; } // Generate missing overloads for extract on pointers #define make_extract(ty) \ static inline uniform ty * uniform extract(uniform ty * varying ptr, uniform int idx) { \ int64 c = (int64)ptr; \ uniform int64 r = extract(c, idx); \ return (uniform ty * uniform)r; \ } make_extract(int8) make_extract(int16) make_extract(int32) make_extract(int64) make_extract(uint8) make_extract(uint16) make_extract(uint32) make_extract(uint64) make_extract(float16) make_extract(float) make_extract(double) /* make_extract(int8* uniform) */ /* make_extract(int16* uniform) */ /* make_extract(int32* uniform) */ /* make_extract(int64* uniform) */ /* make_extract(uint8* uniform) */ /* make_extract(uint16* uniform) */ /* make_extract(uint32* uniform) */ /* make_extract(uint64* uniform) */ /* make_extract(float16* uniform) */ /* make_extract(float* uniform) */ /* make_extract(double* uniform) */ make_extract(struct futhark_context) make_extract(struct memblock) // Handling of atomics // Atomic CAS acts differently in GCC and ISPC, so we emulate it. #define make_atomic_compare_exchange_wrapper(ty) \ static inline uniform bool atomic_compare_exchange_wrapper(uniform ty * uniform mem, \ uniform ty * uniform old, \ const uniform ty val){ \ uniform ty actual = atomic_compare_exchange_global(mem, *old, val); \ if (actual == *old){ \ return 1; \ } \ *old = actual; \ return 0; \ } \ static inline varying bool atomic_compare_exchange_wrapper(uniform ty * varying mem, \ varying ty * uniform old, \ const varying ty val){ \ varying ty actual = atomic_compare_exchange_global(mem, *old, val); \ bool res = 0; \ if(actual == *old){ \ res = 1; \ } else { \ *old = actual; \ } \ return res; \ } \ static inline varying bool atomic_compare_exchange_wrapper(varying ty * uniform mem, \ varying ty * uniform old, \ const varying ty val){ \ uniform ty * uniform base_mem = (uniform ty * uniform)mem; \ uniform ty * uniform base_old = (uniform ty * uniform)old; \ bool res = 0; \ foreach_active (i) { \ uniform ty * uniform curr_mem = base_mem + i; \ uniform ty * uniform curr_old = base_old + i; \ uniform ty curr_val = extract(val, i); \ uniform bool curr = atomic_compare_exchange_wrapper( \ curr_mem, curr_old, curr_val); \ res = insert(res, i, curr); \ } \ return res; \ } \ static inline uniform bool atomic_compare_exchange_wrapper(uniform ty * uniform mem, \ uniform ty * uniform old, \ const varying ty val){ \ uniform ty v = 0; \ foreach_active (i) v = extract(val, i); \ return atomic_compare_exchange_wrapper(mem, old, v); \ } make_atomic_compare_exchange_wrapper(int32) make_atomic_compare_exchange_wrapper(int64) make_atomic_compare_exchange_wrapper(uint32) make_atomic_compare_exchange_wrapper(uint64) make_atomic_compare_exchange_wrapper(float) make_atomic_compare_exchange_wrapper(double) // This code generates missing overloads for atomic operations on uniform // pointers to varying values. #define make_single_atomic(name, ty) \ static inline ty atomic_##name##_global(varying ty * uniform mem, ty val) { \ uniform ty * uniform base_mem = (uniform ty * uniform)mem; \ ty res = 0; \ foreach_active (i) { \ uniform ty * uniform curr_mem = base_mem + i; \ uniform ty curr_val = extract(val, i); \ uniform ty curr = atomic_##name##_global(curr_mem, curr_val); \ res = insert(res, i, curr); \ } \ return res; \ } #define make_all_atomic(name) \ make_single_atomic(name, int32) \ make_single_atomic(name, int64) \ make_single_atomic(name, uint32) \ make_single_atomic(name, uint64) make_all_atomic(add) make_all_atomic(subtract) make_all_atomic(and) make_all_atomic(or) make_all_atomic(xor) make_all_atomic(swap) // This is a hack to prevent literals (which have unbound variability) // from causing us to pick the wrong overload for atomic operations. static inline varying int32 make_varying(uniform int32 x) { return x; } static inline varying int32 make_varying(varying int32 x) { return x; } static inline varying int64 make_varying(uniform int64 x) { return x; } static inline varying int64 make_varying(varying int64 x) { return x; } static inline varying uint32 make_varying(uniform uint32 x) { return x; } static inline varying uint32 make_varying(varying uint32 x) { return x; } static inline varying uint64 make_varying(uniform uint64 x) { return x; } static inline varying uint64 make_varying(varying uint64 x) { return x; } // Redirect atomic operations to the relevant ISPC overloads. #define __atomic_fetch_add(x,y,z) atomic_add_global(x,make_varying(y)) #define __atomic_fetch_sub(x,y,z) atomic_sub_global(x,make_varying(y)) #define __atomic_fetch_and(x,y,z) atomic_and_global(x,make_varying(y)) #define __atomic_fetch_or(x,y,z) atomic_or_global(x,make_varying(y)) #define __atomic_fetch_xor(x,y,z) atomic_xor_global(x,make_varying(y)) #define __atomic_exchange_n(x,y,z) atomic_swap_global(x,make_varying(y)) #define __atomic_compare_exchange_n(x,y,z,h,j,k) atomic_compare_exchange_wrapper(x,y,z) // Memory allocation handling struct memblock { int32_t * references; uint8_t * mem; int64_t size; const int8_t * desc; }; static inline void free(void* ptr) { delete ptr; } static inline void free(void* uniform ptr) { delete ptr; } extern "C" unmasked uniform unsigned char * uniform realloc(uniform unsigned char * uniform ptr, uniform int64_t new_size); extern "C" unmasked uniform char * uniform lexical_realloc_error(uniform struct futhark_context * uniform ctx, uniform int64_t new_size); static inline uniform int lexical_realloc(uniform struct futhark_context * uniform ctx, unsigned char uniform * uniform * uniform ptr, int64_t uniform * uniform old_size, uniform int64_t new_size) { uniform unsigned char * uniform memptr = realloc(*ptr, new_size); if (memptr == NULL) { lexical_realloc_error(ctx, new_size); return FUTHARK_OUT_OF_MEMORY; } else { *ptr = memptr; *old_size = new_size; return FUTHARK_SUCCESS; } } static inline uniform int lexical_realloc(uniform struct futhark_context *ctx, unsigned char uniform * uniform * uniform ptr, int64_t uniform * uniform old_size, varying int64_t new_size) { return lexical_realloc(ctx, ptr, old_size, reduce_max(new_size)); } static inline uniform int lexical_realloc(uniform struct futhark_context * uniform ctx, unsigned char uniform * varying * uniform ptr, int64_t uniform * varying old_size, varying int64_t new_size) { uniform int err = FUTHARK_SUCCESS; foreach_active(i){ uniform unsigned char * uniform memptr = realloc(extract(*ptr,i), extract(new_size,i)); if (memptr == NULL) { lexical_realloc_error(ctx, extract(new_size,i)); err = FUTHARK_OUT_OF_MEMORY; } else { *ptr = (uniform unsigned char * varying)insert((int64_t)*ptr, i, (uniform int64_t) memptr); *old_size = new_size; } } return err; } static inline uniform int lexical_realloc(uniform struct futhark_context * uniform ctx, unsigned char uniform * varying * uniform ptr, int64_t varying * uniform old_size, varying int64_t new_size) { uniform int err = FUTHARK_SUCCESS; foreach_active(i){ uniform unsigned char * uniform memptr = realloc(extract(*ptr,i), extract(new_size,i)); if (memptr == NULL) { lexical_realloc_error(ctx, extract(new_size,i)); err = FUTHARK_OUT_OF_MEMORY; } else { *ptr = (uniform unsigned char * varying)insert((int64_t)*ptr, i, (uniform int64_t) memptr); *old_size = new_size; } } return err; } static inline uniform int lexical_realloc(uniform struct futhark_context * uniform ctx, unsigned char uniform * varying * uniform ptr, size_t varying * uniform old_size, varying int64_t new_size) { return lexical_realloc(ctx, ptr, (varying int64_t * uniform)old_size, new_size); } static inline uniform int lexical_realloc(uniform struct futhark_context * uniform ctx, unsigned char varying * uniform * uniform ptr, size_t varying * uniform old_size, uniform int64_t new_size) { uniform int err = FUTHARK_SUCCESS; uniform unsigned char * uniform memptr = realloc((uniform unsigned char * uniform )*ptr, new_size*programCount); if (memptr == NULL) { lexical_realloc_error(ctx, new_size); err = FUTHARK_OUT_OF_MEMORY; } else { *ptr = (varying unsigned char * uniform)memptr; *old_size = new_size; } return err; } static inline uniform int lexical_realloc(uniform struct futhark_context * uniform ctx, unsigned char varying * uniform * uniform ptr, size_t varying * uniform old_size, varying int64_t new_size) { return lexical_realloc(ctx, ptr, old_size, reduce_max(new_size)); } extern "C" unmasked uniform int memblock_unref(uniform struct futhark_context * uniform ctx, uniform struct memblock * uniform lhs, uniform const char * uniform lhs_desc); static uniform int memblock_unref(uniform struct futhark_context * varying ctx, uniform struct memblock * varying lhs, uniform const char * uniform lhs_desc) { uniform int err = 0; foreach_active(i) { err |= memblock_unref(extract(ctx,i), extract(lhs,i), lhs_desc); } return err; } static uniform int memblock_unref(uniform struct futhark_context * uniform ctx, varying struct memblock * uniform lhs, uniform const char * uniform lhs_desc) { uniform int err = 0; varying struct memblock _lhs = *lhs; uniform struct memblock aos[programCount]; aos[programIndex] = _lhs; foreach_active(i){ err |= memblock_unref(ctx, &aos[i], lhs_desc); } *lhs = aos[programIndex]; return err; } extern "C" unmasked uniform int memblock_alloc(uniform struct futhark_context * uniform ctx, uniform struct memblock * uniform block, uniform int64_t size, uniform const char * uniform block_desc); static uniform int memblock_alloc(uniform struct futhark_context * varying ctx, uniform struct memblock * varying block, varying int64_t size, uniform const char * uniform block_desc) { uniform int err = 0; foreach_active(i){ err |= memblock_alloc(extract(ctx,i), extract(block,i), extract(size, i), block_desc); } return err; } static uniform int memblock_alloc(uniform struct futhark_context * uniform ctx, varying struct memblock * uniform block, uniform int64_t size, uniform const char * uniform block_desc) { uniform int err = 0; varying struct memblock _block = *block; uniform struct memblock aos[programCount]; aos[programIndex] = _block; foreach_active(i){ err |= memblock_alloc(ctx, &aos[i], size, block_desc); } *block = aos[programIndex]; return err; } static uniform int memblock_alloc(uniform struct futhark_context * uniform ctx, varying struct memblock * uniform block, varying int64_t size, uniform const char * uniform block_desc) { uniform int err = 0; varying struct memblock _block = *block; uniform struct memblock aos[programCount]; aos[programIndex] = _block; foreach_active(i){ err |= memblock_alloc(ctx, &aos[i], extract(size, i), block_desc); } *block = aos[programIndex]; return err; } extern "C" unmasked uniform int memblock_set(uniform struct futhark_context * uniform ctx, uniform struct memblock * uniform lhs, uniform struct memblock * uniform rhs, uniform const char * uniform lhs_desc); static uniform int memblock_set (uniform struct futhark_context * uniform ctx, varying struct memblock * uniform lhs, varying struct memblock * uniform rhs, uniform const char * uniform lhs_desc) { uniform int err = 0; varying struct memblock _lhs = *lhs; varying struct memblock _rhs = *rhs; uniform struct memblock aos1[programCount]; aos1[programIndex] = _lhs; uniform struct memblock aos2[programCount]; aos2[programIndex] = _rhs; foreach_active(i) { err |= memblock_set(ctx, &aos1[i], &aos2[i], lhs_desc); } *lhs = aos1[programIndex]; *rhs = aos2[programIndex]; return err; } static uniform int memblock_set (uniform struct futhark_context * uniform ctx, varying struct memblock * uniform lhs, uniform struct memblock * uniform rhs, uniform const char * uniform lhs_desc) { uniform int err = 0; varying struct memblock _lhs = *lhs; uniform struct memblock aos1[programCount]; aos1[programIndex] = _lhs; foreach_active(i) { err |= memblock_set(ctx, &aos1[i], rhs, lhs_desc); } *lhs = aos1[programIndex]; return err; } // End of ispc_util.h. futhark-0.25.27/rts/c/lock.h000066400000000000000000000025351475065116200155020ustar00rootroot00000000000000// Start of lock.h. // A very simple cross-platform implementation of locks. Uses // pthreads on Unix and some Windows thing there. Futhark's // host-level code is not multithreaded, but user code may be, so we // need some mechanism for ensuring atomic access to API functions. // This is that mechanism. It is not exposed to user code at all, so // we do not have to worry about name collisions. #ifdef _WIN32 typedef HANDLE lock_t; static void create_lock(lock_t *lock) { *lock = CreateMutex(NULL, // Default security attributes. FALSE, // Initially unlocked. NULL); // Unnamed. } static void lock_lock(lock_t *lock) { assert(WaitForSingleObject(*lock, INFINITE) == WAIT_OBJECT_0); } static void lock_unlock(lock_t *lock) { assert(ReleaseMutex(*lock)); } static void free_lock(lock_t *lock) { CloseHandle(*lock); } #else // Assuming POSIX #include typedef pthread_mutex_t lock_t; static void create_lock(lock_t *lock) { int r = pthread_mutex_init(lock, NULL); assert(r == 0); } static void lock_lock(lock_t *lock) { int r = pthread_mutex_lock(lock); assert(r == 0); } static void lock_unlock(lock_t *lock) { int r = pthread_mutex_unlock(lock); assert(r == 0); } static void free_lock(lock_t *lock) { // Nothing to do for pthreads. (void)lock; } #endif // End of lock.h. futhark-0.25.27/rts/c/scalar.h000066400000000000000000002001431475065116200160120ustar00rootroot00000000000000// Start of scalar.h. // Implementation of the primitive scalar operations. Very // repetitive. This code is inserted directly into both CUDA and // OpenCL programs, as well as the CPU code, so it has some #ifdefs to // work everywhere. Some operations are defined as macros because // this allows us to use them as constant expressions in things like // array sizes and static initialisers. // Some of the #ifdefs are because OpenCL uses type-generic functions // for some operations (e.g. sqrt), while C and CUDA sensibly use // distinct functions for different precisions (e.g. sqrtf() and // sqrt()). This is quite annoying. Due to C's unfortunate casting // rules, it is also really easy to accidentally implement // floating-point functions in the wrong precision, so be careful. // Double-precision definitions are only included if the preprocessor // macro FUTHARK_F64_ENABLED is set. SCALAR_FUN_ATTR int32_t futrts_to_bits32(float x); SCALAR_FUN_ATTR float futrts_from_bits32(int32_t x); SCALAR_FUN_ATTR uint8_t add8(uint8_t x, uint8_t y) { return x + y; } SCALAR_FUN_ATTR uint16_t add16(uint16_t x, uint16_t y) { return x + y; } SCALAR_FUN_ATTR uint32_t add32(uint32_t x, uint32_t y) { return x + y; } SCALAR_FUN_ATTR uint64_t add64(uint64_t x, uint64_t y) { return x + y; } SCALAR_FUN_ATTR uint8_t sub8(uint8_t x, uint8_t y) { return x - y; } SCALAR_FUN_ATTR uint16_t sub16(uint16_t x, uint16_t y) { return x - y; } SCALAR_FUN_ATTR uint32_t sub32(uint32_t x, uint32_t y) { return x - y; } SCALAR_FUN_ATTR uint64_t sub64(uint64_t x, uint64_t y) { return x - y; } SCALAR_FUN_ATTR uint8_t mul8(uint8_t x, uint8_t y) { return x * y; } SCALAR_FUN_ATTR uint16_t mul16(uint16_t x, uint16_t y) { return x * y; } SCALAR_FUN_ATTR uint32_t mul32(uint32_t x, uint32_t y) { return x * y; } SCALAR_FUN_ATTR uint64_t mul64(uint64_t x, uint64_t y) { return x * y; } #if defined(ISPC) SCALAR_FUN_ATTR uint8_t udiv8(uint8_t x, uint8_t y) { // This strange pattern is used to prevent the ISPC compiler from // causing SIGFPEs and bogus results on divisions where inactive lanes // have 0-valued divisors. It ensures that any inactive lane instead // has a divisor of 1. https://github.com/ispc/ispc/issues/2292 uint8_t ys = 1; foreach_active(i){ ys = y; } return x / ys; } SCALAR_FUN_ATTR uint16_t udiv16(uint16_t x, uint16_t y) { uint16_t ys = 1; foreach_active(i){ ys = y; } return x / ys; } SCALAR_FUN_ATTR uint32_t udiv32(uint32_t x, uint32_t y) { uint32_t ys = 1; foreach_active(i){ ys = y; } return x / ys; } SCALAR_FUN_ATTR uint64_t udiv64(uint64_t x, uint64_t y) { uint64_t ys = 1; foreach_active(i){ ys = y; } return x / ys; } SCALAR_FUN_ATTR uint8_t udiv_up8(uint8_t x, uint8_t y) { uint8_t ys = 1; foreach_active(i){ ys = y; } return (x + y - 1) / ys; } SCALAR_FUN_ATTR uint16_t udiv_up16(uint16_t x, uint16_t y) { uint16_t ys = 1; foreach_active(i){ ys = y; } return (x + y - 1) / ys; } SCALAR_FUN_ATTR uint32_t udiv_up32(uint32_t x, uint32_t y) { uint32_t ys = 1; foreach_active(i){ ys = y; } return (x + y - 1) / ys; } SCALAR_FUN_ATTR uint64_t udiv_up64(uint64_t x, uint64_t y) { uint64_t ys = 1; foreach_active(i){ ys = y; } return (x + y - 1) / ys; } SCALAR_FUN_ATTR uint8_t umod8(uint8_t x, uint8_t y) { uint8_t ys = 1; foreach_active(i){ ys = y; } return x % ys; } SCALAR_FUN_ATTR uint16_t umod16(uint16_t x, uint16_t y) { uint16_t ys = 1; foreach_active(i){ ys = y; } return x % ys; } SCALAR_FUN_ATTR uint32_t umod32(uint32_t x, uint32_t y) { uint32_t ys = 1; foreach_active(i){ ys = y; } return x % ys; } SCALAR_FUN_ATTR uint64_t umod64(uint64_t x, uint64_t y) { uint64_t ys = 1; foreach_active(i){ ys = y; } return x % ys; } SCALAR_FUN_ATTR uint8_t udiv_safe8(uint8_t x, uint8_t y) { uint8_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x / ys; } SCALAR_FUN_ATTR uint16_t udiv_safe16(uint16_t x, uint16_t y) { uint16_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x / ys; } SCALAR_FUN_ATTR uint32_t udiv_safe32(uint32_t x, uint32_t y) { uint32_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x / ys; } SCALAR_FUN_ATTR uint64_t udiv_safe64(uint64_t x, uint64_t y) { uint64_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x / ys; } SCALAR_FUN_ATTR uint8_t udiv_up_safe8(uint8_t x, uint8_t y) { uint8_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : (x + y - 1) / ys; } SCALAR_FUN_ATTR uint16_t udiv_up_safe16(uint16_t x, uint16_t y) { uint16_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : (x + y - 1) / ys; } SCALAR_FUN_ATTR uint32_t udiv_up_safe32(uint32_t x, uint32_t y) { uint32_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : (x + y - 1) / ys; } SCALAR_FUN_ATTR uint64_t udiv_up_safe64(uint64_t x, uint64_t y) { uint64_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : (x + y - 1) / ys; } SCALAR_FUN_ATTR uint8_t umod_safe8(uint8_t x, uint8_t y) { uint8_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x % ys; } SCALAR_FUN_ATTR uint16_t umod_safe16(uint16_t x, uint16_t y) { uint16_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x % ys; } SCALAR_FUN_ATTR uint32_t umod_safe32(uint32_t x, uint32_t y) { uint32_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x % ys; } SCALAR_FUN_ATTR uint64_t umod_safe64(uint64_t x, uint64_t y) { uint64_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x % ys; } SCALAR_FUN_ATTR int8_t sdiv8(int8_t x, int8_t y) { int8_t ys = 1; foreach_active(i){ ys = y; } int8_t q = x / ys; int8_t r = x % ys; return q - ((r != 0 && r < 0 != y < 0) ? 1 : 0); } SCALAR_FUN_ATTR int16_t sdiv16(int16_t x, int16_t y) { int16_t ys = 1; foreach_active(i){ ys = y; } int16_t q = x / ys; int16_t r = x % ys; return q - ((r != 0 && r < 0 != y < 0) ? 1 : 0); } SCALAR_FUN_ATTR int32_t sdiv32(int32_t x, int32_t y) { int32_t ys = 1; foreach_active(i){ ys = y; } int32_t q = x / ys; int32_t r = x % ys; return q - ((r != 0 && r < 0 != y < 0) ? 1 : 0); } SCALAR_FUN_ATTR int64_t sdiv64(int64_t x, int64_t y) { int64_t ys = 1; foreach_active(i){ ys = y; } int64_t q = x / ys; int64_t r = x % ys; return q - ((r != 0 && r < 0 != y < 0) ? 1 : 0); } SCALAR_FUN_ATTR int8_t sdiv_up8(int8_t x, int8_t y) { return sdiv8(x + y - 1, y); } SCALAR_FUN_ATTR int16_t sdiv_up16(int16_t x, int16_t y) { return sdiv16(x + y - 1, y); } SCALAR_FUN_ATTR int32_t sdiv_up32(int32_t x, int32_t y) { return sdiv32(x + y - 1, y); } SCALAR_FUN_ATTR int64_t sdiv_up64(int64_t x, int64_t y) { return sdiv64(x + y - 1, y); } SCALAR_FUN_ATTR int8_t smod8(int8_t x, int8_t y) { int8_t ys = 1; foreach_active(i){ ys = y; } int8_t r = x % ys; return r + (r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0) ? 0 : y); } SCALAR_FUN_ATTR int16_t smod16(int16_t x, int16_t y) { int16_t ys = 1; foreach_active(i){ ys = y; } int16_t r = x % ys; return r + (r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0) ? 0 : y); } SCALAR_FUN_ATTR int32_t smod32(int32_t x, int32_t y) { int32_t ys = 1; foreach_active(i){ ys = y; } int32_t r = x % ys; return r + (r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0) ? 0 : y); } SCALAR_FUN_ATTR int64_t smod64(int64_t x, int64_t y) { int64_t ys = 1; foreach_active(i){ ys = y; } int64_t r = x % ys; return r + (r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0) ? 0 : y); } SCALAR_FUN_ATTR int8_t sdiv_safe8(int8_t x, int8_t y) { return y == 0 ? 0 : sdiv8(x, y); } SCALAR_FUN_ATTR int16_t sdiv_safe16(int16_t x, int16_t y) { return y == 0 ? 0 : sdiv16(x, y); } SCALAR_FUN_ATTR int32_t sdiv_safe32(int32_t x, int32_t y) { return y == 0 ? 0 : sdiv32(x, y); } SCALAR_FUN_ATTR int64_t sdiv_safe64(int64_t x, int64_t y) { return y == 0 ? 0 : sdiv64(x, y); } SCALAR_FUN_ATTR int8_t sdiv_up_safe8(int8_t x, int8_t y) { return sdiv_safe8(x + y - 1, y); } SCALAR_FUN_ATTR int16_t sdiv_up_safe16(int16_t x, int16_t y) { return sdiv_safe16(x + y - 1, y); } SCALAR_FUN_ATTR int32_t sdiv_up_safe32(int32_t x, int32_t y) { return sdiv_safe32(x + y - 1, y); } SCALAR_FUN_ATTR int64_t sdiv_up_safe64(int64_t x, int64_t y) { return sdiv_safe64(x + y - 1, y); } SCALAR_FUN_ATTR int8_t smod_safe8(int8_t x, int8_t y) { return y == 0 ? 0 : smod8(x, y); } SCALAR_FUN_ATTR int16_t smod_safe16(int16_t x, int16_t y) { return y == 0 ? 0 : smod16(x, y); } SCALAR_FUN_ATTR int32_t smod_safe32(int32_t x, int32_t y) { return y == 0 ? 0 : smod32(x, y); } SCALAR_FUN_ATTR int64_t smod_safe64(int64_t x, int64_t y) { return y == 0 ? 0 : smod64(x, y); } SCALAR_FUN_ATTR int8_t squot8(int8_t x, int8_t y) { int8_t ys = 1; foreach_active(i){ ys = y; } return x / ys; } SCALAR_FUN_ATTR int16_t squot16(int16_t x, int16_t y) { int16_t ys = 1; foreach_active(i){ ys = y; } return x / ys; } SCALAR_FUN_ATTR int32_t squot32(int32_t x, int32_t y) { int32_t ys = 1; foreach_active(i){ ys = y; } return x / ys; } SCALAR_FUN_ATTR int64_t squot64(int64_t x, int64_t y) { int64_t ys = 1; foreach_active(i){ ys = y; } return x / ys; } SCALAR_FUN_ATTR int8_t srem8(int8_t x, int8_t y) { int8_t ys = 1; foreach_active(i){ ys = y; } return x % ys; } SCALAR_FUN_ATTR int16_t srem16(int16_t x, int16_t y) { int16_t ys = 1; foreach_active(i){ ys = y; } return x % ys; } SCALAR_FUN_ATTR int32_t srem32(int32_t x, int32_t y) { int32_t ys = 1; foreach_active(i){ ys = y; } return x % ys; } SCALAR_FUN_ATTR int64_t srem64(int64_t x, int64_t y) { int8_t ys = 1; foreach_active(i){ ys = y; } return x % ys; } SCALAR_FUN_ATTR int8_t squot_safe8(int8_t x, int8_t y) { int8_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x / ys; } SCALAR_FUN_ATTR int16_t squot_safe16(int16_t x, int16_t y) { int16_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x / ys; } SCALAR_FUN_ATTR int32_t squot_safe32(int32_t x, int32_t y) { int32_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x / ys; } SCALAR_FUN_ATTR int64_t squot_safe64(int64_t x, int64_t y) { int64_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x / ys; } SCALAR_FUN_ATTR int8_t srem_safe8(int8_t x, int8_t y) { int8_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x % ys; } SCALAR_FUN_ATTR int16_t srem_safe16(int16_t x, int16_t y) { int16_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x % ys; } SCALAR_FUN_ATTR int32_t srem_safe32(int32_t x, int32_t y) { int32_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x % ys; } SCALAR_FUN_ATTR int64_t srem_safe64(int64_t x, int64_t y) { int64_t ys = 1; foreach_active(i){ ys = y; } return y == 0 ? 0 : x % ys; } #else SCALAR_FUN_ATTR uint8_t udiv8(uint8_t x, uint8_t y) { return x / y; } SCALAR_FUN_ATTR uint16_t udiv16(uint16_t x, uint16_t y) { return x / y; } SCALAR_FUN_ATTR uint32_t udiv32(uint32_t x, uint32_t y) { return x / y; } SCALAR_FUN_ATTR uint64_t udiv64(uint64_t x, uint64_t y) { return x / y; } SCALAR_FUN_ATTR uint8_t udiv_up8(uint8_t x, uint8_t y) { return (x + y - 1) / y; } SCALAR_FUN_ATTR uint16_t udiv_up16(uint16_t x, uint16_t y) { return (x + y - 1) / y; } SCALAR_FUN_ATTR uint32_t udiv_up32(uint32_t x, uint32_t y) { return (x + y - 1) / y; } SCALAR_FUN_ATTR uint64_t udiv_up64(uint64_t x, uint64_t y) { return (x + y - 1) / y; } SCALAR_FUN_ATTR uint8_t umod8(uint8_t x, uint8_t y) { return x % y; } SCALAR_FUN_ATTR uint16_t umod16(uint16_t x, uint16_t y) { return x % y; } SCALAR_FUN_ATTR uint32_t umod32(uint32_t x, uint32_t y) { return x % y; } SCALAR_FUN_ATTR uint64_t umod64(uint64_t x, uint64_t y) { return x % y; } SCALAR_FUN_ATTR uint8_t udiv_safe8(uint8_t x, uint8_t y) { return y == 0 ? 0 : x / y; } SCALAR_FUN_ATTR uint16_t udiv_safe16(uint16_t x, uint16_t y) { return y == 0 ? 0 : x / y; } SCALAR_FUN_ATTR uint32_t udiv_safe32(uint32_t x, uint32_t y) { return y == 0 ? 0 : x / y; } SCALAR_FUN_ATTR uint64_t udiv_safe64(uint64_t x, uint64_t y) { return y == 0 ? 0 : x / y; } SCALAR_FUN_ATTR uint8_t udiv_up_safe8(uint8_t x, uint8_t y) { return y == 0 ? 0 : (x + y - 1) / y; } SCALAR_FUN_ATTR uint16_t udiv_up_safe16(uint16_t x, uint16_t y) { return y == 0 ? 0 : (x + y - 1) / y; } SCALAR_FUN_ATTR uint32_t udiv_up_safe32(uint32_t x, uint32_t y) { return y == 0 ? 0 : (x + y - 1) / y; } SCALAR_FUN_ATTR uint64_t udiv_up_safe64(uint64_t x, uint64_t y) { return y == 0 ? 0 : (x + y - 1) / y; } SCALAR_FUN_ATTR uint8_t umod_safe8(uint8_t x, uint8_t y) { return y == 0 ? 0 : x % y; } SCALAR_FUN_ATTR uint16_t umod_safe16(uint16_t x, uint16_t y) { return y == 0 ? 0 : x % y; } SCALAR_FUN_ATTR uint32_t umod_safe32(uint32_t x, uint32_t y) { return y == 0 ? 0 : x % y; } SCALAR_FUN_ATTR uint64_t umod_safe64(uint64_t x, uint64_t y) { return y == 0 ? 0 : x % y; } SCALAR_FUN_ATTR int8_t sdiv8(int8_t x, int8_t y) { int8_t q = x / y; int8_t r = x % y; return q - ((r != 0 && r < 0 != y < 0) ? 1 : 0); } SCALAR_FUN_ATTR int16_t sdiv16(int16_t x, int16_t y) { int16_t q = x / y; int16_t r = x % y; return q - ((r != 0 && r < 0 != y < 0) ? 1 : 0); } SCALAR_FUN_ATTR int32_t sdiv32(int32_t x, int32_t y) { int32_t q = x / y; int32_t r = x % y; return q - ((r != 0 && r < 0 != y < 0) ? 1 : 0); } SCALAR_FUN_ATTR int64_t sdiv64(int64_t x, int64_t y) { int64_t q = x / y; int64_t r = x % y; return q - ((r != 0 && r < 0 != y < 0) ? 1 : 0); } SCALAR_FUN_ATTR int8_t sdiv_up8(int8_t x, int8_t y) { return sdiv8(x + y - 1, y); } SCALAR_FUN_ATTR int16_t sdiv_up16(int16_t x, int16_t y) { return sdiv16(x + y - 1, y); } SCALAR_FUN_ATTR int32_t sdiv_up32(int32_t x, int32_t y) { return sdiv32(x + y - 1, y); } SCALAR_FUN_ATTR int64_t sdiv_up64(int64_t x, int64_t y) { return sdiv64(x + y - 1, y); } SCALAR_FUN_ATTR int8_t smod8(int8_t x, int8_t y) { int8_t r = x % y; return r + (r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0) ? 0 : y); } SCALAR_FUN_ATTR int16_t smod16(int16_t x, int16_t y) { int16_t r = x % y; return r + (r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0) ? 0 : y); } SCALAR_FUN_ATTR int32_t smod32(int32_t x, int32_t y) { int32_t r = x % y; return r + (r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0) ? 0 : y); } SCALAR_FUN_ATTR int64_t smod64(int64_t x, int64_t y) { int64_t r = x % y; return r + (r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0) ? 0 : y); } SCALAR_FUN_ATTR int8_t sdiv_safe8(int8_t x, int8_t y) { return y == 0 ? 0 : sdiv8(x, y); } SCALAR_FUN_ATTR int16_t sdiv_safe16(int16_t x, int16_t y) { return y == 0 ? 0 : sdiv16(x, y); } SCALAR_FUN_ATTR int32_t sdiv_safe32(int32_t x, int32_t y) { return y == 0 ? 0 : sdiv32(x, y); } SCALAR_FUN_ATTR int64_t sdiv_safe64(int64_t x, int64_t y) { return y == 0 ? 0 : sdiv64(x, y); } SCALAR_FUN_ATTR int8_t sdiv_up_safe8(int8_t x, int8_t y) { return sdiv_safe8(x + y - 1, y); } SCALAR_FUN_ATTR int16_t sdiv_up_safe16(int16_t x, int16_t y) { return sdiv_safe16(x + y - 1, y); } SCALAR_FUN_ATTR int32_t sdiv_up_safe32(int32_t x, int32_t y) { return sdiv_safe32(x + y - 1, y); } SCALAR_FUN_ATTR int64_t sdiv_up_safe64(int64_t x, int64_t y) { return sdiv_safe64(x + y - 1, y); } SCALAR_FUN_ATTR int8_t smod_safe8(int8_t x, int8_t y) { return y == 0 ? 0 : smod8(x, y); } SCALAR_FUN_ATTR int16_t smod_safe16(int16_t x, int16_t y) { return y == 0 ? 0 : smod16(x, y); } SCALAR_FUN_ATTR int32_t smod_safe32(int32_t x, int32_t y) { return y == 0 ? 0 : smod32(x, y); } SCALAR_FUN_ATTR int64_t smod_safe64(int64_t x, int64_t y) { return y == 0 ? 0 : smod64(x, y); } SCALAR_FUN_ATTR int8_t squot8(int8_t x, int8_t y) { return x / y; } SCALAR_FUN_ATTR int16_t squot16(int16_t x, int16_t y) { return x / y; } SCALAR_FUN_ATTR int32_t squot32(int32_t x, int32_t y) { return x / y; } SCALAR_FUN_ATTR int64_t squot64(int64_t x, int64_t y) { return x / y; } SCALAR_FUN_ATTR int8_t srem8(int8_t x, int8_t y) { return x % y; } SCALAR_FUN_ATTR int16_t srem16(int16_t x, int16_t y) { return x % y; } SCALAR_FUN_ATTR int32_t srem32(int32_t x, int32_t y) { return x % y; } SCALAR_FUN_ATTR int64_t srem64(int64_t x, int64_t y) { return x % y; } SCALAR_FUN_ATTR int8_t squot_safe8(int8_t x, int8_t y) { return y == 0 ? 0 : x / y; } SCALAR_FUN_ATTR int16_t squot_safe16(int16_t x, int16_t y) { return y == 0 ? 0 : x / y; } SCALAR_FUN_ATTR int32_t squot_safe32(int32_t x, int32_t y) { return y == 0 ? 0 : x / y; } SCALAR_FUN_ATTR int64_t squot_safe64(int64_t x, int64_t y) { return y == 0 ? 0 : x / y; } SCALAR_FUN_ATTR int8_t srem_safe8(int8_t x, int8_t y) { return y == 0 ? 0 : x % y; } SCALAR_FUN_ATTR int16_t srem_safe16(int16_t x, int16_t y) { return y == 0 ? 0 : x % y; } SCALAR_FUN_ATTR int32_t srem_safe32(int32_t x, int32_t y) { return y == 0 ? 0 : x % y; } SCALAR_FUN_ATTR int64_t srem_safe64(int64_t x, int64_t y) { return y == 0 ? 0 : x % y; } #endif SCALAR_FUN_ATTR int8_t smin8(int8_t x, int8_t y) { return x < y ? x : y; } SCALAR_FUN_ATTR int16_t smin16(int16_t x, int16_t y) { return x < y ? x : y; } SCALAR_FUN_ATTR int32_t smin32(int32_t x, int32_t y) { return x < y ? x : y; } SCALAR_FUN_ATTR int64_t smin64(int64_t x, int64_t y) { return x < y ? x : y; } SCALAR_FUN_ATTR uint8_t umin8(uint8_t x, uint8_t y) { return x < y ? x : y; } SCALAR_FUN_ATTR uint16_t umin16(uint16_t x, uint16_t y) { return x < y ? x : y; } SCALAR_FUN_ATTR uint32_t umin32(uint32_t x, uint32_t y) { return x < y ? x : y; } SCALAR_FUN_ATTR uint64_t umin64(uint64_t x, uint64_t y) { return x < y ? x : y; } SCALAR_FUN_ATTR int8_t smax8(int8_t x, int8_t y) { return x < y ? y : x; } SCALAR_FUN_ATTR int16_t smax16(int16_t x, int16_t y) { return x < y ? y : x; } SCALAR_FUN_ATTR int32_t smax32(int32_t x, int32_t y) { return x < y ? y : x; } SCALAR_FUN_ATTR int64_t smax64(int64_t x, int64_t y) { return x < y ? y : x; } SCALAR_FUN_ATTR uint8_t umax8(uint8_t x, uint8_t y) { return x < y ? y : x; } SCALAR_FUN_ATTR uint16_t umax16(uint16_t x, uint16_t y) { return x < y ? y : x; } SCALAR_FUN_ATTR uint32_t umax32(uint32_t x, uint32_t y) { return x < y ? y : x; } SCALAR_FUN_ATTR uint64_t umax64(uint64_t x, uint64_t y) { return x < y ? y : x; } SCALAR_FUN_ATTR uint8_t shl8(uint8_t x, uint8_t y) { return (uint8_t)(x << y); } SCALAR_FUN_ATTR uint16_t shl16(uint16_t x, uint16_t y) { return (uint16_t)(x << y); } SCALAR_FUN_ATTR uint32_t shl32(uint32_t x, uint32_t y) { return x << y; } SCALAR_FUN_ATTR uint64_t shl64(uint64_t x, uint64_t y) { return x << y; } SCALAR_FUN_ATTR uint8_t lshr8(uint8_t x, uint8_t y) { return x >> y; } SCALAR_FUN_ATTR uint16_t lshr16(uint16_t x, uint16_t y) { return x >> y; } SCALAR_FUN_ATTR uint32_t lshr32(uint32_t x, uint32_t y) { return x >> y; } SCALAR_FUN_ATTR uint64_t lshr64(uint64_t x, uint64_t y) { return x >> y; } SCALAR_FUN_ATTR int8_t ashr8(int8_t x, int8_t y) { return x >> y; } SCALAR_FUN_ATTR int16_t ashr16(int16_t x, int16_t y) { return x >> y; } SCALAR_FUN_ATTR int32_t ashr32(int32_t x, int32_t y) { return x >> y; } SCALAR_FUN_ATTR int64_t ashr64(int64_t x, int64_t y) { return x >> y; } SCALAR_FUN_ATTR uint8_t and8(uint8_t x, uint8_t y) { return x & y; } SCALAR_FUN_ATTR uint16_t and16(uint16_t x, uint16_t y) { return x & y; } SCALAR_FUN_ATTR uint32_t and32(uint32_t x, uint32_t y) { return x & y; } SCALAR_FUN_ATTR uint64_t and64(uint64_t x, uint64_t y) { return x & y; } SCALAR_FUN_ATTR uint8_t or8(uint8_t x, uint8_t y) { return x | y; } SCALAR_FUN_ATTR uint16_t or16(uint16_t x, uint16_t y) { return x | y; } SCALAR_FUN_ATTR uint32_t or32(uint32_t x, uint32_t y) { return x | y; } SCALAR_FUN_ATTR uint64_t or64(uint64_t x, uint64_t y) { return x | y; } SCALAR_FUN_ATTR uint8_t xor8(uint8_t x, uint8_t y) { return x ^ y; } SCALAR_FUN_ATTR uint16_t xor16(uint16_t x, uint16_t y) { return x ^ y; } SCALAR_FUN_ATTR uint32_t xor32(uint32_t x, uint32_t y) { return x ^ y; } SCALAR_FUN_ATTR uint64_t xor64(uint64_t x, uint64_t y) { return x ^ y; } SCALAR_FUN_ATTR bool ult8(uint8_t x, uint8_t y) { return x < y; } SCALAR_FUN_ATTR bool ult16(uint16_t x, uint16_t y) { return x < y; } SCALAR_FUN_ATTR bool ult32(uint32_t x, uint32_t y) { return x < y; } SCALAR_FUN_ATTR bool ult64(uint64_t x, uint64_t y) { return x < y; } SCALAR_FUN_ATTR bool ule8(uint8_t x, uint8_t y) { return x <= y; } SCALAR_FUN_ATTR bool ule16(uint16_t x, uint16_t y) { return x <= y; } SCALAR_FUN_ATTR bool ule32(uint32_t x, uint32_t y) { return x <= y; } SCALAR_FUN_ATTR bool ule64(uint64_t x, uint64_t y) { return x <= y; } SCALAR_FUN_ATTR bool slt8(int8_t x, int8_t y) { return x < y; } SCALAR_FUN_ATTR bool slt16(int16_t x, int16_t y) { return x < y; } SCALAR_FUN_ATTR bool slt32(int32_t x, int32_t y) { return x < y; } SCALAR_FUN_ATTR bool slt64(int64_t x, int64_t y) { return x < y; } SCALAR_FUN_ATTR bool sle8(int8_t x, int8_t y) { return x <= y; } SCALAR_FUN_ATTR bool sle16(int16_t x, int16_t y) { return x <= y; } SCALAR_FUN_ATTR bool sle32(int32_t x, int32_t y) { return x <= y; } SCALAR_FUN_ATTR bool sle64(int64_t x, int64_t y) { return x <= y; } SCALAR_FUN_ATTR uint8_t pow8(uint8_t x, uint8_t y) { uint8_t res = 1, rem = y; while (rem != 0) { if (rem & 1) res *= x; rem >>= 1; x *= x; } return res; } SCALAR_FUN_ATTR uint16_t pow16(uint16_t x, uint16_t y) { uint16_t res = 1, rem = y; while (rem != 0) { if (rem & 1) res *= x; rem >>= 1; x *= x; } return res; } SCALAR_FUN_ATTR uint32_t pow32(uint32_t x, uint32_t y) { uint32_t res = 1, rem = y; while (rem != 0) { if (rem & 1) res *= x; rem >>= 1; x *= x; } return res; } SCALAR_FUN_ATTR uint64_t pow64(uint64_t x, uint64_t y) { uint64_t res = 1, rem = y; while (rem != 0) { if (rem & 1) res *= x; rem >>= 1; x *= x; } return res; } SCALAR_FUN_ATTR bool itob_i8_bool(int8_t x) { return x != 0; } SCALAR_FUN_ATTR bool itob_i16_bool(int16_t x) { return x != 0; } SCALAR_FUN_ATTR bool itob_i32_bool(int32_t x) { return x != 0; } SCALAR_FUN_ATTR bool itob_i64_bool(int64_t x) { return x != 0; } SCALAR_FUN_ATTR int8_t btoi_bool_i8(bool x) { return x; } SCALAR_FUN_ATTR int16_t btoi_bool_i16(bool x) { return x; } SCALAR_FUN_ATTR int32_t btoi_bool_i32(bool x) { return x; } SCALAR_FUN_ATTR int64_t btoi_bool_i64(bool x) { return x; } #define sext_i8_i8(x) ((int8_t) (int8_t) (x)) #define sext_i8_i16(x) ((int16_t) (int8_t) (x)) #define sext_i8_i32(x) ((int32_t) (int8_t) (x)) #define sext_i8_i64(x) ((int64_t) (int8_t) (x)) #define sext_i16_i8(x) ((int8_t) (int16_t) (x)) #define sext_i16_i16(x) ((int16_t) (int16_t) (x)) #define sext_i16_i32(x) ((int32_t) (int16_t) (x)) #define sext_i16_i64(x) ((int64_t) (int16_t) (x)) #define sext_i32_i8(x) ((int8_t) (int32_t) (x)) #define sext_i32_i16(x) ((int16_t) (int32_t) (x)) #define sext_i32_i32(x) ((int32_t) (int32_t) (x)) #define sext_i32_i64(x) ((int64_t) (int32_t) (x)) #define sext_i64_i8(x) ((int8_t) (int64_t) (x)) #define sext_i64_i16(x) ((int16_t) (int64_t) (x)) #define sext_i64_i32(x) ((int32_t) (int64_t) (x)) #define sext_i64_i64(x) ((int64_t) (int64_t) (x)) #define zext_i8_i8(x) ((int8_t) (uint8_t) (x)) #define zext_i8_i16(x) ((int16_t) (uint8_t) (x)) #define zext_i8_i32(x) ((int32_t) (uint8_t) (x)) #define zext_i8_i64(x) ((int64_t) (uint8_t) (x)) #define zext_i16_i8(x) ((int8_t) (uint16_t) (x)) #define zext_i16_i16(x) ((int16_t) (uint16_t) (x)) #define zext_i16_i32(x) ((int32_t) (uint16_t) (x)) #define zext_i16_i64(x) ((int64_t) (uint16_t) (x)) #define zext_i32_i8(x) ((int8_t) (uint32_t) (x)) #define zext_i32_i16(x) ((int16_t) (uint32_t) (x)) #define zext_i32_i32(x) ((int32_t) (uint32_t) (x)) #define zext_i32_i64(x) ((int64_t) (uint32_t) (x)) #define zext_i64_i8(x) ((int8_t) (uint64_t) (x)) #define zext_i64_i16(x) ((int16_t) (uint64_t) (x)) #define zext_i64_i32(x) ((int32_t) (uint64_t) (x)) #define zext_i64_i64(x) ((int64_t) (uint64_t) (x)) SCALAR_FUN_ATTR int8_t abs8(int8_t x) { return (int8_t)abs(x); } SCALAR_FUN_ATTR int16_t abs16(int16_t x) { return (int16_t)abs(x); } SCALAR_FUN_ATTR int32_t abs32(int32_t x) { return abs(x); } SCALAR_FUN_ATTR int64_t abs64(int64_t x) { #if defined(__OPENCL_VERSION__) || defined(ISPC) return abs(x); #else return llabs(x); #endif } #if defined(__OPENCL_VERSION__) SCALAR_FUN_ATTR int32_t futrts_popc8(int8_t x) { return popcount(x); } SCALAR_FUN_ATTR int32_t futrts_popc16(int16_t x) { return popcount(x); } SCALAR_FUN_ATTR int32_t futrts_popc32(int32_t x) { return popcount(x); } SCALAR_FUN_ATTR int32_t futrts_popc64(int64_t x) { return popcount(x); } #elif defined(__CUDA_ARCH__) SCALAR_FUN_ATTR int32_t futrts_popc8(int8_t x) { return __popc(zext_i8_i32(x)); } SCALAR_FUN_ATTR int32_t futrts_popc16(int16_t x) { return __popc(zext_i16_i32(x)); } SCALAR_FUN_ATTR int32_t futrts_popc32(int32_t x) { return __popc(x); } SCALAR_FUN_ATTR int32_t futrts_popc64(int64_t x) { return __popcll(x); } #else // Not OpenCL or CUDA, but plain C. SCALAR_FUN_ATTR int32_t futrts_popc8(uint8_t x) { int c = 0; for (; x; ++c) { x &= x - 1; } return c; } SCALAR_FUN_ATTR int32_t futrts_popc16(uint16_t x) { int c = 0; for (; x; ++c) { x &= x - 1; } return c; } SCALAR_FUN_ATTR int32_t futrts_popc32(uint32_t x) { int c = 0; for (; x; ++c) { x &= x - 1; } return c; } SCALAR_FUN_ATTR int32_t futrts_popc64(uint64_t x) { int c = 0; for (; x; ++c) { x &= x - 1; } return c; } #endif #if defined(__OPENCL_VERSION__) SCALAR_FUN_ATTR uint8_t futrts_umul_hi8 ( uint8_t a, uint8_t b) { return mul_hi(a, b); } SCALAR_FUN_ATTR uint16_t futrts_umul_hi16(uint16_t a, uint16_t b) { return mul_hi(a, b); } SCALAR_FUN_ATTR uint32_t futrts_umul_hi32(uint32_t a, uint32_t b) { return mul_hi(a, b); } SCALAR_FUN_ATTR uint64_t futrts_umul_hi64(uint64_t a, uint64_t b) { return mul_hi(a, b); } SCALAR_FUN_ATTR uint8_t futrts_smul_hi8 ( int8_t a, int8_t b) { return mul_hi(a, b); } SCALAR_FUN_ATTR uint16_t futrts_smul_hi16(int16_t a, int16_t b) { return mul_hi(a, b); } SCALAR_FUN_ATTR uint32_t futrts_smul_hi32(int32_t a, int32_t b) { return mul_hi(a, b); } SCALAR_FUN_ATTR uint64_t futrts_smul_hi64(int64_t a, int64_t b) { return mul_hi(a, b); } #elif defined(__CUDA_ARCH__) SCALAR_FUN_ATTR uint8_t futrts_umul_hi8(uint8_t a, uint8_t b) { return ((uint16_t)a) * ((uint16_t)b) >> 8; } SCALAR_FUN_ATTR uint16_t futrts_umul_hi16(uint16_t a, uint16_t b) { return ((uint32_t)a) * ((uint32_t)b) >> 16; } SCALAR_FUN_ATTR uint32_t futrts_umul_hi32(uint32_t a, uint32_t b) { return __umulhi(a, b); } SCALAR_FUN_ATTR uint64_t futrts_umul_hi64(uint64_t a, uint64_t b) { return __umul64hi(a, b); } SCALAR_FUN_ATTR uint8_t futrts_smul_hi8 ( int8_t a, int8_t b) { return ((int16_t)a) * ((int16_t)b) >> 8; } SCALAR_FUN_ATTR uint16_t futrts_smul_hi16(int16_t a, int16_t b) { return ((int32_t)a) * ((int32_t)b) >> 16; } SCALAR_FUN_ATTR uint32_t futrts_smul_hi32(int32_t a, int32_t b) { return __mulhi(a, b); } SCALAR_FUN_ATTR uint64_t futrts_smul_hi64(int64_t a, int64_t b) { return __mul64hi(a, b); } #elif defined(ISPC) SCALAR_FUN_ATTR uint8_t futrts_umul_hi8(uint8_t a, uint8_t b) { return ((uint16_t)a) * ((uint16_t)b) >> 8; } SCALAR_FUN_ATTR uint16_t futrts_umul_hi16(uint16_t a, uint16_t b) { return ((uint32_t)a) * ((uint32_t)b) >> 16; } SCALAR_FUN_ATTR uint32_t futrts_umul_hi32(uint32_t a, uint32_t b) { return ((uint64_t)a) * ((uint64_t)b) >> 32; } SCALAR_FUN_ATTR uint64_t futrts_umul_hi64(uint64_t a, uint64_t b) { uint64_t ah = a >> 32; uint64_t al = a & 0xffffffff; uint64_t bh = b >> 32; uint64_t bl = b & 0xffffffff; uint64_t p1 = al * bl; uint64_t p2 = al * bh; uint64_t p3 = ah * bl; uint64_t p4 = ah * bh; uint64_t p1h = p1 >> 32; uint64_t p2h = p2 >> 32; uint64_t p3h = p3 >> 32; uint64_t p2l = p2 & 0xffffffff; uint64_t p3l = p3 & 0xffffffff; uint64_t l = p1h + p2l + p3l; uint64_t m = (p2 >> 32) + (p3 >> 32); uint64_t h = (l >> 32) + m + p4; return h; } SCALAR_FUN_ATTR int8_t futrts_smul_hi8 ( int8_t a, int8_t b) { return ((uint16_t)a) * ((uint16_t)b) >> 8; } SCALAR_FUN_ATTR int16_t futrts_smul_hi16(int16_t a, int16_t b) { return ((uint32_t)a) * ((uint32_t)b) >> 16; } SCALAR_FUN_ATTR int32_t futrts_smul_hi32(int32_t a, int32_t b) { return ((uint64_t)a) * ((uint64_t)b) >> 32; } SCALAR_FUN_ATTR int64_t futrts_smul_hi64(int64_t a, int64_t b) { uint64_t ah = a >> 32; uint64_t al = a & 0xffffffff; uint64_t bh = b >> 32; uint64_t bl = b & 0xffffffff; uint64_t p1 = al * bl; int64_t p2 = al * bh; int64_t p3 = ah * bl; uint64_t p4 = ah * bh; uint64_t p1h = p1 >> 32; uint64_t p2h = p2 >> 32; uint64_t p3h = p3 >> 32; uint64_t p2l = p2 & 0xffffffff; uint64_t p3l = p3 & 0xffffffff; uint64_t l = p1h + p2l + p3l; uint64_t m = (p2 >> 32) + (p3 >> 32); uint64_t h = (l >> 32) + m + p4; return h; } #else // Not OpenCL, ISPC, or CUDA, but plain C. SCALAR_FUN_ATTR uint8_t futrts_umul_hi8(uint8_t a, uint8_t b) { return ((uint16_t)a) * ((uint16_t)b) >> 8; } SCALAR_FUN_ATTR uint16_t futrts_umul_hi16(uint16_t a, uint16_t b) { return ((uint32_t)a) * ((uint32_t)b) >> 16; } SCALAR_FUN_ATTR uint32_t futrts_umul_hi32(uint32_t a, uint32_t b) { return ((uint64_t)a) * ((uint64_t)b) >> 32; } SCALAR_FUN_ATTR uint64_t futrts_umul_hi64(uint64_t a, uint64_t b) { return ((__uint128_t)a) * ((__uint128_t)b) >> 64; } SCALAR_FUN_ATTR int8_t futrts_smul_hi8(int8_t a, int8_t b) { return ((int16_t)a) * ((int16_t)b) >> 8; } SCALAR_FUN_ATTR int16_t futrts_smul_hi16(int16_t a, int16_t b) { return ((int32_t)a) * ((int32_t)b) >> 16; } SCALAR_FUN_ATTR int32_t futrts_smul_hi32(int32_t a, int32_t b) { return ((int64_t)a) * ((int64_t)b) >> 32; } SCALAR_FUN_ATTR int64_t futrts_smul_hi64(int64_t a, int64_t b) { return ((__int128_t)a) * ((__int128_t)b) >> 64; } #endif #if defined(__OPENCL_VERSION__) SCALAR_FUN_ATTR uint8_t futrts_umad_hi8 ( uint8_t a, uint8_t b, uint8_t c) { return mad_hi(a, b, c); } SCALAR_FUN_ATTR uint16_t futrts_umad_hi16(uint16_t a, uint16_t b, uint16_t c) { return mad_hi(a, b, c); } SCALAR_FUN_ATTR uint32_t futrts_umad_hi32(uint32_t a, uint32_t b, uint32_t c) { return mad_hi(a, b, c); } SCALAR_FUN_ATTR uint64_t futrts_umad_hi64(uint64_t a, uint64_t b, uint64_t c) { return mad_hi(a, b, c); } SCALAR_FUN_ATTR uint8_t futrts_smad_hi8( int8_t a, int8_t b, int8_t c) { return mad_hi(a, b, c); } SCALAR_FUN_ATTR uint16_t futrts_smad_hi16(int16_t a, int16_t b, int16_t c) { return mad_hi(a, b, c); } SCALAR_FUN_ATTR uint32_t futrts_smad_hi32(int32_t a, int32_t b, int32_t c) { return mad_hi(a, b, c); } SCALAR_FUN_ATTR uint64_t futrts_smad_hi64(int64_t a, int64_t b, int64_t c) { return mad_hi(a, b, c); } #else // Not OpenCL SCALAR_FUN_ATTR uint8_t futrts_umad_hi8( uint8_t a, uint8_t b, uint8_t c) { return futrts_umul_hi8(a, b) + c; } SCALAR_FUN_ATTR uint16_t futrts_umad_hi16(uint16_t a, uint16_t b, uint16_t c) { return futrts_umul_hi16(a, b) + c; } SCALAR_FUN_ATTR uint32_t futrts_umad_hi32(uint32_t a, uint32_t b, uint32_t c) { return futrts_umul_hi32(a, b) + c; } SCALAR_FUN_ATTR uint64_t futrts_umad_hi64(uint64_t a, uint64_t b, uint64_t c) { return futrts_umul_hi64(a, b) + c; } SCALAR_FUN_ATTR uint8_t futrts_smad_hi8 ( int8_t a, int8_t b, int8_t c) { return futrts_smul_hi8(a, b) + c; } SCALAR_FUN_ATTR uint16_t futrts_smad_hi16(int16_t a, int16_t b, int16_t c) { return futrts_smul_hi16(a, b) + c; } SCALAR_FUN_ATTR uint32_t futrts_smad_hi32(int32_t a, int32_t b, int32_t c) { return futrts_smul_hi32(a, b) + c; } SCALAR_FUN_ATTR uint64_t futrts_smad_hi64(int64_t a, int64_t b, int64_t c) { return futrts_smul_hi64(a, b) + c; } #endif #if defined(__OPENCL_VERSION__) SCALAR_FUN_ATTR int32_t futrts_clzz8(int8_t x) { return clz(x); } SCALAR_FUN_ATTR int32_t futrts_clzz16(int16_t x) { return clz(x); } SCALAR_FUN_ATTR int32_t futrts_clzz32(int32_t x) { return clz(x); } SCALAR_FUN_ATTR int32_t futrts_clzz64(int64_t x) { return clz(x); } #elif defined(__CUDA_ARCH__) SCALAR_FUN_ATTR int32_t futrts_clzz8(int8_t x) { return __clz(zext_i8_i32(x)) - 24; } SCALAR_FUN_ATTR int32_t futrts_clzz16(int16_t x) { return __clz(zext_i16_i32(x)) - 16; } SCALAR_FUN_ATTR int32_t futrts_clzz32(int32_t x) { return __clz(x); } SCALAR_FUN_ATTR int32_t futrts_clzz64(int64_t x) { return __clzll(x); } #elif defined(ISPC) SCALAR_FUN_ATTR int32_t futrts_clzz8(int8_t x) { return count_leading_zeros((int32_t)(uint8_t)x)-24; } SCALAR_FUN_ATTR int32_t futrts_clzz16(int16_t x) { return count_leading_zeros((int32_t)(uint16_t)x)-16; } SCALAR_FUN_ATTR int32_t futrts_clzz32(int32_t x) { return count_leading_zeros(x); } SCALAR_FUN_ATTR int32_t futrts_clzz64(int64_t x) { return count_leading_zeros(x); } #else // Not OpenCL, ISPC or CUDA, but plain C. SCALAR_FUN_ATTR int32_t futrts_clzz8(int8_t x) { return x == 0 ? 8 : __builtin_clz((uint32_t)zext_i8_i32(x)) - 24; } SCALAR_FUN_ATTR int32_t futrts_clzz16(int16_t x) { return x == 0 ? 16 : __builtin_clz((uint32_t)zext_i16_i32(x)) - 16; } SCALAR_FUN_ATTR int32_t futrts_clzz32(int32_t x) { return x == 0 ? 32 : __builtin_clz((uint32_t)x); } SCALAR_FUN_ATTR int32_t futrts_clzz64(int64_t x) { return x == 0 ? 64 : __builtin_clzll((uint64_t)x); } #endif #if defined(__OPENCL_VERSION__) SCALAR_FUN_ATTR int32_t futrts_ctzz8(int8_t x) { int i = 0; for (; i < 8 && (x & 1) == 0; i++, x >>= 1) ; return i; } SCALAR_FUN_ATTR int32_t futrts_ctzz16(int16_t x) { int i = 0; for (; i < 16 && (x & 1) == 0; i++, x >>= 1) ; return i; } SCALAR_FUN_ATTR int32_t futrts_ctzz32(int32_t x) { int i = 0; for (; i < 32 && (x & 1) == 0; i++, x >>= 1) ; return i; } SCALAR_FUN_ATTR int32_t futrts_ctzz64(int64_t x) { int i = 0; for (; i < 64 && (x & 1) == 0; i++, x >>= 1) ; return i; } #elif defined(__CUDA_ARCH__) SCALAR_FUN_ATTR int32_t futrts_ctzz8(int8_t x) { int y = __ffs(x); return y == 0 ? 8 : y - 1; } SCALAR_FUN_ATTR int32_t futrts_ctzz16(int16_t x) { int y = __ffs(x); return y == 0 ? 16 : y - 1; } SCALAR_FUN_ATTR int32_t futrts_ctzz32(int32_t x) { int y = __ffs(x); return y == 0 ? 32 : y - 1; } SCALAR_FUN_ATTR int32_t futrts_ctzz64(int64_t x) { int y = __ffsll(x); return y == 0 ? 64 : y - 1; } #elif defined(ISPC) SCALAR_FUN_ATTR int32_t futrts_ctzz8(int8_t x) { return x == 0 ? 8 : count_trailing_zeros((int32_t)x); } SCALAR_FUN_ATTR int32_t futrts_ctzz16(int16_t x) { return x == 0 ? 16 : count_trailing_zeros((int32_t)x); } SCALAR_FUN_ATTR int32_t futrts_ctzz32(int32_t x) { return count_trailing_zeros(x); } SCALAR_FUN_ATTR int32_t futrts_ctzz64(int64_t x) { return count_trailing_zeros(x); } #else // Not OpenCL or CUDA, but plain C. SCALAR_FUN_ATTR int32_t futrts_ctzz8(int8_t x) { return x == 0 ? 8 : __builtin_ctz((uint32_t)x); } SCALAR_FUN_ATTR int32_t futrts_ctzz16(int16_t x) { return x == 0 ? 16 : __builtin_ctz((uint32_t)x); } SCALAR_FUN_ATTR int32_t futrts_ctzz32(int32_t x) { return x == 0 ? 32 : __builtin_ctz((uint32_t)x); } SCALAR_FUN_ATTR int32_t futrts_ctzz64(int64_t x) { return x == 0 ? 64 : __builtin_ctzll((uint64_t)x); } #endif SCALAR_FUN_ATTR float fdiv32(float x, float y) { return x / y; } SCALAR_FUN_ATTR float fadd32(float x, float y) { return x + y; } SCALAR_FUN_ATTR float fsub32(float x, float y) { return x - y; } SCALAR_FUN_ATTR float fmul32(float x, float y) { return x * y; } SCALAR_FUN_ATTR bool cmplt32(float x, float y) { return x < y; } SCALAR_FUN_ATTR bool cmple32(float x, float y) { return x <= y; } SCALAR_FUN_ATTR float sitofp_i8_f32(int8_t x) { return (float) x; } SCALAR_FUN_ATTR float sitofp_i16_f32(int16_t x) { return (float) x; } SCALAR_FUN_ATTR float sitofp_i32_f32(int32_t x) { return (float) x; } SCALAR_FUN_ATTR float sitofp_i64_f32(int64_t x) { return (float) x; } SCALAR_FUN_ATTR float uitofp_i8_f32(uint8_t x) { return (float) x; } SCALAR_FUN_ATTR float uitofp_i16_f32(uint16_t x) { return (float) x; } SCALAR_FUN_ATTR float uitofp_i32_f32(uint32_t x) { return (float) x; } SCALAR_FUN_ATTR float uitofp_i64_f32(uint64_t x) { return (float) x; } #ifdef __OPENCL_VERSION__ SCALAR_FUN_ATTR float fabs32(float x) { return fabs(x); } SCALAR_FUN_ATTR float fmax32(float x, float y) { return fmax(x, y); } SCALAR_FUN_ATTR float fmin32(float x, float y) { return fmin(x, y); } SCALAR_FUN_ATTR float fpow32(float x, float y) { return pow(x, y); } #elif defined(ISPC) SCALAR_FUN_ATTR float fabs32(float x) { return abs(x); } SCALAR_FUN_ATTR float fmax32(float x, float y) { return isnan(x) ? y : isnan(y) ? x : max(x, y); } SCALAR_FUN_ATTR float fmin32(float x, float y) { return isnan(x) ? y : isnan(y) ? x : min(x, y); } SCALAR_FUN_ATTR float fpow32(float a, float b) { float ret; foreach_active (i) { uniform float r = pow(extract(a, i), extract(b, i)); ret = insert(ret, i, r); } return ret; } #else // Not OpenCL, but CUDA or plain C. SCALAR_FUN_ATTR float fabs32(float x) { return fabsf(x); } SCALAR_FUN_ATTR float fmax32(float x, float y) { return fmaxf(x, y); } SCALAR_FUN_ATTR float fmin32(float x, float y) { return fminf(x, y); } SCALAR_FUN_ATTR float fpow32(float x, float y) { return powf(x, y); } #endif SCALAR_FUN_ATTR bool futrts_isnan32(float x) { return isnan(x); } #if defined(ISPC) SCALAR_FUN_ATTR bool futrts_isinf32(float x) { return !isnan(x) && isnan(x - x); } SCALAR_FUN_ATTR bool futrts_isfinite32(float x) { return !isnan(x) && !futrts_isinf32(x); } #else SCALAR_FUN_ATTR bool futrts_isinf32(float x) { return isinf(x); } #endif SCALAR_FUN_ATTR int8_t fptosi_f32_i8(float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (int8_t) x; } } SCALAR_FUN_ATTR int16_t fptosi_f32_i16(float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (int16_t) x; } } SCALAR_FUN_ATTR int32_t fptosi_f32_i32(float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (int32_t) x; } } SCALAR_FUN_ATTR int64_t fptosi_f32_i64(float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (int64_t) x; }; } SCALAR_FUN_ATTR uint8_t fptoui_f32_i8(float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (uint8_t) (int8_t) x; } } SCALAR_FUN_ATTR uint16_t fptoui_f32_i16(float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (uint16_t) (int16_t) x; } } SCALAR_FUN_ATTR uint32_t fptoui_f32_i32(float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (uint32_t) (int32_t) x; } } SCALAR_FUN_ATTR uint64_t fptoui_f32_i64(float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (uint64_t) (int64_t) x; } } SCALAR_FUN_ATTR bool ftob_f32_bool(float x) { return x != 0; } SCALAR_FUN_ATTR float btof_bool_f32(bool x) { return x ? 1 : 0; } #ifdef __OPENCL_VERSION__ SCALAR_FUN_ATTR float futrts_log32(float x) { return log(x); } SCALAR_FUN_ATTR float futrts_log2_32(float x) { return log2(x); } SCALAR_FUN_ATTR float futrts_log10_32(float x) { return log10(x); } SCALAR_FUN_ATTR float futrts_log1p_32(float x) { return log1p(x); } SCALAR_FUN_ATTR float futrts_sqrt32(float x) { return sqrt(x); } SCALAR_FUN_ATTR float futrts_cbrt32(float x) { return cbrt(x); } SCALAR_FUN_ATTR float futrts_exp32(float x) { return exp(x); } SCALAR_FUN_ATTR float futrts_cos32(float x) { return cos(x); } SCALAR_FUN_ATTR float futrts_sin32(float x) { return sin(x); } SCALAR_FUN_ATTR float futrts_tan32(float x) { return tan(x); } SCALAR_FUN_ATTR float futrts_acos32(float x) { return acos(x); } SCALAR_FUN_ATTR float futrts_asin32(float x) { return asin(x); } SCALAR_FUN_ATTR float futrts_atan32(float x) { return atan(x); } SCALAR_FUN_ATTR float futrts_cosh32(float x) { return cosh(x); } SCALAR_FUN_ATTR float futrts_sinh32(float x) { return sinh(x); } SCALAR_FUN_ATTR float futrts_tanh32(float x) { return tanh(x); } SCALAR_FUN_ATTR float futrts_acosh32(float x) { return acosh(x); } SCALAR_FUN_ATTR float futrts_asinh32(float x) { return asinh(x); } SCALAR_FUN_ATTR float futrts_atanh32(float x) { return atanh(x); } SCALAR_FUN_ATTR float futrts_atan2_32(float x, float y) { return atan2(x, y); } SCALAR_FUN_ATTR float futrts_hypot32(float x, float y) { return hypot(x, y); } SCALAR_FUN_ATTR float futrts_gamma32(float x) { return tgamma(x); } SCALAR_FUN_ATTR float futrts_lgamma32(float x) { return lgamma(x); } SCALAR_FUN_ATTR float futrts_erf32(float x) { return erf(x); } SCALAR_FUN_ATTR float futrts_erfc32(float x) { return erfc(x); } SCALAR_FUN_ATTR float fmod32(float x, float y) { return fmod(x, y); } SCALAR_FUN_ATTR float futrts_round32(float x) { return rint(x); } SCALAR_FUN_ATTR float futrts_floor32(float x) { return floor(x); } SCALAR_FUN_ATTR float futrts_ceil32(float x) { return ceil(x); } SCALAR_FUN_ATTR float futrts_nextafter32(float x, float y) { return nextafter(x, y); } SCALAR_FUN_ATTR float futrts_lerp32(float v0, float v1, float t) { return mix(v0, v1, t); } SCALAR_FUN_ATTR float futrts_ldexp32(float x, int32_t y) { return ldexp(x, y); } SCALAR_FUN_ATTR float futrts_copysign32(float x, float y) { return copysign(x, y); } SCALAR_FUN_ATTR float futrts_mad32(float a, float b, float c) { return mad(a, b, c); } SCALAR_FUN_ATTR float futrts_fma32(float a, float b, float c) { return fma(a, b, c); } #elif defined(ISPC) SCALAR_FUN_ATTR float futrts_log32(float x) { return futrts_isfinite32(x) || (futrts_isinf32(x) && x < 0)? log(x) : x; } SCALAR_FUN_ATTR float futrts_log2_32(float x) { return futrts_log32(x) / log(2.0f); } SCALAR_FUN_ATTR float futrts_log10_32(float x) { return futrts_log32(x) / log(10.0f); } SCALAR_FUN_ATTR float futrts_log1p_32(float x) { if(x == -1.0f || (futrts_isinf32(x) && x > 0.0f)) return x / 0.0f; float y = 1.0f + x; float z = y - 1.0f; return log(y) - (z-x)/y; } SCALAR_FUN_ATTR float futrts_sqrt32(float x) { return sqrt(x); } extern "C" unmasked uniform float cbrtf(uniform float); SCALAR_FUN_ATTR float futrts_cbrt32(float x) { float res; foreach_active (i) { uniform float r = cbrtf(extract(x, i)); res = insert(res, i, r); } return res; } SCALAR_FUN_ATTR float futrts_exp32(float x) { return exp(x); } SCALAR_FUN_ATTR float futrts_cos32(float x) { return cos(x); } SCALAR_FUN_ATTR float futrts_sin32(float x) { return sin(x); } SCALAR_FUN_ATTR float futrts_tan32(float x) { return tan(x); } SCALAR_FUN_ATTR float futrts_acos32(float x) { return acos(x); } SCALAR_FUN_ATTR float futrts_asin32(float x) { return asin(x); } SCALAR_FUN_ATTR float futrts_atan32(float x) { return atan(x); } SCALAR_FUN_ATTR float futrts_cosh32(float x) { return (exp(x)+exp(-x)) / 2.0f; } SCALAR_FUN_ATTR float futrts_sinh32(float x) { return (exp(x)-exp(-x)) / 2.0f; } SCALAR_FUN_ATTR float futrts_tanh32(float x) { return futrts_sinh32(x)/futrts_cosh32(x); } SCALAR_FUN_ATTR float futrts_acosh32(float x) { float f = x+sqrt(x*x-1); if(futrts_isfinite32(f)) return log(f); return f; } SCALAR_FUN_ATTR float futrts_asinh32(float x) { float f = x+sqrt(x*x+1); if(futrts_isfinite32(f)) return log(f); return f; } SCALAR_FUN_ATTR float futrts_atanh32(float x) { float f = (1+x)/(1-x); if(futrts_isfinite32(f)) return log(f)/2.0f; return f; } SCALAR_FUN_ATTR float futrts_atan2_32(float x, float y) { return (x == 0.0f && y == 0.0f) ? 0.0f : atan2(x, y); } SCALAR_FUN_ATTR float futrts_hypot32(float x, float y) { if (futrts_isfinite32(x) && futrts_isfinite32(y)) { x = abs(x); y = abs(y); float a; float b; if (x >= y){ a = x; b = y; } else { a = y; b = x; } if(b == 0){ return a; } int e; float an; float bn; an = frexp (a, &e); bn = ldexp (b, - e); float cn; cn = sqrt (an * an + bn * bn); return ldexp (cn, e); } else { if (futrts_isinf32(x) || futrts_isinf32(y)) return INFINITY; else return x + y; } } extern "C" unmasked uniform float tgammaf(uniform float x); SCALAR_FUN_ATTR float futrts_gamma32(float x) { float res; foreach_active (i) { uniform float r = tgammaf(extract(x, i)); res = insert(res, i, r); } return res; } extern "C" unmasked uniform float lgammaf(uniform float x); SCALAR_FUN_ATTR float futrts_lgamma32(float x) { float res; foreach_active (i) { uniform float r = lgammaf(extract(x, i)); res = insert(res, i, r); } return res; } extern "C" unmasked uniform float erff(uniform float x); SCALAR_FUN_ATTR float futrts_erf32(float x) { float res; foreach_active (i) { uniform float r = erff(extract(x, i)); res = insert(res, i, r); } return res; } extern "C" unmasked uniform float erfcf(uniform float x); SCALAR_FUN_ATTR float futrts_erfc32(float x) { float res; foreach_active (i) { uniform float r = erfcf(extract(x, i)); res = insert(res, i, r); } return res; } SCALAR_FUN_ATTR float fmod32(float x, float y) { return x - y * trunc(x/y); } SCALAR_FUN_ATTR float futrts_round32(float x) { return round(x); } SCALAR_FUN_ATTR float futrts_floor32(float x) { return floor(x); } SCALAR_FUN_ATTR float futrts_ceil32(float x) { return ceil(x); } extern "C" unmasked uniform float nextafterf(uniform float x, uniform float y); SCALAR_FUN_ATTR float futrts_nextafter32(float x, float y) { float res; foreach_active (i) { uniform float r = nextafterf(extract(x, i), extract(y, i)); res = insert(res, i, r); } return res; } SCALAR_FUN_ATTR float futrts_lerp32(float v0, float v1, float t) { return v0 + (v1 - v0) * t; } SCALAR_FUN_ATTR float futrts_ldexp32(float x, int32_t y) { return x * pow((uniform float)2.0, (float)y); } SCALAR_FUN_ATTR float futrts_copysign32(float x, float y) { int32_t xb = futrts_to_bits32(x); int32_t yb = futrts_to_bits32(y); return futrts_from_bits32((xb & ~(1<<31)) | (yb & (1<<31))); } SCALAR_FUN_ATTR float futrts_mad32(float a, float b, float c) { return a * b + c; } SCALAR_FUN_ATTR float futrts_fma32(float a, float b, float c) { return a * b + c; } #else // Not OpenCL or ISPC, but CUDA or plain C. SCALAR_FUN_ATTR float futrts_log32(float x) { return logf(x); } SCALAR_FUN_ATTR float futrts_log2_32(float x) { return log2f(x); } SCALAR_FUN_ATTR float futrts_log10_32(float x) { return log10f(x); } SCALAR_FUN_ATTR float futrts_log1p_32(float x) { return log1pf(x); } SCALAR_FUN_ATTR float futrts_sqrt32(float x) { return sqrtf(x); } SCALAR_FUN_ATTR float futrts_cbrt32(float x) { return cbrtf(x); } SCALAR_FUN_ATTR float futrts_exp32(float x) { return expf(x); } SCALAR_FUN_ATTR float futrts_cos32(float x) { return cosf(x); } SCALAR_FUN_ATTR float futrts_sin32(float x) { return sinf(x); } SCALAR_FUN_ATTR float futrts_tan32(float x) { return tanf(x); } SCALAR_FUN_ATTR float futrts_acos32(float x) { return acosf(x); } SCALAR_FUN_ATTR float futrts_asin32(float x) { return asinf(x); } SCALAR_FUN_ATTR float futrts_atan32(float x) { return atanf(x); } SCALAR_FUN_ATTR float futrts_cosh32(float x) { return coshf(x); } SCALAR_FUN_ATTR float futrts_sinh32(float x) { return sinhf(x); } SCALAR_FUN_ATTR float futrts_tanh32(float x) { return tanhf(x); } SCALAR_FUN_ATTR float futrts_acosh32(float x) { return acoshf(x); } SCALAR_FUN_ATTR float futrts_asinh32(float x) { return asinhf(x); } SCALAR_FUN_ATTR float futrts_atanh32(float x) { return atanhf(x); } SCALAR_FUN_ATTR float futrts_atan2_32(float x, float y) { return atan2f(x, y); } SCALAR_FUN_ATTR float futrts_hypot32(float x, float y) { return hypotf(x, y); } SCALAR_FUN_ATTR float futrts_gamma32(float x) { return tgammaf(x); } SCALAR_FUN_ATTR float futrts_lgamma32(float x) { return lgammaf(x); } SCALAR_FUN_ATTR float futrts_erf32(float x) { return erff(x); } SCALAR_FUN_ATTR float futrts_erfc32(float x) { return erfcf(x); } SCALAR_FUN_ATTR float fmod32(float x, float y) { return fmodf(x, y); } SCALAR_FUN_ATTR float futrts_round32(float x) { return rintf(x); } SCALAR_FUN_ATTR float futrts_floor32(float x) { return floorf(x); } SCALAR_FUN_ATTR float futrts_ceil32(float x) { return ceilf(x); } SCALAR_FUN_ATTR float futrts_nextafter32(float x, float y) { return nextafterf(x, y); } SCALAR_FUN_ATTR float futrts_lerp32(float v0, float v1, float t) { return v0 + (v1 - v0) * t; } SCALAR_FUN_ATTR float futrts_ldexp32(float x, int32_t y) { return ldexpf(x, y); } SCALAR_FUN_ATTR float futrts_copysign32(float x, float y) { return copysignf(x, y); } SCALAR_FUN_ATTR float futrts_mad32(float a, float b, float c) { return a * b + c; } SCALAR_FUN_ATTR float futrts_fma32(float a, float b, float c) { return fmaf(a, b, c); } #endif #if defined(ISPC) SCALAR_FUN_ATTR int32_t futrts_to_bits32(float x) { return intbits(x); } SCALAR_FUN_ATTR float futrts_from_bits32(int32_t x) { return floatbits(x); } #else SCALAR_FUN_ATTR int32_t futrts_to_bits32(float x) { union { float f; int32_t t; } p; p.f = x; return p.t; } SCALAR_FUN_ATTR float futrts_from_bits32(int32_t x) { union { int32_t f; float t; } p; p.f = x; return p.t; } #endif SCALAR_FUN_ATTR float fsignum32(float x) { return futrts_isnan32(x) ? x : (x > 0 ? 1 : 0) - (x < 0 ? 1 : 0); } #ifdef FUTHARK_F64_ENABLED SCALAR_FUN_ATTR double futrts_from_bits64(int64_t x); SCALAR_FUN_ATTR int64_t futrts_to_bits64(double x); #if defined(ISPC) SCALAR_FUN_ATTR bool futrts_isinf64(float x) { return !isnan(x) && isnan(x - x); } SCALAR_FUN_ATTR bool futrts_isfinite64(float x) { return !isnan(x) && !futrts_isinf64(x); } SCALAR_FUN_ATTR double fdiv64(double x, double y) { return x / y; } SCALAR_FUN_ATTR double fadd64(double x, double y) { return x + y; } SCALAR_FUN_ATTR double fsub64(double x, double y) { return x - y; } SCALAR_FUN_ATTR double fmul64(double x, double y) { return x * y; } SCALAR_FUN_ATTR bool cmplt64(double x, double y) { return x < y; } SCALAR_FUN_ATTR bool cmple64(double x, double y) { return x <= y; } SCALAR_FUN_ATTR double sitofp_i8_f64(int8_t x) { return (double) x; } SCALAR_FUN_ATTR double sitofp_i16_f64(int16_t x) { return (double) x; } SCALAR_FUN_ATTR double sitofp_i32_f64(int32_t x) { return (double) x; } SCALAR_FUN_ATTR double sitofp_i64_f64(int64_t x) { return (double) x; } SCALAR_FUN_ATTR double uitofp_i8_f64(uint8_t x) { return (double) x; } SCALAR_FUN_ATTR double uitofp_i16_f64(uint16_t x) { return (double) x; } SCALAR_FUN_ATTR double uitofp_i32_f64(uint32_t x) { return (double) x; } SCALAR_FUN_ATTR double uitofp_i64_f64(uint64_t x) { return (double) x; } SCALAR_FUN_ATTR double fabs64(double x) { return abs(x); } SCALAR_FUN_ATTR double fmax64(double x, double y) { return isnan(x) ? y : isnan(y) ? x : max(x, y); } SCALAR_FUN_ATTR double fmin64(double x, double y) { return isnan(x) ? y : isnan(y) ? x : min(x, y); } SCALAR_FUN_ATTR double fpow64(double a, double b) { float ret; foreach_active (i) { uniform float r = pow(extract(a, i), extract(b, i)); ret = insert(ret, i, r); } return ret; } SCALAR_FUN_ATTR double futrts_log64(double x) { return futrts_isfinite64(x) || (futrts_isinf64(x) && x < 0)? log(x) : x; } SCALAR_FUN_ATTR double futrts_log2_64(double x) { return futrts_log64(x)/log(2.0d); } SCALAR_FUN_ATTR double futrts_log10_64(double x) { return futrts_log64(x)/log(10.0d); } SCALAR_FUN_ATTR double futrts_log1p_64(double x) { if(x == -1.0d || (futrts_isinf64(x) && x > 0.0d)) return x / 0.0d; double y = 1.0d + x; double z = y - 1.0d; return log(y) - (z-x)/y; } SCALAR_FUN_ATTR double futrts_sqrt64(double x) { return sqrt(x); } extern "C" unmasked uniform double cbrt(uniform double); SCALAR_FUN_ATTR double futrts_cbrt64(double x) { double res; foreach_active (i) { uniform double r = cbrtf(extract(x, i)); res = insert(res, i, r); } return res; } SCALAR_FUN_ATTR double futrts_exp64(double x) { return exp(x); } SCALAR_FUN_ATTR double futrts_cos64(double x) { return cos(x); } SCALAR_FUN_ATTR double futrts_sin64(double x) { return sin(x); } SCALAR_FUN_ATTR double futrts_tan64(double x) { return tan(x); } SCALAR_FUN_ATTR double futrts_acos64(double x) { return acos(x); } SCALAR_FUN_ATTR double futrts_asin64(double x) { return asin(x); } SCALAR_FUN_ATTR double futrts_atan64(double x) { return atan(x); } SCALAR_FUN_ATTR double futrts_cosh64(double x) { return (exp(x)+exp(-x)) / 2.0d; } SCALAR_FUN_ATTR double futrts_sinh64(double x) { return (exp(x)-exp(-x)) / 2.0d; } SCALAR_FUN_ATTR double futrts_tanh64(double x) { return futrts_sinh64(x)/futrts_cosh64(x); } SCALAR_FUN_ATTR double futrts_acosh64(double x) { double f = x+sqrt(x*x-1.0d); if(futrts_isfinite64(f)) return log(f); return f; } SCALAR_FUN_ATTR double futrts_asinh64(double x) { double f = x+sqrt(x*x+1.0d); if(futrts_isfinite64(f)) return log(f); return f; } SCALAR_FUN_ATTR double futrts_atanh64(double x) { double f = (1.0d+x)/(1.0d-x); if(futrts_isfinite64(f)) return log(f)/2.0d; return f; } SCALAR_FUN_ATTR double futrts_atan2_64(double x, double y) { return atan2(x, y); } extern "C" unmasked uniform double hypot(uniform double x, uniform double y); SCALAR_FUN_ATTR double futrts_hypot64(double x, double y) { double res; foreach_active (i) { uniform double r = hypot(extract(x, i), extract(y, i)); res = insert(res, i, r); } return res; } extern "C" unmasked uniform double tgamma(uniform double x); SCALAR_FUN_ATTR double futrts_gamma64(double x) { double res; foreach_active (i) { uniform double r = tgamma(extract(x, i)); res = insert(res, i, r); } return res; } extern "C" unmasked uniform double lgamma(uniform double x); SCALAR_FUN_ATTR double futrts_lgamma64(double x) { double res; foreach_active (i) { uniform double r = lgamma(extract(x, i)); res = insert(res, i, r); } return res; } extern "C" unmasked uniform double erf(uniform double x); SCALAR_FUN_ATTR double futrts_erf64(double x) { double res; foreach_active (i) { uniform double r = erf(extract(x, i)); res = insert(res, i, r); } return res; } extern "C" unmasked uniform double erfc(uniform double x); SCALAR_FUN_ATTR double futrts_erfc64(double x) { double res; foreach_active (i) { uniform double r = erfc(extract(x, i)); res = insert(res, i, r); } return res; } SCALAR_FUN_ATTR double futrts_fma64(double a, double b, double c) { return a * b + c; } SCALAR_FUN_ATTR double futrts_round64(double x) { return round(x); } SCALAR_FUN_ATTR double futrts_ceil64(double x) { return ceil(x); } extern "C" unmasked uniform double nextafter(uniform float x, uniform double y); SCALAR_FUN_ATTR float futrts_nextafter64(double x, double y) { double res; foreach_active (i) { uniform double r = nextafter(extract(x, i), extract(y, i)); res = insert(res, i, r); } return res; } SCALAR_FUN_ATTR double futrts_floor64(double x) { return floor(x); } SCALAR_FUN_ATTR bool futrts_isnan64(double x) { return isnan(x); } SCALAR_FUN_ATTR int8_t fptosi_f64_i8(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (int8_t) x; } } SCALAR_FUN_ATTR int16_t fptosi_f64_i16(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (int16_t) x; } } SCALAR_FUN_ATTR int32_t fptosi_f64_i32(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (int32_t) x; } } SCALAR_FUN_ATTR int64_t fptosi_f64_i64(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (int64_t) x; } } SCALAR_FUN_ATTR uint8_t fptoui_f64_i8(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uint8_t) (int8_t) x; } } SCALAR_FUN_ATTR uint16_t fptoui_f64_i16(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uint16_t) (int16_t) x; } } SCALAR_FUN_ATTR uint32_t fptoui_f64_i32(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uint32_t) (int32_t) x; } } SCALAR_FUN_ATTR uint64_t fptoui_f64_i64(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uint64_t) (int64_t) x; } } SCALAR_FUN_ATTR bool ftob_f64_bool(double x) { return x != 0.0; } SCALAR_FUN_ATTR double btof_bool_f64(bool x) { return x ? 1.0 : 0.0; } SCALAR_FUN_ATTR int64_t futrts_to_bits64(double x) { int64_t res; foreach_active (i) { uniform double tmp = extract(x, i); uniform int64_t r = *((uniform int64_t* uniform)&tmp); res = insert(res, i, r); } return res; } SCALAR_FUN_ATTR double futrts_from_bits64(int64_t x) { double res; foreach_active (i) { uniform int64_t tmp = extract(x, i); uniform double r = *((uniform double* uniform)&tmp); res = insert(res, i, r); } return res; } SCALAR_FUN_ATTR double fmod64(double x, double y) { return x - y * trunc(x/y); } SCALAR_FUN_ATTR double fsignum64(double x) { return futrts_isnan64(x) ? x : (x > 0 ? 1.0d : 0.0d) - (x < 0 ? 1.0d : 0.0d); } SCALAR_FUN_ATTR double futrts_lerp64(double v0, double v1, double t) { return v0 + (v1 - v0) * t; } SCALAR_FUN_ATTR double futrts_ldexp64(double x, int32_t y) { return x * pow((uniform double)2.0, (double)y); } SCALAR_FUN_ATTR double futrts_copysign64(double x, double y) { int64_t xb = futrts_to_bits64(x); int64_t yb = futrts_to_bits64(y); return futrts_from_bits64((xb & ~(((int64_t)1)<<63)) | (yb & (((int64_t)1)<<63))); } SCALAR_FUN_ATTR double futrts_mad64(double a, double b, double c) { return a * b + c; } SCALAR_FUN_ATTR float fpconv_f32_f32(float x) { return (float) x; } SCALAR_FUN_ATTR double fpconv_f32_f64(float x) { return (double) x; } SCALAR_FUN_ATTR float fpconv_f64_f32(double x) { return (float) x; } SCALAR_FUN_ATTR double fpconv_f64_f64(double x) { return (double) x; } #else SCALAR_FUN_ATTR double fdiv64(double x, double y) { return x / y; } SCALAR_FUN_ATTR double fadd64(double x, double y) { return x + y; } SCALAR_FUN_ATTR double fsub64(double x, double y) { return x - y; } SCALAR_FUN_ATTR double fmul64(double x, double y) { return x * y; } SCALAR_FUN_ATTR bool cmplt64(double x, double y) { return x < y; } SCALAR_FUN_ATTR bool cmple64(double x, double y) { return x <= y; } SCALAR_FUN_ATTR double sitofp_i8_f64(int8_t x) { return (double) x; } SCALAR_FUN_ATTR double sitofp_i16_f64(int16_t x) { return (double) x; } SCALAR_FUN_ATTR double sitofp_i32_f64(int32_t x) { return (double) x; } SCALAR_FUN_ATTR double sitofp_i64_f64(int64_t x) { return (double) x; } SCALAR_FUN_ATTR double uitofp_i8_f64(uint8_t x) { return (double) x; } SCALAR_FUN_ATTR double uitofp_i16_f64(uint16_t x) { return (double) x; } SCALAR_FUN_ATTR double uitofp_i32_f64(uint32_t x) { return (double) x; } SCALAR_FUN_ATTR double uitofp_i64_f64(uint64_t x) { return (double) x; } SCALAR_FUN_ATTR double fabs64(double x) { return fabs(x); } SCALAR_FUN_ATTR double fmax64(double x, double y) { return fmax(x, y); } SCALAR_FUN_ATTR double fmin64(double x, double y) { return fmin(x, y); } SCALAR_FUN_ATTR double fpow64(double x, double y) { return pow(x, y); } SCALAR_FUN_ATTR double futrts_log64(double x) { return log(x); } SCALAR_FUN_ATTR double futrts_log2_64(double x) { return log2(x); } SCALAR_FUN_ATTR double futrts_log10_64(double x) { return log10(x); } SCALAR_FUN_ATTR double futrts_log1p_64(double x) { return log1p(x); } SCALAR_FUN_ATTR double futrts_sqrt64(double x) { return sqrt(x); } SCALAR_FUN_ATTR double futrts_cbrt64(double x) { return cbrt(x); } SCALAR_FUN_ATTR double futrts_exp64(double x) { return exp(x); } SCALAR_FUN_ATTR double futrts_cos64(double x) { return cos(x); } SCALAR_FUN_ATTR double futrts_sin64(double x) { return sin(x); } SCALAR_FUN_ATTR double futrts_tan64(double x) { return tan(x); } SCALAR_FUN_ATTR double futrts_acos64(double x) { return acos(x); } SCALAR_FUN_ATTR double futrts_asin64(double x) { return asin(x); } SCALAR_FUN_ATTR double futrts_atan64(double x) { return atan(x); } SCALAR_FUN_ATTR double futrts_cosh64(double x) { return cosh(x); } SCALAR_FUN_ATTR double futrts_sinh64(double x) { return sinh(x); } SCALAR_FUN_ATTR double futrts_tanh64(double x) { return tanh(x); } SCALAR_FUN_ATTR double futrts_acosh64(double x) { return acosh(x); } SCALAR_FUN_ATTR double futrts_asinh64(double x) { return asinh(x); } SCALAR_FUN_ATTR double futrts_atanh64(double x) { return atanh(x); } SCALAR_FUN_ATTR double futrts_atan2_64(double x, double y) { return atan2(x, y); } SCALAR_FUN_ATTR double futrts_hypot64(double x, double y) { return hypot(x, y); } SCALAR_FUN_ATTR double futrts_gamma64(double x) { return tgamma(x); } SCALAR_FUN_ATTR double futrts_lgamma64(double x) { return lgamma(x); } SCALAR_FUN_ATTR double futrts_erf64(double x) { return erf(x); } SCALAR_FUN_ATTR double futrts_erfc64(double x) { return erfc(x); } SCALAR_FUN_ATTR double futrts_fma64(double a, double b, double c) { return fma(a, b, c); } SCALAR_FUN_ATTR double futrts_round64(double x) { return rint(x); } SCALAR_FUN_ATTR double futrts_ceil64(double x) { return ceil(x); } SCALAR_FUN_ATTR float futrts_nextafter64(float x, float y) { return nextafter(x, y); } SCALAR_FUN_ATTR double futrts_floor64(double x) { return floor(x); } SCALAR_FUN_ATTR bool futrts_isnan64(double x) { return isnan(x); } SCALAR_FUN_ATTR bool futrts_isinf64(double x) { return isinf(x); } SCALAR_FUN_ATTR int8_t fptosi_f64_i8(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (int8_t) x; } } SCALAR_FUN_ATTR int16_t fptosi_f64_i16(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (int16_t) x; } } SCALAR_FUN_ATTR int32_t fptosi_f64_i32(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (int32_t) x; } } SCALAR_FUN_ATTR int64_t fptosi_f64_i64(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (int64_t) x; } } SCALAR_FUN_ATTR uint8_t fptoui_f64_i8(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uint8_t) (int8_t) x; } } SCALAR_FUN_ATTR uint16_t fptoui_f64_i16(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uint16_t) (int16_t) x; } } SCALAR_FUN_ATTR uint32_t fptoui_f64_i32(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uint32_t) (int32_t) x; } } SCALAR_FUN_ATTR uint64_t fptoui_f64_i64(double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uint64_t) (int64_t) x; } } SCALAR_FUN_ATTR bool ftob_f64_bool(double x) { return x != 0; } SCALAR_FUN_ATTR double btof_bool_f64(bool x) { return x ? 1 : 0; } SCALAR_FUN_ATTR int64_t futrts_to_bits64(double x) { union { double f; int64_t t; } p; p.f = x; return p.t; } SCALAR_FUN_ATTR double futrts_from_bits64(int64_t x) { union { int64_t f; double t; } p; p.f = x; return p.t; } SCALAR_FUN_ATTR double fmod64(double x, double y) { return fmod(x, y); } SCALAR_FUN_ATTR double fsignum64(double x) { return futrts_isnan64(x) ? x : (x > 0) - (x < 0); } SCALAR_FUN_ATTR double futrts_lerp64(double v0, double v1, double t) { #ifdef __OPENCL_VERSION__ return mix(v0, v1, t); #else return v0 + (v1 - v0) * t; #endif } SCALAR_FUN_ATTR double futrts_ldexp64(double x, int32_t y) { return ldexp(x, y); } SCALAR_FUN_ATTR float futrts_copysign64(double x, double y) { return copysign(x, y); } SCALAR_FUN_ATTR double futrts_mad64(double a, double b, double c) { #ifdef __OPENCL_VERSION__ return mad(a, b, c); #else return a * b + c; #endif } SCALAR_FUN_ATTR float fpconv_f32_f32(float x) { return (float) x; } SCALAR_FUN_ATTR double fpconv_f32_f64(float x) { return (double) x; } SCALAR_FUN_ATTR float fpconv_f64_f32(double x) { return (float) x; } SCALAR_FUN_ATTR double fpconv_f64_f64(double x) { return (double) x; } #endif #endif #define futrts_cond_f16(x,y,z) ((x) ? (y) : (z)) #define futrts_cond_f32(x,y,z) ((x) ? (y) : (z)) #define futrts_cond_f64(x,y,z) ((x) ? (y) : (z)) #define futrts_cond_i8(x,y,z) ((x) ? (y) : (z)) #define futrts_cond_i16(x,y,z) ((x) ? (y) : (z)) #define futrts_cond_i32(x,y,z) ((x) ? (y) : (z)) #define futrts_cond_i64(x,y,z) ((x) ? (y) : (z)) #define futrts_cond_bool(x,y,z) ((x) ? (y) : (z)) #define futrts_cond_unit(x,y,z) ((x) ? (y) : (z)) // End of scalar.h. futhark-0.25.27/rts/c/scalar_f16.h000066400000000000000000000433701475065116200164750ustar00rootroot00000000000000// Start of scalar_f16.h. // Half-precision is emulated if needed (e.g. in straight C) with the // native type used if possible. The emulation works by typedef'ing // 'float' to 'f16', and then implementing all operations on single // precision. To cut down on duplication, we use the same code for // those Futhark functions that require just operators or casts. The // in-memory representation for arrays will still be 16 bits even // under emulation, so the compiler will have to be careful when // generating reads or writes. #if !defined(cl_khr_fp16) && !(defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) && !(defined(ISPC)) #define EMULATE_F16 #endif #if !defined(EMULATE_F16) && defined(__OPENCL_VERSION__) #pragma OPENCL EXTENSION cl_khr_fp16 : enable #endif #ifdef EMULATE_F16 // Note that the half-precision storage format is still 16 bits - the // compiler will have to be real careful! typedef float f16; #elif defined(ISPC) typedef float16 f16; #else #ifdef __CUDA_ARCH__ #include #endif typedef half f16; #endif // Some of these functions convert to single precision because half // precision versions are not available. SCALAR_FUN_ATTR f16 fadd16(f16 x, f16 y) { return x + y; } SCALAR_FUN_ATTR f16 fsub16(f16 x, f16 y) { return x - y; } SCALAR_FUN_ATTR f16 fmul16(f16 x, f16 y) { return x * y; } SCALAR_FUN_ATTR bool cmplt16(f16 x, f16 y) { return x < y; } SCALAR_FUN_ATTR bool cmple16(f16 x, f16 y) { return x <= y; } SCALAR_FUN_ATTR f16 sitofp_i8_f16(int8_t x) { return (f16) x; } SCALAR_FUN_ATTR f16 sitofp_i16_f16(int16_t x) { return (f16) x; } SCALAR_FUN_ATTR f16 sitofp_i32_f16(int32_t x) { return (f16) x; } SCALAR_FUN_ATTR f16 sitofp_i64_f16(int64_t x) { return (f16) x; } SCALAR_FUN_ATTR f16 uitofp_i8_f16(uint8_t x) { return (f16) x; } SCALAR_FUN_ATTR f16 uitofp_i16_f16(uint16_t x) { return (f16) x; } SCALAR_FUN_ATTR f16 uitofp_i32_f16(uint32_t x) { return (f16) x; } SCALAR_FUN_ATTR f16 uitofp_i64_f16(uint64_t x) { return (f16) x; } SCALAR_FUN_ATTR int8_t fptosi_f16_i8(f16 x) { return (int8_t) (float) x; } SCALAR_FUN_ATTR int16_t fptosi_f16_i16(f16 x) { return (int16_t) x; } SCALAR_FUN_ATTR int32_t fptosi_f16_i32(f16 x) { return (int32_t) x; } SCALAR_FUN_ATTR int64_t fptosi_f16_i64(f16 x) { return (int64_t) x; } SCALAR_FUN_ATTR uint8_t fptoui_f16_i8(f16 x) { return (uint8_t) (float) x; } SCALAR_FUN_ATTR uint16_t fptoui_f16_i16(f16 x) { return (uint16_t) x; } SCALAR_FUN_ATTR uint32_t fptoui_f16_i32(f16 x) { return (uint32_t) x; } SCALAR_FUN_ATTR uint64_t fptoui_f16_i64(f16 x) { return (uint64_t) x; } SCALAR_FUN_ATTR bool ftob_f16_bool(f16 x) { return x != (f16)0; } SCALAR_FUN_ATTR f16 btof_bool_f16(bool x) { return x ? 1 : 0; } #ifndef EMULATE_F16 SCALAR_FUN_ATTR bool futrts_isnan16(f16 x) { return isnan((float)x); } #ifdef __OPENCL_VERSION__ SCALAR_FUN_ATTR f16 fabs16(f16 x) { return fabs(x); } SCALAR_FUN_ATTR f16 fmax16(f16 x, f16 y) { return fmax(x, y); } SCALAR_FUN_ATTR f16 fmin16(f16 x, f16 y) { return fmin(x, y); } SCALAR_FUN_ATTR f16 fpow16(f16 x, f16 y) { return pow(x, y); } #elif defined(ISPC) SCALAR_FUN_ATTR f16 fabs16(f16 x) { return abs(x); } SCALAR_FUN_ATTR f16 fmax16(f16 x, f16 y) { return futrts_isnan16(x) ? y : futrts_isnan16(y) ? x : max(x, y); } SCALAR_FUN_ATTR f16 fmin16(f16 x, f16 y) { return futrts_isnan16(x) ? y : futrts_isnan16(y) ? x : min(x, y); } SCALAR_FUN_ATTR f16 fpow16(f16 x, f16 y) { return pow(x, y); } #else // Assuming CUDA. SCALAR_FUN_ATTR f16 fabs16(f16 x) { return fabsf(x); } SCALAR_FUN_ATTR f16 fmax16(f16 x, f16 y) { return fmaxf(x, y); } SCALAR_FUN_ATTR f16 fmin16(f16 x, f16 y) { return fminf(x, y); } SCALAR_FUN_ATTR f16 fpow16(f16 x, f16 y) { return powf(x, y); } #endif #if defined(ISPC) SCALAR_FUN_ATTR bool futrts_isinf16(float x) { return !futrts_isnan16(x) && futrts_isnan16(x - x); } SCALAR_FUN_ATTR bool futrts_isfinite16(float x) { return !futrts_isnan16(x) && !futrts_isinf16(x); } #else SCALAR_FUN_ATTR bool futrts_isinf16(f16 x) { return isinf((float)x); } #endif #ifdef __OPENCL_VERSION__ SCALAR_FUN_ATTR f16 futrts_log16(f16 x) { return log(x); } SCALAR_FUN_ATTR f16 futrts_log2_16(f16 x) { return log2(x); } SCALAR_FUN_ATTR f16 futrts_log10_16(f16 x) { return log10(x); } SCALAR_FUN_ATTR f16 futrts_log1p_16(f16 x) { return log1p(x); } SCALAR_FUN_ATTR f16 futrts_sqrt16(f16 x) { return sqrt(x); } SCALAR_FUN_ATTR f16 futrts_cbrt16(f16 x) { return cbrt(x); } SCALAR_FUN_ATTR f16 futrts_exp16(f16 x) { return exp(x); } SCALAR_FUN_ATTR f16 futrts_cos16(f16 x) { return cos(x); } SCALAR_FUN_ATTR f16 futrts_sin16(f16 x) { return sin(x); } SCALAR_FUN_ATTR f16 futrts_tan16(f16 x) { return tan(x); } SCALAR_FUN_ATTR f16 futrts_acos16(f16 x) { return acos(x); } SCALAR_FUN_ATTR f16 futrts_asin16(f16 x) { return asin(x); } SCALAR_FUN_ATTR f16 futrts_atan16(f16 x) { return atan(x); } SCALAR_FUN_ATTR f16 futrts_cosh16(f16 x) { return cosh(x); } SCALAR_FUN_ATTR f16 futrts_sinh16(f16 x) { return sinh(x); } SCALAR_FUN_ATTR f16 futrts_tanh16(f16 x) { return tanh(x); } SCALAR_FUN_ATTR f16 futrts_acosh16(f16 x) { return acosh(x); } SCALAR_FUN_ATTR f16 futrts_asinh16(f16 x) { return asinh(x); } SCALAR_FUN_ATTR f16 futrts_atanh16(f16 x) { return atanh(x); } SCALAR_FUN_ATTR f16 futrts_atan2_16(f16 x, f16 y) { return atan2(x, y); } SCALAR_FUN_ATTR f16 futrts_hypot16(f16 x, f16 y) { return hypot(x, y); } SCALAR_FUN_ATTR f16 futrts_gamma16(f16 x) { return tgamma(x); } SCALAR_FUN_ATTR f16 futrts_lgamma16(f16 x) { return lgamma(x); } SCALAR_FUN_ATTR f16 futrts_erf16(f16 x) { return erf(x); } SCALAR_FUN_ATTR f16 futrts_erfc16(f16 x) { return erfc(x); } SCALAR_FUN_ATTR f16 fmod16(f16 x, f16 y) { return fmod(x, y); } SCALAR_FUN_ATTR f16 futrts_round16(f16 x) { return rint(x); } SCALAR_FUN_ATTR f16 futrts_floor16(f16 x) { return floor(x); } SCALAR_FUN_ATTR f16 futrts_ceil16(f16 x) { return ceil(x); } SCALAR_FUN_ATTR f16 futrts_nextafter16(f16 x, f16 y) { return nextafter(x, y); } SCALAR_FUN_ATTR f16 futrts_lerp16(f16 v0, f16 v1, f16 t) { return mix(v0, v1, t); } SCALAR_FUN_ATTR f16 futrts_ldexp16(f16 x, int32_t y) { return ldexp(x, y); } SCALAR_FUN_ATTR f16 futrts_copysign16(f16 x, f16 y) { return copysign(x, y); } SCALAR_FUN_ATTR f16 futrts_mad16(f16 a, f16 b, f16 c) { return mad(a, b, c); } SCALAR_FUN_ATTR f16 futrts_fma16(f16 a, f16 b, f16 c) { return fma(a, b, c); } #elif defined(ISPC) SCALAR_FUN_ATTR f16 futrts_log16(f16 x) { return futrts_isfinite16(x) || (futrts_isinf16(x) && x < 0) ? log(x) : x; } SCALAR_FUN_ATTR f16 futrts_log2_16(f16 x) { return futrts_log16(x) / log(2.0f16); } SCALAR_FUN_ATTR f16 futrts_log10_16(f16 x) { return futrts_log16(x) / log(10.0f16); } SCALAR_FUN_ATTR f16 futrts_log1p_16(f16 x) { if(x == -1.0f16 || (futrts_isinf16(x) && x > 0.0f16)) return x / 0.0f16; f16 y = 1.0f16 + x; f16 z = y - 1.0f16; return log(y) - (z-x)/y; } SCALAR_FUN_ATTR f16 futrts_sqrt16(f16 x) { return (float16)sqrt((float)x); } SCALAR_FUN_ATTR f16 futrts_exp16(f16 x) { return exp(x); } SCALAR_FUN_ATTR f16 futrts_cos16(f16 x) { return (float16)cos((float)x); } SCALAR_FUN_ATTR f16 futrts_sin16(f16 x) { return (float16)sin((float)x); } SCALAR_FUN_ATTR f16 futrts_tan16(f16 x) { return (float16)tan((float)x); } SCALAR_FUN_ATTR f16 futrts_acos16(f16 x) { return (float16)acos((float)x); } SCALAR_FUN_ATTR f16 futrts_asin16(f16 x) { return (float16)asin((float)x); } SCALAR_FUN_ATTR f16 futrts_atan16(f16 x) { return (float16)atan((float)x); } SCALAR_FUN_ATTR f16 futrts_cosh16(f16 x) { return (exp(x)+exp(-x)) / 2.0f16; } SCALAR_FUN_ATTR f16 futrts_sinh16(f16 x) { return (exp(x)-exp(-x)) / 2.0f16; } SCALAR_FUN_ATTR f16 futrts_tanh16(f16 x) { return futrts_sinh16(x)/futrts_cosh16(x); } SCALAR_FUN_ATTR f16 futrts_acosh16(f16 x) { float16 f = x+(float16)sqrt((float)(x*x-1)); if(futrts_isfinite16(f)) return log(f); return f; } SCALAR_FUN_ATTR f16 futrts_asinh16(f16 x) { float16 f = x+(float16)sqrt((float)(x*x+1)); if(futrts_isfinite16(f)) return log(f); return f; } SCALAR_FUN_ATTR f16 futrts_atanh16(f16 x) { float16 f = (1+x)/(1-x); if(futrts_isfinite16(f)) return log(f)/2.0f16; return f; } SCALAR_FUN_ATTR f16 futrts_atan2_16(f16 x, f16 y) { return (float16)atan2((float)x, (float)y); } SCALAR_FUN_ATTR f16 futrts_hypot16(f16 x, f16 y) { return (float16)futrts_hypot32((float)x, (float)y); } extern "C" unmasked uniform float tgammaf(uniform float x); SCALAR_FUN_ATTR f16 futrts_gamma16(f16 x) { f16 res; foreach_active (i) { uniform f16 r = (f16)tgammaf(extract((float)x, i)); res = insert(res, i, r); } return res; } extern "C" unmasked uniform float lgammaf(uniform float x); SCALAR_FUN_ATTR f16 futrts_lgamma16(f16 x) { f16 res; foreach_active (i) { uniform f16 r = (f16)lgammaf(extract((float)x, i)); res = insert(res, i, r); } return res; } SCALAR_FUN_ATTR f16 futrts_cbrt16(f16 x) { f16 res = (f16)futrts_cbrt32((float)x); return res; } SCALAR_FUN_ATTR f16 futrts_erf16(f16 x) { f16 res = (f16)futrts_erf32((float)x); return res; } SCALAR_FUN_ATTR f16 futrts_erfc16(f16 x) { f16 res = (f16)futrts_erfc32((float)x); return res; } SCALAR_FUN_ATTR f16 fmod16(f16 x, f16 y) { return x - y * (float16)trunc((float) (x/y)); } SCALAR_FUN_ATTR f16 futrts_round16(f16 x) { return (float16)round((float)x); } SCALAR_FUN_ATTR f16 futrts_floor16(f16 x) { return (float16)floor((float)x); } SCALAR_FUN_ATTR f16 futrts_ceil16(f16 x) { return (float16)ceil((float)x); } SCALAR_FUN_ATTR f16 futrts_nextafter16(f16 x, f16 y) { return (float16)futrts_nextafter32((float)x, (float) y); } SCALAR_FUN_ATTR f16 futrts_lerp16(f16 v0, f16 v1, f16 t) { return v0 + (v1 - v0) * t; } SCALAR_FUN_ATTR f16 futrts_ldexp16(f16 x, int32_t y) { return futrts_ldexp32((float)x, y); } SCALAR_FUN_ATTR f16 futrts_copysign16(f16 x, f16 y) { return futrts_copysign32((float)x, y); } SCALAR_FUN_ATTR f16 futrts_mad16(f16 a, f16 b, f16 c) { return a * b + c; } SCALAR_FUN_ATTR f16 futrts_fma16(f16 a, f16 b, f16 c) { return a * b + c; } #else // Assume CUDA. SCALAR_FUN_ATTR f16 futrts_log16(f16 x) { return hlog(x); } SCALAR_FUN_ATTR f16 futrts_log2_16(f16 x) { return hlog2(x); } SCALAR_FUN_ATTR f16 futrts_log10_16(f16 x) { return hlog10(x); } SCALAR_FUN_ATTR f16 futrts_log1p_16(f16 x) { return (f16)log1pf((float)x); } SCALAR_FUN_ATTR f16 futrts_sqrt16(f16 x) { return hsqrt(x); } SCALAR_FUN_ATTR f16 futrts_cbrt16(f16 x) { return cbrtf(x); } SCALAR_FUN_ATTR f16 futrts_exp16(f16 x) { return hexp(x); } SCALAR_FUN_ATTR f16 futrts_cos16(f16 x) { return hcos(x); } SCALAR_FUN_ATTR f16 futrts_sin16(f16 x) { return hsin(x); } SCALAR_FUN_ATTR f16 futrts_tan16(f16 x) { return tanf(x); } SCALAR_FUN_ATTR f16 futrts_acos16(f16 x) { return acosf(x); } SCALAR_FUN_ATTR f16 futrts_asin16(f16 x) { return asinf(x); } SCALAR_FUN_ATTR f16 futrts_atan16(f16 x) { return atanf(x); } SCALAR_FUN_ATTR f16 futrts_cosh16(f16 x) { return coshf(x); } SCALAR_FUN_ATTR f16 futrts_sinh16(f16 x) { return sinhf(x); } SCALAR_FUN_ATTR f16 futrts_tanh16(f16 x) { return tanhf(x); } SCALAR_FUN_ATTR f16 futrts_acosh16(f16 x) { return acoshf(x); } SCALAR_FUN_ATTR f16 futrts_asinh16(f16 x) { return asinhf(x); } SCALAR_FUN_ATTR f16 futrts_atanh16(f16 x) { return atanhf(x); } SCALAR_FUN_ATTR f16 futrts_atan2_16(f16 x, f16 y) { return atan2f(x, y); } SCALAR_FUN_ATTR f16 futrts_hypot16(f16 x, f16 y) { return hypotf(x, y); } SCALAR_FUN_ATTR f16 futrts_gamma16(f16 x) { return tgammaf(x); } SCALAR_FUN_ATTR f16 futrts_lgamma16(f16 x) { return lgammaf(x); } SCALAR_FUN_ATTR f16 futrts_erf16(f16 x) { return erff(x); } SCALAR_FUN_ATTR f16 futrts_erfc16(f16 x) { return erfcf(x); } SCALAR_FUN_ATTR f16 fmod16(f16 x, f16 y) { return fmodf(x, y); } SCALAR_FUN_ATTR f16 futrts_round16(f16 x) { return rintf(x); } SCALAR_FUN_ATTR f16 futrts_floor16(f16 x) { return hfloor(x); } SCALAR_FUN_ATTR f16 futrts_ceil16(f16 x) { return hceil(x); } SCALAR_FUN_ATTR f16 futrts_nextafter16(f16 x, f16 y) { return __ushort_as_half(halfbitsnextafter(__half_as_ushort(x), __half_as_ushort(y))); } SCALAR_FUN_ATTR f16 futrts_lerp16(f16 v0, f16 v1, f16 t) { return v0 + (v1 - v0) * t; } SCALAR_FUN_ATTR f16 futrts_ldexp16(f16 x, int32_t y) { return futrts_ldexp32((float)x, y); } SCALAR_FUN_ATTR f16 futrts_copysign16(f16 x, f16 y) { return futrts_copysign32((float)x, y); } SCALAR_FUN_ATTR f16 futrts_mad16(f16 a, f16 b, f16 c) { return a * b + c; } SCALAR_FUN_ATTR f16 futrts_fma16(f16 a, f16 b, f16 c) { return fmaf(a, b, c); } #endif // The CUDA __half type cannot be put in unions for some reason, so we // use bespoke conversion functions instead. #ifdef __CUDA_ARCH__ SCALAR_FUN_ATTR int16_t futrts_to_bits16(f16 x) { return __half_as_ushort(x); } SCALAR_FUN_ATTR f16 futrts_from_bits16(int16_t x) { return __ushort_as_half(x); } #elif defined(ISPC) SCALAR_FUN_ATTR int16_t futrts_to_bits16(f16 x) { varying int16_t y = *((varying int16_t * uniform)&x); return y; } SCALAR_FUN_ATTR f16 futrts_from_bits16(int16_t x) { varying f16 y = *((varying f16 * uniform)&x); return y; } #else SCALAR_FUN_ATTR int16_t futrts_to_bits16(f16 x) { union { f16 f; int16_t t; } p; p.f = x; return p.t; } SCALAR_FUN_ATTR f16 futrts_from_bits16(int16_t x) { union { int16_t f; f16 t; } p; p.f = x; return p.t; } #endif #else // No native f16 - emulate. SCALAR_FUN_ATTR f16 fabs16(f16 x) { return fabs32(x); } SCALAR_FUN_ATTR f16 fmax16(f16 x, f16 y) { return fmax32(x, y); } SCALAR_FUN_ATTR f16 fmin16(f16 x, f16 y) { return fmin32(x, y); } SCALAR_FUN_ATTR f16 fpow16(f16 x, f16 y) { return fpow32(x, y); } SCALAR_FUN_ATTR bool futrts_isnan16(f16 x) { return futrts_isnan32(x); } SCALAR_FUN_ATTR bool futrts_isinf16(f16 x) { return futrts_isinf32(x); } SCALAR_FUN_ATTR f16 futrts_log16(f16 x) { return futrts_log32(x); } SCALAR_FUN_ATTR f16 futrts_log2_16(f16 x) { return futrts_log2_32(x); } SCALAR_FUN_ATTR f16 futrts_log10_16(f16 x) { return futrts_log10_32(x); } SCALAR_FUN_ATTR f16 futrts_log1p_16(f16 x) { return futrts_log1p_32(x); } SCALAR_FUN_ATTR f16 futrts_sqrt16(f16 x) { return futrts_sqrt32(x); } SCALAR_FUN_ATTR f16 futrts_cbrt16(f16 x) { return futrts_cbrt32(x); } SCALAR_FUN_ATTR f16 futrts_exp16(f16 x) { return futrts_exp32(x); } SCALAR_FUN_ATTR f16 futrts_cos16(f16 x) { return futrts_cos32(x); } SCALAR_FUN_ATTR f16 futrts_sin16(f16 x) { return futrts_sin32(x); } SCALAR_FUN_ATTR f16 futrts_tan16(f16 x) { return futrts_tan32(x); } SCALAR_FUN_ATTR f16 futrts_acos16(f16 x) { return futrts_acos32(x); } SCALAR_FUN_ATTR f16 futrts_asin16(f16 x) { return futrts_asin32(x); } SCALAR_FUN_ATTR f16 futrts_atan16(f16 x) { return futrts_atan32(x); } SCALAR_FUN_ATTR f16 futrts_cosh16(f16 x) { return futrts_cosh32(x); } SCALAR_FUN_ATTR f16 futrts_sinh16(f16 x) { return futrts_sinh32(x); } SCALAR_FUN_ATTR f16 futrts_tanh16(f16 x) { return futrts_tanh32(x); } SCALAR_FUN_ATTR f16 futrts_acosh16(f16 x) { return futrts_acosh32(x); } SCALAR_FUN_ATTR f16 futrts_asinh16(f16 x) { return futrts_asinh32(x); } SCALAR_FUN_ATTR f16 futrts_atanh16(f16 x) { return futrts_atanh32(x); } SCALAR_FUN_ATTR f16 futrts_atan2_16(f16 x, f16 y) { return futrts_atan2_32(x, y); } SCALAR_FUN_ATTR f16 futrts_hypot16(f16 x, f16 y) { return futrts_hypot32(x, y); } SCALAR_FUN_ATTR f16 futrts_gamma16(f16 x) { return futrts_gamma32(x); } SCALAR_FUN_ATTR f16 futrts_lgamma16(f16 x) { return futrts_lgamma32(x); } SCALAR_FUN_ATTR f16 futrts_erf16(f16 x) { return futrts_erf32(x); } SCALAR_FUN_ATTR f16 futrts_erfc16(f16 x) { return futrts_erfc32(x); } SCALAR_FUN_ATTR f16 fmod16(f16 x, f16 y) { return fmod32(x, y); } SCALAR_FUN_ATTR f16 futrts_round16(f16 x) { return futrts_round32(x); } SCALAR_FUN_ATTR f16 futrts_floor16(f16 x) { return futrts_floor32(x); } SCALAR_FUN_ATTR f16 futrts_ceil16(f16 x) { return futrts_ceil32(x); } SCALAR_FUN_ATTR f16 futrts_nextafter16(f16 x, f16 y) { return halfbits2float(halfbitsnextafter(float2halfbits(x), float2halfbits(y))); } SCALAR_FUN_ATTR f16 futrts_lerp16(f16 v0, f16 v1, f16 t) { return futrts_lerp32(v0, v1, t); } SCALAR_FUN_ATTR f16 futrts_ldexp16(f16 x, int32_t y) { return futrts_ldexp32(x, y); } SCALAR_FUN_ATTR f16 futrts_copysign16(f16 x, f16 y) { return futrts_copysign32((float)x, y); } SCALAR_FUN_ATTR f16 futrts_mad16(f16 a, f16 b, f16 c) { return futrts_mad32(a, b, c); } SCALAR_FUN_ATTR f16 futrts_fma16(f16 a, f16 b, f16 c) { return futrts_fma32(a, b, c); } // Even when we are using an OpenCL that does not support cl_khr_fp16, // it must still support vload_half for actually creating a // half-precision number, which can then be efficiently converted to a // float. Similarly for vstore_half. #ifdef __OPENCL_VERSION__ SCALAR_FUN_ATTR int16_t futrts_to_bits16(f16 x) { int16_t y; // Violating strict aliasing here. vstore_half((float)x, 0, (half*)&y); return y; } SCALAR_FUN_ATTR f16 futrts_from_bits16(int16_t x) { return (f16)vload_half(0, (half*)&x); } #else SCALAR_FUN_ATTR int16_t futrts_to_bits16(f16 x) { return (int16_t)float2halfbits(x); } SCALAR_FUN_ATTR f16 futrts_from_bits16(int16_t x) { return halfbits2float((uint16_t)x); } SCALAR_FUN_ATTR f16 fsignum16(f16 x) { return futrts_isnan16(x) ? x : (x > 0 ? 1 : 0) - (x < 0 ? 1 : 0); } #endif #endif SCALAR_FUN_ATTR float fpconv_f16_f16(f16 x) { return x; } SCALAR_FUN_ATTR float fpconv_f16_f32(f16 x) { return x; } SCALAR_FUN_ATTR f16 fpconv_f32_f16(float x) { return (f16) x; } #ifdef FUTHARK_F64_ENABLED SCALAR_FUN_ATTR double fpconv_f16_f64(f16 x) { return (double) x; } #if defined(ISPC) SCALAR_FUN_ATTR f16 fpconv_f64_f16(double x) { return (f16) ((float)x); } #else SCALAR_FUN_ATTR f16 fpconv_f64_f16(double x) { return (f16) x; } #endif #endif // End of scalar_f16.h. futhark-0.25.27/rts/c/scheduler.h000066400000000000000000001111651475065116200165300ustar00rootroot00000000000000// start of scheduler.h // First, the API that the generated code will access. In principle, // we could then compile the scheduler separately and link an object // file with the generated code. In practice, we will embed all of // this in the generated code. // Scheduler handle. struct scheduler; // Initialise a scheduler (and start worker threads). static int scheduler_init(struct scheduler *scheduler, int num_workers, double kappa); // Shut down a scheduler (and destroy worker threads). static int scheduler_destroy(struct scheduler *scheduler); // Figure out the smallest amount of work that amortises task // creation. static int determine_kappa(double *kappa); // How a segop should be scheduled. enum scheduling { DYNAMIC, STATIC }; // How a given task should be executed. Filled out by the scheduler // and passed to the segop function struct scheduler_info { int64_t iter_pr_subtask; int64_t remainder; int nsubtasks; enum scheduling sched; int wake_up_threads; int64_t *task_time; int64_t *task_iter; }; // A segop function. This is what you hand the scheduler for // execution. typedef int (*segop_fn)(void* args, int64_t iterations, int tid, struct scheduler_info info); // A task for the scheduler to execute. struct scheduler_segop { void *args; segop_fn top_level_fn; segop_fn nested_fn; int64_t iterations; enum scheduling sched; // Pointers to timer and iter associated with the task int64_t *task_time; int64_t *task_iter; // For debugging const char* name; }; static inline int scheduler_prepare_task(struct scheduler *scheduler, struct scheduler_segop *task); typedef int (*parloop_fn)(void* args, int64_t start, int64_t end, int subtask_id, int tid); // A parallel parloop task. struct scheduler_parloop { void* args; parloop_fn fn; int64_t iterations; struct scheduler_info info; // For debugging const char* name; }; static inline int scheduler_execute_task(struct scheduler *scheduler, struct scheduler_parloop *task); // Then the API implementation. #include #if defined(_WIN32) #include #elif defined(__APPLE__) #include // For getting cpu usage of threads #include #include #elif defined(__linux__) #include #include #include #elif defined(__EMSCRIPTEN__) #include #include #include #include #endif /* Multicore Utility functions */ /* A wrapper for getting rusage on Linux and MacOS */ /* TODO maybe figure out this for windows */ static inline int getrusage_thread(struct rusage *rusage) { int err = -1; #if defined(__APPLE__) thread_basic_info_data_t info = { 0 }; mach_msg_type_number_t info_count = THREAD_BASIC_INFO_COUNT; kern_return_t kern_err; kern_err = thread_info(mach_thread_self(), THREAD_BASIC_INFO, (thread_info_t)&info, &info_count); if (kern_err == KERN_SUCCESS) { memset(rusage, 0, sizeof(struct rusage)); rusage->ru_utime.tv_sec = info.user_time.seconds; rusage->ru_utime.tv_usec = info.user_time.microseconds; rusage->ru_stime.tv_sec = info.system_time.seconds; rusage->ru_stime.tv_usec = info.system_time.microseconds; err = 0; } else { errno = EINVAL; } #elif defined(__linux__) || __EMSCRIPTEN__ err = getrusage(RUSAGE_THREAD, rusage); #endif return err; } /* returns the number of logical cores */ static int num_processors(void) { #if defined(_WIN32) /* https://docs.microsoft.com/en-us/windows/win32/api/sysinfoapi/ns-sysinfoapi-system_info */ SYSTEM_INFO sysinfo; GetSystemInfo(&sysinfo); int ncores = sysinfo.dwNumberOfProcessors; fprintf(stderr, "Found %d cores on your Windows machine\n Is that correct?\n", ncores); return ncores; #elif defined(__APPLE__) int ncores; size_t ncores_size = sizeof(ncores); CHECK_ERRNO(sysctlbyname("hw.logicalcpu", &ncores, &ncores_size, NULL, 0), "sysctlbyname (hw.logicalcpu)"); return ncores; #elif defined(__linux__) return get_nprocs(); #elif __EMSCRIPTEN__ return emscripten_num_logical_cores(); #else fprintf(stderr, "operating system not recognised\n"); return -1; #endif } static unsigned int g_seed; // Used to seed the generator. static inline void fast_srand(unsigned int seed) { g_seed = seed; } // Compute a pseudorandom integer. // Output value in range [0, 32767] static inline unsigned int fast_rand(void) { g_seed = (214013*g_seed+2531011); return (g_seed>>16)&0x7FFF; } struct subtask_queue { int capacity; // Size of the buffer. int first; // Index of the start of the ring buffer. int num_used; // Number of used elements in the buffer. struct subtask **buffer; pthread_mutex_t mutex; // Mutex used for synchronisation. pthread_cond_t cond; // Condition variable used for synchronisation. int dead; #if defined(MCPROFILE) /* Profiling fields */ uint64_t time_enqueue; uint64_t time_dequeue; uint64_t n_dequeues; uint64_t n_enqueues; #endif }; /* A subtask that can be executed by a worker */ struct subtask { /* The parloop function */ parloop_fn fn; /* Execution parameters */ void* args; int64_t start, end; int id; /* Dynamic scheduling parameters */ int chunkable; int64_t chunk_size; /* Shared variables across subtasks */ volatile int *counter; // Counter for ongoing subtasks // Shared task timers and iterators int64_t *task_time; int64_t *task_iter; /* For debugging */ const char *name; }; struct worker { pthread_t thread; struct scheduler *scheduler; /* Reference to the scheduler struct the worker belongs to*/ struct subtask_queue q; int dead; int tid; /* Just a thread id */ /* "thread local" time fields used for online algorithm */ uint64_t timer; uint64_t total; int nested; /* How nested the current computation is */ // Profiling fields int output_usage; /* Whether to dump thread usage */ uint64_t time_spent_working; /* Time spent in parloop functions */ }; static inline void output_worker_usage(struct worker *worker) { struct rusage usage; CHECK_ERRNO(getrusage_thread(&usage), "getrusage_thread"); struct timeval user_cpu_time = usage.ru_utime; struct timeval sys_cpu_time = usage.ru_stime; fprintf(stderr, "tid: %2d - work time %10llu us - user time: %10llu us - sys: %10llu us\n", worker->tid, (long long unsigned)worker->time_spent_working / 1000, (long long unsigned)(user_cpu_time.tv_sec * 1000000 + user_cpu_time.tv_usec), (long long unsigned)(sys_cpu_time.tv_sec * 1000000 + sys_cpu_time.tv_usec)); } /* Doubles the size of the queue */ static inline int subtask_queue_grow_queue(struct subtask_queue *subtask_queue) { int new_capacity = 2 * subtask_queue->capacity; #ifdef MCDEBUG fprintf(stderr, "Growing queue to %d\n", subtask_queue->capacity * 2); #endif struct subtask **new_buffer = calloc(new_capacity, sizeof(struct subtask*)); for (int i = 0; i < subtask_queue->num_used; i++) { new_buffer[i] = subtask_queue->buffer[(subtask_queue->first + i) % subtask_queue->capacity]; } free(subtask_queue->buffer); subtask_queue->buffer = new_buffer; subtask_queue->capacity = new_capacity; subtask_queue->first = 0; return 0; } // Initialise a job queue with the given capacity. The queue starts out // empty. Returns non-zero on error. static inline int subtask_queue_init(struct subtask_queue *subtask_queue, int capacity) { assert(subtask_queue != NULL); memset(subtask_queue, 0, sizeof(struct subtask_queue)); subtask_queue->capacity = capacity; subtask_queue->buffer = calloc(capacity, sizeof(struct subtask*)); if (subtask_queue->buffer == NULL) { return -1; } CHECK_ERRNO(pthread_mutex_init(&subtask_queue->mutex, NULL), "pthread_mutex_init"); CHECK_ERRNO(pthread_cond_init(&subtask_queue->cond, NULL), "pthread_cond_init"); return 0; } // Destroy the job queue. Blocks until the queue is empty before it // is destroyed. static inline int subtask_queue_destroy(struct subtask_queue *subtask_queue) { assert(subtask_queue != NULL); CHECK_ERR(pthread_mutex_lock(&subtask_queue->mutex), "pthread_mutex_lock"); while (subtask_queue->num_used != 0) { CHECK_ERR(pthread_cond_wait(&subtask_queue->cond, &subtask_queue->mutex), "pthread_cond_wait"); } // Queue is now empty. Let's kill it! subtask_queue->dead = 1; free(subtask_queue->buffer); CHECK_ERR(pthread_cond_broadcast(&subtask_queue->cond), "pthread_cond_broadcast"); CHECK_ERR(pthread_mutex_unlock(&subtask_queue->mutex), "pthread_mutex_unlock"); return 0; } static inline void dump_queue(struct worker *worker) { struct subtask_queue *subtask_queue = &worker->q; CHECK_ERR(pthread_mutex_lock(&subtask_queue->mutex), "pthread_mutex_lock"); for (int i = 0; i < subtask_queue->num_used; i++) { struct subtask * subtask = subtask_queue->buffer[(subtask_queue->first + i) % subtask_queue->capacity]; printf("queue tid %d with %d task %s\n", worker->tid, i, subtask->name); } CHECK_ERR(pthread_cond_broadcast(&subtask_queue->cond), "pthread_cond_broadcast"); CHECK_ERR(pthread_mutex_unlock(&subtask_queue->mutex), "pthread_mutex_unlock"); } // Push an element onto the end of the job queue. Blocks if the // subtask_queue is full (its size is equal to its capacity). Returns // non-zero on error. It is an error to push a job onto a queue that // has been destroyed. static inline int subtask_queue_enqueue(struct worker *worker, struct subtask *subtask ) { assert(worker != NULL); struct subtask_queue *subtask_queue = &worker->q; #ifdef MCPROFILE uint64_t start = get_wall_time(); #endif CHECK_ERR(pthread_mutex_lock(&subtask_queue->mutex), "pthread_mutex_lock"); // Wait until there is room in the subtask_queue. while (subtask_queue->num_used == subtask_queue->capacity && !subtask_queue->dead) { if (subtask_queue->num_used == subtask_queue->capacity) { CHECK_ERR(subtask_queue_grow_queue(subtask_queue), "subtask_queue_grow_queue"); continue; } CHECK_ERR(pthread_cond_wait(&subtask_queue->cond, &subtask_queue->mutex), "pthread_cond_wait"); } if (subtask_queue->dead) { CHECK_ERR(pthread_mutex_unlock(&subtask_queue->mutex), "pthread_mutex_unlock"); return -1; } // If we made it past the loop, there is room in the subtask_queue. subtask_queue->buffer[(subtask_queue->first + subtask_queue->num_used) % subtask_queue->capacity] = subtask; subtask_queue->num_used++; #ifdef MCPROFILE uint64_t end = get_wall_time(); subtask_queue->time_enqueue += (end - start); subtask_queue->n_enqueues++; #endif // Broadcast a reader (if any) that there is now an element. CHECK_ERR(pthread_cond_broadcast(&subtask_queue->cond), "pthread_cond_broadcast"); CHECK_ERR(pthread_mutex_unlock(&subtask_queue->mutex), "pthread_mutex_unlock"); return 0; } /* Like subtask_queue_dequeue, but with two differences: 1) the subtask is stolen from the __front__ of the queue 2) returns immediately if there is no subtasks queued, as we dont' want to block on another workers queue and */ static inline int subtask_queue_steal(struct worker *worker, struct subtask **subtask) { struct subtask_queue *subtask_queue = &worker->q; assert(subtask_queue != NULL); #ifdef MCPROFILE uint64_t start = get_wall_time(); #endif CHECK_ERR(pthread_mutex_lock(&subtask_queue->mutex), "pthread_mutex_lock"); if (subtask_queue->num_used == 0) { CHECK_ERR(pthread_cond_broadcast(&subtask_queue->cond), "pthread_cond_broadcast"); CHECK_ERR(pthread_mutex_unlock(&subtask_queue->mutex), "pthread_mutex_unlock"); return 1; } if (subtask_queue->dead) { CHECK_ERR(pthread_mutex_unlock(&subtask_queue->mutex), "pthread_mutex_unlock"); return -1; } // Tasks gets stolen from the "front" struct subtask *cur_back = subtask_queue->buffer[subtask_queue->first]; struct subtask *new_subtask = NULL; int remaining_iter = cur_back->end - cur_back->start; // If subtask is chunkable, we steal half of the iterations if (cur_back->chunkable && remaining_iter > 1) { int64_t half = remaining_iter / 2; new_subtask = malloc(sizeof(struct subtask)); *new_subtask = *cur_back; new_subtask->start = cur_back->end - half; cur_back->end = new_subtask->start; __atomic_fetch_add(cur_back->counter, 1, __ATOMIC_RELAXED); } else { new_subtask = cur_back; subtask_queue->num_used--; subtask_queue->first = (subtask_queue->first + 1) % subtask_queue->capacity; } *subtask = new_subtask; if (*subtask == NULL) { CHECK_ERR(pthread_mutex_unlock(&subtask_queue->mutex), "pthred_mutex_unlock"); return 1; } #ifdef MCPROFILE uint64_t end = get_wall_time(); subtask_queue->time_dequeue += (end - start); subtask_queue->n_dequeues++; #endif // Broadcast a writer (if any) that there is now room for more. CHECK_ERR(pthread_cond_broadcast(&subtask_queue->cond), "pthread_cond_broadcast"); CHECK_ERR(pthread_mutex_unlock(&subtask_queue->mutex), "pthread_mutex_unlock"); return 0; } // Pop an element from the back of the job queue. // Optional argument can be provided to block or not static inline int subtask_queue_dequeue(struct worker *worker, struct subtask **subtask, int blocking) { assert(worker != NULL); struct subtask_queue *subtask_queue = &worker->q; #ifdef MCPROFILE uint64_t start = get_wall_time(); #endif CHECK_ERR(pthread_mutex_lock(&subtask_queue->mutex), "pthread_mutex_lock"); if (subtask_queue->num_used == 0 && !blocking) { CHECK_ERR(pthread_mutex_unlock(&subtask_queue->mutex), "pthread_mutex_unlock"); return 1; } // Try to steal some work while the subtask_queue is empty while (subtask_queue->num_used == 0 && !subtask_queue->dead) { pthread_cond_wait(&subtask_queue->cond, &subtask_queue->mutex); } if (subtask_queue->dead) { CHECK_ERR(pthread_mutex_unlock(&subtask_queue->mutex), "pthread_mutex_unlock"); return -1; } // dequeue pops from the back *subtask = subtask_queue->buffer[(subtask_queue->first + subtask_queue->num_used - 1) % subtask_queue->capacity]; subtask_queue->num_used--; if (*subtask == NULL) { assert(!"got NULL ptr"); CHECK_ERR(pthread_mutex_unlock(&subtask_queue->mutex), "pthred_mutex_unlock"); return -1; } #ifdef MCPROFILE uint64_t end = get_wall_time(); subtask_queue->time_dequeue += (end - start); subtask_queue->n_dequeues++; #endif // Broadcast a writer (if any) that there is now room for more. CHECK_ERR(pthread_cond_broadcast(&subtask_queue->cond), "pthread_cond_broadcast"); CHECK_ERR(pthread_mutex_unlock(&subtask_queue->mutex), "pthread_mutex_unlock"); return 0; } static inline int subtask_queue_is_empty(struct subtask_queue *subtask_queue) { return subtask_queue->num_used == 0; } /* Scheduler definitions */ struct scheduler { struct worker *workers; int num_threads; int minimum_chunk_size; // If there is work to steal => active_work > 0 volatile int active_work; // Only one error can be returned at the time now. Maybe we can // provide a stack like structure for pushing errors onto if we wish // to backpropagte multiple errors volatile int error; // kappa time unit in nanoseconds double kappa; }; // Thread local variable worker struct // Note that, accesses to tls variables are expensive // Minimize direct references to this variable __thread struct worker* worker_local = NULL; static int64_t total_now(int64_t total, int64_t time) { return total + (get_wall_time_ns() - time); } static int random_other_worker(struct scheduler *scheduler, int my_id) { int my_num_workers = scheduler->num_threads; assert(my_num_workers != 1); int i = fast_rand() % (my_num_workers - 1); if (i >= my_id) { i++; } #ifdef MCDEBUG assert(i >= 0); assert(i < my_num_workers); assert(i != my_id); #endif return i; } static inline int64_t compute_chunk_size(int64_t minimum_chunk_size, double kappa, struct subtask* subtask) { double C = (double)*subtask->task_time / (double)*subtask->task_iter; if (C == 0.0F) C += DBL_EPSILON; return smax64((int64_t)(kappa / C), minimum_chunk_size); } /* Takes a chunk from subtask and enqueues the remaining iterations onto the worker's queue */ /* A no-op if the subtask is not chunkable */ static inline struct subtask* chunk_subtask(struct worker* worker, struct subtask *subtask) { if (subtask->chunkable) { // Do we have information from previous runs avaliable if (*subtask->task_iter > 0) { subtask->chunk_size = compute_chunk_size(worker->scheduler->minimum_chunk_size, worker->scheduler->kappa, subtask); assert(subtask->chunk_size > 0); } int64_t remaining_iter = subtask->end - subtask->start; assert(remaining_iter > 0); if (remaining_iter > subtask->chunk_size) { struct subtask *new_subtask = malloc(sizeof(struct subtask)); *new_subtask = *subtask; // increment the subtask join counter to account for new subtask __atomic_fetch_add(subtask->counter, 1, __ATOMIC_RELAXED); // Update range parameters subtask->end = subtask->start + subtask->chunk_size; new_subtask->start = subtask->end; subtask_queue_enqueue(worker, new_subtask); } } return subtask; } static inline int run_subtask(struct worker* worker, struct subtask* subtask) { assert(subtask != NULL); assert(worker != NULL); subtask = chunk_subtask(worker, subtask); worker->total = 0; worker->timer = get_wall_time_ns(); #if defined(MCPROFILE) int64_t start = worker->timer; #endif worker->nested++; int err = subtask->fn(subtask->args, subtask->start, subtask->end, subtask->id, worker->tid); worker->nested--; // Some error occured during some other subtask // so we just clean-up and return if (worker->scheduler->error != 0) { // Even a failed task counts as finished. __atomic_fetch_sub(subtask->counter, 1, __ATOMIC_RELAXED); free(subtask); return 0; } if (err != 0) { __atomic_store_n(&worker->scheduler->error, err, __ATOMIC_RELAXED); } // Total sequential time spent int64_t time_elapsed = total_now(worker->total, worker->timer); #if defined(MCPROFILE) worker->time_spent_working += get_wall_time_ns() - start; #endif int64_t iter = subtask->end - subtask->start; // report measurements // These updates should really be done using a single atomic CAS operation __atomic_fetch_add(subtask->task_time, time_elapsed, __ATOMIC_RELAXED); __atomic_fetch_add(subtask->task_iter, iter, __ATOMIC_RELAXED); // We need a fence here, since if the counter is decremented before either // of the two above are updated bad things can happen, e.g. if they are stack-allocated __atomic_thread_fence(__ATOMIC_SEQ_CST); __atomic_fetch_sub(subtask->counter, 1, __ATOMIC_RELAXED); free(subtask); return 0; } static inline int is_small(struct scheduler_segop *task, struct scheduler *scheduler, int *nsubtasks) { int64_t time = *task->task_time; int64_t iter = *task->task_iter; if (task->sched == DYNAMIC || iter == 0) { *nsubtasks = scheduler->num_threads; return 0; } // Estimate the constant C double C = (double)time / (double)iter; double cur_task_iter = (double) task->iterations; // Returns true if the task is small i.e. // if the number of iterations times C is smaller // than the overhead of subtask creation if (C == 0.0F || C * cur_task_iter < scheduler->kappa) { *nsubtasks = 1; return 1; } // Else compute how many subtasks this tasks should create int64_t min_iter_pr_subtask = smax64(scheduler->kappa / C, 1); *nsubtasks = smin64(smax64(task->iterations / min_iter_pr_subtask, 1), scheduler->num_threads); return 0; } // TODO make this prettier static inline struct subtask* create_subtask(parloop_fn fn, void* args, const char* name, volatile int* counter, int64_t *timer, int64_t *iter, int64_t start, int64_t end, int chunkable, int64_t chunk_size, int id) { struct subtask* subtask = malloc(sizeof(struct subtask)); if (subtask == NULL) { assert(!"malloc failed in create_subtask"); return NULL; } subtask->fn = fn; subtask->args = args; subtask->counter = counter; subtask->task_time = timer; subtask->task_iter = iter; subtask->start = start; subtask->end = end; subtask->id = id; subtask->chunkable = chunkable; subtask->chunk_size = chunk_size; subtask->name = name; return subtask; } static int dummy_counter = 0; static int64_t dummy_timer = 0; static int64_t dummy_iter = 0; static int dummy_fn(void *args, int64_t start, int64_t end, int subtask_id, int tid) { (void)args; (void)start; (void)end; (void)subtask_id; (void)tid; return 0; } // Wake up threads, who are blocking by pushing a dummy task // onto their queue static inline void wake_up_threads(struct scheduler *scheduler, int start_tid, int end_tid) { #if defined(MCDEBUG) assert(start_tid >= 1); assert(end_tid <= scheduler->num_threads); #endif for (int i = start_tid; i < end_tid; i++) { struct subtask *subtask = create_subtask(dummy_fn, NULL, "dummy_fn", &dummy_counter, &dummy_timer, &dummy_iter, 0, 0, 0, 0, 0); CHECK_ERR(subtask_queue_enqueue(&scheduler->workers[i], subtask), "subtask_queue_enqueue"); } } static inline int is_finished(struct worker *worker) { return worker->dead && subtask_queue_is_empty(&worker->q); } // Try to steal from a random queue static inline int steal_from_random_worker(struct worker* worker) { int my_id = worker->tid; struct scheduler* scheduler = worker->scheduler; int k = random_other_worker(scheduler, my_id); struct worker *worker_k = &scheduler->workers[k]; struct subtask* subtask = NULL; int retval = subtask_queue_steal(worker_k, &subtask); if (retval == 0) { subtask_queue_enqueue(worker, subtask); return 1; } return 0; } static inline void *scheduler_worker(void* args) { struct worker *worker = (struct worker*) args; struct scheduler *scheduler = worker->scheduler; worker_local = worker; struct subtask *subtask = NULL; while(!is_finished(worker)) { if (!subtask_queue_is_empty(&worker->q)) { int retval = subtask_queue_dequeue(worker, &subtask, 0); if (retval == 0) { assert(subtask != NULL); CHECK_ERR(run_subtask(worker, subtask), "run_subtask"); } // else someone stole our work } else if (scheduler->active_work) { /* steal */ while (!is_finished(worker) && scheduler->active_work) { if (steal_from_random_worker(worker)) { break; } } } else { /* go back to sleep and wait for work */ int retval = subtask_queue_dequeue(worker, &subtask, 1); if (retval == 0) { assert(subtask != NULL); CHECK_ERR(run_subtask(worker, subtask), "run_subtask"); } } } assert(subtask_queue_is_empty(&worker->q)); #if defined(MCPROFILE) if (worker->output_usage) output_worker_usage(worker); #endif return NULL; } static inline int scheduler_execute_parloop(struct scheduler *scheduler, struct scheduler_parloop *task, int64_t *timer) { struct worker *worker = worker_local; struct scheduler_info info = task->info; int64_t iter_pr_subtask = info.iter_pr_subtask; int64_t remainder = info.remainder; int nsubtasks = info.nsubtasks; volatile int join_counter = nsubtasks; // Shared timer used to sum up all // sequential work from each subtask int64_t task_timer = 0; int64_t task_iter = 0; enum scheduling sched = info.sched; /* If each subtasks should be processed in chunks */ int chunkable = sched == STATIC ? 0 : 1; int64_t chunk_size = scheduler->minimum_chunk_size; // The initial chunk size when no info is avaliable if (info.wake_up_threads || sched == DYNAMIC) __atomic_add_fetch(&scheduler->active_work, nsubtasks, __ATOMIC_RELAXED); int64_t start = 0; int64_t end = iter_pr_subtask + (int64_t)(remainder != 0); for (int subtask_id = 0; subtask_id < nsubtasks; subtask_id++) { struct subtask *subtask = create_subtask(task->fn, task->args, task->name, &join_counter, &task_timer, &task_iter, start, end, chunkable, chunk_size, subtask_id); assert(subtask != NULL); // In most cases we will never have more subtasks than workers, // but there can be exceptions (e.g. the kappa tuning function). struct worker *subtask_worker = worker->nested ? &scheduler->workers[worker->tid] : &scheduler->workers[subtask_id % scheduler->num_threads]; CHECK_ERR(subtask_queue_enqueue(subtask_worker, subtask), "subtask_queue_enqueue"); // Update range params start = end; end += iter_pr_subtask + ((subtask_id + 1) < remainder); } if (info.wake_up_threads) { wake_up_threads(scheduler, nsubtasks, scheduler->num_threads); } // Join (wait for subtasks to finish) while(join_counter != 0) { if (!subtask_queue_is_empty(&worker->q)) { struct subtask *subtask = NULL; int err = subtask_queue_dequeue(worker, &subtask, 0); if (err == 0 ) { CHECK_ERR(run_subtask(worker, subtask), "run_subtask"); } } else { if (steal_from_random_worker(worker)) { struct subtask *subtask = NULL; int err = subtask_queue_dequeue(worker, &subtask, 0); if (err == 0) { CHECK_ERR(run_subtask(worker, subtask), "run_subtask"); } } } } if (info.wake_up_threads || sched == DYNAMIC) { __atomic_sub_fetch(&scheduler->active_work, nsubtasks, __ATOMIC_RELAXED); } // Write back timing results of all sequential work (*timer) += task_timer; return scheduler->error; } static inline int scheduler_execute_task(struct scheduler *scheduler, struct scheduler_parloop *task) { struct worker *worker = worker_local; int err = 0; // How much sequential work was performed by the task int64_t task_timer = 0; /* Execute task sequential or parallel based on decision made earlier */ if (task->info.nsubtasks == 1) { int64_t start = get_wall_time_ns(); err = task->fn(task->args, 0, task->iterations, 0, worker->tid); int64_t end = get_wall_time_ns(); task_timer = end - start; worker->time_spent_working += task_timer; // Report time measurements // TODO the update of both of these should really be a single atomic!! __atomic_fetch_add(task->info.task_time, task_timer, __ATOMIC_RELAXED); __atomic_fetch_add(task->info.task_iter, task->iterations, __ATOMIC_RELAXED); } else { // Add "before" time if we already are inside a task int64_t time_before = 0; if (worker->nested > 0) { time_before = total_now(worker->total, worker->timer); } err = scheduler_execute_parloop(scheduler, task, &task_timer); // Report time measurements // TODO the update of both of these should really be a single atomic!! __atomic_fetch_add(task->info.task_time, task_timer, __ATOMIC_RELAXED); __atomic_fetch_add(task->info.task_iter, task->iterations, __ATOMIC_RELAXED); // Update timers to account for new timings worker->total = time_before + task_timer; worker->timer = get_wall_time_ns(); } return err; } /* Decide on how schedule the incoming task i.e. how many subtasks and to run sequential or (potentially nested) parallel code body */ static inline int scheduler_prepare_task(struct scheduler* scheduler, struct scheduler_segop *task) { assert(task != NULL); struct worker *worker = worker_local; struct scheduler_info info; info.task_time = task->task_time; info.task_iter = task->task_iter; int nsubtasks; // Decide if task should be scheduled sequentially if (is_small(task, scheduler, &nsubtasks)) { info.iter_pr_subtask = task->iterations; info.remainder = 0; info.nsubtasks = nsubtasks; return task->top_level_fn(task->args, task->iterations, worker->tid, info); } else { info.iter_pr_subtask = task->iterations / nsubtasks; info.remainder = task->iterations % nsubtasks; info.sched = task->sched; switch (task->sched) { case STATIC: info.nsubtasks = info.iter_pr_subtask == 0 ? info.remainder : ((task->iterations - info.remainder) / info.iter_pr_subtask); break; case DYNAMIC: // As any thread can take any subtasks, we are being safe with using // an upper bound on the number of tasks such that the task allocate enough memory info.nsubtasks = info.iter_pr_subtask == 0 ? info.remainder : nsubtasks; break; default: assert(!"Got unknown scheduling"); } } info.wake_up_threads = 0; // We only use the nested parallel segop function if we can't exchaust all cores // using the outer most level if (task->nested_fn != NULL && info.nsubtasks < scheduler->num_threads && info.nsubtasks == task->iterations) { if (worker->nested == 0) info.wake_up_threads = 1; return task->nested_fn(task->args, task->iterations, worker->tid, info); } return task->top_level_fn(task->args, task->iterations, worker->tid, info); } // Now some code for finding the proper value of kappa on a given // machine (the smallest amount of work that amortises the cost of // task creation). struct tuning_struct { int32_t *free_tuning_res; int32_t *array; }; // Reduction function over an integer array static int tuning_loop(void *args, int64_t start, int64_t end, int flat_tid, int tid) { (void)flat_tid; (void)tid; int err = 0; struct tuning_struct *tuning_struct = (struct tuning_struct *) args; int32_t *array = tuning_struct->array; int32_t *tuning_res = tuning_struct->free_tuning_res; int32_t sum = 0; for (int i = start; i < end; i++) { int32_t y = array[i]; sum = add32(sum, y); } *tuning_res = sum; return err; } // The main entry point for the tuning process. Sets the provided // variable ``kappa``. static int determine_kappa(double *kappa) { int err = 0; int64_t iterations = 100000000; int64_t tuning_time = 0; int64_t tuning_iter = 0; int32_t *array = malloc(sizeof(int32_t) * iterations); for (int64_t i = 0; i < iterations; i++) { array[i] = fast_rand(); } int64_t start_tuning = get_wall_time_ns(); /* **************************** */ /* Run sequential reduce first' */ /* **************************** */ int64_t tuning_sequentiual_start = get_wall_time_ns(); struct tuning_struct tuning_struct; int32_t tuning_res; tuning_struct.free_tuning_res = &tuning_res; tuning_struct.array = array; err = tuning_loop(&tuning_struct, 0, iterations, 0, 0); int64_t tuning_sequentiual_end = get_wall_time_ns(); int64_t sequential_elapsed = tuning_sequentiual_end - tuning_sequentiual_start; double C = (double)sequential_elapsed / (double)iterations; fprintf(stderr, " Time for sequential run is %lld - Found C %f\n", (long long)sequential_elapsed, C); /* ********************** */ /* Now run tuning process */ /* ********************** */ // Setup a scheduler with a single worker struct scheduler scheduler; scheduler.num_threads = 1; scheduler.workers = malloc(sizeof(struct worker)); worker_local = &scheduler.workers[0]; worker_local->tid = 0; CHECK_ERR(subtask_queue_init(&scheduler.workers[0].q, 1024), "failed to init queue for worker %d\n", 0); // Start tuning for kappa double kappa_tune = 1000; // Initial kappa is 1 us double ratio; int64_t time_elapsed; while(1) { int64_t min_iter_pr_subtask = (int64_t) (kappa_tune / C) == 0 ? 1 : (kappa_tune / C); int nsubtasks = iterations / min_iter_pr_subtask; struct scheduler_info info; info.iter_pr_subtask = min_iter_pr_subtask; info.nsubtasks = iterations / min_iter_pr_subtask; info.remainder = iterations % min_iter_pr_subtask; info.task_time = &tuning_time; info.task_iter = &tuning_iter; info.sched = STATIC; struct scheduler_parloop parloop; parloop.name = "tuning_loop"; parloop.fn = tuning_loop; parloop.args = &tuning_struct; parloop.iterations = iterations; parloop.info = info; int64_t tuning_chunked_start = get_wall_time_ns(); int determine_kappa_err = scheduler_execute_task(&scheduler, &parloop); assert(determine_kappa_err == 0); int64_t tuning_chunked_end = get_wall_time_ns(); time_elapsed = tuning_chunked_end - tuning_chunked_start; ratio = (double)time_elapsed / (double)sequential_elapsed; if (ratio < 1.055) { break; } kappa_tune += 100; // Increase by 100 ns at the time fprintf(stderr, "nsubtask %d - kappa %f - ratio %f\n", nsubtasks, kappa_tune, ratio); } int64_t end_tuning = get_wall_time_ns(); fprintf(stderr, "tuning took %lld ns and found kappa %f - time %lld - ratio %f\n", (long long)end_tuning - start_tuning, kappa_tune, (long long)time_elapsed, ratio); *kappa = kappa_tune; // Clean-up CHECK_ERR(subtask_queue_destroy(&scheduler.workers[0].q), "failed to destroy queue"); free(array); free(scheduler.workers); return err; } static int scheduler_init(struct scheduler *scheduler, int num_workers, double kappa) { #ifdef FUTHARK_BACKEND_ispc int64_t get_gang_size(); scheduler->minimum_chunk_size = get_gang_size(); #else scheduler->minimum_chunk_size = 1; #endif assert(num_workers > 0); scheduler->kappa = kappa; scheduler->num_threads = num_workers; scheduler->active_work = 0; scheduler->error = 0; scheduler->workers = calloc(num_workers, sizeof(struct worker)); const int queue_capacity = 1024; worker_local = &scheduler->workers[0]; worker_local->tid = 0; worker_local->scheduler = scheduler; CHECK_ERR(subtask_queue_init(&worker_local->q, queue_capacity), "failed to init queue for worker %d\n", 0); for (int i = 1; i < num_workers; i++) { struct worker *cur_worker = &scheduler->workers[i]; memset(cur_worker, 0, sizeof(struct worker)); cur_worker->tid = i; cur_worker->output_usage = 0; cur_worker->scheduler = scheduler; CHECK_ERR(subtask_queue_init(&cur_worker->q, queue_capacity), "failed to init queue for worker %d\n", i); CHECK_ERR(pthread_create(&cur_worker->thread, NULL, &scheduler_worker, cur_worker), "Failed to create worker %d\n", i); } return 0; } static int scheduler_destroy(struct scheduler *scheduler) { // We assume that this function is called by the thread controlling // the first worker, which is why we treat scheduler->workers[0] // specially here. // First mark them all as dead. for (int i = 1; i < scheduler->num_threads; i++) { struct worker *cur_worker = &scheduler->workers[i]; cur_worker->dead = 1; } // Then destroy their task queues (this will wake up the threads and // make them do their shutdown). for (int i = 1; i < scheduler->num_threads; i++) { struct worker *cur_worker = &scheduler->workers[i]; subtask_queue_destroy(&cur_worker->q); } // Then actually wait for them to stop. for (int i = 1; i < scheduler->num_threads; i++) { struct worker *cur_worker = &scheduler->workers[i]; CHECK_ERR(pthread_join(scheduler->workers[i].thread, NULL), "pthread_join"); } // And then destroy our own queue. subtask_queue_destroy(&scheduler->workers[0].q); free(scheduler->workers); return 0; } // End of scheduler.h futhark-0.25.27/rts/c/server.h000066400000000000000000000620061475065116200160570ustar00rootroot00000000000000// Start of server.h. // Forward declarations of things that we technically don't know until // the application header file is included, but which we need. struct futhark_context_config; struct futhark_context; char *futhark_context_get_error(struct futhark_context *ctx); int futhark_context_sync(struct futhark_context *ctx); int futhark_context_clear_caches(struct futhark_context *ctx); int futhark_context_config_set_tuning_param(struct futhark_context_config *cfg, const char *param_name, size_t new_value); int futhark_get_tuning_param_count(void); const char* futhark_get_tuning_param_name(int i); const char* futhark_get_tuning_param_class(int i); typedef int (*restore_fn)(const void*, FILE *, struct futhark_context*, void*); typedef void (*store_fn)(const void*, FILE *, struct futhark_context*, void*); typedef int (*free_fn)(const void*, struct futhark_context*, void*); typedef int (*project_fn)(struct futhark_context*, void*, const void*); typedef int (*new_fn)(struct futhark_context*, void**, const void*[]); struct field { const char *name; const struct type *type; project_fn project; }; struct record { int num_fields; const struct field* fields; new_fn new; }; struct type { const char *name; restore_fn restore; store_fn store; free_fn free; const void *aux; const struct record *record; }; int free_scalar(const void *aux, struct futhark_context *ctx, void *p) { (void)aux; (void)ctx; (void)p; // Nothing to do. return 0; } #define DEF_SCALAR_TYPE(T) \ int restore_##T(const void *aux, FILE *f, \ struct futhark_context *ctx, void *p) { \ (void)aux; \ (void)ctx; \ return read_scalar(f, &T##_info, p); \ } \ \ void store_##T(const void *aux, FILE *f, \ struct futhark_context *ctx, void *p) { \ (void)aux; \ (void)ctx; \ write_scalar(f, 1, &T##_info, p); \ } \ \ struct type type_##T = \ { .name = #T, \ .restore = restore_##T, \ .store = store_##T, \ .free = free_scalar \ } \ DEF_SCALAR_TYPE(i8); DEF_SCALAR_TYPE(i16); DEF_SCALAR_TYPE(i32); DEF_SCALAR_TYPE(i64); DEF_SCALAR_TYPE(u8); DEF_SCALAR_TYPE(u16); DEF_SCALAR_TYPE(u32); DEF_SCALAR_TYPE(u64); DEF_SCALAR_TYPE(f16); DEF_SCALAR_TYPE(f32); DEF_SCALAR_TYPE(f64); DEF_SCALAR_TYPE(bool); struct value { const struct type *type; union { void *v_ptr; int8_t v_i8; int16_t v_i16; int32_t v_i32; int64_t v_i64; uint8_t v_u8; uint16_t v_u16; uint32_t v_u32; uint64_t v_u64; uint16_t v_f16; float v_f32; double v_f64; bool v_bool; } value; }; void* value_ptr(struct value *v) { if (v->type == &type_i8) { return &v->value.v_i8; } if (v->type == &type_i16) { return &v->value.v_i16; } if (v->type == &type_i32) { return &v->value.v_i32; } if (v->type == &type_i64) { return &v->value.v_i64; } if (v->type == &type_u8) { return &v->value.v_u8; } if (v->type == &type_u16) { return &v->value.v_u16; } if (v->type == &type_u32) { return &v->value.v_u32; } if (v->type == &type_u64) { return &v->value.v_u64; } if (v->type == &type_f16) { return &v->value.v_f16; } if (v->type == &type_f32) { return &v->value.v_f32; } if (v->type == &type_f64) { return &v->value.v_f64; } if (v->type == &type_bool) { return &v->value.v_bool; } return &v->value.v_ptr; } struct variable { // NULL name indicates free slot. Name is owned by this struct. char *name; struct value value; }; typedef int (*entry_point_fn)(struct futhark_context*, void**, void**); struct entry_point { const char *name; entry_point_fn f; const char** tuning_params; const struct type **out_types; bool *out_unique; const struct type **in_types; bool *in_unique; }; int entry_num_ins(struct entry_point *e) { int count = 0; while (e->in_types[count]) { count++; } return count; } int entry_num_outs(struct entry_point *e) { int count = 0; while (e->out_types[count]) { count++; } return count; } struct futhark_prog { // Last entry point identified by NULL name. struct entry_point *entry_points; // Last type identified by NULL name. const struct type **types; }; struct server_state { struct futhark_prog prog; struct futhark_context_config *cfg; struct futhark_context *ctx; int variables_capacity; struct variable *variables; }; struct variable* get_variable(struct server_state *s, const char *name) { for (int i = 0; i < s->variables_capacity; i++) { if (s->variables[i].name != NULL && strcmp(s->variables[i].name, name) == 0) { return &s->variables[i]; } } return NULL; } struct variable* create_variable(struct server_state *s, const char *name, const struct type *type) { int found = -1; for (int i = 0; i < s->variables_capacity; i++) { if (found == -1 && s->variables[i].name == NULL) { found = i; } else if (s->variables[i].name != NULL && strcmp(s->variables[i].name, name) == 0) { return NULL; } } if (found != -1) { // Found a free spot. s->variables[found].name = strdup(name); s->variables[found].value.type = type; return &s->variables[found]; } // Need to grow the buffer. found = s->variables_capacity; s->variables_capacity *= 2; s->variables = realloc(s->variables, s->variables_capacity * sizeof(struct variable)); s->variables[found].name = strdup(name); s->variables[found].value.type = type; for (int i = found+1; i < s->variables_capacity; i++) { s->variables[i].name = NULL; } return &s->variables[found]; } void drop_variable(struct variable *v) { free(v->name); v->name = NULL; } int arg_exists(const char *args[], int i) { return args[i] != NULL; } const char* get_arg(const char *args[], int i) { if (!arg_exists(args, i)) { futhark_panic(1, "Insufficient command args.\n"); } return args[i]; } const struct type* get_type(struct server_state *s, const char *name) { for (int i = 0; s->prog.types[i]; i++) { if (strcmp(s->prog.types[i]->name, name) == 0) { return s->prog.types[i]; } } futhark_panic(1, "Unknown type %s\n", name); return NULL; } struct entry_point* get_entry_point(struct server_state *s, const char *name) { for (int i = 0; s->prog.entry_points[i].name; i++) { if (strcmp(s->prog.entry_points[i].name, name) == 0) { return &s->prog.entry_points[i]; } } return NULL; } // Print the command-done marker, indicating that we are ready for // more input. void ok(void) { printf("%%%%%% OK\n"); fflush(stdout); } // Print the failure marker. Output is now an error message until the // next ok(). void failure(void) { printf("%%%%%% FAILURE\n"); } void error_check(struct server_state *s, int err) { if (err != 0) { failure(); char *error = futhark_context_get_error(s->ctx); if (error != NULL) { puts(error); } free(error); } } void cmd_call(struct server_state *s, const char *args[]) { const char *name = get_arg(args, 0); struct entry_point *e = get_entry_point(s, name); if (e == NULL) { failure(); printf("Unknown entry point: %s\n", name); return; } int num_outs = entry_num_outs(e); int num_ins = entry_num_ins(e); // +1 to avoid zero-size arrays, which is UB. void* outs[num_outs+1]; void* ins[num_ins+1]; for (int i = 0; i < num_ins; i++) { const char *in_name = get_arg(args, 1+num_outs+i); struct variable *v = get_variable(s, in_name); if (v == NULL) { failure(); printf("Unknown variable: %s\n", in_name); return; } if (v->value.type != e->in_types[i]) { failure(); printf("Wrong input type. Expected %s, got %s.\n", e->in_types[i]->name, v->value.type->name); return; } ins[i] = value_ptr(&v->value); } for (int i = 0; i < num_outs; i++) { const char *out_name = get_arg(args, 1+i); struct variable *v = create_variable(s, out_name, e->out_types[i]); if (v == NULL) { failure(); printf("Variable already exists: %s\n", out_name); return; } outs[i] = value_ptr(&v->value); } int64_t t_start = get_wall_time(); int err = e->f(s->ctx, outs, ins); err |= futhark_context_sync(s->ctx); int64_t t_end = get_wall_time(); long long int elapsed_usec = t_end - t_start; printf("runtime: %lld\n", elapsed_usec); error_check(s, err); if (err != 0) { // Need to uncreate the output variables, which would otherwise be left // in an uninitialised state. for (int i = 0; i < num_outs; i++) { const char *out_name = get_arg(args, 1+i); struct variable *v = get_variable(s, out_name); if (v) { drop_variable(v); } } } } void cmd_restore(struct server_state *s, const char *args[]) { const char *fname = get_arg(args, 0); FILE *f = fopen(fname, "rb"); if (f == NULL) { failure(); printf("Failed to open %s: %s\n", fname, strerror(errno)); return; } int bad = 0; int values = 0; for (int i = 1; arg_exists(args, i); i+=2, values++) { const char *vname = get_arg(args, i); const char *type = get_arg(args, i+1); const struct type *t = get_type(s, type); struct variable *v = create_variable(s, vname, t); if (v == NULL) { bad = 1; failure(); printf("Variable already exists: %s\n", vname); break; } errno = 0; if (t->restore(t->aux, f, s->ctx, value_ptr(&v->value)) != 0) { bad = 1; failure(); printf("Failed to restore variable %s.\n" "Possibly malformed data in %s (errno: %s)\n", vname, fname, strerror(errno)); drop_variable(v); break; } } if (!bad && end_of_input(f) != 0) { failure(); printf("Expected EOF after reading %d values from %s\n", values, fname); } fclose(f); if (!bad) { int err = futhark_context_sync(s->ctx); error_check(s, err); } } void cmd_store(struct server_state *s, const char *args[]) { const char *fname = get_arg(args, 0); FILE *f = fopen(fname, "wb"); if (f == NULL) { failure(); printf("Failed to open %s: %s\n", fname, strerror(errno)); } else { for (int i = 1; arg_exists(args, i); i++) { const char *vname = get_arg(args, i); struct variable *v = get_variable(s, vname); if (v == NULL) { failure(); printf("Unknown variable: %s\n", vname); return; } const struct type *t = v->value.type; t->store(t->aux, f, s->ctx, value_ptr(&v->value)); } fclose(f); } } void cmd_free(struct server_state *s, const char *args[]) { for (int i = 0; arg_exists(args, i); i++) { const char *name = get_arg(args, i); struct variable *v = get_variable(s, name); if (v == NULL) { failure(); printf("Unknown variable: %s\n", name); return; } const struct type *t = v->value.type; int err = t->free(t->aux, s->ctx, value_ptr(&v->value)); error_check(s, err); drop_variable(v); } } void cmd_rename(struct server_state *s, const char *args[]) { const char *oldname = get_arg(args, 0); const char *newname = get_arg(args, 1); struct variable *old = get_variable(s, oldname); struct variable *new = get_variable(s, newname); if (old == NULL) { failure(); printf("Unknown variable: %s\n", oldname); return; } if (new != NULL) { failure(); printf("Variable already exists: %s\n", newname); return; } free(old->name); old->name = strdup(newname); } void cmd_inputs(struct server_state *s, const char *args[]) { const char *name = get_arg(args, 0); struct entry_point *e = get_entry_point(s, name); if (e == NULL) { failure(); printf("Unknown entry point: %s\n", name); return; } int num_ins = entry_num_ins(e); for (int i = 0; i < num_ins; i++) { if (e->in_unique[i]) { putchar('*'); } puts(e->in_types[i]->name); } } void cmd_outputs(struct server_state *s, const char *args[]) { const char *name = get_arg(args, 0); struct entry_point *e = get_entry_point(s, name); if (e == NULL) { failure(); printf("Unknown entry point: %s\n", name); return; } int num_outs = entry_num_outs(e); for (int i = 0; i < num_outs; i++) { if (e->out_unique[i]) { putchar('*'); } puts(e->out_types[i]->name); } } void cmd_clear(struct server_state *s, const char *args[]) { (void)args; int err = 0; for (int i = 0; i < s->variables_capacity; i++) { struct variable *v = &s->variables[i]; if (v->name != NULL) { err |= v->value.type->free(v->value.type->aux, s->ctx, value_ptr(&v->value)); drop_variable(v); } } err |= futhark_context_clear_caches(s->ctx); error_check(s, err); } void cmd_pause_profiling(struct server_state *s, const char *args[]) { (void)args; futhark_context_pause_profiling(s->ctx); } void cmd_unpause_profiling(struct server_state *s, const char *args[]) { (void)args; futhark_context_unpause_profiling(s->ctx); } void cmd_report(struct server_state *s, const char *args[]) { (void)args; char *report = futhark_context_report(s->ctx); if (report) { puts(report); } else { failure(); report = futhark_context_get_error(s->ctx); if (report) { puts(report); } else { puts("Failed to produce profiling report.\n"); } } free(report); } void cmd_set_tuning_param(struct server_state *s, const char *args[]) { const char *param = get_arg(args, 0); const char *val_s = get_arg(args, 1); size_t val = atol(val_s); int err = futhark_context_config_set_tuning_param(s->cfg, param, val); error_check(s, err); if (err != 0) { printf("Failed to set tuning parameter %s to %ld\n", param, (long)val); } } void cmd_tuning_params(struct server_state *s, const char *args[]) { const char *name = get_arg(args, 0); struct entry_point *e = get_entry_point(s, name); if (e == NULL) { failure(); printf("Unknown entry point: %s\n", name); return; } const char **params = e->tuning_params; for (int i = 0; params[i] != NULL; i++) { printf("%s\n", params[i]); } } void cmd_tuning_param_class(struct server_state *s, const char *args[]) { (void)s; const char *param = get_arg(args, 0); int n = futhark_get_tuning_param_count(); for (int i = 0; i < n; i++) { if (strcmp(futhark_get_tuning_param_name(i), param) == 0) { printf("%s\n", futhark_get_tuning_param_class(i)); return; } } failure(); printf("Unknown tuning parameter: %s\n", param); } void cmd_fields(struct server_state *s, const char *args[]) { const char *type = get_arg(args, 0); const struct type *t = get_type(s, type); const struct record *r = t->record; if (r == NULL) { failure(); printf("Not a record type\n"); return; } for (int i = 0; i < r->num_fields; i++) { const struct field f = r->fields[i]; printf("%s %s\n", f.name, f.type->name); } } void cmd_project(struct server_state *s, const char *args[]) { const char *to_name = get_arg(args, 0); const char *from_name = get_arg(args, 1); const char *field_name = get_arg(args, 2); struct variable *from = get_variable(s, from_name); if (from == NULL) { failure(); printf("Unknown variable: %s\n", from_name); return; } const struct type *from_type = from->value.type; const struct record *r = from_type->record; if (r == NULL) { failure(); printf("Not a record type\n"); return; } const struct field *field = NULL; for (int i = 0; i < r->num_fields; i++) { if (strcmp(r->fields[i].name, field_name) == 0) { field = &r->fields[i]; break; } } if (field == NULL) { failure(); printf("No such field\n"); } struct variable *to = create_variable(s, to_name, field->type); if (to == NULL) { failure(); printf("Variable already exists: %s\n", to_name); return; } field->project(s->ctx, value_ptr(&to->value), from->value.value.v_ptr); } void cmd_new(struct server_state *s, const char *args[]) { const char *to_name = get_arg(args, 0); const char *type_name = get_arg(args, 1); const struct type *type = get_type(s, type_name); struct variable *to = create_variable(s, to_name, type); if (to == NULL) { failure(); printf("Variable already exists: %s\n", to_name); return; } const struct record* r = type->record; if (r == NULL) { failure(); printf("Not a record type\n"); return; } int num_args = 0; for (int i = 2; arg_exists(args, i); i++) { num_args++; } if (num_args != r->num_fields) { failure(); printf("%d fields expected but %d values provided.\n", num_args, r->num_fields); return; } const void** value_ptrs = alloca(num_args * sizeof(void*)); for (int i = 0; i < num_args; i++) { struct variable* v = get_variable(s, args[2+i]); if (v == NULL) { failure(); printf("Unknown variable: %s\n", args[2+i]); return; } if (strcmp(v->value.type->name, r->fields[i].type->name) != 0) { failure(); printf("Field %s mismatch: expected type %s, got %s\n", r->fields[i].name, r->fields[i].type->name, v->value.type->name); return; } value_ptrs[i] = value_ptr(&v->value); } r->new(s->ctx, value_ptr(&to->value), value_ptrs); } void cmd_entry_points(struct server_state *s, const char *args[]) { (void)args; for (int i = 0; s->prog.entry_points[i].name; i++) { puts(s->prog.entry_points[i].name); } } void cmd_types(struct server_state *s, const char *args[]) { (void)args; for (int i = 0; s->prog.types[i] != NULL; i++) { puts(s->prog.types[i]->name); } } char *next_word(char **line) { char *p = *line; while (isspace(*p)) { p++; } if (*p == 0) { return NULL; } if (*p == '"') { char *save = p+1; // Skip ahead till closing quote. p++; while (*p && *p != '"') { p++; } if (*p == '"') { *p = 0; *line = p+1; return save; } else { return NULL; } } else { char *save = p; // Skip ahead till next whitespace. while (*p && !isspace(*p)) { p++; } if (*p) { *p = 0; *line = p+1; } else { *line = p; } return save; } } void process_line(struct server_state *s, char *line) { int max_num_tokens = 1000; const char* tokens[max_num_tokens]; int num_tokens = 0; while ((tokens[num_tokens] = next_word(&line)) != NULL) { num_tokens++; if (num_tokens == max_num_tokens) { futhark_panic(1, "Line too long.\n"); } } const char *command = tokens[0]; if (command == NULL) { failure(); printf("Empty line\n"); } else if (strcmp(command, "call") == 0) { cmd_call(s, tokens+1); } else if (strcmp(command, "restore") == 0) { cmd_restore(s, tokens+1); } else if (strcmp(command, "store") == 0) { cmd_store(s, tokens+1); } else if (strcmp(command, "free") == 0) { cmd_free(s, tokens+1); } else if (strcmp(command, "rename") == 0) { cmd_rename(s, tokens+1); } else if (strcmp(command, "inputs") == 0) { cmd_inputs(s, tokens+1); } else if (strcmp(command, "outputs") == 0) { cmd_outputs(s, tokens+1); } else if (strcmp(command, "clear") == 0) { cmd_clear(s, tokens+1); } else if (strcmp(command, "pause_profiling") == 0) { cmd_pause_profiling(s, tokens+1); } else if (strcmp(command, "unpause_profiling") == 0) { cmd_unpause_profiling(s, tokens+1); } else if (strcmp(command, "report") == 0) { cmd_report(s, tokens+1); } else if (strcmp(command, "set_tuning_param") == 0) { cmd_set_tuning_param(s, tokens+1); } else if (strcmp(command, "tuning_params") == 0) { cmd_tuning_params(s, tokens+1); } else if (strcmp(command, "tuning_param_class") == 0) { cmd_tuning_param_class(s, tokens+1); } else if (strcmp(command, "fields") == 0) { cmd_fields(s, tokens+1); } else if (strcmp(command, "new") == 0) { cmd_new(s, tokens+1); } else if (strcmp(command, "project") == 0) { cmd_project(s, tokens+1); } else if (strcmp(command, "entry_points") == 0) { cmd_entry_points(s, tokens+1); } else if (strcmp(command, "types") == 0) { cmd_types(s, tokens+1); } else { futhark_panic(1, "Unknown command: %s\n", command); } } void run_server(struct futhark_prog *prog, struct futhark_context_config *cfg, struct futhark_context *ctx) { char *line = NULL; size_t buflen = 0; ssize_t linelen; struct server_state s = { .cfg = cfg, .ctx = ctx, .variables_capacity = 100, .prog = *prog }; s.variables = malloc(s.variables_capacity * sizeof(struct variable)); for (int i = 0; i < s.variables_capacity; i++) { s.variables[i].name = NULL; } ok(); while ((linelen = getline(&line, &buflen, stdin)) > 0) { process_line(&s, line); ok(); } free(s.variables); free(line); } // The aux struct lets us write generic method implementations without // code duplication. typedef void* (*array_new_fn)(struct futhark_context *, const void*, const int64_t*); typedef const int64_t* (*array_shape_fn)(struct futhark_context*, void*); typedef int (*array_values_fn)(struct futhark_context*, void*, void*); typedef int (*array_free_fn)(struct futhark_context*, void*); struct array_aux { int rank; const struct primtype_info_t* info; const char *name; array_new_fn new; array_shape_fn shape; array_values_fn values; array_free_fn free; }; int restore_array(const struct array_aux *aux, FILE *f, struct futhark_context *ctx, void *p) { void *data = NULL; int64_t shape[aux->rank]; if (read_array(f, aux->info, &data, shape, aux->rank) != 0) { return 1; } void *arr = aux->new(ctx, data, shape); if (arr == NULL) { return 1; } int err = futhark_context_sync(ctx); *(void**)p = arr; free(data); return err; } void store_array(const struct array_aux *aux, FILE *f, struct futhark_context *ctx, void *p) { void *arr = *(void**)p; const int64_t *shape = aux->shape(ctx, arr); int64_t size = sizeof(aux->info->size); for (int i = 0; i < aux->rank; i++) { size *= shape[i]; } int32_t *data = malloc(size); assert(aux->values(ctx, arr, data) == 0); assert(futhark_context_sync(ctx) == 0); assert(write_array(f, 1, aux->info, data, shape, aux->rank) == 0); free(data); } int free_array(const struct array_aux *aux, struct futhark_context *ctx, void *p) { void *arr = *(void**)p; return aux->free(ctx, arr); } typedef void* (*opaque_restore_fn)(struct futhark_context*, void*); typedef int (*opaque_store_fn)(struct futhark_context*, const void*, void **, size_t *); typedef int (*opaque_free_fn)(struct futhark_context*, void*); struct opaque_aux { opaque_restore_fn restore; opaque_store_fn store; opaque_free_fn free; }; int restore_opaque(const struct opaque_aux *aux, FILE *f, struct futhark_context *ctx, void *p) { // We have a problem: we need to load data from 'f', since the // restore function takes a pointer, but we don't know how much we // need (and cannot possibly). So we do something hacky: we read // *all* of the file, pass all of the data to the restore function // (which doesn't care if there's extra at the end), then we compute // how much space the the object actually takes in serialised form // and rewind the file to that position. The only downside is more IO. size_t start = ftell(f); size_t size; char *bytes = fslurp_file(f, &size); void *obj = aux->restore(ctx, bytes); free(bytes); if (obj != NULL) { *(void**)p = obj; size_t obj_size; (void)aux->store(ctx, obj, NULL, &obj_size); fseek(f, start+obj_size, SEEK_SET); return 0; } else { fseek(f, start, SEEK_SET); return 1; } } void store_opaque(const struct opaque_aux *aux, FILE *f, struct futhark_context *ctx, void *p) { void *obj = *(void**)p; size_t obj_size; void *data = NULL; (void)aux->store(ctx, obj, &data, &obj_size); assert(futhark_context_sync(ctx) == 0); fwrite(data, sizeof(char), obj_size, f); free(data); } int free_opaque(const struct opaque_aux *aux, struct futhark_context *ctx, void *p) { void *obj = *(void**)p; return aux->free(ctx, obj); } // End of server.h. futhark-0.25.27/rts/c/timing.h000066400000000000000000000015441475065116200160400ustar00rootroot00000000000000// Start of timing.h. // The function get_wall_time() returns the wall time in microseconds // (with an unspecified offset). #ifdef _WIN32 #include static int64_t get_wall_time(void) { LARGE_INTEGER time,freq; assert(QueryPerformanceFrequency(&freq)); assert(QueryPerformanceCounter(&time)); return ((double)time.QuadPart / freq.QuadPart) * 1000000; } static int64_t get_wall_time_ns(void) { return get_wall_time() * 1000; } #else // Assuming POSIX #include #include static int64_t get_wall_time(void) { struct timeval time; assert(gettimeofday(&time,NULL) == 0); return time.tv_sec * 1000000 + time.tv_usec; } static int64_t get_wall_time_ns(void) { struct timespec time; assert(clock_gettime(CLOCK_REALTIME, &time) == 0); return time.tv_sec * 1000000000 + time.tv_nsec; } #endif // End of timing.h. futhark-0.25.27/rts/c/tuning.h000066400000000000000000000027651475065116200160630ustar00rootroot00000000000000// Start of tuning.h. int is_blank_line_or_comment(const char *s) { size_t i = strspn(s, " \t\n"); return s[i] == '\0' || // Line is blank. strncmp(s + i, "--", 2) == 0; // Line is comment. } static char* load_tuning_file(const char *fname, void *cfg, int (*set_tuning_param)(void*, const char*, size_t)) { const int max_line_len = 1024; char* line = (char*) malloc(max_line_len); FILE *f = fopen(fname, "r"); if (f == NULL) { snprintf(line, max_line_len, "Cannot open file: %s", strerror(errno)); return line; } int lineno = 0; while (fgets(line, max_line_len, f) != NULL) { lineno++; if (is_blank_line_or_comment(line)) { continue; } char *eql = strstr(line, "="); if (eql) { *eql = 0; char *endptr; int value = strtol(eql+1, &endptr, 10); if (*endptr && *endptr != '\n') { snprintf(line, max_line_len, "Invalid line %d (must be of form 'name=int').", lineno); return line; } if (set_tuning_param(cfg, line, (size_t)value) != 0) { char* err = (char*) malloc(max_line_len + 50); snprintf(err, max_line_len + 50, "Unknown name '%s' on line %d.", line, lineno); free(line); return err; } } else { snprintf(line, max_line_len, "Invalid line %d (must be of form 'name=int').", lineno); return line; } } free(line); return NULL; } // End of tuning.h. futhark-0.25.27/rts/c/uniform.h000066400000000000000000001245551475065116200162400ustar00rootroot00000000000000// Start of uniform.h // Uniform versions of all library functions as to // improve performance in ISPC when in an uniform context. #if defined(ISPC) static inline uniform uint8_t add8(uniform uint8_t x, uniform uint8_t y) { return x + y; } static inline uniform uint16_t add16(uniform uint16_t x, uniform uint16_t y) { return x + y; } static inline uniform uint32_t add32(uniform uint32_t x, uniform uint32_t y) { return x + y; } static inline uniform uint64_t add64(uniform uint64_t x, uniform uint64_t y) { return x + y; } static inline uniform uint8_t sub8(uniform uint8_t x, uniform uint8_t y) { return x - y; } static inline uniform uint16_t sub16(uniform uint16_t x, uniform uint16_t y) { return x - y; } static inline uniform uint32_t sub32(uniform uint32_t x, uniform uint32_t y) { return x - y; } static inline uniform uint64_t sub64(uniform uint64_t x, uniform uint64_t y) { return x - y; } static inline uniform uint8_t mul8(uniform uint8_t x, uniform uint8_t y) { return x * y; } static inline uniform uint16_t mul16(uniform uint16_t x, uniform uint16_t y) { return x * y; } static inline uniform uint32_t mul32(uniform uint32_t x, uniform uint32_t y) { return x * y; } static inline uniform uint64_t mul64(uniform uint64_t x, uniform uint64_t y) { return x * y; } static inline uniform uint8_t udiv8(uniform uint8_t x, uniform uint8_t y) { return x / y; } static inline uniform uint16_t udiv16(uniform uint16_t x, uniform uint16_t y) { return x / y; } static inline uniform uint32_t udiv32(uniform uint32_t x, uniform uint32_t y) { return x / y; } static inline uniform uint64_t udiv64(uniform uint64_t x, uniform uint64_t y) { return x / y; } static inline uniform uint8_t udiv_up8(uniform uint8_t x, uniform uint8_t y) { return (x + y - 1) / y; } static inline uniform uint16_t udiv_up16(uniform uint16_t x, uniform uint16_t y) { return (x + y - 1) / y; } static inline uniform uint32_t udiv_up32(uniform uint32_t x, uniform uint32_t y) { return (x + y - 1) / y; } static inline uniform uint64_t udiv_up64(uniform uint64_t x, uniform uint64_t y) { return (x + y - 1) / y; } static inline uniform uint8_t umod8(uniform uint8_t x, uniform uint8_t y) { return x % y; } static inline uniform uint16_t umod16(uniform uint16_t x, uniform uint16_t y) { return x % y; } static inline uniform uint32_t umod32(uniform uint32_t x, uniform uint32_t y) { return x % y; } static inline uniform uint64_t umod64(uniform uint64_t x, uniform uint64_t y) { return x % y; } static inline uniform uint8_t udiv_safe8(uniform uint8_t x, uniform uint8_t y) { return y == 0 ? 0 : x / y; } static inline uniform uint16_t udiv_safe16(uniform uint16_t x, uniform uint16_t y) { return y == 0 ? 0 : x / y; } static inline uniform uint32_t udiv_safe32(uniform uint32_t x, uniform uint32_t y) { return y == 0 ? 0 : x / y; } static inline uniform uint64_t udiv_safe64(uniform uint64_t x, uniform uint64_t y) { return y == 0 ? 0 : x / y; } static inline uniform uint8_t udiv_up_safe8(uniform uint8_t x, uniform uint8_t y) { return y == 0 ? 0 : (x + y - 1) / y; } static inline uniform uint16_t udiv_up_safe16(uniform uint16_t x, uniform uint16_t y) { return y == 0 ? 0 : (x + y - 1) / y; } static inline uniform uint32_t udiv_up_safe32(uniform uint32_t x, uniform uint32_t y) { return y == 0 ? 0 : (x + y - 1) / y; } static inline uniform uint64_t udiv_up_safe64(uniform uint64_t x, uniform uint64_t y) { return y == 0 ? 0 : (x + y - 1) / y; } static inline uniform uint8_t umod_safe8(uniform uint8_t x, uniform uint8_t y) { return y == 0 ? 0 : x % y; } static inline uniform uint16_t umod_safe16(uniform uint16_t x, uniform uint16_t y) { return y == 0 ? 0 : x % y; } static inline uniform uint32_t umod_safe32(uniform uint32_t x, uniform uint32_t y) { return y == 0 ? 0 : x % y; } static inline uniform uint64_t umod_safe64(uniform uint64_t x, uniform uint64_t y) { return y == 0 ? 0 : x % y; } static inline uniform int8_t sdiv8(uniform int8_t x, uniform int8_t y) { uniform int8_t q = x / y; uniform int8_t r = x % y; return q - ((r != 0 && r < 0 != y < 0) ? 1 : 0); } static inline uniform int16_t sdiv16(uniform int16_t x, uniform int16_t y) { uniform int16_t q = x / y; uniform int16_t r = x % y; return q - ((r != 0 && r < 0 != y < 0) ? 1 : 0); } static inline uniform int32_t sdiv32(uniform int32_t x, uniform int32_t y) { uniform int32_t q = x / y; uniform int32_t r = x % y; return q - ((r != 0 && r < 0 != y < 0) ? 1 : 0); } static inline uniform int64_t sdiv64(uniform int64_t x, uniform int64_t y) { uniform int64_t q = x / y; uniform int64_t r = x % y; return q - ((r != 0 && r < 0 != y < 0) ? 1 : 0); } static inline uniform int8_t sdiv_up8(uniform int8_t x, uniform int8_t y) { return sdiv8(x + y - 1, y); } static inline uniform int16_t sdiv_up16(uniform int16_t x, uniform int16_t y) { return sdiv16(x + y - 1, y); } static inline uniform int32_t sdiv_up32(uniform int32_t x, uniform int32_t y) { return sdiv32(x + y - 1, y); } static inline uniform int64_t sdiv_up64(uniform int64_t x, uniform int64_t y) { return sdiv64(x + y - 1, y); } static inline uniform int8_t smod8(uniform int8_t x, uniform int8_t y) { uniform int8_t r = x % y; return r + (r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0) ? 0 : y); } static inline uniform int16_t smod16(uniform int16_t x, uniform int16_t y) { uniform int16_t r = x % y; return r + (r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0) ? 0 : y); } static inline uniform int32_t smod32(uniform int32_t x, uniform int32_t y) { uniform int32_t r = x % y; return r + (r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0) ? 0 : y); } static inline uniform int64_t smod64(uniform int64_t x, uniform int64_t y) { uniform int64_t r = x % y; return r + (r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0) ? 0 : y); } static inline uniform int8_t sdiv_safe8(uniform int8_t x, uniform int8_t y) { return y == 0 ? 0 : sdiv8(x, y); } static inline uniform int16_t sdiv_safe16(uniform int16_t x, uniform int16_t y) { return y == 0 ? 0 : sdiv16(x, y); } static inline uniform int32_t sdiv_safe32(uniform int32_t x, uniform int32_t y) { return y == 0 ? 0 : sdiv32(x, y); } static inline uniform int64_t sdiv_safe64(uniform int64_t x, uniform int64_t y) { return y == 0 ? 0 : sdiv64(x, y); } static inline uniform int8_t sdiv_up_safe8(uniform int8_t x, uniform int8_t y) { return sdiv_safe8(x + y - 1, y); } static inline uniform int16_t sdiv_up_safe16(uniform int16_t x, uniform int16_t y) { return sdiv_safe16(x + y - 1, y); } static inline uniform int32_t sdiv_up_safe32(uniform int32_t x, uniform int32_t y) { return sdiv_safe32(x + y - 1, y); } static inline uniform int64_t sdiv_up_safe64(uniform int64_t x, uniform int64_t y) { return sdiv_safe64(x + y - 1, y); } static inline uniform int8_t smod_safe8(uniform int8_t x, uniform int8_t y) { return y == 0 ? 0 : smod8(x, y); } static inline uniform int16_t smod_safe16(uniform int16_t x, uniform int16_t y) { return y == 0 ? 0 : smod16(x, y); } static inline uniform int32_t smod_safe32(uniform int32_t x, uniform int32_t y) { return y == 0 ? 0 : smod32(x, y); } static inline uniform int64_t smod_safe64(uniform int64_t x, uniform int64_t y) { return y == 0 ? 0 : smod64(x, y); } static inline uniform int8_t squot8(uniform int8_t x, uniform int8_t y) { return x / y; } static inline uniform int16_t squot16(uniform int16_t x, uniform int16_t y) { return x / y; } static inline uniform int32_t squot32(uniform int32_t x, uniform int32_t y) { return x / y; } static inline uniform int64_t squot64(uniform int64_t x, uniform int64_t y) { return x / y; } static inline uniform int8_t srem8(uniform int8_t x, uniform int8_t y) { return x % y; } static inline uniform int16_t srem16(uniform int16_t x, uniform int16_t y) { return x % y; } static inline uniform int32_t srem32(uniform int32_t x, uniform int32_t y) { return x % y; } static inline uniform int64_t srem64(uniform int64_t x, uniform int64_t y) { return x % y; } static inline uniform int8_t squot_safe8(uniform int8_t x, uniform int8_t y) { return y == 0 ? 0 : x / y; } static inline uniform int16_t squot_safe16(uniform int16_t x, uniform int16_t y) { return y == 0 ? 0 : x / y; } static inline uniform int32_t squot_safe32(uniform int32_t x, uniform int32_t y) { return y == 0 ? 0 : x / y; } static inline uniform int64_t squot_safe64(uniform int64_t x, uniform int64_t y) { return y == 0 ? 0 : x / y; } static inline uniform int8_t srem_safe8(uniform int8_t x, uniform int8_t y) { return y == 0 ? 0 : x % y; } static inline uniform int16_t srem_safe16(uniform int16_t x, uniform int16_t y) { return y == 0 ? 0 : x % y; } static inline uniform int32_t srem_safe32(uniform int32_t x, uniform int32_t y) { return y == 0 ? 0 : x % y; } static inline uniform int64_t srem_safe64(uniform int64_t x, uniform int64_t y) { return y == 0 ? 0 : x % y; } static inline uniform int8_t smin8(uniform int8_t x, uniform int8_t y) { return x < y ? x : y; } static inline uniform int16_t smin16(uniform int16_t x, uniform int16_t y) { return x < y ? x : y; } static inline uniform int32_t smin32(uniform int32_t x, uniform int32_t y) { return x < y ? x : y; } static inline uniform int64_t smin64(uniform int64_t x, uniform int64_t y) { return x < y ? x : y; } static inline uniform uint8_t umin8(uniform uint8_t x, uniform uint8_t y) { return x < y ? x : y; } static inline uniform uint16_t umin16(uniform uint16_t x, uniform uint16_t y) { return x < y ? x : y; } static inline uniform uint32_t umin32(uniform uint32_t x, uniform uint32_t y) { return x < y ? x : y; } static inline uniform uint64_t umin64(uniform uint64_t x, uniform uint64_t y) { return x < y ? x : y; } static inline uniform int8_t smax8(uniform int8_t x, uniform int8_t y) { return x < y ? y : x; } static inline uniform int16_t smax16(uniform int16_t x, uniform int16_t y) { return x < y ? y : x; } static inline uniform int32_t smax32(uniform int32_t x, uniform int32_t y) { return x < y ? y : x; } static inline uniform int64_t smax64(uniform int64_t x, uniform int64_t y) { return x < y ? y : x; } static inline uniform uint8_t umax8(uniform uint8_t x, uniform uint8_t y) { return x < y ? y : x; } static inline uniform uint16_t umax16(uniform uint16_t x, uniform uint16_t y) { return x < y ? y : x; } static inline uniform uint32_t umax32(uniform uint32_t x, uniform uint32_t y) { return x < y ? y : x; } static inline uniform uint64_t umax64(uniform uint64_t x, uniform uint64_t y) { return x < y ? y : x; } static inline uniform uint8_t shl8(uniform uint8_t x, uniform uint8_t y) { return (uniform uint8_t)(x << y); } static inline uniform uint16_t shl16(uniform uint16_t x, uniform uint16_t y) { return (uniform uint16_t)(x << y); } static inline uniform uint32_t shl32(uniform uint32_t x, uniform uint32_t y) { return x << y; } static inline uniform uint64_t shl64(uniform uint64_t x, uniform uint64_t y) { return x << y; } static inline uniform uint8_t lshr8(uniform uint8_t x, uniform uint8_t y) { return x >> y; } static inline uniform uint16_t lshr16(uniform uint16_t x, uniform uint16_t y) { return x >> y; } static inline uniform uint32_t lshr32(uniform uint32_t x, uniform uint32_t y) { return x >> y; } static inline uniform uint64_t lshr64(uniform uint64_t x, uniform uint64_t y) { return x >> y; } static inline uniform int8_t ashr8(uniform int8_t x, uniform int8_t y) { return x >> y; } static inline uniform int16_t ashr16(uniform int16_t x, uniform int16_t y) { return x >> y; } static inline uniform int32_t ashr32(uniform int32_t x, uniform int32_t y) { return x >> y; } static inline uniform int64_t ashr64(uniform int64_t x, uniform int64_t y) { return x >> y; } static inline uniform uint8_t and8(uniform uint8_t x, uniform uint8_t y) { return x & y; } static inline uniform uint16_t and16(uniform uint16_t x, uniform uint16_t y) { return x & y; } static inline uniform uint32_t and32(uniform uint32_t x, uniform uint32_t y) { return x & y; } static inline uniform uint64_t and64(uniform uint64_t x, uniform uint64_t y) { return x & y; } static inline uniform uint8_t or8(uniform uint8_t x, uniform uint8_t y) { return x | y; } static inline uniform uint16_t or16(uniform uint16_t x, uniform uint16_t y) { return x | y; } static inline uniform uint32_t or32(uniform uint32_t x, uniform uint32_t y) { return x | y; } static inline uniform uint64_t or64(uniform uint64_t x, uniform uint64_t y) { return x | y; } static inline uniform uint8_t xor8(uniform uint8_t x, uniform uint8_t y) { return x ^ y; } static inline uniform uint16_t xor16(uniform uint16_t x, uniform uint16_t y) { return x ^ y; } static inline uniform uint32_t xor32(uniform uint32_t x, uniform uint32_t y) { return x ^ y; } static inline uniform uint64_t xor64(uniform uint64_t x, uniform uint64_t y) { return x ^ y; } static inline uniform bool ult8(uniform uint8_t x, uniform uint8_t y) { return x < y; } static inline uniform bool ult16(uniform uint16_t x, uniform uint16_t y) { return x < y; } static inline uniform bool ult32(uniform uint32_t x, uniform uint32_t y) { return x < y; } static inline uniform bool ult64(uniform uint64_t x, uniform uint64_t y) { return x < y; } static inline uniform bool ule8(uniform uint8_t x, uniform uint8_t y) { return x <= y; } static inline uniform bool ule16(uniform uint16_t x, uniform uint16_t y) { return x <= y; } static inline uniform bool ule32(uniform uint32_t x, uniform uint32_t y) { return x <= y; } static inline uniform bool ule64(uniform uint64_t x, uniform uint64_t y) { return x <= y; } static inline uniform bool slt8(uniform int8_t x, uniform int8_t y) { return x < y; } static inline uniform bool slt16(uniform int16_t x, uniform int16_t y) { return x < y; } static inline uniform bool slt32(uniform int32_t x, uniform int32_t y) { return x < y; } static inline uniform bool slt64(uniform int64_t x, uniform int64_t y) { return x < y; } static inline uniform bool sle8(uniform int8_t x, uniform int8_t y) { return x <= y; } static inline uniform bool sle16(uniform int16_t x, uniform int16_t y) { return x <= y; } static inline uniform bool sle32(uniform int32_t x, uniform int32_t y) { return x <= y; } static inline uniform bool sle64(uniform int64_t x, uniform int64_t y) { return x <= y; } static inline uniform uint8_t pow8(uniform uint8_t x, uniform uint8_t y) { uniform uint8_t res = 1, rem = y; while (rem != 0) { if (rem & 1) res *= x; rem >>= 1; x *= x; } return res; } static inline uniform uint16_t pow16(uniform uint16_t x, uniform uint16_t y) { uniform uint16_t res = 1, rem = y; while (rem != 0) { if (rem & 1) res *= x; rem >>= 1; x *= x; } return res; } static inline uniform uint32_t pow32(uniform uint32_t x, uniform uint32_t y) { uniform uint32_t res = 1, rem = y; while (rem != 0) { if (rem & 1) res *= x; rem >>= 1; x *= x; } return res; } static inline uniform uint64_t pow64(uniform uint64_t x, uniform uint64_t y) { uniform uint64_t res = 1, rem = y; while (rem != 0) { if (rem & 1) res *= x; rem >>= 1; x *= x; } return res; } static inline uniform bool itob_i8_bool(uniform int8_t x) { return x != 0; } static inline uniform bool itob_i16_bool(uniform int16_t x) { return x != 0; } static inline uniform bool itob_i32_bool(uniform int32_t x) { return x != 0; } static inline uniform bool itob_i64_bool(uniform int64_t x) { return x != 0; } static inline uniform int8_t btoi_bool_i8(uniform bool x) { return x; } static inline uniform int16_t btoi_bool_i16(uniform bool x) { return x; } static inline uniform int32_t btoi_bool_i32(uniform bool x) { return x; } static inline uniform int64_t btoi_bool_i64(uniform bool x) { return x; } static uniform int8_t abs8(uniform int8_t x) { return (uniform int8_t)abs(x); } static uniform int16_t abs16(uniform int16_t x) { return (uniform int16_t)abs(x); } static uniform int32_t abs32(uniform int32_t x) { return abs(x); } static uniform int64_t abs64(uniform int64_t x) { return abs(x); } static uniform int32_t futrts_popc8(uniform uint8_t x) { uniform int c = 0; for (; x; ++c) { x &= x - 1; } return c; } static uniform int32_t futrts_popc16(uniform uint16_t x) { uniform int c = 0; for (; x; ++c) { x &= x - 1; } return c; } static uniform int32_t futrts_popc32(uniform uint32_t x) { uniform int c = 0; for (; x; ++c) { x &= x - 1; } return c; } static uniform int32_t futrts_popc64(uniform uint64_t x) { uniform int c = 0; for (; x; ++c) { x &= x - 1; } return c; } static uniform uint8_t futrts_mul_hi8(uniform uint8_t a, uniform uint8_t b) { uniform uint16_t aa = a; uniform uint16_t bb = b; return aa * bb >> 8; } static uniform uint16_t futrts_mul_hi16(uniform uint16_t a, uniform uint16_t b) { uniform uint32_t aa = a; uniform uint32_t bb = b; return aa * bb >> 16; } static uniform uint32_t futrts_mul_hi32(uniform uint32_t a, uniform uint32_t b) { uniform uint64_t aa = a; uniform uint64_t bb = b; return aa * bb >> 32; } static uniform uint64_t futrts_mul_hi64(uniform uint64_t a, uniform uint64_t b) { uniform uint64_t ah = a >> 32; uniform uint64_t al = a & 0xffffffff; uniform uint64_t bh = b >> 32; uniform uint64_t bl = b & 0xffffffff; uniform uint64_t p1 = al * bl; uniform uint64_t p2 = al * bh; uniform uint64_t p3 = ah * bl; uniform uint64_t p4 = ah * bh; uniform uint64_t p1h = p1 >> 32; uniform uint64_t p2h = p2 >> 32; uniform uint64_t p3h = p3 >> 32; uniform uint64_t p2l = p2 & 0xffffffff; uniform uint64_t p3l = p3 & 0xffffffff; uniform uint64_t l = p1h + p2l + p3l; uniform uint64_t m = (p2 >> 32) + (p3 >> 32); uniform uint64_t h = (l >> 32) + m + p4; return h; } static uniform uint8_t futrts_mad_hi8(uniform uint8_t a, uniform uint8_t b, uniform uint8_t c) { return futrts_mul_hi8(a, b) + c; } static uniform uint16_t futrts_mad_hi16(uniform uint16_t a, uniform uint16_t b, uniform uint16_t c) { return futrts_mul_hi16(a, b) + c; } static uniform uint32_t futrts_mad_hi32(uniform uint32_t a, uniform uint32_t b, uniform uint32_t c) { return futrts_mul_hi32(a, b) + c; } static uniform uint64_t futrts_mad_hi64(uniform uint64_t a, uniform uint64_t b, uniform uint64_t c) { return futrts_mul_hi64(a, b) + c; } static uniform int32_t futrts_clzz8(uniform int8_t x) { return count_leading_zeros((uniform int32_t)(uniform uint8_t)x)-24; } static uniform int32_t futrts_clzz16(uniform int16_t x) { return count_leading_zeros((uniform int32_t)(uniform uint16_t)x)-16; } static uniform int32_t futrts_clzz32(uniform int32_t x) { return count_leading_zeros(x); } static uniform int32_t futrts_clzz64(uniform int64_t x) { return count_leading_zeros(x); } static uniform int32_t futrts_ctzz8(uniform int8_t x) { return x == 0 ? 8 : count_trailing_zeros((uniform int32_t)x); } static uniform int32_t futrts_ctzz16(uniform int16_t x) { return x == 0 ? 16 : count_trailing_zeros((uniform int32_t)x); } static uniform int32_t futrts_ctzz32(uniform int32_t x) { return count_trailing_zeros(x); } static uniform int32_t futrts_ctzz64(uniform int64_t x) { return count_trailing_zeros(x); } static inline uniform float fdiv32(uniform float x, uniform float y) { return x / y; } static inline uniform float fadd32(uniform float x, uniform float y) { return x + y; } static inline uniform float fsub32(uniform float x, uniform float y) { return x - y; } static inline uniform float fmul32(uniform float x, uniform float y) { return x * y; } static inline uniform bool cmplt32(uniform float x, uniform float y) { return x < y; } static inline uniform bool cmple32(uniform float x, uniform float y) { return x <= y; } static inline uniform float sitofp_i8_f32(uniform int8_t x) { return (uniform float) x; } static inline uniform float sitofp_i16_f32(uniform int16_t x) { return (uniform float) x; } static inline uniform float sitofp_i32_f32(uniform int32_t x) { return (uniform float) x; } static inline uniform float sitofp_i64_f32(uniform int64_t x) { return (uniform float) x; } static inline uniform float uitofp_i8_f32(uniform uint8_t x) { return (uniform float) x; } static inline uniform float uitofp_i16_f32(uniform uint16_t x) { return (uniform float) x; } static inline uniform float uitofp_i32_f32(uniform uint32_t x) { return (uniform float) x; } static inline uniform float uitofp_i64_f32(uniform uint64_t x) { return (uniform float) x; } static inline uniform float fabs32(uniform float x) { return abs(x); } static inline uniform float fmax32(uniform float x, uniform float y) { return isnan(x) ? y : isnan(y) ? x : max(x, y); } static inline uniform float fmin32(uniform float x, uniform float y) { return isnan(x) ? y : isnan(y) ? x : min(x, y); } static inline uniform float fpow32(uniform float x, uniform float y) { return pow(x, y); } static inline uniform bool futrts_isnan32(uniform float x) { return isnan(x); } static inline uniform bool futrts_isinf32(uniform float x) { return !isnan(x) && isnan(x - x); } static inline uniform bool futrts_isfinite32(uniform float x) { return !isnan(x) && !futrts_isinf32(x); } static inline uniform int8_t fptosi_f32_i8(uniform float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (uniform int8_t) x; } } static inline uniform int16_t fptosi_f32_i16(uniform float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (uniform int16_t) x; } } static inline uniform int32_t fptosi_f32_i32(uniform float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (uniform int32_t) x; } } static inline uniform int64_t fptosi_f32_i64(uniform float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (uniform int64_t) x; }; } static inline uniform uint8_t fptoui_f32_i8(uniform float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (uniform uint8_t) (uniform int8_t) x; } } static inline uniform uint16_t fptoui_f32_i16(uniform float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (uniform uint16_t) (uniform int16_t) x; } } static inline uniform uint32_t fptoui_f32_i32(uniform float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (uniform uint32_t) (uniform int32_t) x; } } static inline uniform uint64_t fptoui_f32_i64(uniform float x) { if (futrts_isnan32(x) || futrts_isinf32(x)) { return 0; } else { return (uniform uint64_t) (uniform int64_t) x; } } static inline uniform float futrts_log32(uniform float x) { return futrts_isfinite32(x) || (futrts_isinf32(x) && x < 0)? log(x) : x; } static inline uniform float futrts_log2_32(uniform float x) { return futrts_log32(x) / log(2.0f); } static inline uniform float futrts_log10_32(uniform float x) { return futrts_log32(x) / log(10.0f); } static inline uniform float futrts_log1p_32(uniform float x) { if(x == -1.0f || (futrts_isinf32(x) && x > 0.0f)) return x / 0.0f; uniform float y = 1.0f + x; uniform float z = y - 1.0f; return log(y) - (z-x)/y; } static inline uniform float futrts_sqrt32(uniform float x) { return sqrt(x); } extern "C" unmasked uniform float cbrtf(uniform float); static inline uniform float futrts_cbrt32(uniform float x) { return cbrtf(x); } static inline uniform float futrts_exp32(uniform float x) { return exp(x); } static inline uniform float futrts_cos32(uniform float x) { return cos(x); } static inline uniform float futrts_sin32(uniform float x) { return sin(x); } static inline uniform float futrts_tan32(uniform float x) { return tan(x); } static inline uniform float futrts_acos32(uniform float x) { return acos(x); } static inline uniform float futrts_asin32(uniform float x) { return asin(x); } static inline uniform float futrts_atan32(uniform float x) { return atan(x); } static inline uniform float futrts_cosh32(uniform float x) { return (exp(x)+exp(-x)) / 2.0f; } static inline uniform float futrts_sinh32(uniform float x) { return (exp(x)-exp(-x)) / 2.0f; } static inline uniform float futrts_tanh32(uniform float x) { return futrts_sinh32(x)/futrts_cosh32(x); } static inline uniform float futrts_acosh32(uniform float x) { uniform float f = x+sqrt(x*x-1); if(futrts_isfinite32(f)) return log(f); return f; } static inline uniform float futrts_asinh32(uniform float x) { uniform float f = x+sqrt(x*x+1); if(futrts_isfinite32(f)) return log(f); return f; } static inline uniform float futrts_atanh32(uniform float x) { uniform float f = (1+x)/(1-x); if(futrts_isfinite32(f)) return log(f)/2.0f; return f; } static inline uniform float futrts_atan2_32(uniform float x, uniform float y) { return (x == 0.0f && y == 0.0f) ? 0.0f : atan2(x, y); } static inline uniform float futrts_hypot32(uniform float x, uniform float y) { if (futrts_isfinite32(x) && futrts_isfinite32(y)) { x = abs(x); y = abs(y); uniform float a; uniform float b; if (x >= y){ a = x; b = y; } else { a = y; b = x; } if(b == 0){ return a; } uniform int e; uniform float an; uniform float bn; an = frexp (a, &e); bn = ldexp (b, - e); uniform float cn; cn = sqrt (an * an + bn * bn); return ldexp (cn, e); } else { if (futrts_isinf32(x) || futrts_isinf32(y)) return INFINITY; else return x + y; } } extern "C" unmasked uniform float tgammaf(uniform float x); static inline uniform float futrts_gamma32(uniform float x) { return tgammaf(x); } extern "C" unmasked uniform float tgammaf(uniform float x); static inline uniform float futrts_lgamma32(uniform float x) { return lgammaf(x); } extern "C" unmasked uniform float erff(uniform float); static inline uniform float futrts_erf32(uniform float x) { return erff(x); } extern "C" unmasked uniform float erfcf(uniform float); static inline uniform float futrts_erfc32(uniform float x) { return erfcf(x); } static inline uniform float fmod32(uniform float x, uniform float y) { return x - y * trunc(x/y); } static inline uniform float futrts_round32(uniform float x) { return round(x); } static inline uniform float futrts_floor32(uniform float x) { return floor(x); } static inline uniform float futrts_ceil32(uniform float x) { return ceil(x); } static inline uniform float futrts_lerp32(uniform float v0, uniform float v1, uniform float t) { return v0 + (v1 - v0) * t; } static inline uniform float futrts_mad32(uniform float a, uniform float b, uniform float c) { return a * b + c; } static inline uniform float futrts_fma32(uniform float a, uniform float b, uniform float c) { return a * b + c; } static inline uniform int32_t futrts_to_bits32(uniform float x) { return intbits(x); } static inline uniform float futrts_from_bits32(uniform int32_t x) { return floatbits(x); } static inline uniform float fsignum32(uniform float x) { return futrts_isnan32(x) ? x : (x > 0 ? 1 : 0) - (x < 0 ? 1 : 0); } #ifdef FUTHARK_F64_ENABLED static inline uniform bool futrts_isinf64(uniform float x) { return !isnan(x) && isnan(x - x); } static inline uniform bool futrts_isfinite64(uniform float x) { return !isnan(x) && !futrts_isinf64(x); } static inline uniform double fdiv64(uniform double x, uniform double y) { return x / y; } static inline uniform double fadd64(uniform double x, uniform double y) { return x + y; } static inline uniform double fsub64(uniform double x, uniform double y) { return x - y; } static inline uniform double fmul64(uniform double x, uniform double y) { return x * y; } static inline uniform bool cmplt64(uniform double x, uniform double y) { return x < y; } static inline uniform bool cmple64(uniform double x, uniform double y) { return x <= y; } static inline uniform double sitofp_i8_f64(uniform int8_t x) { return (uniform double) x; } static inline uniform double sitofp_i16_f64(uniform int16_t x) { return (uniform double) x; } static inline uniform double sitofp_i32_f64(uniform int32_t x) { return (uniform double) x; } static inline uniform double sitofp_i64_f64(uniform int64_t x) { return (uniform double) x; } static inline uniform double uitofp_i8_f64(uniform uint8_t x) { return (uniform double) x; } static inline uniform double uitofp_i16_f64(uniform uint16_t x) { return (uniform double) x; } static inline uniform double uitofp_i32_f64(uniform uint32_t x) { return (uniform double) x; } static inline uniform double uitofp_i64_f64(uniform uint64_t x) { return (uniform double) x; } static inline uniform double fabs64(uniform double x) { return abs(x); } static inline uniform double fmax64(uniform double x, uniform double y) { return isnan(x) ? y : isnan(y) ? x : max(x, y); } static inline uniform double fmin64(uniform double x, uniform double y) { return isnan(x) ? y : isnan(y) ? x : min(x, y); } static inline uniform double fpow64(uniform double x, uniform double y) { return pow(x, y); } static inline uniform double futrts_log64(uniform double x) { return futrts_isfinite64(x) || (futrts_isinf64(x) && x < 0)? log(x) : x; } static inline uniform double futrts_log2_64(uniform double x) { return futrts_log64(x)/log(2.0d); } static inline uniform double futrts_log10_64(uniform double x) { return futrts_log64(x)/log(10.0d); } static inline uniform double futrts_log1p_64(uniform double x) { if(x == -1.0d || (futrts_isinf64(x) && x > 0.0d)) return x / 0.0d; uniform double y = 1.0d + x; uniform double z = y - 1.0d; return log(y) - (z-x)/y; } static inline uniform double futrts_sqrt64(uniform double x) { return sqrt(x); } extern "C" unmasked uniform double cbrt(uniform double); static inline uniform double futrts_cbrt64(uniform double x) { return cbrt(x); } static inline uniform double futrts_exp64(uniform double x) { return exp(x); } static inline uniform double futrts_cos64(uniform double x) { return cos(x); } static inline uniform double futrts_sin64(uniform double x) { return sin(x); } static inline uniform double futrts_tan64(uniform double x) { return tan(x); } static inline uniform double futrts_acos64(uniform double x) { return acos(x); } static inline uniform double futrts_asin64(uniform double x) { return asin(x); } static inline uniform double futrts_atan64(uniform double x) { return atan(x); } static inline uniform double futrts_cosh64(uniform double x) { return (exp(x)+exp(-x)) / 2.0d; } static inline uniform double futrts_sinh64(uniform double x) { return (exp(x)-exp(-x)) / 2.0d; } static inline uniform double futrts_tanh64(uniform double x) { return futrts_sinh64(x)/futrts_cosh64(x); } static inline uniform double futrts_acosh64(uniform double x) { uniform double f = x+sqrt(x*x-1.0d); if(futrts_isfinite64(f)) return log(f); return f; } static inline uniform double futrts_asinh64(uniform double x) { uniform double f = x+sqrt(x*x+1.0d); if(futrts_isfinite64(f)) return log(f); return f; } static inline uniform double futrts_atanh64(uniform double x) { uniform double f = (1.0d+x)/(1.0d-x); if(futrts_isfinite64(f)) return log(f)/2.0d; return f; } static inline uniform double futrts_atan2_64(uniform double x, uniform double y) { return atan2(x, y); } extern "C" unmasked uniform double hypot(uniform double x, uniform double y); static inline uniform double futrts_hypot64(uniform double x, uniform double y) { return hypot(x, y); } extern "C" unmasked uniform double tgamma(uniform double x); static inline uniform double futrts_gamma64(uniform double x) { return tgamma(x); } extern "C" unmasked uniform double lgamma(uniform double x); static inline uniform double futrts_lgamma64(uniform double x) { return lgamma(x); } extern "C" unmasked uniform double erf(uniform double); static inline uniform double futrts_erf64(uniform double x) { return erf(x); } extern "C" unmasked uniform double erfc(uniform double); static inline uniform double futrts_erfc64(uniform double x) { return erfc(x); } static inline uniform double futrts_fma64(uniform double a, uniform double b, uniform double c) { return a * b + c; } static inline uniform double futrts_round64(uniform double x) { return round(x); } static inline uniform double futrts_ceil64(uniform double x) { return ceil(x); } static inline uniform double futrts_floor64(uniform double x) { return floor(x); } static inline uniform bool futrts_isnan64(uniform double x) { return isnan(x); } static inline uniform int8_t fptosi_f64_i8(uniform double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uniform int8_t) x; } } static inline uniform int16_t fptosi_f64_i16(uniform double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uniform int16_t) x; } } static inline uniform int32_t fptosi_f64_i32(uniform double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uniform int32_t) x; } } static inline uniform int64_t fptosi_f64_i64(uniform double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uniform int64_t) x; } } static inline uniform uint8_t fptoui_f64_i8(uniform double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uniform uint8_t) (uniform int8_t) x; } } static inline uniform uint16_t fptoui_f64_i16(uniform double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uniform uint16_t) (uniform int16_t) x; } } static inline uniform uint32_t fptoui_f64_i32(uniform double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uniform uint32_t) (uniform int32_t) x; } } static inline uniform uint64_t fptoui_f64_i64(uniform double x) { if (futrts_isnan64(x) || futrts_isinf64(x)) { return 0; } else { return (uniform uint64_t) (uniform int64_t) x; } } static inline uniform bool ftob_f64_bool(uniform double x) { return x != 0.0; } static inline uniform double btof_bool_f64(uniform bool x) { return x ? 1.0 : 0.0; } static inline uniform bool ftob_f32_bool(uniform float x) { return x != 0; } static inline uniform float btof_bool_f32(uniform bool x) { return x ? 1 : 0; } static inline uniform int64_t futrts_to_bits64(uniform double x) { return *((uniform int64_t* uniform)&x); } static inline uniform double futrts_from_bits64(uniform int64_t x) { return *((uniform double* uniform)&x); } static inline uniform double fmod64(uniform double x, uniform double y) { return x - y * trunc(x/y); } static inline uniform double fsignum64(uniform double x) { return futrts_isnan64(x) ? x : (x > 0 ? 1.0d : 0.0d) - (x < 0 ? 1.0d : 0.0d); } static inline uniform double futrts_lerp64(uniform double v0, uniform double v1, uniform double t) { return v0 + (v1 - v0) * t; } static inline uniform double futrts_mad64(uniform double a, uniform double b, uniform double c) { return a * b + c; } static inline uniform float fpconv_f32_f32(uniform float x) { return (uniform float) x; } static inline uniform double fpconv_f32_f64(uniform float x) { return (uniform double) x; } static inline uniform float fpconv_f64_f32(uniform double x) { return (uniform float) x; } static inline uniform double fpconv_f64_f64(uniform double x) { return (uniform double) x; } static inline uniform double fpconv_f16_f64(uniform f16 x) { return (uniform double) x; } static inline uniform f16 fpconv_f64_f16(uniform double x) { return (uniform f16) ((uniform float)x); } #endif static inline uniform f16 fadd16(uniform f16 x, uniform f16 y) { return x + y; } static inline uniform f16 fsub16(uniform f16 x, uniform f16 y) { return x - y; } static inline uniform f16 fmul16(uniform f16 x, uniform f16 y) { return x * y; } static inline uniform bool cmplt16(uniform f16 x, uniform f16 y) { return x < y; } static inline uniform bool cmple16(uniform f16 x, uniform f16 y) { return x <= y; } static inline uniform f16 sitofp_i8_f16(uniform int8_t x) { return (uniform f16) x; } static inline uniform f16 sitofp_i16_f16(uniform int16_t x) { return (uniform f16) x; } static inline uniform f16 sitofp_i32_f16(uniform int32_t x) { return (uniform f16) x; } static inline uniform f16 sitofp_i64_f16(uniform int64_t x) { return (uniform f16) x; } static inline uniform f16 uitofp_i8_f16(uniform uint8_t x) { return (uniform f16) x; } static inline uniform f16 uitofp_i16_f16(uniform uint16_t x) { return (uniform f16) x; } static inline uniform f16 uitofp_i32_f16(uniform uint32_t x) { return (uniform f16) x; } static inline uniform f16 uitofp_i64_f16(uniform uint64_t x) { return (uniform f16) x; } static inline uniform int8_t fptosi_f16_i8(uniform f16 x) { return (uniform int8_t) (uniform float) x; } static inline uniform int16_t fptosi_f16_i16(uniform f16 x) { return (uniform int16_t) x; } static inline uniform int32_t fptosi_f16_i32(uniform f16 x) { return (uniform int32_t) x; } static inline uniform int64_t fptosi_f16_i64(uniform f16 x) { return (uniform int64_t) x; } static inline uniform uint8_t fptoui_f16_i8(uniform f16 x) { return (uniform uint8_t) (uniform float) x; } static inline uniform uint16_t fptoui_f16_i16(uniform f16 x) { return (uniform uint16_t) x; } static inline uniform uint32_t fptoui_f16_i32(uniform f16 x) { return (uniform uint32_t) x; } static inline uniform uint64_t fptoui_f16_i64(uniform f16 x) { return (uniform uint64_t) x; } static inline uniform f16 fabs16(uniform f16 x) { return abs(x); } static inline uniform bool futrts_isnan16(uniform f16 x) { return isnan((uniform float)x); } static inline uniform f16 fmax16(uniform f16 x, uniform f16 y) { return futrts_isnan16(x) ? y : futrts_isnan16(y) ? x : max(x, y); } static inline uniform f16 fmin16(uniform f16 x, uniform f16 y) { return min(x, y); } static inline uniform f16 fpow16(uniform f16 x, uniform f16 y) { return pow(x, y); } static inline uniform bool futrts_isinf16(uniform float x) { return !futrts_isnan16(x) && futrts_isnan16(x - x); } static inline uniform bool futrts_isfinite16(uniform float x) { return !futrts_isnan16(x) && !futrts_isinf16(x); } static inline uniform f16 futrts_log16(uniform f16 x) { return futrts_isfinite16(x) || (futrts_isinf16(x) && x < 0)? log(x) : x; } static inline uniform f16 futrts_log2_16(uniform f16 x) { return futrts_log16(x) / log(2.0f16); } static inline uniform f16 futrts_log10_16(uniform f16 x) { return futrts_log16(x) / log(10.0f16); } static inline uniform f16 futrts_log1p_16(uniform f16 x) { if(x == -1.0f16 || (futrts_isinf16(x) && x > 0.0f16)) return x / 0.0f16; uniform f16 y = 1.0f16 + x; uniform f16 z = y - 1.0f16; return log(y) - (z-x)/y; } static inline uniform f16 futrts_sqrt16(uniform f16 x) { return (uniform f16)sqrt((uniform float)x); } extern "C" unmasked uniform float cbrtf(uniform float); static inline uniform f16 futrts_cbrt16(uniform f16 x) { return (uniform f16)cbrtf((uniform float)x); } static inline uniform f16 futrts_exp16(uniform f16 x) { return exp(x); } static inline uniform f16 futrts_cos16(uniform f16 x) { return (uniform f16)cos((uniform float)x); } static inline uniform f16 futrts_sin16(uniform f16 x) { return (uniform f16)sin((uniform float)x); } static inline uniform f16 futrts_tan16(uniform f16 x) { return (uniform f16)tan((uniform float)x); } static inline uniform f16 futrts_acos16(uniform f16 x) { return (uniform f16)acos((uniform float)x); } static inline uniform f16 futrts_asin16(uniform f16 x) { return (uniform f16)asin((uniform float)x); } static inline uniform f16 futrts_atan16(uniform f16 x) { return (uniform f16)atan((uniform float)x); } static inline uniform f16 futrts_cosh16(uniform f16 x) { return (exp(x)+exp(-x)) / 2.0f16; } static inline uniform f16 futrts_sinh16(uniform f16 x) { return (exp(x)-exp(-x)) / 2.0f16; } static inline uniform f16 futrts_tanh16(uniform f16 x) { return futrts_sinh16(x)/futrts_cosh16(x); } static inline uniform f16 futrts_acosh16(uniform f16 x) { uniform f16 f = x+(uniform f16)sqrt((uniform float)(x*x-1)); if(futrts_isfinite16(f)) return log(f); return f; } static inline uniform f16 futrts_asinh16(uniform f16 x) { uniform f16 f = x+(uniform f16)sqrt((uniform float)(x*x+1)); if(futrts_isfinite16(f)) return log(f); return f; } static inline uniform f16 futrts_atanh16(uniform f16 x) { uniform f16 f = (1+x)/(1-x); if(futrts_isfinite16(f)) return log(f)/2.0f16; return f; } static inline uniform f16 futrts_atan2_16(uniform f16 x, uniform f16 y) { return (uniform f16)atan2((uniform float)x, (uniform float)y); } static inline uniform f16 futrts_hypot16(uniform f16 x, uniform f16 y) { return (uniform f16)futrts_hypot32((uniform float)x, (uniform float)y); } extern "C" unmasked uniform float tgammaf(uniform float x); static inline uniform f16 futrts_gamma16(uniform f16 x) { return (uniform f16)tgammaf((uniform float)x); } extern "C" unmasked uniform float lgammaf(uniform float x); static inline uniform f16 futrts_lgamma16(uniform f16 x) { return (uniform f16)lgammaf((uniform float)x); } extern "C" unmasked uniform float erff(uniform float); static inline uniform f16 futrts_erf32(uniform f16 x) { return (uniform f16)erff((uniform float)x); } extern "C" unmasked uniform float erfcf(uniform float); static inline uniform f16 futrts_erfc32(uniform f16 x) { return (uniform f16)erfcf((uniform float)x); } static inline uniform f16 fmod16(uniform f16 x, uniform f16 y) { return x - y * (uniform f16)trunc((uniform float) (x/y)); } static inline uniform f16 futrts_round16(uniform f16 x) { return (uniform f16)round((uniform float)x); } static inline uniform f16 futrts_floor16(uniform f16 x) { return (uniform f16)floor((uniform float)x); } static inline uniform f16 futrts_ceil16(uniform f16 x) { return (uniform f16)ceil((uniform float)x); } static inline uniform f16 futrts_lerp16(uniform f16 v0, uniform f16 v1, uniform f16 t) { return v0 + (v1 - v0) * t; } static inline uniform f16 futrts_mad16(uniform f16 a, uniform f16 b, uniform f16 c) { return a * b + c; } static inline uniform f16 futrts_fma16(uniform f16 a, uniform f16 b, uniform f16 c) { return a * b + c; } static inline uniform int16_t futrts_to_bits16(uniform f16 x) { return *((uniform int16_t *)&x); } static inline uniform f16 futrts_from_bits16(uniform int16_t x) { return *((uniform f16 *)&x); } static inline uniform float fpconv_f16_f16(uniform f16 x) { return x; } static inline uniform float fpconv_f16_f32(uniform f16 x) { return x; } static inline uniform f16 fpconv_f32_f16(uniform float x) { return (uniform f16) x; } #endif // End of uniform.h. futhark-0.25.27/rts/c/util.h000066400000000000000000000120331475065116200155210ustar00rootroot00000000000000// Start of util.h. // // Various helper functions that are useful in all generated C code. #include #include static const char *fut_progname = "(embedded Futhark)"; static void futhark_panic(int eval, const char *fmt, ...) __attribute__((noreturn)); static char* msgprintf(const char *s, ...); static void* slurp_file(const char *filename, size_t *size); static int dump_file(const char *file, const void *buf, size_t n); struct str_builder; static void str_builder_init(struct str_builder *b); static void str_builder(struct str_builder *b, const char *s, ...); static char *strclone(const char *str); static void futhark_panic(int eval, const char *fmt, ...) { va_list ap; va_start(ap, fmt); fprintf(stderr, "%s: ", fut_progname); vfprintf(stderr, fmt, ap); va_end(ap); exit(eval); } // For generating arbitrary-sized error messages. It is the callers // responsibility to free the buffer at some point. static char* msgprintf(const char *s, ...) { va_list vl; va_start(vl, s); size_t needed = 1 + (size_t)vsnprintf(NULL, 0, s, vl); char *buffer = (char*) malloc(needed); va_start(vl, s); // Must re-init. vsnprintf(buffer, needed, s, vl); return buffer; } static inline void check_err(int errval, int sets_errno, const char *fun, int line, const char *msg, ...) { if (errval) { char errnum[10]; va_list vl; va_start(vl, msg); fprintf(stderr, "ERROR: "); vfprintf(stderr, msg, vl); fprintf(stderr, " in %s() at line %d with error code %s\n", fun, line, sets_errno ? strerror(errno) : errnum); exit(errval); } } #define CHECK_ERR(err, ...) check_err(err, 0, __func__, __LINE__, __VA_ARGS__) #define CHECK_ERRNO(err, ...) check_err(err, 1, __func__, __LINE__, __VA_ARGS__) // Read the rest of an open file into a NUL-terminated string; returns // NULL on error. static void* fslurp_file(FILE *f, size_t *size) { long start = ftell(f); fseek(f, 0, SEEK_END); long src_size = ftell(f)-start; fseek(f, start, SEEK_SET); unsigned char *s = (unsigned char*) malloc((size_t)src_size + 1); if (fread(s, 1, (size_t)src_size, f) != (size_t)src_size) { free(s); s = NULL; } else { s[src_size] = '\0'; } if (size) { *size = (size_t)src_size; } return s; } // Read a file into a NUL-terminated string; returns NULL on error. static void* slurp_file(const char *filename, size_t *size) { FILE *f = fopen(filename, "rb"); // To avoid Windows messing with linebreaks. if (f == NULL) return NULL; unsigned char *s = fslurp_file(f, size); fclose(f); return s; } // Dump 'n' bytes from 'buf' into the file at the designated location. // Returns 0 on success. static int dump_file(const char *file, const void *buf, size_t n) { FILE *f = fopen(file, "w"); if (f == NULL) { return 1; } if (fwrite(buf, sizeof(char), n, f) != n) { return 1; } if (fclose(f) != 0) { return 1; } return 0; } struct str_builder { char *str; size_t capacity; // Size of buffer. size_t used; // Bytes used, *not* including final zero. }; static void str_builder_init(struct str_builder *b) { b->capacity = 10; b->used = 0; b->str = malloc(b->capacity); b->str[0] = 0; } static void str_builder(struct str_builder *b, const char *s, ...) { va_list vl; va_start(vl, s); size_t needed = (size_t)vsnprintf(NULL, 0, s, vl); while (b->capacity < b->used + needed + 1) { b->capacity *= 2; b->str = realloc(b->str, b->capacity); } va_start(vl, s); // Must re-init. vsnprintf(b->str+b->used, b->capacity-b->used, s, vl); b->used += needed; } static void str_builder_str(struct str_builder *b, const char *s) { size_t needed = strlen(s); if (b->capacity < b->used + needed + 1) { b->capacity *= 2; b->str = realloc(b->str, b->capacity); } strcpy(b->str+b->used, s); b->used += needed; } static void str_builder_char(struct str_builder *b, char c) { size_t needed = 1; if (b->capacity < b->used + needed + 1) { b->capacity *= 2; b->str = realloc(b->str, b->capacity); } b->str[b->used] = c; b->str[b->used+1] = 0; b->used += needed; } static void str_builder_json_str(struct str_builder* sb, const char* s) { str_builder_char(sb, '"'); for (int j = 0; s[j]; j++) { char c = s[j]; switch (c) { case '\n': str_builder_str(sb, "\\n"); break; case '"': str_builder_str(sb, "\\\""); break; default: str_builder_char(sb, c); } } str_builder_char(sb, '"'); } static char *strclone(const char *str) { size_t size = strlen(str) + 1; char *copy = (char*) malloc(size); if (copy == NULL) { return NULL; } memcpy(copy, str, size); return copy; } // Assumes NULL-terminated. static char *strconcat(const char *src_fragments[]) { size_t src_len = 0; const char **p; for (p = src_fragments; *p; p++) { src_len += strlen(*p); } char *src = (char*) malloc(src_len + 1); size_t n = 0; for (p = src_fragments; *p; p++) { strcpy(src + n, *p); n += strlen(*p); } return src; } // End of util.h. futhark-0.25.27/rts/c/values.h000066400000000000000000000540421475065116200160510ustar00rootroot00000000000000// Start of values.h. //// Text I/O typedef int (*writer)(FILE*, const void*); typedef int (*bin_reader)(void*); typedef int (*str_reader)(const char *, void*); struct array_reader { char* elems; int64_t n_elems_space; int64_t elem_size; int64_t n_elems_used; int64_t *shape; str_reader elem_reader; }; static void skipspaces(FILE *f) { int c; do { c = getc(f); } while (isspace(c)); if (c != EOF) { ungetc(c, f); } } static int constituent(char c) { return isalnum(c) || c == '.' || c == '-' || c == '+' || c == '_'; } // Produces an empty token only on EOF. static void next_token(FILE *f, char *buf, int bufsize) { start: skipspaces(f); int i = 0; while (i < bufsize) { int c = getc(f); buf[i] = (char)c; if (c == EOF) { buf[i] = 0; return; } else if (c == '-' && i == 1 && buf[0] == '-') { // Line comment, so skip to end of line and start over. for (; c != '\n' && c != EOF; c = getc(f)); goto start; } else if (!constituent((char)c)) { if (i == 0) { // We permit single-character tokens that are not // constituents; this lets things like ']' and ',' be // tokens. buf[i+1] = 0; return; } else { ungetc(c, f); buf[i] = 0; return; } } i++; } buf[bufsize-1] = 0; } static int next_token_is(FILE *f, char *buf, int bufsize, const char* expected) { next_token(f, buf, bufsize); return strcmp(buf, expected) == 0; } static void remove_underscores(char *buf) { char *w = buf; for (char *r = buf; *r; r++) { if (*r != '_') { *w++ = *r; } } *w++ = 0; } static int read_str_elem(char *buf, struct array_reader *reader) { int ret; if (reader->n_elems_used == reader->n_elems_space) { reader->n_elems_space *= 2; reader->elems = (char*) realloc(reader->elems, (size_t)(reader->n_elems_space * reader->elem_size)); } ret = reader->elem_reader(buf, reader->elems + reader->n_elems_used * reader->elem_size); if (ret == 0) { reader->n_elems_used++; } return ret; } static int read_str_array_elems(FILE *f, char *buf, int bufsize, struct array_reader *reader, int64_t dims) { int ret = 1; int expect_elem = 1; char *knows_dimsize = (char*) calloc((size_t)dims, sizeof(char)); int cur_dim = (int)dims-1; int64_t *elems_read_in_dim = (int64_t*) calloc((size_t)dims, sizeof(int64_t)); while (1) { next_token(f, buf, bufsize); if (strcmp(buf, "]") == 0) { expect_elem = 0; if (knows_dimsize[cur_dim]) { if (reader->shape[cur_dim] != elems_read_in_dim[cur_dim]) { ret = 1; break; } } else { knows_dimsize[cur_dim] = 1; reader->shape[cur_dim] = elems_read_in_dim[cur_dim]; } if (cur_dim == 0) { ret = 0; break; } else { cur_dim--; elems_read_in_dim[cur_dim]++; } } else if (!expect_elem && strcmp(buf, ",") == 0) { expect_elem = 1; } else if (expect_elem) { if (strcmp(buf, "[") == 0) { if (cur_dim == dims - 1) { ret = 1; break; } cur_dim++; elems_read_in_dim[cur_dim] = 0; } else if (cur_dim == dims - 1) { ret = read_str_elem(buf, reader); if (ret != 0) { break; } expect_elem = 0; elems_read_in_dim[cur_dim]++; } else { ret = 1; break; } } else { ret = 1; break; } } free(knows_dimsize); free(elems_read_in_dim); return ret; } static int read_str_empty_array(FILE *f, char *buf, int bufsize, const char *type_name, int64_t *shape, int64_t dims) { if (strlen(buf) == 0) { // EOF return 1; } if (strcmp(buf, "empty") != 0) { return 1; } if (!next_token_is(f, buf, bufsize, "(")) { return 1; } for (int i = 0; i < dims; i++) { if (!next_token_is(f, buf, bufsize, "[")) { return 1; } next_token(f, buf, bufsize); if (sscanf(buf, "%"SCNu64, (uint64_t*)&shape[i]) != 1) { return 1; } if (!next_token_is(f, buf, bufsize, "]")) { return 1; } } if (!next_token_is(f, buf, bufsize, type_name)) { return 1; } if (!next_token_is(f, buf, bufsize, ")")) { return 1; } // Check whether the array really is empty. for (int i = 0; i < dims; i++) { if (shape[i] == 0) { return 0; } } // Not an empty array! return 1; } static int read_str_array(FILE *f, int64_t elem_size, str_reader elem_reader, const char *type_name, void **data, int64_t *shape, int64_t dims) { int ret; struct array_reader reader; char buf[100]; int dims_seen; for (dims_seen = 0; dims_seen < dims; dims_seen++) { if (!next_token_is(f, buf, sizeof(buf), "[")) { break; } } if (dims_seen == 0) { return read_str_empty_array(f, buf, sizeof(buf), type_name, shape, dims); } if (dims_seen != dims) { return 1; } reader.shape = shape; reader.n_elems_used = 0; reader.elem_size = elem_size; reader.n_elems_space = 16; reader.elems = (char*) realloc(*data, (size_t)(elem_size*reader.n_elems_space)); reader.elem_reader = elem_reader; ret = read_str_array_elems(f, buf, sizeof(buf), &reader, dims); *data = reader.elems; return ret; } #define READ_STR(MACRO, PTR, SUFFIX) \ remove_underscores(buf); \ int j; \ if (sscanf(buf, "%"MACRO"%n", (PTR*)dest, &j) == 1) { \ return !(strcmp(buf+j, "") == 0 || strcmp(buf+j, SUFFIX) == 0); \ } else { \ return 1; \ } static int read_str_i8(char *buf, void* dest) { // Some platforms (WINDOWS) does not support scanf %hhd or its // cousin, %SCNi8. Read into int first to avoid corrupting // memory. // // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63417 remove_underscores(buf); int j, x; if (sscanf(buf, "%i%n", &x, &j) == 1) { *(int8_t*)dest = (int8_t)x; return !(strcmp(buf+j, "") == 0 || strcmp(buf+j, "i8") == 0); } else { return 1; } } static int read_str_u8(char *buf, void* dest) { // Some platforms (WINDOWS) does not support scanf %hhd or its // cousin, %SCNu8. Read into int first to avoid corrupting // memory. // // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63417 remove_underscores(buf); int j, x; if (sscanf(buf, "%i%n", &x, &j) == 1) { *(uint8_t*)dest = (uint8_t)x; return !(strcmp(buf+j, "") == 0 || strcmp(buf+j, "u8") == 0); } else { return 1; } } static int read_str_i16(char *buf, void* dest) { READ_STR(SCNi16, int16_t, "i16"); } static int read_str_u16(char *buf, void* dest) { READ_STR(SCNi16, int16_t, "u16"); } static int read_str_i32(char *buf, void* dest) { READ_STR(SCNi32, int32_t, "i32"); } static int read_str_u32(char *buf, void* dest) { READ_STR(SCNi32, int32_t, "u32"); } static int read_str_i64(char *buf, void* dest) { READ_STR(SCNi64, int64_t, "i64"); } static int read_str_u64(char *buf, void* dest) { // FIXME: This is not correct, as SCNu64 only permits decimal // literals. However, SCNi64 does not handle very large numbers // correctly (it's really for signed numbers, so that's fair). READ_STR(SCNu64, uint64_t, "u64"); } static int read_str_f16(char *buf, void* dest) { remove_underscores(buf); if (strcmp(buf, "f16.nan") == 0) { *(uint16_t*)dest = float2halfbits(NAN); return 0; } else if (strcmp(buf, "f16.inf") == 0) { *(uint16_t*)dest = float2halfbits(INFINITY); return 0; } else if (strcmp(buf, "-f16.inf") == 0) { *(uint16_t*)dest = float2halfbits(-INFINITY); return 0; } else { int j; float x; if (sscanf(buf, "%f%n", &x, &j) == 1) { if (strcmp(buf+j, "") == 0 || strcmp(buf+j, "f16") == 0) { *(uint16_t*)dest = float2halfbits(x); return 0; } } return 1; } } static int read_str_f32(char *buf, void* dest) { remove_underscores(buf); if (strcmp(buf, "f32.nan") == 0) { *(float*)dest = (float)NAN; return 0; } else if (strcmp(buf, "f32.inf") == 0) { *(float*)dest = (float)INFINITY; return 0; } else if (strcmp(buf, "-f32.inf") == 0) { *(float*)dest = (float)-INFINITY; return 0; } else { READ_STR("f", float, "f32"); } } static int read_str_f64(char *buf, void* dest) { remove_underscores(buf); if (strcmp(buf, "f64.nan") == 0) { *(double*)dest = (double)NAN; return 0; } else if (strcmp(buf, "f64.inf") == 0) { *(double*)dest = (double)INFINITY; return 0; } else if (strcmp(buf, "-f64.inf") == 0) { *(double*)dest = (double)-INFINITY; return 0; } else { READ_STR("lf", double, "f64"); } } static int read_str_bool(char *buf, void* dest) { if (strcmp(buf, "true") == 0) { *(char*)dest = 1; return 0; } else if (strcmp(buf, "false") == 0) { *(char*)dest = 0; return 0; } else { return 1; } } static int write_str_i8(FILE *out, int8_t *src) { return fprintf(out, "%hhdi8", *src); } static int write_str_u8(FILE *out, uint8_t *src) { return fprintf(out, "%hhuu8", *src); } static int write_str_i16(FILE *out, int16_t *src) { return fprintf(out, "%hdi16", *src); } static int write_str_u16(FILE *out, uint16_t *src) { return fprintf(out, "%huu16", *src); } static int write_str_i32(FILE *out, int32_t *src) { return fprintf(out, "%di32", *src); } static int write_str_u32(FILE *out, uint32_t *src) { return fprintf(out, "%uu32", *src); } static int write_str_i64(FILE *out, int64_t *src) { return fprintf(out, "%"PRIi64"i64", *src); } static int write_str_u64(FILE *out, uint64_t *src) { return fprintf(out, "%"PRIu64"u64", *src); } static int write_str_f16(FILE *out, uint16_t *src) { float x = halfbits2float(*src); if (isnan(x)) { return fprintf(out, "f16.nan"); } else if (isinf(x) && x >= 0) { return fprintf(out, "f16.inf"); } else if (isinf(x)) { return fprintf(out, "-f16.inf"); } else { return fprintf(out, "%.*ff16", FLT_DIG, x); } } static int write_str_f32(FILE *out, float *src) { float x = *src; if (isnan(x)) { return fprintf(out, "f32.nan"); } else if (isinf(x) && x >= 0) { return fprintf(out, "f32.inf"); } else if (isinf(x)) { return fprintf(out, "-f32.inf"); } else { return fprintf(out, "%.*ff32", FLT_DIG, x); } } static int write_str_f64(FILE *out, double *src) { double x = *src; if (isnan(x)) { return fprintf(out, "f64.nan"); } else if (isinf(x) && x >= 0) { return fprintf(out, "f64.inf"); } else if (isinf(x)) { return fprintf(out, "-f64.inf"); } else { return fprintf(out, "%.*ff64", DBL_DIG, x); } } static int write_str_bool(FILE *out, void *src) { return fprintf(out, *(char*)src ? "true" : "false"); } //// Binary I/O #define BINARY_FORMAT_VERSION 2 #define IS_BIG_ENDIAN (!*(unsigned char *)&(uint16_t){1}) static void flip_bytes(size_t elem_size, unsigned char *elem) { for (size_t j=0; j #include static void set_binary_mode(FILE *f) { setmode(fileno(f), O_BINARY); } #else static void set_binary_mode(FILE *f) { (void)f; } #endif static int read_byte(FILE *f, void* dest) { size_t num_elems_read = fread(dest, 1, 1, f); return num_elems_read == 1 ? 0 : 1; } //// Types struct primtype_info_t { const char binname[4]; // Used for parsing binary data. const char* type_name; // Same name as in Futhark. const int64_t size; // in bytes const writer write_str; // Write in text format. const str_reader read_str; // Read in text format. }; static const struct primtype_info_t i8_info = {.binname = " i8", .type_name = "i8", .size = 1, .write_str = (writer)write_str_i8, .read_str = (str_reader)read_str_i8}; static const struct primtype_info_t i16_info = {.binname = " i16", .type_name = "i16", .size = 2, .write_str = (writer)write_str_i16, .read_str = (str_reader)read_str_i16}; static const struct primtype_info_t i32_info = {.binname = " i32", .type_name = "i32", .size = 4, .write_str = (writer)write_str_i32, .read_str = (str_reader)read_str_i32}; static const struct primtype_info_t i64_info = {.binname = " i64", .type_name = "i64", .size = 8, .write_str = (writer)write_str_i64, .read_str = (str_reader)read_str_i64}; static const struct primtype_info_t u8_info = {.binname = " u8", .type_name = "u8", .size = 1, .write_str = (writer)write_str_u8, .read_str = (str_reader)read_str_u8}; static const struct primtype_info_t u16_info = {.binname = " u16", .type_name = "u16", .size = 2, .write_str = (writer)write_str_u16, .read_str = (str_reader)read_str_u16}; static const struct primtype_info_t u32_info = {.binname = " u32", .type_name = "u32", .size = 4, .write_str = (writer)write_str_u32, .read_str = (str_reader)read_str_u32}; static const struct primtype_info_t u64_info = {.binname = " u64", .type_name = "u64", .size = 8, .write_str = (writer)write_str_u64, .read_str = (str_reader)read_str_u64}; static const struct primtype_info_t f16_info = {.binname = " f16", .type_name = "f16", .size = 2, .write_str = (writer)write_str_f16, .read_str = (str_reader)read_str_f16}; static const struct primtype_info_t f32_info = {.binname = " f32", .type_name = "f32", .size = 4, .write_str = (writer)write_str_f32, .read_str = (str_reader)read_str_f32}; static const struct primtype_info_t f64_info = {.binname = " f64", .type_name = "f64", .size = 8, .write_str = (writer)write_str_f64, .read_str = (str_reader)read_str_f64}; static const struct primtype_info_t bool_info = {.binname = "bool", .type_name = "bool", .size = 1, .write_str = (writer)write_str_bool, .read_str = (str_reader)read_str_bool}; static const struct primtype_info_t* primtypes[] = { &i8_info, &i16_info, &i32_info, &i64_info, &u8_info, &u16_info, &u32_info, &u64_info, &f16_info, &f32_info, &f64_info, &bool_info, NULL // NULL-terminated }; // General value interface. All endian business taken care of at // lower layers. static int read_is_binary(FILE *f) { skipspaces(f); int c = getc(f); if (c == 'b') { int8_t bin_version; int ret = read_byte(f, &bin_version); if (ret != 0) { futhark_panic(1, "binary-input: could not read version.\n"); } if (bin_version != BINARY_FORMAT_VERSION) { futhark_panic(1, "binary-input: File uses version %i, but I only understand version %i.\n", bin_version, BINARY_FORMAT_VERSION); } return 1; } ungetc(c, f); return 0; } static const struct primtype_info_t* read_bin_read_type_enum(FILE *f) { char read_binname[4]; int num_matched = fscanf(f, "%4c", read_binname); if (num_matched != 1) { futhark_panic(1, "binary-input: Couldn't read element type.\n"); } const struct primtype_info_t **type = primtypes; for (; *type != NULL; type++) { // I compare the 4 characters manually instead of using strncmp because // this allows any value to be used, also NULL bytes if (memcmp(read_binname, (*type)->binname, 4) == 0) { return *type; } } futhark_panic(1, "binary-input: Did not recognize the type '%s'.\n", read_binname); return NULL; } static void read_bin_ensure_scalar(FILE *f, const struct primtype_info_t *expected_type) { int8_t bin_dims; int ret = read_byte(f, &bin_dims); if (ret != 0) { futhark_panic(1, "binary-input: Couldn't get dims.\n"); } if (bin_dims != 0) { futhark_panic(1, "binary-input: Expected scalar (0 dimensions), but got array with %i dimensions.\n", bin_dims); } const struct primtype_info_t *bin_type = read_bin_read_type_enum(f); if (bin_type != expected_type) { futhark_panic(1, "binary-input: Expected scalar of type %s but got scalar of type %s.\n", expected_type->type_name, bin_type->type_name); } } //// High-level interface static int read_bin_array(FILE *f, const struct primtype_info_t *expected_type, void **data, int64_t *shape, int64_t dims) { int ret; int8_t bin_dims; ret = read_byte(f, &bin_dims); if (ret != 0) { futhark_panic(1, "binary-input: Couldn't get dims.\n"); } if (bin_dims != dims) { futhark_panic(1, "binary-input: Expected %i dimensions, but got array with %i dimensions.\n", dims, bin_dims); } const struct primtype_info_t *bin_primtype = read_bin_read_type_enum(f); if (expected_type != bin_primtype) { futhark_panic(1, "binary-input: Expected %iD-array with element type '%s' but got %iD-array with element type '%s'.\n", dims, expected_type->type_name, dims, bin_primtype->type_name); } int64_t elem_count = 1; for (int i=0; isize; void* tmp = realloc(*data, (size_t)(elem_count * elem_size)); if (tmp == NULL) { futhark_panic(1, "binary-input: Failed to allocate array of size %i.\n", elem_count * elem_size); } *data = tmp; int64_t num_elems_read = (int64_t)fread(*data, (size_t)elem_size, (size_t)elem_count, f); if (num_elems_read != elem_count) { futhark_panic(1, "binary-input: tried to read %i elements of an array, but only got %i elements.\n", elem_count, num_elems_read); } // If we're on big endian platform we must change all multibyte elements // from using little endian to big endian if (IS_BIG_ENDIAN && elem_size != 1) { flip_bytes((size_t)elem_size, (unsigned char*) *data); } return 0; } static int read_array(FILE *f, const struct primtype_info_t *expected_type, void **data, int64_t *shape, int64_t dims) { if (!read_is_binary(f)) { return read_str_array(f, expected_type->size, (str_reader)expected_type->read_str, expected_type->type_name, data, shape, dims); } else { return read_bin_array(f, expected_type, data, shape, dims); } } static int end_of_input(FILE *f) { skipspaces(f); char token[2]; next_token(f, token, sizeof(token)); if (strcmp(token, "") == 0) { return 0; } else { return 1; } } static int write_str_array(FILE *out, const struct primtype_info_t *elem_type, const unsigned char *data, const int64_t *shape, int8_t rank) { if (rank==0) { elem_type->write_str(out, (const void*)data); } else { int64_t len = (int64_t)shape[0]; int64_t slice_size = 1; int64_t elem_size = elem_type->size; for (int8_t i = 1; i < rank; i++) { slice_size *= shape[i]; } if (len*slice_size == 0) { fprintf(out, "empty("); for (int64_t i = 0; i < rank; i++) { fprintf(out, "[%"PRIi64"]", shape[i]); } fprintf(out, "%s", elem_type->type_name); fprintf(out, ")"); } else if (rank==1) { fputc('[', out); for (int64_t i = 0; i < len; i++) { elem_type->write_str(out, (const void*) (data + i * elem_size)); if (i != len-1) { fprintf(out, ", "); } } fputc(']', out); } else { fputc('[', out); for (int64_t i = 0; i < len; i++) { write_str_array(out, elem_type, data + i * slice_size * elem_size, shape+1, rank-1); if (i != len-1) { fprintf(out, ", "); } } fputc(']', out); } } return 0; } static int write_bin_array(FILE *out, const struct primtype_info_t *elem_type, const unsigned char *data, const int64_t *shape, int8_t rank) { int64_t num_elems = 1; for (int64_t i = 0; i < rank; i++) { num_elems *= shape[i]; } fputc('b', out); fputc((char)BINARY_FORMAT_VERSION, out); fwrite(&rank, sizeof(int8_t), 1, out); fwrite(elem_type->binname, 4, 1, out); if (shape != NULL) { fwrite(shape, sizeof(int64_t), (size_t)rank, out); } if (IS_BIG_ENDIAN) { for (int64_t i = 0; i < num_elems; i++) { const unsigned char *elem = data+i*elem_type->size; for (int64_t j = 0; j < elem_type->size; j++) { fwrite(&elem[elem_type->size-j], 1, 1, out); } } } else { fwrite(data, (size_t)elem_type->size, (size_t)num_elems, out); } return 0; } static int write_array(FILE *out, int write_binary, const struct primtype_info_t *elem_type, const void *data, const int64_t *shape, const int8_t rank) { if (write_binary) { return write_bin_array(out, elem_type, data, shape, rank); } else { return write_str_array(out, elem_type, data, shape, rank); } } static int read_scalar(FILE *f, const struct primtype_info_t *expected_type, void *dest) { if (!read_is_binary(f)) { char buf[100]; next_token(f, buf, sizeof(buf)); return expected_type->read_str(buf, dest); } else { read_bin_ensure_scalar(f, expected_type); size_t elem_size = (size_t)expected_type->size; size_t num_elems_read = fread(dest, elem_size, 1, f); if (IS_BIG_ENDIAN) { flip_bytes(elem_size, (unsigned char*) dest); } return num_elems_read == 1 ? 0 : 1; } } static int write_scalar(FILE *out, int write_binary, const struct primtype_info_t *type, void *src) { if (write_binary) { return write_bin_array(out, type, src, NULL, 0); } else { return type->write_str(out, src); } } // End of values.h. futhark-0.25.27/rts/cuda/000077500000000000000000000000001475065116200150665ustar00rootroot00000000000000futhark-0.25.27/rts/cuda/prelude.cu000066400000000000000000000045351475065116200170660ustar00rootroot00000000000000// start of prelude.cu #define SCALAR_FUN_ATTR __device__ static inline #define FUTHARK_FUN_ATTR __device__ static #define FUTHARK_F64_ENABLED typedef char int8_t; typedef short int16_t; typedef int int32_t; typedef long long int64_t; typedef unsigned char uint8_t; typedef unsigned short uint16_t; typedef unsigned int uint32_t; typedef unsigned long long uint64_t; #define __global #define __local #define __private #define __constant #define __write_only #define __read_only static inline __device__ int get_tblock_id(int d) { switch (d) { case 0: return blockIdx.x; case 1: return blockIdx.y; case 2: return blockIdx.z; default: return 0; } } static inline __device__ int get_num_tblocks(int d) { switch(d) { case 0: return gridDim.x; case 1: return gridDim.y; case 2: return gridDim.z; default: return 0; } } static inline __device__ int get_global_id(int d) { switch (d) { case 0: return threadIdx.x + blockIdx.x * blockDim.x; case 1: return threadIdx.y + blockIdx.y * blockDim.y; case 2: return threadIdx.z + blockIdx.z * blockDim.z; default: return 0; } } static inline __device__ int get_local_id(int d) { switch (d) { case 0: return threadIdx.x; case 1: return threadIdx.y; case 2: return threadIdx.z; default: return 0; } } static inline __device__ int get_local_size(int d) { switch (d) { case 0: return blockDim.x; case 1: return blockDim.y; case 2: return blockDim.z; default: return 0; } } static inline __device__ int get_global_size(int d) { switch (d) { case 0: return gridDim.x * blockDim.x; case 1: return gridDim.y * blockDim.y; case 2: return gridDim.z * blockDim.z; default: return 0; } } #define CLK_LOCAL_MEM_FENCE 1 #define CLK_GLOBAL_MEM_FENCE 2 static inline __device__ void barrier(int x) { __syncthreads(); } static inline __device__ void mem_fence_local() { __threadfence_block(); } static inline __device__ void mem_fence_global() { __threadfence(); } static inline __device__ void barrier_local() { __syncthreads(); } #define NAN (0.0/0.0) #define INFINITY (1.0/0.0) extern volatile __shared__ unsigned char shared_mem[]; #define SHARED_MEM_PARAM #define FUTHARK_KERNEL extern "C" __global__ __launch_bounds__(MAX_THREADS_PER_BLOCK) #define FUTHARK_KERNEL_SIZED(a,b,c) extern "C" __global__ __launch_bounds__(a*b*c) // End of prelude.cu futhark-0.25.27/rts/futhark-doc/000077500000000000000000000000001475065116200163615ustar00rootroot00000000000000futhark-0.25.27/rts/futhark-doc/style.css000066400000000000000000000060701475065116200202360ustar00rootroot00000000000000body { background-color: #fff9e5; font-family: sans-serif; padding: 0px; margin: 0px; margin-top: 0em; margin-left: 0em; margin-right: 0em; overflow-y: scroll; line-height: 1.4; } h1 { margin-top: 0; display: inline; } h2 { margin-bottom: 1ex; font-size: 2em; color: #5f021f; } h3 { font-size: 1em; color: #5f021f; } p { max-width: 65em; } td { vertical-align: top; } #header { color: #fff9e5; background: #5f0220; padding-left: 1em; } #navigation { float: right; vertical-align: centre; } #navigation ol { display: block; } #navigation li { display: inline; list-style-type: none; border-left: 1px solid #fff9e5; padding-left: 1em; padding-right: 1em; } #navigation a, #navigation a:visited { color: #fff9e5; text-decoration: none; } #content { display: flex; } main { padding-left: 1em; padding-right: 1em; width: 100%; } #abstract { margin-bottom: 2ex; } #footer { color: #fff9e5; background: #5f0220; padding-left: 1em; padding-right: 1em; } #footer a, #footer a:visited { color: #fff9e5; } @media (max-width:950px) { #filenav { display: none; } } #filenav { color: #fff9e5; background: #5f0220; padding-left: 1em; padding-right: 1em; } #filenav a:hover { background: #fff9e5; display: block; color: #5f0220; } #filenav a, #footer a:visited { color: #fff9e5; text-decoration: none; } #filenav > ul { list-style: none; padding: 0; font-family: monospace; font-weight: bold; } .file_desc { /* Hack to avoid breaking a file description across multiple columns on the contents page. */ display: table; width: 100%; } @media (min-width:950px) { .file_list { columns: 3; } .file_list > dd { break-before: avoid; } } .keyword { font-weight: bold; } .specs { margin-left: 2ex; border-collapse: collapse; } .specs tr:hover td { background: #ccc; } .spec_eql { text-align: right; } .decl_description { margin-bottom: 2ex; } .decl_name { font-weight: bold; } .desc_header { background: #ffcc7a; font-family: monospace; border-top: 1px solid #5f021f; } .desc_doc { margin-left: 2ex; } .synopsis_link { font-weight: bold; text-decoration: none; font-size: 1.3em; width: 2ex; display: inline-block; } .self_link { text-decoration: none; } #module { font-family: monospace; white-space: pre; display: block; unicode-bidi: embed; } #doc_index thead { font-weight: bold; } #doc_index td { padding-right: 1em; } .doc_index_name { font-family: monospace; } .doc_index_namespace { font-style: italic; } .doc_index_initial { font-size: 2em; font-weight: bold; } .doc_index_initial a, .doc_index_initial a:visited { color: #5f0220; } #doc_index_list li { display: inline; list-style-type: none; padding-right: 1em; } #doc_index_list a, #doc_index_list a:visited { color: #5f0220; text-decoration: none; } futhark-0.25.27/rts/javascript/000077500000000000000000000000001475065116200163205ustar00rootroot00000000000000futhark-0.25.27/rts/javascript/server.js000066400000000000000000000152671475065116200201770ustar00rootroot00000000000000// Start of server.js class Server { constructor(ctx) { this.ctx = ctx; this._vars = {}; this._types = {}; this._commands = [ 'inputs', 'outputs', 'call', 'restore', 'store', 'free', 'clear', 'pause_profiling', 'unpause_profiling', 'report', 'rename' ]; } _get_arg(args, i) { if (i < args.length) { return args[i]; } else { throw 'Insufficient command args'; } } _get_entry_point(entry) { if (entry in this.ctx.get_entry_points()) { return this.ctx.get_entry_points()[entry]; } else { throw "Unkown entry point: " + entry; } } _check_var(vname) { if (!(vname in this._vars)) { throw 'Unknown variable: ' + vname; } } _set_var(vname, v, t) { this._vars[vname] = v; this._types[vname] = t; } _get_type(vname) { this._check_var(vname); return this._types[vname]; } _get_var(vname) { this._check_var(vname); return this._vars[vname]; } _cmd_inputs(args) { var entry = this._get_arg(args, 0); var inputs = this._get_entry_point(entry)[1]; for (var i = 0; i < inputs.length; i++) { console.log(inputs[i]); } } _cmd_outputs(args) { var entry = this._get_arg(args, 0); var outputs = this._get_entry_point(entry)[2]; for (var i = 0; i < outputs.length; i++) { console.log(outputs[i]); } } _cmd_dummy(args) { // pass } _cmd_free(args) { for (var i = 0; i < args.length; i++) { var vname = args[i]; this._check_var(vname); delete this._vars[vname]; } } _cmd_rename(args) { var oldname = this._get_arg(args, 0) var newname = this._get_arg(args, 1) if (newname in this._vars) { throw "Variable already exists: " + newname; } this._vars[newname] = this._vars[oldname]; this._types[newname] = this._types[oldname]; delete this._vars[oldname]; delete this._types[oldname]; } _cmd_call(args) { var entry = this._get_entry_point(this._get_arg(args, 0)); var num_ins = entry[1].length; var num_outs = entry[2].length; var expected_len = 1 + num_outs + num_ins if (args.length != expected_len) { throw "Invalid argument count, expected " + expected_len } var out_vnames = args.slice(1, num_outs+1) for (var i = 0; i < out_vnames.length; i++) { var out_vname = out_vnames[i]; if (out_vname in this._vars) { throw "Variable already exists: " + out_vname; } } var in_vnames = args.slice(1+num_outs); var ins = []; for (var i = 0; i < in_vnames.length; i++) { ins.push(this._get_var(in_vnames[i])); } // Call entry point function from string name var bef = performance.now()*1000; var vals = this.ctx[entry[0]].apply(this.ctx, ins); var aft = performance.now()*1000; if (num_outs == 1) { this._set_var(out_vnames[0], vals, entry[2][0]); } else { for (var i = 0; i < out_vnames.length; i++) { this._set_var(out_vnames[i], vals[i], entry[2][i]); } } console.log("runtime: " + Math.round(aft-bef)); } _cmd_store(args) { var fname = this._get_arg(args, 0); for (var i = 1; i < args.length; i++) { var vname = args[i]; var value = this._get_var(vname); var typ = this._get_type(vname); var fs = require("fs"); var bin_val = construct_binary_value(value, typ); fs.appendFileSync(fname, bin_val, 'binary') } } fut_to_dim_typ(typ) { var type = typ; var count = 0; while (type.substr(0, 2) == '[]') { count = count + 1; type = type.slice(2); } return [count, type]; } _cmd_restore(args) { if (args.length % 2 == 0) { throw "Invalid argument count"; } var fname = args[0]; var args = args.slice(1); var as = args; var reader = new Reader(fname); while (as.length != 0) { var vname = as[0]; var typename = as[1]; as = as.slice(2); if (vname in this._vars) { throw "Variable already exists: " + vname; } try { var value = read_value(typename, reader); if (typeof value == 'number' || typeof value == 'bigint') { this._set_var(vname, value, typename); } else { // We are working with an array and need to create to convert [shape, arr] to futhark ptr var shape= value[0]; var arr = value[1]; var dimtyp = this.fut_to_dim_typ(typename); var dim = dimtyp[0]; var typ = dimtyp[1]; var arg_list = [arr, ...shape]; var fnam = "new_" + typ + "_" + dim + "d"; var ptr = this.ctx[fnam].apply(this.ctx, arg_list); this._set_var(vname, ptr, typename); } } catch (err) { var err_msg = "Failed to restore variable " + vname + ".\nPossibly malformed data in " + fname + ".\n" + err.toString(); throw err_msg; } } skip_spaces(reader); if (reader.get_buff().length != 0) { throw "Expected EOF after reading values"; } } _process_line(line) { // TODO make sure it splits on anywhite space var words = line.split(" "); if (words.length == 0) { throw "Empty line"; } else { var cmd = words[0]; var args = words.splice(1); if (this._commands.includes(cmd)) { switch (cmd) { case 'inputs': this._cmd_inputs(args); break; case 'outputs': this._cmd_outputs(args); break case 'call': this._cmd_call(args); break case 'restore': this._cmd_restore(args); break case 'store': this._cmd_store(args); break case 'free': this._cmd_free(args); break case 'clear': this._cmd_dummy(args); break case 'pause_profiling': this._cmd_dummy(args); break case 'unpause_profiling': this._cmd_dummy(args); break case 'report': this._cmd_dummy(args); break case 'rename': this._cmd_rename(args); break } } else { throw "Unknown command: " + cmd; } } } run() { console.log('%%% OK'); // TODO figure out if flushing is neccesary for JS const readline = require('readline'); const rl = readline.createInterface(process.stdin); rl.on('line', (line) => { if (line == "") { rl.close(); } try { this._process_line(line); console.log('%%% OK'); } catch (err) { console.log('%%% FAILURE'); console.log(err); console.log('%%% OK'); } }).on('close', () => { process.exit(0); }); } } // End of server.js futhark-0.25.27/rts/javascript/values.js000066400000000000000000000116041475065116200201570ustar00rootroot00000000000000// Start of values.js var futharkPrimtypes = new Set([ 'i8', 'i16', 'i32', 'i64', 'u8', 'u16', 'u32', 'u64', 'f16', 'f32', 'f64', 'bool']); var typToType = { ' i8' : Int8Array , ' i16' : Int16Array , ' i32' : Int32Array , ' i64' : BigInt64Array , ' u8' : Uint8Array , ' u16' : Uint16Array , ' u32' : Uint32Array , ' u64' : BigUint64Array , ' f16' : Uint16Array , ' f32' : Float32Array , ' f64' : Float64Array , 'bool' : Uint8Array }; function binToStringArray(buff, array) { for (var i = 0; i < array.length; i++) { array[i] = buff[i]; } } function fileToBuff(fname) { var readline = require('readline'); var fs = require('fs'); var buff = fs.readFileSync(fname); return buff; } var typToSize = { "bool" : 1, " u8" : 1, " i8" : 1, " u16" : 2, " i16" : 2, " u32" : 4, " i32" : 4, " f16" : 2, " f32" : 4, " u64" : 8, " i64" : 8, " f64" : 8, } function toU8(ta) { return new Uint8Array(ta.buffer, ta.byteOffset, ta.byteLength); } function construct_binary_value(v, typ) { var dims; var payload_bytes; var filler; if (v instanceof FutharkOpaque) { throw "Opaques are not supported"; } else if (v instanceof FutharkArray) { var t = v.futharkType(); var ftype = " ".slice(t.length) + t; var shape = v.shape(); var ta = v.toTypedArray(shape); var da = new BigInt64Array(shape); dims = shape.length; payload_bytes = da.byteLength + ta.byteLength; filler = (bytes) => { bytes.set(toU8(da), 7); bytes.set(toU8(ta), 7 + da.byteLength); } } else { var ftype = " ".slice(typ.length) + typ; dims = 0; payload_bytes = typToSize[ftype]; filler = (bytes) => { var scalar = new (typToType[ftype])([v]); bytes.set(toU8(scalar), 7); } } var total_bytes = 7 + payload_bytes; var bytes = new Uint8Array(total_bytes); bytes[0] = Buffer.from('b').readUInt8(); bytes[1] = 2; bytes[2] = dims; for (var i = 0; i < 4; i++) { bytes[3+i] = ftype.charCodeAt(i); } filler(bytes); return Buffer.from(bytes); } class Reader { constructor(f) { this.f = f; this.buff = fileToBuff(f); } read_bin_array(num_dim, typ) { var u8_array = new Uint8Array(num_dim * 8); binToStringArray(this.buff.slice(0, num_dim * 8), u8_array); var shape = new BigInt64Array(u8_array.buffer); var length = shape[0]; for (var i = 1; i < shape.length; i++) { length = length * shape[i]; } length = Number(length); var dbytes = typToSize[typ]; var u8_data = new Uint8Array(length * dbytes); binToStringArray(this.buff.slice(num_dim * 8, num_dim * 8 + dbytes * length), u8_data); var data = new (typToType[typ])(u8_data.buffer); var tmp_buff = this.buff.slice(num_dim * 8, num_dim * 8 + dbytes * length); this.buff = this.buff.slice(num_dim * 8 + dbytes * length); return [shape, data]; } read_bin_scalar(typ) { var size = typToSize[typ]; var u8_array = new Uint8Array(size); binToStringArray(this.buff, u8_array); var array = new (typToType[typ])(u8_array.buffer); this.buff = this.buff.slice(u8_array.length); // Update buff to be unread part of the string return array[0]; } skip_spaces() { while (this.buff.length > 0 && this.buff.slice(0, 1).toString().trim() == "") { this.buff = this.buff.slice(1); } } read_binary(typename, dim) { // Skip leading white space while (this.buff.slice(0, 1).toString().trim() == "") { this.buff = this.buff.slice(1); } if (this.buff[0] != 'b'.charCodeAt(0)) { throw "Not in binary format" } var version = this.buff[1]; if (version != 2) { throw "Not version 2"; } var num_dim = this.buff[2]; var typ = this.buff.slice(3, 7); this.buff = this.buff.slice(7); var exp_typ = "[]".repeat(dim) + typename; var given_typ = "[]".repeat(num_dim) + typ.toString().trim(); console.log(exp_typ); console.log(given_typ); if (exp_typ !== given_typ) { throw ("Expected type : " + exp_typ + ", Actual type : " + given_typ); } if (num_dim === 0) { return this.read_bin_scalar(typ); } else { return this.read_bin_array(num_dim, typ); } } get_buff() { return this.buff; } } // Function is redudant but is helpful for keeping consistent with python implementation function skip_spaces(reader) { reader.skip_spaces(); } function read_value(typename, reader) { var typ = typename; var dim = 0; while (typ.slice(0, 2) === "[]") { dim = dim + 1; typ = typ.slice(2); } if (!futharkPrimtypes.has(typ)) { throw ("Unkown type: " + typ); } var val = reader.read_binary(typ, dim); return val; } // End of values.js futhark-0.25.27/rts/javascript/wrapperclasses.js000066400000000000000000000035121475065116200217150ustar00rootroot00000000000000// Start of wrapperclasses.js class FutharkArray { constructor(ctx, ptr, type_name, dim, heap, fshape, fvalues, ffree) { this.ctx = ctx; this.ptr = ptr; this.type_name = type_name; this.dim = dim; this.heap = heap; this.fshape = fshape; this.fvalues = fvalues; this.ffree = ffree; this.valid = true; } validCheck() { if (!this.valid) { throw "Using freed memory" } } futharkType() { return this.type_name; } free() { this.validCheck(); this.ffree(this.ctx.ctx, this.ptr); this.valid = false; } shape() { this.validCheck(); var s = this.fshape(this.ctx.ctx, this.ptr) >> 3; return Array.from(this.ctx.wasm.HEAP64.subarray(s, s + this.dim)); } toTypedArray(dims = this.shape()) { this.validCheck(); console.assert(dims.length === this.dim, "dim=%s,dims=%s", this.dim, dims.toString()); var length = Number(dims.reduce((a, b) => a * b)); var v = this.fvalues(this.ctx.ctx, this.ptr) / this.heap.BYTES_PER_ELEMENT; return this.heap.subarray(v, v + length); } toArray() { this.validCheck(); var dims = this.shape(); var ta = this.toTypedArray(dims); return (function nest(offs, ds) { var d0 = Number(ds[0]); if (ds.length === 1) { return Array.from(ta.subarray(offs, offs + d0)); } else { var d1 = Number(ds[1]); return Array.from(Array(d0), (x,i) => nest(offs + i * d1, ds.slice(1))); } })(0, dims); } } class FutharkOpaque { constructor(ctx, ptr, ffree) { this.ctx = ctx; this.ptr = ptr; this.ffree = ffree; this.valid = true; } validCheck() { if (!this.valid) { throw "Using freed memory" } } free() { this.validCheck(); this.ffree(this.ctx.ctx, this.ptr); this.valid = false; } } // End of wrapperclasses.js futhark-0.25.27/rts/opencl/000077500000000000000000000000001475065116200154325ustar00rootroot00000000000000futhark-0.25.27/rts/opencl/copy.cl000066400000000000000000000130471475065116200167310ustar00rootroot00000000000000// Start of copy.cl #define GEN_COPY_KERNEL(NAME, ELEM_TYPE) \ FUTHARK_KERNEL void lmad_copy_##NAME(SHARED_MEM_PARAM \ __global ELEM_TYPE *dst_mem, \ int64_t dst_offset, \ __global ELEM_TYPE *src_mem, \ int64_t src_offset, \ int64_t n, \ int r, \ int64_t shape0, int64_t dst_stride0, int64_t src_stride0, \ int64_t shape1, int64_t dst_stride1, int64_t src_stride1, \ int64_t shape2, int64_t dst_stride2, int64_t src_stride2, \ int64_t shape3, int64_t dst_stride3, int64_t src_stride3, \ int64_t shape4, int64_t dst_stride4, int64_t src_stride4, \ int64_t shape5, int64_t dst_stride5, int64_t src_stride5, \ int64_t shape6, int64_t dst_stride6, int64_t src_stride6, \ int64_t shape7, int64_t dst_stride7, int64_t src_stride7) { \ int64_t gtid = get_global_id(0); \ int64_t remainder = gtid; \ \ if (gtid >= n) { \ return; \ } \ \ if (r > 0) { \ int64_t i = remainder % shape0; \ dst_offset += i * dst_stride0; \ src_offset += i * src_stride0; \ remainder /= shape0; \ } \ if (r > 1) { \ int64_t i = remainder % shape1; \ dst_offset += i * dst_stride1; \ src_offset += i * src_stride1; \ remainder /= shape1; \ } \ if (r > 2) { \ int64_t i = remainder % shape2; \ dst_offset += i * dst_stride2; \ src_offset += i * src_stride2; \ remainder /= shape2; \ } \ if (r > 3) { \ int64_t i = remainder % shape3; \ dst_offset += i * dst_stride3; \ src_offset += i * src_stride3; \ remainder /= shape3; \ } \ if (r > 4) { \ int64_t i = remainder % shape4; \ dst_offset += i * dst_stride4; \ src_offset += i * src_stride4; \ remainder /= shape4; \ } \ if (r > 5) { \ int64_t i = remainder % shape5; \ dst_offset += i * dst_stride5; \ src_offset += i * src_stride5; \ remainder /= shape5; \ } \ if (r > 6) { \ int64_t i = remainder % shape6; \ dst_offset += i * dst_stride6; \ src_offset += i * src_stride6; \ remainder /= shape6; \ } \ if (r > 7) { \ int64_t i = remainder % shape7; \ dst_offset += i * dst_stride7; \ src_offset += i * src_stride7; \ remainder /= shape7; \ } \ \ dst_mem[dst_offset] = src_mem[src_offset]; \ } GEN_COPY_KERNEL(1b, uint8_t) GEN_COPY_KERNEL(2b, uint16_t) GEN_COPY_KERNEL(4b, uint32_t) GEN_COPY_KERNEL(8b, uint64_t) // End of copy.cl futhark-0.25.27/rts/opencl/prelude.cl000066400000000000000000000030661475065116200174170ustar00rootroot00000000000000// Start of prelude.cl #define SCALAR_FUN_ATTR static inline #define FUTHARK_FUN_ATTR static typedef char int8_t; typedef short int16_t; typedef int int32_t; typedef long int64_t; typedef uchar uint8_t; typedef ushort uint16_t; typedef uint uint32_t; typedef ulong uint64_t; #define get_tblock_id(d) get_group_id(d) #define get_num_tblocks(d) get_num_groups(d) // Clang-based OpenCL implementations need this for 'static' to work. #ifdef cl_clang_storage_class_specifiers #pragma OPENCL EXTENSION cl_clang_storage_class_specifiers : enable #endif #pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable #ifdef FUTHARK_F64_ENABLED #pragma OPENCL EXTENSION cl_khr_fp64 : enable #endif #pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable #pragma OPENCL EXTENSION cl_khr_int64_extended_atomics : enable // NVIDIAs OpenCL does not create device-wide memory fences (see #734), so we // use inline assembly if we detect we are on an NVIDIA GPU. #ifdef cl_nv_pragma_unroll static inline void mem_fence_global() { asm("membar.gl;"); } #else static inline void mem_fence_global() { mem_fence(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE); } #endif static inline void mem_fence_local() { mem_fence(CLK_LOCAL_MEM_FENCE); } static inline void barrier_local() { barrier(CLK_LOCAL_MEM_FENCE); } // Important for this to be int64_t so it has proper alignment for any type. #define SHARED_MEM_PARAM __local uint64_t* shared_mem, #define FUTHARK_KERNEL __kernel #define FUTHARK_KERNEL_SIZED(a,b,c) __attribute__((reqd_work_group_size(a, b, c))) __kernel // End of prelude.cl futhark-0.25.27/rts/opencl/transpose.cl000066400000000000000000000467601475065116200200050ustar00rootroot00000000000000// Start of transpose.cl #define GEN_TRANSPOSE_KERNELS(NAME, ELEM_TYPE) \ FUTHARK_KERNEL_SIZED(TR_BLOCK_DIM*2, TR_TILE_DIM/TR_ELEMS_PER_THREAD, 1)\ void map_transpose_##NAME(SHARED_MEM_PARAM \ __global ELEM_TYPE *dst_mem, \ int64_t dst_offset, \ __global ELEM_TYPE *src_mem, \ int64_t src_offset, \ int32_t num_arrays, \ int32_t x_elems, \ int32_t y_elems, \ int32_t mulx, \ int32_t muly, \ int32_t repeat_1, \ int32_t repeat_2) { \ (void)mulx; (void)muly; \ __local ELEM_TYPE* block = (__local ELEM_TYPE*)shared_mem; \ int tblock_id_0 = get_tblock_id(0); \ int global_id_0 = get_global_id(0); \ int tblock_id_1 = get_tblock_id(1); \ int global_id_1 = get_global_id(1); \ for (int i1 = 0; i1 <= repeat_1; i1++) { \ int tblock_id_2 = get_tblock_id(2); \ int global_id_2 = get_global_id(2); \ for (int i2 = 0; i2 <= repeat_2; i2++) { \ int32_t our_array_offset = tblock_id_2 * x_elems * y_elems; \ int32_t odata_offset = dst_offset + our_array_offset; \ int32_t idata_offset = src_offset + our_array_offset; \ int32_t x_index = global_id_0; \ int32_t y_index = tblock_id_1 * TR_TILE_DIM + get_local_id(1); \ if (x_index < x_elems) { \ for (int32_t j = 0; j < TR_ELEMS_PER_THREAD; j++) { \ int32_t index_i = (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)) * x_elems + x_index; \ if (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD) < y_elems) { \ block[(get_local_id(1) + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)) * (TR_TILE_DIM+1) + \ get_local_id(0)] = \ src_mem[idata_offset + index_i]; \ } \ } \ } \ barrier_local(); \ x_index = tblock_id_1 * TR_TILE_DIM + get_local_id(0); \ y_index = tblock_id_0 * TR_TILE_DIM + get_local_id(1); \ if (x_index < y_elems) { \ for (int32_t j = 0; j < TR_ELEMS_PER_THREAD; j++) { \ int32_t index_out = (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)) * y_elems + x_index; \ if (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD) < x_elems) { \ dst_mem[(odata_offset + index_out)] = \ block[get_local_id(0) * (TR_TILE_DIM+1) + \ get_local_id(1) + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)]; \ } \ } \ } \ tblock_id_2 += get_num_tblocks(2); \ global_id_2 += get_global_size(2); \ } \ tblock_id_1 += get_num_tblocks(1); \ global_id_1 += get_global_size(1); \ } \ } \ \ FUTHARK_KERNEL_SIZED(TR_BLOCK_DIM, TR_BLOCK_DIM, 1) \ void map_transpose_##NAME##_low_height(SHARED_MEM_PARAM \ __global ELEM_TYPE *dst_mem, \ int64_t dst_offset, \ __global ELEM_TYPE *src_mem, \ int64_t src_offset, \ int32_t num_arrays, \ int32_t x_elems, \ int32_t y_elems, \ int32_t mulx, \ int32_t muly, \ int32_t repeat_1, \ int32_t repeat_2) { \ __local ELEM_TYPE* block = (__local ELEM_TYPE*)shared_mem; \ int tblock_id_0 = get_tblock_id(0); \ int global_id_0 = get_global_id(0); \ int tblock_id_1 = get_tblock_id(1); \ int global_id_1 = get_global_id(1); \ for (int i1 = 0; i1 <= repeat_1; i1++) { \ int tblock_id_2 = get_tblock_id(2); \ int global_id_2 = get_global_id(2); \ for (int i2 = 0; i2 <= repeat_2; i2++) { \ int32_t our_array_offset = tblock_id_2 * x_elems * y_elems; \ int32_t odata_offset = dst_offset + our_array_offset; \ int32_t idata_offset = src_offset + our_array_offset; \ int32_t x_index = \ tblock_id_0 * TR_BLOCK_DIM * mulx + \ get_local_id(0) + \ get_local_id(1)%mulx * TR_BLOCK_DIM; \ int32_t y_index = tblock_id_1 * TR_BLOCK_DIM + get_local_id(1)/mulx; \ int32_t index_in = y_index * x_elems + x_index; \ if (x_index < x_elems && y_index < y_elems) { \ block[get_local_id(1) * (TR_BLOCK_DIM+1) + get_local_id(0)] = \ src_mem[idata_offset + index_in]; \ } \ barrier_local(); \ x_index = tblock_id_1 * TR_BLOCK_DIM + get_local_id(0)/mulx; \ y_index = \ tblock_id_0 * TR_BLOCK_DIM * mulx + \ get_local_id(1) + \ (get_local_id(0)%mulx) * TR_BLOCK_DIM; \ int32_t index_out = y_index * y_elems + x_index; \ if (x_index < y_elems && y_index < x_elems) { \ dst_mem[odata_offset + index_out] = \ block[get_local_id(0) * (TR_BLOCK_DIM+1) + get_local_id(1)]; \ } \ tblock_id_2 += get_num_tblocks(2); \ global_id_2 += get_global_size(2); \ } \ tblock_id_1 += get_num_tblocks(1); \ global_id_1 += get_global_size(1); \ } \ } \ \ FUTHARK_KERNEL_SIZED(TR_BLOCK_DIM, TR_BLOCK_DIM, 1) \ void map_transpose_##NAME##_low_width(SHARED_MEM_PARAM \ __global ELEM_TYPE *dst_mem, \ int64_t dst_offset, \ __global ELEM_TYPE *src_mem, \ int64_t src_offset, \ int32_t num_arrays, \ int32_t x_elems, \ int32_t y_elems, \ int32_t mulx, \ int32_t muly, \ int32_t repeat_1, \ int32_t repeat_2) { \ __local ELEM_TYPE* block = (__local ELEM_TYPE*)shared_mem; \ int tblock_id_0 = get_tblock_id(0); \ int global_id_0 = get_global_id(0); \ int tblock_id_1 = get_tblock_id(1); \ int global_id_1 = get_global_id(1); \ for (int i1 = 0; i1 <= repeat_1; i1++) { \ int tblock_id_2 = get_tblock_id(2); \ int global_id_2 = get_global_id(2); \ for (int i2 = 0; i2 <= repeat_2; i2++) { \ int32_t our_array_offset = tblock_id_2 * x_elems * y_elems; \ int32_t odata_offset = dst_offset + our_array_offset; \ int32_t idata_offset = src_offset + our_array_offset; \ int32_t x_index = tblock_id_0 * TR_BLOCK_DIM + get_local_id(0)/muly; \ int32_t y_index = \ tblock_id_1 * TR_BLOCK_DIM * muly + \ get_local_id(1) + (get_local_id(0)%muly) * TR_BLOCK_DIM; \ int32_t index_in = y_index * x_elems + x_index; \ if (x_index < x_elems && y_index < y_elems) { \ block[get_local_id(1) * (TR_BLOCK_DIM+1) + get_local_id(0)] = \ src_mem[idata_offset + index_in]; \ } \ barrier_local(); \ x_index = tblock_id_1 * TR_BLOCK_DIM * muly + \ get_local_id(0) + (get_local_id(1)%muly) * TR_BLOCK_DIM; \ y_index = tblock_id_0 * TR_BLOCK_DIM + get_local_id(1)/muly; \ int32_t index_out = y_index * y_elems + x_index; \ if (x_index < y_elems && y_index < x_elems) { \ dst_mem[odata_offset + index_out] = \ block[get_local_id(0) * (TR_BLOCK_DIM+1) + get_local_id(1)]; \ } \ tblock_id_2 += get_num_tblocks(2); \ global_id_2 += get_num_tblocks(2) * get_local_size(2); \ } \ tblock_id_1 += get_num_tblocks(1); \ global_id_1 += get_num_tblocks(1) * get_local_size(1); \ } \ } \ \ FUTHARK_KERNEL_SIZED(TR_BLOCK_DIM*TR_BLOCK_DIM, 1, 1) \ void map_transpose_##NAME##_small(SHARED_MEM_PARAM \ __global ELEM_TYPE *dst_mem, \ int64_t dst_offset, \ __global ELEM_TYPE *src_mem, \ int64_t src_offset, \ int32_t num_arrays, \ int32_t x_elems, \ int32_t y_elems, \ int32_t mulx, \ int32_t muly, \ int32_t repeat_1, \ int32_t repeat_2) { \ (void)mulx; (void)muly; \ __local ELEM_TYPE* block = (__local ELEM_TYPE*)shared_mem; \ int tblock_id_0 = get_tblock_id(0); \ int global_id_0 = get_global_id(0); \ int tblock_id_1 = get_tblock_id(1); \ int global_id_1 = get_global_id(1); \ for (int i1 = 0; i1 <= repeat_1; i1++) { \ int tblock_id_2 = get_tblock_id(2); \ int global_id_2 = get_global_id(2); \ for (int i2 = 0; i2 <= repeat_2; i2++) { \ int32_t our_array_offset = global_id_0/(y_elems * x_elems) * y_elems * x_elems; \ int32_t x_index = (global_id_0 % (y_elems * x_elems))/y_elems; \ int32_t y_index = global_id_0%y_elems; \ int32_t odata_offset = dst_offset + our_array_offset; \ int32_t idata_offset = src_offset + our_array_offset; \ int32_t index_in = y_index * x_elems + x_index; \ int32_t index_out = x_index * y_elems + y_index; \ if (global_id_0 < x_elems * y_elems * num_arrays) { \ dst_mem[odata_offset + index_out] = src_mem[idata_offset + index_in]; \ } \ tblock_id_2 += get_num_tblocks(2); \ global_id_2 += get_global_size(2); \ } \ tblock_id_1 += get_num_tblocks(1); \ global_id_1 += get_global_size(1); \ } \ } \ \ FUTHARK_KERNEL_SIZED(TR_BLOCK_DIM*2, TR_TILE_DIM/TR_ELEMS_PER_THREAD, 1)\ void map_transpose_##NAME##_large(SHARED_MEM_PARAM \ __global ELEM_TYPE *dst_mem, \ int64_t dst_offset, \ __global ELEM_TYPE *src_mem, \ int64_t src_offset, \ int64_t num_arrays, \ int64_t x_elems, \ int64_t y_elems, \ int64_t mulx, \ int64_t muly, \ int32_t repeat_1, \ int32_t repeat_2) { \ (void)mulx; (void)muly; \ __local ELEM_TYPE* block = (__local ELEM_TYPE*)shared_mem; \ int tblock_id_0 = get_tblock_id(0); \ int global_id_0 = get_global_id(0); \ int tblock_id_1 = get_tblock_id(1); \ int global_id_1 = get_global_id(1); \ for (int i1 = 0; i1 <= repeat_1; i1++) { \ int tblock_id_2 = get_tblock_id(2); \ int global_id_2 = get_global_id(2); \ for (int i2 = 0; i2 <= repeat_2; i2++) { \ int64_t our_array_offset = tblock_id_2 * x_elems * y_elems; \ int64_t odata_offset = dst_offset + our_array_offset; \ int64_t idata_offset = src_offset + our_array_offset; \ int64_t x_index = global_id_0; \ int64_t y_index = tblock_id_1 * TR_TILE_DIM + get_local_id(1); \ if (x_index < x_elems) { \ for (int64_t j = 0; j < TR_ELEMS_PER_THREAD; j++) { \ int64_t index_i = (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)) * x_elems + x_index; \ if (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD) < y_elems) { \ block[(get_local_id(1) + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)) * (TR_TILE_DIM+1) + \ get_local_id(0)] = \ src_mem[idata_offset + index_i]; \ } \ } \ } \ barrier_local(); \ x_index = tblock_id_1 * TR_TILE_DIM + get_local_id(0); \ y_index = tblock_id_0 * TR_TILE_DIM + get_local_id(1); \ if (x_index < y_elems) { \ for (int64_t j = 0; j < TR_ELEMS_PER_THREAD; j++) { \ int64_t index_out = (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)) * y_elems + x_index; \ if (y_index + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD) < x_elems) { \ dst_mem[(odata_offset + index_out)] = \ block[get_local_id(0) * (TR_TILE_DIM+1) + \ get_local_id(1) + j * (TR_TILE_DIM/TR_ELEMS_PER_THREAD)]; \ } \ } \ } \ tblock_id_2 += get_num_tblocks(2); \ global_id_2 += get_global_size(2); \ } \ tblock_id_1 += get_num_tblocks(1); \ global_id_1 += get_global_size(1); \ } \ } \ GEN_TRANSPOSE_KERNELS(1b, uint8_t) GEN_TRANSPOSE_KERNELS(2b, uint16_t) GEN_TRANSPOSE_KERNELS(4b, uint32_t) GEN_TRANSPOSE_KERNELS(8b, uint64_t) // End of transpose.cl futhark-0.25.27/rts/python/000077500000000000000000000000001475065116200154735ustar00rootroot00000000000000futhark-0.25.27/rts/python/memory.py000066400000000000000000000107371475065116200173650ustar00rootroot00000000000000# Start of memory.py. import ctypes as ct def allocateMem(size): return np.empty(size, dtype=np.byte) # Copy an array if its is not-None. This is important for treating # Numpy arrays as flat memory, but has some overhead. def normaliseArray(x): if (x.base is x) or (x.base is None): return x else: return x.copy() def unwrapArray(x): return x.ravel().view(np.byte) def indexArray(x, offset, bt): return x.view(bt)[offset] def writeScalarArray(x, offset, v): x.view(type(v))[offset] = v # An opaque Futhark value. class opaque(object): def __init__(self, desc, *payload): self.data = payload self.desc = desc def __repr__(self): return "".format(self.desc) # LMAD stuff def lmad_contiguous_search(checked, expected, strides, shape, used): for i in range(len(strides)): for j in range(len(strides)): if not used[j] and strides[j] == expected and strides[j] >= 0: used[j] = True if checked + 1 == len(strides) or lmad_contiguous_search( checked + 1, expected * shape[j], strides, shape, used ): return True used[j] = False return False def lmad_contiguous(strides, shape): used = len(strides) * [False] return lmad_contiguous_search(0, 1, strides, shape, used) def lmad_memcpyable(dst_strides, src_strides, shape): if not lmad_contiguous(dst_strides, shape): return False for i in range(len(dst_strides)): if dst_strides[i] != src_strides[i] and shape[i] != 1: return False return True def lmad_is_tr(strides, shape): r = len(shape) for i in range(1, r): n = 1 m = 1 ok = True expected = 1 # Check strides before 'i'. for j in range(i - 1, -1, -1): ok = ok and strides[j] == expected expected *= shape[j] n *= shape[j] # Check strides after 'i'. for j in range(r - 1, i - 1, -1): ok = ok and strides[j] == expected expected *= shape[j] m *= shape[j] if ok: return (n, m) return None def lmad_map_tr(dst_strides, src_strides, shape): r = len(dst_strides) rowmajor_strides = [0] * r rowmajor_strides[r - 1] = 1 for i in range(r - 2, -1, -1): rowmajor_strides[i] = rowmajor_strides[i + 1] * shape[i + 1] # map_r will be the number of mapped dimensions on top. map_r = 0 k = 1 for i in range(r): if ( dst_strides[i] != rowmajor_strides[i] or src_strides[i] != rowmajor_strides[i] ): break else: k *= shape[i] map_r += 1 if rowmajor_strides[map_r:] == dst_strides[map_r:]: r = lmad_is_tr(src_strides[map_r:], shape[map_r:]) if r is not None: (n, m) = r return (k, n, m) elif rowmajor_strides[map_r:] == src_strides[map_r:]: r = lmad_is_tr(dst_strides[map_r:], shape[map_r:]) if r is not None: (n, m) = r return (k, m, n) # Sic! return None def lmad_copy_elements( pt, dst, dst_offset, dst_strides, src, src_offset, src_strides, shape ): if len(shape) == 1: for i in range(shape[0]): writeScalarArray( dst, dst_offset + i * dst_strides[0], indexArray(src, src_offset + i * src_strides[0], pt), ) else: for i in range(shape[0]): lmad_copy_elements( pt, dst, dst_offset + i * dst_strides[0], dst_strides[1:], src, src_offset + i * src_strides[0], src_strides[1:], shape[1:], ) def lmad_copy( pt, dst, dst_offset, dst_strides, src, src_offset, src_strides, shape ): if lmad_memcpyable(dst_strides, src_strides, shape): dst[ dst_offset * ct.sizeof(pt) : dst_offset * ct.sizeof(pt) + np.prod(shape) * ct.sizeof(pt) ] = src[ src_offset * ct.sizeof(pt) : src_offset * ct.sizeof(pt) + np.prod(shape) * ct.sizeof(pt) ] else: lmad_copy_elements( pt, dst, dst_offset, dst_strides, src, src_offset, src_strides, shape, ) # End of memory.py. futhark-0.25.27/rts/python/opencl.py000066400000000000000000000365121475065116200173340ustar00rootroot00000000000000# Stub code for OpenCL setup. import pyopencl as cl import numpy as np import sys if cl.version.VERSION < (2015, 2): raise Exception( "Futhark requires at least PyOpenCL version 2015.2. Installed version is %s." % cl.version.VERSION_TEXT ) TR_BLOCK_DIM = 16 TR_TILE_DIM = TR_BLOCK_DIM * 2 TR_ELEMS_PER_THREAD = 8 def parse_preferred_device(s): pref_num = 0 if len(s) > 1 and s[0] == "#": i = 1 while i < len(s): if not s[i].isdigit(): break else: pref_num = pref_num * 10 + int(s[i]) i += 1 while i < len(s) and s[i].isspace(): i += 1 return (s[i:], pref_num) else: return (s, 0) def get_prefered_context( interactive=False, platform_pref=None, device_pref=None ): if device_pref != None: (device_pref, device_num) = parse_preferred_device(device_pref) else: device_num = 0 if interactive: return cl.create_some_context(interactive=True) def blacklisted(p, d): return ( platform_pref == None and device_pref == None and p.name == "Apple" and d.name.find("Intel(R) Core(TM)") >= 0 ) def platform_ok(p): return not platform_pref or p.name.find(platform_pref) >= 0 def device_ok(d): return not device_pref or d.name.find(device_pref) >= 0 device_matches = 0 for p in cl.get_platforms(): if not platform_ok(p): continue for d in p.get_devices(): if blacklisted(p, d) or not device_ok(d): continue if device_matches == device_num: return cl.Context(devices=[d]) else: device_matches += 1 raise Exception( "No OpenCL platform and device matching constraints found." ) def param_assignment(s): name, value = s.split("=") return (name, int(value)) def check_types(self, required_types): if "f64" in required_types: if ( self.device.get_info(cl.device_info.PREFERRED_VECTOR_WIDTH_DOUBLE) == 0 ): raise Exception( "Program uses double-precision floats, but this is not supported on chosen device: %s" % self.device.name ) def apply_size_heuristics(self, size_heuristics, sizes): for platform_name, device_type, size, valuef in size_heuristics: if ( sizes[size] == None and self.platform.name.find(platform_name) >= 0 and (self.device.type & device_type) == device_type ): sizes[size] = valuef(self.device) return sizes def to_c_str_rep(x): if type(x) is bool or type(x) is np.bool_: if x: return "true" else: return "false" else: return str(x) def initialise_opencl_object( self, program_src="", build_options=[], command_queue=None, interactive=False, platform_pref=None, device_pref=None, default_group_size=None, default_num_groups=None, default_tile_size=None, default_reg_tile_size=None, default_threshold=None, size_heuristics=[], required_types=[], all_sizes={}, user_sizes={}, constants=[], ): if command_queue is None: self.ctx = get_prefered_context( interactive, platform_pref, device_pref ) self.queue = cl.CommandQueue(self.ctx) else: self.ctx = command_queue.context self.queue = command_queue self.device = self.queue.device self.platform = self.device.platform self.pool = cl.tools.MemoryPool(cl.tools.ImmediateAllocator(self.queue)) device_type = self.device.type check_types(self, required_types) max_group_size = int(self.device.max_work_group_size) max_tile_size = int(np.sqrt(self.device.max_work_group_size)) self.max_thread_block_size = max_group_size self.max_tile_size = max_tile_size self.max_threshold = 0 self.max_grid_size = 0 self.max_shared_memory = int(self.device.local_mem_size) # Futhark reserves 4 bytes of local memory for its own purposes. self.max_shared_memory -= 4 # See comment in rts/c/opencl.h. if self.platform.name.find("NVIDIA CUDA") >= 0: self.max_shared_memory -= 12 elif self.platform.name.find("AMD") >= 0: self.max_shared_memory -= 16 self.max_registers = int(2**16) # Not sure how to query for this. self.max_cache = self.device.get_info(cl.device_info.GLOBAL_MEM_CACHE_SIZE) if self.max_cache == 0: self.max_cache = 1024 * 1024 self.free_list = {} self.global_failure = self.pool.allocate(np.int32().itemsize) cl.enqueue_fill_buffer( self.queue, self.global_failure, np.int32(-1), 0, np.int32().itemsize ) self.global_failure_args = self.pool.allocate( np.int64().itemsize * (self.global_failure_args_max + 1) ) self.failure_is_an_option = np.int32(0) if "default_group_size" in sizes: default_group_size = sizes["default_group_size"] del sizes["default_group_size"] if "default_num_groups" in sizes: default_num_groups = sizes["default_num_groups"] del sizes["default_num_groups"] if "default_tile_size" in sizes: default_tile_size = sizes["default_tile_size"] del sizes["default_tile_size"] if "default_reg_tile_size" in sizes: default_reg_tile_size = sizes["default_reg_tile_size"] del sizes["default_reg_tile_size"] if "default_threshold" in sizes: default_threshold = sizes["default_threshold"] del sizes["default_threshold"] default_group_size_set = default_group_size != None default_tile_size_set = default_tile_size != None default_sizes = apply_size_heuristics( self, size_heuristics, { "group_size": default_group_size, "tile_size": default_tile_size, "reg_tile_size": default_reg_tile_size, "num_groups": default_num_groups, "lockstep_width": None, "threshold": default_threshold, }, ) default_group_size = default_sizes["group_size"] default_num_groups = default_sizes["num_groups"] default_threshold = default_sizes["threshold"] default_tile_size = default_sizes["tile_size"] default_reg_tile_size = default_sizes["reg_tile_size"] lockstep_width = default_sizes["lockstep_width"] if default_group_size > max_group_size: if default_group_size_set: sys.stderr.write( "Note: Device limits group size to {} (down from {})\n".format( max_tile_size, default_group_size ) ) default_group_size = max_group_size if default_tile_size > max_tile_size: if default_tile_size_set: sys.stderr.write( "Note: Device limits tile size to {} (down from {})\n".format( max_tile_size, default_tile_size ) ) default_tile_size = max_tile_size for k, v in user_sizes.items(): if k in all_sizes: all_sizes[k]["value"] = v else: raise Exception( "Unknown size: {}\nKnown sizes: {}".format( k, " ".join(all_sizes.keys()) ) ) self.sizes = {} for k, v in all_sizes.items(): if v["class"] == "thread_block_size": max_value = max_group_size default_value = default_group_size elif v["class"] == "grid_size": max_value = max_group_size # Intentional! default_value = default_num_groups elif v["class"] == "tile_size": max_value = max_tile_size default_value = default_tile_size elif v["class"] == "reg_tile_size": max_value = None default_value = default_reg_tile_size elif v["class"].startswith("shared_memory"): max_value = self.max_shared_memory default_value = self.max_shared_memory elif v["class"].startswith("cache"): max_value = self.max_cache default_value = self.max_cache elif v["class"].startswith("threshold"): max_value = None default_value = default_threshold else: # Bespoke sizes have no limit or default. max_value = None if v["value"] == None: self.sizes[k] = default_value elif max_value != None and v["value"] > max_value: sys.stderr.write( "Note: Device limits {} to {} (down from {}\n".format( k, max_value, v["value"] ) ) self.sizes[k] = max_value else: self.sizes[k] = v["value"] # XXX: we perform only a subset of z-encoding here. Really, the # compiler should provide us with the variables to which # parameters are mapped. if len(program_src) >= 0: build_options += ["-DLOCKSTEP_WIDTH={}".format(lockstep_width)] build_options += [ "-D{}={}".format("max_thread_block_size", max_group_size) ] build_options += [ "-D{}={}".format( s.replace("z", "zz") .replace(".", "zi") .replace("#", "zh") .replace("'", "zq"), v, ) for (s, v) in self.sizes.items() ] build_options += [ "-D{}={}".format(s, to_c_str_rep(f())) for (s, f) in constants ] if self.platform.name == "Oclgrind": build_options += ["-DEMULATE_F16"] build_options += [ f"-DTR_BLOCK_DIM={TR_BLOCK_DIM}", f"-DTR_TILE_DIM={TR_TILE_DIM}", f"-DTR_ELEMS_PER_THREAD={TR_ELEMS_PER_THREAD}", ] program = cl.Program(self.ctx, program_src).build(build_options) self.transpose_kernels = { 1: { "default": program.map_transpose_1b, "low_height": program.map_transpose_1b_low_height, "low_width": program.map_transpose_1b_low_width, "small": program.map_transpose_1b_small, "large": program.map_transpose_1b_large, }, 2: { "default": program.map_transpose_2b, "low_height": program.map_transpose_2b_low_height, "low_width": program.map_transpose_2b_low_width, "small": program.map_transpose_2b_small, "large": program.map_transpose_2b_large, }, 4: { "default": program.map_transpose_4b, "low_height": program.map_transpose_4b_low_height, "low_width": program.map_transpose_4b_low_width, "small": program.map_transpose_4b_small, "large": program.map_transpose_4b_large, }, 8: { "default": program.map_transpose_8b, "low_height": program.map_transpose_8b_low_height, "low_width": program.map_transpose_8b_low_width, "small": program.map_transpose_8b_small, "large": program.map_transpose_8b_large, }, } self.copy_kernels = { 1: program.lmad_copy_1b, 2: program.lmad_copy_2b, 4: program.lmad_copy_4b, 8: program.lmad_copy_8b, } return program def opencl_alloc(self, min_size, tag): min_size = 1 if min_size == 0 else min_size assert min_size > 0 return self.pool.allocate(min_size) def opencl_free_all(self): self.pool.free_held() def sync(self): failure = np.empty(1, dtype=np.int32) cl.enqueue_copy(self.queue, failure, self.global_failure, is_blocking=True) self.failure_is_an_option = np.int32(0) if failure[0] >= 0: # Reset failure information. cl.enqueue_fill_buffer( self.queue, self.global_failure, np.int32(-1), 0, np.int32().itemsize, ) # Read failure args. failure_args = np.empty( self.global_failure_args_max + 1, dtype=np.int64 ) cl.enqueue_copy( self.queue, failure_args, self.global_failure_args, is_blocking=True, ) raise Exception(self.failure_msgs[failure[0]].format(*failure_args)) def map_transpose_gpu2gpu( self, elem_size, dst, dst_offset, src, src_offset, k, n, m ): kernels = self.transpose_kernels[elem_size] kernel = kernels["default"] mulx = TR_BLOCK_DIM / n muly = TR_BLOCK_DIM / m group_dims = (TR_TILE_DIM, TR_TILE_DIM // TR_ELEMS_PER_THREAD, 1) dims = ( (m + TR_TILE_DIM - 1) // TR_TILE_DIM * group_dims[0], (n + TR_TILE_DIM - 1) // TR_TILE_DIM * group_dims[1], k, ) k32 = np.int32(k) n32 = np.int32(n) m32 = np.int32(m) mulx32 = np.int32(mulx) muly32 = np.int32(muly) kernel.set_args( cl.LocalMemory(TR_TILE_DIM * (TR_TILE_DIM + 1) * elem_size), dst, dst_offset, src, src_offset, k32, m32, n32, mulx32, muly32, np.int32(0), np.int32(0), ) cl.enqueue_nd_range_kernel(self.queue, kernel, dims, group_dims) def copy_elements_gpu2gpu( self, elem_size, dst, dst_offset, dst_strides, src, src_offset, src_strides, shape, ): r = len(shape) if r > 8: raise Exception( "Futhark runtime limitation:\nCannot copy array of greater than rank 8.\n" ) n = np.prod(shape) zero = np.int64(0) layout_args = [None] * (8 * 3) for i in range(8): if i < r: layout_args[i * 3 + 0] = shape[i] layout_args[i * 3 + 1] = dst_strides[i] layout_args[i * 3 + 2] = src_strides[i] else: layout_args[i * 3 + 0] = zero layout_args[i * 3 + 1] = zero layout_args[i * 3 + 2] = zero kernel = self.copy_kernels[elem_size] kernel.set_args( cl.LocalMemory(1), dst, dst_offset, src, src_offset, n, np.int32(r), *layout_args, ) w = 256 dims = ((n + w - 1) // w * w,) group_dims = (w,) cl.enqueue_nd_range_kernel(self.queue, kernel, dims, group_dims) def lmad_copy_gpu2gpu( self, pt, dst, dst_offset, dst_strides, src, src_offset, src_strides, shape ): elem_size = ct.sizeof(pt) nbytes = np.prod(shape) * elem_size if nbytes == 0: return None if lmad_memcpyable(dst_strides, src_strides, shape): cl.enqueue_copy( self.queue, dst, src, dst_offset=dst_offset * elem_size, src_offset=src_offset * elem_size, byte_count=nbytes, ) else: tr = lmad_map_tr(dst_strides, src_strides, shape) if tr is not None: (k, n, m) = tr map_transpose_gpu2gpu( self, elem_size, dst, dst_offset, src, src_offset, k, m, n ) else: copy_elements_gpu2gpu( self, elem_size, dst, dst_offset, dst_strides, src, src_offset, src_strides, shape, ) futhark-0.25.27/rts/python/panic.py000066400000000000000000000003131475065116200171340ustar00rootroot00000000000000# Start of panic.py. def panic(exitcode, fmt, *args): sys.stderr.write("%s: " % sys.argv[0]) sys.stderr.write(fmt % args) sys.stderr.write("\n") sys.exit(exitcode) # End of panic.py. futhark-0.25.27/rts/python/scalar.py000066400000000000000000000420201475065116200173100ustar00rootroot00000000000000# Start of scalar.py. import numpy as np import math import struct def intlit(t, x): if t == np.int8: return np.int8(x) elif t == np.int16: return np.int16(x) elif t == np.int32: return np.int32(x) else: return np.int64(x) def signed(x): if type(x) == np.uint8: return np.int8(x) elif type(x) == np.uint16: return np.int16(x) elif type(x) == np.uint32: return np.int32(x) else: return np.int64(x) def unsigned(x): if type(x) == np.int8: return np.uint8(x) elif type(x) == np.int16: return np.uint16(x) elif type(x) == np.int32: return np.uint32(x) else: return np.uint64(x) def shlN(x, y): return x << y def ashrN(x, y): return x >> y # Python is so slow that we just make all the unsafe operations safe, # always. def sdivN(x, y): if y == 0: return intlit(type(x), 0) else: return x // y def sdiv_upN(x, y): if y == 0: return intlit(type(x), 0) else: return (x + y - intlit(type(x), 1)) // y def smodN(x, y): if y == 0: return intlit(type(x), 0) else: return x % y def udivN(x, y): if y == 0: return intlit(type(x), 0) else: return signed(unsigned(x) // unsigned(y)) def udiv_upN(x, y): if y == 0: return intlit(type(x), 0) else: return signed( (unsigned(x) + unsigned(y) - unsigned(intlit(type(x), 1))) // unsigned(y) ) def umodN(x, y): if y == 0: return intlit(type(x), 0) else: return signed(unsigned(x) % unsigned(y)) def squotN(x, y): if y == 0: return intlit(type(x), 0) else: return np.floor_divide(np.abs(x), np.abs(y)) * np.sign(x) * np.sign(y) def sremN(x, y): if y == 0: return intlit(type(x), 0) else: return np.remainder(np.abs(x), np.abs(y)) * np.sign(x) def sminN(x, y): return min(x, y) def smaxN(x, y): return max(x, y) def uminN(x, y): return signed(min(unsigned(x), unsigned(y))) def umaxN(x, y): return signed(max(unsigned(x), unsigned(y))) def fminN(x, y): return np.fmin(x, y) def fmaxN(x, y): return np.fmax(x, y) def powN(x, y): return x**y def fpowN(x, y): return x**y def sleN(x, y): return x <= y def sltN(x, y): return x < y def uleN(x, y): return unsigned(x) <= unsigned(y) def ultN(x, y): return unsigned(x) < unsigned(y) def lshr8(x, y): return np.int8(np.uint8(x) >> np.uint8(y)) def lshr16(x, y): return np.int16(np.uint16(x) >> np.uint16(y)) def lshr32(x, y): return np.int32(np.uint32(x) >> np.uint32(y)) def lshr64(x, y): return np.int64(np.uint64(x) >> np.uint64(y)) def sext_T_i8(x): return np.int8(x) def sext_T_i16(x): return np.int16(x) def sext_T_i32(x): return np.int32(x) def sext_T_i64(x): return np.int64(x) def itob_T_bool(x): return bool(x) def btoi_bool_i8(x): return np.int8(x) def btoi_bool_i16(x): return np.int16(x) def btoi_bool_i32(x): return np.int32(x) def btoi_bool_i64(x): return np.int64(x) def ftob_T_bool(x): return bool(x) def btof_bool_f16(x): return np.float16(x) def btof_bool_f32(x): return np.float32(x) def btof_bool_f64(x): return np.float64(x) def zext_i8_i8(x): return np.int8(np.uint8(x)) def zext_i8_i16(x): return np.int16(np.uint8(x)) def zext_i8_i32(x): return np.int32(np.uint8(x)) def zext_i8_i64(x): return np.int64(np.uint8(x)) def zext_i16_i8(x): return np.int8(np.uint16(x)) def zext_i16_i16(x): return np.int16(np.uint16(x)) def zext_i16_i32(x): return np.int32(np.uint16(x)) def zext_i16_i64(x): return np.int64(np.uint16(x)) def zext_i32_i8(x): return np.int8(np.uint32(x)) def zext_i32_i16(x): return np.int16(np.uint32(x)) def zext_i32_i32(x): return np.int32(np.uint32(x)) def zext_i32_i64(x): return np.int64(np.uint32(x)) def zext_i64_i8(x): return np.int8(np.uint64(x)) def zext_i64_i16(x): return np.int16(np.uint64(x)) def zext_i64_i32(x): return np.int32(np.uint64(x)) def zext_i64_i64(x): return np.int64(np.uint64(x)) sdiv8 = sdiv16 = sdiv32 = sdiv64 = sdivN sdiv_up8 = sdiv1_up6 = sdiv_up32 = sdiv_up64 = sdiv_upN sdiv_safe8 = sdiv1_safe6 = sdiv_safe32 = sdiv_safe64 = sdivN sdiv_up_safe8 = sdiv_up1_safe6 = sdiv_up_safe32 = sdiv_up_safe64 = sdiv_upN smod8 = smod16 = smod32 = smod64 = smodN smod_safe8 = smod_safe16 = smod_safe32 = smod_safe64 = smodN udiv8 = udiv16 = udiv32 = udiv64 = udivN udiv_up8 = udiv_up16 = udiv_up32 = udiv_up64 = udivN udiv_safe8 = udiv_safe16 = udiv_safe32 = udiv_safe64 = udiv_upN udiv_up_safe8 = udiv_up_safe16 = udiv_up_safe32 = udiv_up_safe64 = udiv_upN umod8 = umod16 = umod32 = umod64 = umodN umod_safe8 = umod_safe16 = umod_safe32 = umod_safe64 = umodN squot8 = squot16 = squot32 = squot64 = squotN squot_safe8 = squot_safe16 = squot_safe32 = squot_safe64 = squotN srem8 = srem16 = srem32 = srem64 = sremN srem_safe8 = srem_safe16 = srem_safe32 = srem_safe64 = sremN shl8 = shl16 = shl32 = shl64 = shlN ashr8 = ashr16 = ashr32 = ashr64 = ashrN smax8 = smax16 = smax32 = smax64 = smaxN smin8 = smin16 = smin32 = smin64 = sminN umax8 = umax16 = umax32 = umax64 = umaxN umin8 = umin16 = umin32 = umin64 = uminN pow8 = pow16 = pow32 = pow64 = powN fpow16 = fpow32 = fpow64 = fpowN fmax16 = fmax32 = fmax64 = fmaxN fmin16 = fmin32 = fmin64 = fminN sle8 = sle16 = sle32 = sle64 = sleN slt8 = slt16 = slt32 = slt64 = sltN ule8 = ule16 = ule32 = ule64 = uleN ult8 = ult16 = ult32 = ult64 = ultN sext_i8_i8 = sext_i16_i8 = sext_i32_i8 = sext_i64_i8 = sext_T_i8 sext_i8_i16 = sext_i16_i16 = sext_i32_i16 = sext_i64_i16 = sext_T_i16 sext_i8_i32 = sext_i16_i32 = sext_i32_i32 = sext_i64_i32 = sext_T_i32 sext_i8_i64 = sext_i16_i64 = sext_i32_i64 = sext_i64_i64 = sext_T_i64 itob_i8_bool = itob_i16_bool = itob_i32_bool = itob_i64_bool = itob_T_bool ftob_f16_bool = ftob_f32_bool = ftob_f64_bool = ftob_T_bool def clz_T(x): n = np.int32(0) bits = x.itemsize * 8 for i in range(bits): if x < 0: break n += np.int32(1) x <<= np.int8(1) return n def ctz_T(x): n = np.int32(0) bits = x.itemsize * 8 for i in range(bits): if (x & 1) == 1: break n += np.int32(1) x >>= np.int8(1) return n def popc_T(x): c = np.int32(0) while x != 0: x &= x - np.int8(1) c += np.int32(1) return c futhark_popc8 = futhark_popc16 = futhark_popc32 = futhark_popc64 = popc_T futhark_clzz8 = futhark_clzz16 = futhark_clzz32 = futhark_clzz64 = clz_T futhark_ctzz8 = futhark_ctzz16 = futhark_ctzz32 = futhark_ctzz64 = ctz_T def ssignum(x): return np.sign(x) def usignum(x): if x < 0: return ssignum(-x) else: return ssignum(x) def sitofp_T_f32(x): return np.float32(x) sitofp_i8_f32 = sitofp_i16_f32 = sitofp_i32_f32 = sitofp_i64_f32 = sitofp_T_f32 def sitofp_T_f64(x): return np.float64(x) sitofp_i8_f64 = sitofp_i16_f64 = sitofp_i32_f64 = sitofp_i64_f64 = sitofp_T_f64 def uitofp_T_f32(x): return np.float32(unsigned(x)) uitofp_i8_f32 = uitofp_i16_f32 = uitofp_i32_f32 = uitofp_i64_f32 = uitofp_T_f32 def uitofp_T_f64(x): return np.float64(unsigned(x)) uitofp_i8_f64 = uitofp_i16_f64 = uitofp_i32_f64 = uitofp_i64_f64 = uitofp_T_f64 def fptosi_T_i8(x): if np.isnan(x) or np.isinf(x): return np.int8(0) else: return np.int8(np.trunc(x)) fptosi_f16_i8 = fptosi_f32_i8 = fptosi_f64_i8 = fptosi_T_i8 def fptosi_T_i16(x): if np.isnan(x) or np.isinf(x): return np.int16(0) else: return np.int16(np.trunc(x)) fptosi_f16_i16 = fptosi_f32_i16 = fptosi_f64_i16 = fptosi_T_i16 def fptosi_T_i32(x): if np.isnan(x) or np.isinf(x): return np.int32(0) else: return np.int32(np.trunc(x)) fptosi_f16_i32 = fptosi_f32_i32 = fptosi_f64_i32 = fptosi_T_i32 def fptosi_T_i64(x): if np.isnan(x) or np.isinf(x): return np.int64(0) else: return np.int64(np.trunc(x)) fptosi_f16_i64 = fptosi_f32_i64 = fptosi_f64_i64 = fptosi_T_i64 def fptoui_T_i8(x): if np.isnan(x) or np.isinf(x): return np.int8(0) else: return np.int8(np.trunc(x)) fptoui_f16_i8 = fptoui_f32_i8 = fptoui_f64_i8 = fptoui_T_i8 def fptoui_T_i16(x): if np.isnan(x) or np.isinf(x): return np.int16(0) else: return np.int16(np.trunc(x)) fptoui_f16_i16 = fptoui_f32_i16 = fptoui_f64_i16 = fptoui_T_i16 def fptoui_T_i32(x): if np.isnan(x) or np.isinf(x): return np.int32(0) else: return np.int32(np.trunc(x)) fptoui_f16_i32 = fptoui_f32_i32 = fptoui_f64_i32 = fptoui_T_i32 def fptoui_T_i64(x): if np.isnan(x) or np.isinf(x): return np.int64(0) else: return np.int64(np.trunc(x)) fptoui_f16_i64 = fptoui_f32_i64 = fptoui_f64_i64 = fptoui_T_i64 def fpconv_f16_f32(x): return np.float32(x) def fpconv_f16_f64(x): return np.float64(x) def fpconv_f32_f16(x): return np.float16(x) def fpconv_f32_f64(x): return np.float64(x) def fpconv_f64_f16(x): return np.float16(x) def fpconv_f64_f32(x): return np.float32(x) def futhark_umul_hi8(a, b): return np.int8( (np.uint64(np.uint8(a)) * np.uint64(np.uint8(b))) >> np.uint64(8) ) def futhark_umul_hi16(a, b): return np.int16( (np.uint64(np.uint16(a)) * np.uint64(np.uint16(b))) >> np.uint64(16) ) def futhark_umul_hi32(a, b): return np.int32( (np.uint64(np.uint32(a)) * np.uint64(np.uint32(b))) >> np.uint64(32) ) def futhark_umul_hi64(a, b): return np.int64(np.uint64(int(np.uint64(a)) * int(np.uint64(b)) >> 64)) def futhark_smul_hi8(a, b): return np.int8((np.int64(a) * np.int64(b)) >> np.int64(8)) def futhark_smul_hi16(a, b): return np.int16((np.int64(a) * np.int64(b)) >> np.int64(16)) def futhark_smul_hi32(a, b): return np.int32((np.int64(a) * np.int64(b)) >> np.int64(32)) def futhark_smul_hi64(a, b): return np.int64(int(a) * int(b) >> 64) def futhark_umad_hi8(a, b, c): return futhark_umul_hi8(a, b) + c def futhark_umad_hi16(a, b, c): return futhark_umul_hi16(a, b) + c def futhark_umad_hi32(a, b, c): return futhark_umul_hi32(a, b) + c def futhark_umad_hi64(a, b, c): return futhark_umul_hi64(a, b) + c def futhark_smad_hi8(a, b, c): return futhark_smul_hi8(a, b) + c def futhark_smad_hi16(a, b, c): return futhark_smul_hi16(a, b) + c def futhark_smad_hi32(a, b, c): return futhark_smul_hi32(a, b) + c def futhark_smad_hi64(a, b, c): return futhark_smul_hi64(a, b) + c def futhark_log64(x): return np.float64(np.log(x)) def futhark_log2_64(x): return np.float64(np.log2(x)) def futhark_log10_64(x): return np.float64(np.log10(x)) def futhark_log1p_64(x): return np.float64(np.log1p(x)) def futhark_sqrt64(x): return np.sqrt(x) def futhark_cbrt64(x): return np.cbrt(x) def futhark_exp64(x): return np.exp(x) def futhark_cos64(x): return np.cos(x) def futhark_sin64(x): return np.sin(x) def futhark_tan64(x): return np.tan(x) def futhark_acos64(x): return np.arccos(x) def futhark_asin64(x): return np.arcsin(x) def futhark_atan64(x): return np.arctan(x) def futhark_cosh64(x): return np.cosh(x) def futhark_sinh64(x): return np.sinh(x) def futhark_tanh64(x): return np.tanh(x) def futhark_acosh64(x): return np.arccosh(x) def futhark_asinh64(x): return np.arcsinh(x) def futhark_atanh64(x): return np.arctanh(x) def futhark_atan2_64(x, y): return np.arctan2(x, y) def futhark_hypot64(x, y): return np.hypot(x, y) def futhark_gamma64(x): return np.float64(math.gamma(x)) def futhark_lgamma64(x): return np.float64(math.lgamma(x)) def futhark_erf64(x): return np.float64(math.erf(x)) def futhark_erfc64(x): return np.float64(math.erfc(x)) def futhark_round64(x): return np.round(x) def futhark_ceil64(x): return np.ceil(x) def futhark_floor64(x): return np.floor(x) def futhark_nextafter64(x, y): return np.nextafter(x, y) def futhark_isnan64(x): return np.isnan(x) def futhark_isinf64(x): return np.isinf(x) def futhark_to_bits64(x): s = struct.pack(">d", x) return np.int64(struct.unpack(">q", s)[0]) def futhark_from_bits64(x): s = struct.pack(">q", x) return np.float64(struct.unpack(">d", s)[0]) def futhark_log32(x): return np.float32(np.log(x)) def futhark_log2_32(x): return np.float32(np.log2(x)) def futhark_log10_32(x): return np.float32(np.log10(x)) def futhark_log1p_32(x): return np.float32(np.log1p(x)) def futhark_sqrt32(x): return np.float32(np.sqrt(x)) def futhark_cbrt32(x): return np.float32(np.cbrt(x)) def futhark_exp32(x): return np.exp(x) def futhark_cos32(x): return np.cos(x) def futhark_sin32(x): return np.sin(x) def futhark_tan32(x): return np.tan(x) def futhark_acos32(x): return np.arccos(x) def futhark_asin32(x): return np.arcsin(x) def futhark_atan32(x): return np.arctan(x) def futhark_cosh32(x): return np.cosh(x) def futhark_sinh32(x): return np.sinh(x) def futhark_tanh32(x): return np.tanh(x) def futhark_acosh32(x): return np.arccosh(x) def futhark_asinh32(x): return np.arcsinh(x) def futhark_atanh32(x): return np.arctanh(x) def futhark_atan2_32(x, y): return np.arctan2(x, y) def futhark_hypot32(x, y): return np.hypot(x, y) def futhark_gamma32(x): return np.float32(math.gamma(x)) def futhark_lgamma32(x): return np.float32(math.lgamma(x)) def futhark_erf32(x): return np.float32(math.erf(x)) def futhark_erfc32(x): return np.float32(math.erfc(x)) def futhark_round32(x): return np.round(x) def futhark_ceil32(x): return np.ceil(x) def futhark_floor32(x): return np.floor(x) def futhark_nextafter32(x, y): return np.nextafter(x, y) def futhark_isnan32(x): return np.isnan(x) def futhark_isinf32(x): return np.isinf(x) def futhark_to_bits32(x): s = struct.pack(">f", x) return np.int32(struct.unpack(">l", s)[0]) def futhark_from_bits32(x): s = struct.pack(">l", x) return np.float32(struct.unpack(">f", s)[0]) def futhark_log16(x): return np.float16(np.log(x)) def futhark_log2_16(x): return np.float16(np.log2(x)) def futhark_log10_16(x): return np.float16(np.log10(x)) def futhark_log1p_16(x): return np.float16(np.log1p(x)) def futhark_sqrt16(x): return np.float16(np.sqrt(x)) def futhark_cbrt16(x): return np.float16(np.cbrt(x)) def futhark_exp16(x): return np.exp(x) def futhark_cos16(x): return np.cos(x) def futhark_sin16(x): return np.sin(x) def futhark_tan16(x): return np.tan(x) def futhark_acos16(x): return np.arccos(x) def futhark_asin16(x): return np.arcsin(x) def futhark_atan16(x): return np.arctan(x) def futhark_cosh16(x): return np.cosh(x) def futhark_sinh16(x): return np.sinh(x) def futhark_tanh16(x): return np.tanh(x) def futhark_acosh16(x): return np.arccosh(x) def futhark_asinh16(x): return np.arcsinh(x) def futhark_atanh16(x): return np.arctanh(x) def futhark_atan2_16(x, y): return np.arctan2(x, y) def futhark_hypot16(x, y): return np.hypot(x, y) def futhark_gamma16(x): return np.float16(math.gamma(x)) def futhark_lgamma16(x): return np.float16(math.lgamma(x)) def futhark_erf16(x): return np.float16(math.erf(x)) def futhark_erfc16(x): return np.float16(math.erfc(x)) def futhark_round16(x): return np.round(x) def futhark_ceil16(x): return np.ceil(x) def futhark_floor16(x): return np.floor(x) def futhark_nextafter16(x, y): return np.nextafter(x, y) def futhark_isnan16(x): return np.isnan(x) def futhark_isinf16(x): return np.isinf(x) def futhark_to_bits16(x): s = struct.pack(">e", x) return np.int16(struct.unpack(">H", s)[0]) def futhark_from_bits16(x): s = struct.pack(">H", np.uint16(x)) return np.float16(struct.unpack(">e", s)[0]) def futhark_lerp16(v0, v1, t): return v0 + (v1 - v0) * t def futhark_lerp32(v0, v1, t): return v0 + (v1 - v0) * t def futhark_lerp64(v0, v1, t): return v0 + (v1 - v0) * t def futhark_ldexp16(x, y): return np.ldexp(x, y) def futhark_ldexp32(x, y): return np.ldexp(x, y) def futhark_ldexp64(x, y): return np.ldexp(x, y) def futhark_mad16(a, b, c): return a * b + c def futhark_mad32(a, b, c): return a * b + c def futhark_mad64(a, b, c): return a * b + c def futhark_fma16(a, b, c): return a * b + c def futhark_fma32(a, b, c): return a * b + c def futhark_fma64(a, b, c): return a * b + c futhark_copysign16 = futhark_copysign32 = futhark_copysign64 = np.copysign def futhark_cond(x, y, z): return y if x else z futhark_cond_f16 = futhark_cond_f32 = futhark_cond_f64 = futhark_cond futhark_cond_i18 = futhark_cond_i16 = futhark_cond_i32 = futhark_cond_i64 = ( futhark_cond ) futhark_cond_bool = futhark_cond_unit = futhark_cond # End of scalar.py. futhark-0.25.27/rts/python/server.py000066400000000000000000000144061475065116200173600ustar00rootroot00000000000000# Start of server.py import sys import time import shlex # For string splitting class Server: def __init__(self, ctx): self._ctx = ctx self._vars = {} class Failure(BaseException): def __init__(self, msg): self.msg = msg def _get_arg(self, args, i): if i < len(args): return args[i] else: raise self.Failure("Insufficient command args") def _get_entry_point(self, entry): if entry in self._ctx.entry_points: return self._ctx.entry_points[entry] else: raise self.Failure("Unknown entry point: %s" % entry) def _check_var(self, vname): if not vname in self._vars: raise self.Failure("Unknown variable: %s" % vname) def _check_new_var(self, vname): if vname in self._vars: raise self.Failure("Variable already exists: %s" % vname) def _get_var(self, vname): self._check_var(vname) return self._vars[vname] def _cmd_inputs(self, args): entry = self._get_arg(args, 0) for t in self._get_entry_point(entry)[1]: print(t) def _cmd_outputs(self, args): entry = self._get_arg(args, 0) for t in self._get_entry_point(entry)[2]: print(t) def _cmd_dummy(self, args): pass def _cmd_free(self, args): for vname in args: self._check_var(vname) del self._vars[vname] def _cmd_rename(self, args): oldname = self._get_arg(args, 0) newname = self._get_arg(args, 1) self._check_var(oldname) self._check_new_var(newname) self._vars[newname] = self._vars[oldname] del self._vars[oldname] def _cmd_call(self, args): entry = self._get_entry_point(self._get_arg(args, 0)) entry_fname = entry[0] num_ins = len(entry[1]) num_outs = len(entry[2]) exp_len = 1 + num_outs + num_ins if len(args) != exp_len: raise self.Failure("Invalid argument count, expected %d" % exp_len) out_vnames = args[1 : num_outs + 1] for out_vname in out_vnames: self._check_new_var(out_vname) in_vnames = args[1 + num_outs :] ins = [self._get_var(in_vname) for in_vname in in_vnames] try: (runtime, vals) = getattr(self._ctx, entry_fname)(*ins) except Exception as e: raise self.Failure(str(e)) print("runtime: %d" % runtime) if num_outs == 1: self._vars[out_vnames[0]] = vals else: for out_vname, val in zip(out_vnames, vals): self._vars[out_vname] = val def _store_val(self, f, value): # In case we are using the PyOpenCL backend, we first # need to convert OpenCL arrays to ordinary NumPy # arrays. We do this in a nasty way. if isinstance(value, opaque): for component in value.data: self._store_val(f, component) elif ( isinstance(value, np.number) or isinstance(value, bool) or isinstance(value, np.bool_) or isinstance(value, np.ndarray) ): # Ordinary NumPy value. f.write(construct_binary_value(value)) else: # Assuming PyOpenCL array. f.write(construct_binary_value(value.get())) def _cmd_store(self, args): fname = self._get_arg(args, 0) with open(fname, "wb") as f: for i in range(1, len(args)): self._store_val(f, self._get_var(args[i])) def _restore_val(self, reader, typename): if typename in self._ctx.opaques: vs = [] for t in self._ctx.opaques[typename]: vs += [read_value(t, reader)] return opaque(typename, *vs) else: return read_value(typename, reader) def _cmd_restore(self, args): if len(args) % 2 == 0: raise self.Failure("Invalid argument count") fname = args[0] args = args[1:] with open(fname, "rb") as f: reader = ReaderInput(f) while args != []: vname = args[0] typename = args[1] args = args[2:] if vname in self._vars: raise self.Failure("Variable already exists: %s" % vname) try: self._vars[vname] = self._restore_val(reader, typename) except ValueError: raise self.Failure( "Failed to restore variable %s.\n" "Possibly malformed data in %s.\n" % (vname, fname) ) skip_spaces(reader) if reader.get_char() != b"": raise self.Failure("Expected EOF after reading values") def _cmd_types(self, args): for k in self._ctx.opaques.keys(): print(k) def _cmd_entry_points(self, args): for k in self._ctx.entry_points.keys(): print(k) _commands = { "inputs": _cmd_inputs, "outputs": _cmd_outputs, "call": _cmd_call, "restore": _cmd_restore, "store": _cmd_store, "free": _cmd_free, "rename": _cmd_rename, "clear": _cmd_dummy, "pause_profiling": _cmd_dummy, "unpause_profiling": _cmd_dummy, "report": _cmd_dummy, "types": _cmd_types, "entry_points": _cmd_entry_points, } def _process_line(self, line): lex = shlex.shlex(line) lex.quotes = '"' lex.whitespace_split = True lex.commenters = "" words = list(lex) if words == []: raise self.Failure("Empty line") else: cmd = words[0] args = words[1:] if cmd in self._commands: self._commands[cmd](self, args) else: raise self.Failure("Unknown command: %s" % cmd) def run(self): while True: print("%%% OK", flush=True) line = sys.stdin.readline() if line == "": return try: self._process_line(line) except self.Failure as e: print("%%% FAILURE") print(e.msg) # End of server.py futhark-0.25.27/rts/python/tuning.py000066400000000000000000000003061475065116200173500ustar00rootroot00000000000000# Start of tuning.py def read_tuning_file(kvs, f): for line in f.read().splitlines(): size, value = line.split("=") kvs[size] = int(value) return kvs # End of tuning.py. futhark-0.25.27/rts/python/values.py000066400000000000000000000514141475065116200173510ustar00rootroot00000000000000# Start of values.py. # Hacky parser/reader/writer for values written in Futhark syntax. # Used for reading stdin when compiling standalone programs with the # Python code generator. import numpy as np import string import struct import sys class ReaderInput: def __init__(self, f): self.f = f self.lookahead_buffer = [] def get_char(self): if len(self.lookahead_buffer) == 0: return self.f.read(1) else: c = self.lookahead_buffer[0] self.lookahead_buffer = self.lookahead_buffer[1:] return c def unget_char(self, c): self.lookahead_buffer = [c] + self.lookahead_buffer def get_chars(self, n): n1 = min(n, len(self.lookahead_buffer)) s = b"".join(self.lookahead_buffer[:n1]) self.lookahead_buffer = self.lookahead_buffer[n1:] n2 = n - n1 if n2 > 0: s += self.f.read(n2) return s def peek_char(self): c = self.get_char() if c: self.unget_char(c) return c def skip_spaces(f): c = f.get_char() while c != None: if c.isspace(): c = f.get_char() elif c == b"-": # May be line comment. if f.peek_char() == b"-": # Yes, line comment. Skip to end of line. while c != b"\n" and c != None: c = f.get_char() else: break else: break if c: f.unget_char(c) def parse_specific_char(f, expected): got = f.get_char() if got != expected: f.unget_char(got) raise ValueError return True def parse_specific_string(f, s): # This funky mess is intended, and is caused by the fact that if `type(b) == # bytes` then `type(b[0]) == int`, but we need to match each element with a # `bytes`, so therefore we make each character an array element b = s.encode("utf8") bs = [b[i : i + 1] for i in range(len(b))] read = [] try: for c in bs: parse_specific_char(f, c) read.append(c) return True except ValueError: for c in read[::-1]: f.unget_char(c) raise def optional(p, *args): try: return p(*args) except ValueError: return None def optional_specific_string(f, s): c = f.peek_char() # This funky mess is intended, and is caused by the fact that if `type(b) == # bytes` then `type(b[0]) == int`, but we need to match each element with a # `bytes`, so therefore we make each character an array element b = s.encode("utf8") bs = [b[i : i + 1] for i in range(len(b))] if c == bs[0]: return parse_specific_string(f, s) else: return False def sepEndBy(p, sep, *args): elems = [] x = optional(p, *args) if x != None: elems += [x] while optional(sep, *args) != None: x = optional(p, *args) if x == None: break else: elems += [x] return elems # Assumes '0x' has already been read def parse_hex_int(f): s = b"" c = f.get_char() while c != None: if c in b"01234556789ABCDEFabcdef": s += c c = f.get_char() elif c == b"_": c = f.get_char() # skip _ else: f.unget_char(c) break return str(int(s, 16)).encode("utf8") # ugh def parse_int(f): s = b"" c = f.get_char() if c == b"0" and f.peek_char() in b"xX": c = f.get_char() # skip X return parse_hex_int(f) else: while c != None: if c.isdigit(): s += c c = f.get_char() elif c == b"_": c = f.get_char() # skip _ else: f.unget_char(c) break if len(s) == 0: raise ValueError return s def parse_int_signed(f): s = b"" c = f.get_char() if c == b"-" and f.peek_char().isdigit(): return c + parse_int(f) else: if c != b"+": f.unget_char(c) return parse_int(f) def read_str_comma(f): skip_spaces(f) parse_specific_char(f, b",") return b"," def read_str_int(f, s): skip_spaces(f) x = int(parse_int_signed(f)) optional_specific_string(f, s) return x def read_str_uint(f, s): skip_spaces(f) x = int(parse_int(f)) optional_specific_string(f, s) return x def read_str_i8(f): return np.int8(read_str_int(f, "i8")) def read_str_i16(f): return np.int16(read_str_int(f, "i16")) def read_str_i32(f): return np.int32(read_str_int(f, "i32")) def read_str_i64(f): return np.int64(read_str_int(f, "i64")) def read_str_u8(f): return np.uint8(read_str_int(f, "u8")) def read_str_u16(f): return np.uint16(read_str_int(f, "u16")) def read_str_u32(f): return np.uint32(read_str_int(f, "u32")) def read_str_u64(f): return np.uint64(read_str_int(f, "u64")) def read_char(f): skip_spaces(f) parse_specific_char(f, b"'") c = f.get_char() parse_specific_char(f, b"'") return c def read_str_hex_float(f, sign): int_part = parse_hex_int(f) parse_specific_char(f, b".") frac_part = parse_hex_int(f) parse_specific_char(f, b"p") exponent = parse_int(f) int_val = int(int_part, 16) frac_val = float(int(frac_part, 16)) / (16 ** len(frac_part)) exp_val = int(exponent) total_val = (int_val + frac_val) * (2.0**exp_val) if sign == b"-": total_val = -1 * total_val return float(total_val) def read_str_decimal(f): skip_spaces(f) c = f.get_char() if c == b"-": sign = b"-" else: f.unget_char(c) sign = b"" # Check for hexadecimal float c = f.get_char() if c == "0" and (f.peek_char() in ["x", "X"]): f.get_char() return read_str_hex_float(f, sign) else: f.unget_char(c) bef = optional(parse_int, f) if bef == None: bef = b"0" parse_specific_char(f, b".") aft = parse_int(f) elif optional(parse_specific_char, f, b"."): aft = parse_int(f) else: aft = b"0" if optional(parse_specific_char, f, b"E") or optional( parse_specific_char, f, b"e" ): expt = parse_int_signed(f) else: expt = b"0" return float(sign + bef + b"." + aft + b"E" + expt) def read_str_f16(f): skip_spaces(f) try: parse_specific_string(f, "f16.nan") return np.float32(np.nan) except ValueError: try: parse_specific_string(f, "f16.inf") return np.float32(np.inf) except ValueError: try: parse_specific_string(f, "-f16.inf") return np.float32(-np.inf) except ValueError: x = read_str_decimal(f) optional_specific_string(f, "f16") return x def read_str_f32(f): skip_spaces(f) try: parse_specific_string(f, "f32.nan") return np.float32(np.nan) except ValueError: try: parse_specific_string(f, "f32.inf") return np.float32(np.inf) except ValueError: try: parse_specific_string(f, "-f32.inf") return np.float32(-np.inf) except ValueError: x = read_str_decimal(f) optional_specific_string(f, "f32") return x def read_str_f64(f): skip_spaces(f) try: parse_specific_string(f, "f64.nan") return np.float64(np.nan) except ValueError: try: parse_specific_string(f, "f64.inf") return np.float64(np.inf) except ValueError: try: parse_specific_string(f, "-f64.inf") return np.float64(-np.inf) except ValueError: x = read_str_decimal(f) optional_specific_string(f, "f64") return x def read_str_bool(f): skip_spaces(f) if f.peek_char() == b"t": parse_specific_string(f, "true") return True elif f.peek_char() == b"f": parse_specific_string(f, "false") return False else: raise ValueError def read_str_empty_array(f, type_name, rank): parse_specific_string(f, "empty") parse_specific_char(f, b"(") dims = [] for i in range(rank): parse_specific_string(f, "[") dims += [int(parse_int(f))] parse_specific_string(f, "]") if np.prod(dims) != 0: raise ValueError parse_specific_string(f, type_name) parse_specific_char(f, b")") return tuple(dims) def read_str_array_elems(f, elem_reader, type_name, rank): skip_spaces(f) try: parse_specific_char(f, b"[") except ValueError: return read_str_empty_array(f, type_name, rank) else: xs = sepEndBy(elem_reader, read_str_comma, f) skip_spaces(f) parse_specific_char(f, b"]") return xs def read_str_array_helper(f, elem_reader, type_name, rank): def nested_row_reader(_): return read_str_array_helper(f, elem_reader, type_name, rank - 1) if rank == 1: row_reader = elem_reader else: row_reader = nested_row_reader return read_str_array_elems(f, row_reader, type_name, rank) def expected_array_dims(l, rank): if rank > 1: n = len(l) if n == 0: elem = [] else: elem = l[0] return [n] + expected_array_dims(elem, rank - 1) else: return [len(l)] def verify_array_dims(l, dims): if dims[0] != len(l): raise ValueError if len(dims) > 1: for x in l: verify_array_dims(x, dims[1:]) def read_str_array(f, elem_reader, type_name, rank, bt): elems = read_str_array_helper(f, elem_reader, type_name, rank) if type(elems) == tuple: # Empty array return np.empty(elems, dtype=bt) else: dims = expected_array_dims(elems, rank) verify_array_dims(elems, dims) return np.array(elems, dtype=bt) ################################################################################ READ_BINARY_VERSION = 2 # struct format specified at # https://docs.python.org/2/library/struct.html#format-characters def mk_bin_scalar_reader(t): def bin_reader(f): fmt = FUTHARK_PRIMTYPES[t]["bin_format"] size = FUTHARK_PRIMTYPES[t]["size"] tf = FUTHARK_PRIMTYPES[t]["numpy_type"] return tf(struct.unpack("<" + fmt, f.get_chars(size))[0]) return bin_reader read_bin_i8 = mk_bin_scalar_reader("i8") read_bin_i16 = mk_bin_scalar_reader("i16") read_bin_i32 = mk_bin_scalar_reader("i32") read_bin_i64 = mk_bin_scalar_reader("i64") read_bin_u8 = mk_bin_scalar_reader("u8") read_bin_u16 = mk_bin_scalar_reader("u16") read_bin_u32 = mk_bin_scalar_reader("u32") read_bin_u64 = mk_bin_scalar_reader("u64") read_bin_f16 = mk_bin_scalar_reader("f16") read_bin_f32 = mk_bin_scalar_reader("f32") read_bin_f64 = mk_bin_scalar_reader("f64") read_bin_bool = mk_bin_scalar_reader("bool") def read_is_binary(f): skip_spaces(f) c = f.get_char() if c == b"b": bin_version = read_bin_u8(f) if bin_version != READ_BINARY_VERSION: panic( 1, "binary-input: File uses version %i, but I only understand version %i.\n", bin_version, READ_BINARY_VERSION, ) return True else: f.unget_char(c) return False FUTHARK_PRIMTYPES = { "i8": { "binname": b" i8", "size": 1, "bin_reader": read_bin_i8, "str_reader": read_str_i8, "bin_format": "b", "numpy_type": np.int8, }, "i16": { "binname": b" i16", "size": 2, "bin_reader": read_bin_i16, "str_reader": read_str_i16, "bin_format": "h", "numpy_type": np.int16, }, "i32": { "binname": b" i32", "size": 4, "bin_reader": read_bin_i32, "str_reader": read_str_i32, "bin_format": "i", "numpy_type": np.int32, }, "i64": { "binname": b" i64", "size": 8, "bin_reader": read_bin_i64, "str_reader": read_str_i64, "bin_format": "q", "numpy_type": np.int64, }, "u8": { "binname": b" u8", "size": 1, "bin_reader": read_bin_u8, "str_reader": read_str_u8, "bin_format": "B", "numpy_type": np.uint8, }, "u16": { "binname": b" u16", "size": 2, "bin_reader": read_bin_u16, "str_reader": read_str_u16, "bin_format": "H", "numpy_type": np.uint16, }, "u32": { "binname": b" u32", "size": 4, "bin_reader": read_bin_u32, "str_reader": read_str_u32, "bin_format": "I", "numpy_type": np.uint32, }, "u64": { "binname": b" u64", "size": 8, "bin_reader": read_bin_u64, "str_reader": read_str_u64, "bin_format": "Q", "numpy_type": np.uint64, }, "f16": { "binname": b" f16", "size": 2, "bin_reader": read_bin_f16, "str_reader": read_str_f16, "bin_format": "e", "numpy_type": np.float16, }, "f32": { "binname": b" f32", "size": 4, "bin_reader": read_bin_f32, "str_reader": read_str_f32, "bin_format": "f", "numpy_type": np.float32, }, "f64": { "binname": b" f64", "size": 8, "bin_reader": read_bin_f64, "str_reader": read_str_f64, "bin_format": "d", "numpy_type": np.float64, }, "bool": { "binname": b"bool", "size": 1, "bin_reader": read_bin_bool, "str_reader": read_str_bool, "bin_format": "b", "numpy_type": bool, }, } def read_bin_read_type(f): read_binname = f.get_chars(4) for k, v in FUTHARK_PRIMTYPES.items(): if v["binname"] == read_binname: return k panic(1, "binary-input: Did not recognize the type '%s'.\n", read_binname) def numpy_type_to_type_name(t): for k, v in FUTHARK_PRIMTYPES.items(): if v["numpy_type"] == t: return k raise Exception("Unknown Numpy type: {}".format(t)) def read_bin_ensure_scalar(f, expected_type): dims = read_bin_i8(f) if dims != 0: panic( 1, "binary-input: Expected scalar (0 dimensions), but got array with %i dimensions.\n", dims, ) bin_type = read_bin_read_type(f) if bin_type != expected_type: panic( 1, "binary-input: Expected scalar of type %s but got scalar of type %s.\n", expected_type, bin_type, ) # ------------------------------------------------------------------------------ # General interface for reading Primitive Futhark Values # ------------------------------------------------------------------------------ def read_scalar(f, ty): if read_is_binary(f): read_bin_ensure_scalar(f, ty) return FUTHARK_PRIMTYPES[ty]["bin_reader"](f) return FUTHARK_PRIMTYPES[ty]["str_reader"](f) def read_array(f, expected_type, rank): if not read_is_binary(f): str_reader = FUTHARK_PRIMTYPES[expected_type]["str_reader"] return read_str_array( f, str_reader, expected_type, rank, FUTHARK_PRIMTYPES[expected_type]["numpy_type"], ) bin_rank = read_bin_u8(f) if bin_rank != rank: panic( 1, "binary-input: Expected %i dimensions, but got array with %i dimensions.\n", rank, bin_rank, ) bin_type_enum = read_bin_read_type(f) if expected_type != bin_type_enum: panic( 1, "binary-input: Expected %iD-array with element type '%s' but got %iD-array with element type '%s'.\n", rank, expected_type, bin_rank, bin_type_enum, ) shape = [] elem_count = 1 for i in range(rank): bin_size = read_bin_i64(f) elem_count *= bin_size shape.append(bin_size) bin_fmt = FUTHARK_PRIMTYPES[bin_type_enum]["bin_format"] # We first read the expected number of types into a bytestring, # then use np.frombuffer. This is because np.fromfile does not # work on things that are insufficiently file-like, like a network # stream. bytes = f.get_chars(elem_count * FUTHARK_PRIMTYPES[expected_type]["size"]) arr = np.frombuffer( bytes, dtype=FUTHARK_PRIMTYPES[bin_type_enum]["numpy_type"] ) arr.shape = shape return arr.copy() # To ensure it is writeable. if sys.version_info >= (3, 0): input_reader = ReaderInput(sys.stdin.buffer) else: input_reader = ReaderInput(sys.stdin) import re def read_value(type_desc, reader=input_reader): """Read a value of the given type. The type is a string representation of the Futhark type.""" m = re.match(r"((?:\[\])*)([a-z0-9]+)$", type_desc) if m: dims = int(len(m.group(1)) / 2) basetype = m.group(2) assert m and basetype in FUTHARK_PRIMTYPES, "Unknown type: {}".format( type_desc ) if dims > 0: return read_array(reader, basetype, dims) else: return read_scalar(reader, basetype) def end_of_input(entry, f=input_reader): skip_spaces(f) if f.get_char() != b"": panic(1, 'Expected EOF on stdin after reading input for "%s".', entry) def write_value_text(v, out=sys.stdout): if type(v) == np.uint8: out.write("%uu8" % v) elif type(v) == np.uint16: out.write("%uu16" % v) elif type(v) == np.uint32: out.write("%uu32" % v) elif type(v) == np.uint64: out.write("%uu64" % v) elif type(v) == np.int8: out.write("%di8" % v) elif type(v) == np.int16: out.write("%di16" % v) elif type(v) == np.int32: out.write("%di32" % v) elif type(v) == np.int64: out.write("%di64" % v) elif type(v) in [bool, np.bool_]: if v: out.write("true") else: out.write("false") elif type(v) == np.float16: if np.isnan(v): out.write("f16.nan") elif np.isinf(v): if v >= 0: out.write("f16.inf") else: out.write("-f16.inf") else: out.write("%.6ff16" % v) elif type(v) == np.float32: if np.isnan(v): out.write("f32.nan") elif np.isinf(v): if v >= 0: out.write("f32.inf") else: out.write("-f32.inf") else: out.write("%.6ff32" % v) elif type(v) == np.float64: if np.isnan(v): out.write("f64.nan") elif np.isinf(v): if v >= 0: out.write("f64.inf") else: out.write("-f64.inf") else: out.write("%.6ff64" % v) elif type(v) == np.ndarray: if np.prod(v.shape) == 0: tname = numpy_type_to_type_name(v.dtype) out.write( "empty({}{})".format( "".join(["[{}]".format(d) for d in v.shape]), tname ) ) else: first = True out.write("[") for x in v: if not first: out.write(", ") first = False write_value(x, out=out) out.write("]") else: raise Exception("Cannot print value of type {}: {}".format(type(v), v)) type_strs = { np.dtype("int8"): b" i8", np.dtype("int16"): b" i16", np.dtype("int32"): b" i32", np.dtype("int64"): b" i64", np.dtype("uint8"): b" u8", np.dtype("uint16"): b" u16", np.dtype("uint32"): b" u32", np.dtype("uint64"): b" u64", np.dtype("float16"): b" f16", np.dtype("float32"): b" f32", np.dtype("float64"): b" f64", np.dtype("bool"): b"bool", } def construct_binary_value(v): t = v.dtype shape = v.shape elems = 1 for d in shape: elems *= d num_bytes = 1 + 1 + 1 + 4 + len(shape) * 8 + elems * t.itemsize bytes = bytearray(num_bytes) bytes[0] = np.int8(ord("b")) bytes[1] = 2 bytes[2] = np.int8(len(shape)) bytes[3:7] = type_strs[t] for i in range(len(shape)): bytes[7 + i * 8 : 7 + (i + 1) * 8] = np.int64(shape[i]).tobytes() bytes[7 + len(shape) * 8 :] = np.ascontiguousarray(v).tobytes() return bytes def write_value_binary(v, out=sys.stdout): if sys.version_info >= (3, 0): out = out.buffer out.write(construct_binary_value(v)) def write_value(v, out=sys.stdout, binary=False): if binary: return write_value_binary(v, out=out) else: return write_value_text(v, out=out) # End of values.py. futhark-0.25.27/setup.cfg000066400000000000000000000001251475065116200151610ustar00rootroot00000000000000[mypy] ignore_missing_imports = True exclude = benchmark-performance-plot.py|conf.py futhark-0.25.27/shell.nix000066400000000000000000000020501475065116200151660ustar00rootroot00000000000000# See header comment in default.nix for how to update sources.nix. let sources = import ./nix/sources.nix; pkgs = import sources.nixpkgs {}; python = pkgs.python312Packages; haskell = pkgs.haskell.packages.ghc96; in pkgs.stdenv.mkDerivation { name = "futhark"; buildInputs = with pkgs; [ cabal-install cacert curl file git parallel haskell.ghc ormolu haskell.weeder haskell.haskell-language-server haskellPackages.graphmod haskellPackages.apply-refact xdot hlint pkg-config zlib zlib.out cabal2nix ghcid niv ispc python.python python.mypy python.black python.cycler python.numpy python.pyopencl python.matplotlib python.jsonschema python.sphinx python.sphinxcontrib-bibtex imagemagick # needed for literate tests ] ++ lib.optionals (stdenv.isLinux) [ opencl-headers ocl-icd oclgrind rocmPackages.clr ] ; } futhark-0.25.27/src/000077500000000000000000000000001475065116200141315ustar00rootroot00000000000000futhark-0.25.27/src/Futhark.hs000066400000000000000000000150011475065116200160660ustar00rootroot00000000000000-- | -- -- This module contains no code. Its purpose is to explain the -- overall architecture of the code base, with links to the relevant -- modules and definitions. It is intended for people who intend to -- hack on the compiler itself, not for those who just want to use its -- public API as a library. Most of what's here discusses the -- compiler itself, but we'll also provide pointers for where to find -- the code for the various other tools and subcommands. -- -- Much of the documentation here is by referencing other modules that -- contain more information. Make sure to follow the links. Also, -- make sure to click the *source* links to study the actual code. -- Many of the links are provided not because the exposed module API -- is particularly interesting, but because their implementation is. -- -- = Compiler -- -- At a very high level, the Futhark compiler has a fairly -- conventional architecture. After reading and type-checking a -- program, it is gradually transformed, passing through a variety of -- different /intermediate representations/ (IRs), where the specifics -- depend on which code generator is intended. The final result is -- then written to one or more files. We can divide the compiler into -- roughly three parts: frontend, middle-end [sic], and backend. -- -- == Frontend -- -- The frontend of the compiler is concerned with parsing the Futhark -- source language, type-checking it, and transforming it to the core -- IR used by the middle-end of the compiler. The following modules -- are of particular interest: -- -- * "Language.Futhark.Syntax" contains the source language AST definition. -- -- * "Language.Futhark.Parser" contains parsers for the source -- language and various fragments. -- -- * "Futhark.Compiler" and "Futhark.Compiler.Program" contains -- functions for conveniently reading and type-checking an entire -- Futhark source program, chasing down imports, and so on. -- -- * "Language.Futhark.TypeChecker" is the type checker itself. -- -- * /Internalisation/ of the source language to the core IR is in -- "Futhark.Internalise". -- -- == Middle-end -- -- The compiler middle-end is based around /passes/ that are -- essentially pure functions that accept a program as input and -- produce a program as output. A composition of such passes is -- called a /pipeline/. The various compiler backends (@futhark -- opencl@, @futhark multicore@, etc) use different pipelines. See -- "Futhark.Pass" and (less importantly) "Futhark.Pipeline". The -- actual pipelines we use are in "Futhark.Passes". -- -- The compiler front-end produces a core IR program that is then -- processed by a pipeline in the middle-end. The fundamental -- structure of this IR is defined in "Futhark.IR.Syntax". As -- mentioned in that module, the IR supports multiple -- /representations/. The middle-end always starts with the program -- in the SOACS representation ("Futhark.IR.SOACS"), which closely -- resembles the kind of free-form nested parallelism that the source -- language permits. Depending on the specific compilation pipeline, -- the program will eventually be transformed to various other -- representations. The output of the middle-end is always a core IR -- program in some representation (probably not the same one that it -- started with). -- -- Many passes will involve constructing core IR AST fragments. See -- "Futhark.Construct" for advice and guidance on how to do that. -- -- === Main representations -- -- * "Futhark.IR.SOACS": the initial core representation of any -- Futhark program. This is the representation on which we perform -- optimisations such as fusion, inlining, and a host of other -- cleanup. -- -- * "Futhark.IR.GPU": a representation where parallelism is expressed -- with flat /segmented operations/, and a few other GPU-specific -- operations are also supported. The pass -- "Futhark.Pass.ExtractKernels" transforms a -- 'Futhark.IR.SOACS.SOACS' program to a 'Futhark.IR.GPU.GPU' -- program. -- -- * "Futhark.IR.MC": a representation where parallelism is expressed -- with flat /segmented operations/, and a few other multicore-specific -- operations are also supported. The pass -- "Futhark.Pass.ExtractMulticore" transforms a -- 'Futhark.IR.SOACS.SOACS' program to a 'Futhark.IR.MC.MC' -- program. -- -- * "Futhark.IR.GPUMem": like 'Futhark.IR.GPU.GPU', but with memory -- information. See "Futhark.IR.Mem" for information on how we -- represent memory in the Futhark compiler. This representation is -- produced by "Futhark.Pass.ExplicitAllocations.GPU" -- -- * "Futhark.IR.MCMem": the multicore counterpart to -- 'Futhark.IR.GPUMem.GPUMem'. Produced by -- "Futhark.Pass.ExplicitAllocations.MC". -- -- == Backend -- -- The backend accepts a program in some core IR representation. -- Currently this must always be a representation with memory -- information (e.g. 'Futhark.IR.GPUMem.GPUMem'). It then translates -- the program to the /imperative/ IR, which we call ImpCode. -- -- * "Futhark.CodeGen.ImpCode": The main definition of ImpCode, which -- is an extensible representation like the core IR. -- -- * "Futhark.CodeGen.ImpCode.GPU": an example of ImpCode -- extended/specialised to handle GPU operations (but at a high -- level, before generating actual CUDA/OpenCL kernel code). -- -- * "Futhark.CodeGen.ImpGen": a translator from core IR to ImpCode. -- Heavily parameterised so that it can handle the various dialects -- of both core IR and ImpCode that it must be expected to work -- with. In practice, when adding a new backend (or modifying an -- existing one), you'll be working on more specialised modules such -- as "Futhark.CodeGen.ImpGen.GPU". -- -- Ultimately, ImpCode must then be translated to some /real/ -- executable language, which in our case means C or Python. -- -- * "Futhark.CodeGen.Backends.GenericC": framework for translating -- ImpCode to C. Rather convoluted because it is used in a large -- number of settings. The module -- "Futhark.CodeGen.Backends.SequentialC" shows perhaps the simplest -- use of it. -- -- * "Futhark.CodeGen.Backends.GenericPython": the generic Python code -- generator (see "Futhark.CodeGen.Backends.SequentialPython" for a -- usage example). -- -- = Command line interface -- -- The @futhark@ program dispatches to a Haskell-level main function -- based on its first argument (the subcommand). Some of these -- subcommands are implemented as their own modules under the -- @Futhark.CLI@ hierarchy. Others, particularly the simplest ones, -- are bundled together in "Futhark.CLI.Misc". module Futhark () where futhark-0.25.27/src/Futhark/000077500000000000000000000000001475065116200155355ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/AD/000077500000000000000000000000001475065116200160215ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/AD/Derivatives.hs000066400000000000000000000375011475065116200206500ustar00rootroot00000000000000-- | Partial derivatives of scalar Futhark operations and built-in functions. module Futhark.AD.Derivatives ( pdBuiltin, pdBinOp, pdUnOp, ) where import Data.Bifunctor (bimap) import Futhark.Analysis.PrimExp.Convert import Futhark.IR.Syntax.Core (Name, VName, nameToText) import Futhark.Util.IntegralExp import Prelude hiding (quot) iConst :: IntType -> Integer -> PrimExp VName iConst it x = ValueExp $ IntValue $ intValue it x fConst :: FloatType -> Double -> PrimExp VName fConst ft x = ValueExp $ FloatValue $ floatValue ft x untyped2 :: (TPrimExp t v, TPrimExp t v) -> (PrimExp v, PrimExp v) untyped2 = bimap untyped untyped -- | @pdUnOp op x@ computes the partial derivatives of @op@ -- with respect to @x@. pdUnOp :: UnOp -> PrimExp VName -> PrimExp VName pdUnOp (Abs it) a = UnOpExp (SSignum it) a pdUnOp (FAbs ft) a = UnOpExp (FSignum ft) a pdUnOp (Neg Bool) x = x pdUnOp (Neg Unit) x = x pdUnOp (Neg (IntType it)) _ = iConst it (-1) pdUnOp (Neg (FloatType ft)) _ = fConst ft (-1) pdUnOp (Complement it) x = UnOpExp (Complement it) x pdUnOp (SSignum it) _ = iConst it 0 pdUnOp (USignum it) _ = iConst it 0 pdUnOp (FSignum ft) _ = fConst ft 0 type OnBinOp t v = TPrimExp t v -> TPrimExp t v -> (TPrimExp t v, TPrimExp t v) intBinOp :: OnBinOp Int8 v -> OnBinOp Int16 v -> OnBinOp Int32 v -> OnBinOp Int64 v -> IntType -> PrimExp v -> PrimExp v -> (PrimExp v, PrimExp v) intBinOp f _ _ _ Int8 a b = untyped2 $ f (isInt8 a) (isInt8 b) intBinOp _ f _ _ Int16 a b = untyped2 $ f (isInt16 a) (isInt16 b) intBinOp _ _ f _ Int32 a b = untyped2 $ f (isInt32 a) (isInt32 b) intBinOp _ _ _ f Int64 a b = untyped2 $ f (isInt64 a) (isInt64 b) floatBinOp :: OnBinOp Half v -> OnBinOp Float v -> OnBinOp Double v -> FloatType -> PrimExp v -> PrimExp v -> (PrimExp v, PrimExp v) floatBinOp f _ _ Float16 a b = untyped2 $ f (isF16 a) (isF16 b) floatBinOp _ f _ Float32 a b = untyped2 $ f (isF32 a) (isF32 b) floatBinOp _ _ f Float64 a b = untyped2 $ f (isF64 a) (isF64 b) -- | @pdBinOp op x y@ computes the partial derivatives of @op@ with -- respect to @x@ and @y@. pdBinOp :: BinOp -> PrimExp VName -> PrimExp VName -> (PrimExp VName, PrimExp VName) pdBinOp (Add it _) _ _ = (iConst it 1, iConst it 1) pdBinOp (Sub it _) _ _ = (iConst it 1, iConst it (-1)) pdBinOp (Mul _ _) x y = (y, x) pdBinOp (Pow it) a b = intBinOp derivs derivs derivs derivs it a b where derivs x y = ((x `pow` (y - 1)) * y, 0) -- FIXME (wrt y) pdBinOp (SDiv it _) a b = intBinOp derivs derivs derivs derivs it a b where derivs x y = (1 `quot` y, negate (x `quot` (y * y))) pdBinOp (SDivUp it _) a b = intBinOp derivs derivs derivs derivs it a b where derivs x y = (1 `quot` y, negate (x `quot` (y * y))) pdBinOp (SQuot it _) a b = intBinOp derivs derivs derivs derivs it a b where derivs x y = (1 `quot` y, negate (x `quot` (y * y))) pdBinOp (UDiv it _) a b = intBinOp derivs derivs derivs derivs it a b where derivs x y = (1 `quot` y, negate (x `quot` (y * y))) pdBinOp (UDivUp it _) a b = intBinOp derivs derivs derivs derivs it a b where derivs x y = (1 `quot` y, negate (x `quot` (y * y))) pdBinOp (UMod it _) _ _ = (iConst it 1, iConst it 0) -- FIXME pdBinOp (SMod it _) _ _ = (iConst it 1, iConst it 0) -- FIXME pdBinOp (SRem it _) _ _ = (iConst it 1, iConst it 0) -- FIXME pdBinOp (FMod ft) _ _ = (fConst ft 1, fConst ft 0) -- FIXME pdBinOp (UMax it) a b = intBinOp derivs derivs derivs derivs it a b where derivs x y = (fromBoolExp (x .>=. y), fromBoolExp (x .<. y)) pdBinOp (SMax it) a b = intBinOp derivs derivs derivs derivs it a b where derivs x y = (fromBoolExp (x .>=. y), fromBoolExp (x .<. y)) pdBinOp (UMin it) a b = intBinOp derivs derivs derivs derivs it a b where derivs x y = (fromBoolExp (x .<=. y), fromBoolExp (x .>. y)) pdBinOp (SMin it) a b = intBinOp derivs derivs derivs derivs it a b where derivs x y = (fromBoolExp (x .<=. y), fromBoolExp (x .>. y)) -- pdBinOp (Shl it) a b = pdBinOp (Mul it OverflowWrap) a $ BinOpExp (Pow it) (iConst it 2) b pdBinOp (LShr it) a b = pdBinOp (UDiv it Unsafe) a $ BinOpExp (Pow it) (iConst it 2) b pdBinOp (AShr it) a b = pdBinOp (SDiv it Unsafe) a $ BinOpExp (Pow it) (iConst it 2) b pdBinOp (And it) _a _b = (iConst it 0, iConst it 0) -- FIXME pdBinOp (Or it) _a _b = (iConst it 0, iConst it 0) -- FIXME pdBinOp (Xor it) _a _b = (iConst it 0, iConst it 0) -- FIXME -- pdBinOp (FAdd ft) _ _ = (fConst ft 1, fConst ft 1) pdBinOp (FSub ft) _ _ = (fConst ft 1, fConst ft (-1)) pdBinOp (FMul _) x y = (y, x) pdBinOp (FDiv ft) a b = floatBinOp derivs derivs derivs ft a b where derivs x y = (1 / y, negate (x / (y * y))) pdBinOp (FPow ft) a b = floatBinOp derivs derivs derivs ft a b where derivs x y = ( y * (x ** (y - 1)), condExp (x .<=. 0) 0 ((x ** y) * log x) ) pdBinOp (FMax ft) a b = floatBinOp derivs derivs derivs ft a b where derivs x y = (fromBoolExp (x .>=. y), fromBoolExp (x .<. y)) pdBinOp (FMin ft) a b = floatBinOp derivs derivs derivs ft a b where derivs x y = (fromBoolExp (x .<=. y), fromBoolExp (x .>. y)) pdBinOp LogAnd a b = (b, a) pdBinOp LogOr _ _ = (ValueExp $ BoolValue True, ValueExp $ BoolValue False) -- | @pdBuiltin f args i@ computes the partial derivative of @f@ -- applied to @args@ with respect to each of its arguments. Returns -- 'Nothing' if no such derivative is known. pdBuiltin :: Name -> [PrimExp VName] -> Maybe [PrimExp VName] pdBuiltin "sqrt16" [x] = Just [untyped $ 1 / (2 * sqrt (isF16 x))] pdBuiltin "sqrt32" [x] = Just [untyped $ 1 / (2 * sqrt (isF32 x))] pdBuiltin "sqrt64" [x] = Just [untyped $ 1 / (2 * sqrt (isF64 x))] pdBuiltin "cbrt16" [x] = Just [untyped $ 1 / (3 * cbrt16 (isF16 x) * cbrt16 (isF16 x))] where cbrt16 a = isF16 $ FunExp "cbrt16" [untyped a] $ FloatType Float16 pdBuiltin "cbrt32" [x] = Just [untyped $ 1 / (3 * cbrt32 (isF32 x) * cbrt32 (isF32 x))] where cbrt32 a = isF32 $ FunExp "cbrt32" [untyped a] $ FloatType Float32 pdBuiltin "cbrt64" [x] = Just [untyped $ 1 / (3 * cbrt64 (isF64 x) * cbrt64 (isF64 x))] where cbrt64 a = isF64 $ FunExp "cbrt64" [untyped a] $ FloatType Float32 pdBuiltin "log16" [x] = Just [untyped $ 1 / isF16 x] pdBuiltin "log32" [x] = Just [untyped $ 1 / isF32 x] pdBuiltin "log64" [x] = Just [untyped $ 1 / isF64 x] pdBuiltin "log10_16" [x] = Just [untyped $ 1 / (isF16 x * log 10)] pdBuiltin "log10_32" [x] = Just [untyped $ 1 / (isF32 x * log 10)] pdBuiltin "log10_64" [x] = Just [untyped $ 1 / (isF64 x * log 10)] pdBuiltin "log2_16" [x] = Just [untyped $ 1 / (isF16 x * log 2)] pdBuiltin "log2_32" [x] = Just [untyped $ 1 / (isF32 x * log 2)] pdBuiltin "log2_64" [x] = Just [untyped $ 1 / (isF64 x * log 2)] pdBuiltin "log1p_16" [x] = Just [untyped $ 1 / (isF16 x + 1)] pdBuiltin "log1p_32" [x] = Just [untyped $ 1 / (isF32 x + 1)] pdBuiltin "log1p_64" [x] = Just [untyped $ 1 / (isF64 x + 1)] pdBuiltin "exp16" [x] = Just [untyped $ exp (isF16 x)] pdBuiltin "exp32" [x] = Just [untyped $ exp (isF32 x)] pdBuiltin "exp64" [x] = Just [untyped $ exp (isF64 x)] pdBuiltin "sin16" [x] = Just [untyped $ cos (isF16 x)] pdBuiltin "sin32" [x] = Just [untyped $ cos (isF32 x)] pdBuiltin "sin64" [x] = Just [untyped $ cos (isF64 x)] pdBuiltin "sinh16" [x] = Just [untyped $ cosh (isF16 x)] pdBuiltin "sinh32" [x] = Just [untyped $ cosh (isF32 x)] pdBuiltin "sinh64" [x] = Just [untyped $ cosh (isF64 x)] pdBuiltin "cos16" [x] = Just [untyped $ -sin (isF16 x)] pdBuiltin "cos32" [x] = Just [untyped $ -sin (isF32 x)] pdBuiltin "cos64" [x] = Just [untyped $ -sin (isF64 x)] pdBuiltin "cosh16" [x] = Just [untyped $ sinh (isF16 x)] pdBuiltin "cosh32" [x] = Just [untyped $ sinh (isF32 x)] pdBuiltin "cosh64" [x] = Just [untyped $ sinh (isF64 x)] pdBuiltin "tan16" [x] = Just [untyped $ 1 / (cos (isF16 x) * cos (isF16 x))] pdBuiltin "tan32" [x] = Just [untyped $ 1 / (cos (isF32 x) * cos (isF32 x))] pdBuiltin "tan64" [x] = Just [untyped $ 1 / (cos (isF64 x) * cos (isF64 x))] pdBuiltin "asin16" [x] = Just [untyped $ 1 / sqrt (1 - isF16 x * isF16 x)] pdBuiltin "asin32" [x] = Just [untyped $ 1 / sqrt (1 - isF32 x * isF32 x)] pdBuiltin "asin64" [x] = Just [untyped $ 1 / sqrt (1 - isF64 x * isF64 x)] pdBuiltin "asinh16" [x] = Just [untyped $ 1 / sqrt (1 + isF16 x * isF16 x)] pdBuiltin "asinh32" [x] = Just [untyped $ 1 / sqrt (1 + isF32 x * isF32 x)] pdBuiltin "asinh64" [x] = Just [untyped $ 1 / sqrt (1 + isF64 x * isF64 x)] pdBuiltin "acos16" [x] = Just [untyped $ -1 / sqrt (1 - isF16 x * isF16 x)] pdBuiltin "acos32" [x] = Just [untyped $ -1 / sqrt (1 - isF32 x * isF32 x)] pdBuiltin "acos64" [x] = Just [untyped $ -1 / sqrt (1 - isF64 x * isF64 x)] pdBuiltin "acosh16" [x] = Just [untyped $ 1 / sqrt (isF16 x * isF16 x - 1)] pdBuiltin "acosh32" [x] = Just [untyped $ 1 / sqrt (isF32 x * isF32 x - 1)] pdBuiltin "acosh64" [x] = Just [untyped $ 1 / sqrt (isF64 x * isF64 x - 1)] pdBuiltin "atan16" [x] = Just [untyped $ 1 / (1 + isF16 x * isF16 x)] pdBuiltin "atan32" [x] = Just [untyped $ 1 / (1 + isF32 x * isF32 x)] pdBuiltin "atan64" [x] = Just [untyped $ 1 / (1 + isF64 x * isF64 x)] pdBuiltin "atanh16" [x] = Just [untyped $ cosh (isF16 x) * cosh (isF16 x)] pdBuiltin "atanh32" [x] = Just [untyped $ cosh (isF32 x) * cosh (isF32 x)] pdBuiltin "atanh64" [x] = Just [untyped $ cosh (isF64 x) * cosh (isF64 x)] pdBuiltin "atan2_16" [x, y] = Just [ untyped $ -isF16 y / (isF16 x * isF16 x + isF16 y * isF16 y), untyped $ -isF16 x / (isF16 x * isF16 x + isF16 y * isF16 y) ] pdBuiltin "atan2_32" [x, y] = Just [ untyped $ -isF32 y / (isF32 x * isF32 x + isF32 y * isF32 y), untyped $ -isF32 x / (isF32 x * isF32 x + isF32 y * isF32 y) ] pdBuiltin "atan2_64" [x, y] = Just [ untyped $ -isF64 y / (isF64 x * isF64 x + isF64 y * isF64 y), untyped $ -isF64 x / (isF64 x * isF64 x + isF64 y * isF64 y) ] pdBuiltin "tanh16" [x] = Just [untyped $ 1 - tanh (isF16 x) * tanh (isF16 x)] pdBuiltin "tanh32" [x] = Just [untyped $ 1 - tanh (isF32 x) * tanh (isF32 x)] pdBuiltin "tanh64" [x] = Just [untyped $ 1 - tanh (isF64 x) * tanh (isF64 x)] pdBuiltin "fma16" [a, b, _c] = Just [b, a, fConst Float16 1] pdBuiltin "fma32" [a, b, _c] = Just [b, a, fConst Float32 1] pdBuiltin "fma64" [a, b, _c] = Just [b, a, fConst Float64 1] pdBuiltin "mad16" [a, b, _c] = Just [b, a, fConst Float16 1] pdBuiltin "mad32" [a, b, _c] = Just [b, a, fConst Float32 1] pdBuiltin "mad64" [a, b, _c] = Just [b, a, fConst Float64 1] pdBuiltin "from_bits16" [_] = Just [fConst Float16 1] pdBuiltin "from_bits32" [_] = Just [fConst Float32 1] pdBuiltin "from_bits64" [_] = Just [fConst Float64 1] pdBuiltin "to_bits16" [_] = Just [iConst Int16 1] pdBuiltin "to_bits32" [_] = Just [iConst Int32 1] pdBuiltin "to_bits64" [_] = Just [iConst Int64 1] pdBuiltin "hypot16" [x, y] = Just [ untyped $ isF16 x / isF16 (FunExp "hypot16" [x, y] $ FloatType Float16), untyped $ isF16 y / isF16 (FunExp "hypot16" [x, y] $ FloatType Float16) ] pdBuiltin "hypot32" [x, y] = Just [ untyped $ isF32 x / isF32 (FunExp "hypot32" [x, y] $ FloatType Float32), untyped $ isF32 y / isF32 (FunExp "hypot32" [x, y] $ FloatType Float32) ] pdBuiltin "hypot64" [x, y] = Just [ untyped $ isF64 x / isF64 (FunExp "hypot64" [x, y] $ FloatType Float64), untyped $ isF64 y / isF64 (FunExp "hypot64" [x, y] $ FloatType Float64) ] pdBuiltin "lerp16" [v0, v1, t] = Just [ untyped $ 1 - fMax16 0 (fMin16 1 (isF16 t)), untyped $ fMax16 0 (fMin16 1 (isF16 t)), untyped $ isF16 v1 - isF16 v0 ] pdBuiltin "lerp32" [v0, v1, t] = Just [ untyped $ 1 - fMax32 0 (fMin32 1 (isF32 t)), untyped $ fMax32 0 (fMin32 1 (isF32 t)), untyped $ isF32 v1 - isF32 v0 ] pdBuiltin "lerp64" [v0, v1, t] = Just [ untyped $ 1 - fMax64 0 (fMin64 1 (isF64 t)), untyped $ fMax64 0 (fMin64 1 (isF64 t)), untyped $ isF64 v1 - isF64 v0 ] pdBuiltin "ldexp16" [x, y] = Just [ untyped $ 2 ** isF16 x, untyped $ log 2 * (2 ** isF16 y) * isF16 x ] pdBuiltin "ldexp32" [x, y] = Just [ untyped $ 2 ** isF32 x, untyped $ log 2 * (2 ** isF32 y) * isF32 x ] pdBuiltin "ldexp64" [x, y] = Just [ untyped $ 2 ** isF64 x, untyped $ log 2 * (2 ** isF64 y) * isF64 x ] pdBuiltin "erf16" [z] = Just [untyped $ (2 / sqrt pi) * exp (negate (isF16 z * isF16 z))] pdBuiltin "erf32" [z] = Just [untyped $ (2 / sqrt pi) * exp (negate (isF32 z * isF32 z))] pdBuiltin "erf64" [z] = Just [untyped $ (2 / sqrt pi) * exp (negate (isF64 z * isF64 z))] pdBuiltin "erfc16" [z] = Just [untyped $ negate $ (2 / sqrt pi) * exp (negate (isF16 z * isF16 z))] pdBuiltin "erfc32" [z] = Just [untyped $ negate $ (2 / sqrt pi) * exp (negate (isF32 z * isF32 z))] pdBuiltin "erfc64" [z] = Just [untyped $ negate $ (2 / sqrt pi) * exp (negate (isF64 z * isF64 z))] pdBuiltin "copysign16" [_x, y] = Just [untyped $ 1 * isF16 (UnOpExp (FSignum Float16) y), fConst Float16 0] pdBuiltin "copysign32" [_x, y] = Just [untyped $ 1 * isF32 (UnOpExp (FSignum Float32) y), fConst Float32 0] pdBuiltin "copysign64" [_x, y] = Just [untyped $ 1 * isF64 (UnOpExp (FSignum Float64) y), fConst Float64 0] pdBuiltin h [x, _y, _z] | Just t <- isCondFun $ nameToText h = Just [ boolToT t false, boolToT t $ isBool x, boolToT t $ bNot $ isBool x ] where boolToT t = case t of IntType it -> ConvOpExp (BToI it) . untyped FloatType ft -> ConvOpExp (SIToFP Int32 ft) . ConvOpExp (BToI Int32) . untyped Bool -> untyped Unit -> const $ ValueExp UnitValue -- More problematic derivatives follow below. pdBuiltin "umul_hi8" [x, y] = Just [y, x] pdBuiltin "umul_hi16" [x, y] = Just [y, x] pdBuiltin "umul_hi32" [x, y] = Just [y, x] pdBuiltin "umul_hi64" [x, y] = Just [y, x] pdBuiltin "umad_hi8" [a, b, _c] = Just [b, a, iConst Int8 1] pdBuiltin "umad_hi16" [a, b, _c] = Just [b, a, iConst Int16 1] pdBuiltin "umad_hi32" [a, b, _c] = Just [b, a, iConst Int32 1] pdBuiltin "umad_hi64" [a, b, _c] = Just [b, a, iConst Int64 1] pdBuiltin "smul_hi8" [x, y] = Just [y, x] pdBuiltin "smul_hi16" [x, y] = Just [y, x] pdBuiltin "smul_hi32" [x, y] = Just [y, x] pdBuiltin "smul_hi64" [x, y] = Just [y, x] pdBuiltin "smad_hi8" [a, b, _c] = Just [b, a, iConst Int8 1] pdBuiltin "smad_hi16" [a, b, _c] = Just [b, a, iConst Int16 1] pdBuiltin "smad_hi32" [a, b, _c] = Just [b, a, iConst Int32 1] pdBuiltin "smad_hi64" [a, b, _c] = Just [b, a, iConst Int64 1] pdBuiltin "isnan16" [_] = Just [untyped false] pdBuiltin "isnan32" [_] = Just [untyped false] pdBuiltin "isnan64" [_] = Just [untyped false] pdBuiltin "isinf16" [_] = Just [untyped false] pdBuiltin "isinf32" [_] = Just [untyped false] pdBuiltin "isinf64" [_] = Just [untyped false] pdBuiltin "round16" [_] = Just [fConst Float16 0] pdBuiltin "round32" [_] = Just [fConst Float32 0] pdBuiltin "round64" [_] = Just [fConst Float64 0] pdBuiltin "ceil16" [_] = Just [fConst Float16 0] pdBuiltin "ceil32" [_] = Just [fConst Float32 0] pdBuiltin "ceil64" [_] = Just [fConst Float64 0] pdBuiltin "floor16" [_] = Just [fConst Float16 0] pdBuiltin "floor32" [_] = Just [fConst Float32 0] pdBuiltin "floor64" [_] = Just [fConst Float64 0] pdBuiltin "nextafter16" [_, _] = Just [fConst Float16 1, fConst Float16 0] pdBuiltin "nextafter32" [_, _] = Just [fConst Float32 1, fConst Float32 0] pdBuiltin "nextafter64" [_, _] = Just [fConst Float64 1, fConst Float64 0] pdBuiltin "clz8" [_] = Just [iConst Int32 0] pdBuiltin "clz16" [_] = Just [iConst Int32 0] pdBuiltin "clz32" [_] = Just [iConst Int32 0] pdBuiltin "clz64" [_] = Just [iConst Int32 0] pdBuiltin "ctz8" [_] = Just [iConst Int32 0] pdBuiltin "ctz16" [_] = Just [iConst Int32 0] pdBuiltin "ctz32" [_] = Just [iConst Int32 0] pdBuiltin "ctz64" [_] = Just [iConst Int32 0] pdBuiltin "popc8" [_] = Just [iConst Int32 0] pdBuiltin "popc16" [_] = Just [iConst Int32 0] pdBuiltin "popc32" [_] = Just [iConst Int32 0] pdBuiltin "popc64" [_] = Just [iConst Int32 0] pdBuiltin _ _ = Nothing futhark-0.25.27/src/Futhark/AD/Fwd.hs000066400000000000000000000367531475065116200171130ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.AD.Fwd (fwdJVP) where import Control.Monad import Control.Monad.RWS.Strict import Control.Monad.State.Strict import Data.Bifunctor (second) import Data.List (transpose) import Data.List.NonEmpty (NonEmpty (..)) import Data.Map qualified as M import Futhark.AD.Derivatives import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.Construct import Futhark.IR.SOACS zeroTan :: Type -> ADM SubExp zeroTan (Prim t) = pure $ constant $ blankPrimValue t zeroTan t = error $ "zeroTan on non-primitive type: " ++ prettyString t zeroExp :: Type -> Exp SOACS zeroExp (Prim pt) = BasicOp $ SubExp $ Constant $ blankPrimValue pt zeroExp (Array pt shape _) = BasicOp $ Replicate shape $ Constant $ blankPrimValue pt zeroExp t = error $ "zeroExp: " ++ show t tanType :: TypeBase s u -> ADM (TypeBase s u) tanType (Acc acc ispace ts u) = do ts_tan <- mapM tanType ts pure $ Acc acc ispace (ts ++ ts_tan) u tanType t = pure t slocal' :: ADM a -> ADM a slocal' = slocal id slocal :: (RState -> RState) -> ADM a -> ADM a slocal f m = do s <- get modify f a <- m modify $ \s' -> s' {stateTans = stateTans s} pure a data RState = RState { stateTans :: M.Map VName VName, stateNameSource :: VNameSource } newtype ADM a = ADM (BuilderT SOACS (State RState) a) deriving ( Functor, Applicative, Monad, MonadState RState, MonadFreshNames, HasScope SOACS, LocalScope SOACS ) instance MonadBuilder ADM where type Rep ADM = SOACS mkExpDecM pat e = ADM $ mkExpDecM pat e mkBodyM bnds res = ADM $ mkBodyM bnds res mkLetNamesM pat e = ADM $ mkLetNamesM pat e addStms = ADM . addStms collectStms (ADM m) = ADM $ collectStms m instance MonadFreshNames (State RState) where getNameSource = gets stateNameSource putNameSource src = modify (\env -> env {stateNameSource = src}) runADM :: (MonadFreshNames m) => ADM a -> m a runADM (ADM m) = modifyNameSource $ \vn -> second stateNameSource $ runState (fst <$> runBuilderT m mempty) (RState mempty vn) tanVName :: VName -> ADM VName tanVName v = newVName (baseString v <> "_tan") insertTan :: VName -> VName -> ADM () insertTan v v' = modify $ \env -> env {stateTans = M.insert v v' (stateTans env)} class TanBuilder a where newTan :: a -> ADM a bundleNew :: a -> ADM [a] bundleNewList :: (TanBuilder a) => [a] -> ADM [a] bundleNewList = fmap mconcat . mapM bundleNew instance TanBuilder (PatElem (TypeBase s u)) where newTan (PatElem p t) | isAcc t = do insertTan p p t' <- tanType t pure $ PatElem p t' | otherwise = do p' <- tanVName p insertTan p p' t' <- tanType t pure $ PatElem p' t' bundleNew pe@(PatElem _ t) = do pe' <- newTan pe if isAcc t then pure [pe'] else pure [pe, pe'] newTanPat :: (TanBuilder (PatElem t)) => Pat t -> ADM (Pat t) newTanPat (Pat pes) = Pat <$> mapM newTan pes bundleNewPat :: (TanBuilder (PatElem t)) => Pat t -> ADM (Pat t) bundleNewPat (Pat pes) = Pat <$> bundleNewList pes instance TanBuilder (Param (TypeBase s u)) where newTan (Param _ p t) = do PatElem p' t' <- newTan $ PatElem p t pure $ Param mempty p' t' bundleNew param@(Param _ _ (Prim Unit)) = pure [param] bundleNew param@(Param _ _ t) = do param' <- newTan param if isAcc t then pure [param'] else pure [param, param'] instance (Tangent a) => TanBuilder (Param (TypeBase s u), a) where newTan (p, x) = (,) <$> newTan p <*> tangent x bundleNew (p, x) = do b <- bundleNew p x_tan <- tangent x pure $ zip b [x, x_tan] class Tangent a where tangent :: a -> ADM a bundleTan :: a -> ADM [a] instance Tangent (TypeBase s u) where tangent = tanType bundleTan t | isAcc t = do t' <- tangent t pure [t'] | otherwise = do t' <- tangent t pure [t, t'] bundleTangents :: (Tangent a) => [a] -> ADM [a] bundleTangents = (mconcat <$>) . mapM bundleTan instance Tangent VName where tangent v = do maybeTan <- gets $ M.lookup v . stateTans case maybeTan of Just v_tan -> pure v_tan Nothing -> do t <- lookupType v letExp (baseString v <> "_implicit_tan") $ zeroExp t bundleTan v = do t <- lookupType v if isAcc t then pure [v] else do v_tan <- tangent v pure [v, v_tan] instance Tangent SubExp where tangent (Constant c) = zeroTan $ Prim $ primValueType c tangent (Var v) = Var <$> tangent v bundleTan c@Constant {} = do c_tan <- tangent c pure [c, c_tan] bundleTan (Var v) = fmap Var <$> bundleTan v instance Tangent SubExpRes where tangent (SubExpRes cs se) = SubExpRes cs <$> tangent se bundleTan (SubExpRes cs se) = map (SubExpRes cs) <$> bundleTan se basicFwd :: Pat Type -> StmAux () -> BasicOp -> ADM () basicFwd pat aux op = do pat_tan <- newTanPat pat case op of SubExp se -> do se_tan <- tangent se addStm $ Let pat_tan aux $ BasicOp $ SubExp se_tan Opaque opaqueop se -> do se_tan <- tangent se addStm $ Let pat_tan aux $ BasicOp $ Opaque opaqueop se_tan ArrayLit ses t -> do ses_tan <- mapM tangent ses addStm $ Let pat_tan aux $ BasicOp $ ArrayLit ses_tan t UnOp unop x -> do let t = unOpType unop x_pe = primExpFromSubExp t x dx = pdUnOp unop x_pe x_tan <- primExpFromSubExp t <$> tangent x auxing aux $ letBindNames (patNames pat_tan) <=< toExp $ x_tan ~*~ dx BinOp bop x y -> do let t = binOpType bop x_tan <- primExpFromSubExp t <$> tangent x y_tan <- primExpFromSubExp t <$> tangent y let (wrt_x, wrt_y) = pdBinOp bop (primExpFromSubExp t x) (primExpFromSubExp t y) auxing aux $ letBindNames (patNames pat_tan) <=< toExp $ x_tan ~*~ wrt_x ~+~ y_tan ~*~ wrt_y CmpOp {} -> addStm $ Let pat_tan aux $ BasicOp op ConvOp cop x -> do x_tan <- tangent x addStm $ Let pat_tan aux $ BasicOp $ ConvOp cop x_tan Assert {} -> pure () Index arr slice -> do arr_tan <- tangent arr addStm $ Let pat_tan aux $ BasicOp $ Index arr_tan slice Update safety arr slice se -> do arr_tan <- tangent arr se_tan <- tangent se addStm $ Let pat_tan aux $ BasicOp $ Update safety arr_tan slice se_tan Concat d (arr :| arrs) w -> do arr_tan <- tangent arr arrs_tans <- mapM tangent arrs addStm $ Let pat_tan aux $ BasicOp $ Concat d (arr_tan :| arrs_tans) w Manifest ds arr -> do arr_tan <- tangent arr addStm $ Let pat_tan aux $ BasicOp $ Manifest ds arr_tan Iota n _ _ it -> do addStm $ Let pat_tan aux $ BasicOp $ Replicate (Shape [n]) (intConst it 0) Replicate n x -> do x_tan <- tangent x addStm $ Let pat_tan aux $ BasicOp $ Replicate n x_tan Scratch t shape -> addStm $ Let pat_tan aux $ BasicOp $ Scratch t shape Reshape k reshape arr -> do arr_tan <- tangent arr addStm $ Let pat_tan aux $ BasicOp $ Reshape k reshape arr_tan Rearrange perm arr -> do arr_tan <- tangent arr addStm $ Let pat_tan aux $ BasicOp $ Rearrange perm arr_tan _ -> error $ "basicFwd: Unsupported op " ++ prettyString op fwdLambda :: Lambda SOACS -> ADM (Lambda SOACS) fwdLambda l@(Lambda params ret body) = Lambda <$> bundleNewList params <*> bundleTangents ret <*> inScopeOf l (fwdBody body) fwdStreamLambda :: Lambda SOACS -> ADM (Lambda SOACS) fwdStreamLambda l@(Lambda params ret body) = Lambda <$> ((take 1 params ++) <$> bundleNewList (drop 1 params)) <*> bundleTangents ret <*> inScopeOf l (fwdBody body) interleave :: [a] -> [a] -> [a] interleave xs ys = concat $ transpose [xs, ys] zeroFromSubExp :: SubExp -> ADM VName zeroFromSubExp (Constant c) = letExp "zero" . BasicOp . SubExp . Constant $ blankPrimValue (primValueType c) zeroFromSubExp (Var v) = do t <- lookupType v letExp "zero" $ zeroExp t fwdSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM () fwdSOAC pat aux (Screma size xs (ScremaForm f scs reds)) = do pat' <- bundleNewPat pat xs' <- bundleTangents xs f' <- fwdLambda f scs' <- mapM fwdScan scs reds' <- mapM fwdRed reds addStm $ Let pat' aux $ Op $ Screma size xs' $ ScremaForm f' scs' reds' where fwdScan :: Scan SOACS -> ADM (Scan SOACS) fwdScan sc = do op' <- fwdLambda $ scanLambda sc neutral_tans <- mapM zeroFromSubExp $ scanNeutral sc pure $ Scan { scanNeutral = scanNeutral sc `interleave` map Var neutral_tans, scanLambda = op' } fwdRed :: Reduce SOACS -> ADM (Reduce SOACS) fwdRed red = do op' <- fwdLambda $ redLambda red neutral_tans <- mapM zeroFromSubExp $ redNeutral red pure $ Reduce { redComm = redComm red, redLambda = op', redNeutral = redNeutral red `interleave` map Var neutral_tans } fwdSOAC pat aux (Stream size xs nes lam) = do pat' <- bundleNewPat pat lam' <- fwdStreamLambda lam xs' <- bundleTangents xs nes_tan <- mapM (fmap Var . zeroFromSubExp) nes let nes' = interleave nes nes_tan addStm $ Let pat' aux $ Op $ Stream size xs' nes' lam' fwdSOAC pat aux (Hist w arrs ops bucket_fun) = do pat' <- bundleNewPat pat ops' <- mapM fwdHist ops bucket_fun' <- fwdHistBucket bucket_fun arrs' <- bundleTangents arrs addStm $ Let pat' aux $ Op $ Hist w arrs' ops' bucket_fun' where n_indices = sum $ map (shapeRank . histShape) ops fwdBodyHist (Body _ stms res) = buildBody_ $ do mapM_ fwdStm stms let (res_is, res_vs) = splitAt n_indices res (res_is ++) <$> bundleTangents res_vs fwdHistBucket l@(Lambda params ret body) = let (r_is, r_vs) = splitAt n_indices ret in Lambda <$> bundleNewList params <*> ((r_is ++) <$> bundleTangents r_vs) <*> inScopeOf l (fwdBodyHist body) fwdHist :: HistOp SOACS -> ADM (HistOp SOACS) fwdHist (HistOp shape rf dest nes op) = do dest' <- bundleTangents dest nes_tan <- mapM (fmap Var . zeroFromSubExp) nes op' <- fwdLambda op pure $ HistOp { histShape = shape, histRaceFactor = rf, histDest = dest', histNeutral = interleave nes nes_tan, histOp = op' } fwdSOAC (Pat pes) aux (Scatter w ivs as lam) = do as_tan <- mapM (\(s, n, a) -> do a_tan <- tangent a; pure (s, n, a_tan)) as pes_tan <- mapM newTan pes ivs' <- bundleTangents ivs let (as_ws, as_ns, _as_vs) = unzip3 as n_indices = sum $ zipWith (*) as_ns $ map length as_ws lam' <- fwdScatterLambda n_indices lam let s = Let (Pat (pes ++ pes_tan)) aux $ Op $ Scatter w ivs' (as ++ as_tan) lam' addStm s where fwdScatterLambda :: Int -> Lambda SOACS -> ADM (Lambda SOACS) fwdScatterLambda n_indices (Lambda params ret body) = do params' <- bundleNewList params ret_tan <- mapM tangent $ drop n_indices ret body' <- fwdBodyScatter n_indices body let indices = concat $ replicate 2 $ take n_indices ret ret' = indices ++ drop n_indices ret ++ ret_tan pure $ Lambda params' ret' body' fwdBodyScatter :: Int -> Body SOACS -> ADM (Body SOACS) fwdBodyScatter n_indices (Body _ stms res) = do (res_tan, stms') <- collectStms $ do mapM_ fwdStm stms mapM tangent $ drop n_indices res let indices = concat $ replicate 2 $ take n_indices res res' = indices ++ drop n_indices res ++ res_tan pure $ mkBody stms' res' fwdSOAC _ _ JVP {} = error "fwdSOAC: nested JVP not allowed." fwdSOAC _ _ VJP {} = error "fwdSOAC: nested VJP not allowed." fwdStm :: Stm SOACS -> ADM () fwdStm (Let pat aux (BasicOp (UpdateAcc safety acc i x))) = do pat' <- bundleNewPat pat x' <- bundleTangents x acc_tan <- tangent acc addStm $ Let pat' aux $ BasicOp $ UpdateAcc safety acc_tan i x' fwdStm stm@(Let pat aux (BasicOp e)) = do -- XXX: this has to be too naive. unless (any isAcc $ patTypes pat) $ addStm stm basicFwd pat aux e fwdStm stm@(Let pat _ (Apply f args _ _)) | Just (ret, argts) <- M.lookup f builtInFunctions = do addStm stm arg_tans <- zipWith primExpFromSubExp argts <$> mapM (tangent . fst) args pat_tan <- newTanPat pat let arg_pes = zipWith primExpFromSubExp argts (map fst args) case pdBuiltin f arg_pes of Nothing -> error $ "No partial derivative defined for builtin function: " ++ prettyString f Just derivs -> do let convertTo tt e | e_t == tt = e | otherwise = case (tt, e_t) of (IntType tt', IntType ft) -> ConvOpExp (SExt ft tt') e (FloatType tt', FloatType ft) -> ConvOpExp (FPConv ft tt') e (Bool, FloatType ft) -> ConvOpExp (FToB ft) e (FloatType tt', Bool) -> ConvOpExp (BToF tt') e _ -> error $ "fwdStm.convertTo: " ++ prettyString (f, tt, e_t) where e_t = primExpType e zipWithM_ (letBindNames . pure) (patNames pat_tan) =<< mapM toExp (zipWith (~*~) (map (convertTo ret) arg_tans) derivs) fwdStm (Let pat aux (Match ses cases defbody (MatchDec ret ifsort))) = do cases' <- slocal' $ mapM (traverse fwdBody) cases defbody' <- slocal' $ fwdBody defbody pat' <- bundleNewPat pat ret' <- bundleTangents ret addStm $ Let pat' aux $ Match ses cases' defbody' $ MatchDec ret' ifsort fwdStm (Let pat aux (Loop val_pats loop@(WhileLoop v) body)) = do val_pats' <- bundleNewList val_pats pat' <- bundleNewPat pat body' <- localScope (scopeOfFParams (map fst val_pats) <> scopeOfLoopForm loop) . slocal' $ fwdBody body addStm $ Let pat' aux $ Loop val_pats' (WhileLoop v) body' fwdStm (Let pat aux (Loop val_pats loop@(ForLoop i it bound) body)) = do pat' <- bundleNewPat pat val_pats' <- bundleNewList val_pats body' <- localScope (scopeOfFParams (map fst val_pats) <> scopeOfLoopForm loop) . slocal' $ fwdBody body addStm $ Let pat' aux $ Loop val_pats' (ForLoop i it bound) body' fwdStm (Let pat aux (WithAcc inputs lam)) = do inputs' <- forM inputs $ \(shape, arrs, op) -> do arrs_tan <- mapM tangent arrs op' <- case op of Nothing -> pure Nothing Just (op_lam, nes) -> do nes_tan <- mapM (fmap Var . zeroFromSubExp) nes op_lam' <- fwdLambda op_lam case op_lam' of Lambda ps ret body -> do let op_lam'' = Lambda (removeIndexTans (shapeRank shape) ps) ret body pure $ Just (op_lam'', interleave nes nes_tan) pure (shape, arrs <> arrs_tan, op') pat' <- bundleNewPat pat lam' <- fwdLambda lam addStm $ Let pat' aux $ WithAcc inputs' lam' where removeIndexTans 0 ps = ps removeIndexTans i (p : _ : ps) = p : removeIndexTans (i - 1) ps removeIndexTans _ ps = ps fwdStm (Let pat aux (Op soac)) = fwdSOAC pat aux soac fwdStm stm = error $ "unhandled forward mode AD for Stm: " ++ prettyString stm ++ "\n" ++ show stm fwdBody :: Body SOACS -> ADM (Body SOACS) fwdBody (Body _ stms res) = buildBody_ $ do mapM_ fwdStm stms bundleTangents res fwdBodyTansLast :: Body SOACS -> ADM (Body SOACS) fwdBodyTansLast (Body _ stms res) = buildBody_ $ do mapM_ fwdStm stms (res <>) <$> mapM tangent res fwdJVP :: (MonadFreshNames m) => Scope SOACS -> Lambda SOACS -> m (Lambda SOACS) fwdJVP scope l@(Lambda params ret body) = runADM . localScope scope . inScopeOf l $ do params_tan <- mapM newTan params body_tan <- fwdBodyTansLast body ret_tan <- mapM tangent ret pure $ Lambda (params ++ params_tan) (ret <> ret_tan) body_tan futhark-0.25.27/src/Futhark/AD/Rev.hs000066400000000000000000000376261475065116200171270ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- Naming scheme: -- -- An adjoint-related object for "x" is named "x_adj". This means -- both actual adjoints and statements. -- -- Do not assume "x'" means anything related to derivatives. module Futhark.AD.Rev (revVJP) where import Control.Monad import Data.List ((\\)) import Data.List.NonEmpty (NonEmpty (..)) import Data.Map qualified as M import Futhark.AD.Derivatives import Futhark.AD.Rev.Loop import Futhark.AD.Rev.Monad import Futhark.AD.Rev.SOAC import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools import Futhark.Transform.Rename import Futhark.Transform.Substitute import Futhark.Util (takeLast) patName :: Pat Type -> ADM VName patName (Pat [pe]) = pure $ patElemName pe patName pat = error $ "Expected single-element pattern: " ++ prettyString pat -- The vast majority of BasicOps require no special treatment in the -- forward pass and produce one value (and hence one adjoint). We -- deal with that case here. commonBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName) commonBasicOp pat aux op m = do addStm $ Let pat aux $ BasicOp op m pat_v <- patName pat pat_adj <- lookupAdjVal pat_v pure (pat_v, pat_adj) diffBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM () diffBasicOp pat aux e m = case e of CmpOp cmp x y -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do let t = cmpOpType cmp update contrib = do void $ updateSubExpAdj x contrib void $ updateSubExpAdj y contrib case t of FloatType ft -> update <=< letExp "contrib" $ Match [Var pat_adj] [Case [Just $ BoolValue True] $ resultBody [constant (floatValue ft (1 :: Int))]] (resultBody [constant (floatValue ft (0 :: Int))]) (MatchDec [Prim (FloatType ft)] MatchNormal) IntType it -> update <=< letExp "contrib" $ BasicOp $ ConvOp (BToI it) (Var pat_adj) Bool -> update pat_adj Unit -> pure () -- ConvOp op x -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do contrib <- letExp "contrib" $ BasicOp $ ConvOp (flipConvOp op) $ Var pat_adj updateSubExpAdj x contrib -- UnOp op x -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do let t = unOpType op contrib <- do let x_pe = primExpFromSubExp t x pat_adj' = primExpFromSubExp t (Var pat_adj) dx = pdUnOp op x_pe letExp "contrib" <=< toExp $ pat_adj' ~*~ dx updateSubExpAdj x contrib -- BinOp op x y -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do let t = binOpType op (wrt_x, wrt_y) = pdBinOp op (primExpFromSubExp t x) (primExpFromSubExp t y) pat_adj' = primExpFromSubExp t $ Var pat_adj adj_x <- letExp "binop_x_adj" <=< toExp $ pat_adj' ~*~ wrt_x adj_y <- letExp "binop_y_adj" <=< toExp $ pat_adj' ~*~ wrt_y updateSubExpAdj x adj_x updateSubExpAdj y adj_y -- SubExp se -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ updateSubExpAdj se pat_adj -- Assert {} -> void $ commonBasicOp pat aux e m -- ArrayVal {} -> void $ commonBasicOp pat aux e m -- ArrayLit elems _ -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m t <- lookupType pat_adj returnSweepCode $ do forM_ (zip [(0 :: Int64) ..] elems) $ \(i, se) -> do let slice = fullSlice t [DimFix (constant i)] updateSubExpAdj se <=< letExp "elem_adj" $ BasicOp $ Index pat_adj slice -- Index arr slice -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do void $ updateAdjSlice slice arr pat_adj FlatIndex {} -> error "FlatIndex not handled by AD yet." FlatUpdate {} -> error "FlatUpdate not handled by AD yet." -- Opaque _ se -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ updateSubExpAdj se pat_adj -- Reshape k _ arr -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do arr_shape <- arrayShape <$> lookupType arr void $ updateAdj arr <=< letExp "adj_reshape" . BasicOp $ Reshape k arr_shape pat_adj -- Rearrange perm arr -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ void $ updateAdj arr <=< letExp "adj_rearrange" . BasicOp $ Rearrange (rearrangeInverse perm) pat_adj -- Replicate (Shape []) (Var se) -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ void $ updateAdj se pat_adj -- Replicate (Shape ns) x -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do x_t <- subExpType x lam <- addLambda x_t ne <- letSubExp "zero" $ zeroExp x_t n <- letSubExp "rep_size" =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) ns pat_adj_flat <- letExp (baseString pat_adj <> "_flat") . BasicOp $ Reshape ReshapeArbitrary (Shape $ n : arrayDims x_t) pat_adj reduce <- reduceSOAC [Reduce Commutative lam [ne]] updateSubExpAdj x =<< letExp "rep_contrib" (Op $ Screma n [pat_adj_flat] reduce) -- Concat d (arr :| arrs) _ -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do let sliceAdj _ [] = pure [] sliceAdj start (v : vs) = do v_t <- lookupType v let w = arraySize 0 v_t slice = DimSlice start w (intConst Int64 1) pat_adj_slice <- letExp (baseString pat_adj <> "_slice") $ BasicOp $ Index pat_adj (sliceAt v_t d [slice]) start' <- letSubExp "start" $ BasicOp $ BinOp (Add Int64 OverflowUndef) start w slices <- sliceAdj start' vs pure $ pat_adj_slice : slices slices <- sliceAdj (intConst Int64 0) $ arr : arrs zipWithM_ updateAdj (arr : arrs) slices -- Manifest _ se -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ void $ updateAdj se pat_adj -- Scratch {} -> void $ commonBasicOp pat aux e m -- Iota n _ _ t -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do ne <- letSubExp "zero" $ zeroExp $ Prim $ IntType t lam <- addLambda $ Prim $ IntType t reduce <- reduceSOAC [Reduce Commutative lam [ne]] updateSubExpAdj n =<< letExp "iota_contrib" (Op $ Screma n [pat_adj] reduce) -- Update safety arr slice v -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do v_adj <- letExp "update_val_adj" $ BasicOp $ Index pat_adj slice t <- lookupType v_adj v_adj_copy <- case t of Array {} -> letExp "update_val_adj_copy" . BasicOp $ Replicate mempty (Var v_adj) _ -> pure v_adj updateSubExpAdj v v_adj_copy zeroes <- letSubExp "update_zero" . zeroExp =<< subExpType v void $ updateAdj arr =<< letExp "update_src_adj" (BasicOp $ Update safety pat_adj slice zeroes) -- See Note [Adjoints of accumulators] UpdateAcc _ _ is vs -> do addStm $ Let pat aux $ BasicOp e m pat_adjs <- mapM lookupAdjVal (patNames pat) returnSweepCode $ do forM_ (zip pat_adjs vs) $ \(adj, v) -> do adj_i <- letExp "updateacc_val_adj" $ BasicOp $ Index adj $ Slice $ map DimFix is updateSubExpAdj v adj_i vjpOps :: VjpOps vjpOps = VjpOps { vjpLambda = diffLambda, vjpStm = diffStm } diffStm :: Stm SOACS -> ADM () -> ADM () diffStm (Let pat aux (BasicOp e)) m = diffBasicOp pat aux e m diffStm stm@(Let pat _ (Apply f args _ _)) m | Just (ret, argts) <- M.lookup f builtInFunctions = do addStm stm m pat_adj <- lookupAdjVal =<< patName pat let arg_pes = zipWith primExpFromSubExp argts (map fst args) pat_adj' = primExpFromSubExp ret (Var pat_adj) convert ft tt | ft == tt = id convert (IntType ft) (IntType tt) = ConvOpExp (SExt ft tt) convert (FloatType ft) (FloatType tt) = ConvOpExp (FPConv ft tt) convert Bool (FloatType tt) = ConvOpExp (BToF tt) convert (FloatType ft) Bool = ConvOpExp (FToB ft) convert ft tt = error $ "diffStm.convert: " ++ prettyString (f, ft, tt) contribs <- case pdBuiltin f arg_pes of Nothing -> error $ "No partial derivative defined for builtin function: " ++ prettyString f Just derivs -> forM (zip derivs argts) $ \(deriv, argt) -> letExp "contrib" <=< toExp . convert ret argt $ pat_adj' ~*~ deriv zipWithM_ updateSubExpAdj (map fst args) contribs diffStm stm@(Let pat _ (Match ses cases defbody _)) m = do addStm stm m returnSweepCode $ do let cases_free = map freeIn cases defbody_free = freeIn defbody branches_free = namesToList $ mconcat $ defbody_free : cases_free adjs <- mapM lookupAdj $ patNames pat branches_free_adj <- ( pure . takeLast (length branches_free) <=< letTupExp "branch_adj" <=< renameExp ) =<< eMatch ses (map (fmap $ diffBody adjs branches_free) cases) (diffBody adjs branches_free defbody) zipWithM_ insAdj branches_free branches_free_adj diffStm (Let pat aux (Op soac)) m = vjpSOAC vjpOps pat aux soac m diffStm (Let pat aux loop@Loop {}) m = diffLoop diffStms pat aux loop m -- See Note [Adjoints of accumulators] diffStm stm@(Let pat _aux (WithAcc inputs lam)) m = do addStm stm m returnSweepCode $ do adjs <- mapM lookupAdj $ patNames pat lam' <- renameLambda lam free_vars <- filterM isActive $ namesToList $ freeIn lam' free_accs <- filterM (fmap isAcc . lookupType) free_vars let free_vars' = free_vars \\ free_accs lam'' <- diffLambda' adjs free_vars' lam' inputs' <- mapM renameInputLambda inputs free_adjs <- letTupExp "with_acc_contrib" $ WithAcc inputs' lam'' zipWithM_ insAdj (arrs <> free_vars') free_adjs where arrs = concatMap (\(_, as, _) -> as) inputs renameInputLambda (shape, as, Just (f, nes)) = do f' <- renameLambda f pure (shape, as, Just (f', nes)) renameInputLambda input = pure input diffLambda' res_adjs get_adjs_for (Lambda params ts body) = localScope (scopeOfLParams params) $ do Body () stms res <- diffBody res_adjs get_adjs_for body let body' = Body () stms $ take (length inputs) res <> takeLast (length get_adjs_for) res ts' <- mapM lookupType get_adjs_for pure $ Lambda params (take (length inputs) ts <> ts') body' diffStm stm _ = error $ "diffStm unhandled:\n" ++ prettyString stm diffStms :: Stms SOACS -> ADM () diffStms all_stms | Just (stm, stms) <- stmsHead all_stms = do (subst, copy_stms) <- copyConsumedArrsInStm stm let (stm', stms') = substituteNames subst (stm, stms) diffStms copy_stms >> diffStm stm' (diffStms stms') forM_ (M.toList subst) $ \(from, to) -> setAdj from =<< lookupAdj to | otherwise = pure () -- | Preprocess statements before differentiating. -- For now, it's just stripmining. preprocess :: Stms SOACS -> ADM (Stms SOACS) preprocess = stripmineStms diffBody :: [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS) diffBody res_adjs get_adjs_for (Body () stms res) = subAD $ subSubsts $ do let onResult (SubExpRes _ (Constant _)) _ = pure () onResult (SubExpRes _ (Var v)) v_adj = void $ updateAdj v =<< adjVal v_adj (adjs, stms') <- collectStms $ do zipWithM_ onResult (takeLast (length res_adjs) res) res_adjs diffStms =<< preprocess stms mapM lookupAdjVal get_adjs_for pure $ Body () stms' $ res <> varsRes adjs diffLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS) diffLambda res_adjs get_adjs_for (Lambda params _ body) = localScope (scopeOfLParams params) $ do Body () stms res <- diffBody res_adjs get_adjs_for body let body' = Body () stms $ takeLast (length get_adjs_for) res ts' <- mapM lookupType get_adjs_for pure $ Lambda params ts' body' revVJP :: (MonadFreshNames m) => Scope SOACS -> Lambda SOACS -> m (Lambda SOACS) revVJP scope (Lambda params ts body) = runADM . localScope (scope <> scopeOfLParams params) $ do params_adj <- forM (zip (map resSubExp (bodyResult body)) ts) $ \(se, t) -> Param mempty <$> maybe (newVName "const_adj") adjVName (subExpVar se) <*> pure t body' <- localScope (scopeOfLParams params_adj) $ diffBody (map adjFromParam params_adj) (map paramName params) body pure $ Lambda (params ++ params_adj) (ts <> map paramType params) body' -- Note [Adjoints of accumulators] -- -- The general case of taking adjoints of WithAcc is tricky. We make -- some assumptions and lay down a basic design. -- -- First, we assume that any WithAccs that occur in the program are -- the result of previous invocations of VJP. This means we can rely -- on the operator having a constant adjoint (it's some kind of -- addition). -- -- Second, the adjoint of an accumulator is an array of the same type -- as the underlying array. For example, the adjoint type of the -- primal type 'acc(c, [n], {f64})' is '[n]f64'. In principle the -- adjoint of 'acc(c, [n], {f64,f32})' should be two arrays of type -- '[]f64', '[]f32'. Our current design assumes that adjoints are -- single variables. This is fixable. -- -- # Adjoint of UpdateAcc -- -- Consider primal code -- -- update_acc(acc, i, v) -- -- Interpreted as an imperative statement, this means -- -- acc[i] ⊕= v -- -- for some '⊕'. Normally all the compiler knows of '⊕' is that it -- is associative and commutative, but because we assume that all -- accumulators are the result of previous AD transformations, we -- can assume that '⊕' actually behaves like addition - that is, has -- unit partial derivatives. So the return sweep is -- -- v += acc_adj[i] -- -- # Adjoint of Map -- -- Suppose we have primal code -- -- let acc' = -- map (...) acc -- -- where "acc : acc(c, [n], {f64})" and the width of the Map is "w". -- Our normal transformation for Map input arrays is to similarly map -- their adjoint, but clearly this doesn't work here because the -- semantics of mapping an adjoint is an "implicit replicate". So -- when generating the return sweep we actually perform that -- replication: -- -- map (...) (replicate w acc_adj) -- -- But what about the contributions to "acc'"? Those we also have to -- take special care of. The result of the map itself is actually a -- multidimensional array: -- -- let acc_contribs = -- map (...) (replicate w acc'_adj) -- -- which we must then sum to add to the contribution. -- -- acc_adj += sum(acc_contribs) -- -- I'm slightly worried about the asymptotics of this, since my -- intuition of this is that the contributions might be rather sparse. -- (Maybe completely zero? If so it will be simplified away -- entirely.) Perhaps a better solution is to treat -- accumulator-inputs in the primal code as we do free variables, and -- create accumulators for them in the return sweep. -- -- # Consumption -- -- A minor problem is that our usual way of handling consumption (Note -- [Consumption]) is not viable, because accumulators are not -- copyable. Fortunately, while the accumulators that are consumed in -- the forward sweep will also be present in the return sweep given -- our current translation rules, they will be dead code. As long as -- we are careful to run dead code elimination after revVJP, we should -- be good. futhark-0.25.27/src/Futhark/AD/Rev/000077500000000000000000000000001475065116200165555ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/AD/Rev/Hist.hs000066400000000000000000001037151475065116200200270ustar00rootroot00000000000000{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeFamilies #-} module Futhark.AD.Rev.Hist ( diffMinMaxHist, diffMulHist, diffAddHist, diffVecHist, diffHist, ) where import Control.Monad import Futhark.AD.Rev.Monad import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools import Futhark.Transform.Rename getBinOpPlus :: PrimType -> BinOp getBinOpPlus (IntType x) = Add x OverflowUndef getBinOpPlus (FloatType f) = FAdd f getBinOpPlus _ = error "In getBinOpMul, Hist.hs: input not supported" getBinOpDiv :: PrimType -> BinOp getBinOpDiv (IntType t) = SDiv t Unsafe getBinOpDiv (FloatType t) = FDiv t getBinOpDiv _ = error "In getBinOpDiv, Hist.hs: input not supported" withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName withinBounds [] = TPrimExp $ ValueExp (BoolValue True) withinBounds [(q, i)] = (le64 i .<. pe64 q) .&&. (pe64 (intConst Int64 (-1)) .<. le64 i) withinBounds (qi : qis) = withinBounds [qi] .&&. withinBounds qis elseIf :: (MonadBuilder m, BranchType (Rep m) ~ ExtType) => PrimType -> [(m (Exp (Rep m)), m (Exp (Rep m)))] -> [m (Body (Rep m))] -> m (Exp (Rep m)) elseIf t [(c1, c2)] [bt, bf] = eIf (eCmpOp (CmpEq t) c1 c2) bt bf elseIf t ((c1, c2) : cs) (bt : bs) = eIf (eCmpOp (CmpEq t) c1 c2) bt $ eBody $ pure $ elseIf t cs bs elseIf _ _ _ = error "In elseIf, Hist.hs: input not supported" bindSubExpRes :: (MonadBuilder m) => String -> [SubExpRes] -> m [VName] bindSubExpRes s = traverse ( \(SubExpRes cs se) -> do bn <- newVName s certifying cs $ letBindNames [bn] $ BasicOp $ SubExp se pure bn ) nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS) nestedmap [] _ lam = pure lam nestedmap s@(h : r) pt lam = do params <- traverse (\tp -> newParam "x" $ Array tp (Shape s) NoUniqueness) pt body <- nestedmap r pt lam mkLambda params $ fmap varsRes . letTupExp "res" . Op $ Screma h (map paramName params) (mapSOAC body) -- \ds hs -> map2 lam ds hs mkF' :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS) mkF' lam tps n = do lam' <- renameLambda lam ds_params <- traverse (newParam "ds_param") tps hs_params <- traverse (newParam "hs_param") tps let ds_pars = fmap paramName ds_params let hs_pars = fmap paramName hs_params lam_map <- mkLambda (ds_params <> hs_params) $ fmap varsRes . letTupExp "map_f'" . Op $ Screma n (ds_pars <> hs_pars) (mapSOAC lam') pure (ds_pars, hs_pars, lam_map) -- \ls as rs -> map3 (\li ai ri -> li `lam` ai `lam` ri) ls as rs mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS) mkF lam tps n = do lam_l <- renameLambda lam lam_r <- renameLambda lam let q = length $ lambdaReturnType lam (lps, aps) = splitAt q $ lambdaParams lam_l (ips, rps) = splitAt q $ lambdaParams lam_r lam' <- mkLambda (lps <> aps <> rps) $ do lam_l_res <- bodyBind $ lambdaBody lam_l forM_ (zip ips lam_l_res) $ \(ip, SubExpRes cs se) -> certifying cs $ letBindNames [paramName ip] $ BasicOp $ SubExp se bodyBind $ lambdaBody lam_r ls_params <- traverse (newParam "ls_param") tps as_params <- traverse (newParam "as_param") tps rs_params <- traverse (newParam "rs_param") tps let map_params = ls_params <> as_params <> rs_params lam_map <- mkLambda map_params $ fmap varsRes . letTupExp "map_f" $ Op $ Screma n (map paramName map_params) $ mapSOAC lam' pure (map paramName as_params, lam_map) mapout :: VName -> SubExp -> SubExp -> ADM VName mapout is n w = do par_is <- newParam "is" $ Prim int64 is'_lam <- mkLambda [par_is] $ fmap varsRes . letTupExp "is'" =<< eIf (toExp $ withinBounds $ pure (w, paramName par_is)) (eBody $ pure $ eParam par_is) (eBody $ pure $ eSubExp w) letExp "is'" $ Op $ Screma n (pure is) $ mapSOAC is'_lam multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName] multiScatter n dst is vs = do tps <- traverse lookupType vs par_i <- newParam "i" $ Prim int64 scatter_params <- traverse (newParam "scatter_param" . rowType) tps scatter_lam <- mkLambda (par_i : scatter_params) $ fmap subExpsRes . mapM (letSubExp "scatter_map_res") =<< do p1 <- replicateM (length scatter_params) $ eParam par_i p2 <- traverse eParam scatter_params pure $ p1 <> p2 let spec = zipWith (\t -> (,,) (Shape $ pure $ arraySize 0 t) 1) tps dst letTupExp "scatter_res" . Op $ Scatter n (is : vs) spec scatter_lam multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName] multiIndex vs s = do traverse ( \x -> do t <- lookupType x letExp "sorted" $ BasicOp $ Index x (fullSlice t s) ) vs -- -- special case of histogram with min/max as operator. -- Original, assuming `is: [n]i64` and `dst: [w]btp` -- let x = reduce_by_index dst minmax ne is vs -- Forward sweep: -- need to copy dst: reverse sweep might use it 7 -- (see ex. in reducebyindexminmax6.fut where the first map requires the original dst to be differentiated). -- let dst_cpy = copy dst -- let (x, x_inds) = zip vs (iota n) -- |> reduce_by_index (dst_cpy,-1s) argminmax (ne,-1) is -- -- Reverse sweep: -- dst_bar += map2 (\i b -> if i == -1 -- then b -- else 0 -- ) x_inds x_bar -- vs_ctrbs = map2 (\i b -> if i == -1 -- then 0 -- else vs_bar[i] + b -- ) x_inds x_bar -- vs_bar <- scatter vs_bar x_inds vs_ctrbs diffMinMaxHist :: VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM () diffMinMaxHist _ops x aux n minmax ne is vs w rf dst m = do let t = binOpType minmax vs_type <- lookupType vs let vs_elm_type = elemType vs_type let vs_dims = arrayDims vs_type let inner_dims = tail vs_dims let nr_dims = length vs_dims dst_type <- lookupType dst let dst_dims = arrayDims dst_type dst_cpy <- letExp (baseString dst <> "_copy") . BasicOp $ Replicate mempty (Var dst) acc_v_p <- newParam "acc_v" $ Prim t acc_i_p <- newParam "acc_i" $ Prim int64 v_p <- newParam "v" $ Prim t i_p <- newParam "i" $ Prim int64 hist_lam_inner <- mkLambda [acc_v_p, acc_i_p, v_p, i_p] $ fmap varsRes . letTupExp "idx_res" =<< eIf (eCmpOp (CmpEq t) (eParam acc_v_p) (eParam v_p)) ( eBody [ eParam acc_v_p, eBinOp (SMin Int64) (eParam acc_i_p) (eParam i_p) ] ) ( eBody [ eIf ( eCmpOp (CmpEq t) (eParam acc_v_p) (eBinOp minmax (eParam acc_v_p) (eParam v_p)) ) (eBody [eParam acc_v_p, eParam acc_i_p]) (eBody [eParam v_p, eParam i_p]) ] ) hist_lam <- nestedmap inner_dims [vs_elm_type, int64, vs_elm_type, int64] hist_lam_inner dst_minus_ones <- letExp "minus_ones" . BasicOp $ Replicate (Shape dst_dims) (intConst Int64 (-1)) ne_minus_ones <- letSubExp "minus_ones" . BasicOp $ Replicate (Shape inner_dims) (intConst Int64 (-1)) iota_n <- letExp "red_iota" . BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64 inp_iota <- do if nr_dims == 1 then pure iota_n else do i <- newParam "i" $ Prim int64 lam <- mkLambda [i] $ fmap varsRes . letTupExp "res" =<< do pure $ BasicOp $ Replicate (Shape inner_dims) $ Var $ paramName i letExp "res" $ Op $ Screma n [iota_n] $ mapSOAC lam let hist_op = HistOp (Shape [w]) rf [dst_cpy, dst_minus_ones] [ne, if nr_dims == 1 then intConst Int64 (-1) else ne_minus_ones] hist_lam f' <- mkIdentityLambda [Prim int64, rowType vs_type, rowType $ Array int64 (Shape vs_dims) NoUniqueness] x_inds <- newVName (baseString x <> "_inds") auxing aux $ letBindNames [x, x_inds] $ Op $ Hist n [is, vs, inp_iota] [hist_op] f' m x_bar <- lookupAdjVal x x_ind_dst <- newParam (baseString x <> "_ind_param") $ Prim int64 x_bar_dst <- newParam (baseString x <> "_bar_param") $ Prim t dst_lam_inner <- mkLambda [x_ind_dst, x_bar_dst] $ fmap varsRes . letTupExp "dst_bar" =<< eIf (toExp $ le64 (paramName x_ind_dst) .==. -1) (eBody $ pure $ eParam x_bar_dst) (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) dst_lam <- nestedmap inner_dims [int64, vs_elm_type] dst_lam_inner dst_bar <- letExp (baseString dst <> "_bar") . Op $ Screma w [x_inds, x_bar] (mapSOAC dst_lam) updateAdj dst dst_bar vs_bar <- lookupAdjVal vs inds' <- traverse (letExp "inds" . BasicOp . Replicate (Shape [w]) . Var) =<< mk_indices inner_dims [] let inds = x_inds : inds' par_x_ind_vs <- replicateM nr_dims $ newParam (baseString x <> "_ind_param") $ Prim int64 par_x_bar_vs <- newParam (baseString x <> "_bar_param") $ Prim t vs_lam_inner <- mkLambda (par_x_bar_vs : par_x_ind_vs) $ fmap varsRes . letTupExp "res" =<< eIf (toExp $ le64 (paramName $ head par_x_ind_vs) .==. -1) (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) ( eBody $ pure $ do vs_bar_i <- letSubExp (baseString vs_bar <> "_el") . BasicOp $ Index vs_bar . Slice $ fmap (DimFix . Var . paramName) par_x_ind_vs eBinOp (getBinOpPlus t) (eParam par_x_bar_vs) (eSubExp vs_bar_i) ) vs_lam <- nestedmap inner_dims (vs_elm_type : replicate nr_dims int64) vs_lam_inner vs_bar_p <- letExp (baseString vs <> "_partial") . Op $ Screma w (x_bar : inds) (mapSOAC vs_lam) q <- letSubExp "q" =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) dst_dims scatter_inps <- do -- traverse (letExp "flat" . BasicOp . Reshape [DimNew q]) $ inds ++ [vs_bar_p] -- ToDo: Cosmin asks: is the below the correct translation of the line above? traverse (letExp "flat" . BasicOp . Reshape ReshapeArbitrary (Shape [q])) $ inds ++ [vs_bar_p] f'' <- mkIdentityLambda $ replicate nr_dims (Prim int64) ++ [Prim t] vs_bar' <- letExp (baseString vs <> "_bar") . Op $ Scatter q scatter_inps [(Shape vs_dims, 1, vs_bar)] f'' insAdj vs vs_bar' where mk_indices :: [SubExp] -> [SubExp] -> ADM [VName] mk_indices [] _ = pure [] mk_indices [d] iotas = do reps <- traverse (letExp "rep" . BasicOp . Replicate (Shape [d])) iotas iota_d <- letExp "red_iota" . BasicOp $ Iota d (intConst Int64 0) (intConst Int64 1) Int64 pure $ reps ++ [iota_d] mk_indices (d : dims) iotas = do iota_d <- letExp "red_iota" . BasicOp $ Iota d (intConst Int64 0) (intConst Int64 1) Int64 i_param <- newParam "i" $ Prim int64 lam <- mkLambda [i_param] $ fmap varsRes $ mk_indices dims $ iotas ++ [Var $ paramName i_param] letTupExp "res" $ Op $ Screma d [iota_d] $ mapSOAC lam -- -- special case of histogram with multiplication as operator. -- Original, assuming `is: [n]i64` and `dst: [w]btp` -- let x = reduce_by_index dst (*) ne is vs -- Forward sweep: -- dst does not need to be copied: dst is not overwritten -- let (ps, zs) = map (\v -> if v == 0 then (1,1) else (v,0)) vs -- let non_zero_prod = reduce_by_index nes (*) ne is ps -- let zero_count = reduce_by_index 0s (+) 0 is zs -- let h_part = map2 (\p c -> if c == 0 then p else 0 -- ) non_zero_prod zero_count -- let x = map2 (*) dst h_part -- -- Reverse sweep: -- dst_bar += map2 (*) h_part x_bar -- let part_bar = map2 (*) dst x_bar -- vs_bar += map2 (\i v -> let zr_cts = zero_count[i] -- let pr_bar = part_bar[i] -- let nz_prd = non_zero_prod[i] -- in if zr_cts == 0 -- then pr_bar * (nz_prd / v) -- else if zr_cts == 1 and v == 0 -- then nz_prd * pr_bar -- else 0 -- ) is vs diffMulHist :: VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM () diffMulHist _ops x aux n mul ne is vs w rf dst m = do let t = binOpType mul vs_type <- lookupType vs let vs_dims = arrayDims vs_type let vs_elm_type = elemType vs_type dst_type <- lookupType dst let dst_dims = arrayDims dst_type let inner_dims = tail vs_dims v_param <- newParam "v" $ Prim t lam_ps_zs_inner <- mkLambda [v_param] $ fmap varsRes . letTupExp "map_res" =<< eIf (eCmpOp (CmpEq t) (eParam v_param) (eSubExp $ Constant $ blankPrimValue t)) (eBody $ fmap eSubExp [Constant $ onePrimValue t, intConst Int64 1]) (eBody [eParam v_param, eSubExp $ intConst Int64 0]) lam_ps_zs <- nestedmap vs_dims [vs_elm_type] lam_ps_zs_inner ps_zs_res <- eLambda lam_ps_zs [eSubExp $ Var vs] ps_zs <- bindSubExpRes "ps_zs" ps_zs_res let [ps, zs] = ps_zs lam_mul_inner <- binOpLambda mul t lam_mul <- nestedmap inner_dims [vs_elm_type, vs_elm_type] lam_mul_inner nz_prods0 <- letExp "nz_prd" $ BasicOp $ Replicate (Shape [w]) ne let hist_nzp = HistOp (Shape [w]) rf [nz_prods0] [ne] lam_mul lam_add_inner <- binOpLambda (Add Int64 OverflowUndef) int64 lam_add <- nestedmap inner_dims [int64, int64] lam_add_inner zr_counts0 <- letExp "zr_cts" $ BasicOp $ Replicate (Shape dst_dims) (intConst Int64 0) zrn_ne <- letSubExp "zr_ne" $ BasicOp $ Replicate (Shape inner_dims) (intConst Int64 0) let hist_zrn = HistOp (Shape [w]) rf [zr_counts0] [if length vs_dims == 1 then intConst Int64 0 else zrn_ne] lam_add f' <- mkIdentityLambda [Prim int64, Prim int64, rowType vs_type, rowType $ Array int64 (Shape vs_dims) NoUniqueness] nz_prods <- newVName "non_zero_prod" zr_counts <- newVName "zero_count" auxing aux $ letBindNames [nz_prods, zr_counts] $ Op $ Hist n [is, is, ps, zs] [hist_nzp, hist_zrn] f' p_param <- newParam "prod" $ Prim t c_param <- newParam "count" $ Prim int64 lam_h_part_inner <- mkLambda [p_param, c_param] $ fmap varsRes . letTupExp "h_part" =<< eIf (toExp $ 0 .==. le64 (paramName c_param)) (eBody $ pure $ eParam p_param) (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) lam_h_part <- nestedmap dst_dims [vs_elm_type, int64] lam_h_part_inner h_part_res <- eLambda lam_h_part $ map (eSubExp . Var) [nz_prods, zr_counts] h_part' <- bindSubExpRes "h_part" h_part_res let [h_part] = h_part' lam_mul_inner' <- binOpLambda mul t lam_mul' <- nestedmap dst_dims [vs_elm_type, vs_elm_type] lam_mul_inner' x_res <- eLambda lam_mul' $ map (eSubExp . Var) [dst, h_part] x' <- bindSubExpRes "x" x_res auxing aux $ letBindNames [x] $ BasicOp $ SubExp $ Var $ head x' m x_bar <- lookupAdjVal x lam_mul'' <- renameLambda lam_mul' dst_bar_res <- eLambda lam_mul'' $ map (eSubExp . Var) [h_part, x_bar] dst_bar <- bindSubExpRes (baseString dst <> "_bar") dst_bar_res updateAdj dst $ head dst_bar lam_mul''' <- renameLambda lam_mul' part_bar_res <- eLambda lam_mul''' $ map (eSubExp . Var) [dst, x_bar] part_bar' <- bindSubExpRes "part_bar" part_bar_res let [part_bar] = part_bar' inner_params <- zipWithM newParam ["zr_cts", "pr_bar", "nz_prd", "a"] $ map Prim [int64, t, t, t] let [zr_cts, pr_bar, nz_prd, a_param] = inner_params lam_vsbar_inner <- mkLambda inner_params $ fmap varsRes . letTupExp "vs_bar" =<< do eIf (eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 0) (eParam zr_cts)) (eBody $ pure $ eBinOp mul (eParam pr_bar) $ eBinOp (getBinOpDiv t) (eParam nz_prd) $ eParam a_param) ( eBody $ pure $ eIf ( eBinOp LogAnd (eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 1) (eParam zr_cts)) (eCmpOp (CmpEq t) (eSubExp $ Constant $ blankPrimValue t) $ eParam a_param) ) (eBody $ pure $ eBinOp mul (eParam nz_prd) (eParam pr_bar)) (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) ) lam_vsbar_middle <- nestedmap inner_dims [int64, t, t, t] lam_vsbar_inner i_param <- newParam "i" $ Prim int64 a_param' <- newParam "a" $ rowType vs_type lam_vsbar <- mkLambda [i_param, a_param'] $ fmap varsRes . letTupExp "vs_bar" =<< eIf (toExp $ withinBounds $ pure (w, paramName i_param)) ( buildBody_ $ do let i = fullSlice vs_type [DimFix $ Var $ paramName i_param] names <- traverse newVName ["zr_cts", "pr_bar", "nz_prd"] zipWithM_ (\name -> letBindNames [name] . BasicOp . flip Index i) names [zr_counts, part_bar, nz_prods] eLambda lam_vsbar_middle $ map (eSubExp . Var) names <> [eParam a_param'] ) (eBody $ pure $ pure $ zeroExp $ rowType dst_type) vs_bar <- letExp (baseString vs <> "_bar") $ Op $ Screma n [is, vs] $ mapSOAC lam_vsbar updateAdj vs vs_bar -- -- special case of histogram with add as operator. -- Original, assuming `is: [n]i64` and `dst: [w]btp` -- let x = reduce_by_index dst (+) ne is vs -- Forward sweep: -- need to copy dst: reverse sweep might use it 7 -- (see ex. in reducebyindexminmax6.fut where the first map requires the original dst to be differentiated). -- let dst_cpy = copy dst -- let x = reduce_by_index dst_cpy (+) ne is vs -- -- Reverse sweep: -- dst_bar += x_bar -- -- vs_bar += map (\i -> x_bar[i]) is diffAddHist :: VjpOps -> VName -> StmAux () -> SubExp -> Lambda SOACS -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM () diffAddHist _ops x aux n add ne is vs w rf dst m = do let t = paramDec $ head $ lambdaParams add dst_cpy <- letExp (baseString dst <> "_copy") . BasicOp $ Replicate mempty (Var dst) f <- mkIdentityLambda [Prim int64, t] auxing aux . letBindNames [x] . Op $ Hist n [is, vs] [HistOp (Shape [w]) rf [dst_cpy] [ne] add] f m x_bar <- lookupAdjVal x updateAdj dst x_bar x_type <- lookupType x i_param <- newParam (baseString vs <> "_i") $ Prim int64 let i = paramName i_param lam_vsbar <- mkLambda [i_param] $ fmap varsRes . letTupExp "vs_bar" =<< eIf (toExp $ withinBounds $ pure (w, i)) (eBody $ pure $ pure $ BasicOp $ Index x_bar $ fullSlice x_type [DimFix $ Var i]) (eBody $ pure $ eSubExp ne) vs_bar <- letExp (baseString vs <> "_bar") $ Op $ Screma n [is] $ mapSOAC lam_vsbar updateAdj vs vs_bar -- Special case for vectorised combining operator. Rewrite -- reduce_by_index dst (map2 op) nes is vss -- to -- map3 (\dst_col vss_col ne -> -- reduce_by_index dst_col op ne is vss_col -- ) (transpose dst) (transpose vss) nes |> transpose -- before differentiating. diffVecHist :: VjpOps -> VName -> StmAux () -> SubExp -> Lambda SOACS -> VName -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM () diffVecHist ops x aux n op nes is vss w rf dst m = do stms <- collectStms_ $ do rank <- arrayRank <$> lookupType vss let dims = [1, 0] ++ drop 2 [0 .. rank - 1] dstT <- letExp "dstT" $ BasicOp $ Rearrange dims dst vssT <- letExp "vssT" $ BasicOp $ Rearrange dims vss t_dstT <- lookupType dstT t_vssT <- lookupType vssT t_nes <- lookupType nes dst_col <- newParam "dst_col" $ rowType t_dstT vss_col <- newParam "vss_col" $ rowType t_vssT ne <- newParam "ne" $ rowType t_nes f <- mkIdentityLambda (Prim int64 : lambdaReturnType op) map_lam <- mkLambda [dst_col, vss_col, ne] $ do -- TODO Have to copy dst_col, but isn't it already unique? dst_col_cpy <- letExp "dst_col_cpy" . BasicOp $ Replicate mempty (Var $ paramName dst_col) fmap (varsRes . pure) . letExp "col_res" $ Op $ Hist n [is, paramName vss_col] [HistOp (Shape [w]) rf [dst_col_cpy] [Var $ paramName ne] op] f histT <- letExp "histT" $ Op $ Screma (arraySize 0 t_dstT) [dstT, vssT, nes] $ mapSOAC map_lam auxing aux . letBindNames [x] . BasicOp $ Rearrange dims histT foldr (vjpStm ops) m stms -- -- a step in the radix sort implementation -- it assumes the key we are sorting -- after is [n]i64 and it is the first VName -- -- local def radix_sort_step [n] 't (xs: [n]t) (get_bit: i32 -> t -> i32) -- (digit_n: i32): [n]t = -- let num x = get_bit (digit_n+1) x * 2 + get_bit digit_n x -- let pairwise op (a1,b1,c1,d1) (a2,b2,c2,d2) = -- (a1 `op` a2, b1 `op` b2, c1 `op` c2, d1 `op` d2) -- let bins = xs |> map num -- let flags = bins |> map (\x -> if x == 0 then (1,0,0,0) -- else if x == 1 then (0,1,0,0) -- else if x == 2 then (0,0,1,0) -- else (0,0,0,1)) -- let offsets = scan (pairwise (+)) (0,0,0,0) flags -- let (na,nb,nc,_nd) = last offsets -- let f bin (a,b,c,d) = match bin -- case 0 -> a-1 -- case 1 -> na+b-1 -- case 2 -> na+nb+c-1 -- case _ -> na+nb+nc+d-1 -- let is = map2 f bins offsets -- in scatter scratch is xs radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName] radixSortStep xs tps bit n w = do -- let is = head xs is <- mapout (head xs) n w num_param <- newParam "num" $ Prim int64 num_lam <- mkLambda [num_param] $ fmap varsRes . letTupExp "num_res" =<< eBinOp (Add Int64 OverflowUndef) ( eBinOp (And Int64) (eBinOp (AShr Int64) (eParam num_param) (eSubExp bit)) (iConst 1) ) ( eBinOp (Mul Int64 OverflowUndef) (iConst 2) ( eBinOp (And Int64) (eBinOp (AShr Int64) (eParam num_param) (eBinOp (Add Int64 OverflowUndef) (eSubExp bit) (iConst 1))) (iConst 1) ) ) bins <- letExp "bins" $ Op $ Screma n [is] $ mapSOAC num_lam flag_param <- newParam "flag" $ Prim int64 flag_lam <- mkLambda [flag_param] $ fmap varsRes . letTupExp "flag_res" =<< elseIf int64 (map ((,) (eParam flag_param) . iConst) [0 .. 2]) (map (eBody . fmap iConst . (\i -> map (\j -> if i == j then 1 else 0) [0 .. 3])) ([0 .. 3] :: [Integer])) flags <- letTupExp "flags" $ Op $ Screma n [bins] $ mapSOAC flag_lam scan_params <- traverse (flip newParam $ Prim int64) ["a1", "b1", "c1", "d1", "a2", "b2", "c2", "d2"] scan_lam <- mkLambda scan_params $ fmap subExpsRes . mapM (letSubExp "scan_res") =<< do uncurry (zipWithM (eBinOp $ Add Int64 OverflowUndef)) $ splitAt 4 $ map eParam scan_params scan <- scanSOAC $ pure $ Scan scan_lam $ map (intConst Int64) [0, 0, 0, 0] offsets <- letTupExp "offsets" $ Op $ Screma n flags scan ind <- letSubExp "ind_last" =<< eBinOp (Sub Int64 OverflowUndef) (eSubExp n) (iConst 1) let i = Slice [DimFix ind] nabcd <- traverse newVName ["na", "nb", "nc", "nd"] zipWithM_ (\abcd -> letBindNames [abcd] . BasicOp . flip Index i) nabcd offsets let vars = map Var nabcd map_params <- traverse (flip newParam $ Prim int64) ["bin", "a", "b", "c", "d"] map_lam <- mkLambda map_params $ fmap varsRes . letTupExp "map_res" =<< elseIf int64 (map ((,) (eParam $ head map_params) . iConst) [0 .. 2]) ( zipWith ( \j p -> eBody $ pure $ do t <- letSubExp "t" =<< eBinOp (Sub Int64 OverflowUndef) (eParam p) (iConst 1) foldBinOp (Add Int64 OverflowUndef) (intConst Int64 0) (t : take j vars) ) [0 .. 3] (tail map_params) ) nis <- letExp "nis" $ Op $ Screma n (bins : offsets) $ mapSOAC map_lam scatter_dst <- traverse (\t -> letExp "scatter_dst" $ BasicOp $ Scratch (elemType t) (arrayDims t)) tps multiScatter n scatter_dst nis xs where iConst c = eSubExp $ intConst Int64 c -- -- the radix sort implementation -- def radix_sort [n] 't (xs: [n]i64) = -- let iters = if n == 0 then 0 else 32 -- in loop xs for i < iters do radix_sort_step xs i64.get_bit (i*2) radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName] radixSort xs n w = do logw <- log2 =<< letSubExp "w1" =<< toExp (pe64 w + 1) -- ceil logw by (logw + 1) / 2 iters <- letSubExp "iters" =<< toExp (untyped (pe64 logw + 1) ~/~ untyped (pe64 (intConst Int64 2))) types <- traverse lookupType xs params <- zipWithM (\x -> newParam (baseString x) . flip toDecl Nonunique) xs types i <- newVName "i" loopbody <- buildBody_ . localScope (scopeOfFParams params) $ fmap varsRes $ do bit <- letSubExp "bit" =<< toExp (le64 i * 2) radixSortStep (map paramName params) types bit n w letTupExp "sorted" $ Loop (zip params $ map Var xs) (ForLoop i Int64 iters) loopbody where log2 :: SubExp -> ADM SubExp log2 m = do params <- zipWithM newParam ["cond", "r", "i"] $ map Prim [Bool, int64, int64] let [cond, r, i] = params body <- buildBody_ . localScope (scopeOfFParams params) $ do r' <- letSubExp "r'" =<< toExp (le64 (paramName r) .>>. 1) cond' <- letSubExp "cond'" =<< toExp (bNot $ pe64 r' .==. 0) i' <- letSubExp "i'" =<< toExp (le64 (paramName i) + 1) pure $ subExpsRes [cond', r', i'] cond_init <- letSubExp "test" =<< toExp (bNot $ pe64 m .==. 0) l <- letTupExp' "log2res" $ Loop (zip params [cond_init, m, Constant $ blankPrimValue int64]) (WhileLoop $ paramName cond) body let [_, _, res] = l pure res radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName] radixSort' xs n w = do iota_n <- letExp "red_iota" . BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64 radres <- radixSort [head xs, iota_n] n w let [is', iota'] = radres i_param <- newParam "i" $ Prim int64 let slice = [DimFix $ Var $ paramName i_param] map_lam <- mkLambda [i_param] $ varsRes <$> multiIndex (tail xs) slice sorted <- letTupExp "sorted" $ Op $ Screma n [iota'] $ mapSOAC map_lam pure $ iota' : is' : sorted -- -- generic case of histogram. -- Original, assuming `is: [n]i64` and `dst: [w]btp` -- let xs = reduce_by_index dst odot ne is as -- Forward sweep: -- let h_part = reduce_by_index (replicate w ne) odot ne is as -- let xs = map2 odot dst h_part -- Reverse sweep: -- h_part_bar += f'' dst h_part -- dst_bar += f' dst h_part -- let flag = map (\i -> i == 0 || sis[i] != sis[i-1]) (iota n) -- let flag_rev = map (\i -> i==0 || flag[n-i]) (iota n) -- let ls = seg_scan_exc odot ne flag sas -- let rs = reverse sas |> -- seg_scan_exc odot ne flag_rev |> reverse -- let f_bar = map (\i -> if i < w && -1 < w -- then h_part_bar[i] -- else 0s -- ) sis -- let sas_bar = f f_dst ls sas rs -- as_bar += scatter (Scratch alpha n) siota sas_bar -- Where: -- siota: 'iota n' sorted wrt 'is' -- sis: 'is' sorted wrt 'is' -- sas: 'as' sorted wrt 'is' -- f'' = vjpLambda xs_bar h_part (map2 odot) -- f' = vjpLambda xs_bar dst (map2 odot) -- f = vjpLambda f_bar sas (map4 (\di li ai ri -> di odot li odot ai odot ri)) -- 0s is an alpha-dimensional array with 0 (possibly 0-dim) diffHist :: VjpOps -> [VName] -> StmAux () -> SubExp -> Lambda SOACS -> [SubExp] -> [VName] -> [SubExp] -> SubExp -> [VName] -> ADM () -> ADM () diffHist ops xs aux n lam0 ne as w rf dst m = do as_type <- traverse lookupType $ tail as dst_type <- traverse lookupType dst nes <- traverse (letExp "new_dst" . BasicOp . Replicate (Shape $ pure $ head w)) ne h_map <- mkIdentityLambda $ Prim int64 : map rowType as_type h_part <- traverse (newVName . flip (<>) "_h_part" . baseString) xs auxing aux . letBindNames h_part . Op $ Hist n as [HistOp (Shape w) rf nes ne lam0] h_map lam0' <- renameLambda lam0 auxing aux . letBindNames xs . Op $ Screma (head w) (dst <> h_part) (mapSOAC lam0') m xs_bar <- traverse lookupAdjVal xs (dst_params, hp_params, f') <- mkF' lam0 dst_type $ head w f'_adj_dst <- vjpLambda ops (map adjFromVar xs_bar) dst_params f' f'_adj_hp <- vjpLambda ops (map adjFromVar xs_bar) hp_params f' dst_bar' <- eLambda f'_adj_dst $ map (eSubExp . Var) $ dst <> h_part dst_bar <- bindSubExpRes "dst_bar" dst_bar' zipWithM_ updateAdj dst dst_bar h_part_bar' <- eLambda f'_adj_hp $ map (eSubExp . Var) $ dst <> h_part h_part_bar <- bindSubExpRes "h_part_bar" h_part_bar' lam <- renameLambda lam0 lam' <- renameLambda lam0 -- is' <- mapout (head as) n (head w) -- sorted <- radixSort' (is' : tail as) n $ head w sorted <- radixSort' as n $ head w let siota = head sorted let sis = head $ tail sorted let sas = drop 2 sorted iota_n <- letExp "iota" $ BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64 par_i <- newParam "i" $ Prim int64 flag_lam <- mkFlagLam par_i sis flag <- letExp "flag" $ Op $ Screma n [iota_n] $ mapSOAC flag_lam -- map (\i -> (if flag[i] then (true,ne) else (false,vs[i-1]), if i==0 || flag[n-i] then (true,ne) else (false,vs[n-i]))) (iota n) par_i' <- newParam "i" $ Prim int64 let i' = paramName par_i' g_lam <- mkLambda [par_i'] $ fmap subExpsRes . mapM (letSubExp "scan_inps") =<< do im1 <- letSubExp "i_1" =<< toExp (le64 i' - 1) nmi <- letSubExp "n_i" =<< toExp (pe64 n - le64 i') let s1 = [DimFix im1] let s2 = [DimFix nmi] -- flag array for left scan f1 <- letSubExp "f1" $ BasicOp $ Index flag $ Slice [DimFix $ Var i'] -- array for left scan r1 <- letTupExp' "r1" =<< eIf (eSubExp f1) (eBody $ fmap eSubExp ne) (eBody . fmap (eSubExp . Var) =<< multiIndex sas s1) -- array for right scan inc flag r2 <- letTupExp' "r2" =<< eIf (toExp $ le64 i' .==. 0) (eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne) ( eBody $ pure $ do eIf (pure $ BasicOp $ Index flag $ Slice s2) (eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne) ( eBody . fmap eSubExp . (Constant (blankPrimValue Bool) :) . fmap Var =<< multiIndex sas s2 ) ) traverse eSubExp $ f1 : r1 ++ r2 -- scan (\(f1,v1) (f2,v2) -> -- let f = f1 || f2 -- let v = if f2 then v2 else g v1 v2 -- in (f,v) ) (false,ne) (zip flags vals) scan_lams <- traverse ( \l -> do f1 <- newParam "f1" $ Prim Bool f2 <- newParam "f2" $ Prim Bool ps <- lambdaParams <$> renameLambda lam0 let (p1, p2) = splitAt (length ne) ps mkLambda (f1 : p1 ++ f2 : p2) $ fmap varsRes . letTupExp "scan_res" =<< do let f = eBinOp LogOr (eParam f1) (eParam f2) eIf (eParam f2) (eBody $ f : fmap eParam p2) ( eBody . (f :) . fmap (eSubExp . Var) =<< bindSubExpRes "gres" =<< eLambda l (fmap eParam ps) ) ) [lam, lam'] let ne' = Constant (BoolValue False) : ne scansres <- letTupExp "adj_ctrb_scan" . Op $ Screma n [iota_n] (scanomapSOAC (map (`Scan` ne') scan_lams) g_lam) let (_ : ls_arr, _ : rs_arr_rev) = splitAt (length ne + 1) scansres -- map (\i -> if i < w && -1 < w then (xs_bar[i], dst[i]) else (0,ne)) sis par_i'' <- newParam "i" $ Prim int64 let i'' = paramName par_i'' map_lam <- mkLambda [par_i''] $ fmap varsRes . letTupExp "scan_res" =<< eIf (toExp $ withinBounds $ pure (head w, i'')) (eBody . fmap (eSubExp . Var) =<< multiIndex h_part_bar [DimFix $ Var i'']) ( eBody $ do map (\t -> pure $ BasicOp $ Replicate (Shape $ tail $ arrayDims t) (Constant $ blankPrimValue $ elemType t)) as_type ) f_bar <- letTupExp "f_bar" $ Op $ Screma n [sis] $ mapSOAC map_lam (as_params, f) <- mkF lam0 as_type n f_adj <- vjpLambda ops (map adjFromVar f_bar) as_params f -- map (\i -> rs_arr_rev[n-i-1]) (iota n) par_i''' <- newParam "i" $ Prim int64 let i''' = paramName par_i''' rev_lam <- mkLambda [par_i'''] $ do nmim1 <- letSubExp "n_i_1" =<< toExp (pe64 n - le64 i''' - 1) varsRes <$> multiIndex rs_arr_rev [DimFix nmim1] rs_arr <- letTupExp "rs_arr" $ Op $ Screma n [iota_n] $ mapSOAC rev_lam sas_bar <- bindSubExpRes "sas_bar" =<< eLambda f_adj (map (eSubExp . Var) $ ls_arr <> sas <> rs_arr) scatter_dst <- traverse (\t -> letExp "scatter_dst" $ BasicOp $ Scratch (elemType t) (arrayDims t)) as_type as_bar <- multiScatter n scatter_dst siota sas_bar zipWithM_ updateAdj (tail as) as_bar where -- map (\i -> if i == 0 then true else is[i] != is[i-1]) (iota n) mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS) mkFlagLam par_i sis = mkLambda [par_i] $ fmap varsRes . letTupExp "flag" =<< do let i = paramName par_i eIf (toExp (le64 i .==. 0)) (eBody $ pure $ eSubExp $ Constant $ onePrimValue Bool) ( eBody $ pure $ do i_p <- letExp "i_p" =<< toExp (le64 i - 1) vs <- traverse (letExp "vs" . BasicOp . Index sis . Slice . pure . DimFix . Var) [i, i_p] let [vs_i, vs_p] = vs toExp $ bNot $ le64 vs_i .==. le64 vs_p ) futhark-0.25.27/src/Futhark/AD/Rev/Loop.hs000066400000000000000000000412001475065116200200170ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.AD.Rev.Loop (diffLoop, stripmineStms) where import Control.Monad import Data.Foldable (toList) import Data.List ((\\)) import Data.Map qualified as M import Data.Maybe import Futhark.AD.Rev.Monad import Futhark.Analysis.Alias qualified as Alias import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.Aliases (consumedInStms) import Futhark.IR.SOACS import Futhark.Tools import Futhark.Transform.Rename import Futhark.Transform.Substitute import Futhark.Util (nubOrd, traverseFold) -- | A convenience function to bring the components of a for-loop into -- scope and throw an error if the passed 'Exp' is not a for-loop. bindForLoop :: (PrettyRep rep) => Exp rep -> ( [(Param (FParamInfo rep), SubExp)] -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> a ) -> a bindForLoop (Loop val_pats form@(ForLoop i it bound) body) f = f val_pats form i it bound body bindForLoop e _ = error $ "bindForLoop: not a for-loop:\n" <> prettyString e -- | A convenience function to rename a for-loop and then bind the -- renamed components. renameForLoop :: (MonadFreshNames m, Renameable rep, PrettyRep rep) => Exp rep -> ( Exp rep -> [(Param (FParamInfo rep), SubExp)] -> LoopForm -> VName -> IntType -> SubExp -> Body rep -> m a ) -> m a renameForLoop loop f = renameExp loop >>= \loop' -> bindForLoop loop' (f loop') -- | Is the loop a while-loop? isWhileLoop :: Exp rep -> Bool isWhileLoop (Loop _ WhileLoop {} _) = True isWhileLoop _ = False -- | Augments a while-loop to also compute the number of iterations. computeWhileIters :: Exp SOACS -> ADM SubExp computeWhileIters (Loop val_pats (WhileLoop b) body) = do bound_v <- newVName "bound" let t = Prim $ IntType Int64 bound_param = Param mempty bound_v t bound_init <- letSubExp "bound_init" $ zeroExp t body' <- localScope (scopeOfFParams [bound_param]) $ buildBody_ $ do bound_plus_one <- let one = Constant $ IntValue $ intValue Int64 (1 :: Int) in letSubExp "bound+1" $ BasicOp $ BinOp (Add Int64 OverflowUndef) (Var bound_v) one addStms $ bodyStms body pure (pure (subExpRes bound_plus_one) <> bodyResult body) res <- letTupExp' "loop" $ Loop ((bound_param, bound_init) : val_pats) (WhileLoop b) body' pure $ head res computeWhileIters e = error $ "convertWhileIters: not a while-loop:\n" <> prettyString e -- | Converts a 'WhileLoop' into a 'ForLoop'. Requires that the -- surrounding 'Loop' is annotated with a @#[bound(n)]@ attribute, -- where @n@ is an upper bound on the number of iterations of the -- while-loop. The resulting for-loop will execute for @n@ iterations on -- all inputs, so the tighter the bound the better. convertWhileLoop :: SubExp -> Exp SOACS -> ADM (Exp SOACS) convertWhileLoop bound_se (Loop val_pats (WhileLoop cond) body) = localScope (scopeOfFParams $ map fst val_pats) $ do i <- newVName "i" body' <- eBody [ eIf (pure $ BasicOp $ SubExp $ Var cond) (pure body) (resultBodyM $ map (Var . paramName . fst) val_pats) ] pure $ Loop val_pats (ForLoop i Int64 bound_se) body' convertWhileLoop _ e = error $ "convertWhileLoopBound: not a while-loop:\n" <> prettyString e -- | @nestifyLoop n bound loop@ transforms a loop into a depth-@n@ loop nest -- of @bound@-iteration loops. This transformation does not preserve -- the original semantics of the loop: @n@ and @bound@ may be arbitrary and have -- no relation to the number of iterations of @loop@. nestifyLoop :: SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS) nestifyLoop bound_se = nestifyLoop' bound_se where nestifyLoop' offset n loop = bindForLoop loop nestify where nestify val_pats _form i it _bound body | n > 1 = do renameForLoop loop $ \_loop' val_pats' _form' i' it' _bound' body' -> do let loop_params = map fst val_pats loop_params' = map fst val_pats' loop_inits' = map (Var . paramName) loop_params val_pats'' = zip loop_params' loop_inits' outer_body <- buildBody_ $ do offset' <- letSubExp "offset" . BasicOp $ BinOp (Mul it OverflowUndef) offset (Var i) inner_body <- insertStmsM $ do i_inner <- letExp "i_inner" . BasicOp $ BinOp (Add it OverflowUndef) offset' (Var i') pure $ substituteNames (M.singleton i' i_inner) body' inner_loop <- letTupExp "inner_loop" =<< nestifyLoop' offset' (n - 1) (Loop val_pats'' (ForLoop i' it' bound_se) inner_body) pure $ varsRes inner_loop pure $ Loop val_pats (ForLoop i it bound_se) outer_body | n == 1 = pure $ Loop val_pats (ForLoop i it bound_se) body | otherwise = pure loop -- | @stripmine n pat loop@ stripmines a loop into a depth-@n@ loop nest. -- An additional @bound - (floor(bound^(1/n)))^n@-iteration remainder loop is -- inserted after the stripmined loop which executes the remaining iterations -- so that the stripmined loop is semantically equivalent to the original loop. stripmine :: Integer -> Pat Type -> Exp SOACS -> ADM (Stms SOACS) stripmine n pat loop = do bindForLoop loop $ \_val_pats _form _i it bound _body -> do let n_root = Constant $ FloatValue $ floatValue Float64 (1 / fromIntegral n :: Double) bound_float <- letSubExp "bound_f64" $ BasicOp $ ConvOp (UIToFP it Float64) bound bound' <- letSubExp "bound" $ BasicOp $ BinOp (FPow Float64) bound_float n_root bound_int <- letSubExp "bound_int" $ BasicOp $ ConvOp (FPToUI Float64 it) bound' total_iters <- letSubExp "total_iters" . BasicOp $ BinOp (Pow it) bound_int (Constant $ IntValue $ intValue it n) remain_iters <- letSubExp "remain_iters" $ BasicOp $ BinOp (Sub it OverflowUndef) bound total_iters mined_loop <- nestifyLoop bound_int n loop pat' <- renamePat pat renameForLoop loop $ \_loop val_pats' _form' i' it' _bound' body' -> do remain_body <- insertStmsM $ do i_remain <- letExp "i_remain" . BasicOp $ BinOp (Add it OverflowUndef) total_iters (Var i') pure $ substituteNames (M.singleton i' i_remain) body' let loop_params_rem = map fst val_pats' loop_inits_rem = map (Var . patElemName) $ patElems pat' val_pats_rem = zip loop_params_rem loop_inits_rem remain_loop = Loop val_pats_rem (ForLoop i' it' remain_iters) remain_body collectStms_ $ do letBind pat' mined_loop letBind pat remain_loop -- | Stripmines a statement. Only has an effect when the statement's -- expression is a for-loop with a @#[stripmine(n)]@ attribute, where -- @n@ is the nesting depth. stripmineStm :: Stm SOACS -> ADM (Stms SOACS) stripmineStm stm@(Let pat aux loop@(Loop _ ForLoop {} _)) = case nums of (n : _) -> stripmine n pat loop _ -> pure $ oneStm stm where extractNum (AttrComp "stripmine" [AttrInt n]) = Just n extractNum _ = Nothing nums = catMaybes $ mapAttrs extractNum $ stmAuxAttrs aux stripmineStm stm = pure $ oneStm stm stripmineStms :: Stms SOACS -> ADM (Stms SOACS) stripmineStms = traverseFold stripmineStm -- | Forward pass transformation of a loop. This includes modifying the loop -- to save the loop values at each iteration onto a tape as well as copying -- any consumed arrays in the loop's body and consuming said copies in lieu of -- the originals (which will be consumed later in the reverse pass). fwdLoop :: Pat Type -> StmAux () -> Exp SOACS -> ADM () fwdLoop pat aux loop = bindForLoop loop $ \val_pats form i _it bound body -> do bound64 <- asIntS Int64 bound let loop_params = map fst val_pats is_true_dep = inAttrs (AttrName "true_dep") . paramAttrs dont_copy_params = filter is_true_dep loop_params dont_copy = map paramName dont_copy_params loop_params_to_copy = loop_params \\ dont_copy_params empty_saved_array <- forM loop_params_to_copy $ \p -> letSubExp (baseString (paramName p) <> "_empty_saved") =<< eBlank (arrayOf (paramDec p) (Shape [bound64]) NoUniqueness) (body', (saved_pats, saved_params)) <- buildBody $ localScope (scopeOfFParams loop_params) $ localScope (scopeOfLoopForm form) $ do copy_substs <- copyConsumedArrsInBody dont_copy body addStms $ bodyStms body i_i64 <- asIntS Int64 $ Var i (saved_updates, saved_pats_params) <- fmap unzip $ forM loop_params_to_copy $ \p -> do let v = paramName p t = paramDec p saved_param_v <- newVName $ baseString v <> "_saved" saved_pat_v <- newVName $ baseString v <> "_saved" setLoopTape v saved_pat_v let saved_param = Param mempty saved_param_v $ arrayOf t (Shape [bound64]) Unique saved_pat = PatElem saved_pat_v $ arrayOf t (Shape [bound64]) NoUniqueness saved_update <- localScope (scopeOfFParams [saved_param]) $ letInPlace (baseString v <> "_saved_update") saved_param_v (fullSlice (fromDecl $ paramDec saved_param) [DimFix i_i64]) $ substituteNames copy_substs $ BasicOp $ SubExp $ Var v pure (saved_update, (saved_pat, saved_param)) pure (bodyResult body <> varsRes saved_updates, unzip saved_pats_params) let pat' = pat <> Pat saved_pats val_pats' = val_pats <> zip saved_params empty_saved_array addStm $ Let pat' aux $ Loop val_pats' form body' -- | Construct a loop value-pattern for the adjoint of the -- given variable. valPatAdj :: VName -> ADM (Param DeclType, SubExp) valPatAdj v = do v_adj <- adjVName v init_adj <- lookupAdjVal v t <- lookupType init_adj pure (Param mempty v_adj (toDecl t Unique), Var init_adj) valPatAdjs :: LoopInfo [VName] -> ADM (LoopInfo [(Param DeclType, SubExp)]) valPatAdjs = (mapM . mapM) valPatAdj -- | Reverses a loop by substituting the loop index. reverseIndices :: Exp SOACS -> ADM (Substitutions, Stms SOACS) reverseIndices loop = do bindForLoop loop $ \_val_pats form i it bound _body -> do bound_minus_one <- localScope (scopeOfLoopForm form) $ let one = Constant $ IntValue $ intValue it (1 :: Int) in letSubExp "bound-1" $ BasicOp $ BinOp (Sub it OverflowUndef) bound one (i_rev, i_stms) <- collectStms $ localScope (scopeOfLoopForm form) $ do letExp (baseString i <> "_rev") $ BasicOp $ BinOp (Sub it OverflowWrap) bound_minus_one (Var i) pure (M.singleton i i_rev, i_stms) -- | Pures a substitution which substitutes values in the reverse -- loop body with values from the tape. restore :: Stms SOACS -> [Param DeclType] -> VName -> ADM Substitutions restore stms_adj loop_params' i' = M.fromList . catMaybes <$> mapM f loop_params' where dont_copy = map paramName $ filter (inAttrs (AttrName "true_dep") . paramAttrs) loop_params' f p | v `notElem` dont_copy = do m_vs <- lookupLoopTape v case m_vs of Nothing -> pure Nothing Just vs -> do vs_t <- lookupType vs i_i64' <- asIntS Int64 $ Var i' v' <- letExp "restore" $ BasicOp $ Index vs $ fullSlice vs_t [DimFix i_i64'] t <- lookupType v v'' <- case (t, v `elem` consumed) of (Array {}, True) -> letExp "restore_copy" $ BasicOp $ Replicate mempty $ Var v' _ -> pure v' pure $ Just (v, v'') | otherwise = pure Nothing where v = paramName p consumed = namesToList $ consumedInStms $ fst $ Alias.analyseStms mempty stms_adj -- | A type to keep track of and seperate values corresponding to different -- parts of the loop. data LoopInfo a = LoopInfo { loopRes :: a, loopFree :: a, loopVals :: a } deriving (Functor, Foldable, Traversable, Show) -- | Transforms a for-loop into its reverse-mode derivative. revLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> Exp SOACS -> ADM () revLoop diffStms pat loop = bindForLoop loop $ \val_pats _form _i _it _bound _body -> renameForLoop loop $ \loop' val_pats' form' i' _it' _bound' body' -> do let loop_params = map fst val_pats (loop_params', loop_vals') = unzip val_pats' getVName Constant {} = Nothing getVName (Var v) = Just v loop_vnames = LoopInfo { loopRes = mapMaybe subExpResVName $ bodyResult body', loopFree = namesToList (freeIn loop') \\ mapMaybe getVName loop_vals', loopVals = nubOrd $ mapMaybe getVName loop_vals' } renameLoopTape $ M.fromList $ zip (map paramName loop_params) (map paramName loop_params') forM_ (zip (bodyResult body') $ patElems pat) $ \(se_res, pe) -> case subExpResVName se_res of Just v -> setAdj v =<< lookupAdj (patElemName pe) Nothing -> pure () (i_subst, i_stms) <- reverseIndices loop' val_pat_adjs <- valPatAdjs loop_vnames let val_pat_adjs_list = concat $ toList val_pat_adjs (loop_adjs, stms_adj) <- collectStms $ localScope (scopeOfLoopForm form' <> scopeOfFParams (map fst val_pat_adjs_list <> loop_params')) $ do addStms i_stms (loop_adjs, stms_adj) <- collectStms $ subAD $ do zipWithM_ (\val_pat v -> insAdj v (paramName $ fst val_pat)) val_pat_adjs_list (concat $ toList loop_vnames) diffStms $ bodyStms body' loop_res_adjs <- mapM (lookupAdjVal . paramName) loop_params' loop_free_adjs <- mapM lookupAdjVal $ loopFree loop_vnames loop_vals_adjs <- mapM lookupAdjVal $ loopVals loop_vnames pure $ LoopInfo { loopRes = loop_res_adjs, loopFree = loop_free_adjs, loopVals = loop_vals_adjs } (substs, restore_stms) <- collectStms $ restore stms_adj loop_params' i' addStms $ substituteNames i_subst restore_stms addStms $ substituteNames i_subst $ substituteNames substs stms_adj pure loop_adjs inScopeOf stms_adj $ localScope (scopeOfFParams $ map fst val_pat_adjs_list) $ do let body_adj = mkBody stms_adj $ varsRes $ concat $ toList loop_adjs restore_true_deps = M.fromList $ flip mapMaybe (zip loop_params' $ patElems pat) $ \(p, pe) -> if p `elem` filter (inAttrs (AttrName "true_dep") . paramAttrs) loop_params' then Just (paramName p, patElemName pe) else Nothing adjs' <- letTupExp "loop_adj" $ substituteNames restore_true_deps $ Loop val_pat_adjs_list form' body_adj let (loop_res_adjs, loop_free_var_val_adjs) = splitAt (length $ loopRes loop_adjs) adjs' (loop_free_adjs, loop_val_adjs) = splitAt (length $ loopFree loop_adjs) loop_free_var_val_adjs returnSweepCode $ do zipWithM_ updateSubExpAdj loop_vals' loop_res_adjs zipWithM_ insAdj (loopFree loop_vnames) loop_free_adjs zipWithM_ updateAdj (loopVals loop_vnames) loop_val_adjs -- | Transforms a loop into its reverse-mode derivative. diffLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM () diffLoop diffStms pat aux loop m | isWhileLoop loop = let getBound (AttrComp "bound" [AttrInt b]) = Just b getBound _ = Nothing bounds = catMaybes $ mapAttrs getBound $ stmAuxAttrs aux in case bounds of (bound : _) -> do let bound_se = Constant $ IntValue $ intValue Int64 bound for_loop <- convertWhileLoop bound_se loop diffLoop diffStms pat aux for_loop m _ -> do bound <- computeWhileIters loop for_loop <- convertWhileLoop bound =<< renameExp loop diffLoop diffStms pat aux for_loop m | otherwise = do fwdLoop pat aux loop m revLoop diffStms pat loop futhark-0.25.27/src/Futhark/AD/Rev/Map.hs000066400000000000000000000240101475065116200176230ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | VJP transformation for Map SOACs. This is a pretty complicated -- case due to the possibility of free variables. module Futhark.AD.Rev.Map (vjpMap) where import Control.Monad import Data.Bifunctor (first) import Futhark.AD.Rev.Monad import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools import Futhark.Transform.Rename import Futhark.Util (splitAt3) -- | A classification of a free variable based on its adjoint. The -- 'VName' stored is *not* the adjoint, but the primal variable. data AdjVar = -- | Adjoint is already an accumulator. FreeAcc VName | -- | Currently has no adjoint, but should be given one, and is an -- array with this shape and element type. FreeArr VName Shape PrimType | -- | Does not need an accumulator adjoint (might still be an array). FreeNonAcc VName classifyAdjVars :: [VName] -> ADM [AdjVar] classifyAdjVars = mapM f where f v = do v_adj <- lookupAdjVal v v_adj_t <- lookupType v_adj case v_adj_t of Array pt shape _ -> pure $ FreeArr v shape pt Acc {} -> pure $ FreeAcc v _ -> pure $ FreeNonAcc v partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName]) partitionAdjVars [] = ([], [], []) partitionAdjVars (fv : fvs) = case fv of FreeArr v shape t -> ((v, (shape, t)) : xs, ys, zs) FreeAcc v -> (xs, v : ys, zs) FreeNonAcc v -> (xs, ys, v : zs) where (xs, ys, zs) = partitionAdjVars fvs buildRenamedBody :: (MonadBuilder m) => m (Result, a) -> m (Body (Rep m), a) buildRenamedBody m = do (body, x) <- buildBody m body' <- renameBody body pure (body', x) withAcc :: [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] -> ([VName] -> ADM Result) -> ADM [VName] withAcc [] m = mapM (letExp "withacc_res" . BasicOp . SubExp . resSubExp) =<< m [] withAcc inputs m = do (cert_params, acc_params) <- fmap unzip $ forM inputs $ \(shape, arrs, _) -> do cert_param <- newParam "acc_cert_p" $ Prim Unit ts <- mapM (fmap (stripArray (shapeRank shape)) . lookupType) arrs acc_param <- newParam "acc_p" $ Acc (paramName cert_param) shape ts NoUniqueness pure (cert_param, acc_param) acc_lam <- subAD $ mkLambda (cert_params ++ acc_params) $ m $ map paramName acc_params letTupExp "withhacc_res" $ WithAcc inputs acc_lam -- | Perform VJP on a Map. The 'Adj' list is the adjoints of the -- result of the map. vjpMap :: VjpOps -> [Adj] -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> ADM () vjpMap ops res_adjs _ w map_lam as | Just res_ivs <- mapM isSparse res_adjs = returnSweepCode $ do -- Since at most only a constant number of adjoint are nonzero -- (length res_ivs), there is no need for the return sweep code to -- contain a Map at all. free <- filterM isActive $ namesToList $ freeIn map_lam `namesSubtract` namesFromList as free_ts <- mapM lookupType free let adjs_for = map paramName (lambdaParams map_lam) ++ free adjs_ts = map paramType (lambdaParams map_lam) ++ free_ts let oneHot res_i adj_v = zipWith f [0 :: Int ..] $ lambdaReturnType map_lam where f j t | res_i == j = adj_v | otherwise = AdjZero (arrayShape t) (elemType t) -- Values for the out-of-bounds case does not matter, as we will -- be writing to an out-of-bounds index anyway, which is ignored. ooBounds adj_i = subAD . buildRenamedBody $ do forM_ (zip as adjs_ts) $ \(a, t) -> do scratch <- letSubExp "oo_scratch" =<< eBlank t updateAdjIndex a (OutOfBounds, adj_i) scratch -- We must make sure that all free variables have the same -- representation in the oo-branch as in the ib-branch. -- In practice we do this by manifesting the adjoint. -- This is probably efficient, since the adjoint of a free -- variable is probably either a scalar or an accumulator. forM_ free $ \v -> insAdj v =<< adjVal =<< lookupAdj v first subExpsRes . adjsReps <$> mapM lookupAdj (as <> free) inBounds res_i adj_i adj_v = subAD . buildRenamedBody $ do forM_ (zip (lambdaParams map_lam) as) $ \(p, a) -> do a_t <- lookupType a letBindNames [paramName p] . BasicOp . Index a $ fullSlice a_t [DimFix adj_i] adj_elems <- fmap (map resSubExp) . bodyBind . lambdaBody =<< vjpLambda ops (oneHot res_i (AdjVal adj_v)) adjs_for map_lam let (as_adj_elems, free_adj_elems) = splitAt (length as) adj_elems forM_ (zip as as_adj_elems) $ \(a, a_adj_elem) -> updateAdjIndex a (AssumeBounds, adj_i) a_adj_elem forM_ (zip free free_adj_elems) $ \(v, adj_se) -> do adj_se_v <- letExp "adj_v" (BasicOp $ SubExp adj_se) insAdj v adj_se_v first subExpsRes . adjsReps <$> mapM lookupAdj (as <> free) -- Generate an iteration of the map function for every -- position. This is a bit inefficient - probably we could do -- some deduplication. forPos res_i (check, adj_i, adj_v) = do adjs <- case check of CheckBounds b -> do (obbranch, mkadjs) <- ooBounds adj_i (ibbranch, _) <- inBounds res_i adj_i adj_v fmap mkadjs . letTupExp' "map_adj_elem" =<< eIf (maybe (eDimInBounds (eSubExp w) (eSubExp adj_i)) eSubExp b) (pure ibbranch) (pure obbranch) AssumeBounds -> do (body, mkadjs) <- inBounds res_i adj_i adj_v mkadjs . map resSubExp <$> bodyBind body OutOfBounds -> mapM lookupAdj as zipWithM setAdj (as <> free) adjs -- Generate an iteration of the map function for every result. forRes res_i = mapM_ (forPos res_i) zipWithM_ forRes [0 ..] res_ivs where isSparse (AdjSparse (Sparse shape _ ivs)) = do guard $ shapeDims shape == [w] Just ivs isSparse _ = Nothing -- See Note [Adjoints of accumulators] for how we deal with -- accumulators - it's a bit tricky here. vjpMap ops pat_adj aux w map_lam as = returnSweepCode $ do pat_adj_vals <- forM (zip pat_adj (lambdaReturnType map_lam)) $ \(adj, t) -> case t of Acc {} -> letExp "acc_adj_rep" . BasicOp . Replicate (Shape [w]) . Var =<< adjVal adj _ -> adjVal adj pat_adj_params <- mapM (newParam "map_adj_p" . rowType <=< lookupType) pat_adj_vals map_lam' <- renameLambda map_lam free <- filterM isActive $ namesToList $ freeIn map_lam' accAdjoints free $ \free_with_adjs free_without_adjs -> do free_adjs <- mapM lookupAdjVal free_with_adjs free_adjs_ts <- mapM lookupType free_adjs free_adjs_params <- mapM (newParam "free_adj_p") free_adjs_ts let lam_rev_params = lambdaParams map_lam' ++ pat_adj_params ++ free_adjs_params adjs_for = map paramName (lambdaParams map_lam') ++ free lam_rev <- mkLambda lam_rev_params . subAD . noAdjsFor free_without_adjs $ do zipWithM_ insAdj free_with_adjs $ map paramName free_adjs_params bodyBind . lambdaBody =<< vjpLambda ops (map adjFromParam pat_adj_params) adjs_for map_lam' (param_contribs, free_contribs) <- fmap (splitAt (length (lambdaParams map_lam'))) $ auxing aux . letTupExp "map_adjs" . Op $ Screma w (as ++ pat_adj_vals ++ free_adjs) (mapSOAC lam_rev) -- Crucial that we handle the free contribs first in case 'free' -- and 'as' intersect. zipWithM_ freeContrib free free_contribs let param_ts = map paramType (lambdaParams map_lam') forM_ (zip3 param_ts as param_contribs) $ \(param_t, a, param_contrib) -> case param_t of Acc {} -> freeContrib a param_contrib _ -> updateAdj a param_contrib where addIdxParams n lam = do idxs <- replicateM n $ newParam "idx" $ Prim int64 pure $ lam {lambdaParams = idxs ++ lambdaParams lam} accAddLambda n t = addIdxParams n =<< addLambda t withAccInput (v, (shape, pt)) = do v_adj <- lookupAdjVal v add_lam <- accAddLambda (shapeRank shape) $ Prim pt zero <- letSubExp "zero" $ zeroExp $ Prim pt pure (shape, [v_adj], Just (add_lam, [zero])) accAdjoints free m = do (arr_free, acc_free, nonacc_free) <- partitionAdjVars <$> classifyAdjVars free arr_free' <- mapM withAccInput arr_free -- We only consider those input arrays that are also not free in -- the lambda. let as_nonfree = filter (`notElem` free) as (arr_adjs, acc_adjs, rest_adjs) <- fmap (splitAt3 (length arr_free) (length acc_free)) . withAcc arr_free' $ \accs -> do zipWithM_ insAdj (map fst arr_free) accs () <- m (acc_free ++ map fst arr_free) (namesFromList nonacc_free) acc_free_adj <- mapM lookupAdjVal acc_free arr_free_adj <- mapM (lookupAdjVal . fst) arr_free nonacc_free_adj <- mapM lookupAdjVal nonacc_free as_nonfree_adj <- mapM lookupAdjVal as_nonfree pure $ varsRes $ arr_free_adj <> acc_free_adj <> nonacc_free_adj <> as_nonfree_adj zipWithM_ insAdj acc_free acc_adjs zipWithM_ insAdj (map fst arr_free) arr_adjs let (nonacc_adjs, as_nonfree_adjs) = splitAt (length nonacc_free) rest_adjs zipWithM_ insAdj nonacc_free nonacc_adjs zipWithM_ insAdj as_nonfree as_nonfree_adjs freeContrib v contribs = do contribs_t <- lookupType contribs case rowType contribs_t of Acc {} -> void $ insAdj v contribs t -> do lam <- addLambda t zero <- letSubExp "zero" $ zeroExp t reduce <- reduceSOAC [Reduce Commutative lam [zero]] contrib_sum <- letExp (baseString v <> "_contrib_sum") . Op $ Screma w [contribs] reduce void $ updateAdj v contrib_sum futhark-0.25.27/src/Futhark/AD/Rev/Monad.hs000066400000000000000000000440501475065116200201520ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- Naming scheme: -- -- An adjoint-related object for "x" is named "x_adj". This means -- both actual adjoints and statements. -- -- Do not assume "x'" means anything related to derivatives. module Futhark.AD.Rev.Monad ( ADM, RState (..), runADM, Adj (..), InBounds (..), Sparse (..), adjFromParam, adjFromVar, lookupAdj, lookupAdjVal, adjVal, updateAdj, updateSubExpAdj, updateAdjSlice, updateAdjIndex, setAdj, insAdj, adjsReps, -- copyConsumedArrsInStm, copyConsumedArrsInBody, addSubstitution, returnSweepCode, -- adjVName, subAD, noAdjsFor, subSubsts, isActive, -- tabNest, oneExp, zeroExp, unitAdjOfType, addLambda, -- VjpOps (..), -- setLoopTape, lookupLoopTape, substLoopTape, renameLoopTape, ) where import Control.Monad import Control.Monad.State.Strict import Data.Bifunctor (second) import Data.List (foldl') import Data.Map qualified as M import Data.Maybe import Futhark.Analysis.Alias qualified as Alias import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.Aliases (consumedInStms) import Futhark.IR.Prop.Aliases import Futhark.IR.SOACS import Futhark.Tools import Futhark.Transform.Substitute import Futhark.Util (chunks) zeroExp :: Type -> Exp rep zeroExp (Prim pt) = BasicOp $ SubExp $ Constant $ blankPrimValue pt zeroExp (Array pt shape _) = BasicOp $ Replicate shape $ Constant $ blankPrimValue pt zeroExp t = error $ "zeroExp: " ++ prettyString t onePrim :: PrimType -> PrimValue onePrim (IntType it) = IntValue $ intValue it (1 :: Int) onePrim (FloatType ft) = FloatValue $ floatValue ft (1 :: Double) onePrim Bool = BoolValue True onePrim Unit = UnitValue oneExp :: Type -> Exp rep oneExp (Prim t) = BasicOp $ SubExp $ constant $ onePrim t oneExp (Array pt shape _) = BasicOp $ Replicate shape $ Constant $ onePrim pt oneExp t = error $ "oneExp: " ++ prettyString t -- | Whether 'Sparse' should check bounds or assume they are correct. -- The latter results in simpler code. data InBounds = -- | If a SubExp is provided, it references a boolean that is true -- when in-bounds. CheckBounds (Maybe SubExp) | AssumeBounds | -- | Dynamically these will always fail, so don't bother -- generating code for the update. This is only needed to ensure -- a consistent representation of sparse Jacobians. OutOfBounds deriving (Eq, Ord, Show) -- | A symbolic representation of an array that is all zeroes, except -- at certain indexes. data Sparse = Sparse { -- | The shape of the array. sparseShape :: Shape, -- | Element type of the array. sparseType :: PrimType, -- | Locations and values of nonzero values. Indexes may be -- negative, in which case the value is ignored (unless -- 'AssumeBounds' is used). sparseIdxVals :: [(InBounds, SubExp, SubExp)] } deriving (Eq, Ord, Show) -- | The adjoint of a variable. data Adj = AdjSparse Sparse | AdjVal SubExp | AdjZero Shape PrimType deriving (Eq, Ord, Show) instance Substitute Adj where substituteNames m (AdjVal (Var v)) = AdjVal $ Var $ substituteNames m v substituteNames _ adj = adj zeroArray :: (MonadBuilder m) => Shape -> Type -> m VName zeroArray shape t | shapeRank shape == 0 = letExp "zero" $ zeroExp t | otherwise = do zero <- letSubExp "zero" $ zeroExp t attributing (oneAttr "sequential") $ letExp "zeroes_" . BasicOp $ Replicate shape zero sparseArray :: (MonadBuilder m, Rep m ~ SOACS) => Sparse -> m VName sparseArray (Sparse shape t ivs) = do flip (foldM f) ivs =<< zeroArray shape (Prim t) where arr_t = Prim t `arrayOfShape` shape f arr (check, i, se) = do let stm s = letExp "sparse" . BasicOp $ Update s arr (fullSlice arr_t [DimFix i]) se case check of AssumeBounds -> stm Unsafe CheckBounds _ -> stm Safe OutOfBounds -> pure arr adjFromVar :: VName -> Adj adjFromVar = AdjVal . Var adjFromParam :: Param t -> Adj adjFromParam = adjFromVar . paramName unitAdjOfType :: Type -> ADM Adj unitAdjOfType t = AdjVal <$> letSubExp "adj_unit" (oneExp t) -- | The values representing an adjoint in symbolic form. This is -- used for when we wish to return an Adj from a Body or similar -- without forcing manifestation. Also returns a function for -- reassembling the Adj from a new representation (the list must have -- the same length). adjRep :: Adj -> ([SubExp], [SubExp] -> Adj) adjRep (AdjVal se) = ([se], \[se'] -> AdjVal se') adjRep (AdjZero shape pt) = ([], \[] -> AdjZero shape pt) adjRep (AdjSparse (Sparse shape pt ivs)) = (concatMap ivRep ivs, AdjSparse . Sparse shape pt . repIvs ivs) where ivRep (_, i, v) = [i, v] repIvs ((check, _, _) : ivs') (i : v : ses) = (check', i, v) : repIvs ivs' ses where check' = case check of AssumeBounds -> AssumeBounds CheckBounds b -> CheckBounds b OutOfBounds -> CheckBounds (Just (constant False)) -- sic! repIvs _ _ = [] -- | Conveniently convert a list of Adjs to their representation, as -- well as produce a function for converting back. adjsReps :: [Adj] -> ([SubExp], [SubExp] -> [Adj]) adjsReps adjs = let (reps, fs) = unzip $ map adjRep adjs in (concat reps, zipWith ($) fs . chunks (map length reps)) data RState = RState { stateAdjs :: M.Map VName Adj, stateLoopTape :: Substitutions, stateSubsts :: Substitutions, stateNameSource :: VNameSource } newtype ADM a = ADM (BuilderT SOACS (State RState) a) deriving ( Functor, Applicative, Monad, MonadState RState, MonadFreshNames, HasScope SOACS, LocalScope SOACS ) instance MonadBuilder ADM where type Rep ADM = SOACS mkExpDecM pat e = ADM $ mkExpDecM pat e mkBodyM bnds res = ADM $ mkBodyM bnds res mkLetNamesM pat e = ADM $ mkLetNamesM pat e addStms = ADM . addStms collectStms (ADM m) = ADM $ collectStms m instance MonadFreshNames (State RState) where getNameSource = gets stateNameSource putNameSource src = modify (\env -> env {stateNameSource = src}) runADM :: (MonadFreshNames m) => ADM a -> m a runADM (ADM m) = modifyNameSource $ \vn -> second stateNameSource $ runState (fst <$> runBuilderT m mempty) (RState mempty mempty mempty vn) adjVal :: Adj -> ADM VName adjVal (AdjVal se) = letExp "const_adj" $ BasicOp $ SubExp se adjVal (AdjSparse sparse) = sparseArray sparse adjVal (AdjZero shape t) = zeroArray shape $ Prim t -- | Set a specific adjoint. setAdj :: VName -> Adj -> ADM () setAdj v v_adj = modify $ \env -> env {stateAdjs = M.insert v v_adj $ stateAdjs env} -- | Set an 'AdjVal' adjoint. Simple wrapper around 'setAdj'. insAdj :: VName -> VName -> ADM () insAdj v = setAdj v . AdjVal . Var adjVName :: VName -> ADM VName adjVName v = newVName (baseString v <> "_adj") -- | Create copies of all arrays consumed in the given statement, and -- return statements which include copies of the consumed arrays. -- -- See Note [Consumption]. copyConsumedArrsInStm :: Stm SOACS -> ADM (Substitutions, Stms SOACS) copyConsumedArrsInStm s = inScopeOf s $ collectStms $ copyConsumedArrsInStm' s where copyConsumedArrsInStm' stm = let onConsumed v = inScopeOf s $ do v_t <- lookupType v case v_t of Array {} -> do v' <- letExp (baseString v <> "_ad_copy") . BasicOp $ Replicate mempty (Var v) addSubstitution v' v pure [(v, v')] _ -> pure mempty in M.fromList . mconcat <$> mapM onConsumed (namesToList $ consumedInStms $ fst (Alias.analyseStms mempty (oneStm stm))) copyConsumedArrsInBody :: [VName] -> Body SOACS -> ADM Substitutions copyConsumedArrsInBody dontCopy b = mconcat <$> mapM onConsumed (filter (`notElem` dontCopy) $ namesToList $ consumedInBody (Alias.analyseBody mempty b)) where onConsumed v = do v_t <- lookupType v case v_t of Acc {} -> error $ "copyConsumedArrsInBody: Acc " <> prettyString v Array {} -> M.singleton v <$> letExp (baseString v <> "_ad_copy") (BasicOp $ Replicate mempty (Var v)) _ -> pure mempty returnSweepCode :: ADM a -> ADM a returnSweepCode m = do (a, stms) <- collectStms m substs <- gets stateSubsts addStms $ substituteNames substs stms pure a addSubstitution :: VName -> VName -> ADM () addSubstitution v v' = modify $ \env -> env {stateSubsts = M.insert v v' $ stateSubsts env} -- While evaluating this action, pretend these variables have no -- adjoints. Restore current adjoints afterwards. This is used for -- handling certain nested operations. XXX: feels like this should -- really be part of subAD, somehow. Main challenge is that we don't -- want to blank out Accumulator adjoints. Also, might be inefficient -- to blank out array adjoints. noAdjsFor :: Names -> ADM a -> ADM a noAdjsFor names m = do old <- gets $ \env -> mapMaybe (`M.lookup` stateAdjs env) names' modify $ \env -> env {stateAdjs = foldl' (flip M.delete) (stateAdjs env) names'} x <- m modify $ \env -> env {stateAdjs = M.fromList (zip names' old) <> stateAdjs env} pure x where names' = namesToList names addBinOp :: PrimType -> BinOp addBinOp (IntType it) = Add it OverflowWrap addBinOp (FloatType ft) = FAdd ft addBinOp Bool = LogAnd addBinOp Unit = LogAnd tabNest :: Int -> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName] tabNest = tabNest' [] where tabNest' is 0 vs f = f (reverse is) vs tabNest' is n vs f = do vs_ts <- mapM lookupType vs let w = arraysSize 0 vs_ts iota <- letExp "tab_iota" . BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64 iparam <- newParam "i" $ Prim int64 params <- forM vs $ \v -> newParam (baseString v <> "_p") . rowType =<< lookupType v ((ret, res), stms) <- collectStms . localScope (scopeOfLParams (iparam : params)) $ do res <- tabNest' (paramName iparam : is) (n - 1) (map paramName params) f ret <- mapM lookupType res pure (ret, varsRes res) let lam = Lambda (iparam : params) ret (Body () stms res) letTupExp "tab" $ Op $ Screma w (iota : vs) (mapSOAC lam) -- | Construct a lambda for adding two values of the given type. addLambda :: Type -> ADM (Lambda SOACS) addLambda (Prim pt) = binOpLambda (addBinOp pt) pt addLambda t@Array {} = do xs_p <- newParam "xs" t ys_p <- newParam "ys" t lam <- addLambda $ rowType t body <- insertStmsM $ do res <- letSubExp "lam_map" . Op $ Screma (arraySize 0 t) [paramName xs_p, paramName ys_p] (mapSOAC lam) pure $ resultBody [res] pure Lambda { lambdaParams = [xs_p, ys_p], lambdaReturnType = [t], lambdaBody = body } addLambda t = error $ "addLambda: " ++ show t -- Construct an expression for adding the two variables. addExp :: VName -> VName -> ADM (Exp SOACS) addExp x y = do x_t <- lookupType x case x_t of Prim pt -> pure $ BasicOp $ BinOp (addBinOp pt) (Var x) (Var y) Array {} -> do lam <- addLambda $ rowType x_t pure $ Op $ Screma (arraySize 0 x_t) [x, y] (mapSOAC lam) _ -> error $ "addExp: unexpected type: " ++ prettyString x_t lookupAdj :: VName -> ADM Adj lookupAdj v = do maybeAdj <- gets $ M.lookup v . stateAdjs case maybeAdj of Nothing -> do v_t <- lookupType v case v_t of Acc _ shape [Prim t] _ -> pure $ AdjZero shape t _ -> pure $ AdjZero (arrayShape v_t) (elemType v_t) Just v_adj -> pure v_adj lookupAdjVal :: VName -> ADM VName lookupAdjVal v = adjVal =<< lookupAdj v updateAdj :: VName -> VName -> ADM () updateAdj v d = do maybeAdj <- gets $ M.lookup v . stateAdjs case maybeAdj of Nothing -> insAdj v d Just adj -> do v_adj <- adjVal adj v_adj_t <- lookupType v_adj case v_adj_t of Acc {} -> do dims <- arrayDims <$> lookupType d ~[v_adj'] <- tabNest (length dims) [d, v_adj] $ \is [d', v_adj'] -> letTupExp "acc" . BasicOp $ UpdateAcc Safe v_adj' (map Var is) [Var d'] insAdj v v_adj' _ -> do v_adj' <- letExp (baseString v <> "_adj") =<< addExp v_adj d insAdj v v_adj' updateAdjSlice :: Slice SubExp -> VName -> VName -> ADM () updateAdjSlice (Slice [DimFix i]) v d = updateAdjIndex v (AssumeBounds, i) (Var d) updateAdjSlice slice v d = do t <- lookupType v v_adj <- lookupAdjVal v v_adj_t <- lookupType v_adj v_adj' <- case v_adj_t of Acc {} -> do let dims = sliceDims slice ~[v_adj'] <- tabNest (length dims) [d, v_adj] $ \is [d', v_adj'] -> do slice' <- traverse (toSubExp "index") $ fixSlice (fmap pe64 slice) $ map le64 is letTupExp (baseString v_adj') . BasicOp $ UpdateAcc Safe v_adj' slice' [Var d'] pure v_adj' _ -> do v_adjslice <- if primType t then pure v_adj else letExp (baseString v ++ "_slice") $ BasicOp $ Index v_adj slice letInPlace "updated_adj" v_adj slice =<< addExp v_adjslice d insAdj v v_adj' updateSubExpAdj :: SubExp -> VName -> ADM () updateSubExpAdj Constant {} _ = pure () updateSubExpAdj (Var v) d = void $ updateAdj v d -- The index may be negative, in which case the update has no effect. updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM () updateAdjIndex v (check, i) se = do maybeAdj <- gets $ M.lookup v . stateAdjs t <- lookupType v let iv = (check, i, se) case maybeAdj of Nothing -> do setAdj v $ AdjSparse $ Sparse (arrayShape t) (elemType t) [iv] Just AdjZero {} -> setAdj v $ AdjSparse $ Sparse (arrayShape t) (elemType t) [iv] Just (AdjSparse (Sparse shape pt ivs)) -> setAdj v $ AdjSparse $ Sparse shape pt $ iv : ivs Just adj@AdjVal {} -> do v_adj <- adjVal adj v_adj_t <- lookupType v_adj se_v <- letExp "se_v" $ BasicOp $ SubExp se insAdj v =<< case v_adj_t of Acc {} | check == OutOfBounds -> pure v_adj | otherwise -> do dims <- arrayDims <$> lookupType se_v ~[v_adj'] <- tabNest (length dims) [se_v, v_adj] $ \is [se_v', v_adj'] -> letTupExp "acc" . BasicOp $ UpdateAcc Safe v_adj' (i : map Var is) [Var se_v'] pure v_adj' _ -> do let stms s = do v_adj_i <- letExp (baseString v_adj <> "_i") . BasicOp $ Index v_adj $ fullSlice v_adj_t [DimFix i] se_update <- letSubExp "updated_adj_i" =<< addExp se_v v_adj_i letExp (baseString v_adj) . BasicOp $ Update s v_adj (fullSlice v_adj_t [DimFix i]) se_update case check of CheckBounds _ -> stms Safe AssumeBounds -> stms Unsafe OutOfBounds -> pure v_adj -- | Is this primal variable active in the AD sense? FIXME: this is -- (obviously) much too conservative. isActive :: VName -> ADM Bool isActive = fmap (/= Prim Unit) . lookupType -- | Ignore any changes to adjoints made while evaluating this action. subAD :: ADM a -> ADM a subAD m = do old_state_adjs <- gets stateAdjs x <- m modify $ \s -> s {stateAdjs = old_state_adjs} pure x subSubsts :: ADM a -> ADM a subSubsts m = do old_state_substs <- gets stateSubsts x <- m modify $ \s -> s {stateSubsts = old_state_substs} pure x data VjpOps = VjpOps { vjpLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS), vjpStm :: Stm SOACS -> ADM () -> ADM () } -- | @setLoopTape v vs@ establishes @vs@ as the name of the array -- where values of loop parameter @v@ from the forward pass are -- stored. setLoopTape :: VName -> VName -> ADM () setLoopTape v vs = modify $ \env -> env {stateLoopTape = M.insert v vs $ stateLoopTape env} -- | Look-up the name of the array where @v@ is stored. lookupLoopTape :: VName -> ADM (Maybe VName) lookupLoopTape v = gets $ M.lookup v . stateLoopTape -- | @substLoopTape v v'@ substitutes the key @v@ for @v'@. That is, -- if @v |-> vs@ then after the substitution @v' |-> vs@ (and @v@ -- points to nothing). substLoopTape :: VName -> VName -> ADM () substLoopTape v v' = mapM_ (setLoopTape v') =<< lookupLoopTape v -- | Renames the keys of the loop tape. Useful for fixing the -- the names in the loop tape after a loop rename. renameLoopTape :: Substitutions -> ADM () renameLoopTape = mapM_ (uncurry substLoopTape) . M.toList -- Note [Consumption] -- -- Parts of this transformation depends on duplicating computation. -- This is a problem when a primal expression consumes arrays (via -- e.g. Update). For example, consider how we handle this conditional: -- -- if b then ys with [0] = 0 else ys -- -- This consumes the array 'ys', which means that when we later -- generate code for the return sweep, we can no longer use 'ys'. -- This is a problem, because when we call 'diffBody' on the branch -- bodies, we'll keep the primal code (maybe it'll be removed by -- simplification later - we cannot know). A similar issue occurs for -- SOACs. Our solution is to make copies of all consumes arrays: -- -- let ys_copy = copy ys -- -- Then we generate code for the return sweep as normal, but replace -- _every instance_ of 'ys' in the generated code with 'ys_copy'. -- This works because Futhark does not have *semantic* in-place -- updates - any uniqueness violation can be replaced with copies (on -- arrays, anyway). -- -- If we are lucky, the uses of 'ys_copy' will be removed by -- simplification, and there will be no overhead. But even if not, -- this is still (asymptotically) efficient because the array that is -- being consumed must in any case have been produced within the code -- that we are differentiating, so a copy is at most a scalar -- overhead. This is _not_ the case when loops are involved. -- -- Also, the above only works for arrays, not accumulator variables. -- Those will need some other mechanism. futhark-0.25.27/src/Futhark/AD/Rev/Reduce.hs000066400000000000000000000244761475065116200203350ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.AD.Rev.Reduce ( diffReduce, diffMinMaxReduce, diffVecReduce, diffMulReduce, ) where import Control.Monad import Futhark.AD.Rev.Monad import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools import Futhark.Transform.Rename eReverse :: (MonadBuilder m) => VName -> m VName eReverse arr = do arr_t <- lookupType arr let w = arraySize 0 arr_t start <- letSubExp "rev_start" . BasicOp $ BinOp (Sub Int64 OverflowUndef) w (intConst Int64 1) let stride = intConst Int64 (-1) slice = fullSlice arr_t [DimSlice start w stride] letExp (baseString arr <> "_rev") $ BasicOp $ Index arr slice scanExc :: (MonadBuilder m, Rep m ~ SOACS) => String -> Scan SOACS -> [VName] -> m [VName] scanExc desc scan arrs = do w <- arraysSize 0 <$> mapM lookupType arrs form <- scanSOAC [scan] res_incl <- letTupExp (desc <> "_incl") $ Op $ Screma w arrs form iota <- letExp "iota" . BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64 iparam <- newParam "iota_param" $ Prim int64 lam <- mkLambda [iparam] $ do let first_elem = eCmpOp (CmpEq int64) (eSubExp (Var (paramName iparam))) (eSubExp (intConst Int64 0)) prev = toExp $ le64 (paramName iparam) - 1 fmap subExpsRes . letTupExp' "scan_ex_res" =<< eIf first_elem (resultBodyM $ scanNeutral scan) (eBody $ map (`eIndex` [prev]) res_incl) letTupExp desc $ Op $ Screma w [iota] (mapSOAC lam) mkF :: Lambda SOACS -> ADM ([VName], Lambda SOACS) mkF lam = do lam_l <- renameLambda lam lam_r <- renameLambda lam let q = length $ lambdaReturnType lam (lps, aps) = splitAt q $ lambdaParams lam_l (ips, rps) = splitAt q $ lambdaParams lam_r lam' <- mkLambda (lps <> aps <> rps) $ do lam_l_res <- bodyBind $ lambdaBody lam_l forM_ (zip ips lam_l_res) $ \(ip, SubExpRes cs se) -> certifying cs $ letBindNames [paramName ip] $ BasicOp $ SubExp se bodyBind $ lambdaBody lam_r pure (map paramName aps, lam') diffReduce :: VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM () diffReduce _ops [adj] w [a] red | Just [(op, _, _, _)] <- lamIsBinOp $ redLambda red, isAdd op = do adj_rep <- letExp (baseString adj <> "_rep") $ BasicOp $ Replicate (Shape [w]) $ Var adj void $ updateAdj a adj_rep where isAdd FAdd {} = True isAdd Add {} = True isAdd _ = False -- -- Differentiating a general single reduce: -- let y = reduce \odot ne as -- Forward sweep: -- let ls = scan_exc \odot ne as -- let rs = scan_exc \odot' ne (reverse as) -- Reverse sweep: -- let as_c = map3 (f_bar y_bar) ls as (reverse rs) -- where -- x \odot' y = y \odot x -- y_bar is the adjoint of the result y -- f l_i a_i r_i = l_i \odot a_i \odot r_i -- f_bar = the reverse diff of f with respect to a_i under the adjoint y_bar -- The plan is to create -- one scanomap SOAC which computes ls and rs -- another map which computes as_c -- diffReduce ops pat_adj w as red = do red' <- renameRed red flip_red <- renameRed =<< flipReduce red ls <- scanExc "ls" (redToScan red') as rs <- mapM eReverse =<< scanExc "ls" (redToScan flip_red) =<< mapM eReverse as (as_params, f) <- mkF $ redLambda red f_adj <- vjpLambda ops (map adjFromVar pat_adj) as_params f as_adj <- letTupExp "adjs" $ Op $ Screma w (ls ++ as ++ rs) (mapSOAC f_adj) zipWithM_ updateAdj as as_adj where renameRed (Reduce comm lam nes) = Reduce comm <$> renameLambda lam <*> pure nes redToScan :: Reduce SOACS -> Scan SOACS redToScan (Reduce _ lam nes) = Scan lam nes flipReduce (Reduce comm lam nes) = do lam' <- renameLambda lam {lambdaParams = flipParams $ lambdaParams lam} pure $ Reduce comm lam' nes flipParams ps = uncurry (flip (++)) $ splitAt (length ps `div` 2) ps -- -- Special case of reduce with min/max: -- let x = reduce minmax ne as -- Forward trace (assuming w = length as): -- let (x, x_ind) = -- reduce (\ acc_v acc_i v i -> -- if (acc_v == v) then (acc_v, min acc_i i) -- else if (acc_v == minmax acc_v v) -- then (acc_v, acc_i) -- else (v, i)) -- (ne_min, -1) -- (zip as (iota w)) -- Reverse trace: -- num_elems = i64.bool (0 <= x_ind) -- m_bar_repl = replicate num_elems m_bar -- as_bar[x_ind:num_elems:1] += m_bar_repl diffMinMaxReduce :: VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> ADM () -> ADM () diffMinMaxReduce _ops x aux w minmax ne as m = do let t = binOpType minmax acc_v_p <- newParam "acc_v" $ Prim t acc_i_p <- newParam "acc_i" $ Prim int64 v_p <- newParam "v" $ Prim t i_p <- newParam "i" $ Prim int64 red_lam <- mkLambda [acc_v_p, acc_i_p, v_p, i_p] $ fmap varsRes . letTupExp "idx_res" =<< eIf (eCmpOp (CmpEq t) (eParam acc_v_p) (eParam v_p)) ( eBody [ eParam acc_v_p, eBinOp (SMin Int64) (eParam acc_i_p) (eParam i_p) ] ) ( eBody [ eIf ( eCmpOp (CmpEq t) (eParam acc_v_p) (eBinOp minmax (eParam acc_v_p) (eParam v_p)) ) (eBody [eParam acc_v_p, eParam acc_i_p]) (eBody [eParam v_p, eParam i_p]) ] ) red_iota <- letExp "red_iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64 form <- reduceSOAC [Reduce Commutative red_lam [ne, intConst Int64 (-1)]] x_ind <- newVName (baseString x <> "_ind") auxing aux $ letBindNames [x, x_ind] $ Op $ Screma w [as, red_iota] form m x_adj <- lookupAdjVal x in_bounds <- letSubExp "minmax_in_bounds" . BasicOp $ CmpOp (CmpSlt Int64) (intConst Int64 0) w updateAdjIndex as (CheckBounds (Just in_bounds), Var x_ind) (Var x_adj) -- -- Special case of vectorised reduce: -- let x = reduce (map2 op) nes as -- Idea: -- rewrite to -- let x = map2 (\as ne -> reduce op ne as) (transpose as) nes -- and diff diffVecReduce :: VjpOps -> Pat Type -> StmAux () -> SubExp -> Commutativity -> Lambda SOACS -> VName -> VName -> ADM () -> ADM () diffVecReduce ops x aux w iscomm lam ne as m = do stms <- collectStms_ $ do rank <- arrayRank <$> lookupType as let rear = [1, 0] ++ drop 2 [0 .. rank - 1] tran_as <- letExp "tran_as" $ BasicOp $ Rearrange rear as ts <- lookupType tran_as t_ne <- lookupType ne as_param <- newParam "as_param" $ rowType ts ne_param <- newParam "ne_param" $ rowType t_ne reduce_form <- reduceSOAC [Reduce iscomm lam [Var $ paramName ne_param]] map_lam <- mkLambda [as_param, ne_param] $ fmap varsRes . letTupExp "idx_res" $ Op $ Screma w [paramName as_param] reduce_form addStm $ Let x aux $ Op $ Screma (arraySize 0 ts) [tran_as, ne] $ mapSOAC map_lam foldr (vjpStm ops) m stms -- -- Special case of reduce with mul: -- let x = reduce (*) ne as -- Forward trace (assuming w = length as): -- let (p, z) = map (\a -> if a == 0 then (1, 1) else (a, 0)) as -- non_zero_prod = reduce (*) ne p -- zr_count = reduce (+) 0 z -- let x = -- if 0 == zr_count -- then non_zero_prod -- else 0 -- Reverse trace: -- as_bar = map2 -- (\a a_bar -> -- if zr_count == 0 -- then a_bar + non_zero_prod/a * x_bar -- else if zr_count == 1 -- then a_bar + (if a == 0 then non_zero_prod * x_bar else 0) -- else as_bar -- ) as as_bar diffMulReduce :: VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> ADM () -> ADM () diffMulReduce _ops x aux w mul ne as m = do let t = binOpType mul let const_zero = eSubExp $ Constant $ blankPrimValue t a_param <- newParam "a" $ Prim t map_lam <- mkLambda [a_param] $ fmap varsRes . letTupExp "map_res" =<< eIf (eCmpOp (CmpEq t) (eParam a_param) const_zero) (eBody $ fmap eSubExp [Constant $ onePrimValue t, intConst Int64 1]) (eBody [eParam a_param, eSubExp $ intConst Int64 0]) ps <- newVName "ps" zs <- newVName "zs" auxing aux $ letBindNames [ps, zs] $ Op $ Screma w [as] $ mapSOAC map_lam red_lam_mul <- binOpLambda mul t red_lam_add <- binOpLambda (Add Int64 OverflowUndef) int64 red_form_mul <- reduceSOAC $ pure $ Reduce Commutative red_lam_mul $ pure ne red_form_add <- reduceSOAC $ pure $ Reduce Commutative red_lam_add $ pure $ intConst Int64 0 nz_prods <- newVName "non_zero_prod" zr_count <- newVName "zero_count" auxing aux $ letBindNames [nz_prods] $ Op $ Screma w [ps] red_form_mul auxing aux $ letBindNames [zr_count] $ Op $ Screma w [zs] red_form_add auxing aux $ letBindNames [x] =<< eIf (toExp $ 0 .==. le64 zr_count) (eBody $ pure $ eSubExp $ Var nz_prods) (eBody $ pure const_zero) m x_adj <- lookupAdjVal x a_param_rev <- newParam "a" $ Prim t map_lam_rev <- mkLambda [a_param_rev] $ fmap varsRes . letTupExp "adj_res" =<< eIf (toExp $ 0 .==. le64 zr_count) ( eBody $ pure $ eBinOp mul (eSubExp $ Var x_adj) $ eBinOp (getDiv t) (eSubExp $ Var nz_prods) $ eParam a_param_rev ) ( eBody $ pure $ eIf (toExp $ 1 .==. le64 zr_count) ( eBody $ pure $ eIf (eCmpOp (CmpEq t) (eParam a_param_rev) const_zero) ( eBody $ pure $ eBinOp mul (eSubExp $ Var x_adj) $ eSubExp $ Var nz_prods ) (eBody $ pure const_zero) ) (eBody $ pure const_zero) ) as_adjup <- letExp "adjs" $ Op $ Screma w [as] $ mapSOAC map_lam_rev updateAdj as as_adjup where getDiv :: PrimType -> BinOp getDiv (IntType t) = SDiv t Unsafe getDiv (FloatType t) = FDiv t getDiv _ = error "In getDiv, Reduce.hs: input not supported" futhark-0.25.27/src/Futhark/AD/Rev/SOAC.hs000066400000000000000000000211201475065116200176320ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.AD.Rev.SOAC (vjpSOAC) where import Control.Monad import Futhark.AD.Rev.Hist import Futhark.AD.Rev.Map import Futhark.AD.Rev.Monad import Futhark.AD.Rev.Reduce import Futhark.AD.Rev.Scan import Futhark.AD.Rev.Scatter import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools import Futhark.Util (chunks) -- We split any multi-op scan or reduction into multiple operations so -- we can detect special cases. Post-AD, the result may be fused -- again. splitScanRed :: VjpOps -> ([a] -> ADM (ScremaForm SOACS), a -> [SubExp]) -> (Pat Type, StmAux (), [a], SubExp, [VName]) -> ADM () -> ADM () splitScanRed vjpops (opSOAC, opNeutral) (pat, aux, ops, w, as) m = do let ks = map (length . opNeutral) ops pat_per_op = map Pat $ chunks ks $ patElems pat as_per_op = chunks ks as onOps (op : ops') (op_pat : op_pats') (op_as : op_as') = do op_form <- opSOAC [op] vjpSOAC vjpops op_pat aux (Screma w op_as op_form) $ onOps ops' op_pats' op_as' onOps _ _ _ = m onOps ops pat_per_op as_per_op -- We split multi-op histograms into multiple operations so we -- can take advantage of special cases. Post-AD, the result may -- be fused again. splitHist :: VjpOps -> Pat Type -> StmAux () -> [HistOp SOACS] -> SubExp -> [VName] -> [VName] -> ADM () -> ADM () splitHist vjpops pat aux ops w is as m = do let ks = map (length . histNeutral) ops pat_per_op = map Pat $ chunks ks $ patElems pat as_per_op = chunks ks as onOps (op : ops') (op_pat : op_pats') (op_is : op_is') (op_as : op_as') = do f <- mkIdentityLambda . (Prim int64 :) =<< traverse lookupType op_as vjpSOAC vjpops op_pat aux (Hist w (op_is : op_as) [op] f) $ onOps ops' op_pats' op_is' op_as' onOps _ _ _ _ = m onOps ops pat_per_op is as_per_op -- unfusing a map-histogram construct into a map and a histogram. histomapToMapAndHist :: Pat Type -> (SubExp, [HistOp SOACS], Lambda SOACS, [VName]) -> ADM (Stm SOACS, Stm SOACS) histomapToMapAndHist (Pat pes) (w, histops, map_lam, as) = do map_pat <- traverse accMapPatElem $ lambdaReturnType map_lam let map_stm = mkLet map_pat $ Op $ Screma w as $ mapSOAC map_lam new_lam <- mkIdentityLambda $ lambdaReturnType map_lam let hist_stm = Let (Pat pes) (defAux ()) $ Op $ Hist w (map identName map_pat) histops new_lam pure (map_stm, hist_stm) where accMapPatElem = newIdent "hist_map_res" . (`arrayOfRow` w) commonSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM [Adj] commonSOAC pat aux soac m = do addStm $ Let pat aux $ Op soac m returnSweepCode $ mapM lookupAdj $ patNames pat -- Reverse-mode differentiation of SOACs vjpSOAC :: VjpOps -> Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM () -- Differentiating Reduces vjpSOAC ops pat aux soac@(Screma w as form) m | Just [Reduce iscomm lam [Var ne]] <- isReduceSOAC form, [a] <- as, Just op <- mapOp lam = diffVecReduce ops pat aux w iscomm op ne a m | Just reds <- isReduceSOAC form, length reds > 1 = splitScanRed ops (reduceSOAC, redNeutral) (pat, aux, reds, w, as) m | Just [red] <- isReduceSOAC form, [x] <- patNames pat, [ne] <- redNeutral red, [a] <- as, Just [(op, _, _, _)] <- lamIsBinOp $ redLambda red, isMinMaxOp op = diffMinMaxReduce ops x aux w op ne a m | Just [red] <- isReduceSOAC form, [x] <- patNames pat, [ne] <- redNeutral red, [a] <- as, Just [(op, _, _, _)] <- lamIsBinOp $ redLambda red, isMulOp op = diffMulReduce ops x aux w op ne a m | Just red <- singleReduce <$> isReduceSOAC form = do pat_adj <- mapM adjVal =<< commonSOAC pat aux soac m diffReduce ops pat_adj w as red -- Differentiating Scans vjpSOAC ops pat aux soac@(Screma w as form) m | Just [Scan lam [ne]] <- isScanSOAC form, [x] <- patNames pat, [a] <- as, Just [(op, _, _, _)] <- lamIsBinOp lam, isAddOp op = do void $ commonSOAC pat aux soac m diffScanAdd ops x w lam ne a | Just [Scan lam ne] <- isScanSOAC form, Just op <- mapOp lam = do diffScanVec ops (patNames pat) aux w op ne as m | Just scans <- isScanSOAC form, length scans > 1 = splitScanRed ops (scanSOAC, scanNeutral) (pat, aux, scans, w, as) m | Just red <- singleScan <$> isScanSOAC form = do void $ commonSOAC pat aux soac m diffScan ops (patNames pat) w as red -- Differentiating Maps vjpSOAC ops pat aux soac@(Screma w as form) m | Just lam <- isMapSOAC form = do pat_adj <- commonSOAC pat aux soac m vjpMap ops pat_adj aux w lam as -- Differentiating Redomaps vjpSOAC ops pat _aux (Screma w as form) m | Just (reds, map_lam) <- isRedomapSOAC form = do (mapstm, redstm) <- redomapToMapAndReduce pat (w, reds, map_lam, as) vjpStm ops mapstm $ vjpStm ops redstm m -- Differentiating Scanomaps vjpSOAC ops pat _aux (Screma w as form) m | Just (scans, map_lam) <- isScanomapSOAC form = do (mapstm, scanstm) <- scanomapToMapAndScan pat (w, scans, map_lam, as) vjpStm ops mapstm $ vjpStm ops scanstm m -- Differentiating Scatter vjpSOAC ops pat aux (Scatter w ass written_info lam) m | isIdentityLambda lam = vjpScatter ops pat aux (w, ass, lam, written_info) m | otherwise = do map_idents <- mapM (\t -> newIdent "map_res" (arrayOfRow t w)) $ lambdaReturnType lam let map_stm = mkLet map_idents $ Op $ Screma w ass $ mapSOAC lam lam_id <- mkIdentityLambda $ lambdaReturnType lam let scatter_stm = Let pat aux $ Op $ Scatter w (map identName map_idents) written_info lam_id vjpStm ops map_stm $ vjpStm ops scatter_stm m -- Differentiating Histograms vjpSOAC ops pat aux (Hist n as histops f) m | isIdentityLambda f, length histops > 1 = do let (is, vs) = splitAt (length histops) as splitHist ops pat aux histops n is vs m vjpSOAC ops pat aux (Hist n [is, vs] [histop] f) m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [Var ne] lam <- histop, -- Note that the operator is vectorised, so `ne` cannot be a 'PrimValue'. Just op <- mapOp lam = diffVecHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', isMinMaxOp op = diffMinMaxHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', isMulOp op = diffMulHist ops x aux n op ne is vs w rf dst m | isIdentityLambda f, [x] <- patNames pat, HistOp (Shape [w]) rf [dst] [ne] lam <- histop, lam' <- nestedMapOp lam, Just [(op, _, _, _)] <- lamIsBinOp lam', isAddOp op = diffAddHist ops x aux n lam ne is vs w rf dst m vjpSOAC ops pat aux (Hist n as [histop] f) m | isIdentityLambda f, HistOp (Shape w) rf dst ne lam <- histop = do diffHist ops (patNames pat) aux n lam ne as w rf dst m vjpSOAC ops pat _aux (Hist n as histops f) m | not (isIdentityLambda f) = do (mapstm, redstm) <- histomapToMapAndHist pat (n, histops, f, as) vjpStm ops mapstm $ vjpStm ops redstm m vjpSOAC _ _ _ soac _ = error $ "vjpSOAC unhandled:\n" ++ prettyString soac --------------- --- Helpers --- --------------- isMinMaxOp :: BinOp -> Bool isMinMaxOp (SMin _) = True isMinMaxOp (UMin _) = True isMinMaxOp (FMin _) = True isMinMaxOp (SMax _) = True isMinMaxOp (UMax _) = True isMinMaxOp (FMax _) = True isMinMaxOp _ = False isMulOp :: BinOp -> Bool isMulOp (Mul _ _) = True isMulOp (FMul _) = True isMulOp _ = False isAddOp :: BinOp -> Bool isAddOp (Add _ _) = True isAddOp (FAdd _) = True isAddOp _ = False -- Identifies vectorized operators (lambdas): -- if the lambda argument is a map, then returns -- just the map's lambda; otherwise nothing. mapOp :: Lambda SOACS -> Maybe (Lambda SOACS) mapOp (Lambda [pa1, pa2] _ lam_body) | [SubExpRes cs r] <- bodyResult lam_body, cs == mempty, [map_stm] <- stmsToList (bodyStms lam_body), (Let (Pat [pe]) _ (Op scrm)) <- map_stm, (Screma _ [a1, a2] (ScremaForm map_lam [] [])) <- scrm, (a1 == paramName pa1 && a2 == paramName pa2) || (a1 == paramName pa2 && a2 == paramName pa1), r == Var (patElemName pe) = Just map_lam mapOp _ = Nothing -- getting the innermost lambda of a perfect-map nest -- (i.e., the first lambda that does not consists of exactly a map) nestedMapOp :: Lambda SOACS -> Lambda SOACS nestedMapOp lam = maybe lam nestedMapOp (mapOp lam) futhark-0.25.27/src/Futhark/AD/Rev/Scan.hs000066400000000000000000000462131475065116200200030ustar00rootroot00000000000000module Futhark.AD.Rev.Scan (diffScan, diffScanVec, diffScanAdd) where import Control.Monad import Data.List (transpose) import Futhark.AD.Rev.Monad import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS import Futhark.IR.SOACS.Simplify (simplifyLambda) import Futhark.Tools import Futhark.Transform.Rename import Futhark.Util (chunk) data FirstOrSecond = WrtFirst | WrtSecond identityM :: Int -> Type -> ADM [[SubExp]] identityM n t = traverse (traverse (letSubExp "id")) [[if i == j then oneExp t else zeroExp t | i <- [1 .. n]] | j <- [1 .. n]] matrixMul :: [[PrimExp VName]] -> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]] matrixMul m1 m2 t = let zero = primExpFromSubExp t $ Constant $ blankPrimValue t in [[foldl (~+~) zero $ zipWith (~*~) r q | q <- transpose m2] | r <- m1] matrixVecMul :: [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName] matrixVecMul m v t = let zero = primExpFromSubExp t $ Constant $ blankPrimValue t in [foldl (~+~) zero $ zipWith (~*~) v r | r <- m] vectorAdd :: [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName] vectorAdd = zipWith (~+~) orderArgs :: Special -> [a] -> [[a]] orderArgs s lst = chunk (div (length lst) $ specialScans s) lst -- computes `d(x op y)/dx` or d(x op y)/dy mkScanAdjointLam :: VjpOps -> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS) mkScanAdjointLam ops lam0 which adjs = do let len = length $ lambdaReturnType lam0 lam <- renameLambda lam0 let p2diff = case which of WrtFirst -> take len $ lambdaParams lam WrtSecond -> drop len $ lambdaParams lam vjpLambda ops (fmap AdjVal adjs) (map paramName p2diff) lam -- Should generate something like: -- `\ j -> let i = n - 1 - j -- if i < n-1 then ( ys_adj[i], df2dx ys[i] xs[i+1]) else (ys_adj[i],1) )` -- where `ys` is the result of scan -- `xs` is the input of scan -- `ys_adj` is the known adjoint of ys -- `j` draw values from `iota n` mkScanFusedMapLam :: -- i and j above are probably swapped in the code below VjpOps -> -- (ops) helper functions SubExp -> -- (w) ~length of arrays e.g. xs Lambda SOACS -> -- (scn_lam) the scan to be differentiated ('scan' turned into a lambda) [VName] -> -- (xs) input of the scan (actually as) [VName] -> -- (ys) output of the scan [VName] -> -- (ys_adj) adjoint of ys Special -> -- (s) information about which special case we're working with for the scan derivative Int -> -- (d) dimension of the input (number of elements in the input tuple) ADM (Lambda SOACS) -- output: some kind of codegen for the lambda mkScanFusedMapLam ops w scn_lam xs ys ys_adj s d = do let sc = specialCase s k = specialSubSize s ys_ts <- traverse lookupType ys idmat <- identityM (length ys) $ rowType $ head ys_ts lams <- traverse (mkScanAdjointLam ops scn_lam WrtFirst) idmat par_i <- newParam "i" $ Prim int64 let i = paramName par_i mkLambda [par_i] $ fmap varsRes . letTupExp "x" =<< eIf (toExp $ le64 i .==. 0) ( buildBody_ $ do j <- letSubExp "j" =<< toExp (pe64 w - (le64 i + 1)) y_s <- forM ys_adj $ \y_ -> letSubExp (baseString y_ ++ "_j") =<< eIndex y_ [eSubExp j] let zso = orderArgs s y_s let ido = orderArgs s $ caseJac k sc idmat pure $ subExpsRes $ concat $ zipWith (++) zso $ fmap concat ido ) ( buildBody_ $ do j <- letSubExp "j" =<< toExp (pe64 w - (le64 i + 1)) j1 <- letSubExp "j1" =<< toExp (pe64 w - le64 i) y_s <- forM ys_adj $ \y_ -> letSubExp (baseString y_ ++ "_j") =<< eIndex y_ [eSubExp j] let args = map (`eIndex` [eSubExp j]) ys ++ map (`eIndex` [eSubExp j1]) xs lam_rs <- traverse (`eLambda` args) lams let yso = orderArgs s $ subExpsRes y_s let jaco = orderArgs s $ caseJac k sc $ transpose lam_rs pure $ concat $ zipWith (++) yso $ fmap concat jaco ) where caseJac :: Int -> Maybe SpecialCase -> [[a]] -> [[a]] caseJac _ Nothing jac = jac caseJac k (Just ZeroQuadrant) jac = concat $ zipWith (\i -> map (take k . drop (i * k))) [0 .. d `div` k] $ chunk k jac caseJac k (Just MatrixMul) jac = take k <$> take k jac -- a1 a2 b -> a2 + b * a1 linFunT0 :: [PrimExp VName] -> [PrimExp VName] -> [[PrimExp VName]] -> Special -> PrimType -> [PrimExp VName] linFunT0 a1 a2 b s pt = let t = case specialCase s of Just MatrixMul -> concatMap (\v -> matrixVecMul b v pt) $ chunk (specialSubSize s) a1 _ -> matrixVecMul b a1 pt in a2 `vectorAdd` t -- \(a1, b1) (a2, b2) -> (a2 + b2 * a1, b2 * b1) mkScanLinFunO :: Type -> Special -> ADM (Scan SOACS) -- a is an instance of y_bar, b is a Jacobian (a 'c' in the 2023 paper) mkScanLinFunO t s = do let pt = elemType t neu_elm <- mkNeutral $ specialNeutral s let (as, bs) = specialParams s -- input size, Jacobian element count (a1s, b1s, a2s, b2s) <- mkParams (as, bs) -- create sufficient free variables to bind every element of the vectors / matrices let pet = primExpFromSubExp pt . Var -- manifest variable names as expressions let (_, n) = specialNeutral s -- output size (one side of the Jacobian) lam <- mkLambda (map (\v -> Param mempty v (rowType t)) (a1s ++ b1s ++ a2s ++ b2s)) . fmap subExpsRes $ do let [a1s', b1s', a2s', b2s'] = (fmap . fmap) pet [a1s, b1s, a2s, b2s] let (b1sm, b2sm) = (chunk n b1s', chunk n b2s') let t0 = linFunT0 a1s' a2s' b2sm s pt let t1 = concat $ matrixMul b2sm b1sm pt traverse (letSubExp "r" <=< toExp) $ t0 ++ t1 pure $ Scan lam neu_elm where mkNeutral (a, b) = do zeros <- replicateM a $ letSubExp "zeros" $ zeroExp $ rowType t idmat <- identityM b $ Prim $ elemType t pure $ zeros ++ concat idmat mkParams (a, b) = do a1s <- replicateM a $ newVName "a1" b1s <- replicateM b $ newVName "b1" a2s <- replicateM a $ newVName "a2" b2s <- replicateM b $ newVName "b2" pure (a1s, b1s, a2s, b2s) -- perform the final map -- let xs_contribs = -- map3 (\ i a r -> if i==0 then r else (df2dy (ys[i-1]) a) \bar{*} r) -- (iota n) xs (reverse ds) mkScanFinalMap :: VjpOps -> SubExp -> Lambda SOACS -> [VName] -> [VName] -> [VName] -> ADM [VName] mkScanFinalMap ops w scan_lam xs ys ds = do let eltps = lambdaReturnType scan_lam par_i <- newParam "i" $ Prim int64 let i = paramName par_i par_x <- zipWithM (\x -> newParam (baseString x ++ "_par_x")) xs eltps map_lam <- mkLambda (par_i : par_x) $ do j <- letSubExp "j" =<< toExp (pe64 w - (le64 i + 1)) dj <- forM ds $ \dd -> letExp (baseString dd ++ "_dj") =<< eIndex dd [eSubExp j] fmap varsRes . letTupExp "scan_contribs" =<< eIf (toExp $ le64 i .==. 0) (resultBodyM $ fmap Var dj) ( buildBody_ $ do lam <- mkScanAdjointLam ops scan_lam WrtSecond $ fmap Var dj im1 <- letSubExp "im1" =<< toExp (le64 i - 1) ys_im1 <- forM ys $ \y -> letSubExp (baseString y <> "_im1") =<< eIndex y [eSubExp im1] let args = map eSubExp $ ys_im1 ++ map (Var . paramName) par_x eLambda lam args ) iota <- letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64 letTupExp "scan_contribs" $ Op $ Screma w (iota : xs) $ mapSOAC map_lam -- | Scan special cases. data SpecialCase = ZeroQuadrant | MatrixMul deriving (Show) -- | Metadata for how to perform the scan for the return sweep. data Special = Special { -- | Size of one of the two dimensions of the Jacobian (e.g. 3 if -- it's 3x3, must be square because scan must be a->a->a). It's -- the size of the special neutral element, not the element itself specialNeutral :: (Int, Int), -- | Size of input (nr params); Flat size of Jacobian (dim1 * -- dim2)). Number of params for the special lambda. specialParams :: (Int, Int), -- | The number of scans to do, 1 in most cases, k in the -- ZeroQuadrant (block diagonal?) case. specialScans :: Int, -- | Probably: the size of submatrices for the ZeroQuadrant (block -- diagonal?) case, or 1 otherwise. specialSubSize :: Int, -- | Which case. specialCase :: Maybe SpecialCase } deriving (Show) -- | The different ways to handle scans. The best one is chosen -- heuristically by looking at the operator. data ScanAlgo = -- | Construct and compose the Jacobians; the approach presented -- in *Reverse-Mode AD of Multi-Reduce and Scan in Futhark*. GenericIFL23 Special | -- | The approach from *Parallelism-preserving automatic -- differentiation for second-order array languages*. GenericPPAD deriving (Show) subMats :: Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int subMats d mat zero = let sub_d = filter (\x -> d `mod` x == 0) [1 .. (d `div` 2)] poss = map (\m -> all (ok m) $ zip mat [0 .. d - 1]) sub_d tmp = filter fst (zip poss sub_d) in if null tmp then Nothing else Just $ snd $ head tmp where ok m (row, i) = all (\(v, j) -> v == zero || i `div` m == j `div` m) $ zip row [0 .. d - 1] cases :: Int -> Type -> [[Exp SOACS]] -> ScanAlgo cases d t mat = case subMats d mat $ zeroExp t of Just k -> let nonZeros = zipWith (\i -> map (take k . drop (i * k))) [0 .. d `div` k] $ chunk k mat in if all (== head nonZeros) $ tail nonZeros then GenericIFL23 $ Special (d, k) (d, k * k) 1 k $ Just MatrixMul else GenericIFL23 $ Special (k, k) (k, k * k) (d `div` k) k $ Just ZeroQuadrant Nothing -> case d of 1 -> GenericIFL23 $ Special (d, d) (d, d * d) 1 d Nothing _ -> GenericPPAD -- | construct and optimise a temporary lambda, that calculates the -- Jacobian of the scan op. Figure out if the Jacobian has some -- special shape, discarding the temporary lambda. identifyCase :: VjpOps -> Lambda SOACS -> ADM ScanAlgo identifyCase ops lam = do let t = lambdaReturnType lam let d = length t idmat <- identityM d $ head t lams <- traverse (mkScanAdjointLam ops lam WrtFirst) idmat par1 <- traverse (newParam "tmp1") t par2 <- traverse (newParam "tmp2") t jac_lam <- mkLambda (par1 ++ par2) $ do let args = fmap eParam $ par1 ++ par2 lam_rs <- traverse (`eLambda` args) lams pure $ concat (transpose lam_rs) simp <- simplifyLambda jac_lam let jac = chunk d $ fmap (BasicOp . SubExp . resSubExp) $ bodyResult $ lambdaBody simp pure $ cases d (head t) jac scanRight :: [VName] -> SubExp -> Scan SOACS -> ADM [VName] scanRight as w scan = do as_types <- mapM lookupType as let arg_type_row = map rowType as_types par_a1 <- zipWithM (\x -> newParam (baseString x <> "_par_a1")) as arg_type_row par_a2 <- zipWithM (\x -> newParam (baseString x <> "_par_a2")) as arg_type_row -- Just the original operator but with par_a1 and par_a2 swapped. rev_op <- mkLambda (par_a1 <> par_a2) $ do op <- renameLambda $ scanLambda scan eLambda op (map (toExp . paramName) (par_a2 <> par_a1)) -- same neutral element let e = scanNeutral scan let rev_scan = Scan rev_op e iota <- letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64 -- flip the input array (this code is inspired from the code in -- diffScanAdd, but made to work with [VName] instead VName) map_scan <- revArrLam as -- perform the scan scan_res <- letTupExp "adj_ctrb_scan" . Op . Screma w [iota] $ scanomapSOAC [rev_scan] map_scan -- flip the output array again rev_lam <- revArrLam scan_res letTupExp "reverse_scan_result" $ Op $ Screma w [iota] $ mapSOAC rev_lam where revArrLam :: [VName] -> ADM (Lambda SOACS) revArrLam arrs = do par_i <- newParam "i" $ Prim int64 mkLambda [par_i] . forM arrs $ \arr -> fmap varRes . letExp "ys_bar_rev" =<< eIndex arr [toExp (pe64 w - le64 (paramName par_i) - 1)] mkPPADOpLifted :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS) mkPPADOpLifted ops as scan = do as_types <- mapM lookupType as let arg_type_row = map rowType as_types par_x1 <- zipWithM (\x -> newParam (baseString x ++ "_par_x1")) as arg_type_row par_x2_unused <- zipWithM (\x -> newParam (baseString x ++ "_par_x2_unused")) as arg_type_row par_a1 <- zipWithM (\x -> newParam (baseString x ++ "_par_a1")) as arg_type_row par_a2 <- zipWithM (\x -> newParam (baseString x ++ "_par_a2")) as arg_type_row par_y1_h <- zipWithM (\x -> newParam (baseString x ++ "_par_y1_h")) as arg_type_row par_y2_h <- zipWithM (\x -> newParam (baseString x ++ "_par_y2_h")) as arg_type_row add_lams <- mapM addLambda arg_type_row mkLambda (par_x1 ++ par_a1 ++ par_y1_h ++ par_x2_unused ++ par_a2 ++ par_y2_h) (op_lift par_x1 par_a1 par_y1_h par_a2 par_y2_h add_lams) where op_lift px1 pa1 py1 pa2 py2 adds = do op_bar_1 <- mkScanAdjointLam ops (scanLambda scan) WrtFirst (Var . paramName <$> py2) let op_bar_args = toExp . Var . paramName <$> px1 ++ pa1 z_term <- map resSubExp <$> eLambda op_bar_1 op_bar_args let z = mapM (\(z_t, y_1, add) -> head <$> eLambda add [toExp z_t, toExp y_1]) (zip3 z_term (Var . paramName <$> py1) adds) let x1 = subExpsRes <$> mapM (toSubExp "x1" . Var . paramName) px1 op <- renameLambda $ scanLambda scan let a3 = eLambda op (toExp . paramName <$> pa1 ++ pa2) concat <$> sequence [x1, a3, z] asLiftPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName] asLiftPPAD as w e = do par_i <- newParam "i" $ Prim int64 lmb <- mkLambda [par_i] $ do forM (zip as e) $ \(arr, arr_e) -> do a_lift <- letExp "a_lift" =<< eIf ( do nm1 <- toSubExp "n_minus_one" $ pe64 w - 1 pure $ BasicOp $ CmpOp (CmpSlt Int64) (Var $ paramName par_i) nm1 ) ( buildBody_ $ (\x -> [subExpRes x]) <$> (letSubExp "val" =<< eIndex arr [toExp $ le64 (paramName par_i) + 1]) ) (buildBody_ $ pure [subExpRes arr_e]) pure $ varRes a_lift iota <- letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64 letTupExp "as_lift" $ Op $ Screma w [iota] $ mapSOAC lmb ysRightPPAD :: [VName] -> SubExp -> [SubExp] -> ADM [VName] ysRightPPAD ys w e = do par_i <- newParam "i" $ Prim int64 lmb <- mkLambda [par_i] $ do forM (zip ys e) $ \(arr, arr_e) -> do a_lift <- letExp "y_right" =<< eIf ( pure $ BasicOp $ CmpOp (CmpEq int64) (Var $ paramName par_i) (constant (0 :: Int64)) ) (buildBody_ $ pure [subExpRes arr_e]) ( buildBody_ $ (\x -> [subExpRes x]) <$> (letSubExp "val" =<< eIndex arr [toExp $ le64 (paramName par_i) - 1]) ) pure $ varRes a_lift iota <- letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64 letTupExp "ys_right" $ Op $ Screma w [iota] $ mapSOAC lmb finalMapPPAD :: VjpOps -> [VName] -> Scan SOACS -> ADM (Lambda SOACS) finalMapPPAD ops as scan = do as_types <- mapM lookupType as let arg_type_row = map rowType as_types par_y_right <- zipWithM (\x -> newParam (baseString x ++ "_par_y_right")) as arg_type_row par_a <- zipWithM (\x -> newParam (baseString x ++ "_par_a")) as arg_type_row par_r_adj <- zipWithM (\x -> newParam (baseString x ++ "_par_r_adj")) as arg_type_row mkLambda (par_y_right ++ par_a ++ par_r_adj) $ do op_bar_2 <- mkScanAdjointLam ops (scanLambda scan) WrtSecond (Var . paramName <$> par_r_adj) eLambda op_bar_2 $ toExp . Var . paramName <$> par_y_right ++ par_a diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM () diffScan ops ys w as scan = do -- ys ~ results of scan, w ~ size of input array, as ~ (unzipped) -- arrays, scan ~ scan: operator with ne scan_case <- identifyCase ops $ scanLambda scan let d = length as ys_adj <- mapM lookupAdjVal ys -- ys_bar as_ts <- mapM lookupType as as_contribs <- case scan_case of GenericPPAD -> do let e = scanNeutral scan as_lift <- asLiftPPAD as w e let m = ys ++ as_lift ++ ys_adj op_lft <- mkPPADOpLifted ops as scan a_zero <- mapM (fmap Var . letExp "rscan_zero" . zeroExp . rowType) as_ts let lft_scan = Scan op_lft $ e ++ e ++ a_zero rs_adj <- (!! 2) . chunk d <$> scanRight m w lft_scan ys_right <- ysRightPPAD ys w e final_lmb <- finalMapPPAD ops as scan letTupExp "as_bar" $ Op $ Screma w (ys_right ++ as ++ rs_adj) $ mapSOAC final_lmb GenericIFL23 sc -> do -- IFL23 map1_lam <- mkScanFusedMapLam ops w (scanLambda scan) as ys ys_adj sc d scans_lin_fun_o <- mkScanLinFunO (head as_ts) sc scan_lams <- mkScans (specialScans sc) scans_lin_fun_o iota <- letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64 r_scan <- letTupExp "adj_ctrb_scan" . Op . Screma w [iota] $ scanomapSOAC scan_lams map1_lam mkScanFinalMap ops w (scanLambda scan) as ys (splitScanRes sc r_scan d) -- Goal: calculate as_contribs in new way -- zipWithM_ updateAdj as as_contribs -- as_bar += new adjoint zipWithM_ updateAdj as as_contribs where mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS] mkScans d s = replicateM d $ do lam' <- renameLambda $ scanLambda s pure $ Scan lam' $ scanNeutral s splitScanRes sc res d = concatMap (take (div d $ specialScans sc)) (orderArgs sc res) diffScanVec :: VjpOps -> [VName] -> StmAux () -> SubExp -> Lambda SOACS -> [SubExp] -> [VName] -> ADM () -> ADM () diffScanVec ops ys aux w lam ne as m = do stmts <- collectStms_ $ do rank <- arrayRank <$> lookupType (head as) let rear = [1, 0] ++ drop 2 [0 .. rank - 1] transp_as <- forM as $ \a -> letExp (baseString a ++ "_transp") $ BasicOp $ Rearrange rear a ts <- traverse lookupType transp_as let n = arraysSize 0 ts as_par <- traverse (newParam "as_par" . rowType) ts ne_par <- traverse (newParam "ne_par") $ lambdaReturnType lam scan_form <- scanSOAC [Scan lam (map (Var . paramName) ne_par)] map_lam <- mkLambda (as_par ++ ne_par) . fmap varsRes . letTupExp "map_res" . Op $ Screma w (map paramName as_par) scan_form transp_ys <- letTupExp "trans_ys" . Op $ Screma n (transp_as ++ subExpVars ne) (mapSOAC map_lam) forM (zip ys transp_ys) $ \(y, x) -> auxing aux $ letBindNames [y] $ BasicOp $ Rearrange rear x foldr (vjpStm ops) m stmts diffScanAdd :: VjpOps -> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM () diffScanAdd _ops ys n lam' ne as = do lam <- renameLambda lam' ys_bar <- lookupAdjVal ys map_scan <- rev_arr_lam ys_bar iota <- letExp "iota" $ BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64 scan_res <- letExp "res_rev" $ Op $ Screma n [iota] $ scanomapSOAC [Scan lam [ne]] map_scan rev_lam <- rev_arr_lam scan_res contrb <- letExp "contrb" $ Op $ Screma n [iota] $ mapSOAC rev_lam updateAdj as contrb where rev_arr_lam :: VName -> ADM (Lambda SOACS) rev_arr_lam arr = do par_i <- newParam "i" $ Prim int64 mkLambda [par_i] $ do a <- letExp "ys_bar_rev" =<< eIndex arr [toExp (pe64 n - le64 (paramName par_i) - 1)] pure [varRes a] futhark-0.25.27/src/Futhark/AD/Rev/Scatter.hs000066400000000000000000000167551475065116200205340ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.AD.Rev.Scatter (vjpScatter) where import Control.Monad import Futhark.AD.Rev.Monad import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools import Futhark.Util (chunk) withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName withinBounds [] = TPrimExp $ ValueExp (BoolValue True) withinBounds [(q, i)] = (le64 i .<. pe64 q) .&&. (pe64 (intConst Int64 (-1)) .<. le64 i) withinBounds (qi : qis) = withinBounds [qi] .&&. withinBounds qis -- Generates a potential tower-of-maps lambda body for an indexing operation. -- Assuming parameters: -- `arr` the array that is indexed -- `[(w_1, i_1), (w_2, i_2), ..., (w_k, i_k)]` outer lambda formal parameters and their bounds -- `[n_1,n_2,...]ptp` the type of the index expression `arr[i_1,i_2,...,i_k]` -- Generates something like: -- (\ i_1 i_2 -> -- map (\j_1 -> ... if (i_1 >= 0 && i_1 < w_1) && -- (i_2 >= 0 && i_2 < w_2) && ... -- then arr[i_1, i_2, ... j_1, ...] -- else 0 -- ) (iota n_1) -- ) -- The idea is that you do not want to put under the `if` something -- that is an array because it would not flatten well! genIdxLamBody :: VName -> [(SubExp, Param Type)] -> Type -> ADM (Body SOACS) genIdxLamBody as wpis = genRecLamBody as wpis [] where genRecLamBody :: VName -> [(SubExp, Param Type)] -> [Param Type] -> Type -> ADM (Body SOACS) genRecLamBody arr w_pis nest_pis (Array t (Shape []) _) = genRecLamBody arr w_pis nest_pis (Prim t) genRecLamBody arr w_pis nest_pis (Array t (Shape (s : ss)) _) = do new_ip <- newParam "i" (Prim int64) let t' = Prim t `arrayOfShape` Shape ss inner_lam <- mkLambda [new_ip] $ bodyBind =<< genRecLamBody arr w_pis (nest_pis ++ [new_ip]) t' let (_, orig_pis) = unzip w_pis buildBody_ . localScope (scopeOfLParams (orig_pis ++ nest_pis)) $ do iota_v <- letExp "iota" $ BasicOp $ Iota s (intConst Int64 0) (intConst Int64 1) Int64 r <- letSubExp (baseString arr ++ "_elem") $ Op $ Screma s [iota_v] (mapSOAC inner_lam) pure [subExpRes r] genRecLamBody arr w_pis nest_pis (Prim ptp) = do let (ws, orig_pis) = unzip w_pis let inds = map paramName (orig_pis ++ nest_pis) localScope (scopeOfLParams (orig_pis ++ nest_pis)) $ eBody [ eIf (toExp $ withinBounds $ zip ws $ map paramName orig_pis) ( do r <- letSubExp "r" $ BasicOp $ Index arr $ Slice $ map (DimFix . Var) inds resultBodyM [r] ) (resultBodyM [Constant $ blankPrimValue ptp]) ] genRecLamBody _ _ _ _ = error "In Rev.hs, helper function genRecLamBody, unreachable case reached!" -- -- Original: -- let ys = scatter xs is vs -- Assumes no duplicate indices in `is` -- Forward Sweep: -- let xs_save = gather xs is -- let ys = scatter xs is vs -- Return Sweep: -- let vs_ctrbs = gather is ys_adj -- let vs_adj \overline{+}= vs_ctrbs -- by map or generalized reduction -- let xs_adj = scatter ys_adj is \overline{0} -- let xs = scatter ys is xs_save vjpScatter1 :: PatElem Type -> StmAux () -> (SubExp, [VName], (ShapeBase SubExp, Int, VName)) -> ADM () -> ADM () vjpScatter1 pys aux (w, ass, (shp, num_vals, xs)) m = do let rank = length $ shapeDims shp (all_inds, val_as) = splitAt (rank * num_vals) ass inds_as = chunk rank all_inds xs_t <- lookupType xs let val_t = stripArray (shapeRank shp) xs_t -- computing xs_save xs_saves <- mkGather inds_as xs xs_t -- performing the scatter id_lam <- mkIdentityLambda $ replicate (shapeRank shp) (Prim int64) ++ replicate (shapeRank shp) val_t addStm $ Let (Pat [pys]) aux $ Op $ Scatter w ass [(shp, num_vals, xs)] id_lam m let ys = patElemName pys -- XXX: Since our restoration of xs will consume ys, we have to -- make a copy of ys in the chance that it is actually the result -- of the program. In that case the asymptotics will not be -- (locally) preserved, but since ys must necessarily have been -- constructed somewhere close, they are probably globally OK. ys_copy <- letExp (baseString ys <> "_copy") . BasicOp $ Replicate mempty (Var ys) returnSweepCode $ do ys_adj <- lookupAdjVal ys -- computing vs_ctrbs and updating vs_adj vs_ctrbs <- mkGather inds_as ys_adj xs_t zipWithM_ updateAdj val_as vs_ctrbs -- use Slice? -- creating xs_adj zeros <- replicateM (length val_as) . letExp "zeros" $ zeroExp $ xs_t `setOuterSize` w let f_tps = replicate (rank * num_vals) (Prim int64) ++ replicate num_vals val_t f <- mkIdentityLambda f_tps xs_adj <- letExp (baseString xs ++ "_adj") . Op $ Scatter w (all_inds ++ zeros) [(shp, num_vals, ys_adj)] f insAdj xs xs_adj -- reusing the ys_adj for xs_adj! f' <- mkIdentityLambda f_tps xs_rc <- auxing aux . letExp (baseString xs <> "_rc") . Op $ Scatter w (all_inds ++ xs_saves) [(shp, num_vals, ys)] f' addSubstitution xs xs_rc addSubstitution ys ys_copy where -- Creates a potential map-nest that indexes in full the array, -- and applies the condition of indices within bounds at the -- deepest level in the nest so that everything can be parallel. mkGather :: [[VName]] -> VName -> Type -> ADM [VName] mkGather inds_as arr arr_t = do ips <- forM inds_as $ \idxs -> mapM (\idx -> newParam (baseString idx ++ "_elem") (Prim int64)) idxs gather_lam <- mkLambda (concat ips) . fmap mconcat . forM ips $ \idxs -> do let q = length idxs (ws, eltp) = (take q $ arrayDims arr_t, stripArray q arr_t) bodyBind =<< genIdxLamBody arr (zip ws idxs) eltp let soac = Screma w (concat inds_as) (mapSOAC gather_lam) letTupExp (baseString arr ++ "_gather") $ Op soac vjpScatter :: VjpOps -> Pat Type -> StmAux () -> (SubExp, [VName], Lambda SOACS, [(Shape, Int, VName)]) -> ADM () -> ADM () vjpScatter ops (Pat pes) aux (w, ass, lam, written_info) m | isIdentityLambda lam, [(shp, num_vals, xs)] <- written_info, [pys] <- pes = vjpScatter1 pys aux (w, ass, (shp, num_vals, xs)) m | isIdentityLambda lam = do let sind = splitInd written_info (inds, vals) = splitAt sind ass lst_stms <- chunkScatterInps (inds, vals) (zip pes written_info) diffScatters (stmsFromList lst_stms) | otherwise = error "vjpScatter: cannot handle" where splitInd [] = 0 splitInd ((shp, num_res, _) : rest) = num_res * length (shapeDims shp) + splitInd rest chunkScatterInps (acc_inds, acc_vals) [] = case (acc_inds, acc_vals) of ([], []) -> pure [] _ -> error "chunkScatterInps: cannot handle" chunkScatterInps (acc_inds, acc_vals) ((pe, info@(shp, num_vals, _)) : rest) = do let num_inds = num_vals * length (shapeDims shp) (curr_inds, other_inds) = splitAt num_inds acc_inds (curr_vals, other_vals) = splitAt num_vals acc_vals vtps <- mapM lookupType curr_vals f <- mkIdentityLambda (replicate num_inds (Prim int64) ++ vtps) let stm = Let (Pat [pe]) aux . Op $ Scatter w (curr_inds ++ curr_vals) [info] f stms_rest <- chunkScatterInps (other_inds, other_vals) rest pure $ stm : stms_rest diffScatters all_stms | Just (stm, stms) <- stmsHead all_stms = vjpStm ops stm $ diffScatters stms | otherwise = m futhark-0.25.27/src/Futhark/Actions.hs000066400000000000000000000542361475065116200175030ustar00rootroot00000000000000-- | All (almost) compiler pipelines end with an 'Action', which does -- something with the result of the pipeline. module Futhark.Actions ( printAction, printAliasesAction, printLastUseGPU, printFusionGraph, printInterferenceGPU, printMemAliasGPU, printMemoryAccessAnalysis, callGraphAction, impCodeGenAction, kernelImpCodeGenAction, multicoreImpCodeGenAction, metricsAction, compileCAction, compileCtoWASMAction, compileOpenCLAction, compileCUDAAction, compileHIPAction, compileMulticoreAction, compileMulticoreToISPCAction, compileMulticoreToWASMAction, compilePythonAction, compilePyOpenCLAction, ) where import Control.Monad import Control.Monad.IO.Class import Data.Bifunctor import Data.List (intercalate) import Data.Map qualified as M import Data.Maybe (fromMaybe) import Data.Text qualified as T import Data.Text.IO qualified as T import Futhark.Analysis.AccessPattern import Futhark.Analysis.Alias import Futhark.Analysis.CallGraph (buildCallGraph) import Futhark.Analysis.Interference qualified as Interference import Futhark.Analysis.LastUse qualified as LastUse import Futhark.Analysis.MemAlias qualified as MemAlias import Futhark.Analysis.Metrics import Futhark.CodeGen.Backends.CCUDA qualified as CCUDA import Futhark.CodeGen.Backends.COpenCL qualified as COpenCL import Futhark.CodeGen.Backends.HIP qualified as HIP import Futhark.CodeGen.Backends.MulticoreC qualified as MulticoreC import Futhark.CodeGen.Backends.MulticoreISPC qualified as MulticoreISPC import Futhark.CodeGen.Backends.MulticoreWASM qualified as MulticoreWASM import Futhark.CodeGen.Backends.PyOpenCL qualified as PyOpenCL import Futhark.CodeGen.Backends.SequentialC qualified as SequentialC import Futhark.CodeGen.Backends.SequentialPython qualified as SequentialPy import Futhark.CodeGen.Backends.SequentialWASM qualified as SequentialWASM import Futhark.CodeGen.ImpGen.GPU qualified as ImpGenGPU import Futhark.CodeGen.ImpGen.Multicore qualified as ImpGenMulticore import Futhark.CodeGen.ImpGen.Sequential qualified as ImpGenSequential import Futhark.Compiler.CLI import Futhark.IR import Futhark.IR.GPUMem (GPUMem) import Futhark.IR.MCMem (MCMem) import Futhark.IR.SOACS (SOACS) import Futhark.IR.SeqMem (SeqMem) import Futhark.Optimise.Fusion.GraphRep qualified import Futhark.Util (runProgramWithExitCode, unixEnvironment) import Futhark.Version (versionString) import System.Directory import System.Exit import System.FilePath import System.Info qualified -- | Print the result to stdout. printAction :: (ASTRep rep) => Action rep printAction = Action { actionName = "Prettyprint", actionDescription = "Prettyprint the resulting internal representation on standard output.", actionProcedure = liftIO . putStrLn . prettyString } -- | Print the result to stdout, alias annotations. printAliasesAction :: (AliasableRep rep) => Action rep printAliasesAction = Action { actionName = "Prettyprint", actionDescription = "Prettyprint the resulting internal representation on standard output.", actionProcedure = liftIO . putStrLn . prettyString . aliasAnalysis } -- | Print last use information to stdout. printLastUseGPU :: Action GPUMem printLastUseGPU = Action { actionName = "print last use gpu", actionDescription = "Print last use information on gpu.", actionProcedure = liftIO . putStrLn . prettyString . bimap M.toList (M.toList . fmap M.toList) . LastUse.lastUseGPUMem . aliasAnalysis } -- | Print fusion graph to stdout. printFusionGraph :: Action SOACS printFusionGraph = Action { actionName = "print fusion graph", actionDescription = "Print fusion graph in Graphviz format.", actionProcedure = liftIO . mapM_ ( putStrLn . Futhark.Optimise.Fusion.GraphRep.pprg . Futhark.Optimise.Fusion.GraphRep.mkDepGraphForFun ) . progFuns } -- | Print interference information to stdout. printInterferenceGPU :: Action GPUMem printInterferenceGPU = Action { actionName = "print interference gpu", actionDescription = "Print interference information on gpu.", actionProcedure = liftIO . print . Interference.analyseProgGPU } -- | Print memory alias information to stdout printMemAliasGPU :: Action GPUMem printMemAliasGPU = Action { actionName = "print mem alias gpu", actionDescription = "Print memory alias information on gpu.", actionProcedure = liftIO . print . MemAlias.analyzeGPUMem } -- | Print result of array access analysis on the IR printMemoryAccessAnalysis :: (Analyse rep) => Action rep printMemoryAccessAnalysis = Action { actionName = "array-access-analysis", actionDescription = "Prettyprint the array access analysis to standard output.", actionProcedure = liftIO . putStrLn . prettyString . analyseDimAccesses } -- | Print call graph to stdout. callGraphAction :: Action SOACS callGraphAction = Action { actionName = "call-graph", actionDescription = "Prettyprint the callgraph of the result to standard output.", actionProcedure = liftIO . putStrLn . prettyString . buildCallGraph } -- | Print metrics about AST node counts to stdout. metricsAction :: (OpMetrics (Op rep)) => Action rep metricsAction = Action { actionName = "Compute metrics", actionDescription = "Print metrics on the final AST.", actionProcedure = liftIO . putStr . show . progMetrics } -- | Convert the program to sequential ImpCode and print it to stdout. impCodeGenAction :: Action SeqMem impCodeGenAction = Action { actionName = "Compile imperative", actionDescription = "Translate program into imperative IL and write it on standard output.", actionProcedure = liftIO . putStrLn . prettyString . snd <=< ImpGenSequential.compileProg } -- | Convert the program to GPU ImpCode and print it to stdout. kernelImpCodeGenAction :: Action GPUMem kernelImpCodeGenAction = Action { actionName = "Compile imperative kernels", actionDescription = "Translate program into imperative IL with kernels and write it on standard output.", actionProcedure = liftIO . putStrLn . prettyString . snd <=< ImpGenGPU.compileProgHIP } -- | Convert the program to CPU multicore ImpCode and print it to stdout. multicoreImpCodeGenAction :: Action MCMem multicoreImpCodeGenAction = Action { actionName = "Compile to imperative multicore", actionDescription = "Translate program into imperative multicore IL and write it on standard output.", actionProcedure = liftIO . putStrLn . prettyString . snd <=< ImpGenMulticore.compileProg } -- Lines that we prepend (in comments) to generated code. headerLines :: [T.Text] headerLines = T.lines $ "Generated by Futhark " <> versionString cHeaderLines :: [T.Text] cHeaderLines = map ("// " <>) headerLines pyHeaderLines :: [T.Text] pyHeaderLines = map ("# " <>) headerLines cPrependHeader :: T.Text -> T.Text cPrependHeader = (T.unlines cHeaderLines <>) pyPrependHeader :: T.Text -> T.Text pyPrependHeader = (T.unlines pyHeaderLines <>) cmdCC :: String cmdCC = fromMaybe "cc" $ lookup "CC" unixEnvironment cmdCFLAGS :: [String] -> [String] cmdCFLAGS def = maybe def words $ lookup "CFLAGS" unixEnvironment cmdISPCFLAGS :: [String] -> [String] cmdISPCFLAGS def = maybe def words $ lookup "ISPCFLAGS" unixEnvironment runCC :: String -> String -> [String] -> [String] -> FutharkM () runCC cpath outpath cflags_def ldflags = do ret <- liftIO $ runProgramWithExitCode cmdCC ( [cpath, "-o", outpath] ++ cmdCFLAGS cflags_def ++ -- The default LDFLAGS are always added. ldflags ) mempty case ret of Left err -> externalErrorS $ "Failed to run " ++ cmdCC ++ ": " ++ show err Right (ExitFailure code, _, gccerr) -> externalErrorS $ cmdCC ++ " failed with code " ++ show code ++ ":\n" ++ gccerr Right (ExitSuccess, _, _) -> pure () runISPC :: String -> String -> String -> String -> [String] -> [String] -> [String] -> FutharkM () runISPC ispcpath outpath cpath ispcextension ispc_flags cflags_def ldflags = do ret_ispc <- liftIO $ runProgramWithExitCode cmdISPC ( [ispcpath, "-o", ispcbase `addExtension` "o"] ++ ["--addressing=64", "--pic"] ++ cmdISPCFLAGS ispc_flags -- These flags are always needed ) mempty ret <- liftIO $ runProgramWithExitCode cmdCC ( [ispcbase `addExtension` "o"] ++ [cpath, "-o", outpath] ++ cmdCFLAGS cflags_def ++ -- The default LDFLAGS are always added. ldflags ) mempty case ret_ispc of Left err -> externalErrorS $ "Failed to run " ++ cmdISPC ++ ": " ++ show err Right (ExitFailure code, _, ispcerr) -> throwError cmdISPC code ispcerr Right (ExitSuccess, _, _) -> case ret of Left err -> externalErrorS $ "Failed to run ispc: " ++ show err Right (ExitFailure code, _, gccerr) -> throwError cmdCC code gccerr Right (ExitSuccess, _, _) -> pure () where cmdISPC = "ispc" ispcbase = outpath <> ispcextension throwError prog code err = externalErrorS $ prog ++ " failed with code " ++ show code ++ ":\n" ++ err -- | The @futhark c@ action. compileCAction :: FutharkConfig -> CompilerMode -> FilePath -> Action SeqMem compileCAction fcfg mode outpath = Action { actionName = "Compile to sequential C", actionDescription = "Compile to sequential C", actionProcedure = helper } where helper prog = do cprog <- handleWarnings fcfg $ SequentialC.compileProg versionString prog let cpath = outpath `addExtension` "c" hpath = outpath `addExtension` "h" jsonpath = outpath `addExtension` "json" case mode of ToLibrary -> do let (header, impl, manifest) = SequentialC.asLibrary cprog liftIO $ T.writeFile hpath $ cPrependHeader header liftIO $ T.writeFile cpath $ cPrependHeader impl liftIO $ T.writeFile jsonpath manifest ToExecutable -> do liftIO $ T.writeFile cpath $ SequentialC.asExecutable cprog runCC cpath outpath ["-O3", "-std=c99"] ["-lm"] ToServer -> do liftIO $ T.writeFile cpath $ SequentialC.asServer cprog runCC cpath outpath ["-O3", "-std=c99"] ["-lm"] -- | The @futhark opencl@ action. compileOpenCLAction :: FutharkConfig -> CompilerMode -> FilePath -> Action GPUMem compileOpenCLAction fcfg mode outpath = Action { actionName = "Compile to OpenCL", actionDescription = "Compile to OpenCL", actionProcedure = helper } where helper prog = do cprog <- handleWarnings fcfg $ COpenCL.compileProg versionString prog let cpath = outpath `addExtension` "c" hpath = outpath `addExtension` "h" jsonpath = outpath `addExtension` "json" extra_options | System.Info.os == "darwin" = ["-framework", "OpenCL"] | System.Info.os == "mingw32" = ["-lOpenCL64"] | otherwise = ["-lOpenCL"] case mode of ToLibrary -> do let (header, impl, manifest) = COpenCL.asLibrary cprog liftIO $ T.writeFile hpath $ cPrependHeader header liftIO $ T.writeFile cpath $ cPrependHeader impl liftIO $ T.writeFile jsonpath manifest ToExecutable -> do liftIO $ T.writeFile cpath $ cPrependHeader $ COpenCL.asExecutable cprog runCC cpath outpath ["-O", "-std=c99"] ("-lm" : extra_options) ToServer -> do liftIO $ T.writeFile cpath $ cPrependHeader $ COpenCL.asServer cprog runCC cpath outpath ["-O", "-std=c99"] ("-lm" : extra_options) -- | The @futhark cuda@ action. compileCUDAAction :: FutharkConfig -> CompilerMode -> FilePath -> Action GPUMem compileCUDAAction fcfg mode outpath = Action { actionName = "Compile to CUDA", actionDescription = "Compile to CUDA", actionProcedure = helper } where helper prog = do cprog <- handleWarnings fcfg $ CCUDA.compileProg versionString prog let cpath = outpath `addExtension` "c" hpath = outpath `addExtension` "h" jsonpath = outpath `addExtension` "json" extra_options = [ "-lcuda", "-lcudart", "-lnvrtc" ] case mode of ToLibrary -> do let (header, impl, manifest) = CCUDA.asLibrary cprog liftIO $ T.writeFile hpath $ cPrependHeader header liftIO $ T.writeFile cpath $ cPrependHeader impl liftIO $ T.writeFile jsonpath manifest ToExecutable -> do liftIO $ T.writeFile cpath $ cPrependHeader $ CCUDA.asExecutable cprog runCC cpath outpath ["-O", "-std=c99"] ("-lm" : extra_options) ToServer -> do liftIO $ T.writeFile cpath $ cPrependHeader $ CCUDA.asServer cprog runCC cpath outpath ["-O", "-std=c99"] ("-lm" : extra_options) -- | The @futhark hip@ action. compileHIPAction :: FutharkConfig -> CompilerMode -> FilePath -> Action GPUMem compileHIPAction fcfg mode outpath = Action { actionName = "Compile to HIP", actionDescription = "Compile to HIP", actionProcedure = helper } where helper prog = do cprog <- handleWarnings fcfg $ HIP.compileProg versionString prog let cpath = outpath `addExtension` "c" hpath = outpath `addExtension` "h" jsonpath = outpath `addExtension` "json" extra_options = [ "-lamdhip64", "-lhiprtc" ] case mode of ToLibrary -> do let (header, impl, manifest) = HIP.asLibrary cprog liftIO $ T.writeFile hpath $ cPrependHeader header liftIO $ T.writeFile cpath $ cPrependHeader impl liftIO $ T.writeFile jsonpath manifest ToExecutable -> do liftIO $ T.writeFile cpath $ cPrependHeader $ HIP.asExecutable cprog runCC cpath outpath ["-O", "-std=c99"] ("-lm" : extra_options) ToServer -> do liftIO $ T.writeFile cpath $ cPrependHeader $ HIP.asServer cprog runCC cpath outpath ["-O", "-std=c99"] ("-lm" : extra_options) -- | The @futhark multicore@ action. compileMulticoreAction :: FutharkConfig -> CompilerMode -> FilePath -> Action MCMem compileMulticoreAction fcfg mode outpath = Action { actionName = "Compile to multicore", actionDescription = "Compile to multicore", actionProcedure = helper } where helper prog = do cprog <- handleWarnings fcfg $ MulticoreC.compileProg versionString prog let cpath = outpath `addExtension` "c" hpath = outpath `addExtension` "h" jsonpath = outpath `addExtension` "json" case mode of ToLibrary -> do let (header, impl, manifest) = MulticoreC.asLibrary cprog liftIO $ T.writeFile hpath $ cPrependHeader header liftIO $ T.writeFile cpath $ cPrependHeader impl liftIO $ T.writeFile jsonpath manifest ToExecutable -> do liftIO $ T.writeFile cpath $ cPrependHeader $ MulticoreC.asExecutable cprog runCC cpath outpath ["-O3", "-std=c99"] ["-lm", "-pthread"] ToServer -> do liftIO $ T.writeFile cpath $ cPrependHeader $ MulticoreC.asServer cprog runCC cpath outpath ["-O3", "-std=c99"] ["-lm", "-pthread"] -- | The @futhark ispc@ action. compileMulticoreToISPCAction :: FutharkConfig -> CompilerMode -> FilePath -> Action MCMem compileMulticoreToISPCAction fcfg mode outpath = Action { actionName = "Compile to multicore ISPC", actionDescription = "Compile to multicore ISPC", actionProcedure = helper } where helper prog = do let cpath = outpath `addExtension` "c" hpath = outpath `addExtension` "h" jsonpath = outpath `addExtension` "json" ispcpath = outpath `addExtension` "kernels.ispc" ispcextension = "_ispc" (cprog, ispc) <- handleWarnings fcfg $ MulticoreISPC.compileProg versionString prog case mode of ToLibrary -> do let (header, impl, manifest) = MulticoreC.asLibrary cprog liftIO $ T.writeFile hpath $ cPrependHeader header liftIO $ T.writeFile cpath $ cPrependHeader impl liftIO $ T.writeFile ispcpath ispc liftIO $ T.writeFile jsonpath manifest ToExecutable -> do liftIO $ T.writeFile cpath $ cPrependHeader $ MulticoreC.asExecutable cprog liftIO $ T.writeFile ispcpath ispc runISPC ispcpath outpath cpath ispcextension ["-O3", "--woff"] ["-O3", "-std=c99"] ["-lm", "-pthread"] ToServer -> do liftIO $ T.writeFile cpath $ cPrependHeader $ MulticoreC.asServer cprog liftIO $ T.writeFile ispcpath ispc runISPC ispcpath outpath cpath ispcextension ["-O3", "--woff"] ["-O3", "-std=c99"] ["-lm", "-pthread"] pythonCommon :: (CompilerMode -> String -> prog -> FutharkM (Warnings, T.Text)) -> FutharkConfig -> CompilerMode -> FilePath -> prog -> FutharkM () pythonCommon codegen fcfg mode outpath prog = do let class_name = case mode of ToLibrary -> takeBaseName outpath _ -> "internal" pyprog <- handleWarnings fcfg $ codegen mode class_name prog case mode of ToLibrary -> liftIO $ T.writeFile (outpath `addExtension` "py") $ pyPrependHeader pyprog _ -> liftIO $ do T.writeFile outpath $ "#!/usr/bin/env python3\n" <> pyPrependHeader pyprog perms <- liftIO $ getPermissions outpath setPermissions outpath $ setOwnerExecutable True perms -- | The @futhark python@ action. compilePythonAction :: FutharkConfig -> CompilerMode -> FilePath -> Action SeqMem compilePythonAction fcfg mode outpath = Action { actionName = "Compile to PyOpenCL", actionDescription = "Compile to Python with OpenCL", actionProcedure = pythonCommon SequentialPy.compileProg fcfg mode outpath } -- | The @futhark pyopencl@ action. compilePyOpenCLAction :: FutharkConfig -> CompilerMode -> FilePath -> Action GPUMem compilePyOpenCLAction fcfg mode outpath = Action { actionName = "Compile to PyOpenCL", actionDescription = "Compile to Python with OpenCL", actionProcedure = pythonCommon PyOpenCL.compileProg fcfg mode outpath } cmdEMCFLAGS :: [String] -> [String] cmdEMCFLAGS def = maybe def words $ lookup "EMCFLAGS" unixEnvironment runEMCC :: String -> String -> FilePath -> [String] -> [String] -> [String] -> Bool -> FutharkM () runEMCC cpath outpath classpath cflags_def ldflags expfuns lib = do ret <- liftIO $ runProgramWithExitCode "emcc" ( [cpath, "-o", outpath] ++ ["-lnodefs.js"] ++ ["-s", "--extern-post-js", classpath] ++ ( if lib then ["-s", "EXPORT_NAME=loadWASM"] else [] ) ++ ["-s", "WASM_BIGINT"] ++ cmdCFLAGS cflags_def ++ cmdEMCFLAGS [""] ++ [ "-s", "EXPORTED_FUNCTIONS=[" ++ intercalate "," ("'_malloc'" : "'_free'" : expfuns) ++ "]" ] -- The default LDFLAGS are always added. ++ ldflags ) mempty case ret of Left err -> externalErrorS $ "Failed to run emcc: " ++ show err Right (ExitFailure code, _, emccerr) -> externalErrorS $ "emcc failed with code " ++ show code ++ ":\n" ++ emccerr Right (ExitSuccess, _, _) -> pure () -- | The @futhark wasm@ action. compileCtoWASMAction :: FutharkConfig -> CompilerMode -> FilePath -> Action SeqMem compileCtoWASMAction fcfg mode outpath = Action { actionName = "Compile to sequential C", actionDescription = "Compile to sequential C", actionProcedure = helper } where helper prog = do (cprog, jsprog, exps) <- handleWarnings fcfg $ SequentialWASM.compileProg versionString prog case mode of ToLibrary -> do writeLibs cprog jsprog liftIO $ T.appendFile classpath SequentialWASM.libraryExports runEMCC cpath mjspath classpath ["-O3", "-msimd128"] ["-lm"] exps True _ -> do -- Non-server executables are not supported. writeLibs cprog jsprog liftIO $ T.appendFile classpath SequentialWASM.runServer runEMCC cpath outpath classpath ["-O3", "-msimd128"] ["-lm"] exps False writeLibs cprog jsprog = do let (h, imp, _) = SequentialC.asLibrary cprog liftIO $ T.writeFile hpath h liftIO $ T.writeFile cpath imp liftIO $ T.writeFile classpath jsprog cpath = outpath `addExtension` "c" hpath = outpath `addExtension` "h" mjspath = outpath `addExtension` "mjs" classpath = outpath `addExtension` ".class.js" -- | The @futhark wasm-multicore@ action. compileMulticoreToWASMAction :: FutharkConfig -> CompilerMode -> FilePath -> Action MCMem compileMulticoreToWASMAction fcfg mode outpath = Action { actionName = "Compile to sequential C", actionDescription = "Compile to sequential C", actionProcedure = helper } where helper prog = do (cprog, jsprog, exps) <- handleWarnings fcfg $ MulticoreWASM.compileProg versionString prog case mode of ToLibrary -> do writeLibs cprog jsprog liftIO $ T.appendFile classpath MulticoreWASM.libraryExports runEMCC cpath mjspath classpath ["-O3", "-msimd128"] ["-lm", "-pthread"] exps True _ -> do -- Non-server executables are not supported. writeLibs cprog jsprog liftIO $ T.appendFile classpath MulticoreWASM.runServer runEMCC cpath outpath classpath ["-O3", "-msimd128"] ["-lm", "-pthread"] exps False writeLibs cprog jsprog = do let (h, imp, _) = MulticoreC.asLibrary cprog liftIO $ T.writeFile hpath h liftIO $ T.writeFile cpath imp liftIO $ T.writeFile classpath jsprog cpath = outpath `addExtension` "c" hpath = outpath `addExtension` "h" mjspath = outpath `addExtension` "mjs" classpath = outpath `addExtension` ".class.js" futhark-0.25.27/src/Futhark/Analysis/000077500000000000000000000000001475065116200173205ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Analysis/AccessPattern.hs000066400000000000000000000706741475065116200224310ustar00rootroot00000000000000{-# LANGUAGE LambdaCase #-} module Futhark.Analysis.AccessPattern ( analyseDimAccesses, analyseFunction, vnameFromSegOp, analysisPropagateByTransitivity, isInvariant, Analyse, IndexTable, ArrayName, DimAccess (..), IndexExprName, BodyType (..), SegOpName (SegmentedMap, SegmentedRed, SegmentedScan, SegmentedHist), Context (..), analyseIndex, VariableInfo (..), VarType (..), isCounter, Dependency (..), ) where import Data.Bifunctor import Data.Foldable import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Futhark.IR.Aliases import Futhark.IR.GPU import Futhark.IR.GPUMem import Futhark.IR.MC import Futhark.IR.MCMem import Futhark.IR.SOACS import Futhark.IR.Seq import Futhark.IR.SeqMem import Futhark.Util.Pretty -- | Name of a SegOp, used to identify the SegOp that an array access is -- contained in. data SegOpName = SegmentedMap {vnameFromSegOp :: VName} | SegmentedRed {vnameFromSegOp :: VName} | SegmentedScan {vnameFromSegOp :: VName} | SegmentedHist {vnameFromSegOp :: VName} deriving (Eq, Ord, Show) -- | Name of an array indexing expression. Taken from the pattern of -- the expression. type IndexExprName = VName data BodyType = SegOpName SegOpName | LoopBodyName VName | CondBodyName VName deriving (Show, Ord, Eq) -- | Stores the name of an array, the nest of loops, kernels, -- conditionals in which it is constructed, and the existing layout of -- the array. The latter is currently largely unused and not -- trustworthy, but might be useful in the future. type ArrayName = (VName, [BodyType], [Int]) -- | Tuple of patternName and nested `level` it index occurred at, as well as -- what the actual iteration type is. data Dependency = Dependency { lvl :: Int, varType :: VarType } deriving (Eq, Show) -- | Collect all features of access to a specific dimension of an array. data DimAccess rep = DimAccess { -- | Set of VNames of iteration variables (gtids, loop counters, etc.) -- that some access is variant to. -- An empty set indicates that the access is invariant. dependencies :: M.Map VName Dependency, -- | Used to store the name of the original expression from which `dependencies` -- was computed. `Nothing` if it is a constant. originalVar :: Maybe VName } deriving (Eq, Show) instance Semigroup (DimAccess rep) where adeps <> bdeps = DimAccess (dependencies adeps <> dependencies bdeps) ( case originalVar adeps of Nothing -> originalVar bdeps _ -> originalVar adeps ) instance Monoid (DimAccess rep) where mempty = DimAccess mempty Nothing isInvariant :: DimAccess rep -> Bool isInvariant = null . dependencies -- | For each array access in a program, this data structure stores the -- dependencies of each dimension in the access, the array name, and the -- name of the SegOp that the access is contained in. -- Each DimAccess element corresponds to an access to a given dimension -- in the given array, in the same order of the dimensions. type IndexTable rep = M.Map SegOpName (M.Map ArrayName (M.Map IndexExprName [DimAccess rep])) unionIndexTables :: IndexTable rep -> IndexTable rep -> IndexTable rep unionIndexTables = M.unionWith (M.unionWith M.union) -- | Make segops on arrays transitive, ie. if -- > let A = segmap (..) xs -- A indexes into xs -- > let B = segmap (..) A -- B indexes into A -- Then B also derives all A's array-accesses, like xs. -- Runs in n² analysisPropagateByTransitivity :: IndexTable rep -> IndexTable rep analysisPropagateByTransitivity idx_table = M.map foldlArrayNameMap idx_table where aggregateResults arr_name = maybe mempty foldlArrayNameMap (M.mapKeys vnameFromSegOp idx_table M.!? arr_name) foldlArrayNameMap aMap = foldl (M.unionWith M.union) aMap $ map (aggregateResults . \(a, _, _) -> a) $ M.keys aMap -- -- Helper types and functions to perform the analysis. -- -- | Used during the analysis to keep track of the dependencies of patterns -- encountered so far. data Context rep = Context { -- | A mapping from patterns occuring in Let expressions to their dependencies -- and iteration types. assignments :: M.Map VName (VariableInfo rep), -- | Maps from sliced arrays to their respective access patterns. slices :: M.Map IndexExprName (ArrayName, [VName], [DimAccess rep]), -- | A list of the segMaps encountered during the analysis in the order they -- were encountered. parents :: [BodyType], -- | Current level of recursion, also just `length parents` currentLevel :: Int } deriving (Show, Eq) instance Monoid (Context rep) where mempty = Context { assignments = mempty, slices = mempty, parents = [], currentLevel = 0 } instance Semigroup (Context rep) where Context ass0 slices0 lastBody0 lvl0 <> Context ass1 slices1 lastBody1 lvl1 = Context (ass0 <> ass1) (slices0 <> slices1) (lastBody0 <> lastBody1) (max lvl0 lvl1) -- | Extend a context with another context. -- We never have to consider the case where VNames clash in the context, since -- they are unique. extend :: Context rep -> Context rep -> Context rep extend = (<>) allSegMap :: Context rep -> [SegOpName] allSegMap (Context _ _ parents _) = mapMaybe f parents where f (SegOpName o) = Just o f _ = Nothing -- | Context Value (VariableInfo) is the type used in the context to categorize -- assignments. For example, a pattern might depend on a function parameter, a -- gtid, or some other pattern. data VariableInfo rep = VariableInfo { deps :: Names, level :: Int, parents_nest :: [BodyType], variableType :: VarType } deriving (Show, Eq) data VarType = ConstType | Variable | ThreadID | LoopVar deriving (Show, Eq) isCounter :: VarType -> Bool isCounter LoopVar = True isCounter ThreadID = True isCounter _ = False varInfoFromNames :: Context rep -> Names -> VariableInfo rep varInfoFromNames ctx names = do VariableInfo names (currentLevel ctx) (parents ctx) Variable -- | Wrapper around the constructur of Context. oneContext :: VName -> VariableInfo rep -> Context rep oneContext name var_info = Context { assignments = M.singleton name var_info, slices = mempty, parents = [], currentLevel = 0 } -- | Create a singular varInfo with no dependencies. varInfoZeroDeps :: Context rep -> VariableInfo rep varInfoZeroDeps ctx = VariableInfo mempty (currentLevel ctx) (parents ctx) Variable -- | Create a singular context from a segspace contextFromNames :: Context rep -> VariableInfo rep -> [VName] -> Context rep contextFromNames ctx var_info = foldl' extend ctx . map (`oneContext` var_info) -- | A representation where we can analyse access patterns. class Analyse rep where -- | Analyse the op for this representation. analyseOp :: Op rep -> Context rep -> [VName] -> (Context rep, IndexTable rep) -- | Analyse each `entry` and accumulate the results. analyseDimAccesses :: (Analyse rep) => Prog rep -> IndexTable rep analyseDimAccesses = foldMap' analyseFunction . progFuns -- | Analyse each statement in a function body. analyseFunction :: (Analyse rep) => FunDef rep -> IndexTable rep analyseFunction func = let stms = stmsToList . bodyStms $ funDefBody func -- Create a context containing the function parameters ctx = contextFromNames mempty (varInfoZeroDeps ctx) $ map paramName $ funDefParams func in snd $ analyseStmsPrimitive ctx stms -- | Analyse each statement in a list of statements. analyseStmsPrimitive :: (Analyse rep) => Context rep -> [Stm rep] -> (Context rep, IndexTable rep) analyseStmsPrimitive ctx = -- Fold over statements in body foldl' (\(c, r) stm -> second (unionIndexTables r) $ analyseStm c stm) (ctx, mempty) -- | Same as analyseStmsPrimitive, but change the resulting context into -- a varInfo, mapped to pattern. analyseStms :: (Analyse rep) => Context rep -> (VName -> BodyType) -> [VName] -> [Stm rep] -> (Context rep, IndexTable rep) analyseStms ctx body_constructor pats body = do -- 0. Recurse into body with ctx let (ctx'', indexTable) = analyseStmsPrimitive recContext body -- 0.1 Get all new slices let slices_new = M.difference (slices ctx'') (slices ctx) -- 0.2 Make "IndexExpressions" of the slices let slices_indices = foldl unionIndexTables indexTable $ mapMaybe ( uncurry $ \_idx_expression (array_name, patterns, dim_indices) -> Just . snd $ -- Should we use recContex instead of ctx''? analyseIndex' ctx'' patterns array_name dim_indices ) $ M.toList slices_new -- 1. We do not want the returned context directly. -- however, we do want pat to map to the names what was hit in body. -- therefore we need to subtract the old context from the returned one, -- and discard all the keys within it. -- assignments :: M.Map VName (VariableInfo rep), let in_scope_dependencies_from_body = rmOutOfScopeDeps ctx'' $ M.difference (assignments ctx'') (assignments recContext) -- 2. We are ONLY interested in the rhs of assignments (ie. the -- dependencies of pat :) ) let ctx' = foldl extend ctx $ concatVariableInfo in_scope_dependencies_from_body -- . map snd $ M.toList varInfos -- 3. Now we have the correct context and result (ctx' {parents = parents ctx, currentLevel = currentLevel ctx, slices = slices ctx}, slices_indices) where -- Extracts and merges `Names` in `VariableInfo`s, and makes a new VariableInfo. This -- MAY throw away needed information, but it was my best guess at a solution -- at the time of writing. concatVariableInfo dependencies = map (\pat -> oneContext pat (varInfoFromNames ctx dependencies)) pats -- Context used for "recursion" into analyseStmsPrimitive recContext = ctx { parents = parents ctx <> concatMap (\pat -> [body_constructor pat]) pats, currentLevel = currentLevel ctx + 1 } -- Recursively looks up dependencies, until they're in scope or empty set. rmOutOfScopeDeps :: Context rep -> M.Map VName (VariableInfo rep) -> Names rmOutOfScopeDeps ctx' new_assignments = let throwaway_assignments = assignments ctx' local_assignments = assignments ctx f result a var_info = -- if the VName of the assignment exists in the context, we are good if a `M.member` local_assignments then result <> oneName a else -- Otherwise, recurse on its dependencies; -- 0. Add dependencies in ctx to result let (deps_in_ctx, deps_not_in_ctx) = L.partition (`M.member` local_assignments) $ namesToList (deps var_info) deps_not_in_ctx' = M.fromList $ mapMaybe (\d -> (d,) <$> M.lookup d throwaway_assignments) deps_not_in_ctx in result <> namesFromList deps_in_ctx <> rmOutOfScopeDeps ctx' deps_not_in_ctx' in M.foldlWithKey f mempty new_assignments -- | Analyse a rep statement and return the updated context and array index -- descriptors. analyseStm :: (Analyse rep) => Context rep -> Stm rep -> (Context rep, IndexTable rep) analyseStm ctx (Let pats _ e) = do -- Get the name of the first element in a pattern let pattern_names = map patElemName $ patElems pats -- Construct the result and Context from the subexpression. If the -- subexpression is a body, we recurse into it. case e of BasicOp (Index name (Slice dim_subexp)) -> analyseIndex ctx pattern_names name dim_subexp BasicOp (Update _ name (Slice dim_subexp) _subexp) -> analyseIndex ctx pattern_names name dim_subexp BasicOp op -> analyseBasicOp ctx op pattern_names Match conds cases default_body _ -> analyseMatch ctx' pattern_names default_body $ map caseBody cases where ctx' = contextFromNames ctx (varInfoZeroDeps ctx) $ concatMap (namesToList . freeIn) conds Loop bindings loop body -> analyseLoop ctx bindings loop body pattern_names Apply _name diets _ _ -> analyseApply ctx pattern_names diets WithAcc _ _ -> (ctx, mempty) -- ignored Op op -> analyseOp op ctx pattern_names -- If left, this is just a regular index. If right, a slice happened. getIndexDependencies :: Context rep -> [DimIndex SubExp] -> Either [DimAccess rep] [DimAccess rep] getIndexDependencies ctx dims = fst $ foldr ( \idx (a, i) -> ( either (matchDimIndex idx) (either Right Right . matchDimIndex idx) a, i - 1 ) ) (Left [], length dims - 1) dims where matchDimIndex (DimFix subExpression) accumulator = Left $ consolidate ctx subExpression : accumulator -- If we encounter a DimSlice, add it to a map of `DimSlice`s and check -- result later. matchDimIndex (DimSlice offset num_elems stride) accumulator = -- We assume that a slice is iterated sequentially, so we have to -- create a fake dependency for the slice. let dimAccess' = DimAccess (M.singleton (VName "slice" 0) $ Dependency (currentLevel ctx) LoopVar) (Just $ VName "slice" 0) cons = consolidate ctx dimAccess = dimAccess' <> cons offset <> cons num_elems <> cons stride in Right $ dimAccess : accumulator -- | Gets the dependencies of each dimension and either returns a result, or -- adds a slice to the context. analyseIndex :: Context rep -> [VName] -> VName -> [DimIndex SubExp] -> (Context rep, IndexTable rep) analyseIndex ctx pats arr_name dim_indices = -- Get the dependendencies of each dimension let dependencies = getIndexDependencies ctx dim_indices -- Extend the current context with current pattern(s) and its deps ctx' = analyseIndexContextFromIndices ctx dim_indices pats -- The bodytype(s) are used in the result construction array_name' = -- For now, we assume the array is in row-major-order, hence the -- identity permutation. In the future, we might want to infer its -- layout, for example, if the array is the result of a transposition. let layout = [0 .. length dim_indices - 1] in -- 2. If the arrayname was not in assignments, it was not an immediately -- allocated array. fromMaybe (arr_name, [], layout) -- 1. Maybe find the array name, and the "stack" of body types that the -- array was allocated in. . L.find (\(n, _, _) -> n == arr_name) -- 0. Get the "stack" of bodytypes for each assignment $ map (\(n, vi) -> (n, parents_nest vi, layout)) (M.toList $ assignments ctx') in either (index ctx' array_name') (slice ctx' array_name') dependencies where slice :: Context rep -> ArrayName -> [DimAccess rep] -> (Context rep, IndexTable rep) slice context array_name dims = (context {slices = M.insert (head pats) (array_name, pats, dims) $ slices context}, mempty) index :: Context rep -> ArrayName -> [DimAccess rep] -> (Context rep, IndexTable rep) index context array_name@(name, _, _) dim_access = -- If the arrayname is a `DimSlice` we want to fixup the access case M.lookup name $ slices context of Nothing -> analyseIndex' context pats array_name dim_access Just (arr_name', pats', slice_access) -> analyseIndex' context pats' arr_name' (init slice_access ++ [head dim_access <> last slice_access] ++ drop 1 dim_access) analyseIndexContextFromIndices :: Context rep -> [DimIndex SubExp] -> [VName] -> Context rep analyseIndexContextFromIndices ctx dim_accesses pats = let subexprs = mapMaybe ( \case DimFix (Var v) -> Just v DimFix (Constant _) -> Nothing DimSlice _offs _n _stride -> Nothing ) dim_accesses -- Add each non-constant DimIndex as a dependency to the index expression var_info = varInfoFromNames ctx $ namesFromList subexprs in -- Extend context with the dependencies index expression foldl' extend ctx $ map (`oneContext` var_info) pats analyseIndex' :: Context rep -> [VName] -> ArrayName -> [DimAccess rep] -> (Context rep, IndexTable rep) analyseIndex' ctx _ _ [] = (ctx, mempty) analyseIndex' ctx _ _ [_] = (ctx, mempty) analyseIndex' ctx pats arr_name dim_accesses = -- Get the name of all segmaps in the current "callstack" let segmaps = allSegMap ctx idx_expr_name = pats -- IndexExprName -- For each pattern, create a mapping to the dimensional indices map_ixd_expr = map (`M.singleton` dim_accesses) idx_expr_name -- IndexExprName |-> [DimAccess] -- For each pattern -> [DimAccess] mapping, create a mapping from the array -- name that was indexed. map_array = map (M.singleton arr_name) map_ixd_expr -- ArrayName |-> IndexExprName |-> [DimAccess] -- ∀ (arr_name -> IdxExp -> [DimAccess]) mappings, create a mapping from all -- segmaps in current callstack (segThread & segGroups alike). results = concatMap (\ma -> map (`M.singleton` ma) segmaps) map_array res = foldl' unionIndexTables mempty results in (ctx, res) analyseBasicOp :: Context rep -> BasicOp -> [VName] -> (Context rep, IndexTable rep) analyseBasicOp ctx expression pats = -- Construct a VariableInfo from the subexpressions let ctx_val = case expression of SubExp se -> varInfoFromSubExp se Opaque _ se -> varInfoFromSubExp se ArrayVal _ _ -> (varInfoFromNames ctx mempty) {variableType = ConstType} ArrayLit ses _t -> concatVariableInfos mempty ses UnOp _ se -> varInfoFromSubExp se BinOp _ lsubexp rsubexp -> concatVariableInfos mempty [lsubexp, rsubexp] CmpOp _ lsubexp rsubexp -> concatVariableInfos mempty [lsubexp, rsubexp] ConvOp _ se -> varInfoFromSubExp se Assert se _ _ -> varInfoFromSubExp se Index name _ -> error $ "unhandled: Index (This should NEVER happen) into " ++ prettyString name Update _ name _slice _subexp -> error $ "unhandled: Update (This should NEVER happen) onto " ++ prettyString name -- Technically, do we need this case? Concat _ _ length_subexp -> varInfoFromSubExp length_subexp Manifest _dim name -> varInfoFromNames ctx $ oneName name Iota end start stride _ -> concatVariableInfos mempty [end, start, stride] Replicate (Shape shape) value' -> concatVariableInfos mempty (value' : shape) Scratch _ sers -> concatVariableInfos mempty sers Reshape _ (Shape shape_subexp) name -> concatVariableInfos (oneName name) shape_subexp Rearrange _ name -> varInfoFromNames ctx $ oneName name UpdateAcc _ name lsubexprs rsubexprs -> concatVariableInfos (oneName name) (lsubexprs ++ rsubexprs) FlatIndex name _ -> varInfoFromNames ctx $ oneName name FlatUpdate name _ source -> varInfoFromNames ctx $ namesFromList [name, source] ctx' = foldl' extend ctx $ map (`oneContext` ctx_val) pats in (ctx', mempty) where concatVariableInfos ne nn = varInfoFromNames ctx (ne <> mconcat (map (analyseSubExp pats ctx) nn)) varInfoFromSubExp (Constant _) = (varInfoFromNames ctx mempty) {variableType = ConstType} varInfoFromSubExp (Var v) = case M.lookup v (assignments ctx) of Just _ -> (varInfoFromNames ctx $ oneName v) {variableType = Variable} Nothing -> (varInfoFromNames ctx mempty) {variableType = Variable} -- Means a global. analyseMatch :: (Analyse rep) => Context rep -> [VName] -> Body rep -> [Body rep] -> (Context rep, IndexTable rep) analyseMatch ctx pats body parents = let ctx'' = ctx {currentLevel = currentLevel ctx - 1} in foldl ( \(ctx', res) b -> -- This Little Maneuver's Gonna Cost Us 51 Years bimap constLevel (unionIndexTables res) . analyseStms ctx' CondBodyName pats . stmsToList $ bodyStms b ) (ctx'', mempty) (body : parents) where constLevel context = context {currentLevel = currentLevel ctx - 1} analyseLoop :: (Analyse rep) => Context rep -> [(FParam rep, SubExp)] -> LoopForm -> Body rep -> [VName] -> (Context rep, IndexTable rep) analyseLoop ctx bindings loop body pats = do let next_level = currentLevel ctx let ctx'' = ctx {currentLevel = next_level} let ctx' = contextFromNames ctx'' ((varInfoZeroDeps ctx) {variableType = LoopVar}) $ case loop of WhileLoop iv -> iv : map (paramName . fst) bindings ForLoop iv _ _ -> iv : map (paramName . fst) bindings -- Extend context with the loop expression analyseStms ctx' LoopBodyName pats $ stmsToList $ bodyStms body analyseApply :: Context rep -> [VName] -> [(SubExp, Diet)] -> (Context rep, IndexTable rep) analyseApply ctx pats diets = ( foldl' extend ctx $ map (\pat -> oneContext pat $ varInfoFromNames ctx $ mconcat $ map (freeIn . fst) diets) pats, mempty ) segOpType :: SegOp lvl rep -> VName -> SegOpName segOpType (SegMap {}) = SegmentedMap segOpType (SegRed {}) = SegmentedRed segOpType (SegScan {}) = SegmentedScan segOpType (SegHist {}) = SegmentedHist analyseSegOp :: (Analyse rep) => SegOp lvl rep -> Context rep -> [VName] -> (Context rep, IndexTable rep) analyseSegOp op ctx pats = let next_level = currentLevel ctx + length (unSegSpace $ segSpace op) - 1 ctx' = ctx {currentLevel = next_level} segspace_context = foldl' extend ctx' . map (\(n, i) -> oneContext n $ VariableInfo mempty (currentLevel ctx + i) (parents ctx') ThreadID) . (\segspace_params -> zip segspace_params [0 ..]) -- contextFromNames ctx' Parallel . map fst . unSegSpace $ segSpace op in -- Analyse statements in the SegOp body analyseStms segspace_context (SegOpName . segOpType op) pats . stmsToList . kernelBodyStms $ segBody op analyseSizeOp :: SizeOp -> Context rep -> [VName] -> (Context rep, IndexTable rep) analyseSizeOp op ctx pats = let ctx' = case op of CmpSizeLe _name _class subexp -> subexprsToContext [subexp] CalcNumBlocks lsubexp _name rsubexp -> subexprsToContext [lsubexp, rsubexp] _ -> ctx -- Add sizeOp to context ctx'' = foldl' extend ctx' $ map (\pat -> oneContext pat $ (varInfoZeroDeps ctx) {parents_nest = parents ctx'}) pats in (ctx'', mempty) where subexprsToContext = contextFromNames ctx (varInfoZeroDeps ctx) . concatMap (namesToList . analyseSubExp pats ctx) -- | Analyse statements in a rep body. analyseGPUBody :: (Analyse rep) => Body rep -> Context rep -> (Context rep, IndexTable rep) analyseGPUBody body ctx = analyseStmsPrimitive ctx $ stmsToList $ bodyStms body analyseOtherOp :: Context rep -> [VName] -> (Context rep, IndexTable rep) analyseOtherOp ctx _ = (ctx, mempty) -- | Returns an intmap of names, to be used as dependencies in construction of -- VariableInfos. analyseSubExp :: [VName] -> Context rep -> SubExp -> Names analyseSubExp _ _ (Constant _) = mempty analyseSubExp _ _ (Var v) = oneName v -- | Reduce a DimFix into its set of dependencies consolidate :: Context rep -> SubExp -> DimAccess rep consolidate _ (Constant _) = mempty consolidate ctx (Var v) = DimAccess (reduceDependencies ctx v) (Just v) -- | Recursively lookup vnames until vars with no deps are reached. reduceDependencies :: Context rep -> VName -> M.Map VName Dependency reduceDependencies ctx v = case M.lookup v (assignments ctx) of Nothing -> mempty -- Means a global. Just (VariableInfo deps lvl _parents t) -> -- We detect whether it is a threadID or loop counter by checking -- whether or not it has any dependencies case t of ThreadID -> M.fromList [(v, Dependency lvl t)] LoopVar -> M.fromList [(v, Dependency lvl t)] Variable -> mconcat $ map (reduceDependencies ctx) $ namesToList deps ConstType -> mempty -- Misc functions -- Instances for AST types that we actually support instance Analyse GPU where analyseOp gpu_op | (SegOp op) <- gpu_op = analyseSegOp op | (SizeOp op) <- gpu_op = analyseSizeOp op | (GPUBody _ body) <- gpu_op = pure . analyseGPUBody body | (Futhark.IR.GPU.OtherOp _) <- gpu_op = analyseOtherOp instance Analyse MC where analyseOp mc_op | ParOp Nothing seq_segop <- mc_op = analyseSegOp seq_segop | ParOp (Just segop) seq_segop <- mc_op = \ctx name -> do let (ctx', res') = analyseSegOp segop ctx name let (ctx'', res'') = analyseSegOp seq_segop ctx' name (ctx'', unionIndexTables res' res'') | Futhark.IR.MC.OtherOp _ <- mc_op = analyseOtherOp -- Unfortunately we need these instances, even though they may never appear. instance Analyse GPUMem where analyseOp _ = error $ notImplementedYet "GPUMem" instance Analyse MCMem where analyseOp _ = error "Unexpected?" instance Analyse Seq where analyseOp _ = error $ notImplementedYet "Seq" instance Analyse SeqMem where analyseOp _ = error $ notImplementedYet "SeqMem" instance Analyse SOACS where analyseOp _ = error $ notImplementedYet "SOACS" notImplementedYet :: String -> String notImplementedYet s = "Access pattern analysis for the " ++ s ++ " backend is not implemented." instance Pretty (IndexTable rep) where pretty = stack . map f . M.toList :: IndexTable rep -> Doc ann where f (segop, arrNameToIdxExprMap) = pretty segop <+> colon <+> g arrNameToIdxExprMap g maps = lbrace indent 4 (mapprintArray $ M.toList maps) rbrace mapprintArray :: [(ArrayName, M.Map IndexExprName [DimAccess rep])] -> Doc ann mapprintArray [] = "" mapprintArray [m] = printArrayMap m mapprintArray (m : mm) = printArrayMap m mapprintArray mm printArrayMap :: (ArrayName, M.Map IndexExprName [DimAccess rep]) -> Doc ann printArrayMap ((name, _, layout), maps) = "(arr)" <+> pretty name <+> colon <+> pretty layout <+> lbrace indent 4 (mapprintIdxExpr (M.toList maps)) rbrace mapprintIdxExpr :: [(IndexExprName, [DimAccess rep])] -> Doc ann mapprintIdxExpr [] = "" mapprintIdxExpr [m] = printIdxExpMap m mapprintIdxExpr (m : mm) = printIdxExpMap m mapprintIdxExpr mm printIdxExpMap (name, mems) = "(idx)" <+> pretty name <+> ":" indent 4 (printDimAccess mems) printDimAccess :: [DimAccess rep] -> Doc ann printDimAccess dim_accesses = stack $ zipWith (curry printDim) [0 ..] dim_accesses printDim :: (Int, DimAccess rep) -> Doc ann printDim (i, m) = pretty i <+> ":" <+> indent 0 (pretty m) instance Pretty (DimAccess rep) where pretty dim_access = -- Instead of using `brackets $` we manually enclose with `[`s, to add -- spacing between the enclosed elements if case originalVar dim_access of Nothing -> True Just n -> length (dependencies dim_access) == 1 && n == head (map fst $ M.toList $ dependencies dim_access) -- Only print the original name if it is different from the first (and single) dependency then "dependencies" <+> equals <+> align (prettyDeps $ dependencies dim_access) else "dependencies" <+> equals <+> pretty (originalVar dim_access) <+> "->" <+> align (prettyDeps $ dependencies dim_access) where prettyDeps = braces . commasep . map printPair . M.toList printPair (name, Dependency lvl vtype) = pretty name <+> pretty lvl <+> pretty vtype instance Pretty SegOpName where pretty (SegmentedMap name) = "(segmap)" <+> pretty name pretty (SegmentedRed name) = "(segred)" <+> pretty name pretty (SegmentedScan name) = "(segscan)" <+> pretty name pretty (SegmentedHist name) = "(seghist)" <+> pretty name instance Pretty BodyType where pretty (SegOpName (SegmentedMap name)) = pretty name <+> colon <+> "segmap" pretty (SegOpName (SegmentedRed name)) = pretty name <+> colon <+> "segred" pretty (SegOpName (SegmentedScan name)) = pretty name <+> colon <+> "segscan" pretty (SegOpName (SegmentedHist name)) = pretty name <+> colon <+> "seghist" pretty (LoopBodyName name) = pretty name <+> colon <+> "loop" pretty (CondBodyName name) = pretty name <+> colon <+> "cond" instance Pretty VarType where pretty ConstType = "const" pretty Variable = "var" pretty ThreadID = "tid" pretty LoopVar = "iter" futhark-0.25.27/src/Futhark/Analysis/AlgSimplify.hs000066400000000000000000000203611475065116200220760ustar00rootroot00000000000000module Futhark.Analysis.AlgSimplify ( Prod (..), SofP, simplify0, simplify, simplify', simplifySofP, simplifySofP', sumOfProducts, sumToExp, prodToExp, add, sub, negate, isMultipleOf, maybeDivide, removeLessThans, lessThanish, compareComplexity, ) where import Data.Bits (xor) import Data.Function ((&)) import Data.List (findIndex, intersect, partition, sort, (\\)) import Data.Maybe (mapMaybe) import Futhark.Analysis.PrimExp import Futhark.Analysis.PrimExp.Convert import Futhark.IR.Prop.Names import Futhark.IR.Syntax.Core (SubExp (..), VName) import Futhark.Util import Futhark.Util.Pretty import Prelude hiding (negate) type Exp = PrimExp VName type TExp = TPrimExp Int64 VName data Prod = Prod { negated :: Bool, atoms :: [Exp] } deriving (Show, Eq, Ord) type SofP = [Prod] sumOfProducts :: Exp -> SofP sumOfProducts = map sortProduct . sumOfProducts' sortProduct :: Prod -> Prod sortProduct (Prod n as) = Prod n $ sort as sumOfProducts' :: Exp -> SofP sumOfProducts' (BinOpExp (Add Int64 _) e1 e2) = sumOfProducts' e1 <> sumOfProducts' e2 sumOfProducts' (BinOpExp (Sub Int64 _) (ValueExp (IntValue (Int64Value 0))) e) = map negate $ sumOfProducts' e sumOfProducts' (BinOpExp (Sub Int64 _) e1 e2) = sumOfProducts' e1 <> map negate (sumOfProducts' e2) sumOfProducts' (BinOpExp (Mul Int64 _) e1 e2) = sumOfProducts' e1 `mult` sumOfProducts' e2 sumOfProducts' (ValueExp (IntValue (Int64Value i))) = [Prod (i < 0) [ValueExp $ IntValue $ Int64Value $ abs i]] sumOfProducts' e = [Prod False [e]] mult :: SofP -> SofP -> SofP mult xs ys = [Prod (b `xor` b') (x <> y) | Prod b x <- xs, Prod b' y <- ys] negate :: Prod -> Prod negate p = p {negated = not $ negated p} sumToExp :: SofP -> Exp sumToExp [] = val 0 sumToExp [x] = prodToExp x sumToExp (x : xs) = foldl (BinOpExp $ Add Int64 OverflowUndef) (prodToExp x) $ map prodToExp xs prodToExp :: Prod -> Exp prodToExp (Prod _ []) = val 1 prodToExp (Prod True [ValueExp (IntValue (Int64Value i))]) = ValueExp $ IntValue $ Int64Value (-i) prodToExp (Prod True as) = foldl (BinOpExp $ Mul Int64 OverflowUndef) (val (-1)) as prodToExp (Prod False (a : as)) = foldl (BinOpExp $ Mul Int64 OverflowUndef) a as simplifySofP :: SofP -> SofP simplifySofP = -- TODO: Maybe 'constFoldValueExps' is not necessary after adding scaleConsts fixPoint (mapMaybe (applyZero . removeOnes) . scaleConsts . constFoldValueExps . removeNegations) simplifySofP' :: SofP -> SofP simplifySofP' = fixPoint (mapMaybe (applyZero . removeOnes) . scaleConsts . removeNegations) simplify0 :: Exp -> SofP simplify0 = simplifySofP . sumOfProducts simplify :: Exp -> Exp simplify = constFoldPrimExp . sumToExp . simplify0 simplify' :: TExp -> TExp simplify' = TPrimExp . simplify . untyped applyZero :: Prod -> Maybe Prod applyZero p@(Prod _ as) | val 0 `elem` as = Nothing | otherwise = Just p removeOnes :: Prod -> Prod removeOnes (Prod neg as) = let as' = filter (/= val 1) as in Prod neg $ if null as' then [ValueExp $ IntValue $ Int64Value 1] else as' removeNegations :: SofP -> SofP removeNegations [] = [] removeNegations (t : ts) = case break (== negate t) ts of (start, _ : rest) -> removeNegations $ start <> rest _ -> t : removeNegations ts constFoldValueExps :: SofP -> SofP constFoldValueExps prods = let (value_exps, others) = partition (all isPrimValue . atoms) prods value_exps' = sumOfProducts $ constFoldPrimExp $ sumToExp value_exps in value_exps' <> others intFromExp :: Exp -> Maybe Int64 intFromExp (ValueExp (IntValue x)) = Just $ valueIntegral x intFromExp _ = Nothing -- | Given @-[2, x]@ returns @(-2, [x])@ prodToScale :: Prod -> (Int64, [Exp]) prodToScale (Prod b exps) = let (scalars, exps') = partitionMaybe intFromExp exps in if b then (-(product scalars), exps') else (product scalars, exps') -- | Given @(-2, [x])@ returns @-[1, 2, x]@ scaleToProd :: (Int64, [Exp]) -> Prod scaleToProd (i, exps) = Prod (i < 0) $ ValueExp (IntValue $ Int64Value $ abs i) : exps -- | Given @[[2, x], -[x]]@ returns @[[x]]@ scaleConsts :: SofP -> SofP scaleConsts = helper [] . map prodToScale where helper :: [Prod] -> [(Int64, [Exp])] -> [Prod] helper acc [] = reverse acc helper acc ((scale, exps) : rest) = case flip focusNth rest =<< findIndex ((==) exps . snd) rest of Nothing -> helper (scaleToProd (scale, exps) : acc) rest Just (before, (scale', _), after) -> helper acc $ (scale + scale', exps) : (before <> after) isPrimValue :: Exp -> Bool isPrimValue (ValueExp _) = True isPrimValue _ = False val :: Int64 -> Exp val = ValueExp . IntValue . Int64Value add :: SofP -> SofP -> SofP add ps1 ps2 = simplifySofP $ ps1 <> ps2 sub :: SofP -> SofP -> SofP sub ps1 ps2 = add ps1 $ map negate ps2 isMultipleOf :: Prod -> [Exp] -> Bool isMultipleOf (Prod _ as) term = let quotient = as \\ term in sort (quotient <> term) == sort as maybeDivide :: Prod -> Prod -> Maybe Prod maybeDivide dividend divisor | Prod dividend_b dividend_factors <- dividend, Prod divisor_b divisor_factors <- divisor, quotient <- dividend_factors \\ divisor_factors, sort (quotient <> divisor_factors) == sort dividend_factors = Just $ Prod (dividend_b `xor` divisor_b) quotient | (dividend_scale, dividend_rest) <- prodToScale dividend, (divisor_scale, divisor_rest) <- prodToScale divisor, dividend_scale `mod` divisor_scale == 0, null $ divisor_rest \\ dividend_rest = Just $ Prod (signum (dividend_scale `div` divisor_scale) < 0) ( ValueExp (IntValue $ Int64Value $ dividend_scale `div` divisor_scale) : (dividend_rest \\ divisor_rest) ) | otherwise = Nothing -- | Given a list of 'Names' that we know are non-negative (>= 0), determine -- whether we can say for sure that the given 'AlgSimplify.SofP' is -- non-negative. Conservatively returns 'False' if there is any doubt. -- -- TODO: We need to expand this to be able to handle cases such as @i*n + g < (i -- + 1) * n@, if it is known that @g < n@, eg. from a 'SegSpace' or a loop form. nonNegativeish :: Names -> SofP -> Bool nonNegativeish non_negatives = all (nonNegativeishProd non_negatives) nonNegativeishProd :: Names -> Prod -> Bool nonNegativeishProd _ (Prod True _) = False nonNegativeishProd non_negatives (Prod False as) = all (nonNegativeishExp non_negatives) as nonNegativeishExp :: Names -> PrimExp VName -> Bool nonNegativeishExp _ (ValueExp v) = not $ negativeIsh v nonNegativeishExp non_negatives (LeafExp vname _) = vname `nameIn` non_negatives nonNegativeishExp _ _ = False -- | Is e1 symbolically less than or equal to e2? lessThanOrEqualish :: [(VName, PrimExp VName)] -> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool lessThanOrEqualish less_thans0 non_negatives e1 e2 = case e2 - e1 & untyped & simplify0 of [] -> True simplified -> nonNegativeish non_negatives $ fixPoint (`removeLessThans` less_thans) simplified where less_thans = concatMap (\(i, bound) -> [(Var i, bound), (Constant $ IntValue $ Int64Value 0, bound)]) less_thans0 lessThanish :: [(VName, PrimExp VName)] -> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool lessThanish less_thans non_negatives e1 = lessThanOrEqualish less_thans non_negatives (e1 + 1) removeLessThans :: SofP -> [(SubExp, PrimExp VName)] -> SofP removeLessThans = foldl ( \sofp (i, bound) -> let to_remove = simplifySofP $ Prod True [primExpFromSubExp (IntType Int64) i] : simplify0 bound in case to_remove `intersect` sofp of to_remove' | to_remove' == to_remove -> sofp \\ to_remove _ -> sofp ) compareComplexity :: SofP -> SofP -> Ordering compareComplexity xs0 ys0 = case length xs0 `compare` length ys0 of EQ -> helper xs0 ys0 c -> c where helper [] [] = EQ helper [] _ = LT helper _ [] = GT helper (px : xs) (py : ys) = case (prodToScale px, prodToScale py) of ((ix, []), (iy, [])) -> case ix `compare` iy of EQ -> helper xs ys c -> c ((_, []), (_, _)) -> LT ((_, _), (_, [])) -> GT ((_, x), (_, y)) -> case length x `compare` length y of EQ -> helper xs ys c -> c futhark-0.25.27/src/Futhark/Analysis/Alias.hs000066400000000000000000000101211475065116200207000ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Alias analysis of a full Futhark program. Takes as input a -- program with an arbitrary rep and produces one with aliases. This -- module does not implement the aliasing logic itself, and derives -- its information from definitions in -- "Futhark.IR.Prop.Aliases" and -- "Futhark.IR.Aliases". The alias information computed -- here will include transitive aliases (note that this is not what -- the building blocks do). module Futhark.Analysis.Alias ( aliasAnalysis, AliasableRep, -- * Ad-hoc utilities analyseFun, analyseStms, analyseStm, analyseExp, analyseBody, analyseLambda, ) where import Data.List qualified as L import Data.Map qualified as M import Futhark.IR.Aliases -- | Perform alias analysis on a Futhark program. aliasAnalysis :: (AliasableRep rep) => Prog rep -> Prog (Aliases rep) aliasAnalysis prog = prog { progConsts = fst (analyseStms mempty (progConsts prog)), progFuns = map analyseFun (progFuns prog) } -- | Perform alias analysis on function. analyseFun :: (AliasableRep rep) => FunDef rep -> FunDef (Aliases rep) analyseFun (FunDef entry attrs fname restype params body) = FunDef entry attrs fname restype params body' where body' = analyseBody mempty body -- | Perform alias analysis on Body. analyseBody :: (AliasableRep rep) => AliasTable -> Body rep -> Body (Aliases rep) analyseBody atable (Body rep stms result) = let (stms', _atable') = analyseStms atable stms in mkAliasedBody rep stms' result -- | Perform alias analysis on statements. analyseStms :: (AliasableRep rep) => AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed) analyseStms orig_aliases = withoutBound . L.foldl' f (mempty, (orig_aliases, mempty)) . stmsToList where withoutBound (stms, (aliases, consumed)) = let bound = foldMap (namesFromList . patNames . stmPat) stms consumed' = consumed `namesSubtract` bound in (stms, (aliases, consumed')) f (stms, aliases) stm = let stm' = analyseStm (fst aliases) stm atable' = trackAliases aliases stm' in (stms <> oneStm stm', atable') -- | Perform alias analysis on statement. analyseStm :: (AliasableRep rep) => AliasTable -> Stm rep -> Stm (Aliases rep) analyseStm aliases (Let pat (StmAux cs attrs dec) e) = let e' = analyseExp aliases e pat' = mkAliasedPat pat e' rep' = (AliasDec $ consumedInExp e', dec) in Let pat' (StmAux cs attrs rep') e' -- | Perform alias analysis on expression. analyseExp :: (AliasableRep rep) => AliasTable -> Exp rep -> Exp (Aliases rep) -- Would be better to put this in a BranchType annotation, but that -- requires a lot of other work. analyseExp aliases (Match cond cases defbody matchdec) = let cases' = map (fmap $ analyseBody aliases) cases defbody' = analyseBody aliases defbody all_cons = foldMap (snd . fst . bodyDec) $ defbody' : map caseBody cases' isConsumed v = any (`nameIn` unAliases all_cons) $ v : namesToList (M.findWithDefault mempty v aliases) notConsumed = AliasDec . namesFromList . filter (not . isConsumed) . namesToList . unAliases onBody (Body ((als, cons), dec) stms res) = Body ((map notConsumed als, cons), dec) stms res cases'' = map (fmap onBody) cases' defbody'' = onBody defbody' in Match cond cases'' defbody'' matchdec analyseExp aliases e = mapExp analyse e where analyse = Mapper { mapOnSubExp = pure, mapOnVName = pure, mapOnBody = const $ pure . analyseBody aliases, mapOnRetType = pure, mapOnBranchType = pure, mapOnFParam = pure, mapOnLParam = pure, mapOnOp = pure . addOpAliases aliases } -- | Perform alias analysis on lambda. analyseLambda :: (AliasableRep rep) => AliasTable -> Lambda rep -> Lambda (Aliases rep) analyseLambda aliases lam = let body = analyseBody aliases $ lambdaBody lam in lam { lambdaBody = body, lambdaParams = lambdaParams lam } futhark-0.25.27/src/Futhark/Analysis/CallGraph.hs000066400000000000000000000114561475065116200215200ustar00rootroot00000000000000-- | This module exports functionality for generating a call graph of -- an Futhark program. module Futhark.Analysis.CallGraph ( CallGraph, buildCallGraph, isFunInCallGraph, calls, calledByConsts, allCalledBy, numOccurences, ) where import Control.Monad.Writer.Strict import Data.List (foldl') import Data.Map.Strict qualified as M import Data.Maybe (isJust) import Data.Set qualified as S import Futhark.IR.SOACS import Futhark.Util.Pretty type FunctionTable = M.Map Name (FunDef SOACS) buildFunctionTable :: Prog SOACS -> FunctionTable buildFunctionTable = foldl expand M.empty . progFuns where expand ftab f = M.insert (funDefName f) f ftab -- | A unique (at least within a function) name identifying a function -- call. In practice the first element of the corresponding pattern. type CallId = VName data FunCalls = FunCalls { fcMap :: M.Map CallId (Attrs, Name), fcAllCalled :: S.Set Name } deriving (Eq, Ord, Show) instance Monoid FunCalls where mempty = FunCalls mempty mempty instance Semigroup FunCalls where FunCalls x1 y1 <> FunCalls x2 y2 = FunCalls (x1 <> x2) (y1 <> y2) fcCalled :: Name -> FunCalls -> Bool fcCalled f fcs = f `S.member` fcAllCalled fcs type FunGraph = M.Map Name FunCalls -- | The call graph is a mapping from a function name, i.e., the -- caller, to a record of the names of functions called *directly* (not -- transitively!) by the function. -- -- We keep track separately of the functions called by constants. data CallGraph = CallGraph { cgCalledByFuns :: FunGraph, cgCalledByConsts :: FunCalls } deriving (Eq, Ord, Show) -- | Is the given function known to the call graph? isFunInCallGraph :: Name -> CallGraph -> Bool isFunInCallGraph f = M.member f . cgCalledByFuns -- | Does the first function call the second? calls :: Name -> Name -> CallGraph -> Bool calls caller callee = maybe False (fcCalled callee) . M.lookup caller . cgCalledByFuns -- | Is the function called in any of the constants? calledByConsts :: Name -> CallGraph -> Bool calledByConsts callee = fcCalled callee . cgCalledByConsts -- | All functions called by this function. allCalledBy :: Name -> CallGraph -> S.Set Name allCalledBy f = maybe mempty fcAllCalled . M.lookup f . cgCalledByFuns -- | @buildCallGraph prog@ build the program's call graph. buildCallGraph :: Prog SOACS -> CallGraph buildCallGraph prog = CallGraph fg cg where fg = foldl' (buildFGfun ftable) M.empty entry_points cg = buildFGStms $ progConsts prog entry_points = S.fromList (map funDefName (filter (isJust . funDefEntryPoint) $ progFuns prog)) <> fcAllCalled cg ftable = buildFunctionTable prog count :: (Ord k) => [k] -> M.Map k Int count ks = M.fromListWith (+) $ map (,1) ks -- | Produce a mapping of the number of occurences in the call graph -- of each function. Only counts functions that are called at least -- once. numOccurences :: CallGraph -> M.Map Name Int numOccurences (CallGraph funs consts) = count $ map snd $ M.elems (fcMap consts <> foldMap fcMap (M.elems funs)) -- | @buildCallGraph ftable fg fname@ updates @fg@ with the -- contributions of function @fname@. buildFGfun :: FunctionTable -> FunGraph -> Name -> FunGraph buildFGfun ftable fg fname = -- Check if function is a non-builtin that we have not already -- processed. case M.lookup fname ftable of Just f | Nothing <- M.lookup fname fg -> do let callees = buildFGBody $ funDefBody f fg' = M.insert fname callees fg -- recursively build the callees foldl' (buildFGfun ftable) fg' $ fcAllCalled callees _ -> fg buildFGStms :: Stms SOACS -> FunCalls buildFGStms = mconcat . map buildFGstm . stmsToList buildFGBody :: Body SOACS -> FunCalls buildFGBody = buildFGStms . bodyStms buildFGstm :: Stm SOACS -> FunCalls buildFGstm (Let (Pat (p : _)) aux (Apply fname _ _ _)) = FunCalls (M.singleton (patElemName p) (stmAuxAttrs aux, fname)) (S.singleton fname) buildFGstm (Let _ _ (Op op)) = execWriter $ mapSOACM folder op where folder = identitySOACMapper { mapOnSOACLambda = \lam -> do tell $ buildFGBody $ lambdaBody lam pure lam } buildFGstm (Let _ _ e) = execWriter $ mapExpM folder e where folder = identityMapper { mapOnBody = \_ body -> do tell $ buildFGBody body pure body } instance Pretty FunCalls where pretty = stack . map f . M.toList . fcMap where f (x, (attrs, y)) = "=>" <+> pretty y <+> parens ("at" <+> pretty x <+> pretty attrs) instance Pretty CallGraph where pretty (CallGraph fg cg) = stack $ punctuate line $ ppFunCalls ("called at top level", cg) : map ppFunCalls (M.toList fg) where ppFunCalls (f, fcalls) = pretty f pretty (map (const '=') (nameToString f)) indent 2 (pretty fcalls) futhark-0.25.27/src/Futhark/Analysis/DataDependencies.hs000066400000000000000000000145041475065116200230400ustar00rootroot00000000000000-- | Facilities for inspecting the data dependencies of a program. module Futhark.Analysis.DataDependencies ( Dependencies, dataDependencies, depsOf, depsOf', depsOfArrays, depsOfShape, lambdaDependencies, reductionDependencies, findNecessaryForReturned, ) where import Data.List qualified as L import Data.Map.Strict qualified as M import Futhark.IR -- | A mapping from a variable name @v@, to those variables on which -- the value of @v@ is dependent. The intuition is that we could -- remove all other variables, and @v@ would still be computable. -- This also includes names bound in loops or by lambdas. type Dependencies = M.Map VName Names -- | Compute the data dependencies for an entire body. dataDependencies :: (ASTRep rep) => Body rep -> Dependencies dataDependencies = dataDependencies' M.empty dataDependencies' :: (ASTRep rep) => Dependencies -> Body rep -> Dependencies dataDependencies' startdeps = foldl grow startdeps . bodyStms where grow deps (Let pat _ (WithAcc inputs lam)) = let input_deps = foldMap depsOfWithAccInput inputs -- Dependencies of each input reduction are concatenated. -- Input to lam is cert_1, ..., cert_n, acc_1, ..., acc_n. lam_deps = lambdaDependencies deps lam (input_deps <> input_deps) transitive = map (depsOfNames deps) lam_deps in M.fromList (zip (patNames pat) transitive) `M.union` deps where depsOfArrays' shape = map (\arr -> oneName arr <> depsOfShape shape) depsOfWithAccInput (shape, arrs, Nothing) = depsOfArrays' shape arrs depsOfWithAccInput (shape, arrs, Just (lam', nes)) = reductionDependencies deps lam' nes (depsOfArrays' shape arrs) grow deps (Let pat _ (Op op)) = let op_deps = map (depsOfNames deps) (opDependencies op) pat_deps = map (depsOfNames deps . freeIn) (patElems pat) in if length op_deps /= length pat_deps then error . unlines $ [ "dataDependencies':", "Pattern size: " <> show (length pat_deps), "Op deps size: " <> show (length op_deps), "Expression:", prettyString op ] else M.fromList (zip (patNames pat) $ zipWith (<>) pat_deps op_deps) `M.union` deps grow deps (Let pat _ (Match c cases defbody _)) = let cases_deps = map (dataDependencies' deps . caseBody) cases defbody_deps = dataDependencies' deps defbody cdeps = foldMap (depsOf deps) c comb (pe, se_cases_deps, se_defbody_deps) = ( patElemName pe, mconcat $ se_cases_deps ++ [freeIn pe, cdeps, se_defbody_deps] ++ map (depsOfVar deps) (namesToList $ freeIn pe) ) branchdeps = M.fromList $ map comb $ zip3 (patElems pat) ( L.transpose . zipWith (map . depsOf) cases_deps $ map (map resSubExp . bodyResult . caseBody) cases ) (map (depsOf defbody_deps . resSubExp) (bodyResult defbody)) in M.unions $ [branchdeps, deps, defbody_deps] ++ cases_deps grow deps (Let pat _ e) = let free = freeIn pat <> freeIn e free_deps = depsOfNames deps free in M.fromList [(name, free_deps) | name <- patNames pat] `M.union` deps depsOf :: Dependencies -> SubExp -> Names depsOf _ (Constant _) = mempty depsOf deps (Var v) = depsOfVar deps v depsOf' :: SubExp -> Names depsOf' (Constant _) = mempty depsOf' (Var v) = depsOfVar mempty v depsOfVar :: Dependencies -> VName -> Names depsOfVar deps name = oneName name <> M.findWithDefault mempty name deps depsOfRes :: Dependencies -> SubExpRes -> Names depsOfRes deps (SubExpRes _ se) = depsOf deps se -- | Extend @names@ with direct dependencies in @deps@. depsOfNames :: Dependencies -> Names -> Names depsOfNames deps names = mconcat $ map (depsOfVar deps) $ namesToList names depsOfArrays :: SubExp -> [VName] -> [Names] depsOfArrays size = map (\arr -> oneName arr <> depsOf mempty size) depsOfShape :: Shape -> Names depsOfShape shape = mconcat $ map (depsOf mempty) (shapeDims shape) -- | Determine the variables on which the results of applying -- anonymous function @lam@ to @inputs@ depend. lambdaDependencies :: (ASTRep rep) => Dependencies -> Lambda rep -> [Names] -> [Names] lambdaDependencies deps lam inputs = let names_in_scope = freeIn lam <> mconcat inputs deps_in = M.fromList $ zip (boundByLambda lam) inputs deps' = dataDependencies' (deps_in <> deps) (lambdaBody lam) in map (namesIntersection names_in_scope . depsOfRes deps') (bodyResult $ lambdaBody lam) -- | Like 'lambdaDependencies', but @lam@ is a binary operation -- with a neutral element. reductionDependencies :: (ASTRep rep) => Dependencies -> Lambda rep -> [SubExp] -> [Names] -> [Names] reductionDependencies deps lam nes inputs = let nes' = map (depsOf deps) nes in lambdaDependencies deps lam (zipWith (<>) nes' inputs) -- | @findNecessaryForReturned p merge deps@ computes which of the -- loop parameters (@merge@) are necessary for the result of the loop, -- where @p@ given a loop parameter indicates whether the final value -- of that parameter is live after the loop. @deps@ is the data -- dependencies of the loop body. This is computed by straightforward -- fixpoint iteration. findNecessaryForReturned :: (Param dec -> Bool) -> [(Param dec, SubExp)] -> M.Map VName Names -> Names findNecessaryForReturned usedAfterLoop merge_and_res allDependencies = iterateNecessary mempty <> namesFromList (map paramName $ filter usedAfterLoop $ map fst merge_and_res) where iterateNecessary prev_necessary | necessary == prev_necessary = necessary | otherwise = iterateNecessary necessary where necessary = mconcat $ map dependencies returnedResultSubExps usedAfterLoopOrNecessary param = usedAfterLoop param || paramName param `nameIn` prev_necessary returnedResultSubExps = map snd $ filter (usedAfterLoopOrNecessary . fst) merge_and_res dependencies (Constant _) = mempty dependencies (Var v) = M.findWithDefault (oneName v) v allDependencies futhark-0.25.27/src/Futhark/Analysis/HORep/000077500000000000000000000000001475065116200202755ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Analysis/HORep/MapNest.hs000066400000000000000000000126071475065116200222060ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.Analysis.HORep.MapNest ( Nesting (..), MapNest (..), typeOf, params, inputs, setInputs, fromSOAC, toSOAC, ) where import Data.List (find) import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Analysis.HORep.SOAC (SOAC) import Futhark.Analysis.HORep.SOAC qualified as SOAC import Futhark.Construct import Futhark.IR hiding (typeOf) import Futhark.IR.SOACS.SOAC qualified as Futhark import Futhark.Transform.Substitute data Nesting rep = Nesting { nestingParamNames :: [VName], nestingResult :: [VName], nestingReturnType :: [Type], nestingWidth :: SubExp } deriving (Eq, Ord, Show) data MapNest rep = MapNest SubExp (Lambda rep) [Nesting rep] [SOAC.Input] deriving (Show) typeOf :: MapNest rep -> [Type] typeOf (MapNest w lam [] _) = map (`arrayOfRow` w) $ lambdaReturnType lam typeOf (MapNest w _ (nest : _) _) = map (`arrayOfRow` w) $ nestingReturnType nest params :: MapNest rep -> [VName] params (MapNest _ lam [] _) = map paramName $ lambdaParams lam params (MapNest _ _ (nest : _) _) = nestingParamNames nest inputs :: MapNest rep -> [SOAC.Input] inputs (MapNest _ _ _ inps) = inps setInputs :: [SOAC.Input] -> MapNest rep -> MapNest rep setInputs [] (MapNest w body ns _) = MapNest w body ns [] setInputs (inp : inps) (MapNest _ body ns _) = MapNest w body ns' (inp : inps) where w = arraySize 0 $ SOAC.inputType inp ws = drop 1 $ arrayDims $ SOAC.inputType inp ns' = zipWith setDepth ns ws setDepth n nw = n {nestingWidth = nw} fromSOAC :: ( Buildable rep, MonadFreshNames m, LocalScope rep m, Op rep ~ Futhark.SOAC rep ) => SOAC rep -> m (Maybe (MapNest rep)) fromSOAC = fromSOAC' mempty fromSOAC' :: ( Buildable rep, MonadFreshNames m, LocalScope rep m, Op rep ~ Futhark.SOAC rep ) => [Ident] -> SOAC rep -> m (Maybe (MapNest rep)) fromSOAC' bound (SOAC.Screma w inps (SOAC.ScremaForm lam [] [])) = do maybenest <- case ( stmsToList $ bodyStms $ lambdaBody lam, bodyResult $ lambdaBody lam ) of ([Let pat _ e], res) | map resSubExp res == map Var (patNames pat) -> localScope (scopeOfLParams $ lambdaParams lam) $ SOAC.fromExp e >>= either (pure . Left) (fmap (Right . fmap (pat,)) . fromSOAC' bound') _ -> pure $ Right Nothing case maybenest of -- Do we have a nested MapNest? Right (Just (pat, mn@(MapNest inner_w body' ns' inps'))) -> do (ps, inps'') <- unzip <$> fixInputs w (zip (map paramName $ lambdaParams lam) inps) (zip (params mn) inps') let n' = Nesting { nestingParamNames = ps, nestingResult = patNames pat, nestingReturnType = typeOf mn, nestingWidth = inner_w } pure $ Just $ MapNest w body' (n' : ns') inps'' -- No nested MapNest it seems. _ -> do let isBound name | Just param <- find ((name ==) . identName) bound = Just param | otherwise = Nothing boundUsedInBody = mapMaybe isBound $ namesToList $ freeIn lam newParams <- mapM (newIdent' (++ "_wasfree")) boundUsedInBody let subst = M.fromList $ zip (map identName boundUsedInBody) (map identName newParams) inps' = inps ++ map (SOAC.addTransform (SOAC.Replicate mempty $ Shape [w]) . SOAC.identInput) boundUsedInBody lam' = lam { lambdaBody = substituteNames subst $ lambdaBody lam, lambdaParams = lambdaParams lam ++ [Param mempty name t | Ident name t <- newParams] } pure $ Just $ MapNest w lam' [] inps' where bound' = bound <> map paramIdent (lambdaParams lam) fromSOAC' _ _ = pure Nothing toSOAC :: ( MonadFreshNames m, HasScope rep m, Buildable rep, BuilderOps rep, Op rep ~ Futhark.SOAC rep ) => MapNest rep -> m (SOAC rep) toSOAC (MapNest w lam [] inps) = pure $ SOAC.Screma w inps (Futhark.mapSOAC lam) toSOAC (MapNest w lam (Nesting npnames nres nrettype nw : ns) inps) = do let nparams = zipWith (Param mempty) npnames $ map SOAC.inputRowType inps body <- runBodyBuilder $ localScope (scopeOfLParams nparams) $ do letBindNames nres =<< SOAC.toExp =<< toSOAC (MapNest nw lam ns $ map (SOAC.identInput . paramIdent) nparams) pure $ varsRes nres let outerlam = Lambda { lambdaParams = nparams, lambdaBody = body, lambdaReturnType = nrettype } pure $ SOAC.Screma w inps (Futhark.mapSOAC outerlam) fixInputs :: (MonadFreshNames m) => SubExp -> [(VName, SOAC.Input)] -> [(VName, SOAC.Input)] -> m [(VName, SOAC.Input)] fixInputs w ourInps = mapM inspect where isParam x (y, _) = x == y inspect (_, SOAC.Input ts v _) | Just (p, pInp) <- find (isParam v) ourInps = do let pInp' = SOAC.transformRows ts pInp p' <- newNameFromString $ baseString p pure (p', pInp') inspect (param, SOAC.Input ts a t) = do param' <- newNameFromString (baseString param ++ "_rep") pure (param', SOAC.Input (ts SOAC.|> SOAC.Replicate mempty (Shape [w])) a t) futhark-0.25.27/src/Futhark/Analysis/HORep/SOAC.hs000066400000000000000000000656221475065116200213710ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | High-level representation of SOACs. When performing -- SOAC-transformations, operating on normal 'Exp' values is somewhat -- of a nuisance, as they can represent terms that are not proper -- SOACs. In contrast, this module exposes a SOAC representation that -- does not enable invalid representations (except for type errors). -- -- Furthermore, while standard normalised Futhark requires that the inputs -- to a SOAC are variables or constants, the representation in this -- module also supports various index-space transformations, like -- @replicate@ or @rearrange@. This is also very convenient when -- implementing transformations. -- -- The names exported by this module conflict with the standard Futhark -- syntax tree constructors, so you are advised to use a qualified -- import: -- -- @ -- import Futhark.Analysis.HORep.SOAC (SOAC) -- import qualified Futhark.Analysis.HORep.SOAC as SOAC -- @ module Futhark.Analysis.HORep.SOAC ( -- * SOACs SOAC (..), Futhark.ScremaForm (..), inputs, setInputs, lambda, setLambda, typeOf, width, -- ** Converting to and from expressions NotSOAC (..), fromExp, toExp, toSOAC, -- * SOAC inputs Input (..), varInput, inputTransforms, identInput, isVarInput, isVarishInput, addTransform, addInitialTransforms, inputArray, inputRank, inputType, inputRowType, transformRows, transposeInput, applyTransforms, -- ** Input transformations ArrayTransforms, noTransforms, nullTransforms, (|>), (<|), viewf, ViewF (..), viewl, ViewL (..), ArrayTransform (..), transformFromExp, transformToExp, soacToStream, ) where import Data.Foldable as Foldable import Data.Maybe import Data.Sequence qualified as Seq import Futhark.Construct hiding (toExp) import Futhark.IR hiding ( Index, Iota, Rearrange, Replicate, Reshape, typeOf, ) import Futhark.IR qualified as Futhark import Futhark.IR.SOACS.SOAC ( HistOp (..), ScatterSpec, ScremaForm (..), scremaType, ) import Futhark.IR.SOACS.SOAC qualified as Futhark import Futhark.Transform.Rename (renameLambda) import Futhark.Transform.Substitute import Futhark.Util.Pretty (pretty) import Futhark.Util.Pretty qualified as PP -- | A single, simple transformation. If you want several, don't just -- create a list, use 'ArrayTransforms' instead. data ArrayTransform = -- | A permutation of an otherwise valid input. Rearrange Certs [Int] | -- | A reshaping of an otherwise valid input. Reshape Certs ReshapeKind Shape | -- | A reshaping of the outer dimension. ReshapeOuter Certs ReshapeKind Shape | -- | A reshaping of everything but the outer dimension. ReshapeInner Certs ReshapeKind Shape | -- | Replicate the rows of the array a number of times. Replicate Certs Shape | -- | An array indexing operation. Index Certs (Slice SubExp) deriving (Show, Eq, Ord) instance Substitute ArrayTransform where substituteNames substs (Rearrange cs xs) = Rearrange (substituteNames substs cs) xs substituteNames substs (Reshape cs k ses) = Reshape (substituteNames substs cs) k (substituteNames substs ses) substituteNames substs (ReshapeOuter cs k ses) = ReshapeOuter (substituteNames substs cs) k (substituteNames substs ses) substituteNames substs (ReshapeInner cs k ses) = ReshapeInner (substituteNames substs cs) k (substituteNames substs ses) substituteNames substs (Replicate cs se) = Replicate (substituteNames substs cs) (substituteNames substs se) substituteNames substs (Index cs slice) = Index (substituteNames substs cs) (substituteNames substs slice) -- | A sequence of array transformations, heavily inspired by -- "Data.Seq". You can decompose it using 'viewf' and 'viewl', and -- grow it by using '|>' and '<|'. These correspond closely to the -- similar operations for sequences, except that appending will try to -- normalise and simplify the transformation sequence. -- -- The data type is opaque in order to enforce normalisation -- invariants. Basically, when you grow the sequence, the -- implementation will try to coalesce neighboring permutations, for -- example by composing permutations and removing identity -- transformations. newtype ArrayTransforms = ArrayTransforms (Seq.Seq ArrayTransform) deriving (Eq, Ord, Show) instance Semigroup ArrayTransforms where ts1 <> ts2 = case viewf ts2 of t :< ts2' -> (ts1 |> t) <> ts2' EmptyF -> ts1 instance Monoid ArrayTransforms where mempty = noTransforms instance Substitute ArrayTransforms where substituteNames substs (ArrayTransforms ts) = ArrayTransforms $ substituteNames substs <$> ts -- | The empty transformation list. noTransforms :: ArrayTransforms noTransforms = ArrayTransforms Seq.empty -- | Is it an empty transformation list? nullTransforms :: ArrayTransforms -> Bool nullTransforms (ArrayTransforms s) = Seq.null s -- | Decompose the input-end of the transformation sequence. viewf :: ArrayTransforms -> ViewF viewf (ArrayTransforms s) = case Seq.viewl s of t Seq.:< s' -> t :< ArrayTransforms s' Seq.EmptyL -> EmptyF -- | A view of the first transformation to be applied. data ViewF = EmptyF | ArrayTransform :< ArrayTransforms -- | Decompose the output-end of the transformation sequence. viewl :: ArrayTransforms -> ViewL viewl (ArrayTransforms s) = case Seq.viewr s of s' Seq.:> t -> ArrayTransforms s' :> t Seq.EmptyR -> EmptyL -- | A view of the last transformation to be applied. data ViewL = EmptyL | ArrayTransforms :> ArrayTransform -- | Add a transform to the end of the transformation list. (|>) :: ArrayTransforms -> ArrayTransform -> ArrayTransforms (|>) = flip $ addTransform' extract add $ uncurry (flip (,)) where extract ts' = case viewl ts' of EmptyL -> Nothing ts'' :> t' -> Just (t', ts'') add t' (ArrayTransforms ts') = ArrayTransforms $ ts' Seq.|> t' -- | Add a transform at the beginning of the transformation list. (<|) :: ArrayTransform -> ArrayTransforms -> ArrayTransforms (<|) = addTransform' extract add id where extract ts' = case viewf ts' of EmptyF -> Nothing t' :< ts'' -> Just (t', ts'') add t' (ArrayTransforms ts') = ArrayTransforms $ t' Seq.<| ts' addTransform' :: (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)) -> (ArrayTransform -> ArrayTransforms -> ArrayTransforms) -> ((ArrayTransform, ArrayTransform) -> (ArrayTransform, ArrayTransform)) -> ArrayTransform -> ArrayTransforms -> ArrayTransforms addTransform' extract add swap t ts = fromMaybe (t `add` ts) $ do (t', ts') <- extract ts combined <- uncurry combineTransforms $ swap (t', t) Just $ if identityTransform combined then ts' else addTransform' extract add swap combined ts' identityTransform :: ArrayTransform -> Bool identityTransform (Rearrange _ perm) = Foldable.and $ zipWith (==) perm [0 ..] identityTransform _ = False combineTransforms :: ArrayTransform -> ArrayTransform -> Maybe ArrayTransform combineTransforms (Rearrange cs2 perm2) (Rearrange cs1 perm1) = Just $ Rearrange (cs1 <> cs2) $ perm2 `rearrangeCompose` perm1 combineTransforms _ _ = Nothing -- | Given an expression, determine whether the expression represents -- an input transformation of an array variable. If so, return the -- variable and the transformation. Only 'Rearrange' and 'Reshape' -- are possible to express this way. transformFromExp :: Certs -> Exp rep -> Maybe (VName, ArrayTransform) transformFromExp cs (BasicOp (Futhark.Rearrange perm v)) = Just (v, Rearrange cs perm) transformFromExp cs (BasicOp (Futhark.Reshape k shape v)) = Just (v, Reshape cs k shape) transformFromExp cs (BasicOp (Futhark.Replicate shape (Var v))) = Just (v, Replicate cs shape) transformFromExp cs (BasicOp (Futhark.Index v slice)) = Just (v, Index cs slice) transformFromExp _ _ = Nothing -- | Turn an array transform on an array back into an expression. transformToExp :: (Monad m, HasScope rep m) => ArrayTransform -> VName -> m (Certs, Exp rep) transformToExp (Replicate cs n) ia = pure (cs, BasicOp $ Futhark.Replicate n (Var ia)) transformToExp (Rearrange cs perm) ia = do r <- arrayRank <$> lookupType ia pure (cs, BasicOp $ Futhark.Rearrange (perm ++ [length perm .. r - 1]) ia) transformToExp (Reshape cs k shape) ia = do pure (cs, BasicOp $ Futhark.Reshape k shape ia) transformToExp (ReshapeOuter cs k shape) ia = do shape' <- reshapeOuter shape 1 . arrayShape <$> lookupType ia pure (cs, BasicOp $ Futhark.Reshape k shape' ia) transformToExp (ReshapeInner cs k shape) ia = do shape' <- reshapeInner shape 1 . arrayShape <$> lookupType ia pure (cs, BasicOp $ Futhark.Reshape k shape' ia) transformToExp (Index cs slice) ia = do pure (cs, BasicOp $ Futhark.Index ia slice) -- | One array input to a SOAC - a SOAC may have multiple inputs, but -- all are of this form. Only the array inputs are expressed with -- this type; other arguments, such as initial accumulator values, are -- plain expressions. The transforms are done left-to-right, that is, -- the first element of the 'ArrayTransform' list is applied first. data Input = Input ArrayTransforms VName Type deriving (Show, Eq, Ord) instance Substitute Input where substituteNames substs (Input ts v t) = Input (substituteNames substs ts) (substituteNames substs v) (substituteNames substs t) -- | Create a plain array variable input with no transformations. varInput :: (HasScope t f) => VName -> f Input varInput v = withType <$> lookupType v where withType = Input (ArrayTransforms Seq.empty) v -- | Create a plain array variable input with no transformations, from an 'Ident'. identInput :: Ident -> Input identInput v = Input (ArrayTransforms Seq.empty) (identName v) (identType v) -- | If the given input is a plain variable input, with no transforms, -- return the variable. isVarInput :: Input -> Maybe VName isVarInput (Input ts v _) | nullTransforms ts = Just v isVarInput _ = Nothing -- | If the given input is a plain variable input, with no non-vacuous -- transforms, return the variable. isVarishInput :: Input -> Maybe VName isVarishInput (Input ts v t) | nullTransforms ts = Just v | Reshape cs ReshapeCoerce (Shape [_]) :< ts' <- viewf ts, cs == mempty = isVarishInput $ Input ts' v t isVarishInput _ = Nothing -- | Add a transformation to the end of the transformation list. addTransform :: ArrayTransform -> Input -> Input addTransform tr (Input trs a t) = Input (trs |> tr) a t -- | Add several transformations to the start of the transformation -- list. addInitialTransforms :: ArrayTransforms -> Input -> Input addInitialTransforms ts (Input ots a t) = Input (ts <> ots) a t applyTransform :: (MonadBuilder m) => ArrayTransform -> VName -> m VName applyTransform tr ia = do (cs, e) <- transformToExp tr ia certifying cs $ letExp s e where s = case tr of Replicate {} -> "replicate" Rearrange {} -> "rearrange" Reshape {} -> "reshape" ReshapeOuter {} -> "reshape_outer" ReshapeInner {} -> "reshape_inner" Index {} -> "index" applyTransforms :: (MonadBuilder m) => ArrayTransforms -> VName -> m VName applyTransforms (ArrayTransforms ts) a = foldlM (flip applyTransform) a ts -- | Convert SOAC inputs to the corresponding expressions. inputsToSubExps :: (MonadBuilder m) => [Input] -> m [VName] inputsToSubExps = mapM f where f (Input ts a _) = applyTransforms ts a -- | Return the array name of the input. inputArray :: Input -> VName inputArray (Input _ v _) = v -- | The transformations applied to an input. inputTransforms :: Input -> ArrayTransforms inputTransforms (Input ts _ _) = ts -- | Return the type of an input. inputType :: Input -> Type inputType (Input (ArrayTransforms ts) _ at) = Foldable.foldl transformType at ts where transformType t (Replicate _ shape) = arrayOfShape t shape transformType t (Rearrange _ perm) = rearrangeType perm t transformType t (Reshape _ _ shape) = t `setArrayShape` shape transformType t (ReshapeOuter _ _ shape) = let Shape oldshape = arrayShape t in t `setArrayShape` Shape (shapeDims shape ++ drop 1 oldshape) transformType t (ReshapeInner _ _ shape) = let Shape oldshape = arrayShape t in t `setArrayShape` Shape (take 1 oldshape ++ shapeDims shape) transformType t (Index _ slice) = t `setArrayShape` sliceShape slice -- | Return the row type of an input. Just a convenient alias. inputRowType :: Input -> Type inputRowType = rowType . inputType -- | Return the array rank (dimensionality) of an input. Just a -- convenient alias. inputRank :: Input -> Int inputRank = arrayRank . inputType -- | Apply the transformations to every row of the input. transformRows :: ArrayTransforms -> Input -> Input transformRows (ArrayTransforms ts) = flip (Foldable.foldl transformRows') ts where transformRows' inp (Rearrange cs perm) = addTransform (Rearrange cs (0 : map (+ 1) perm)) inp transformRows' inp (Reshape cs k shape) = addTransform (ReshapeInner cs k shape) inp transformRows' inp (Replicate cs n) | inputRank inp == 1 = Rearrange mempty [1, 0] `addTransform` (Replicate cs n `addTransform` inp) | otherwise = Rearrange mempty (2 : 0 : 1 : [3 .. inputRank inp]) `addTransform` ( Replicate cs n `addTransform` (Rearrange mempty (1 : 0 : [2 .. inputRank inp - 1]) `addTransform` inp) ) transformRows' inp nts = error $ "transformRows: Cannot transform this yet:\n" ++ show nts ++ "\n" ++ show inp -- | Add to the input a 'Rearrange' transform that performs an @(k,n)@ -- transposition. The new transform will be at the end of the current -- transformation list. transposeInput :: Int -> Int -> Input -> Input transposeInput k n inp = addTransform (Rearrange mempty $ transposeIndex k n [0 .. inputRank inp - 1]) inp -- | A definite representation of a SOAC expression. data SOAC rep = Stream SubExp [Input] [SubExp] (Lambda rep) | Scatter SubExp [Input] (ScatterSpec VName) (Lambda rep) | Screma SubExp [Input] (ScremaForm rep) | Hist SubExp [Input] [HistOp rep] (Lambda rep) deriving (Eq, Show) -- | Returns the inputs used in a SOAC. inputs :: SOAC rep -> [Input] inputs (Stream _ arrs _ _) = arrs inputs (Scatter _ arrs _lam _spec) = arrs inputs (Screma _ arrs _) = arrs inputs (Hist _ inps _ _) = inps -- | Set the inputs to a SOAC. setInputs :: [Input] -> SOAC rep -> SOAC rep setInputs arrs (Stream w _ nes lam) = Stream (newWidth arrs w) arrs nes lam setInputs arrs (Scatter w _ lam spec) = Scatter (newWidth arrs w) arrs lam spec setInputs arrs (Screma w _ form) = Screma w arrs form setInputs inps (Hist w _ ops lam) = Hist w inps ops lam newWidth :: [Input] -> SubExp -> SubExp newWidth [] w = w newWidth (inp : _) _ = arraySize 0 $ inputType inp -- | The lambda used in a given SOAC. lambda :: SOAC rep -> Lambda rep lambda (Stream _ _ _ lam) = lam lambda (Scatter _len _ivs _spec lam) = lam lambda (Screma _ _ (ScremaForm lam _ _)) = lam lambda (Hist _ _ _ lam) = lam -- | Set the lambda used in the SOAC. setLambda :: Lambda rep -> SOAC rep -> SOAC rep setLambda lam (Stream w arrs nes _) = Stream w arrs nes lam setLambda lam (Scatter len arrs spec _lam) = Scatter len arrs spec lam setLambda lam (Screma w arrs (ScremaForm _ scan red)) = Screma w arrs (ScremaForm lam scan red) setLambda lam (Hist w ops inps _) = Hist w ops inps lam -- | The return type of a SOAC. typeOf :: SOAC rep -> [Type] typeOf (Stream w _ nes lam) = let accrtps = take (length nes) $ lambdaReturnType lam arrtps = [ arrayOf (stripArray 1 t) (Shape [w]) NoUniqueness | t <- drop (length nes) (lambdaReturnType lam) ] in accrtps ++ arrtps typeOf (Scatter _w _ivs dests lam) = zipWith arrayOfShape val_ts ws where indexes = sum $ zipWith (*) ns $ map length ws val_ts = drop indexes $ lambdaReturnType lam (ws, ns, _) = unzip3 dests typeOf (Screma w _ form) = scremaType w form typeOf (Hist _ _ ops _) = do op <- ops map (`arrayOfShape` histShape op) (lambdaReturnType $ histOp op) -- | The "width" of a SOAC is the expected outer size of its array -- inputs _after_ input-transforms have been carried out. width :: SOAC rep -> SubExp width (Stream w _ _ _) = w width (Scatter len _lam _ivs _as) = len width (Screma w _ _) = w width (Hist w _ _ _) = w -- | Convert a SOAC to the corresponding expression. toExp :: (MonadBuilder m, Op (Rep m) ~ Futhark.SOAC (Rep m)) => SOAC (Rep m) -> m (Exp (Rep m)) toExp soac = Op <$> toSOAC soac -- | Convert a SOAC to a Futhark-level SOAC. toSOAC :: (MonadBuilder m) => SOAC (Rep m) -> m (Futhark.SOAC (Rep m)) toSOAC (Stream w inps nes lam) = Futhark.Stream w <$> inputsToSubExps inps <*> pure nes <*> pure lam toSOAC (Scatter w ivs dests lam) = Futhark.Scatter w <$> inputsToSubExps ivs <*> pure dests <*> pure lam toSOAC (Screma w arrs form) = Futhark.Screma w <$> inputsToSubExps arrs <*> pure form toSOAC (Hist w arrs ops lam) = Futhark.Hist w <$> inputsToSubExps arrs <*> pure ops <*> pure lam -- | The reason why some expression cannot be converted to a 'SOAC' -- value. data NotSOAC = -- | The expression is not a (tuple-)SOAC at all. NotSOAC deriving (Show) -- | Either convert an expression to the normalised SOAC -- representation, or a reason why the expression does not have the -- valid form. fromExp :: (Op rep ~ Futhark.SOAC rep, HasScope rep m) => Exp rep -> m (Either NotSOAC (SOAC rep)) fromExp (Op (Futhark.Stream w as nes lam)) = Right <$> (Stream w <$> traverse varInput as <*> pure nes <*> pure lam) fromExp (Op (Futhark.Scatter w arrs spec lam)) = Right <$> (Scatter w <$> traverse varInput arrs <*> pure spec <*> pure lam) fromExp (Op (Futhark.Screma w arrs form)) = Right <$> (Screma w <$> traverse varInput arrs <*> pure form) fromExp (Op (Futhark.Hist w arrs ops lam)) = Right <$> (Hist w <$> traverse varInput arrs <*> pure ops <*> pure lam) fromExp _ = pure $ Left NotSOAC -- | To-Stream translation of SOACs. -- Returns the Stream SOAC and the -- extra-accumulator body-result ident if any. soacToStream :: ( HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep, Op rep ~ Futhark.SOAC rep ) => SOAC rep -> m (SOAC rep, [Ident]) soacToStream soac = do chunk_param <- newParam "chunk" $ Prim int64 let chvar = Var $ paramName chunk_param (lam, inps) = (lambda soac, inputs soac) w = width soac lam' <- renameLambda lam let arrrtps = mapType w lam -- the chunked-outersize of the array result and input types loutps = [arrayOfRow t chvar | t <- map rowType arrrtps] lintps = [arrayOfRow t chvar | t <- map inputRowType inps] strm_inpids <- mapM (newParam "inp") lintps -- Treat each SOAC case individually: case soac of Screma _ _ form | Just _ <- Futhark.isMapSOAC form -> do -- Map(f,a) => is translated in strem's body to: -- let strm_resids = map(f,a_ch) in strm_resids -- -- array result and input IDs of the stream's lambda strm_resids <- mapM (newIdent "res") loutps let insoac = Futhark.Screma chvar (map paramName strm_inpids) $ Futhark.mapSOAC lam' insstm = mkLet strm_resids $ Op insoac strmbdy = mkBody (oneStm insstm) $ map (subExpRes . Var . identName) strm_resids strmpar = chunk_param : strm_inpids strmlam = Lambda strmpar loutps strmbdy -- map(f,a) creates a stream with NO accumulators pure (Stream w inps [] strmlam, []) | Just (scans, _) <- Futhark.isScanomapSOAC form, Futhark.Scan scan_lam nes <- Futhark.singleScan scans -> do -- scanomap(scan_lam,nes,map_lam,a) => is translated in strem's body to: -- 1. let (scan0_ids,map_resids) = scanomap(scan_lam, nes, map_lam, a_ch) -- 2. let strm_resids = map (acc `+`,nes, scan0_ids) -- 3. let outerszm1id = sizeof(0,strm_resids) - 1 -- 4. let lasteel_ids = if outerszm1id < 0 -- then nes -- else strm_resids[outerszm1id] -- 5. let acc' = acc + lasteel_ids -- {acc', strm_resids, map_resids} -- the array and accumulator result types let scan_arr_ts = map (`arrayOfRow` chvar) $ lambdaReturnType scan_lam accrtps = lambdaReturnType scan_lam inpacc_ids <- mapM (newParam "inpacc") accrtps maplam <- mkMapPlusAccLam (map (Var . paramName) inpacc_ids) scan_lam -- Finally, construct the stream let strmpar = chunk_param : inpacc_ids ++ strm_inpids strmlam <- fmap fst . runBuilder . mkLambda strmpar $ do -- 1. let (scan0_ids,map_resids) = scanomap(scan_lam,nes,map_lam,a_ch) (scan0_ids, map_resids) <- fmap (splitAt (length scan_arr_ts)) . letTupExp "scan" . Op $ Futhark.Screma chvar (map paramName strm_inpids) $ Futhark.scanomapSOAC [Futhark.Scan scan_lam nes] lam' -- 2. let outerszm1id = chunksize - 1 outszm1id <- letSubExp "outszm1" . BasicOp $ BinOp (Sub Int64 OverflowUndef) (Var $ paramName chunk_param) (constant (1 :: Int64)) empty_arr <- letExp "empty_arr" . BasicOp $ CmpOp (CmpSlt Int64) outszm1id (constant (0 :: Int64)) -- 3. let lasteel_ids = ... let indexLast arr = eIndex arr [eSubExp outszm1id] lastel_ids <- letTupExp "lastel" =<< eIf (eSubExp $ Var empty_arr) (resultBodyM nes) (eBody $ map indexLast scan0_ids) addlelbdy <- mkPlusBnds scan_lam $ map Var $ map paramName inpacc_ids ++ lastel_ids let (addlelstm, addlelres) = (bodyStms addlelbdy, bodyResult addlelbdy) -- 4. let strm_resids = map (acc `+`,nes, scan0_ids) strm_resids <- letTupExp "strm_res" . Op $ Futhark.Screma chvar scan0_ids (Futhark.mapSOAC maplam) -- 5. let acc' = acc + lasteel_ids addStms addlelstm pure $ addlelres ++ map (subExpRes . Var) (strm_resids ++ map_resids) pure ( Stream w inps nes strmlam, map paramIdent inpacc_ids ) | Just (reds, _) <- Futhark.isRedomapSOAC form, Futhark.Reduce comm lamin nes <- Futhark.singleReduce reds -> do -- Redomap(+,lam,nes,a) => is translated in strem's body to: -- 1. let (acc0_ids,strm_resids) = redomap(+,lam,nes,a_ch) in -- 2. let acc' = acc + acc0_ids in -- {acc', strm_resids} let accrtps = take (length nes) $ lambdaReturnType lam -- the chunked-outersize of the array result and input types loutps' = drop (length nes) loutps -- the lambda with proper index foldlam = lam' -- array result and input IDs of the stream's lambda strm_resids <- mapM (newIdent "res") loutps' inpacc_ids <- mapM (newParam "inpacc") accrtps acc0_ids <- mapM (newIdent "acc0") accrtps -- 1. let (acc0_ids,strm_resids) = redomap(+,lam,nes,a_ch) in let insoac = Futhark.Screma chvar (map paramName strm_inpids) $ Futhark.redomapSOAC [Futhark.Reduce comm lamin nes] foldlam insstm = mkLet (acc0_ids ++ strm_resids) $ Op insoac -- 2. let acc' = acc + acc0_ids in addaccbdy <- mkPlusBnds lamin $ map Var $ map paramName inpacc_ids ++ map identName acc0_ids -- Construct the stream let (addaccstm, addaccres) = (bodyStms addaccbdy, bodyResult addaccbdy) strmbdy = mkBody (oneStm insstm <> addaccstm) $ addaccres ++ map (subExpRes . Var . identName) strm_resids strmpar = chunk_param : inpacc_ids ++ strm_inpids strmlam = Lambda strmpar (accrtps ++ loutps') strmbdy pure (Stream w inps nes strmlam, []) -- Otherwise it cannot become a stream. _ -> pure (soac, []) where mkMapPlusAccLam :: (MonadFreshNames m, Buildable rep) => [SubExp] -> Lambda rep -> m (Lambda rep) mkMapPlusAccLam accs plus = do let (accpars, rempars) = splitAt (length accs) $ lambdaParams plus parstms = zipWith (\par se -> mkLet [paramIdent par] (BasicOp $ SubExp se)) accpars accs plus_bdy = lambdaBody plus newlambdy = Body (bodyDec plus_bdy) (stmsFromList parstms <> bodyStms plus_bdy) (bodyResult plus_bdy) renameLambda $ Lambda rempars (lambdaReturnType plus) newlambdy mkPlusBnds :: (MonadFreshNames m, Buildable rep) => Lambda rep -> [SubExp] -> m (Body rep) mkPlusBnds plus accels = do plus' <- renameLambda plus let parstms = zipWith (\par se -> mkLet [paramIdent par] (BasicOp $ SubExp se)) (lambdaParams plus') accels body = lambdaBody plus' pure $ body {bodyStms = stmsFromList parstms <> bodyStms body} ppArrayTransform :: PP.Doc a -> ArrayTransform -> PP.Doc a ppArrayTransform e (Rearrange cs perm) = "rearrange" <> pretty cs <> PP.apply [PP.apply (map pretty perm), e] ppArrayTransform e (Reshape cs ReshapeArbitrary shape) = "reshape" <> pretty cs <> PP.apply [pretty shape, e] ppArrayTransform e (ReshapeOuter cs ReshapeArbitrary shape) = "reshape_outer" <> pretty cs <> PP.apply [pretty shape, e] ppArrayTransform e (ReshapeInner cs ReshapeArbitrary shape) = "reshape_inner" <> pretty cs <> PP.apply [pretty shape, e] ppArrayTransform e (Reshape cs ReshapeCoerce shape) = "coerce" <> pretty cs <> PP.apply [pretty shape, e] ppArrayTransform e (ReshapeOuter cs ReshapeCoerce shape) = "coerce_outer" <> pretty cs <> PP.apply [pretty shape, e] ppArrayTransform e (ReshapeInner cs ReshapeCoerce shape) = "coerce_inner" <> pretty cs <> PP.apply [pretty shape, e] ppArrayTransform e (Replicate cs ne) = "replicate" <> pretty cs <> PP.apply [pretty ne, e] ppArrayTransform e (Index cs slice) = e <> pretty cs <> pretty slice instance PP.Pretty Input where pretty (Input (ArrayTransforms ts) arr _) = foldl ppArrayTransform (pretty arr) ts instance PP.Pretty ArrayTransform where pretty = ppArrayTransform "INPUT" instance (PrettyRep rep) => PP.Pretty (SOAC rep) where pretty (Screma w arrs form) = Futhark.ppScrema w arrs form pretty (Hist len imgs ops bucket_fun) = Futhark.ppHist len imgs ops bucket_fun pretty (Stream w arrs nes lam) = Futhark.ppStream w arrs nes lam pretty (Scatter w arrs dests lam) = Futhark.ppScatter w arrs dests lam futhark-0.25.27/src/Futhark/Analysis/Interference.hs000066400000000000000000000276171475065116200223020ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Interference analysis for Futhark programs. module Futhark.Analysis.Interference (Graph, analyseProgGPU) where import Control.Monad import Control.Monad.Reader import Data.Foldable (toList) import Data.Function ((&)) import Data.Functor ((<&>)) import Data.Map (Map) import Data.Map qualified as M import Data.Maybe (catMaybes, fromMaybe, mapMaybe) import Data.Set (Set) import Data.Set qualified as S import Futhark.Analysis.Alias qualified as AnlAls import Futhark.Analysis.LastUse (LUTabFun) import Futhark.Analysis.LastUse qualified as LastUse import Futhark.Analysis.MemAlias qualified as MemAlias import Futhark.IR.GPUMem import Futhark.Util (cartesian, invertMap) -- | The set of 'VName' currently in use. type InUse = Names -- | The set of 'VName' that are no longer in use. type LastUsed = Names -- | An interference graph. An element @(x, y)@ in the set means that there is -- an undirected edge between @x@ and @y@, and therefore the lifetimes of @x@ -- and @y@ overlap and they "interfere" with each other. We assume that pairs -- are always normalized, such that @x@ < @y@, before inserting. This should -- prevent any duplicates. We also don't allow any pairs where @x == y@. type Graph a = Set (a, a) -- | Insert an edge between two values into the graph. makeEdge :: (Ord a) => a -> a -> Graph a makeEdge v1 v2 | v1 == v2 = mempty | otherwise = S.singleton (min v1 v2, max v1 v2) analyseStm :: (LocalScope GPUMem m) => LUTabFun -> InUse -> Stm GPUMem -> m (InUse, LastUsed, Graph VName) analyseStm lumap inuse0 stm = inScopeOf stm $ do let pat_name = patElemName $ head $ patElems $ stmPat stm new_mems <- stmPat stm & patElems & mapM (memInfo . patElemName) <&> catMaybes <&> namesFromList -- `new_mems` should interfere with any mems inside the statement expression let inuse_outside = inuse0 <> new_mems -- `inuse` is the set of memory blocks that are inuse at the end of any code -- bodies inside the expression. `lus` is the set of all memory blocks that -- have reached their last use in any code bodies inside the -- expression. `graph` is the interference graph computed for any code -- bodies inside the expression. (inuse, lus, graph) <- analyseExp lumap inuse_outside (stmExp stm) last_use_mems <- M.lookup pat_name lumap & fromMaybe mempty & namesToList & mapM memInfo <&> catMaybes <&> namesFromList <&> namesIntersection inuse_outside pure ( (inuse_outside `namesSubtract` last_use_mems `namesSubtract` lus) <> new_mems, (lus <> last_use_mems) `namesSubtract` new_mems, graph <> cartesian makeEdge (namesToList inuse_outside) (namesToList $ inuse_outside <> inuse <> lus <> last_use_mems) ) -- We conservatively treat all memory arguments to a Loop to -- interfere with each other, as well as anything used inside the -- loop. This could potentially be improved by looking at the -- interference computed by the loop body wrt. the loop arguments, but -- probably very few programs would benefit from this. analyseLoopParams :: [(FParam GPUMem, SubExp)] -> (InUse, LastUsed, Graph VName) -> (InUse, LastUsed, Graph VName) analyseLoopParams merge (inuse, lastused, graph) = (inuse, lastused, cartesian makeEdge mems (mems <> inner_mems) <> graph) where mems = mapMaybe isMemArg merge inner_mems = namesToList lastused <> namesToList inuse isMemArg (Param _ _ MemMem {}, Var v) = Just v isMemArg _ = Nothing analyseExp :: (LocalScope GPUMem m) => LUTabFun -> InUse -> Exp GPUMem -> m (InUse, LastUsed, Graph VName) analyseExp lumap inuse_outside expr = case expr of Match _ cases defbody _ -> fmap mconcat $ mapM (analyseBody lumap inuse_outside) $ defbody : map caseBody cases Loop merge _ body -> analyseLoopParams merge <$> analyseBody lumap inuse_outside body Op (Inner (SegOp segop)) -> do analyseSegOp lumap inuse_outside segop _ -> pure mempty analyseKernelBody :: (LocalScope GPUMem m) => LUTabFun -> InUse -> KernelBody GPUMem -> m (InUse, LastUsed, Graph VName) analyseKernelBody lumap inuse body = analyseStms lumap inuse $ kernelBodyStms body analyseBody :: (LocalScope GPUMem m) => LUTabFun -> InUse -> Body GPUMem -> m (InUse, LastUsed, Graph VName) analyseBody lumap inuse body = analyseStms lumap inuse $ bodyStms body analyseStms :: (LocalScope GPUMem m) => LUTabFun -> InUse -> Stms GPUMem -> m (InUse, LastUsed, Graph VName) analyseStms lumap inuse0 stms = do inScopeOf stms $ foldM helper (inuse0, mempty, mempty) $ stmsToList stms where helper (inuse, lus, graph) stm = do (inuse', lus', graph') <- analyseStm lumap inuse stm pure (inuse', lus' <> lus, graph' <> graph) analyseSegOp :: (LocalScope GPUMem m) => LUTabFun -> InUse -> SegOp lvl GPUMem -> m (InUse, LastUsed, Graph VName) analyseSegOp lumap inuse (SegMap _ _ _ body) = analyseKernelBody lumap inuse body analyseSegOp lumap inuse (SegRed _ _ binops _ body) = segWithBinOps lumap inuse binops body analyseSegOp lumap inuse (SegScan _ _ binops _ body) = do segWithBinOps lumap inuse binops body analyseSegOp lumap inuse (SegHist _ _ histops _ body) = do (inuse', lus', graph) <- analyseKernelBody lumap inuse body (inuse'', lus'', graph') <- mconcat <$> mapM (analyseHistOp lumap inuse') histops pure (inuse'', lus' <> lus'', graph <> graph') segWithBinOps :: (LocalScope GPUMem m) => LUTabFun -> InUse -> [SegBinOp GPUMem] -> KernelBody GPUMem -> m (InUse, LastUsed, Graph VName) segWithBinOps lumap inuse binops body = do (inuse', lus', graph) <- analyseKernelBody lumap inuse body (inuse'', lus'', graph') <- mconcat <$> mapM (analyseSegBinOp lumap inuse') binops pure (inuse'', lus' <> lus'', graph <> graph') analyseSegBinOp :: (LocalScope GPUMem m) => LUTabFun -> InUse -> SegBinOp GPUMem -> m (InUse, LastUsed, Graph VName) analyseSegBinOp lumap inuse (SegBinOp _ lambda _ _) = analyseLambda lumap inuse lambda analyseHistOp :: (LocalScope GPUMem m) => LUTabFun -> InUse -> HistOp GPUMem -> m (InUse, LastUsed, Graph VName) analyseHistOp lumap inuse = analyseLambda lumap inuse . histOp analyseLambda :: (LocalScope GPUMem m) => LUTabFun -> InUse -> Lambda GPUMem -> m (InUse, LastUsed, Graph VName) analyseLambda lumap inuse = analyseBody lumap inuse . lambdaBody analyseProgGPU :: Prog GPUMem -> Graph VName analyseProgGPU prog = onConsts (progConsts prog) <> foldMap onFun (progFuns prog) where (consts_aliases, funs_aliases) = MemAlias.analyzeGPUMem prog (lumap_consts, lumap) = LastUse.lastUseGPUMem $ AnlAls.aliasAnalysis prog onFun f = applyAliases (fromMaybe mempty $ M.lookup (funDefName f) funs_aliases) $ runReader (analyseGPU (lumap M.! funDefName f) $ bodyStms $ funDefBody f) $ scopeOf f onConsts stms = applyAliases consts_aliases $ runReader (analyseGPU lumap_consts stms) (mempty :: Scope GPUMem) applyAliases :: MemAlias.MemAliases -> Graph VName -> Graph VName applyAliases aliases = -- For each pair @(x, y)@ in graph, all memory aliases of x should interfere with all memory aliases of y foldMap ( \(x, y) -> let xs = MemAlias.aliasesOf aliases x <> oneName x ys = MemAlias.aliasesOf aliases y <> oneName y in cartesian makeEdge (namesToList xs) (namesToList ys) ) -- | Perform interference analysis on the given statements. The result is a -- triple of the names currently in use, names that hit their last use somewhere -- within, and the resulting graph. analyseGPU :: (LocalScope GPUMem m) => LUTabFun -> Stms GPUMem -> m (Graph VName) analyseGPU lumap stms = do (_, _, graph) <- analyseGPU' lumap stms -- We need to insert edges between memory blocks which differ in size, if they -- are in DefaultSpace. The problem is that during memory expansion, -- DefaultSpace arrays in kernels are interleaved. If the element sizes of two -- merged memory blocks are different, threads might try to read and write to -- overlapping memory positions. More information here: -- https://munksgaard.me/technical-diary/2020-12-30.html#org210775b spaces <- M.filter (== DefaultSpace) <$> memSpaces stms inv_size_map <- memSizes stms <&> flip M.restrictKeys (S.fromList $ M.keys spaces) <&> invertMap let new_edges = cartesian (\x y -> if x /= y then cartesian makeEdge x y else mempty) inv_size_map inv_size_map pure $ graph <> new_edges -- | Return a mapping from memory blocks to their element sizes in the given -- statements. memSizes :: (LocalScope GPUMem m) => Stms GPUMem -> m (Map VName Int) memSizes stms = inScopeOf stms $ fmap mconcat <$> mapM memSizesStm $ stmsToList stms where memSizesStm :: (LocalScope GPUMem m) => Stm GPUMem -> m (Map VName Int) memSizesStm (Let pat _ e) = do arraySizes <- fmap mconcat <$> mapM memElemSize $ patNames pat arraySizes' <- memSizesExp e pure $ arraySizes <> arraySizes' memSizesExp :: (LocalScope GPUMem m) => Exp GPUMem -> m (Map VName Int) memSizesExp (Op (Inner (SegOp segop))) = let body = segBody segop in inScopeOf (kernelBodyStms body) $ fmap mconcat <$> mapM memSizesStm $ stmsToList $ kernelBodyStms body memSizesExp (Match _ cases defbody _) = do mconcat <$> mapM (memSizes . bodyStms) (defbody : map caseBody cases) memSizesExp (Loop _ _ body) = memSizes $ bodyStms body memSizesExp _ = pure mempty -- | Return a mapping from memory blocks to the space they are allocated in. memSpaces :: (LocalScope GPUMem m) => Stms GPUMem -> m (Map VName Space) memSpaces stms = pure $ foldMap getSpacesStm stms where getSpacesStm :: Stm GPUMem -> Map VName Space getSpacesStm (Let (Pat [PatElem name _]) _ (Op (Alloc _ sp))) = M.singleton name sp getSpacesStm (Let _ _ (Op (Alloc _ _))) = error "impossible" getSpacesStm (Let _ _ (Op (Inner (SegOp segop)))) = foldMap getSpacesStm $ kernelBodyStms $ segBody segop getSpacesStm (Let _ _ (Match _ cases defbody _)) = foldMap (foldMap getSpacesStm . bodyStms) $ defbody : map caseBody cases getSpacesStm (Let _ _ (Loop _ _ body)) = foldMap getSpacesStm (bodyStms body) getSpacesStm _ = mempty analyseGPU' :: (LocalScope GPUMem m) => LUTabFun -> Stms GPUMem -> m (InUse, LastUsed, Graph VName) analyseGPU' lumap stms = mconcat . toList <$> mapM helper stms where helper :: (LocalScope GPUMem m) => Stm GPUMem -> m (InUse, LastUsed, Graph VName) helper stm@Let {stmExp = Op (Inner (SegOp segop))} = inScopeOf stm $ analyseSegOp lumap mempty segop helper stm@Let {stmExp = Match _ cases defbody _} = inScopeOf stm $ mconcat <$> mapM (analyseGPU' lumap . bodyStms) (defbody : map caseBody cases) helper stm@Let {stmExp = Loop merge _ body} = fmap (analyseLoopParams merge) . inScopeOf stm $ analyseGPU' lumap $ bodyStms body helper stm = inScopeOf stm $ pure mempty memInfo :: (LocalScope GPUMem m) => VName -> m (Maybe VName) memInfo vname = do summary <- asksScope (fmap nameInfoToMemInfo . M.lookup vname) case summary of Just (MemArray _ _ _ (ArrayIn mem _)) -> pure $ Just mem _ -> pure Nothing -- | Returns a mapping from memory block to element size. The input is the -- `VName` of a variable (supposedly an array), and the result is a mapping from -- the memory block of that array to element size of the array. memElemSize :: (LocalScope GPUMem m) => VName -> m (Map VName Int) memElemSize vname = do summary <- asksScope (fmap nameInfoToMemInfo . M.lookup vname) case summary of Just (MemArray pt _ _ (ArrayIn mem _)) -> pure $ M.singleton mem (primByteSize pt) _ -> pure mempty futhark-0.25.27/src/Futhark/Analysis/LastUse.hs000066400000000000000000000434571475065116200212510ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Last use analysis for array short circuiting -- -- Last-Use analysis of a Futhark program in aliased explicit-memory lore form. -- Takes as input such a program or a function and produces a 'M.Map VName -- Names', in which the key identified the let stmt, and the list argument -- identifies the variables that were lastly used in that stmt. Note that the -- results of a body do not have a last use, and neither do a function -- parameters if it happens to not be used inside function's body. Such cases -- are supposed to be treated separately. module Futhark.Analysis.LastUse ( lastUseSeqMem, lastUseGPUMem, lastUseMCMem, LUTabFun, LUTabProg, ) where import Control.Monad import Control.Monad.Reader import Control.Monad.State.Strict import Data.Bifunctor (bimap) import Data.Function ((&)) import Data.Map.Strict qualified as M import Data.Maybe import Data.Sequence (Seq (..)) import Futhark.IR.Aliases import Futhark.IR.GPUMem import Futhark.IR.GPUMem qualified as GPU import Futhark.IR.MCMem import Futhark.IR.MCMem qualified as MC import Futhark.IR.SeqMem import Futhark.Optimise.ArrayShortCircuiting.DataStructs import Futhark.Util -- | Maps a name indentifying a Stm to the last uses in that Stm. type LUTabFun = M.Map VName Names -- | LU-table for the constants, and for each function. type LUTabProg = (LUTabFun, M.Map Name LUTabFun) type LastUseOp rep = Op (Aliases rep) -> Names -> LastUseM rep (LUTabFun, Names, Names) -- | 'LastUseReader' allows us to abstract over representations by supplying the -- 'onOp' function. data LastUseReader rep = LastUseReader { onOp :: LastUseOp rep, scope :: Scope (Aliases rep) } -- | Maps a variable or memory block to its aliases. type AliasTab = M.Map VName Names newtype LastUseM rep a = LastUseM (StateT AliasTab (Reader (LastUseReader rep)) a) deriving ( Monad, Functor, Applicative, MonadReader (LastUseReader rep), MonadState AliasTab ) instance (RepTypes (Aliases rep)) => HasScope (Aliases rep) (LastUseM rep) where askScope = asks scope instance (RepTypes (Aliases rep)) => LocalScope (Aliases rep) (LastUseM rep) where localScope sc (LastUseM m) = LastUseM $ do local (\rd -> rd {scope = scope rd <> sc}) m type Constraints rep = ( LocalScope (Aliases rep) (LastUseM rep), HasMemBlock (Aliases rep), AliasableRep rep ) runLastUseM :: LastUseOp rep -> LastUseM rep a -> a runLastUseM onOp (LastUseM m) = runReader (evalStateT m mempty) (LastUseReader onOp mempty) aliasLookup :: VName -> LastUseM rep Names aliasLookup vname = gets $ fromMaybe mempty . M.lookup vname lastUseProg :: (Constraints rep) => Prog (Aliases rep) -> LastUseM rep LUTabProg lastUseProg prog = let bound_in_consts = progConsts prog & concatMap (patNames . stmPat) & namesFromList consts = progConsts prog funs = progFuns prog in inScopeOf consts $ do (consts_lu, _) <- lastUseStms consts mempty mempty lus <- mapM (lastUseFun bound_in_consts) funs pure (consts_lu, M.fromList $ zip (map funDefName funs) lus) lastUseFun :: (Constraints rep) => Names -> FunDef (Aliases rep) -> LastUseM rep LUTabFun lastUseFun bound_in_consts f = inScopeOf f $ fst <$> lastUseBody (funDefBody f) (mempty, bound_in_consts) -- | Perform last-use analysis on a 'Prog' in 'SeqMem' lastUseSeqMem :: Prog (Aliases SeqMem) -> LUTabProg lastUseSeqMem = runLastUseM lastUseSeqOp . lastUseProg -- | Perform last-use analysis on a 'Prog' in 'GPUMem' lastUseGPUMem :: Prog (Aliases GPUMem) -> LUTabProg lastUseGPUMem = runLastUseM (lastUseMemOp lastUseGPUOp) . lastUseProg -- | Perform last-use analysis on a 'Prog' in 'MCMem' lastUseMCMem :: Prog (Aliases MCMem) -> LUTabProg lastUseMCMem = runLastUseM (lastUseMemOp lastUseMCOp) . lastUseProg -- | Performing the last-use analysis on a body. -- -- The implementation consists of a bottom-up traversal of the body's statements -- in which the the variables lastly used in a statement are computed as the -- difference between the free-variables in that stmt and the set of variables -- known to be used after that statement. lastUseBody :: (Constraints rep) => -- | The body of statements Body (Aliases rep) -> -- | The current last-use table, tupled with the known set of already used names (LUTabFun, Names) -> -- | The result is: -- (i) an updated last-use table, -- (ii) an updated set of used names (including the binding). LastUseM rep (LUTabFun, Names) lastUseBody bdy@(Body _ stms result) (lutab, used_nms) = -- perform analysis bottom-up in bindings: results are known to be used, -- hence they are added to the used_nms set. inScopeOf stms $ do (lutab', _) <- lastUseStms stms (lutab, used_nms) $ namesToList $ freeIn $ map resSubExp result -- Clean up the used names by recomputing the aliasing transitive-closure -- of the free names in body based on the current alias table @alstab@. used_in_body <- aliasTransitiveClosure $ freeIn bdy pure (lutab', used_nms <> used_in_body) -- | Performing the last-use analysis on a body. -- -- The implementation consists of a bottom-up traversal of the body's statements -- in which the the variables lastly used in a statement are computed as the -- difference between the free-variables in that stmt and the set of variables -- known to be used after that statement. lastUseKernelBody :: (Constraints rep) => -- | The body of statements KernelBody (Aliases rep) -> -- | The current last-use table, tupled with the known set of already used names (LUTabFun, Names) -> -- | The result is: -- (i) an updated last-use table, -- (ii) an updated set of used names (including the binding). LastUseM rep (LUTabFun, Names) lastUseKernelBody bdy@(KernelBody _ stms result) (lutab, used_nms) = inScopeOf stms $ do -- perform analysis bottom-up in bindings: results are known to be used, -- hence they are added to the used_nms set. (lutab', _) <- lastUseStms stms (lutab, used_nms) $ namesToList $ freeIn result -- Clean up the used names by recomputing the aliasing transitive-closure -- of the free names in body based on the current alias table @alstab@. used_in_body <- aliasTransitiveClosure $ freeIn bdy pure (lutab', used_nms <> used_in_body) lastUseStms :: (Constraints rep) => Stms (Aliases rep) -> (LUTabFun, Names) -> [VName] -> LastUseM rep (LUTabFun, Names) lastUseStms Empty (lutab, nms) res_nms = do aliases <- concatMapM aliasLookup res_nms pure (lutab, nms <> aliases <> namesFromList res_nms) lastUseStms (stm@(Let pat _ e) :<| stms) (lutab, nms) res_nms = inScopeOf stm $ do let extra_alias = case e of BasicOp (Update _ old _ _) -> oneName old BasicOp (FlatUpdate old _ _) -> oneName old _ -> mempty -- We build up aliases top-down updateAliasing extra_alias pat -- But compute last use bottom-up (lutab', nms') <- lastUseStms stms (lutab, nms) res_nms (lutab'', nms'') <- lastUseStm stm (lutab', nms') pure (lutab'', nms'') lastUseStm :: (Constraints rep) => Stm (Aliases rep) -> (LUTabFun, Names) -> LastUseM rep (LUTabFun, Names) lastUseStm (Let pat _ e) (lutab, used_nms) = do -- analyse the expression and get the -- (i) a new last-use table (in case the @e@ contains bodies of stmts) -- (ii) the set of variables lastly used in the current binding. -- (iii) aliased transitive-closure of used names, and (lutab', last_uses, used_nms') <- lastUseExp e used_nms sc <- asks scope let lu_mems = namesToList last_uses & mapMaybe (`getScopeMemInfo` sc) & map memName & namesFromList & flip namesSubtract used_nms -- filter-out the binded names from the set of used variables, -- since they go out of scope, and update the last-use table. let patnms = patNames pat used_nms'' = used_nms' `namesSubtract` namesFromList patnms lutab'' = M.union lutab' $ M.insert (head patnms) (last_uses <> lu_mems) lutab pure (lutab'', used_nms'') -------------------------------- -- | Last-Use Analysis for an expression. lastUseExp :: (Constraints rep) => -- | The expression to analyse Exp (Aliases rep) -> -- | The set of used names "after" this expression Names -> -- | Result: -- 1. an extra LUTab recording the last use for expression's inner bodies, -- 2. the set of last-used vars in the expression at this level, -- 3. the updated used names, now including expression's free vars. LastUseM rep (LUTabFun, Names, Names) lastUseExp (Match _ cases body _) used_nms = do -- For an if-then-else, we duplicate the last use at each body level, meaning -- we record the last use of the outer statement, and also the last use in the -- statement in the inner bodies. We can safely ignore the if-condition as it is -- a boolean scalar. (lutab_cases, used_cases) <- bimap mconcat mconcat . unzip <$> mapM (flip lastUseBody (M.empty, used_nms) . caseBody) cases (lutab', body_used_nms) <- lastUseBody body (M.empty, used_nms) let free_in_body = freeIn body let free_in_cases = freeIn cases let used_nms' = used_cases <> body_used_nms (_, last_used_arrs) <- lastUsedInNames used_nms $ free_in_body <> free_in_cases pure (lutab_cases <> lutab', last_used_arrs, used_nms') lastUseExp (Loop var_ses form body) used_nms0 = localScope (scopeOfLoopForm form) $ do free_in_body <- aliasTransitiveClosure $ freeIn body -- compute the aliasing transitive closure of initializers that are not last-uses var_inis <- catMaybes <$> mapM (initHelper (free_in_body <> used_nms0)) var_ses let -- To record last-uses inside the loop body, we call 'lastUseBody' with used-names -- being: (free_in_body - loop-variants-a) + used_nms0. As such we disable cases b) -- and c) to produce loop-variant last uses inside the loop, and also we prevent -- the free-loop-variables to having last uses inside the loop. free_in_body' = free_in_body `namesSubtract` namesFromList (map fst var_inis) used_nms = used_nms0 <> free_in_body' <> freeIn (bodyResult body) (body_lutab, _) <- lastUseBody body (mempty, used_nms) -- add var_inis_a to the body_lutab, i.e., record the last-use of -- initializer in the corresponding loop variant. let lutab_res = body_lutab <> M.fromList var_inis -- the result used names are: fpar_nms = namesFromList $ map (identName . paramIdent . fst) var_ses used_nms' = (free_in_body <> freeIn (map snd var_ses)) `namesSubtract` fpar_nms used_nms_res = used_nms0 <> used_nms' <> freeIn (bodyResult body) -- the last-uses at loop-statement level are the loop free variables that -- do not belong to @used_nms0@; this includes the initializers of b), @lu_ini_b@ lu_arrs = used_nms' `namesSubtract` used_nms0 pure (lutab_res, lu_arrs, used_nms_res) where initHelper free_and_used (fp, se) = do names <- aliasTransitiveClosure $ maybe mempty oneName $ subExpVar se if names `namesIntersect` free_and_used then pure Nothing else pure $ Just (identName $ paramIdent fp, names) lastUseExp (Op op) used_nms = do on_op <- reader onOp on_op op used_nms lastUseExp e used_nms = do let free_in_e = freeIn e (used_nms', lu_vars) <- lastUsedInNames used_nms free_in_e pure (M.empty, lu_vars, used_nms') lastUseMemOp :: (inner (Aliases rep) -> Names -> LastUseM rep (LUTabFun, Names, Names)) -> MemOp inner (Aliases rep) -> Names -> LastUseM rep (LUTabFun, Names, Names) lastUseMemOp _ (Alloc se sp) used_nms = do let free_in_e = freeIn se <> freeIn sp (used_nms', lu_vars) <- lastUsedInNames used_nms free_in_e pure (M.empty, lu_vars, used_nms') lastUseMemOp onInner (Inner op) used_nms = onInner op used_nms lastUseSegOp :: (Constraints rep) => SegOp lvl (Aliases rep) -> Names -> LastUseM rep (LUTabFun, Names, Names) lastUseSegOp (SegMap _ _ tps kbody) used_nms = do (used_nms', lu_vars) <- lastUsedInNames used_nms $ freeIn tps (body_lutab, used_nms'') <- lastUseKernelBody kbody (mempty, used_nms') pure (body_lutab, lu_vars, used_nms' <> used_nms'') lastUseSegOp (SegRed _ _ sbos tps kbody) used_nms = do (lutab_sbo, lu_vars_sbo, used_nms_sbo) <- lastUseSegBinOp sbos used_nms (used_nms', lu_vars) <- lastUsedInNames used_nms_sbo $ freeIn tps (body_lutab, used_nms'') <- lastUseKernelBody kbody (mempty, used_nms') pure (M.union lutab_sbo body_lutab, lu_vars <> lu_vars_sbo, used_nms_sbo <> used_nms' <> used_nms'') lastUseSegOp (SegScan _ _ sbos tps kbody) used_nms = do (lutab_sbo, lu_vars_sbo, used_nms_sbo) <- lastUseSegBinOp sbos used_nms (used_nms', lu_vars) <- lastUsedInNames used_nms_sbo $ freeIn tps (body_lutab, used_nms'') <- lastUseKernelBody kbody (mempty, used_nms') pure (M.union lutab_sbo body_lutab, lu_vars <> lu_vars_sbo, used_nms_sbo <> used_nms' <> used_nms'') lastUseSegOp (SegHist _ _ hos tps kbody) used_nms = do (lutab_sbo, lu_vars_sbo, used_nms_sbo) <- lastUseHistOp hos used_nms (used_nms', lu_vars) <- lastUsedInNames used_nms_sbo $ freeIn tps (body_lutab, used_nms'') <- lastUseKernelBody kbody (mempty, used_nms') pure (M.union lutab_sbo body_lutab, lu_vars <> lu_vars_sbo, used_nms_sbo <> used_nms' <> used_nms'') lastUseGPUOp :: HostOp NoOp (Aliases GPUMem) -> Names -> LastUseM GPUMem (LUTabFun, Names, Names) lastUseGPUOp (GPU.OtherOp NoOp) used_nms = pure (mempty, mempty, used_nms) lastUseGPUOp (SizeOp sop) used_nms = do (used_nms', lu_vars) <- lastUsedInNames used_nms $ freeIn sop pure (mempty, lu_vars, used_nms') lastUseGPUOp (GPUBody tps body) used_nms = do (used_nms', lu_vars) <- lastUsedInNames used_nms $ freeIn tps (body_lutab, used_nms'') <- lastUseBody body (mempty, used_nms') pure (body_lutab, lu_vars, used_nms' <> used_nms'') lastUseGPUOp (SegOp op) used_nms = lastUseSegOp op used_nms lastUseMCOp :: MCOp NoOp (Aliases MCMem) -> Names -> LastUseM MCMem (LUTabFun, Names, Names) lastUseMCOp (MC.OtherOp NoOp) used_nms = pure (mempty, mempty, used_nms) lastUseMCOp (MC.ParOp par_op op) used_nms = do (lutab_par_op, lu_vars_par_op, used_names_par_op) <- maybe (pure mempty) (`lastUseSegOp` used_nms) par_op (lutab_op, lu_vars_op, used_names_op) <- lastUseSegOp op used_nms pure ( lutab_par_op <> lutab_op, lu_vars_par_op <> lu_vars_op, used_names_par_op <> used_names_op ) lastUseSegBinOp :: (Constraints rep) => [SegBinOp (Aliases rep)] -> Names -> LastUseM rep (LUTabFun, Names, Names) lastUseSegBinOp sbos used_nms = do (lutab, lu_vars, used_nms') <- unzip3 <$> mapM helper sbos pure (mconcat lutab, mconcat lu_vars, mconcat used_nms') where helper (SegBinOp _ l@(Lambda _ _ body) neutral shp) = inScopeOf l $ do (used_nms', lu_vars) <- lastUsedInNames used_nms $ freeIn neutral <> freeIn shp (body_lutab, used_nms'') <- lastUseBody body (mempty, used_nms') pure (body_lutab, lu_vars, used_nms'') lastUseHistOp :: (Constraints rep) => [HistOp (Aliases rep)] -> Names -> LastUseM rep (LUTabFun, Names, Names) lastUseHistOp hos used_nms = do (lutab, lu_vars, used_nms') <- unzip3 <$> mapM helper hos pure (mconcat lutab, mconcat lu_vars, mconcat used_nms') where helper (HistOp shp rf dest neutral shp' l@(Lambda _ _ body)) = inScopeOf l $ do (used_nms', lu_vars) <- lastUsedInNames used_nms $ freeIn shp <> freeIn rf <> freeIn dest <> freeIn neutral <> freeIn shp' (body_lutab, used_nms'') <- lastUseBody body (mempty, used_nms') pure (body_lutab, lu_vars, used_nms'') lastUseSeqOp :: Op (Aliases SeqMem) -> Names -> LastUseM SeqMem (LUTabFun, Names, Names) lastUseSeqOp (Alloc se sp) used_nms = do let free_in_e = freeIn se <> freeIn sp (used_nms', lu_vars) <- lastUsedInNames used_nms free_in_e pure (mempty, lu_vars, used_nms') lastUseSeqOp (Inner NoOp) used_nms = do pure (mempty, mempty, used_nms) ------------------------------------------------------ -- | Given already used names and newly encountered 'Names', return an updated -- set used names and the set of names that were last used here. -- -- For a given name @x@ in the new uses, if neither @x@ nor any of its aliases -- are present in the set of used names, this is a last use of @x@. lastUsedInNames :: -- | Used names Names -> -- | New uses Names -> LastUseM rep (Names, Names) lastUsedInNames used_nms new_uses = do -- a use of an argument x is also a use of any variable in x alias set -- so we update the alias-based transitive-closure of used names. new_uses_with_aliases <- aliasTransitiveClosure new_uses -- if neither a variable x, nor any of its alias set have been used before (in -- the backward traversal), then it is a last use of both that variable and -- all other variables in its alias set last_uses <- filterM isLastUse $ namesToList new_uses last_uses' <- aliasTransitiveClosure $ namesFromList last_uses pure (used_nms <> new_uses_with_aliases, last_uses') where isLastUse x = do with_aliases <- aliasTransitiveClosure $ oneName x pure $ not $ with_aliases `namesIntersect` used_nms -- | Compute the transitive closure of the aliases of a set of 'Names'. aliasTransitiveClosure :: Names -> LastUseM rep Names aliasTransitiveClosure args = do res <- foldl (<>) args <$> mapM aliasLookup (namesToList args) if res == args then pure res else aliasTransitiveClosure res -- | For each 'PatElem' in the 'Pat', add its aliases to the 'AliasTab' in -- 'LastUseM'. Additionally, 'Names' are added as aliases of all the 'PatElemT'. updateAliasing :: (AliasesOf dec) => -- | Extra names that all 'PatElem' should alias. Names -> -- | Pattern to process Pat dec -> LastUseM rep () updateAliasing extra_aliases = mapM_ update . patElems where update :: (AliasesOf dec) => PatElem dec -> LastUseM rep () update (PatElem name dec) = do let aliases = aliasesOf dec aliases' <- aliasTransitiveClosure $ extra_aliases <> aliases modify $ M.insert name aliases' futhark-0.25.27/src/Futhark/Analysis/MemAlias.hs000066400000000000000000000125631475065116200213530ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.Analysis.MemAlias ( analyzeSeqMem, analyzeGPUMem, aliasesOf, MemAliases, ) where import Control.Monad import Control.Monad.Reader import Data.Bifunctor import Data.Function ((&)) import Data.Functor ((<&>)) import Data.Map qualified as M import Data.Maybe (fromMaybe, mapMaybe) import Data.Set qualified as S import Futhark.IR.GPUMem import Futhark.IR.SeqMem import Futhark.Util import Futhark.Util.Pretty -- For our purposes, memory aliases are a bijective function: If @a@ aliases -- @b@, @b@ also aliases @a@. However, this relationship is not transitive. Consider for instance the following: -- -- @ -- let xs@mem_1 = -- if ... then -- replicate i 0 @ mem_2 -- else -- replicate j 1 @ mem_3 -- @ -- -- Here, @mem_1@ aliases both @mem_2@ and @mem_3@, each of which alias @mem_1@ -- but not each other. newtype MemAliases = MemAliases (M.Map VName Names) deriving (Show, Eq) instance Semigroup MemAliases where (MemAliases m1) <> (MemAliases m2) = MemAliases $ M.unionWith (<>) m1 m2 instance Monoid MemAliases where mempty = MemAliases mempty instance Pretty MemAliases where pretty (MemAliases m) = stack $ map f $ M.toList m where f (v, vs) = pretty v <+> "aliases:" indent 2 (oneLine $ pretty vs) addAlias :: VName -> VName -> MemAliases -> MemAliases addAlias v1 v2 m = m <> singleton v1 (oneName v2) <> singleton v2 mempty singleton :: VName -> Names -> MemAliases singleton v ns = MemAliases $ M.singleton v ns aliasesOf :: MemAliases -> VName -> Names aliasesOf (MemAliases m) v = fromMaybe mempty $ M.lookup v m isIn :: VName -> MemAliases -> Bool isIn v (MemAliases m) = v `S.member` M.keysSet m newtype Env inner = Env {onInner :: MemAliases -> inner -> MemAliasesM inner MemAliases} type MemAliasesM inner a = Reader (Env inner) a analyzeHostOp :: MemAliases -> HostOp NoOp GPUMem -> MemAliasesM (HostOp NoOp GPUMem) MemAliases analyzeHostOp m (SegOp (SegMap _ _ _ kbody)) = analyzeStms (kernelBodyStms kbody) m analyzeHostOp m (SegOp (SegRed _ _ _ _ kbody)) = analyzeStms (kernelBodyStms kbody) m analyzeHostOp m (SegOp (SegScan _ _ _ _ kbody)) = analyzeStms (kernelBodyStms kbody) m analyzeHostOp m (SegOp (SegHist _ _ _ _ kbody)) = analyzeStms (kernelBodyStms kbody) m analyzeHostOp m SizeOp {} = pure m analyzeHostOp m GPUBody {} = pure m analyzeHostOp m (OtherOp NoOp) = pure m analyzeStm :: (Mem rep inner, LetDec rep ~ LetDecMem) => MemAliases -> Stm rep -> MemAliasesM (inner rep) MemAliases analyzeStm m (Let (Pat [PatElem vname _]) _ (Op (Alloc _ _))) = pure $ m <> singleton vname mempty analyzeStm m (Let _ _ (Op (Inner inner))) = do on_inner <- asks onInner on_inner m inner analyzeStm m (Let pat _ (Match _ cases defbody _)) = do let bodies = defbody : map caseBody cases m' <- foldM (flip analyzeStms) m $ map bodyStms bodies foldMap (zip (patNames pat) . map resSubExp . bodyResult) bodies & mapMaybe (filterFun m') & foldr (uncurry addAlias) m' & pure analyzeStm m (Let pat _ (Loop params _ body)) = do let m_init = map snd params & zip (patNames pat) & mapMaybe (filterFun m) & foldr (uncurry addAlias) m m_params = mapMaybe (filterFun m_init . first paramName) params & foldr (uncurry addAlias) m_init m_body <- analyzeStms (bodyStms body) m_params zip (patNames pat) (map resSubExp $ bodyResult body) & mapMaybe (filterFun m_body) & foldr (uncurry addAlias) m_body & pure analyzeStm m _ = pure m filterFun :: MemAliases -> (VName, SubExp) -> Maybe (VName, VName) filterFun m' (v, Var v') | v' `isIn` m' = Just (v, v') filterFun _ _ = Nothing analyzeStms :: (Mem rep inner, LetDec rep ~ LetDecMem) => Stms rep -> MemAliases -> MemAliasesM (inner rep) MemAliases analyzeStms = flip $ foldM analyzeStm analyzeFun :: (Mem rep inner, LetDec rep ~ LetDecMem) => FunDef rep -> MemAliasesM (inner rep) (Name, MemAliases) analyzeFun f = funDefParams f & mapMaybe justMem & mconcat & analyzeStms (bodyStms $ funDefBody f) <&> (funDefName f,) where justMem (Param _ v (MemMem _)) = Just $ singleton v mempty justMem _ = Nothing transitiveClosure :: MemAliases -> MemAliases transitiveClosure ma@(MemAliases m) = M.foldMapWithKey ( \k ns -> namesToList ns & foldMap (aliasesOf ma) & singleton k ) m <> ma -- | Produce aliases for constants and for each function. analyzeSeqMem :: Prog SeqMem -> (MemAliases, M.Map Name MemAliases) analyzeSeqMem prog = completeBijection $ runReader (analyze prog) $ Env $ \x _ -> pure x -- | Produce aliases for constants and for each function. analyzeGPUMem :: Prog GPUMem -> (MemAliases, M.Map Name MemAliases) analyzeGPUMem prog = completeBijection $ runReader (analyze prog) $ Env analyzeHostOp analyze :: (Mem rep inner, LetDec rep ~ LetDecMem) => Prog rep -> MemAliasesM (inner rep) (MemAliases, M.Map Name MemAliases) analyze prog = (,) <$> (progConsts prog & flip analyzeStms mempty <&> fixPoint transitiveClosure) <*> (progFuns prog & mapM analyzeFun <&> M.fromList <&> M.map (fixPoint transitiveClosure)) completeBijection :: (MemAliases, M.Map Name MemAliases) -> (MemAliases, M.Map Name MemAliases) completeBijection (a, bs) = (f a, fmap f bs) where f ma@(MemAliases m) = M.foldMapWithKey (\k ns -> foldMap (`singleton` oneName k) (namesToList ns)) m <> ma futhark-0.25.27/src/Futhark/Analysis/Metrics.hs000066400000000000000000000111231475065116200212600ustar00rootroot00000000000000-- | Abstract Syntax Tree metrics. This is used in the @futhark test@ -- program, for the @structure@ stanzas. module Futhark.Analysis.Metrics ( AstMetrics (..), progMetrics, -- * Extensibility OpMetrics (..), seen, inside, MetricsM, stmMetrics, lambdaMetrics, bodyMetrics, ) where import Control.Monad import Control.Monad.Writer import Data.List (tails) import Data.Map.Strict qualified as M import Data.Text (Text) import Data.Text qualified as T import Futhark.Analysis.Metrics.Type import Futhark.IR import Futhark.Util (showText) -- | Compute the metrics for some operation. class OpMetrics op where opMetrics :: op -> MetricsM () instance (OpMetrics a) => OpMetrics (Maybe a) where opMetrics Nothing = pure () opMetrics (Just x) = opMetrics x instance OpMetrics (NoOp rep) where opMetrics NoOp = pure () newtype CountMetrics = CountMetrics [([Text], Text)] instance Semigroup CountMetrics where CountMetrics x <> CountMetrics y = CountMetrics $ x <> y instance Monoid CountMetrics where mempty = CountMetrics mempty actualMetrics :: CountMetrics -> AstMetrics actualMetrics (CountMetrics metrics) = AstMetrics $ M.fromListWith (+) $ concatMap expand metrics where expand (ctx, k) = [ (T.intercalate "/" (ctx' ++ [k]), 1) | ctx' <- tails $ "" : ctx ] -- | This monad is used for computing metrics. It internally keeps -- track of what we've seen so far. Use 'seen' to add more stuff. newtype MetricsM a = MetricsM {runMetricsM :: Writer CountMetrics a} deriving ( Monad, Applicative, Functor, MonadWriter CountMetrics ) -- | Add this node to the current tally. seen :: Text -> MetricsM () seen k = tell $ CountMetrics [([], k)] -- | Enclose a metrics counting operation. Most importantly, this -- prefixes the name of the context to all the metrics computed in the -- enclosed operation. inside :: Text -> MetricsM () -> MetricsM () inside what m = seen what >> censor addWhat m where addWhat (CountMetrics metrics) = CountMetrics (map addWhat' metrics) addWhat' (ctx, k) = (what : ctx, k) -- | Compute the metrics for a program. progMetrics :: (OpMetrics (Op rep)) => Prog rep -> AstMetrics progMetrics prog = actualMetrics $ execWriter $ runMetricsM $ do mapM_ funDefMetrics $ progFuns prog mapM_ stmMetrics $ progConsts prog funDefMetrics :: (OpMetrics (Op rep)) => FunDef rep -> MetricsM () funDefMetrics = bodyMetrics . funDefBody -- | Compute metrics for this body. bodyMetrics :: (OpMetrics (Op rep)) => Body rep -> MetricsM () bodyMetrics = mapM_ stmMetrics . bodyStms -- | Compute metrics for this statement. stmMetrics :: (OpMetrics (Op rep)) => Stm rep -> MetricsM () stmMetrics = expMetrics . stmExp expMetrics :: (OpMetrics (Op rep)) => Exp rep -> MetricsM () expMetrics (BasicOp op) = seen "BasicOp" >> basicOpMetrics op expMetrics (Loop _ ForLoop {} body) = inside "Loop" $ seen "ForLoop" >> bodyMetrics body expMetrics (Loop _ WhileLoop {} body) = inside "Loop" $ seen "WhileLoop" >> bodyMetrics body expMetrics (Match _ [Case [Just (BoolValue True)] tb] fb _) = inside "If" $ do inside "True" $ bodyMetrics tb inside "False" $ bodyMetrics fb expMetrics (Match _ cases defbody _) = inside "Match" $ do forM_ (zip [0 ..] cases) $ \(i, c) -> inside (showText (i :: Int)) $ bodyMetrics $ caseBody c inside "default" $ bodyMetrics defbody expMetrics Apply {} = seen "Apply" expMetrics (WithAcc _ lam) = inside "WithAcc" $ lambdaMetrics lam expMetrics (Op op) = opMetrics op basicOpMetrics :: BasicOp -> MetricsM () basicOpMetrics (SubExp _) = seen "SubExp" basicOpMetrics (Opaque _ _) = seen "Opaque" basicOpMetrics ArrayVal {} = seen "ArrayVal" basicOpMetrics ArrayLit {} = seen "ArrayLit" basicOpMetrics BinOp {} = seen "BinOp" basicOpMetrics UnOp {} = seen "UnOp" basicOpMetrics ConvOp {} = seen "ConvOp" basicOpMetrics CmpOp {} = seen "CmpOp" basicOpMetrics Assert {} = seen "Assert" basicOpMetrics Index {} = seen "Index" basicOpMetrics Update {} = seen "Update" basicOpMetrics FlatIndex {} = seen "FlatIndex" basicOpMetrics FlatUpdate {} = seen "FlatUpdate" basicOpMetrics Concat {} = seen "Concat" basicOpMetrics Manifest {} = seen "Manifest" basicOpMetrics Iota {} = seen "Iota" basicOpMetrics Replicate {} = seen "Replicate" basicOpMetrics Scratch {} = seen "Scratch" basicOpMetrics Reshape {} = seen "Reshape" basicOpMetrics Rearrange {} = seen "Rearrange" basicOpMetrics UpdateAcc {} = seen "UpdateAcc" -- | Compute metrics for this lambda. lambdaMetrics :: (OpMetrics (Op rep)) => Lambda rep -> MetricsM () lambdaMetrics = bodyMetrics . lambdaBody futhark-0.25.27/src/Futhark/Analysis/Metrics/000077500000000000000000000000001475065116200207265ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Analysis/Metrics/Type.hs000066400000000000000000000016031475065116200222030ustar00rootroot00000000000000-- | The data type definition for "Futhark.Analysis.Metrics", factored -- out to simplify the module import hierarchies when working on the -- test modules. module Futhark.Analysis.Metrics.Type (AstMetrics (..)) where import Data.Map.Strict qualified as M import Data.Text (Text) import Data.Text qualified as T -- | AST metrics are simply a collection from identifiable node names -- to the number of times that node appears. newtype AstMetrics = AstMetrics (M.Map Text Int) instance Show AstMetrics where show (AstMetrics m) = unlines $ map metric $ M.toList m where metric (k, v) = T.unpack k ++ " " ++ show v instance Read AstMetrics where readsPrec _ s = maybe [] success $ mapM onLine $ lines s where onLine l = case words l of [k, x] | [(n, "")] <- reads x -> Just (T.pack k, n) _ -> Nothing success m = [(AstMetrics $ M.fromList m, "")] futhark-0.25.27/src/Futhark/Analysis/PrimExp.hs000066400000000000000000000642751475065116200212560ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-redundant-constraints #-} -- | A primitive expression is an expression where the non-leaves are -- primitive operators. Our representation does not guarantee that -- the expression is type-correct. module Futhark.Analysis.PrimExp ( PrimExp (..), TPrimExp (..), isInt8, isInt16, isInt32, isInt64, isBool, isF16, isF32, isF64, evalPrimExp, primExpType, primExpSizeAtLeast, coerceIntPrimExp, leafExpTypes, true, false, fromBool, constFoldPrimExp, -- * Construction module Language.Futhark.Primitive, NumExp (..), IntExp (..), FloatExp (..), sExt, zExt, (.&&.), (.||.), (.<.), (.<=.), (.>.), (.>=.), (.==.), (.&.), (.|.), (.^.), (.>>.), (.<<.), bNot, sMax32, sMin32, sMax64, sMin64, sExt32, sExt64, zExt32, zExt64, sExtAs, fMin16, fMin32, fMin64, fMax16, fMax32, fMax64, condExp, -- * Untyped construction (~*~), (~/~), (~+~), (~-~), (~==~), ) where import Control.Category import Control.Monad import Data.Map qualified as M import Data.Set qualified as S import Data.Text qualified as T import Data.Traversable import Futhark.IR.Prop.Names import Futhark.Util.IntegralExp import Futhark.Util.Pretty import Language.Futhark.Primitive import Prelude hiding (id, (.)) -- | A primitive expression parametrised over the representation of -- free variables. Note that the 'Functor', 'Traversable', and 'Num' -- instances perform automatic (but simple) constant folding. -- -- Note also that the 'Num' instance assumes 'OverflowUndef' -- semantics! data PrimExp v = LeafExp v PrimType | ValueExp PrimValue | BinOpExp BinOp (PrimExp v) (PrimExp v) | CmpOpExp CmpOp (PrimExp v) (PrimExp v) | UnOpExp UnOp (PrimExp v) | ConvOpExp ConvOp (PrimExp v) | FunExp T.Text [PrimExp v] PrimType deriving (Eq, Ord, Show) instance Functor PrimExp where fmap = fmapDefault instance Foldable PrimExp where foldMap = foldMapDefault instance Traversable PrimExp where traverse f (LeafExp v t) = LeafExp <$> f v <*> pure t traverse _ (ValueExp v) = pure $ ValueExp v traverse f (BinOpExp op x y) = BinOpExp op <$> traverse f x <*> traverse f y traverse f (CmpOpExp op x y) = CmpOpExp op <$> traverse f x <*> traverse f y traverse f (ConvOpExp op x) = ConvOpExp op <$> traverse f x traverse f (UnOpExp op x) = UnOpExp op <$> traverse f x traverse f (FunExp h args t) = FunExp h <$> traverse (traverse f) args <*> pure t instance (FreeIn v) => FreeIn (PrimExp v) where freeIn' = foldMap freeIn' -- | A 'PrimExp' tagged with a phantom type used to provide type-safe -- construction. Does not guarantee that the underlying expression is -- actually type correct. newtype TPrimExp t v = TPrimExp {untyped :: PrimExp v} deriving (Eq, Ord, Show) instance Functor (TPrimExp t) where fmap = fmapDefault instance Foldable (TPrimExp t) where foldMap = foldMapDefault instance Traversable (TPrimExp t) where traverse f (TPrimExp e) = TPrimExp <$> traverse f e instance (FreeIn v) => FreeIn (TPrimExp t v) where freeIn' = freeIn' . untyped -- | This expression is of type t'Int8'. isInt8 :: PrimExp v -> TPrimExp Int8 v isInt8 = TPrimExp -- | This expression is of type t'Int16'. isInt16 :: PrimExp v -> TPrimExp Int16 v isInt16 = TPrimExp -- | This expression is of type t'Int32'. isInt32 :: PrimExp v -> TPrimExp Int32 v isInt32 = TPrimExp -- | This expression is of type t'Int64'. isInt64 :: PrimExp v -> TPrimExp Int64 v isInt64 = TPrimExp -- | This is a boolean expression. isBool :: PrimExp v -> TPrimExp Bool v isBool = TPrimExp -- | This expression is of type t'Half'. isF16 :: PrimExp v -> TPrimExp Half v isF16 = TPrimExp -- | This expression is of type t'Float'. isF32 :: PrimExp v -> TPrimExp Float v isF32 = TPrimExp -- | This expression is of type t'Double'. isF64 :: PrimExp v -> TPrimExp Double v isF64 = TPrimExp -- | True if the 'PrimExp' has at least this many nodes. This can be -- much more efficient than comparing with 'length' for large -- 'PrimExp's, as this function is lazy. primExpSizeAtLeast :: Int -> PrimExp v -> Bool primExpSizeAtLeast k = maybe True (>= k) . descend 0 where descend i _ | i >= k = Nothing descend i LeafExp {} = Just (i + 1) descend i ValueExp {} = Just (i + 1) descend i (BinOpExp _ x y) = do x' <- descend (i + 1) x descend x' y descend i (CmpOpExp _ x y) = do x' <- descend (i + 1) x descend x' y descend i (ConvOpExp _ x) = descend (i + 1) x descend i (UnOpExp _ x) = descend (i + 1) x descend i (FunExp _ args _) = foldM descend (i + 1) args -- | Perform quick and dirty constant folding on the top level of a -- PrimExp. This is necessary because we want to consider -- e.g. equality modulo constant folding. constFoldPrimExp :: PrimExp v -> PrimExp v constFoldPrimExp (BinOpExp Add {} x y) | zeroIshExp x = y | zeroIshExp y = x constFoldPrimExp (BinOpExp Sub {} x y) | zeroIshExp y = x constFoldPrimExp (BinOpExp Mul {} x y) | oneIshExp x = y | oneIshExp y = x | zeroIshExp x, IntType it <- primExpType y = ValueExp $ IntValue $ intValue it (0 :: Int) | zeroIshExp y, IntType it <- primExpType x = ValueExp $ IntValue $ intValue it (0 :: Int) constFoldPrimExp (BinOpExp SDiv {} x y) | oneIshExp y = x constFoldPrimExp (BinOpExp SQuot {} x y) | oneIshExp y = x constFoldPrimExp (BinOpExp UDiv {} x y) | oneIshExp y = x constFoldPrimExp (BinOpExp bop (ValueExp x) (ValueExp y)) | Just z <- doBinOp bop x y = ValueExp z constFoldPrimExp (BinOpExp LogAnd x y) | oneIshExp x = y | oneIshExp y = x | zeroIshExp x = x | zeroIshExp y = y constFoldPrimExp (BinOpExp LogOr x y) | oneIshExp x = x | oneIshExp y = y | zeroIshExp x = y | zeroIshExp y = x constFoldPrimExp (UnOpExp Abs {} x) | not $ negativeIshExp x = x constFoldPrimExp (UnOpExp (Neg _) (ValueExp (BoolValue x))) = ValueExp $ BoolValue $ not x constFoldPrimExp (BinOpExp UMod {} x y) | sameIshExp x y, IntType it <- primExpType x = ValueExp $ IntValue $ intValue it (0 :: Integer) constFoldPrimExp (BinOpExp SMod {} x y) | sameIshExp x y, IntType it <- primExpType x = ValueExp $ IntValue $ intValue it (0 :: Integer) constFoldPrimExp (BinOpExp SRem {} x y) | sameIshExp x y, IntType it <- primExpType x = ValueExp $ IntValue $ intValue it (0 :: Integer) constFoldPrimExp e = e constFoldCmpExp :: (Eq v) => PrimExp v -> PrimExp v constFoldCmpExp (CmpOpExp (CmpEq _) x y) | x == y = untyped true constFoldCmpExp (CmpOpExp (CmpEq _) (ValueExp x) (ValueExp y)) | x /= y = untyped false constFoldCmpExp e = constFoldPrimExp e -- | The class of numeric types that can be used for constructing -- 'TPrimExp's. class NumExp t where -- | Construct a typed expression from an integer. fromInteger' :: Integer -> TPrimExp t v -- | Construct a numeric expression from a boolean expression. This -- can be used to encode arithmetic control flow. fromBoolExp :: TPrimExp Bool v -> TPrimExp t v -- | The class of integer types that can be used for constructing -- 'TPrimExp's. class (NumExp t) => IntExp t where -- | The type of an expression, known to be an integer type. expIntType :: TPrimExp t v -> IntType instance NumExp Int8 where fromInteger' = isInt8 . ValueExp . IntValue . Int8Value . fromInteger fromBoolExp = isInt8 . ConvOpExp (BToI Int8) . untyped instance IntExp Int8 where expIntType = const Int8 instance NumExp Int16 where fromInteger' = isInt16 . ValueExp . IntValue . Int16Value . fromInteger fromBoolExp = isInt16 . ConvOpExp (BToI Int16) . untyped instance IntExp Int16 where expIntType = const Int16 instance NumExp Int32 where fromInteger' = isInt32 . ValueExp . IntValue . Int32Value . fromInteger fromBoolExp = isInt32 . ConvOpExp (BToI Int32) . untyped instance IntExp Int32 where expIntType = const Int32 instance NumExp Int64 where fromInteger' = isInt64 . ValueExp . IntValue . Int64Value . fromInteger fromBoolExp = isInt64 . ConvOpExp (BToI Int64) . untyped instance IntExp Int64 where expIntType = const Int64 -- | The class of floating-point types that can be used for -- constructing 'TPrimExp's. class (NumExp t) => FloatExp t where -- | Construct a typed expression from a rational. fromRational' :: Rational -> TPrimExp t v -- | The type of an expression, known to be a floating-point type. expFloatType :: TPrimExp t v -> FloatType instance NumExp Half where fromInteger' = isF16 . ValueExp . FloatValue . Float16Value . fromInteger fromBoolExp = isF16 . ConvOpExp (SIToFP Int16 Float16) . ConvOpExp (BToI Int16) . untyped instance NumExp Float where fromInteger' = isF32 . ValueExp . FloatValue . Float32Value . fromInteger fromBoolExp = isF32 . ConvOpExp (SIToFP Int32 Float32) . ConvOpExp (BToI Int32) . untyped instance NumExp Double where fromInteger' = TPrimExp . ValueExp . FloatValue . Float64Value . fromInteger fromBoolExp = isF64 . ConvOpExp (SIToFP Int32 Float64) . ConvOpExp (BToI Int32) . untyped instance FloatExp Half where fromRational' = TPrimExp . ValueExp . FloatValue . Float16Value . fromRational expFloatType = const Float16 instance FloatExp Float where fromRational' = TPrimExp . ValueExp . FloatValue . Float32Value . fromRational expFloatType = const Float32 instance FloatExp Double where fromRational' = TPrimExp . ValueExp . FloatValue . Float64Value . fromRational expFloatType = const Float64 instance (NumExp t, Pretty v) => Num (TPrimExp t v) where TPrimExp x + TPrimExp y | Just z <- msum [ asIntOp (`Add` OverflowUndef) x y, asFloatOp FAdd x y ] = TPrimExp $ constFoldPrimExp z | otherwise = numBad "+" (x, y) TPrimExp x - TPrimExp y | Just z <- msum [ asIntOp (`Sub` OverflowUndef) x y, asFloatOp FSub x y ] = TPrimExp $ constFoldPrimExp z | otherwise = numBad "-" (x, y) TPrimExp x * TPrimExp y | Just z <- msum [ asIntOp (`Mul` OverflowUndef) x y, asFloatOp FMul x y ] = TPrimExp $ constFoldPrimExp z | otherwise = numBad "*" (x, y) abs (TPrimExp x) | IntType t <- primExpType x = TPrimExp $ constFoldPrimExp $ UnOpExp (Abs t) x | FloatType t <- primExpType x = TPrimExp $ constFoldPrimExp $ UnOpExp (FAbs t) x | otherwise = numBad "abs" x signum (TPrimExp x) | IntType t <- primExpType x = TPrimExp $ UnOpExp (SSignum t) x | otherwise = numBad "signum" x fromInteger = fromInteger' instance (FloatExp t, Pretty v) => Fractional (TPrimExp t v) where TPrimExp x / TPrimExp y | Just z <- msum [asFloatOp FDiv x y] = TPrimExp $ constFoldPrimExp z | otherwise = numBad "/" (x, y) fromRational = fromRational' instance (Pretty v) => Floating (TPrimExp Half v) where x ** y = isF16 $ BinOpExp (FPow Float16) (untyped x) (untyped y) pi = isF16 $ ValueExp $ FloatValue $ Float16Value pi exp x = isF16 $ FunExp "exp16" [untyped x] $ FloatType Float16 log x = isF16 $ FunExp "log16" [untyped x] $ FloatType Float16 sin x = isF16 $ FunExp "sin16" [untyped x] $ FloatType Float16 cos x = isF16 $ FunExp "cos16" [untyped x] $ FloatType Float16 tan x = isF16 $ FunExp "tan16" [untyped x] $ FloatType Float16 asin x = isF16 $ FunExp "asin16" [untyped x] $ FloatType Float16 acos x = isF16 $ FunExp "acos16" [untyped x] $ FloatType Float16 atan x = isF16 $ FunExp "atan16" [untyped x] $ FloatType Float16 sinh x = isF16 $ FunExp "sinh16" [untyped x] $ FloatType Float16 cosh x = isF16 $ FunExp "cosh16" [untyped x] $ FloatType Float16 tanh x = isF16 $ FunExp "tanh16" [untyped x] $ FloatType Float16 asinh x = isF16 $ FunExp "asinh16" [untyped x] $ FloatType Float16 acosh x = isF16 $ FunExp "acosh16" [untyped x] $ FloatType Float16 atanh x = isF16 $ FunExp "atanh16" [untyped x] $ FloatType Float16 instance (Pretty v) => Floating (TPrimExp Float v) where x ** y = isF32 $ BinOpExp (FPow Float32) (untyped x) (untyped y) pi = isF32 $ ValueExp $ FloatValue $ Float32Value pi exp x = isF32 $ FunExp "exp32" [untyped x] $ FloatType Float32 log x = isF32 $ FunExp "log32" [untyped x] $ FloatType Float32 sin x = isF32 $ FunExp "sin32" [untyped x] $ FloatType Float32 cos x = isF32 $ FunExp "cos32" [untyped x] $ FloatType Float32 tan x = isF32 $ FunExp "tan32" [untyped x] $ FloatType Float32 asin x = isF32 $ FunExp "asin32" [untyped x] $ FloatType Float32 acos x = isF32 $ FunExp "acos32" [untyped x] $ FloatType Float32 atan x = isF32 $ FunExp "atan32" [untyped x] $ FloatType Float32 sinh x = isF32 $ FunExp "sinh32" [untyped x] $ FloatType Float32 cosh x = isF32 $ FunExp "cosh32" [untyped x] $ FloatType Float32 tanh x = isF32 $ FunExp "tanh32" [untyped x] $ FloatType Float32 asinh x = isF32 $ FunExp "asinh32" [untyped x] $ FloatType Float32 acosh x = isF32 $ FunExp "acosh32" [untyped x] $ FloatType Float32 atanh x = isF32 $ FunExp "atanh32" [untyped x] $ FloatType Float32 instance (Pretty v) => Floating (TPrimExp Double v) where x ** y = isF64 $ BinOpExp (FPow Float64) (untyped x) (untyped y) pi = isF64 $ ValueExp $ FloatValue $ Float64Value pi exp x = isF64 $ FunExp "exp64" [untyped x] $ FloatType Float64 log x = isF64 $ FunExp "log64" [untyped x] $ FloatType Float64 sin x = isF64 $ FunExp "sin64" [untyped x] $ FloatType Float64 cos x = isF64 $ FunExp "cos64" [untyped x] $ FloatType Float64 tan x = isF64 $ FunExp "tan64" [untyped x] $ FloatType Float64 asin x = isF64 $ FunExp "asin64" [untyped x] $ FloatType Float64 acos x = isF64 $ FunExp "acos64" [untyped x] $ FloatType Float64 atan x = isF64 $ FunExp "atan64" [untyped x] $ FloatType Float64 sinh x = isF64 $ FunExp "sinh64" [untyped x] $ FloatType Float64 cosh x = isF64 $ FunExp "cosh64" [untyped x] $ FloatType Float64 tanh x = isF64 $ FunExp "tanh64" [untyped x] $ FloatType Float64 asinh x = isF64 $ FunExp "asinh64" [untyped x] $ FloatType Float64 acosh x = isF64 $ FunExp "acosh64" [untyped x] $ FloatType Float64 atanh x = isF64 $ FunExp "atanh64" [untyped x] $ FloatType Float64 instance (IntExp t, Pretty v) => IntegralExp (TPrimExp t v) where TPrimExp x `div` TPrimExp y | Just z <- msum [ asIntOp (`SDiv` Unsafe) x y, asFloatOp FDiv x y ] = TPrimExp $ constFoldPrimExp z | otherwise = numBad "div" (x, y) TPrimExp x `mod` TPrimExp y | Just z <- msum [asIntOp (`SMod` Unsafe) x y] = TPrimExp $ constFoldPrimExp z | otherwise = numBad "mod" (x, y) TPrimExp x `quot` TPrimExp y | oneIshExp y = TPrimExp x | Just z <- msum [asIntOp (`SQuot` Unsafe) x y] = TPrimExp $ constFoldPrimExp z | otherwise = numBad "quot" (x, y) TPrimExp x `rem` TPrimExp y | Just z <- msum [asIntOp (`SRem` Unsafe) x y] = TPrimExp $ constFoldPrimExp z | otherwise = numBad "rem" (x, y) TPrimExp x `divUp` TPrimExp y | Just z <- msum [asIntOp (`SDivUp` Unsafe) x y] = TPrimExp $ constFoldPrimExp z | otherwise = numBad "divRoundingUp" (x, y) TPrimExp x `pow` TPrimExp y | Just z <- msum [ asIntOp Pow x y, asFloatOp FPow x y ] = TPrimExp $ constFoldPrimExp z | otherwise = numBad "pow" (x, y) sgn (TPrimExp (ValueExp (IntValue i))) = Just $ signum $ valueIntegral i sgn _ = Nothing -- | Lifted logical conjunction. (.&&.) :: (Eq v) => TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v TPrimExp x .&&. TPrimExp y = TPrimExp $ constFoldPrimExp $ BinOpExp LogAnd x y -- | Lifted logical conjunction. (.||.) :: (Eq v) => TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v TPrimExp x .||. TPrimExp y = TPrimExp $ constFoldPrimExp $ BinOpExp LogOr x y -- | Lifted relational operators; assuming signed numbers in case of -- integers. (.<.), (.>.), (.<=.), (.>=.), (.==.) :: (Eq v) => TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v TPrimExp x .<. TPrimExp y = TPrimExp $ constFoldCmpExp $ CmpOpExp cmp x y where cmp = case primExpType x of IntType t -> CmpSlt t FloatType t -> FCmpLt t _ -> CmpLlt TPrimExp x .<=. TPrimExp y = TPrimExp $ constFoldCmpExp $ CmpOpExp cmp x y where cmp = case primExpType x of IntType t -> CmpSle t FloatType t -> FCmpLe t _ -> CmpLle TPrimExp x .==. TPrimExp y = TPrimExp $ constFoldCmpExp $ CmpOpExp (CmpEq t) x y where t = primExpType x `min` primExpType y x .>. y = y .<. x x .>=. y = y .<=. x -- | Lifted bitwise operators. The right-shift is logical, *not* arithmetic. (.&.), (.|.), (.^.), (.>>.), (.<<.) :: (Eq v) => TPrimExp t v -> TPrimExp t v -> TPrimExp t v bitPrimExp :: (Eq v) => (IntType -> BinOp) -> TPrimExp t v -> TPrimExp t v -> TPrimExp t v bitPrimExp op (TPrimExp x) (TPrimExp y) = TPrimExp $ constFoldPrimExp $ BinOpExp (op $ primExpIntType x) x y (.&.) = bitPrimExp And (.|.) = bitPrimExp Or (.^.) = bitPrimExp Xor (.>>.) = bitPrimExp LShr (.<<.) = bitPrimExp Shl infix 4 .==., .<., .>., .<=., .>=. infixr 3 .&&. infixr 2 .||. -- | Untyped smart constructor for sign extension that does a bit of -- constant folding. sExt :: IntType -> PrimExp v -> PrimExp v sExt it (ValueExp (IntValue v)) = ValueExp $ IntValue $ doSExt v it sExt it e | primExpIntType e == it = e | otherwise = ConvOpExp (SExt (primExpIntType e) it) e -- | Untyped smart constructor for zero extension that does a bit of -- constant folding. zExt :: IntType -> PrimExp v -> PrimExp v zExt it (ValueExp (IntValue v)) = ValueExp $ IntValue $ doZExt v it zExt it e | primExpIntType e == it = e | otherwise = ConvOpExp (ZExt (primExpIntType e) it) e asIntOp :: (IntType -> BinOp) -> PrimExp v -> PrimExp v -> Maybe (PrimExp v) asIntOp f x y | IntType x_t <- primExpType x = Just $ BinOpExp (f x_t) x y | otherwise = Nothing asFloatOp :: (FloatType -> BinOp) -> PrimExp v -> PrimExp v -> Maybe (PrimExp v) asFloatOp f x y | FloatType t <- primExpType x = Just $ BinOpExp (f t) x y | otherwise = Nothing numBad :: (Pretty a) => String -> a -> b numBad s x = error $ "Invalid argument to PrimExp method " ++ s ++ ": " ++ prettyString x -- | Evaluate a 'PrimExp' in the given monad. Invokes 'fail' on type -- errors. evalPrimExp :: (Pretty v, MonadFail m) => (v -> m PrimValue) -> PrimExp v -> m PrimValue evalPrimExp f (LeafExp v _) = f v evalPrimExp _ (ValueExp v) = pure v evalPrimExp f (BinOpExp op x y) = do x' <- evalPrimExp f x y' <- evalPrimExp f y maybe (evalBad op (x, y)) pure $ doBinOp op x' y' evalPrimExp f (CmpOpExp op x y) = do x' <- evalPrimExp f x y' <- evalPrimExp f y maybe (evalBad op (x, y)) (pure . BoolValue) $ doCmpOp op x' y' evalPrimExp f (UnOpExp op x) = do x' <- evalPrimExp f x maybe (evalBad op x) pure $ doUnOp op x' evalPrimExp f (ConvOpExp op x) = do x' <- evalPrimExp f x maybe (evalBad op x) pure $ doConvOp op x' evalPrimExp f (FunExp h args _) = do args' <- mapM (evalPrimExp f) args maybe (evalBad h args) pure $ do (_, _, fun) <- M.lookup h primFuns fun args' evalBad :: (Pretty a, Pretty b, MonadFail m) => a -> b -> m c evalBad op arg = fail $ "evalPrimExp: Type error when applying " ++ prettyString op ++ " to " ++ prettyString arg -- | The type of values returned by a 'PrimExp'. This function -- returning does not imply that the 'PrimExp' is type-correct. primExpType :: PrimExp v -> PrimType primExpType (LeafExp _ t) = t primExpType (ValueExp v) = primValueType v primExpType (BinOpExp op _ _) = binOpType op primExpType CmpOpExp {} = Bool primExpType (UnOpExp op _) = unOpType op primExpType (ConvOpExp op _) = snd $ convOpType op primExpType (FunExp _ _ t) = t -- | Is the expression a constant zero of some sort? zeroIshExp :: PrimExp v -> Bool zeroIshExp (ValueExp v) = zeroIsh v zeroIshExp _ = False -- | Is the expression a constant one of some sort? oneIshExp :: PrimExp v -> Bool oneIshExp (ValueExp v) = oneIsh v oneIshExp _ = False -- | Is the expression a constant negative of some sort? negativeIshExp :: PrimExp v -> Bool negativeIshExp (ValueExp v) = negativeIsh v negativeIshExp _ = False sameIshExp :: PrimExp v -> PrimExp v -> Bool sameIshExp (ValueExp v1) (ValueExp v2) = v1 == v2 sameIshExp _ _ = False -- | If the given 'PrimExp' is a constant of the wrong integer type, -- coerce it to the given integer type. This is a workaround for an -- issue in the 'Num' instance. coerceIntPrimExp :: IntType -> PrimExp v -> PrimExp v coerceIntPrimExp t (ValueExp (IntValue v)) = ValueExp $ IntValue $ doSExt v t coerceIntPrimExp _ e = e primExpIntType :: PrimExp v -> IntType primExpIntType e = case primExpType e of IntType t -> t _ -> Int64 -- | Boolean-valued PrimExps. true, false :: TPrimExp Bool v true = TPrimExp $ ValueExp $ BoolValue True false = TPrimExp $ ValueExp $ BoolValue False -- | Conversion from Bool to 'TPrimExp' fromBool :: Bool -> TPrimExp Bool v fromBool b = if b then true else false -- | Boolean negation smart constructor. bNot :: TPrimExp Bool v -> TPrimExp Bool v bNot = TPrimExp . UnOpExp (Neg Bool) . untyped -- | SMax on 32-bit integers. sMax32 :: TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v sMax32 x y = TPrimExp $ BinOpExp (SMax Int32) (untyped x) (untyped y) -- | SMin on 32-bit integers. sMin32 :: TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v sMin32 x y = TPrimExp $ BinOpExp (SMin Int32) (untyped x) (untyped y) -- | SMax on 64-bit integers. sMax64 :: TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v sMax64 x y = TPrimExp $ BinOpExp (SMax Int64) (untyped x) (untyped y) -- | SMin on 64-bit integers. sMin64 :: TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v sMin64 x y = TPrimExp $ BinOpExp (SMin Int64) (untyped x) (untyped y) -- | Sign-extend to 32 bit integer. sExt32 :: (IntExp t) => TPrimExp t v -> TPrimExp Int32 v sExt32 = isInt32 . sExt Int32 . untyped -- | Sign-extend to 64 bit integer. sExt64 :: (IntExp t) => TPrimExp t v -> TPrimExp Int64 v sExt64 = isInt64 . sExt Int64 . untyped -- | Zero-extend to 32 bit integer. zExt32 :: (IntExp t) => TPrimExp t v -> TPrimExp Int32 v zExt32 = isInt32 . zExt Int32 . untyped -- | Zero-extend to 64 bit integer. zExt64 :: (IntExp t) => TPrimExp t v -> TPrimExp Int64 v zExt64 = isInt64 . zExt Int64 . untyped -- | 16-bit float minimum. fMin16 :: TPrimExp Half v -> TPrimExp Half v -> TPrimExp Half v fMin16 x y = isF16 $ BinOpExp (FMin Float16) (untyped x) (untyped y) -- | 32-bit float minimum. fMin32 :: TPrimExp Float v -> TPrimExp Float v -> TPrimExp Float v fMin32 x y = isF32 $ BinOpExp (FMin Float32) (untyped x) (untyped y) -- | 64-bit float minimum. fMin64 :: TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v fMin64 x y = isF64 $ BinOpExp (FMin Float64) (untyped x) (untyped y) -- | 16-bit float maximum. fMax16 :: TPrimExp Half v -> TPrimExp Half v -> TPrimExp Half v fMax16 x y = isF16 $ BinOpExp (FMax Float16) (untyped x) (untyped y) -- | 32-bit float maximum. fMax32 :: TPrimExp Float v -> TPrimExp Float v -> TPrimExp Float v fMax32 x y = isF32 $ BinOpExp (FMax Float32) (untyped x) (untyped y) -- | 64-bit float maximum. fMax64 :: TPrimExp Double v -> TPrimExp Double v -> TPrimExp Double v fMax64 x y = isF64 $ BinOpExp (FMax Float64) (untyped x) (untyped y) -- | Conditional expression. condExp :: TPrimExp Bool v -> TPrimExp t v -> TPrimExp t v -> TPrimExp t v condExp x y z = TPrimExp $ FunExp (condFun t) [untyped x, untyped y, untyped z] t where t = primExpType $ untyped y -- | Convert result of some integer expression to have the same type -- as another, using sign extension. sExtAs :: (IntExp to, IntExp from) => TPrimExp from v -> TPrimExp to v -> TPrimExp to v sExtAs from to = TPrimExp $ sExt (expIntType to) (untyped from) -- Prettyprinting instances instance (Pretty v) => Pretty (PrimExp v) where pretty (LeafExp v _) = pretty v pretty (ValueExp v) = pretty v pretty (BinOpExp op x y) = pretty op <+> parens (pretty x) <+> parens (pretty y) pretty (CmpOpExp op x y) = pretty op <+> parens (pretty x) <+> parens (pretty y) pretty (ConvOpExp op x) = pretty op <+> parens (pretty x) pretty (UnOpExp op x) = pretty op <+> parens (pretty x) pretty (FunExp h args _) = pretty h <+> parens (commasep $ map pretty args) instance (Pretty v) => Pretty (TPrimExp t v) where pretty = pretty . untyped -- | Produce a mapping from the leaves of the 'PrimExp' to their -- designated types. leafExpTypes :: (Ord a) => PrimExp a -> S.Set (a, PrimType) leafExpTypes (LeafExp x ptp) = S.singleton (x, ptp) leafExpTypes (ValueExp _) = S.empty leafExpTypes (UnOpExp _ e) = leafExpTypes e leafExpTypes (ConvOpExp _ e) = leafExpTypes e leafExpTypes (BinOpExp _ e1 e2) = S.union (leafExpTypes e1) (leafExpTypes e2) leafExpTypes (CmpOpExp _ e1 e2) = S.union (leafExpTypes e1) (leafExpTypes e2) leafExpTypes (FunExp _ pes _) = S.unions $ map leafExpTypes pes -- | Multiplication of untyped 'PrimExp's, which must have the same -- type. Uses 'OverflowWrap' for integer operations. (~*~) :: PrimExp v -> PrimExp v -> PrimExp v x ~*~ y = BinOpExp op x y where t = primExpType x op = case t of IntType it -> Mul it OverflowWrap FloatType ft -> FMul ft Bool -> LogAnd Unit -> LogAnd -- | Division of untyped 'PrimExp's, which must have the same -- type. For integers, this is unsafe signed division. (~/~) :: PrimExp v -> PrimExp v -> PrimExp v x ~/~ y = BinOpExp op x y where t = primExpType x op = case t of IntType it -> SDiv it Unsafe FloatType ft -> FDiv ft Bool -> LogAnd Unit -> LogAnd -- | Addition of untyped 'PrimExp's, which must have the same type. -- Uses 'OverflowWrap' for integer operations. (~+~) :: PrimExp v -> PrimExp v -> PrimExp v x ~+~ y = BinOpExp op x y where t = primExpType x op = case t of IntType it -> Add it OverflowWrap FloatType ft -> FAdd ft Bool -> LogOr Unit -> LogOr -- | Subtraction of untyped 'PrimExp's, which must have the same type. -- Uses 'OverflowWrap' for integer operations. (~-~) :: PrimExp v -> PrimExp v -> PrimExp v x ~-~ y = BinOpExp op x y where t = primExpType x op = case t of IntType it -> Sub it OverflowWrap FloatType ft -> FSub ft Bool -> LogOr Unit -> LogOr -- | Equality of untyped 'PrimExp's, which must have the same type. (~==~) :: PrimExp v -> PrimExp v -> PrimExp v x ~==~ y = CmpOpExp (CmpEq t) x y where t = primExpType x infix 7 ~*~, ~/~ infix 6 ~+~, ~-~ infix 4 ~==~ futhark-0.25.27/src/Futhark/Analysis/PrimExp/000077500000000000000000000000001475065116200207045ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Analysis/PrimExp/Convert.hs000066400000000000000000000132051475065116200226610ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} -- | Converting back and forth between 'PrimExp's. Use the 'ToExp' -- instance to convert to Futhark expressions. module Futhark.Analysis.PrimExp.Convert ( primExpFromExp, primExpFromSubExp, pe32, le32, pe64, le64, f32pe, f32le, f64pe, f64le, primExpFromSubExpM, replaceInPrimExp, replaceInPrimExpM, substituteInPrimExp, primExpSlice, subExpSlice, -- * Module reexport module Futhark.Analysis.PrimExp, ) where import Control.Monad.Fail qualified as Fail import Control.Monad.Identity import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Analysis.PrimExp import Futhark.Construct import Futhark.IR instance (ToExp v) => ToExp (PrimExp v) where toExp (BinOpExp op x y) = BasicOp <$> (BinOp op <$> toSubExp "binop_x" x <*> toSubExp "binop_y" y) toExp (CmpOpExp op x y) = BasicOp <$> (CmpOp op <$> toSubExp "cmpop_x" x <*> toSubExp "cmpop_y" y) toExp (UnOpExp op x) = BasicOp <$> (UnOp op <$> toSubExp "unop_x" x) toExp (ConvOpExp op x) = BasicOp <$> (ConvOp op <$> toSubExp "convop_x" x) toExp (ValueExp v) = pure $ BasicOp $ SubExp $ Constant v toExp (FunExp h args t) = Apply (nameFromText h) <$> args' <*> pure [(primRetType t, mempty)] <*> pure (Safe, mempty, []) where args' = zip <$> mapM (toSubExp "apply_arg") args <*> pure (repeat Observe) toExp (LeafExp v _) = toExp v instance (ToExp v) => ToExp (TPrimExp t v) where toExp = toExp . untyped -- | Convert an expression to a 'PrimExp'. The provided function is -- used to convert expressions that are not trivially 'PrimExp's. -- This includes constants and variable names, which are passed as -- t'SubExp's. primExpFromExp :: (Fail.MonadFail m, RepTypes rep) => (VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v) primExpFromExp f (BasicOp (BinOp op x y)) = BinOpExp op <$> primExpFromSubExpM f x <*> primExpFromSubExpM f y primExpFromExp f (BasicOp (CmpOp op x y)) = CmpOpExp op <$> primExpFromSubExpM f x <*> primExpFromSubExpM f y primExpFromExp f (BasicOp (UnOp op x)) = UnOpExp op <$> primExpFromSubExpM f x primExpFromExp f (BasicOp (ConvOp op x)) = ConvOpExp op <$> primExpFromSubExpM f x primExpFromExp f (BasicOp (SubExp se)) = primExpFromSubExpM f se primExpFromExp f (Apply fname args ts _) | isBuiltInFunction fname, [Prim t] <- map (declExtTypeOf . fst) ts = FunExp (nameToText fname) <$> mapM (primExpFromSubExpM f . fst) args <*> pure t primExpFromExp _ _ = fail "Not a PrimExp" -- | Like 'primExpFromExp', but for a t'SubExp'. primExpFromSubExpM :: (Applicative m) => (VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v) primExpFromSubExpM f (Var v) = f v primExpFromSubExpM _ (Constant v) = pure $ ValueExp v -- | Convert t'SubExp's of a given type. primExpFromSubExp :: PrimType -> SubExp -> PrimExp VName primExpFromSubExp t (Var v) = LeafExp v t primExpFromSubExp _ (Constant v) = ValueExp v -- | Shorthand for constructing a 'TPrimExp' of type v'Int32'. pe32 :: SubExp -> TPrimExp Int32 VName pe32 = isInt32 . primExpFromSubExp int32 -- | Shorthand for constructing a 'TPrimExp' of type v'Int32', from a leaf. le32 :: a -> TPrimExp Int32 a le32 = isInt32 . flip LeafExp int32 -- | Shorthand for constructing a 'TPrimExp' of type v'Int64'. pe64 :: SubExp -> TPrimExp Int64 VName pe64 = isInt64 . primExpFromSubExp int64 -- | Shorthand for constructing a 'TPrimExp' of type v'Int64', from a leaf. le64 :: a -> TPrimExp Int64 a le64 = isInt64 . flip LeafExp int64 -- | Shorthand for constructing a 'TPrimExp' of type 'Float32'. f32pe :: SubExp -> TPrimExp Float VName f32pe = isF32 . primExpFromSubExp float32 -- | Shorthand for constructing a 'TPrimExp' of type v'Float32', from a leaf. f32le :: a -> TPrimExp Float a f32le = isF32 . flip LeafExp float32 -- | Shorthand for constructing a 'TPrimExp' of type v'Float64'. f64pe :: SubExp -> TPrimExp Double VName f64pe = isF64 . primExpFromSubExp float64 -- | Shorthand for constructing a 'TPrimExp' of type v'Float64', from a leaf. f64le :: a -> TPrimExp Double a f64le = isF64 . flip LeafExp float64 -- | Applying a monadic transformation to the leaves in a 'PrimExp'. replaceInPrimExpM :: (Monad m) => (a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b) replaceInPrimExpM f (LeafExp v pt) = f v pt replaceInPrimExpM _ (ValueExp v) = pure $ ValueExp v replaceInPrimExpM f (BinOpExp bop pe1 pe2) = constFoldPrimExp <$> (BinOpExp bop <$> replaceInPrimExpM f pe1 <*> replaceInPrimExpM f pe2) replaceInPrimExpM f (CmpOpExp cop pe1 pe2) = CmpOpExp cop <$> replaceInPrimExpM f pe1 <*> replaceInPrimExpM f pe2 replaceInPrimExpM f (UnOpExp uop pe) = UnOpExp uop <$> replaceInPrimExpM f pe replaceInPrimExpM f (ConvOpExp cop pe) = ConvOpExp cop <$> replaceInPrimExpM f pe replaceInPrimExpM f (FunExp h args t) = FunExp h <$> mapM (replaceInPrimExpM f) args <*> pure t -- | As 'replaceInPrimExpM', but in the identity monad. replaceInPrimExp :: (a -> PrimType -> PrimExp b) -> PrimExp a -> PrimExp b replaceInPrimExp f e = runIdentity $ replaceInPrimExpM f' e where f' x y = pure $ f x y -- | Substituting names in a PrimExp with other PrimExps substituteInPrimExp :: (Ord v) => M.Map v (PrimExp v) -> PrimExp v -> PrimExp v substituteInPrimExp tab = replaceInPrimExp $ \v t -> fromMaybe (LeafExp v t) $ M.lookup v tab -- | Convert a t'SubExp' slice to a 'PrimExp' slice. primExpSlice :: Slice SubExp -> Slice (TPrimExp Int64 VName) primExpSlice = fmap pe64 -- | Convert a 'PrimExp' slice to a t'SubExp' slice. subExpSlice :: (MonadBuilder m) => Slice (TPrimExp Int64 VName) -> m (Slice SubExp) subExpSlice = traverse $ toSubExp "slice" futhark-0.25.27/src/Futhark/Analysis/PrimExp/Parse.hs000066400000000000000000000033461475065116200223200ustar00rootroot00000000000000-- | Building blocks for parsing prim primexpressions. *Not* an infix -- representation. module Futhark.Analysis.PrimExp.Parse ( pPrimExp, pPrimValue, -- * Module reexport module Futhark.Analysis.PrimExp, ) where import Data.Functor import Data.Text qualified as T import Data.Void import Futhark.Analysis.PrimExp import Futhark.Util.Pretty (prettyText) import Language.Futhark.Primitive.Parse import Text.Megaparsec pBinOp :: Parsec Void T.Text BinOp pBinOp = choice $ map p allBinOps where p op = keyword (prettyText op) $> op pCmpOp :: Parsec Void T.Text CmpOp pCmpOp = choice $ map p allCmpOps where p op = keyword (prettyText op) $> op pUnOp :: Parsec Void T.Text UnOp pUnOp = choice $ map p allUnOps where p op = keyword (prettyText op) $> op pConvOp :: Parsec Void T.Text ConvOp pConvOp = choice $ map p allConvOps where p op = keyword (prettyText op) $> op parens :: Parsec Void T.Text a -> Parsec Void T.Text a parens = between (lexeme "(") (lexeme ")") -- | Parse a 'PrimExp' given a leaf parser. pPrimExp :: PrimType -> Parsec Void T.Text v -> Parsec Void T.Text (PrimExp v) pPrimExp t pLeaf = choice [ flip LeafExp t <$> pLeaf, ValueExp <$> pPrimValue, pBinOp >>= binOpExp, pCmpOp >>= cmpOpExp, pConvOp >>= convOpExp, pUnOp >>= unOpExp, parens $ pPrimExp t pLeaf ] where binOpExp op = BinOpExp op <$> pPrimExp (binOpType op) pLeaf <*> pPrimExp (binOpType op) pLeaf cmpOpExp op = CmpOpExp op <$> pPrimExp (cmpOpType op) pLeaf <*> pPrimExp (cmpOpType op) pLeaf convOpExp op = ConvOpExp op <$> pPrimExp (fst (convOpType op)) pLeaf unOpExp op = UnOpExp op <$> pPrimExp (unOpType op) pLeaf futhark-0.25.27/src/Futhark/Analysis/PrimExp/Simplify.hs000066400000000000000000000035051475065116200230370ustar00rootroot00000000000000-- | Defines simplification functions for 'PrimExp's. module Futhark.Analysis.PrimExp.Simplify (simplifyPrimExp, simplifyExtPrimExp) where import Futhark.Analysis.PrimExp import Futhark.IR import Futhark.Optimise.Simplify.Engine as Engine -- | Simplify a 'PrimExp', including copy propagation. If a 'LeafExp' -- refers to a name that is a 'Constant', the node turns into a -- 'ValueExp'. simplifyPrimExp :: (SimplifiableRep rep) => PrimExp VName -> SimpleM rep (PrimExp VName) simplifyPrimExp = simplifyAnyPrimExp onLeaf where onLeaf v pt = do se <- simplify $ Var v case se of Var v' -> pure $ LeafExp v' pt Constant pv -> pure $ ValueExp pv -- | Like 'simplifyPrimExp', but where leaves may be 'Ext's. simplifyExtPrimExp :: (SimplifiableRep rep) => PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName)) simplifyExtPrimExp = simplifyAnyPrimExp onLeaf where onLeaf (Free v) pt = do se <- simplify $ Var v case se of Var v' -> pure $ LeafExp (Free v') pt Constant pv -> pure $ ValueExp pv onLeaf (Ext i) pt = pure $ LeafExp (Ext i) pt simplifyAnyPrimExp :: (SimplifiableRep rep) => (a -> PrimType -> SimpleM rep (PrimExp a)) -> PrimExp a -> SimpleM rep (PrimExp a) simplifyAnyPrimExp f (LeafExp v pt) = f v pt simplifyAnyPrimExp _ (ValueExp pv) = pure $ ValueExp pv simplifyAnyPrimExp f (BinOpExp bop e1 e2) = BinOpExp bop <$> simplifyAnyPrimExp f e1 <*> simplifyAnyPrimExp f e2 simplifyAnyPrimExp f (CmpOpExp cmp e1 e2) = CmpOpExp cmp <$> simplifyAnyPrimExp f e1 <*> simplifyAnyPrimExp f e2 simplifyAnyPrimExp f (UnOpExp op e) = UnOpExp op <$> simplifyAnyPrimExp f e simplifyAnyPrimExp f (ConvOpExp conv e) = ConvOpExp conv <$> simplifyAnyPrimExp f e simplifyAnyPrimExp f (FunExp h args t) = FunExp h <$> mapM (simplifyAnyPrimExp f) args <*> pure t futhark-0.25.27/src/Futhark/Analysis/PrimExp/Table.hs000066400000000000000000000137221475065116200222740ustar00rootroot00000000000000-- | Compute a mapping from variables to their corresponding (fully -- expanded) PrimExps. module Futhark.Analysis.PrimExp.Table ( primExpTable, PrimExpTable, -- * Extensibility PrimExpAnalysis (..), -- * Testing stmToPrimExps, ) where import Control.Monad.State.Strict import Data.Foldable import Data.Map.Strict qualified as M import Futhark.Analysis.PrimExp import Futhark.Analysis.PrimExp.Convert import Futhark.IR.Aliases import Futhark.IR.GPU import Futhark.IR.GPUMem import Futhark.IR.MC import Futhark.IR.MCMem -- | Maps variables to maybe PrimExps. Will map to nothing if it -- cannot be resolved to a PrimExp. For all uses of this analysis atm. -- a variable can be considered inscrutable if it cannot be resolved -- to a primexp. type PrimExpTable = M.Map VName (Maybe (PrimExp VName)) -- | A class for extracting PrimExps from what is inside an op. class PrimExpAnalysis rep where opPrimExp :: Scope rep -> Op rep -> State PrimExpTable () primExpTable :: (PrimExpAnalysis rep, RepTypes rep) => Prog rep -> PrimExpTable primExpTable prog = initialState <> foldMap' (uncurry funToPrimExp) scopesAndFuns where scopesAndFuns = do let fun_defs = progFuns prog let scopes = map getScope fun_defs zip scopes fun_defs getScope funDef = scopeOf (progConsts prog) <> scopeOfFParams (funDefParams funDef) -- We need to have the dummy "slice" in the analysis for our "slice hack". initialState = M.singleton (VName "slice" 0) $ Just $ LeafExp (VName "slice" 0) $ IntType Int64 funToPrimExp :: (PrimExpAnalysis rep, RepTypes rep) => Scope rep -> FunDef rep -> PrimExpTable funToPrimExp scope fundef = execState (bodyToPrimExps scope (funDefBody fundef)) mempty -- | Adds the statements of a body to the PrimExpTable bodyToPrimExps :: (PrimExpAnalysis rep, RepTypes rep) => Scope rep -> Body rep -> State PrimExpTable () bodyToPrimExps scope body = mapM_ (stmToPrimExps scope') (bodyStms body) where scope' = scope <> scopeOf (bodyStms body) -- | Adds the statements of a kernel body to the PrimExpTable kernelToBodyPrimExps :: (PrimExpAnalysis rep, RepTypes rep) => Scope rep -> KernelBody rep -> State PrimExpTable () kernelToBodyPrimExps scope kbody = mapM_ (stmToPrimExps scope') (kernelBodyStms kbody) where scope' = scope <> scopeOf (kernelBodyStms kbody) -- | Adds a statement to the PrimExpTable. If it can't be resolved as a `PrimExp`, -- it will be added as `Nothing`. stmToPrimExps :: forall rep. (PrimExpAnalysis rep, RepTypes rep) => Scope rep -> Stm rep -> State PrimExpTable () stmToPrimExps scope stm = do table <- get case stm of (Let (Pat pat_elems) _ e) | Just primExp <- primExpFromExp (toPrimExp scope table) e -> -- The statement can be resolved as a `PrimExp`. -- For each pattern element, insert the PrimExp in the table forM_ pat_elems $ \pe -> modify $ M.insert (patElemName pe) (Just primExp) | otherwise -> do -- The statement can't be resolved as a `PrimExp`. walk $ stmExp stm -- Traverse the rest of the AST Get the -- updated PrimExpTable after traversing the AST table' <- get -- Add pattern elements that can't be resolved as `PrimExp` -- to the `PrimExpTable` as `Nothing` forM_ pat_elems $ \pe -> case M.lookup (patElemName pe) table' of Nothing -> modify $ M.insert (patElemName pe) Nothing Just _ -> pure () where walk e = do -- Handle most cases using the walker walkExpM walker e -- Additionally, handle loop parameters case e of Loop _ (ForLoop i t _) _ -> modify $ M.insert i $ Just $ LeafExp i $ IntType t _ -> pure () walker = (identityWalker @rep) { walkOnBody = \body_scope -> bodyToPrimExps (scope <> body_scope), walkOnOp = opPrimExp scope, walkOnFParam = paramToPrimExp -- Loop parameters } -- Adds a loop parameter to the PrimExpTable paramToPrimExp :: FParam rep -> State PrimExpTable () paramToPrimExp param = do let name = paramName param -- Construct a `PrimExp` from the type of the parameter -- and add it to the `PrimExpTable` case typeOf $ paramDec param of -- TODO: Handle other types? Prim pt -> modify $ M.insert name (Just $ LeafExp name pt) _ -> pure () -- | Checks if a name is in the PrimExpTable and construct a `PrimExp` -- if it is not toPrimExp :: (RepTypes rep) => Scope rep -> PrimExpTable -> VName -> Maybe (PrimExp VName) toPrimExp scope table name = case M.lookup name table of Just maybePrimExp | Just primExp <- maybePrimExp -> Just primExp -- Already in the table _ -> case fmap typeOf . M.lookup name $ scope of (Just (Prim pt)) -> Just $ LeafExp name pt _ -> Nothing -- | Adds the parameters of a SegOp as well as the statements in its -- body to the PrimExpTable segOpToPrimExps :: (PrimExpAnalysis rep, RepTypes rep) => Scope rep -> SegOp lvl rep -> State PrimExpTable () segOpToPrimExps scope op = do forM_ (map fst $ unSegSpace $ segSpace op) $ \name -> modify $ M.insert name $ Just $ LeafExp name int64 kernelToBodyPrimExps scope (segBody op) instance PrimExpAnalysis GPU where opPrimExp scope gpu_op | (SegOp op) <- gpu_op = segOpToPrimExps scope op | (SizeOp _) <- gpu_op = pure () | (GPUBody _ body) <- gpu_op = bodyToPrimExps scope body | (Futhark.IR.GPUMem.OtherOp _) <- gpu_op = pure () instance PrimExpAnalysis MC where opPrimExp scope mc_op | (ParOp maybe_par_segop seq_segop) <- mc_op = do -- Add the statements in the parallel part of the ParOp to the PrimExpTable case maybe_par_segop of Nothing -> pure () Just _ -> forM_ maybe_par_segop $ segOpToPrimExps scope -- Add the statements in the sequential part of the ParOp to the PrimExpTable segOpToPrimExps scope seq_segop | (Futhark.IR.MCMem.OtherOp _) <- mc_op = pure () futhark-0.25.27/src/Futhark/Analysis/SymbolTable.hs000066400000000000000000000443071475065116200221010ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.Analysis.SymbolTable ( SymbolTable (bindings, loopDepth, availableAtClosestLoop, simplifyMemory), empty, fromScope, toScope, -- * Entries Entry, deepen, entryAccInput, entryDepth, entryLetBoundDec, entryIsSize, entryStm, entryFParam, entryLParam, -- * Lookup elem, lookup, lookupStm, lookupExp, lookupBasicOp, lookupType, lookupSubExp, lookupAliases, lookupLoopVar, lookupLoopParam, aliases, available, subExpAvailable, consume, index, index', Indexed (..), indexedAddCerts, IndexOp (..), -- * Insertion insertStm, insertStms, insertFParams, insertLParam, insertLoopVar, insertLoopMerge, -- * Misc hideCertified, noteAccTokens, ) where import Control.Arrow ((&&&)) import Control.Monad import Data.List (elemIndex, foldl') import Data.Map.Strict qualified as M import Data.Maybe import Data.Ord import Futhark.Analysis.PrimExp.Convert import Futhark.IR hiding (FParam, lookupType) import Futhark.IR qualified as AST import Futhark.IR.Prop.Aliases qualified as Aliases import Prelude hiding (elem, lookup) data SymbolTable rep = SymbolTable { loopDepth :: Int, bindings :: M.Map VName (Entry rep), -- | Which names are available just before the most enclosing -- loop? availableAtClosestLoop :: Names, -- | We are in a situation where we should -- simplify/hoist/un-existentialise memory as much as possible - -- typically, inside a kernel. simplifyMemory :: Bool } instance Semigroup (SymbolTable rep) where table1 <> table2 = SymbolTable { loopDepth = max (loopDepth table1) (loopDepth table2), bindings = bindings table1 <> bindings table2, availableAtClosestLoop = availableAtClosestLoop table1 <> availableAtClosestLoop table2, simplifyMemory = simplifyMemory table1 || simplifyMemory table2 } instance Monoid (SymbolTable rep) where mempty = empty empty :: SymbolTable rep empty = SymbolTable 0 M.empty mempty False fromScope :: (ASTRep rep) => Scope rep -> SymbolTable rep fromScope = M.foldlWithKey' insertFreeVar' empty where insertFreeVar' m k dec = insertFreeVar k dec m toScope :: SymbolTable rep -> Scope rep toScope = M.map entryInfo . bindings deepen :: SymbolTable rep -> SymbolTable rep deepen vtable = vtable { loopDepth = loopDepth vtable + 1, availableAtClosestLoop = namesFromList $ M.keys $ bindings vtable } -- | The result of indexing a delayed array. data Indexed = -- | A PrimExp based on the indexes (that is, without -- accessing any actual array). Indexed Certs (PrimExp VName) | -- | The indexing corresponds to another (perhaps more -- advantageous) array. IndexedArray Certs VName [TPrimExp Int64 VName] indexedAddCerts :: Certs -> Indexed -> Indexed indexedAddCerts cs1 (Indexed cs2 v) = Indexed (cs1 <> cs2) v indexedAddCerts cs1 (IndexedArray cs2 arr v) = IndexedArray (cs1 <> cs2) arr v instance FreeIn Indexed where freeIn' (Indexed cs v) = freeIn' cs <> freeIn' v freeIn' (IndexedArray cs arr v) = freeIn' cs <> freeIn' arr <> freeIn' v -- | Indexing a delayed array if possible. type IndexArray = [TPrimExp Int64 VName] -> Maybe Indexed data Entry rep = Entry { -- | True if consumed. entryConsumed :: Bool, entryDepth :: Int, -- | True if this name has been used as an array size, -- implying that it is non-negative. entryIsSize :: Bool, -- | For names that are tokens of an accumulator, this is the -- corresponding combining function and neutral element. entryAccInput :: Maybe (WithAccInput rep), entryType :: EntryType rep } data EntryType rep = LoopVar (LoopVarEntry rep) | LetBound (LetBoundEntry rep) | FParam (FParamEntry rep) | LParam (LParamEntry rep) | FreeVar (FreeVarEntry rep) data LoopVarEntry rep = LoopVarEntry { loopVarType :: IntType, loopVarBound :: SubExp } data LetBoundEntry rep = LetBoundEntry { letBoundDec :: LetDec rep, letBoundAliases :: Names, letBoundStm :: Stm rep, -- | Index a delayed array, if possible. letBoundIndex :: Int -> IndexArray } data FParamEntry rep = FParamEntry { fparamDec :: FParamInfo rep, fparamAliases :: Names, -- | If a loop parameter, the initial value and the eventual -- result. The result need not be in scope in the symbol table. fparamMerge :: Maybe (SubExp, SubExp) } data LParamEntry rep = LParamEntry { lparamDec :: LParamInfo rep, lparamAliases :: Names, lparamIndex :: IndexArray } data FreeVarEntry rep = FreeVarEntry { freeVarDec :: NameInfo rep, freeVarAliases :: Names, -- | Index a delayed array, if possible. freeVarIndex :: VName -> IndexArray } instance (ASTRep rep) => Typed (Entry rep) where typeOf = typeOf . entryInfo entryInfo :: Entry rep -> NameInfo rep entryInfo e = case entryType e of LetBound entry -> LetName $ letBoundDec entry LoopVar entry -> IndexName $ loopVarType entry FParam entry -> FParamName $ fparamDec entry LParam entry -> LParamName $ lparamDec entry FreeVar entry -> freeVarDec entry isLetBound :: Entry rep -> Maybe (LetBoundEntry rep) isLetBound e = case entryType e of LetBound entry -> Just entry _ -> Nothing entryStm :: Entry rep -> Maybe (Stm rep) entryStm = fmap letBoundStm . isLetBound entryFParam :: Entry rep -> Maybe (FParamInfo rep) entryFParam e = case entryType e of FParam e' -> Just $ fparamDec e' _ -> Nothing entryLParam :: Entry rep -> Maybe (LParamInfo rep) entryLParam e = case entryType e of LParam e' -> Just $ lparamDec e' _ -> Nothing entryLetBoundDec :: Entry rep -> Maybe (LetDec rep) entryLetBoundDec = fmap letBoundDec . isLetBound entryAliases :: EntryType rep -> Names entryAliases (LetBound e) = letBoundAliases e entryAliases (FParam e) = fparamAliases e entryAliases (LParam e) = lparamAliases e entryAliases (FreeVar e) = freeVarAliases e entryAliases (LoopVar _) = mempty -- Integers have no aliases. -- | You almost always want 'available' instead of this one. elem :: VName -> SymbolTable rep -> Bool elem name = isJust . lookup name lookup :: VName -> SymbolTable rep -> Maybe (Entry rep) lookup name = M.lookup name . bindings lookupStm :: VName -> SymbolTable rep -> Maybe (Stm rep) lookupStm name vtable = entryStm =<< lookup name vtable lookupExp :: VName -> SymbolTable rep -> Maybe (Exp rep, Certs) lookupExp name vtable = (stmExp &&& stmCerts) <$> lookupStm name vtable lookupBasicOp :: VName -> SymbolTable rep -> Maybe (BasicOp, Certs) lookupBasicOp name vtable = case lookupExp name vtable of Just (BasicOp e, cs) -> Just (e, cs) _ -> Nothing lookupType :: (ASTRep rep) => VName -> SymbolTable rep -> Maybe Type lookupType name vtable = typeOf <$> lookup name vtable lookupSubExpType :: (ASTRep rep) => SubExp -> SymbolTable rep -> Maybe Type lookupSubExpType (Var v) = lookupType v lookupSubExpType (Constant v) = const $ Just $ Prim $ primValueType v lookupSubExp :: VName -> SymbolTable rep -> Maybe (SubExp, Certs) lookupSubExp name vtable = do (e, cs) <- lookupExp name vtable case e of BasicOp (SubExp se) -> Just (se, cs) _ -> Nothing lookupAliases :: VName -> SymbolTable rep -> Names lookupAliases name vtable = maybe mempty (entryAliases . entryType) $ M.lookup name (bindings vtable) -- | If the given variable name is the name of a 'ForLoop' parameter, -- then return the bound of that loop. lookupLoopVar :: VName -> SymbolTable rep -> Maybe SubExp lookupLoopVar name vtable = do LoopVar e <- entryType <$> M.lookup name (bindings vtable) pure $ loopVarBound e -- | Look up the initial value and eventual result of a loop -- parameter. Note that the result almost certainly refers to -- something that is not part of the symbol table. lookupLoopParam :: VName -> SymbolTable rep -> Maybe (SubExp, SubExp) lookupLoopParam name vtable = do FParam e <- entryType <$> M.lookup name (bindings vtable) fparamMerge e -- | Do these two names alias each other? This is expected to be a -- commutative relationship, so the order of arguments does not -- matter. aliases :: VName -> VName -> SymbolTable rep -> Bool aliases x y vtable = x == y || (x `nameIn` lookupAliases y vtable) -- | In symbol table and not consumed. available :: VName -> SymbolTable rep -> Bool available name = maybe False (not . entryConsumed) . M.lookup name . bindings -- | Constant or 'available' subExpAvailable :: SubExp -> SymbolTable rep -> Bool subExpAvailable (Var name) = available name subExpAvailable Constant {} = const True index :: (ASTRep rep) => VName -> [SubExp] -> SymbolTable rep -> Maybe Indexed index name is table = do is' <- mapM asPrimExp is index' name is' table where asPrimExp i = do Prim t <- lookupSubExpType i table pure $ TPrimExp $ primExpFromSubExp t i index' :: VName -> [TPrimExp Int64 VName] -> SymbolTable rep -> Maybe Indexed index' name is vtable = do entry <- lookup name vtable case entryType entry of LetBound entry' | Just k <- elemIndex name . patNames . stmPat $ letBoundStm entry' -> letBoundIndex entry' k is FreeVar entry' -> freeVarIndex entry' name is LParam entry' -> lparamIndex entry' is _ -> Nothing class IndexOp op where indexOp :: (ASTRep rep, IndexOp (Op rep)) => SymbolTable rep -> Int -> op -> [TPrimExp Int64 VName] -> Maybe Indexed indexOp _ _ _ _ = Nothing instance IndexOp (NoOp rep) indexExp :: (IndexOp (Op rep), ASTRep rep) => SymbolTable rep -> Exp rep -> Int -> IndexArray indexExp vtable (Op op) k is = indexOp vtable k op is indexExp _ (BasicOp (Iota _ x s to_it)) _ [i] = Just $ Indexed mempty $ ( sExt to_it (untyped i) `mul` primExpFromSubExp (IntType to_it) s ) `add` primExpFromSubExp (IntType to_it) x where mul = BinOpExp (Mul to_it OverflowWrap) add = BinOpExp (Add to_it OverflowWrap) indexExp table (BasicOp (Replicate (Shape ds) v)) _ is | length ds == length is, Just (Prim t) <- lookupSubExpType v table = Just $ Indexed mempty $ primExpFromSubExp t v indexExp table (BasicOp (Replicate s (Var v))) _ is = do guard $ v `available` table guard $ s /= mempty index' v (drop (shapeRank s) is) table indexExp table (BasicOp (Reshape _ newshape v)) _ is | Just oldshape <- arrayDims <$> lookupType v table = -- TODO: handle coercions more efficiently. let is' = reshapeIndex (map pe64 oldshape) (map pe64 $ shapeDims newshape) is in index' v is' table indexExp table (BasicOp (Index v slice)) _ is = do guard $ v `available` table index' v (adjust (unSlice slice) is) table where adjust (DimFix j : js') is' = pe64 j : adjust js' is' adjust (DimSlice j _ s : js') (i : is') = let i_t_s = i * pe64 s j_p_i_t_s = pe64 j + i_t_s in j_p_i_t_s : adjust js' is' adjust _ _ = [] indexExp _ _ _ _ = Nothing defBndEntry :: (ASTRep rep, IndexOp (Op rep)) => SymbolTable rep -> PatElem (LetDec rep) -> Names -> Stm rep -> LetBoundEntry rep defBndEntry vtable patElem als stm = LetBoundEntry { letBoundDec = patElemDec patElem, letBoundAliases = als, letBoundStm = stm, letBoundIndex = \k -> fmap (indexedAddCerts (stmAuxCerts $ stmAux stm)) . indexExp vtable (stmExp stm) k } bindingEntries :: (Aliases.Aliased rep, IndexOp (Op rep)) => Stm rep -> SymbolTable rep -> [LetBoundEntry rep] bindingEntries stm@(Let pat _ _) vtable = do pat_elem <- patElems pat pure $ defBndEntry vtable pat_elem (expandAliases (Aliases.aliasesOf pat_elem) vtable) stm adjustSeveral :: (Ord k) => (v -> v) -> [k] -> M.Map k v -> M.Map k v adjustSeveral f = flip $ foldl' $ flip $ M.adjust f insertEntry :: (ASTRep rep) => VName -> EntryType rep -> SymbolTable rep -> SymbolTable rep insertEntry name entry vtable = let entry' = Entry { entryConsumed = False, entryDepth = loopDepth vtable, entryIsSize = False, entryAccInput = Nothing, entryType = entry } dims = mapMaybe subExpVar $ arrayDims $ typeOf entry' isSize e = e {entryIsSize = True} in vtable { bindings = adjustSeveral isSize dims $ M.insert name entry' $ bindings vtable } insertEntries :: (ASTRep rep) => [(VName, EntryType rep)] -> SymbolTable rep -> SymbolTable rep insertEntries entries vtable = foldl' add vtable entries where add vtable' (name, entry) = insertEntry name entry vtable' insertStm :: (IndexOp (Op rep), Aliases.Aliased rep) => Stm rep -> SymbolTable rep -> SymbolTable rep insertStm stm vtable = flip (foldl' $ flip consume) (namesToList stm_consumed) $ flip (foldl' addRevAliases) (zip names entries) $ insertEntries (zip names $ map LetBound entries) vtable where entries = bindingEntries stm vtable names = patNames $ stmPat stm stm_consumed = expandAliases (Aliases.consumedInStm stm) vtable addRevAliases vtable' (name, LetBoundEntry {letBoundAliases = als}) = vtable' {bindings = adjustSeveral update inedges $ bindings vtable'} where inedges = namesToList $ expandAliases als vtable' update e = e {entryType = update' $ entryType e} update' (LetBound entry) = LetBound entry { letBoundAliases = oneName name <> letBoundAliases entry } update' (FParam entry) = FParam entry { fparamAliases = oneName name <> fparamAliases entry } update' (LParam entry) = LParam entry { lparamAliases = oneName name <> lparamAliases entry } update' (FreeVar entry) = FreeVar entry { freeVarAliases = oneName name <> freeVarAliases entry } update' e = e insertStms :: (IndexOp (Op rep), Aliases.Aliased rep) => Stms rep -> SymbolTable rep -> SymbolTable rep insertStms stms vtable = foldl' (flip insertStm) vtable $ stmsToList stms expandAliases :: Names -> SymbolTable rep -> Names expandAliases names vtable = names <> aliasesOfAliases where aliasesOfAliases = mconcat . map (`lookupAliases` vtable) . namesToList $ names insertFParam :: (ASTRep rep) => AST.FParam rep -> SymbolTable rep -> SymbolTable rep insertFParam fparam = insertEntry name entry where name = AST.paramName fparam entry = FParam FParamEntry { fparamDec = AST.paramDec fparam, fparamAliases = mempty, fparamMerge = Nothing } insertFParams :: (ASTRep rep) => [AST.FParam rep] -> SymbolTable rep -> SymbolTable rep insertFParams fparams symtable = foldl' (flip insertFParam) symtable fparams insertLParam :: (ASTRep rep) => LParam rep -> SymbolTable rep -> SymbolTable rep insertLParam param = insertEntry name bind where bind = LParam LParamEntry { lparamDec = AST.paramDec param, lparamAliases = mempty, lparamIndex = const Nothing } name = AST.paramName param -- | Insert entries corresponding to the parameters of a loop (not -- distinguishing contect and value part). Apart from the parameter -- itself, we also insert the initial value and the subexpression -- providing the final value. Note that the latter is likely not in -- scope in the symbol at this point. This is OK, and can still be -- used to help some loop optimisations detect invariant loop -- parameters. insertLoopMerge :: (ASTRep rep) => [(AST.FParam rep, SubExp, SubExpRes)] -> SymbolTable rep -> SymbolTable rep insertLoopMerge = flip $ foldl' $ flip bind where bind (p, initial, SubExpRes _ res) = insertEntry (paramName p) $ FParam FParamEntry { fparamDec = AST.paramDec p, fparamAliases = mempty, fparamMerge = Just (initial, res) } insertLoopVar :: (ASTRep rep) => VName -> IntType -> SubExp -> SymbolTable rep -> SymbolTable rep insertLoopVar name it bound = insertEntry name bind where bind = LoopVar LoopVarEntry { loopVarType = it, loopVarBound = bound } insertFreeVar :: (ASTRep rep) => VName -> NameInfo rep -> SymbolTable rep -> SymbolTable rep insertFreeVar name dec = insertEntry name entry where entry = FreeVar FreeVarEntry { freeVarDec = dec, freeVarIndex = \_ _ -> Nothing, freeVarAliases = mempty } consume :: VName -> SymbolTable rep -> SymbolTable rep consume consumee vtable = foldl' consume' vtable $ namesToList $ expandAliases (oneName consumee) vtable where consume' vtable' v = vtable' {bindings = M.adjust consume'' v $ bindings vtable'} consume'' e = e {entryConsumed = True} -- | Hide definitions of those entries that satisfy some predicate. hideIf :: (Entry rep -> Bool) -> SymbolTable rep -> SymbolTable rep hideIf hide vtable = vtable {bindings = M.map maybeHide $ bindings vtable} where maybeHide entry | hide entry = entry { entryType = FreeVar FreeVarEntry { freeVarDec = entryInfo entry, freeVarIndex = \_ _ -> Nothing, freeVarAliases = entryAliases $ entryType entry } } | otherwise = entry -- | Hide these definitions, if they are protected by certificates in -- the set of names. hideCertified :: Names -> SymbolTable rep -> SymbolTable rep hideCertified to_hide = hideIf $ maybe False hide . entryStm where hide = any (`nameIn` to_hide) . unCerts . stmCerts -- | Note that these names are tokens for the corresponding -- accumulators. The names must already be present in the symbol -- table. noteAccTokens :: [(VName, WithAccInput rep)] -> SymbolTable rep -> SymbolTable rep noteAccTokens = flip (foldl' f) where f vtable (v, accum) = case M.lookup v $ bindings vtable of Nothing -> vtable Just e -> vtable { bindings = M.insert v (e {entryAccInput = Just accum}) $ bindings vtable } futhark-0.25.27/src/Futhark/Analysis/UsageTable.hs000066400000000000000000000134551475065116200217000ustar00rootroot00000000000000{-# LANGUAGE Strict #-} -- | A usage-table is sort of a bottom-up symbol table, describing how -- (and if) a variable is used. module Futhark.Analysis.UsageTable ( UsageTable, without, lookup, used, expand, isConsumed, isInResult, isUsedDirectly, isSize, usages, usage, consumedUsage, inResultUsage, sizeUsage, sizeUsages, withoutU, Usages, consumedU, presentU, usageInStm, usageInPat, ) where import Data.Bits import Data.Foldable qualified as Foldable import Data.IntMap.Strict qualified as IM import Data.List qualified as L import Futhark.IR import Futhark.IR.Prop.Aliases import Prelude hiding (lookup) -- | A usage table. newtype UsageTable = UsageTable (IM.IntMap Usages) deriving (Eq, Show) instance Semigroup UsageTable where UsageTable table1 <> UsageTable table2 = UsageTable $ IM.unionWith (<>) table1 table2 instance Monoid UsageTable where mempty = UsageTable mempty -- | Remove these entries from the usage table. without :: UsageTable -> [VName] -> UsageTable without (UsageTable table) = UsageTable . Foldable.foldl (flip IM.delete) table . map baseTag -- | Look up a variable in the usage table. lookup :: VName -> UsageTable -> Maybe Usages lookup name (UsageTable table) = IM.lookup (baseTag name) table lookupPred :: (Usages -> Bool) -> VName -> UsageTable -> Bool lookupPred f name = maybe False f . lookup name -- | Is the variable present in the usage table? That is, has it been used? used :: VName -> UsageTable -> Bool used = lookupPred $ const True -- | Expand the usage table based on aliasing information. expand :: (VName -> Names) -> UsageTable -> UsageTable expand look (UsageTable m) = UsageTable $ L.foldl' grow m $ IM.toList m where grow m' (k, v) = L.foldl' (grow'' $ v `withoutU` presentU) m' (namesIntMap $ look $ VName (nameFromString "") k) grow'' v m'' k = IM.insertWith (<>) (baseTag k) v m'' is :: Usages -> VName -> UsageTable -> Bool is = lookupPred . matches -- | Has the variable been consumed? isConsumed :: VName -> UsageTable -> Bool isConsumed = is consumedU -- | Has the variable been used in the 'Result' of a body? isInResult :: VName -> UsageTable -> Bool isInResult = is inResultU -- | Has the given name been used directly (i.e. could we rename it or -- remove it without anyone noticing?) isUsedDirectly :: VName -> UsageTable -> Bool isUsedDirectly = is presentU -- | Is this name used as the size of something (array or memory block)? isSize :: VName -> UsageTable -> Bool isSize = is sizeU -- | Construct a usage table reflecting that these variables have been -- used. usages :: Names -> UsageTable usages = UsageTable . IM.map (const presentU) . namesIntMap -- | Construct a usage table where the given variable has been used in -- this specific way. usage :: VName -> Usages -> UsageTable usage name uses = UsageTable $ IM.singleton (baseTag name) uses -- | Construct a usage table where the given variable has been consumed. consumedUsage :: VName -> UsageTable consumedUsage name = UsageTable $ IM.singleton (baseTag name) consumedU -- | Construct a usage table where the given variable has been used in -- the 'Result' of a body. inResultUsage :: VName -> UsageTable inResultUsage name = UsageTable $ IM.singleton (baseTag name) inResultU -- | Construct a usage table where the given variable has been used as -- an array or memory size. sizeUsage :: VName -> UsageTable sizeUsage name = UsageTable $ IM.singleton (baseTag name) sizeU -- | Construct a usage table where the given names have been used as -- an array or memory size. sizeUsages :: Names -> UsageTable sizeUsages = UsageTable . IM.map (const (sizeU <> presentU)) . namesIntMap -- | A description of how a single variable has been used. newtype Usages = Usages Int -- Bitmap representation for speed. deriving (Eq, Ord, Show) instance Semigroup Usages where Usages x <> Usages y = Usages $ x .|. y instance Monoid Usages where mempty = Usages 0 -- | A kind of usage. consumedU, inResultU, presentU, sizeU :: Usages consumedU = Usages 1 inResultU = Usages 2 presentU = Usages 4 sizeU = Usages 8 -- | Check whether the bits that are set in the first argument are -- also set in the second. matches :: Usages -> Usages -> Bool matches (Usages x) (Usages y) = x == (x .&. y) -- | x - y, but for 'Usages'. withoutU :: Usages -> Usages -> Usages withoutU (Usages x) (Usages y) = Usages $ x .&. complement y usageInBody :: (Aliased rep) => Body rep -> UsageTable usageInBody = foldMap consumedUsage . namesToList . consumedInBody -- | Produce a usage table reflecting the use of the free variables in -- a single statement. usageInStm :: (Aliased rep) => Stm rep -> UsageTable usageInStm (Let pat rep e) = mconcat [ usageInPat pat `without` patNames pat, usages $ freeIn rep, usageInExp e, usages (freeIn e) ] -- | Usage table reflecting use in pattern. In particular, free -- variables in the decorations are considered used as sizes, even if -- they are also bound in this pattern. usageInPat :: (FreeIn t) => Pat t -> UsageTable usageInPat = sizeUsages . foldMap freeIn . patElems usageInExp :: (Aliased rep) => Exp rep -> UsageTable usageInExp (Apply _ args _ _) = mconcat [consumedUsage v | (Var v, Consume) <- args] usageInExp e@Loop {} = foldMap consumedUsage $ namesToList $ consumedInExp e usageInExp (Match _ cases defbody _) = foldMap (usageInBody . caseBody) cases <> usageInBody defbody usageInExp (WithAcc inputs lam) = foldMap inputUsage inputs <> usageInBody (lambdaBody lam) where inputUsage (_, arrs, _) = foldMap consumedUsage arrs usageInExp (BasicOp (Update _ src _ _)) = consumedUsage src usageInExp (BasicOp (FlatUpdate src _ _)) = consumedUsage src usageInExp (Op op) = mconcat $ map consumedUsage (namesToList $ consumedInOp op) usageInExp (BasicOp _) = mempty futhark-0.25.27/src/Futhark/Bench.hs000066400000000000000000000336261475065116200171220ustar00rootroot00000000000000-- | Facilities for handling Futhark benchmark results. A Futhark -- benchmark program is just like a Futhark test program. module Futhark.Bench ( RunResult (..), DataResult (..), BenchResult (..), Result (..), encodeBenchResults, decodeBenchResults, binaryName, benchmarkDataset, RunOptions (..), prepareBenchmarkProgram, CompileOptions (..), module Futhark.Profile, ) where import Control.Applicative import Control.Monad import Control.Monad.Except (ExceptT, MonadError (..), liftEither, runExceptT) import Control.Monad.IO.Class (MonadIO, liftIO) import Data.Aeson qualified as JSON import Data.Aeson.Key qualified as JSON import Data.Aeson.KeyMap qualified as JSON import Data.ByteString.Char8 qualified as SBS import Data.ByteString.Lazy.Char8 qualified as LBS import Data.DList qualified as DL import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T import Data.Time.Clock import Data.Vector.Unboxed qualified as U import Futhark.Profile import Futhark.Server import Futhark.Test import Futhark.Util (showText) import Statistics.Autocorrelation (autocorrelation) import Statistics.Sample (fastStdDev, mean) import System.Exit import System.FilePath import System.Process.ByteString (readProcessWithExitCode) import System.Timeout (timeout) -- | The runtime of a single succesful run. newtype RunResult = RunResult {runMicroseconds :: Int} deriving (Eq, Show) -- | The measurements resulting from various successful runs of a -- benchmark (same dataset). data Result = Result { -- | Runtime of every run. runResults :: [RunResult], -- | Memory usage. memoryMap :: M.Map T.Text Int, -- | The error output produced during execution. Often 'Nothing' -- for space reasons, and otherwise only the output from the last -- run. stdErr :: Maybe T.Text, -- | Profiling report. This will have been measured based on the -- last run. report :: Maybe ProfilingReport } deriving (Eq, Show) -- | The results for a single named dataset is either an error message, or -- a result. data DataResult = DataResult T.Text (Either T.Text Result) deriving (Eq, Show) -- | The results for all datasets for some benchmark program. data BenchResult = BenchResult { benchResultProg :: FilePath, benchResultResults :: [DataResult] } deriving (Eq, Show) newtype DataResults = DataResults {unDataResults :: [DataResult]} newtype BenchResults = BenchResults {unBenchResults :: [BenchResult]} instance JSON.ToJSON Result where toJSON (Result runres memmap err profiling) = JSON.toJSON (runres, memmap, err, profiling) instance JSON.FromJSON Result where parseJSON = fmap f . JSON.parseJSON where f (runres, memmap, err, profiling) = Result runres memmap err profiling instance JSON.ToJSON RunResult where toJSON = JSON.toJSON . runMicroseconds instance JSON.FromJSON RunResult where parseJSON = fmap RunResult . JSON.parseJSON instance JSON.ToJSON DataResults where toJSON (DataResults rs) = JSON.object $ map dataResultJSON rs toEncoding (DataResults rs) = JSON.pairs $ mconcat $ map (uncurry (JSON..=) . dataResultJSON) rs instance JSON.FromJSON DataResults where parseJSON = JSON.withObject "datasets" $ \o -> DataResults <$> mapM datasetResult (JSON.toList o) where datasetResult (k, v) = DataResult (JSON.toText k) <$> ((Right <$> success v) <|> (Left <$> JSON.parseJSON v)) success = JSON.withObject "result" $ \o -> Result <$> o JSON..: "runtimes" <*> o JSON..: "bytes" <*> o JSON..:? "stderr" <*> o JSON..:? "profiling" dataResultJSON :: DataResult -> (JSON.Key, JSON.Value) dataResultJSON (DataResult desc (Left err)) = (JSON.fromText desc, JSON.toJSON err) dataResultJSON (DataResult desc (Right (Result runtimes bytes progerr_opt profiling_opt))) = ( JSON.fromText desc, JSON.object $ [ ("runtimes", JSON.toJSON $ map runMicroseconds runtimes), ("bytes", JSON.toJSON bytes) ] <> case progerr_opt of Nothing -> [] Just progerr -> [("stderr", JSON.toJSON progerr)] <> case profiling_opt of Nothing -> [] Just profiling -> [("profiling", JSON.toJSON profiling)] ) benchResultJSON :: BenchResult -> (JSON.Key, JSON.Value) benchResultJSON (BenchResult prog r) = ( JSON.fromString prog, JSON.object [("datasets", JSON.toJSON $ DataResults r)] ) instance JSON.ToJSON BenchResults where toJSON (BenchResults rs) = JSON.object $ map benchResultJSON rs instance JSON.FromJSON BenchResults where parseJSON = JSON.withObject "benchmarks" $ \o -> BenchResults <$> mapM onBenchmark (JSON.toList o) where onBenchmark (k, v) = BenchResult (JSON.toString k) <$> JSON.withObject "benchmark" onBenchmark' v onBenchmark' o = fmap unDataResults . JSON.parseJSON =<< o JSON..: "datasets" -- | Transform benchmark results to a JSON bytestring. encodeBenchResults :: [BenchResult] -> LBS.ByteString encodeBenchResults = JSON.encode . BenchResults -- | Decode benchmark results from a JSON bytestring. decodeBenchResults :: LBS.ByteString -> Either String [BenchResult] decodeBenchResults = fmap unBenchResults . JSON.eitherDecode' --- Running benchmarks -- | How to run a benchmark. data RunOptions = RunOptions { -- | Applies both to initial and convergence phase. runMinRuns :: Int, runMinTime :: NominalDiffTime, runTimeout :: Int, runVerbose :: Int, -- | If true, run the convergence phase. runConvergencePhase :: Bool, -- | Stop convergence once this much time has passed. runConvergenceMaxTime :: NominalDiffTime, -- | Invoked for every runtime measured during the run. Can be -- used to provide a progress bar. runResultAction :: (Int, Maybe Double) -> IO (), -- | Perform a final run at the end with profiling information -- enabled. runProfile :: Bool } -- | A list of @(autocorrelation,rse)@ pairs. When the -- autocorrelation is above the first element and the RSE is above the -- second element, we want more runs. convergenceCriteria :: [(Double, Double)] convergenceCriteria = [ (0.95, 0.0010), (0.75, 0.0015), (0.65, 0.0025), (0.45, 0.0050), (0.00, 0.0100) ] -- Returns the next run count. nextRunCount :: Int -> Double -> Double -> Int nextRunCount runs rse acor = if any check convergenceCriteria then div runs 2 else 0 where check (acor_lb, rse_lb) = acor > acor_lb && rse > rse_lb type BenchM = ExceptT T.Text IO -- Do the minimum number of runs. runMinimum :: BenchM (RunResult, [T.Text]) -> RunOptions -> Int -> NominalDiffTime -> DL.DList (RunResult, [T.Text]) -> BenchM (DL.DList (RunResult, [T.Text])) runMinimum do_run opts runs_done elapsed r = do let actions = do x <- do_run liftIO $ runResultAction opts (runMicroseconds (fst x), Nothing) pure x -- Figure out how much we have left to do. let todo | runs_done < runMinRuns opts = runMinRuns opts - runs_done | otherwise = -- Guesstimate how many runs we need to reach the minimum -- time. let time_per_run = elapsed / fromIntegral runs_done in ceiling ((runMinTime opts - elapsed) / time_per_run) -- Note that todo might be negative if minimum time has been exceeded. if todo <= 0 then pure r else do before <- liftIO getCurrentTime r' <- DL.fromList <$> replicateM todo actions after <- liftIO getCurrentTime let elapsed' = elapsed + diffUTCTime after before runMinimum do_run opts (runs_done + todo) elapsed' (r <> r') -- Do more runs until a convergence criterion is reached. runConvergence :: BenchM (RunResult, [T.Text]) -> RunOptions -> DL.DList (RunResult, [T.Text]) -> BenchM (DL.DList (RunResult, [T.Text])) runConvergence do_run opts initial_r = let runtimes = resultRuntimes (DL.toList initial_r) (n, _, rse, acor) = runtimesMetrics runtimes in -- If the runtimes collected during the runMinimum phase are -- unstable enough that we need more in order to converge, we throw -- away the runMinimum runtimes. This is because they often exhibit -- behaviour similar to a "warmup" period, and hence function as -- outliers that poison the metrics we use to determine convergence. -- By throwing them away we converge much faster, and still get the -- right result. case nextRunCount n rse acor of x | x > 0, runConvergencePhase opts -> moreRuns mempty mempty rse (x `max` runMinRuns opts) | otherwise -> pure initial_r where resultRuntimes = U.fromList . map (fromIntegral . runMicroseconds . fst) runtimesMetrics runtimes = let n = U.length runtimes rse = (fastStdDev runtimes / sqrt (fromIntegral n)) / mean runtimes (x, _, _) = autocorrelation runtimes in ( n, realToFrac (U.sum runtimes) :: NominalDiffTime, rse, fromMaybe 1 (x U.!? 1) ) sample rse = do x <- do_run liftIO $ runResultAction opts (runMicroseconds (fst x), Just rse) pure x moreRuns runtimes r rse x = do r' <- replicateM x $ sample rse loop (runtimes <> resultRuntimes r') (r <> DL.fromList r') loop runtimes r = do let (n, total, rse, acor) = runtimesMetrics runtimes case nextRunCount n rse acor of x | x > 0, total < runConvergenceMaxTime opts -> moreRuns runtimes r rse x | otherwise -> pure r -- | Run the benchmark program on the indicated dataset. benchmarkDataset :: Server -> RunOptions -> FutharkExe -> FilePath -> T.Text -> Values -> Maybe Success -> FilePath -> IO (Either T.Text ([RunResult], T.Text, ProfilingReport)) benchmarkDataset server opts futhark program entry input_spec expected_spec ref_out = runExceptT $ do output_types <- cmdEither $ cmdOutputs server entry input_types <- cmdEither $ cmdInputs server entry let outs = ["out" <> showText i | i <- [0 .. length output_types - 1]] ins = ["in" <> showText i | i <- [0 .. length input_types - 1]] cmdMaybe . liftIO $ cmdClear server cmdMaybe . liftIO $ cmdPauseProfiling server let freeOuts = cmdMaybe (cmdFree server outs) freeIns = cmdMaybe (cmdFree server ins) loadInput = valuesAsVars server (zip ins $ map inputType input_types) futhark dir input_spec reloadInput = freeIns >> loadInput loadInput let runtime l | Just l' <- T.stripPrefix "runtime: " l, [(x, "")] <- reads $ T.unpack l' = Just x | otherwise = Nothing doRun = do call_lines <- cmdEither (cmdCall server entry outs ins) when (any inputConsumed input_types) reloadInput case mapMaybe runtime call_lines of [call_runtime] -> pure (RunResult call_runtime, call_lines) [] -> throwError "Could not find runtime in output." ls -> throwError $ "Ambiguous runtimes: " <> showText ls maybe_call_logs <- liftIO . timeout (runTimeout opts * 1000000) . runExceptT $ do -- First one uncounted warmup run. void $ cmdEither $ cmdCall server entry outs ins ys <- runMinimum (freeOuts *> doRun) opts 0 0 mempty xs <- runConvergence (freeOuts *> doRun) opts ys -- Possibly a profiled run at the end. profile_log <- if not (runProfile opts) then pure Nothing else do cmdMaybe . liftIO $ cmdUnpauseProfiling server profile_log <- freeOuts *> doRun cmdMaybe . liftIO $ cmdPauseProfiling server pure $ Just profile_log vs <- readResults server outs <* freeOuts pure (vs, DL.toList xs, profile_log) (vs, call_logs, profile_log) <- case maybe_call_logs of Nothing -> throwError . T.pack $ "Execution exceeded " ++ show (runTimeout opts) ++ " seconds." Just x -> liftEither x freeIns report <- cmdEither $ cmdReport server report' <- maybe (throwError "Program produced invalid profiling report.") pure $ profilingReportFromText (T.unlines report) maybe_expected <- liftIO $ maybe (pure Nothing) (fmap Just . getExpectedValues) expected_spec case maybe_expected of Just expected -> checkResult program expected vs Nothing -> pure () pure ( map fst call_logs, T.unlines $ snd $ fromMaybe (last call_logs) profile_log, report' ) where getExpectedValues (SuccessValues vs) = getValues futhark dir vs getExpectedValues SuccessGenerateValues = getExpectedValues $ SuccessValues $ InFile ref_out dir = takeDirectory program -- | How to compile a benchmark. data CompileOptions = CompileOptions { compFuthark :: String, compBackend :: String, compOptions :: [String] } progNotFound :: String -> String progNotFound s = s ++ ": command not found" -- | Compile and produce reference datasets. prepareBenchmarkProgram :: (MonadIO m) => Maybe Int -> CompileOptions -> FilePath -> [InputOutputs] -> m (Either (String, Maybe SBS.ByteString) ()) prepareBenchmarkProgram concurrency opts program cases = do let futhark = compFuthark opts ref_res <- runExceptT $ ensureReferenceOutput concurrency (FutharkExe futhark) "c" program cases case ref_res of Left err -> pure $ Left ( "Reference output generation for " <> program <> " failed:\n" <> T.unpack err, Nothing ) Right () -> do (futcode, _, futerr) <- liftIO $ readProcessWithExitCode futhark ( [compBackend opts, program, "-o", binaryName program, "--server"] <> compOptions opts ) "" case futcode of ExitSuccess -> pure $ Right () ExitFailure 127 -> pure $ Left (progNotFound futhark, Nothing) ExitFailure _ -> pure $ Left ("Compilation of " ++ program ++ " failed:\n", Just futerr) futhark-0.25.27/src/Futhark/Builder.hs000066400000000000000000000170561475065116200174700ustar00rootroot00000000000000{-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -- | This module defines a convenience monad/typeclass for building -- ASTs. The fundamental building block is 'BuilderT' and its -- execution functions, but it is usually easier to use 'Builder'. -- -- See "Futhark.Construct" for a high-level description. module Futhark.Builder ( -- * A concrete @MonadBuilder@ monad. BuilderT, runBuilderT, runBuilderT_, runBuilderT', runBuilderT'_, BuilderOps (..), Builder, runBuilder, runBuilder_, runBodyBuilder, runLambdaBuilder, -- * The 'MonadBuilder' typeclass module Futhark.Builder.Class, ) where import Control.Arrow (second) import Control.Monad.Error.Class import Control.Monad.Reader import Control.Monad.State.Strict import Control.Monad.Writer import Data.Map.Strict qualified as M import Futhark.Builder.Class import Futhark.IR -- | A 'BuilderT' (and by extension, a 'Builder') is only an instance of -- 'MonadBuilder' for representations that implement this type class, -- which contains methods for constructing statements. class (ASTRep rep) => BuilderOps rep where mkExpDecB :: (MonadBuilder m, Rep m ~ rep) => Pat (LetDec rep) -> Exp rep -> m (ExpDec rep) mkBodyB :: (MonadBuilder m, Rep m ~ rep) => Stms rep -> Result -> m (Body rep) mkLetNamesB :: (MonadBuilder m, Rep m ~ rep) => [VName] -> Exp rep -> m (Stm rep) default mkExpDecB :: (MonadBuilder m, Buildable rep) => Pat (LetDec rep) -> Exp rep -> m (ExpDec rep) mkExpDecB pat e = pure $ mkExpDec pat e default mkBodyB :: (MonadBuilder m, Buildable rep) => Stms rep -> Result -> m (Body rep) mkBodyB stms res = pure $ mkBody stms res default mkLetNamesB :: (MonadBuilder m, Rep m ~ rep, Buildable rep) => [VName] -> Exp rep -> m (Stm rep) mkLetNamesB = mkLetNames -- | A monad transformer that tracks statements and provides a -- 'MonadBuilder' instance, assuming that the underlying monad provides -- a name source. In almost all cases, this is what you will use for -- constructing statements (possibly as part of a larger monad stack). -- If you find yourself needing to implement 'MonadBuilder' from -- scratch, then it is likely that you are making a mistake. newtype BuilderT rep m a = BuilderT (StateT (Stms rep, Scope rep) m a) deriving (Functor, Monad, Applicative) instance MonadTrans (BuilderT rep) where lift = BuilderT . lift -- | The most commonly used builder monad. type Builder rep = BuilderT rep (State VNameSource) instance (MonadFreshNames m) => MonadFreshNames (BuilderT rep m) where getNameSource = lift getNameSource putNameSource = lift . putNameSource instance (ASTRep rep, Monad m) => HasScope rep (BuilderT rep m) where lookupType name = do t <- BuilderT $ gets $ M.lookup name . snd case t of Nothing -> do known <- BuilderT $ gets $ M.keys . snd error . unlines $ [ "BuilderT.lookupType: unknown variable " ++ prettyString name, "Known variables: ", unwords $ map prettyString known ] Just t' -> pure $ typeOf t' askScope = BuilderT $ gets snd instance (ASTRep rep, Monad m) => LocalScope rep (BuilderT rep m) where localScope types (BuilderT m) = BuilderT $ do modify $ second (M.union types) x <- m modify $ second (`M.difference` types) pure x instance (MonadFreshNames m, BuilderOps rep) => MonadBuilder (BuilderT rep m) where type Rep (BuilderT rep m) = rep mkExpDecM = mkExpDecB mkBodyM = mkBodyB mkLetNamesM = mkLetNamesB addStms stms = BuilderT $ modify $ \(cur_stms, scope) -> (cur_stms <> stms, scope `M.union` scopeOf stms) collectStms m = do (old_stms, old_scope) <- BuilderT get BuilderT $ put (mempty, old_scope) x <- m (new_stms, _) <- BuilderT get BuilderT $ put (old_stms, old_scope) pure (x, new_stms) -- | Run a builder action given an initial scope, returning a value and -- the statements added ('addStm') during the action. runBuilderT :: (MonadFreshNames m) => BuilderT rep m a -> Scope rep -> m (a, Stms rep) runBuilderT (BuilderT m) scope = do (x, (stms, _)) <- runStateT m (mempty, scope) pure (x, stms) -- | Like 'runBuilderT', but return only the statements. runBuilderT_ :: (MonadFreshNames m) => BuilderT rep m () -> Scope rep -> m (Stms rep) runBuilderT_ m = fmap snd . runBuilderT m -- | Like 'runBuilderT', but get the initial scope from the current -- monad. runBuilderT' :: (MonadFreshNames m, HasScope somerep m, SameScope somerep rep) => BuilderT rep m a -> m (a, Stms rep) runBuilderT' m = do scope <- askScope runBuilderT m $ castScope scope -- | Like 'runBuilderT_', but get the initial scope from the current -- monad. runBuilderT'_ :: (MonadFreshNames m, HasScope somerep m, SameScope somerep rep) => BuilderT rep m a -> m (Stms rep) runBuilderT'_ = fmap snd . runBuilderT' -- | Run a builder action, returning a value and the statements added -- ('addStm') during the action. Assumes that the current monad -- provides initial scope and name source. runBuilder :: (MonadFreshNames m, HasScope somerep m, SameScope somerep rep) => Builder rep a -> m (a, Stms rep) runBuilder m = do types <- askScope modifyNameSource $ runState $ runBuilderT m $ castScope types -- | Like 'runBuilder', but throw away the result and just return the -- added statements. runBuilder_ :: (MonadFreshNames m, HasScope somerep m, SameScope somerep rep) => Builder rep a -> m (Stms rep) runBuilder_ = fmap snd . runBuilder -- | Run a builder that produces a 'Result' and construct a body that -- contains that result alongside the statements produced during the -- builder. runBodyBuilder :: ( Buildable rep, MonadFreshNames m, HasScope somerep m, SameScope somerep rep ) => Builder rep Result -> m (Body rep) runBodyBuilder = fmap (uncurry $ flip insertStms) . runBuilder . fmap (mkBody mempty) -- | Given lambda parameters, Run a builder action that produces the -- statements and returns the 'Result' of the lambda body. runLambdaBuilder :: ( Buildable rep, MonadFreshNames m, HasScope somerep m, SameScope somerep rep ) => [LParam rep] -> Builder rep Result -> m (Lambda rep) runLambdaBuilder params m = do ((res, ret), stms) <- runBuilder . localScope (scopeOfLParams params) $ do res <- m ret <- mapM subExpResType res pure (res, ret) pure $ Lambda params ret $ mkBody stms res -- Utility instance defintions for MTL classes. These require -- UndecidableInstances, but save on typing elsewhere. mapInner :: (Monad m) => ( m (a, (Stms rep, Scope rep)) -> m (b, (Stms rep, Scope rep)) ) -> BuilderT rep m a -> BuilderT rep m b mapInner f (BuilderT m) = BuilderT $ do s <- get (x, s') <- lift $ f $ runStateT m s put s' pure x instance (MonadReader r m) => MonadReader r (BuilderT rep m) where ask = BuilderT $ lift ask local f = mapInner $ local f instance (MonadState s m) => MonadState s (BuilderT rep m) where get = BuilderT $ lift get put = BuilderT . lift . put instance (MonadWriter w m) => MonadWriter w (BuilderT rep m) where tell = BuilderT . lift . tell pass = mapInner $ \m -> pass $ do ((x, f), s) <- m pure ((x, s), f) listen = mapInner $ \m -> do ((x, s), y) <- listen m pure ((x, y), s) instance (MonadError e m) => MonadError e (BuilderT rep m) where throwError = lift . throwError catchError (BuilderT m) f = BuilderT $ catchError m $ unBuilder . f where unBuilder (BuilderT m') = m' futhark-0.25.27/src/Futhark/Builder/000077500000000000000000000000001475065116200171235ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Builder/Class.hs000066400000000000000000000124541475065116200205320ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | This module defines a convenience typeclass for creating -- normalised programs. -- -- See "Futhark.Construct" for a high-level description. module Futhark.Builder.Class ( Buildable (..), mkLet, mkLet', MonadBuilder (..), insertStms, insertStm, letBind, letBindNames, collectStms_, bodyBind, attributing, auxing, module Futhark.MonadFreshNames, ) where import Data.Kind qualified import Futhark.IR import Futhark.MonadFreshNames -- | The class of representations that can be constructed solely from -- an expression, within some monad. Very important: the methods -- should not have any significant side effects! They may be called -- more often than you think, and the results thrown away. If used -- exclusively within a 'MonadBuilder' instance, it is acceptable for -- them to create new bindings, however. class ( ASTRep rep, FParamInfo rep ~ DeclType, LParamInfo rep ~ Type, RetType rep ~ DeclExtType, BranchType rep ~ ExtType ) => Buildable rep where mkExpPat :: [Ident] -> Exp rep -> Pat (LetDec rep) mkExpDec :: Pat (LetDec rep) -> Exp rep -> ExpDec rep mkBody :: Stms rep -> Result -> Body rep mkLetNames :: (MonadFreshNames m, HasScope rep m) => [VName] -> Exp rep -> m (Stm rep) -- | A monad that supports the creation of bindings from expressions -- and bodies from bindings, with a specific rep. This is the main -- typeclass that a monad must implement in order for it to be useful -- for generating or modifying Futhark code. Most importantly -- maintains a current state of 'Stms' (as well as a 'Scope') that -- have been added with 'addStm'. -- -- Very important: the methods should not have any significant side -- effects! They may be called more often than you think, and the -- results thrown away. It is acceptable for them to create new -- bindings, however. class ( ASTRep (Rep m), MonadFreshNames m, Applicative m, Monad m, LocalScope (Rep m) m ) => MonadBuilder m where type Rep m :: Data.Kind.Type mkExpDecM :: Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m)) mkBodyM :: Stms (Rep m) -> Result -> m (Body (Rep m)) mkLetNamesM :: [VName] -> Exp (Rep m) -> m (Stm (Rep m)) -- | Add a statement to the 'Stms' under construction. addStm :: Stm (Rep m) -> m () addStm = addStms . oneStm -- | Add multiple statements to the 'Stms' under construction. addStms :: Stms (Rep m) -> m () -- | Obtain the statements constructed during a monadic action, -- instead of adding them to the state. collectStms :: m a -> m (a, Stms (Rep m)) -- | Add the provided certificates to any statements added during -- execution of the action. certifying :: Certs -> m a -> m a certifying = censorStms . fmap . certify -- | Apply a function to the statements added by this action. censorStms :: (MonadBuilder m) => (Stms (Rep m) -> Stms (Rep m)) -> m a -> m a censorStms f m = do (x, stms) <- collectStms m addStms $ f stms pure x -- | Add the given attributes to any statements added by this action. attributing :: (MonadBuilder m) => Attrs -> m a -> m a attributing attrs = censorStms $ fmap onStm where onStm (Let pat aux e) = Let pat aux {stmAuxAttrs = attrs <> stmAuxAttrs aux} e -- | Add the certificates and attributes to any statements added by -- this action. auxing :: (MonadBuilder m) => StmAux anyrep -> m a -> m a auxing (StmAux cs attrs _) = censorStms $ fmap onStm where onStm (Let pat aux e) = Let pat aux' e where aux' = aux { stmAuxAttrs = attrs <> stmAuxAttrs aux, stmAuxCerts = cs <> stmAuxCerts aux } -- | Add a statement with the given pattern and expression. letBind :: (MonadBuilder m) => Pat (LetDec (Rep m)) -> Exp (Rep m) -> m () letBind pat e = addStm =<< Let pat <$> (defAux <$> mkExpDecM pat e) <*> pure e -- | Construct a 'Stm' from identifiers for the context- and value -- part of the pattern, as well as the expression. mkLet :: (Buildable rep) => [Ident] -> Exp rep -> Stm rep mkLet ids e = let pat = mkExpPat ids e dec = mkExpDec pat e in Let pat (defAux dec) e -- | Like mkLet, but also take attributes and certificates from the -- given 'StmAux'. mkLet' :: (Buildable rep) => [Ident] -> StmAux a -> Exp rep -> Stm rep mkLet' ids (StmAux cs attrs _) e = let pat = mkExpPat ids e dec = mkExpDec pat e in Let pat (StmAux cs attrs dec) e -- | Add a statement with the given pattern element names and -- expression. letBindNames :: (MonadBuilder m) => [VName] -> Exp (Rep m) -> m () letBindNames names e = addStm =<< mkLetNamesM names e -- | As 'collectStms', but throw away the ordinary result. collectStms_ :: (MonadBuilder m) => m a -> m (Stms (Rep m)) collectStms_ = fmap snd . collectStms -- | Add the statements of the body, then return the body result. bodyBind :: (MonadBuilder m) => Body (Rep m) -> m Result bodyBind (Body _ stms res) = do addStms stms pure res -- | Add several bindings at the outermost level of a t'Body'. insertStms :: (Buildable rep) => Stms rep -> Body rep -> Body rep insertStms stms1 (Body _ stms2 res) = mkBody (stms1 <> stms2) res -- | Add a single binding at the outermost level of a t'Body'. insertStm :: (Buildable rep) => Stm rep -> Body rep -> Body rep insertStm = insertStms . oneStm futhark-0.25.27/src/Futhark/CLI/000077500000000000000000000000001475065116200161445ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CLI/Autotune.hs000066400000000000000000000446321475065116200203150ustar00rootroot00000000000000-- | @futhark autotune@ module Futhark.CLI.Autotune (main) where import Control.Monad import Data.ByteString.Char8 qualified as SBS import Data.Function (on) import Data.List (elemIndex, intersect, minimumBy, sort, sortOn) import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T import Data.Text.IO qualified as T import Data.Tree import Futhark.Bench import Futhark.Server import Futhark.Test import Futhark.Util (maxinum, showText) import Futhark.Util.Options import System.Directory import System.Environment (getExecutablePath) import System.Exit import System.FilePath import Text.Read (readMaybe) import Text.Regex.TDFA data AutotuneOptions = AutotuneOptions { optBackend :: String, optFuthark :: Maybe String, optMinRuns :: Int, optTuning :: Maybe String, optExtraOptions :: [String], optVerbose :: Int, optTimeout :: Int, optSkipCompilation :: Bool, optDefaultThreshold :: Int, optTestSpec :: Maybe FilePath } initialAutotuneOptions :: AutotuneOptions initialAutotuneOptions = AutotuneOptions { optBackend = "opencl", optFuthark = Nothing, optMinRuns = 10, optTuning = Just "tuning", optExtraOptions = [], optVerbose = 0, optTimeout = 600, optSkipCompilation = False, optDefaultThreshold = thresholdMax, optTestSpec = Nothing } compileOptions :: AutotuneOptions -> IO CompileOptions compileOptions opts = do futhark <- maybe getExecutablePath pure $ optFuthark opts pure $ CompileOptions { compFuthark = futhark, compBackend = optBackend opts, compOptions = mempty } runOptions :: Int -> AutotuneOptions -> RunOptions runOptions timeout_s opts = RunOptions { runMinRuns = optMinRuns opts, runMinTime = 0.5, runTimeout = timeout_s, runVerbose = optVerbose opts, runConvergencePhase = True, runConvergenceMaxTime = fromIntegral timeout_s, runResultAction = const $ pure (), runProfile = False } type Path = [(T.Text, Int)] regexBlocks :: Regex -> T.Text -> Maybe [T.Text] regexBlocks regex s = do (_, _, _, groups) <- matchM regex s :: Maybe (T.Text, T.Text, T.Text, [T.Text]) Just groups comparisons :: T.Text -> [(T.Text, Int)] comparisons = mapMaybe isComparison . T.lines where regex = makeRegex ("Compared ([^ ]+) <= (-?[0-9]+)" :: String) isComparison l = do [thresh, val] <- regexBlocks regex l val' <- readMaybe $ T.unpack val pure (thresh, val') type RunDataset = Server -> Int -> Path -> IO (Either String ([(T.Text, Int)], Int)) type DatasetName = T.Text serverOptions :: AutotuneOptions -> [String] serverOptions opts = "--default-threshold" : show (optDefaultThreshold opts) : "-L" : optExtraOptions opts checkCmd :: Either CmdFailure a -> IO a checkCmd = either (error . T.unpack . T.unlines . failureMsg) pure setTuningParam :: Server -> T.Text -> Int -> IO () setTuningParam server name val = void $ checkCmd =<< cmdSetTuningParam server name (showText val) setTuningParams :: Server -> Path -> IO () setTuningParams server = mapM_ (uncurry $ setTuningParam server) restoreTuningParams :: AutotuneOptions -> Server -> Path -> IO () restoreTuningParams opts server = mapM_ opt where opt (name, _) = setTuningParam server name (optDefaultThreshold opts) prepare :: AutotuneOptions -> FutharkExe -> FilePath -> IO [(DatasetName, RunDataset, T.Text)] prepare opts futhark prog = do spec <- maybe (testSpecFromProgramOrDie prog) testSpecFromFileOrDie $ optTestSpec opts copts <- compileOptions opts truns <- case testAction spec of RunCases ios _ _ | not $ null ios -> do when (optVerbose opts > 1) $ putStrLn $ unwords ("Entry points:" : map (T.unpack . iosEntryPoint) ios) if optSkipCompilation opts then do exists <- doesFileExist $ binaryName prog if exists then pure ios else do putStrLn $ binaryName prog ++ " does not exist, but --skip-compilation passed." exitFailure else do res <- prepareBenchmarkProgram Nothing copts prog ios case res of Left (err, errstr) -> do putStrLn err maybe (pure ()) SBS.putStrLn errstr exitFailure Right () -> pure ios _ -> fail "Unsupported test spec." let runnableDataset entry_point trun = case runExpectedResult trun of Succeeds expected | null (runTags trun `intersect` ["notune", "disable"]) -> Just ( runDescription trun, \server -> run server entry_point trun expected ) _ -> Nothing fmap concat . forM truns $ \ios -> do let cases = mapMaybe (runnableDataset $ iosEntryPoint ios) (iosTestRuns ios) forM cases $ \(dataset, do_run) -> pure (dataset, do_run, iosEntryPoint ios) where run server entry_point trun expected timeout path = do let bestRuntime (runres, errout, _) = ( comparisons errout, minimum $ map runMicroseconds runres ) ropts = runOptions timeout opts when (optVerbose opts > 1) $ putStrLn ("Trying path: " ++ show path) -- Setting the tuning parameters is a stateful action, so we -- must be careful to restore the defaults below. This is -- because we rely on parameters not in 'path' to have their -- default value. setTuningParams server path either (Left . T.unpack) (Right . bestRuntime) <$> benchmarkDataset server ropts futhark prog entry_point (runInput trun) expected (testRunReferenceOutput prog entry_point trun) <* restoreTuningParams opts server path --- Benchmarking a program data DatasetResult = DatasetResult [(T.Text, Int)] Double deriving (Show) --- Finding initial comparisons. --- Extracting threshold hierarchy. type ThresholdForest = Forest (T.Text, Bool) thresholdMin, thresholdMax :: Int thresholdMin = 1 thresholdMax = 2000000000 -- | Depth-first list of thresholds to tune in order, and a -- corresponding assignment of ancestor thresholds to ensure that they -- are used. tuningPaths :: ThresholdForest -> [(T.Text, Path)] tuningPaths = concatMap (treePaths []) where treePaths ancestors (Node (v, _) children) = concatMap (onChild ancestors v) children ++ [(v, ancestors)] onChild ancestors v child@(Node (_, cmp) _) = treePaths (ancestors ++ [(v, t cmp)]) child t False = thresholdMax t True = thresholdMin allTuningParams :: Server -> IO [(T.Text, T.Text)] allTuningParams server = do entry_points <- checkCmd =<< cmdEntryPoints server param_names <- concat <$> mapM (checkCmd <=< cmdTuningParams server) entry_points param_classes <- mapM (checkCmd <=< cmdTuningParamClass server) param_names pure $ zip param_names param_classes thresholdForest :: Server -> IO ThresholdForest thresholdForest server = do thresholds <- mapMaybe findThreshold <$> allTuningParams server let root (v, _) = ((v, False), []) pure $ unfoldForest (unfold thresholds) $ map root $ filter (null . snd) thresholds where regex = makeRegex ("threshold\\(([^ ]+,)(.*)\\)" :: T.Text) findThreshold :: (T.Text, T.Text) -> Maybe (T.Text, [(T.Text, Bool)]) findThreshold (name, param_class) = do [_, grp] <- regexBlocks regex param_class pure ( name, filter (not . T.null . fst) $ map ( \x -> if "!" `T.isPrefixOf` x then (T.drop 1 x, False) else (x, True) ) $ T.words grp ) unfold thresholds ((parent, parent_cmp), ancestors) = let ancestors' = parent : ancestors isChild (v, v_ancestors) = do cmp <- lookup parent v_ancestors guard $ sort (map fst v_ancestors) == sort (parent : ancestors) pure ((v, cmp), ancestors') in ((parent, parent_cmp), mapMaybe isChild thresholds) -- | The performance difference in percentage that triggers a non-monotonicity -- warning. This is to account for slight variantions in run-time. epsilon :: Double epsilon = 1.02 --- Doing the atual tuning tuneThreshold :: AutotuneOptions -> Server -> [(DatasetName, RunDataset, T.Text)] -> (Path, M.Map DatasetName Int) -> (T.Text, Path) -> IO (Path, M.Map DatasetName Int) tuneThreshold opts server datasets (already_tuned, best_runtimes0) (v, _v_path) = do (tune_result, best_runtimes) <- foldM tuneDataset (Nothing, best_runtimes0) datasets case tune_result of Nothing -> pure ((v, thresholdMin) : already_tuned, best_runtimes) Just (_, threshold) -> pure ((v, threshold) : already_tuned, best_runtimes) where tuneDataset :: (Maybe (Int, Int), M.Map DatasetName Int) -> (DatasetName, RunDataset, T.Text) -> IO (Maybe (Int, Int), M.Map DatasetName Int) tuneDataset (thresholds, best_runtimes) (dataset_name, run, entry_point) = do relevant <- checkCmd =<< cmdTuningParams server entry_point if v `notElem` relevant then do when (optVerbose opts > 0) $ T.putStrLn $ T.unwords [v, "is irrelevant for", entry_point] pure (thresholds, best_runtimes) else do T.putStrLn $ T.unwords [ "Tuning", v, "on entry point", entry_point, "and dataset", dataset_name ] sample_run <- run server (optTimeout opts) ((v, maybe thresholdMax snd thresholds) : already_tuned) case sample_run of Left err -> do -- If the sampling run fails, we treat it as zero information. -- One of our ancestor thresholds will have be set such that -- this path is never taken. when (optVerbose opts > 0) $ putStrLn $ "Sampling run failed:\n" ++ err pure (thresholds, best_runtimes) Right (cmps, t) -> do let (tMin, tMax) = fromMaybe (thresholdMin, thresholdMax) thresholds let ePars = S.toAscList $ S.map snd $ S.filter (candidateEPar (tMin, tMax)) $ S.fromList cmps runner :: Int -> Int -> IO (Maybe Int) runner timeout' threshold = do res <- run server timeout' ((v, threshold) : already_tuned) case res of Right (_, runTime) -> pure $ Just runTime _ -> pure Nothing when (optVerbose opts > 1) $ putStrLn $ unwords ("Got ePars: " : map show ePars) (best_t, newMax) <- binarySearch runner (t, tMax) ePars let newMinIdx = do i <- pred <$> elemIndex newMax ePars if i < 0 then fail "Invalid lower index" else pure i let newMin = maxinum $ catMaybes [Just tMin, fmap (ePars !!) newMinIdx] best_runtimes' <- case dataset_name `M.lookup` best_runtimes of Just rt | fromIntegral rt * epsilon < fromIntegral best_t -> do T.putStrLn $ T.unwords [ "WARNING! Possible non-monotonicity detected. Previous best run-time for dataset", dataset_name, " was", showText rt, "but after tuning threshold", v, "it is", showText best_t ] pure best_runtimes _ -> pure $ M.insertWith min dataset_name best_t best_runtimes pure (Just (newMin, newMax), best_runtimes') bestPair :: [(Int, Int)] -> (Int, Int) bestPair = minimumBy (compare `on` fst) timeout :: Int -> Int -- We wish to let datasets run for the untuned time + 20% + 1 second. timeout elapsed = ceiling (fromIntegral elapsed * 1.2 :: Double) + 1 candidateEPar :: (Int, Int) -> (T.Text, Int) -> Bool candidateEPar (tMin, tMax) (threshold, ePar) = ePar > tMin && ePar < tMax && threshold == v binarySearch :: (Int -> Int -> IO (Maybe Int)) -> (Int, Int) -> [Int] -> IO (Int, Int) binarySearch runner best@(best_t, best_e_par) xs = case splitAt (length xs `div` 2) xs of (lower, middle : middle' : upper) -> do when (optVerbose opts > 0) $ putStrLn $ unwords [ "Trying e_par", show middle, "and", show middle' ] candidate <- runner (timeout best_t) middle candidate' <- runner (timeout best_t) middle' case (candidate, candidate') of (Just new_t, Just new_t') -> if new_t < new_t' then -- recurse into lower half binarySearch runner (bestPair [(new_t, middle), best]) lower else -- recurse into upper half binarySearch runner (bestPair [(new_t', middle'), best]) upper (Just new_t, Nothing) -> -- recurse into lower half binarySearch runner (bestPair [(new_t, middle), best]) lower (Nothing, Just new_t') -> -- recurse into upper half binarySearch runner (bestPair [(new_t', middle'), best]) upper (Nothing, Nothing) -> do when (optVerbose opts > 2) $ putStrLn $ unwords [ "Timing failed for candidates", show middle, "and", show middle' ] pure (best_t, best_e_par) (_, _) -> do when (optVerbose opts > 0) $ putStrLn $ unwords ["Trying e_pars", show xs] candidates <- catMaybes . zipWith (fmap . flip (,)) xs <$> mapM (runner $ timeout best_t) xs pure $ bestPair $ best : candidates --- CLI tune :: AutotuneOptions -> FilePath -> IO Path tune opts prog = do futhark <- fmap FutharkExe $ maybe getExecutablePath pure $ optFuthark opts putStrLn $ "Compiling " ++ prog ++ "..." datasets <- prepare opts futhark prog putStrLn $ "Running with options: " ++ unwords (serverOptions opts) let progbin = "." dropExtension prog withServer (futharkServerCfg progbin (serverOptions opts)) $ \server -> do forest <- thresholdForest server when (optVerbose opts > 0) $ putStrLn $ ("Threshold forest:\n" <>) $ drawForest (map (fmap show) forest) fmap fst . foldM (tuneThreshold opts server datasets) ([], mempty) $ tuningPaths forest runAutotuner :: AutotuneOptions -> FilePath -> IO () runAutotuner opts prog = do best <- tune opts prog let tuning = T.unlines $ do (s, n) <- sortOn fst best pure $ s <> "=" <> showText n case optTuning opts of Nothing -> pure () Just suffix -> do T.writeFile (prog <.> suffix) tuning putStrLn $ "Wrote " ++ prog <.> suffix T.putStrLn $ "Result of autotuning:\n" <> tuning supportedBackends :: [String] supportedBackends = ["opencl", "cuda", "hip"] commandLineOptions :: [FunOptDescr AutotuneOptions] commandLineOptions = [ Option "r" ["runs"] ( ReqArg ( \n -> case reads n of [(n', "")] | n' >= 0 -> Right $ \config -> config {optMinRuns = n'} _ -> Left $ optionsError $ "'" ++ n ++ "' is not a non-negative integer." ) "RUNS" ) "Run each test case this many times.", Option [] ["backend"] ( ReqArg ( \backend -> if backend `elem` supportedBackends then Right $ \config -> config {optBackend = backend} else Left $ optionsError $ "autotuning is only supported for these backends: " <> unwords supportedBackends ) "BACKEND" ) "The backend used (defaults to 'opencl').", Option [] ["futhark"] ( ReqArg (\prog -> Right $ \config -> config {optFuthark = Just prog}) "PROGRAM" ) "The binary used for operations (defaults to 'futhark').", Option [] ["pass-option"] ( ReqArg ( \opt -> Right $ \config -> config {optExtraOptions = opt : optExtraOptions config} ) "OPT" ) "Pass this option to programs being run.", Option [] ["tuning"] ( ReqArg (\s -> Right $ \config -> config {optTuning = Just s}) "EXTENSION" ) "Write tuning files with this extension (default: .tuning).", Option [] ["timeout"] ( ReqArg ( \n -> case reads n of [(n', "")] -> Right $ \config -> config {optTimeout = n'} _ -> Left $ optionsError $ "'" ++ n ++ "' is not a non-negative integer." ) "SECONDS" ) "Initial tuning timeout for each dataset. Later tuning runs are based off of the runtime of the first run.", Option [] ["skip-compilation"] (NoArg $ Right $ \config -> config {optSkipCompilation = True}) "Use already compiled program.", Option "v" ["verbose"] (NoArg $ Right $ \config -> config {optVerbose = optVerbose config + 1}) "Enable logging. Pass multiple times for more.", Option [] ["spec-file"] (ReqArg (\s -> Right $ \config -> config {optTestSpec = Just s}) "FILE") "Use test specification from this file." ] -- | Run @futhark autotune@ main :: String -> [String] -> IO () main = mainWithOptions initialAutotuneOptions commandLineOptions "options... program" $ \progs config -> case progs of [prog] -> Just $ runAutotuner config prog _ -> Nothing futhark-0.25.27/src/Futhark/CLI/Bench.hs000066400000000000000000000517061475065116200175300ustar00rootroot00000000000000-- | @futhark bench@ module Futhark.CLI.Bench (main) where import Control.Exception import Control.Monad import Control.Monad.IO.Class (liftIO) import Data.Bifunctor (first) import Data.ByteString.Char8 qualified as SBS import Data.ByteString.Lazy.Char8 qualified as LBS import Data.Either import Data.Function ((&)) import Data.IORef import Data.List (intersect, sortBy) import Data.Map qualified as M import Data.Maybe import Data.Ord import Data.Text qualified as T import Data.Text.IO qualified as T import Data.Time.Clock (NominalDiffTime, UTCTime, diffUTCTime, getCurrentTime) import Data.Vector.Unboxed qualified as U import Futhark.Bench import Futhark.Server import Futhark.Test import Futhark.Util (atMostChars, fancyTerminal, pmapIO, showText) import Futhark.Util.Options import Futhark.Util.Pretty (AnsiStyle, Color (..), annotate, bold, color, line, pretty, prettyText, putDoc) import Futhark.Util.ProgressBar import Statistics.Resampling (Estimator (..), resample) import Statistics.Resampling.Bootstrap (bootstrapBCA) import Statistics.Types (cl95, confIntLDX, confIntUDX, estError, estPoint) import System.Console.ANSI (clearLine) import System.Directory import System.Environment import System.Exit import System.FilePath import System.IO import System.Random.MWC (create) import Text.Printf import Text.Regex.TDFA putStyleLn :: AnsiStyle -> T.Text -> IO () putStyleLn s t = putDoc $ annotate s (pretty t <> line) putRedLn, putBoldRedLn, putBoldLn :: T.Text -> IO () putRedLn = putStyleLn (color Red) putBoldRedLn = putStyleLn (color Red <> bold) putBoldLn = putStyleLn bold data BenchOptions = BenchOptions { optBackend :: String, optFuthark :: Maybe String, optRunner :: String, optMinRuns :: Int, optMinTime :: NominalDiffTime, optExtraOptions :: [String], optCompilerOptions :: [String], optJSON :: Maybe FilePath, optTimeout :: Int, optSkipCompilation :: Bool, optExcludeCase :: [T.Text], optIgnoreFiles :: [Regex], optEntryPoint :: Maybe String, optTuning :: Maybe String, optCacheExt :: Maybe String, optConvergencePhase :: Bool, optConvergenceMaxTime :: NominalDiffTime, optConcurrency :: Maybe Int, optProfile :: Bool, optVerbose :: Int, optTestSpec :: Maybe FilePath } initialBenchOptions :: BenchOptions initialBenchOptions = BenchOptions { optBackend = "c", optFuthark = Nothing, optRunner = "", optMinRuns = 10, optMinTime = 0.5, optExtraOptions = [], optCompilerOptions = [], optJSON = Nothing, optTimeout = -1, optSkipCompilation = False, optExcludeCase = ["nobench", "disable"], optIgnoreFiles = [], optEntryPoint = Nothing, optTuning = Just "tuning", optCacheExt = Nothing, optConvergencePhase = True, optConvergenceMaxTime = 5 * 60, optConcurrency = Nothing, optProfile = False, optVerbose = 0, optTestSpec = Nothing } runBenchmarks :: BenchOptions -> [FilePath] -> IO () runBenchmarks opts paths = do -- We force line buffering to ensure that we produce running output. -- Otherwise, CI tools and the like may believe we are hung and kill -- us. hSetBuffering stdout LineBuffering benchmarks <- filter (not . ignored . fst) <$> testSpecsFromPathsOrDie paths -- Try to avoid concurrency at both program and data set level. let opts' = if length paths /= 1 then opts {optConcurrency = Just 1} else opts (skipped_benchmarks, compiled_benchmarks) <- partitionEithers <$> pmapIO (optConcurrency opts) (compileBenchmark opts') benchmarks when (anyFailedToCompile skipped_benchmarks) exitFailure putStrLn $ "Reporting arithmetic mean runtime of at least " <> show (optMinRuns opts) <> " runs for each dataset (min " <> show (optMinTime opts) <> ")." when (optConvergencePhase opts) . putStrLn $ "More runs automatically performed for up to " <> show (optConvergenceMaxTime opts) <> " to ensure accurate measurement." futhark <- FutharkExe . compFuthark <$> compileOptions opts maybe_results <- mapM (runBenchmark opts futhark) (sortBy (comparing fst) compiled_benchmarks) let results = concat $ catMaybes maybe_results case optJSON opts of Nothing -> pure () Just file -> LBS.writeFile file $ encodeBenchResults results when (any isNothing maybe_results || anyFailed results) exitFailure where ignored f = any (`match` f) $ optIgnoreFiles opts anyFailed :: [BenchResult] -> Bool anyFailed = any failedBenchResult where failedBenchResult (BenchResult _ xs) = any failedResult xs failedResult (DataResult _ Left {}) = True failedResult _ = False anyFailedToCompile :: [SkipReason] -> Bool anyFailedToCompile = not . all (== Skipped) data SkipReason = Skipped | FailedToCompile deriving (Eq) compileOptions :: BenchOptions -> IO CompileOptions compileOptions opts = do futhark <- maybe getExecutablePath pure $ optFuthark opts pure $ CompileOptions { compFuthark = futhark, compBackend = optBackend opts, compOptions = optCompilerOptions opts } compileBenchmark :: BenchOptions -> (FilePath, ProgramTest) -> IO (Either SkipReason (FilePath, [InputOutputs])) compileBenchmark opts (program, program_spec) = do spec <- maybe (pure program_spec) testSpecFromFileOrDie $ optTestSpec opts case testAction spec of RunCases cases _ _ | null $ optExcludeCase opts `intersect` testTags spec <> testTags program_spec, any hasRuns cases -> if optSkipCompilation opts then do exists <- doesFileExist $ binaryName program if exists then pure $ Right (program, cases) else do putStrLn $ binaryName program ++ " does not exist, but --skip-compilation passed." pure $ Left FailedToCompile else do putStr $ "Compiling " ++ program ++ "...\n" compile_opts <- compileOptions opts res <- prepareBenchmarkProgram (optConcurrency opts) compile_opts program cases case res of Left (err, errstr) -> do putRedLn $ T.pack err maybe (pure ()) SBS.putStrLn errstr pure $ Left FailedToCompile Right () -> pure $ Right (program, cases) _ -> pure $ Left Skipped where hasRuns (InputOutputs _ runs) = not $ null runs withProgramServer :: FilePath -> FilePath -> [String] -> (Server -> IO a) -> IO (Maybe a) withProgramServer program runner extra_options f = do -- Explicitly prefixing the current directory is necessary for -- readProcessWithExitCode to find the binary when binOutputf has -- no path component. let binOutputf = dropExtension program binpath = "." binOutputf (to_run, to_run_args) | null runner = (binpath, extra_options) | otherwise = (runner, binpath : extra_options) liftIO $ (Just <$> withServer (futharkServerCfg to_run to_run_args) f) `catch` onError where onError :: SomeException -> IO (Maybe a) onError e = do putBoldRedLn $ "\nFailed to run " <> T.pack program putRedLn $ showText e pure Nothing -- Truncate dataset name display after this many characters. maxDatasetNameLength :: Int maxDatasetNameLength = 40 runBenchmark :: BenchOptions -> FutharkExe -> (FilePath, [InputOutputs]) -> IO (Maybe [BenchResult]) runBenchmark opts futhark (program, cases) = do (tuning_opts, tuning_desc) <- determineTuning (optTuning opts) program let runopts = optExtraOptions opts ++ tuning_opts ++ determineCache (optCacheExt opts) program ++ if optProfile opts then ["--profile", "--log"] else [] withProgramServer program (optRunner opts) runopts $ \server -> mapM (forInputOutputs server tuning_desc) $ filter relevant cases where forInputOutputs server tuning_desc (InputOutputs entry_name runs) = do putBoldLn $ "\n" <> T.pack program' <> T.pack tuning_desc <> ":" BenchResult program' . catMaybes <$> mapM (runBenchmarkCase server opts futhark program entry_name pad_to) runs where program' = if entry_name == "main" then program else program ++ ":" ++ T.unpack entry_name relevant = maybe (const True) (==) (optEntryPoint opts) . T.unpack . iosEntryPoint len = T.length . atMostChars maxDatasetNameLength . runDescription pad_to = foldl max 0 $ concatMap (map len . iosTestRuns) cases runOptions :: ((Int, Maybe Double) -> IO ()) -> BenchOptions -> RunOptions runOptions f opts = RunOptions { runMinRuns = optMinRuns opts, runMinTime = optMinTime opts, runTimeout = optTimeout opts, runVerbose = optVerbose opts, runConvergencePhase = optConvergencePhase opts, runConvergenceMaxTime = optConvergenceMaxTime opts, runResultAction = f, runProfile = optProfile opts } descText :: T.Text -> Int -> T.Text descText desc pad_to = desc <> ": " <> T.replicate (pad_to - T.length desc) " " progress :: Double -> T.Text progress elapsed = progressBar ( ProgressBar { progressBarSteps = 10, progressBarBound = 1, progressBarElapsed = elapsed } ) interimResult :: Int -> Int -> Double -> T.Text interimResult us_sum runs elapsed = T.pack (printf "%10.0fμs " avg) <> progress elapsed <> (" " <> prettyText runs <> " runs") where avg :: Double avg = fromIntegral us_sum / fromIntegral runs convergenceBar :: (T.Text -> IO ()) -> IORef Int -> Int -> Int -> Double -> IO () convergenceBar p spin_count us_sum i rse' = do spin_idx <- readIORef spin_count let spin = progressSpinner spin_idx p $ T.pack $ printf "%10.0fμs %s (RSE of mean: %2.4f; %4d runs)" avg spin rse' i writeIORef spin_count (spin_idx + 1) where avg :: Double avg = fromIntegral us_sum / fromIntegral i data BenchPhase = Initial | Convergence mkProgressPrompt :: BenchOptions -> Int -> T.Text -> UTCTime -> IO ((Maybe Int, Maybe Double) -> IO ()) mkProgressPrompt opts pad_to dataset_desc start_time | fancyTerminal = do count <- newIORef (0, 0) phase_var <- newIORef Initial spin_count <- newIORef 0 pure $ \(us, rse) -> do T.putStr "\r" -- Go to start of line. let p s = T.putStr $ descText (atMostChars maxDatasetNameLength dataset_desc) pad_to <> s (us_sum, i) <- readIORef count now <- liftIO getCurrentTime let determineProgress i' = let time_elapsed = toDouble (realToFrac (diffUTCTime now start_time) / optMinTime opts) runs_elapsed = fromIntegral i' / fromIntegral (optMinRuns opts) in -- The progress bar is the _shortest_ of the -- time-based or runs-based estimate. This is -- intended to avoid a situation where the progress -- bar is full but stuff is still happening. On the -- other hand, it means it can sometimes shrink. min time_elapsed runs_elapsed phase <- readIORef phase_var case (us, phase, rse) of (Nothing, _, _) -> let elapsed = determineProgress i in p $ T.pack (replicate 13 ' ') <> progress elapsed (Just us', Initial, Nothing) -> do let us_sum' = us_sum + us' i' = i + 1 writeIORef count (us_sum', i') let elapsed = determineProgress i' p $ interimResult us_sum' i' elapsed (Just us', Initial, Just rse') -> do -- Switched from phase 1 to convergence; discard all -- prior results. writeIORef count (us', 1) writeIORef phase_var Convergence convergenceBar p spin_count us' 1 rse' (Just us', Convergence, Just rse') -> do let us_sum' = us_sum + us' i' = i + 1 writeIORef count (us_sum', i') convergenceBar p spin_count us_sum' i' rse' (Just _, Convergence, Nothing) -> pure () -- Probably should not happen. putStr " " -- Just to move the cursor away from the progress bar. hFlush stdout | otherwise = do T.putStr $ descText dataset_desc pad_to hFlush stdout pure $ const $ pure () where toDouble = fromRational . toRational reportResult :: [RunResult] -> (Double, Double) -> IO () reportResult results (ci_lower, ci_upper) = do let runtimes = map (fromIntegral . runMicroseconds) results avg = sum runtimes / fromIntegral (length runtimes) :: Double putStrLn $ printf "%10.0fμs (95%% CI: [%10.1f, %10.1f])" avg ci_lower ci_upper runBenchmarkCase :: Server -> BenchOptions -> FutharkExe -> FilePath -> T.Text -> Int -> TestRun -> IO (Maybe DataResult) runBenchmarkCase _ _ _ _ _ _ (TestRun _ _ RunTimeFailure {} _ _) = pure Nothing -- Not our concern, we are not a testing tool. runBenchmarkCase _ opts _ _ _ _ (TestRun tags _ _ _ _) | any (`elem` tags) $ optExcludeCase opts = pure Nothing runBenchmarkCase server opts futhark program entry pad_to tr = do let (TestRun _ input_spec (Succeeds expected_spec) _ dataset_desc) = tr start_time <- liftIO getCurrentTime prompt <- mkProgressPrompt opts pad_to dataset_desc start_time -- Report the dataset name before running the program, so that if an -- error occurs it's easier to see where. prompt (Nothing, Nothing) res <- benchmarkDataset server (runOptions (prompt . first Just) opts) futhark program entry input_spec expected_spec (testRunReferenceOutput program entry tr) when fancyTerminal $ do clearLine T.putStr "\r" T.putStr $ descText (atMostChars maxDatasetNameLength dataset_desc) pad_to case res of Left err -> liftIO $ do putStrLn "" putRedLn err pure $ Just $ DataResult dataset_desc $ Left err Right (runtimes, errout, report) -> do let vec_runtimes = U.fromList $ map (fromIntegral . runMicroseconds) runtimes g <- create resampled <- liftIO $ resample g [Mean] 70000 vec_runtimes let bootstrapCI = case bootstrapBCA cl95 vec_runtimes resampled of boot : _ -> ( estPoint boot - confIntLDX (estError boot), estPoint boot + confIntUDX (estError boot) ) _ -> (0, 0) reportResult runtimes bootstrapCI -- We throw away the 'errout' unless profiling is enabled, -- because it is otherwise useless and adds too much to the -- .json file size. let errout' = guard (optProfile opts) >> Just errout report' = guard (optProfile opts) >> Just report Result runtimes (getMemoryUsage report) errout' report' & Right & DataResult dataset_desc & Just & pure getMemoryUsage :: ProfilingReport -> M.Map T.Text Int getMemoryUsage = fmap fromInteger . profilingMemory commandLineOptions :: [FunOptDescr BenchOptions] commandLineOptions = [ Option "r" ["runs"] ( ReqArg ( \n -> case reads n of [(n', "")] | n' > 0 -> Right $ \config -> config { optMinRuns = n' } _ -> Left . optionsError $ "'" ++ n ++ "' is not a positive integer." ) "RUNS" ) "Run each test case this many times.", Option [] ["backend"] ( ReqArg (\backend -> Right $ \config -> config {optBackend = backend}) "PROGRAM" ) "The compiler used (defaults to 'futhark-c').", Option [] ["futhark"] ( ReqArg (\prog -> Right $ \config -> config {optFuthark = Just prog}) "PROGRAM" ) "The binary used for operations (defaults to same binary as 'futhark bench').", Option [] ["runner"] (ReqArg (\prog -> Right $ \config -> config {optRunner = prog}) "PROGRAM") "The program used to run the Futhark-generated programs (defaults to nothing).", Option "p" ["pass-option"] ( ReqArg ( \opt -> Right $ \config -> config {optExtraOptions = opt : optExtraOptions config} ) "OPT" ) "Pass this option to programs being run.", Option [] ["pass-compiler-option"] ( ReqArg ( \opt -> Right $ \config -> config {optCompilerOptions = opt : optCompilerOptions config} ) "OPT" ) "Pass this option to the compiler (or typechecker if in -t mode).", Option [] ["json"] ( ReqArg ( \file -> Right $ \config -> config {optJSON = Just file} ) "FILE" ) "Scatter results in JSON format here.", Option [] ["timeout"] ( ReqArg ( \n -> case reads n of [(n', "")] | n' < max_timeout -> Right $ \config -> config {optTimeout = fromIntegral n'} _ -> Left . optionsError $ "'" ++ n ++ "' is not an integer smaller than" ++ show max_timeout ++ "." ) "SECONDS" ) "Number of seconds before a dataset is aborted.", Option [] ["skip-compilation"] (NoArg $ Right $ \config -> config {optSkipCompilation = True}) "Use already compiled server-mode program.", Option [] ["exclude-case"] ( ReqArg ( \s -> Right $ \config -> config {optExcludeCase = T.pack s : optExcludeCase config} ) "TAG" ) "Do not run test cases with this tag.", Option [] ["ignore-files"] ( ReqArg ( \s -> Right $ \config -> config {optIgnoreFiles = makeRegex s : optIgnoreFiles config} ) "REGEX" ) "Ignore files matching this regular expression.", Option "e" ["entry-point"] ( ReqArg ( \s -> Right $ \config -> config {optEntryPoint = Just s} ) "NAME" ) "Only run this entry point.", Option [] ["tuning"] ( ReqArg (\s -> Right $ \config -> config {optTuning = Just s}) "EXTENSION" ) "Look for tuning files with this extension (defaults to .tuning).", Option [] ["cache-extension"] ( ReqArg (\s -> Right $ \config -> config {optCacheExt = Just s}) "EXTENSION" ) "Use cache files with this extension (none by default).", Option [] ["no-tuning"] (NoArg $ Right $ \config -> config {optTuning = Nothing}) "Do not load tuning files.", Option [] ["no-convergence-phase"] (NoArg $ Right $ \config -> config {optConvergencePhase = False}) "Do not run convergence phase.", Option [] ["convergence-max-seconds"] ( ReqArg ( \n -> case reads n of [(n', "")] | n' > 0 -> Right $ \config -> config {optConvergenceMaxTime = fromInteger n'} _ -> Left . optionsError $ "'" ++ n ++ "' is not a positive integer." ) "NUM" ) "Limit convergence phase to this number of seconds.", Option [] ["concurrency"] ( ReqArg ( \n -> case reads n of [(n', "")] | n' > 0 -> Right $ \config -> config {optConcurrency = Just n'} _ -> Left . optionsError $ "'" ++ n ++ "' is not a positive integer." ) "NUM" ) "Number of benchmarks to prepare (not run) concurrently.", Option [] ["spec-file"] (ReqArg (\s -> Right $ \config -> config {optTestSpec = Just s}) "FILE") "Use test specification from this file.", Option "v" ["verbose"] (NoArg $ Right $ \config -> config {optVerbose = optVerbose config + 1}) "Enable logging. Pass multiple times for more.", Option "P" ["profile"] (NoArg $ Right $ \config -> config {optProfile = True}) "Collect profiling information." ] where max_timeout :: Int max_timeout = maxBound `div` 1000000 excludeBackend :: BenchOptions -> BenchOptions excludeBackend config = config { optExcludeCase = "no_" <> T.pack (optBackend config) : optExcludeCase config } -- | Run @futhark bench@. main :: String -> [String] -> IO () main = mainWithOptions initialBenchOptions commandLineOptions "options... programs..." $ \progs config -> case progs of [] -> Nothing _ | optProfile config && isNothing (optJSON config) -> Just $ optionsError "--profile cannot be used without --json." | otherwise -> Just $ runBenchmarks (excludeBackend config) progs futhark-0.25.27/src/Futhark/CLI/Benchcmp.hs000066400000000000000000000265161475065116200202310ustar00rootroot00000000000000-- | @futhark benchcmp@ module Futhark.CLI.Benchcmp (main) where import Control.Exception (catch) import Data.Bifunctor (Bifunctor (bimap, first, second)) import Data.ByteString.Lazy.Char8 qualified as LBS import Data.Either qualified as E import Data.List qualified as L import Data.Map qualified as M import Data.Text qualified as T import Data.Vector qualified as V import Futhark.Bench import Futhark.Util (showText) import Futhark.Util.Options (mainWithOptions) import Statistics.Sample qualified as S import System.Console.ANSI (hSupportsANSI) import System.IO (stdout) import Text.Printf (printf) -- | Record that summerizes a comparison between two benchmarks. data SpeedUp = SpeedUp { -- | What factor the benchmark is improved by. speedup :: Double, -- | Memory usage. memoryUsage :: M.Map T.Text Double, -- | If the speedup was significant. significant :: Bool } deriving (Show) -- | Terminal colors used when printing the comparisons. Some of these are not -- colors ways of emphasising text. data Colors = Colors { -- | The header color. header :: T.Text, -- | Okay color okblue :: T.Text, -- | A second okay color okgreen :: T.Text, -- | Warning color. warning :: T.Text, -- | When something fails. failing :: T.Text, -- | Default color. endc :: T.Text, -- | Bold text. bold :: T.Text, -- | Underline text. underline :: T.Text } -- | Colors to use for a terminal device. ttyColors :: Colors ttyColors = Colors { header = "\ESC[95m", okblue = "\ESC[94m", okgreen = "\ESC[92m", warning = "\ESC[93m", failing = "\ESC[91m", endc = "\ESC[0m", bold = "\ESC[1m", underline = "\ESC[4m" } -- | Colors to use for a non-terminal device. nonTtyColors :: Colors nonTtyColors = Colors { header = "", okblue = "", okgreen = "", warning = "", failing = "", endc = "", bold = "", underline = "" } -- | Reads a file without throwing an error. readFileSafely :: T.Text -> IO (Either T.Text LBS.ByteString) readFileSafely filepath = (Right <$> LBS.readFile (T.unpack filepath)) `catch` couldNotRead where couldNotRead e = pure $ Left $ showText (e :: IOError) -- | Converts DataResults to a Map with the text as a key. toDataResultsMap :: [DataResult] -> M.Map T.Text (Either T.Text Result) toDataResultsMap = M.fromList . fmap toTuple where toTuple (DataResult dataset dataResults) = (dataset, dataResults) -- | Converts BenchResults to a Map with the file path as a key. toBenchResultsMap :: [BenchResult] -> M.Map T.Text (M.Map T.Text (Either T.Text Result)) toBenchResultsMap = M.fromList . fmap toTuple where toTuple (BenchResult path dataResults) = (T.pack path, toDataResultsMap dataResults) -- | Given a file path to a json file which has the form of a futhark benchmark -- result, it will try to parse the file to a Map of Maps. The key -- in the outer most dictionary is a file path the inner key is the dataset. decodeFileBenchResultsMap :: T.Text -> IO (Either T.Text (M.Map T.Text (M.Map T.Text (Either T.Text Result)))) decodeFileBenchResultsMap path = do file <- readFileSafely path pure $ toBenchResultsMap <$> (file >>= (first T.pack . decodeBenchResults)) -- | Will return a text with an error saying there is a missing program in a -- given result. formatMissingProg :: T.Text -> T.Text -> T.Text -> T.Text formatMissingProg = ((T.pack .) .) . printf "In %s but not %s: program %s" -- | Will return a text with an error saying there is a missing dataset in a -- given result. formatMissingData :: T.Text -> T.Text -> T.Text -> T.Text -> T.Text formatMissingData = (((T.pack .) .) .) . printf "In %s but not %s: program %s dataset %s" -- | Will return texts that say there are a missing program. formatManyMissingProg :: T.Text -> T.Text -> [T.Text] -> [T.Text] formatManyMissingProg a_path b_path = zipWith3 formatMissingProg a_paths b_paths where a_paths = repeat a_path b_paths = repeat b_path -- | Will return texts that say there are missing datasets for a program. formatManyMissingData :: T.Text -> T.Text -> T.Text -> [T.Text] -> [T.Text] formatManyMissingData prog a_path b_path = L.zipWith4 formatMissingData a_paths b_paths progs where a_paths = repeat a_path b_paths = repeat b_path progs = repeat prog -- | Finds the keys two Maps does not have in common and returns a appropiate -- error based on the functioned passed. missingResults :: (T.Text -> T.Text -> [T.Text] -> [T.Text]) -> T.Text -> T.Text -> M.Map T.Text a -> M.Map T.Text b -> [T.Text] missingResults toMissingMap a_path b_path a_results b_results = missing where a_keys = M.keys a_results b_keys = M.keys b_results a_missing = toMissingMap a_path b_path $ a_keys L.\\ b_keys b_missing = toMissingMap b_path a_path $ b_keys L.\\ a_keys missing = a_missing `L.union` b_missing -- | Compares the memory usage of two results. computeMemoryUsage :: M.Map T.Text Int -> M.Map T.Text Int -> M.Map T.Text Double computeMemoryUsage a b = M.intersectionWith divide b $ M.filter (/= 0) a where divide x y = fromIntegral x / fromIntegral y -- | Compares two results and thereby computes the Speed Up records. compareResult :: Result -> Result -> SpeedUp compareResult a b = SpeedUp { speedup = speedup', significant = significant', memoryUsage = memory_usage } where runResultToDouble :: RunResult -> Double runResultToDouble = fromIntegral . runMicroseconds toVector = V.fromList . (runResultToDouble <$>) . runResults a_memory_usage = memoryMap a b_memory_usage = memoryMap b a_run_results = toVector a b_run_results = toVector b a_std = S.stdDev a_run_results b_std = S.stdDev b_run_results a_mean = S.mean a_run_results b_mean = S.mean b_run_results diff = abs $ a_mean - b_mean speedup' = a_mean / b_mean significant' = diff > a_std / 2 + b_std / 2 memory_usage = computeMemoryUsage a_memory_usage b_memory_usage -- | Given two Maps containing datasets as keys and results as values, compare -- the results and return the errors in a tuple. compareDataResults :: T.Text -> T.Text -> T.Text -> M.Map T.Text (Either T.Text Result) -> M.Map T.Text (Either T.Text Result) -> (M.Map T.Text SpeedUp, ([T.Text], [T.Text])) compareDataResults prog a_path b_path a_data b_data = result where formatMissing = formatManyMissingData prog partition = E.partitionEithers . fmap sequence . M.toList (a_errors, a_data') = second M.fromList $ partition a_data (b_errors, b_data') = second M.fromList $ partition b_data missing = missingResults formatMissing a_path b_path a_data' b_data' exists = M.intersectionWith compareResult a_data' b_data' errors = a_errors ++ b_errors result = (exists, (errors, missing)) -- | Given two Maps containing program file paths as keys and values as datasets -- with results. Compare the results for each dataset in each program and -- return the errors in a tuple. compareBenchResults :: T.Text -> T.Text -> M.Map T.Text (M.Map T.Text (Either T.Text Result)) -> M.Map T.Text (M.Map T.Text (Either T.Text Result)) -> (M.Map T.Text (M.Map T.Text SpeedUp), ([T.Text], [T.Text])) compareBenchResults a_path b_path a_bench b_bench = (exists, errors_missing) where missing = missingResults formatManyMissingProg a_path b_path a_bench b_bench result = M.intersectionWithKey auxiliary a_bench b_bench auxiliary prog = compareDataResults prog a_path b_path exists = M.filter (not . null) $ fst <$> result errors_missing' = bimap concat concat . unzip . M.elems $ snd <$> result errors_missing = second (missing ++) errors_missing' -- | Formats memory usage such that it is human readable. If the memory usage -- is not significant an empty text is returned. memoryFormatter :: Colors -> T.Text -> Double -> T.Text memoryFormatter colors key value | value < 0.99 = memoryFormat $ okgreen colors | value > 1.01 = memoryFormat $ failing colors | otherwise = "" where memoryFormat c = T.pack $ printf "%s%4.2fx@%s%s" c value key endc' endc' = endc colors -- | Given a SpeedUp record the memory usage will be formatted to a colored -- human readable text. toMemoryText :: Colors -> SpeedUp -> T.Text toMemoryText colors data_result | T.null memory_text = "" | otherwise = " (mem: " <> memory_text <> ")" where memory_text = M.foldrWithKey formatFolder "" memory memory = memoryUsage data_result formatFolder key value lst = lst <> memoryFormatter colors key value -- | Given a text shorten it to a given length and add a suffix as the last -- word. shorten :: Int -> T.Text -> T.Text -> T.Text shorten c end string | T.length string > c = (T.unwords . init $ T.words shortened) <> " " <> end | otherwise = string where end_len = T.length end (shortened, _) = T.splitAt (c - end_len) string -- | Given a text add padding to the right of the text in form of spaces. rightPadding :: Int -> T.Text -> T.Text rightPadding c = T.pack . printf s where s = "%-" <> show c <> "s" -- | Given a SpeedUp record print the SpeedUp in a human readable manner. printSpeedUp :: Colors -> T.Text -> SpeedUp -> IO () printSpeedUp colors dataset data_result = do let color | significant data_result && speedup data_result > 1.01 = okgreen colors | significant data_result && speedup data_result < 0.99 = failing colors | otherwise = "" let short_dataset = rightPadding 64 . (<> ":") $ shorten 63 "[...]" dataset let memoryText = toMemoryText colors data_result let speedup' = speedup data_result let endc' = endc colors let format = " %s%s%10.2fx%s%s" putStrLn $ printf format short_dataset color speedup' endc' memoryText -- | Given a Map of SpeedUp records where the key is the program, print the -- SpeedUp in a human readable manner. printProgSpeedUps :: Colors -> T.Text -> M.Map T.Text SpeedUp -> IO () printProgSpeedUps colors prog bench_result = do putStrLn "" putStrLn $ printf "%s%s%s%s" (header colors) (bold colors) prog (endc colors) mapM_ (uncurry (printSpeedUp colors)) $ M.toList bench_result -- | Given a Map of programs with dataset speedups and relevant errors, print -- the errors and print the speedups in a human readable manner. printComparisons :: Colors -> M.Map T.Text (M.Map T.Text SpeedUp) -> ([T.Text], [T.Text]) -> IO () printComparisons colors speedups (errors, missing) = do mapM_ (putStrLn . T.unpack) $ L.sort missing mapM_ (putStrLn . T.unpack) $ L.sort errors mapM_ (uncurry (printProgSpeedUps colors)) $ M.toList speedups -- | Run @futhark benchcmp@ main :: String -> [String] -> IO () main = mainWithOptions () [] " " f where f [a_path', b_path'] () = Just $ do let a_path = T.pack a_path' let b_path = T.pack b_path' a_either <- decodeFileBenchResultsMap a_path b_either <- decodeFileBenchResultsMap b_path isTty <- hSupportsANSI stdout let colors = if isTty then ttyColors else nonTtyColors let comparePrint = (uncurry (printComparisons colors) .) . compareBenchResults a_path b_path case (a_either, b_either) of (Left a, Left b) -> putStrLn . T.unpack $ (a <> "\n" <> b) (Left a, _) -> putStrLn . T.unpack $ a (_, Left b) -> putStrLn . T.unpack $ b (Right a, Right b) -> comparePrint a b f _ _ = Nothing futhark-0.25.27/src/Futhark/CLI/C.hs000066400000000000000000000006761475065116200166730ustar00rootroot00000000000000-- | @futhark c@ module Futhark.CLI.C (main) where import Futhark.Actions (compileCAction) import Futhark.Compiler.CLI import Futhark.Passes (seqmemPipeline) -- | Run @futhark c@ main :: String -> [String] -> IO () main = compilerMain () [] "Compile sequential C" "Generate sequential C code from optimised Futhark program." seqmemPipeline $ \fcfg () mode outpath prog -> actionProcedure (compileCAction fcfg mode outpath) prog futhark-0.25.27/src/Futhark/CLI/CUDA.hs000066400000000000000000000007001475065116200172110ustar00rootroot00000000000000-- | @futhark cuda@ module Futhark.CLI.CUDA (main) where import Futhark.Actions (compileCUDAAction) import Futhark.Compiler.CLI import Futhark.Passes (gpumemPipeline) -- | Run @futhark cuda@. main :: String -> [String] -> IO () main = compilerMain () [] "Compile CUDA" "Generate CUDA/C code from optimised Futhark program." gpumemPipeline $ \fcfg () mode outpath prog -> actionProcedure (compileCUDAAction fcfg mode outpath) prog futhark-0.25.27/src/Futhark/CLI/Check.hs000066400000000000000000000022771475065116200175250ustar00rootroot00000000000000-- | @futhark check@ module Futhark.CLI.Check (main) where import Control.Monad import Control.Monad.IO.Class import Futhark.Compiler import Futhark.Util.Options import Futhark.Util.Pretty (hPutDoc) import Language.Futhark.Warnings import System.Exit import System.IO data CheckConfig = CheckConfig {checkWarn :: Bool, checkWerror :: Bool} newCheckConfig :: CheckConfig newCheckConfig = CheckConfig True False options :: [FunOptDescr CheckConfig] options = [ Option "w" [] (NoArg $ Right $ \cfg -> cfg {checkWarn = False}) "Disable all warnings.", Option [] ["Werror"] (NoArg $ Right $ \cfg -> cfg {checkWerror = True}) "Treat warnings as errors." ] -- | Run @futhark check@. main :: String -> [String] -> IO () main = mainWithOptions newCheckConfig options "program" $ \args cfg -> case args of [file] -> Just $ do (warnings, _, _) <- readProgramOrDie file when (checkWarn cfg && anyWarnings warnings) $ do liftIO $ hPutDoc stderr $ prettyWarnings warnings when (checkWerror cfg) $ do hPutStrLn stderr "\nTreating above warnings as errors due to --Werror." exitWith $ ExitFailure 2 _ -> Nothing futhark-0.25.27/src/Futhark/CLI/Datacmp.hs000066400000000000000000000032231475065116200200510ustar00rootroot00000000000000-- | @futhark datacmp@ module Futhark.CLI.Datacmp (main) where import Control.Exception import Data.ByteString.Lazy.Char8 qualified as BS import Futhark.Data.Compare import Futhark.Data.Reader import Futhark.Util.Options import System.Exit import System.IO readFileSafely :: String -> IO (Either String BS.ByteString) readFileSafely filepath = (Right <$> BS.readFile filepath) `catch` couldNotRead where couldNotRead e = pure $ Left $ show (e :: IOError) -- | Run @futhark datacmp@ main :: String -> [String] -> IO () main = mainWithOptions () [] " " f where f [file_a, file_b] () = Just $ do file_contents_a_maybe <- readFileSafely file_a file_contents_b_maybe <- readFileSafely file_b case (file_contents_a_maybe, file_contents_b_maybe) of (Left err_msg, _) -> do hPutStrLn stderr err_msg exitFailure (_, Left err_msg) -> do hPutStrLn stderr err_msg exitFailure (Right contents_a, Right contents_b) -> do let vs_a_maybe = readValues contents_a let vs_b_maybe = readValues contents_b case (vs_a_maybe, vs_b_maybe) of (Nothing, _) -> do hPutStrLn stderr $ "Error reading values from " ++ file_a exitFailure (_, Nothing) -> do hPutStrLn stderr $ "Error reading values from " ++ file_b exitFailure (Just vs_a, Just vs_b) -> case compareSeveralValues (Tolerance 0.002) vs_a vs_b of [] -> pure () es -> do mapM_ print es exitWith $ ExitFailure 2 f _ _ = Nothing futhark-0.25.27/src/Futhark/CLI/Dataset.hs000066400000000000000000000245251475065116200200750ustar00rootroot00000000000000{-# LANGUAGE Strict #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | @futhark dataset@ module Futhark.CLI.Dataset (main) where import Control.Monad import Control.Monad.ST import Data.Binary qualified as Bin import Data.ByteString.Lazy.Char8 qualified as BS import Data.Text qualified as T import Data.Text.IO qualified as T import Data.Vector.Generic (freeze) import Data.Vector.Storable qualified as SVec import Data.Vector.Storable.Mutable qualified as USVec import Data.Word import Futhark.Data qualified as V import Futhark.Data.Reader (readValues) import Futhark.Util (convFloat) import Futhark.Util.Options import Language.Futhark.Parser import Language.Futhark.Pretty () import Language.Futhark.Prop (UncheckedTypeExp) import Language.Futhark.Syntax hiding ( FloatValue (..), IntValue (..), PrimValue (..), ValueType, ) import System.Exit import System.IO import System.Random (mkStdGen, uniformR) import System.Random.Stateful (UniformRange (..)) -- | Run @futhark dataset@. main :: String -> [String] -> IO () main = mainWithOptions initialDataOptions commandLineOptions "options..." f where f [] config | null $ optOrders config = Just $ do maybe_vs <- readValues <$> BS.getContents case maybe_vs of Nothing -> do hPutStrLn stderr "Malformed data on standard input." exitFailure Just vs -> case format config of Text -> mapM_ (T.putStrLn . V.valueText) vs Binary -> mapM_ (BS.putStr . Bin.encode) vs Type -> mapM_ (T.putStrLn . V.valueTypeText . V.valueType) vs | otherwise = Just $ zipWithM_ ($) (optOrders config) [fromIntegral (optSeed config) ..] f _ _ = Nothing data OutputFormat = Text | Binary | Type deriving (Eq, Ord, Show) data DataOptions = DataOptions { optSeed :: Int, optRange :: RandomConfiguration, optOrders :: [Word64 -> IO ()], format :: OutputFormat } initialDataOptions :: DataOptions initialDataOptions = DataOptions 1 initialRandomConfiguration [] Text commandLineOptions :: [FunOptDescr DataOptions] commandLineOptions = [ Option "s" ["seed"] ( ReqArg ( \n -> case reads n of [(n', "")] -> Right $ \config -> config {optSeed = n'} _ -> Left $ do hPutStrLn stderr $ "'" ++ n ++ "' is not an integer." exitFailure ) "SEED" ) "The seed to use when initialising the RNG.", Option "g" ["generate"] ( ReqArg ( \t -> case tryMakeGenerator t of Right g -> Right $ \config -> config { optOrders = optOrders config ++ [g (optRange config) (format config)] } Left err -> Left $ do T.hPutStrLn stderr err exitFailure ) "TYPE" ) "Generate a random value of this type.", Option [] ["text"] (NoArg $ Right $ \opts -> opts {format = Text}) "Output data in text format (default; must precede --generate).", Option "b" ["binary"] (NoArg $ Right $ \opts -> opts {format = Binary}) "Output data in binary Futhark format (must precede --generate).", Option "t" ["type"] (NoArg $ Right $ \opts -> opts {format = Type}) "Output the type (textually) rather than the value (must precede --generate).", setRangeOption "i8" seti8Range, setRangeOption "i16" seti16Range, setRangeOption "i32" seti32Range, setRangeOption "i64" seti64Range, setRangeOption "u8" setu8Range, setRangeOption "u16" setu16Range, setRangeOption "u32" setu32Range, setRangeOption "u64" setu64Range, setRangeOption "f16" setf16Range, setRangeOption "f32" setf32Range, setRangeOption "f64" setf64Range ] setRangeOption :: (Read a) => String -> (Range a -> RandomConfiguration -> RandomConfiguration) -> FunOptDescr DataOptions setRangeOption tname set = Option "" [name] ( ReqArg ( \b -> let (lower, rest) = span (/= ':') b upper = drop 1 rest in case (reads lower, reads upper) of ([(lower', "")], [(upper', "")]) -> Right $ \config -> config {optRange = set (lower', upper') $ optRange config} _ -> Left $ do hPutStrLn stderr $ "Invalid bounds for " ++ tname ++ ": " ++ b exitFailure ) "MIN:MAX" ) $ "Range of " ++ tname ++ " values." where name = tname ++ "-bounds" tryMakeGenerator :: String -> Either T.Text (RandomConfiguration -> OutputFormat -> Word64 -> IO ()) tryMakeGenerator t | Just vs <- readValues $ BS.pack t = pure $ \_ fmt _ -> mapM_ (outValue fmt) vs | otherwise = do t' <- toValueType =<< either (Left . syntaxErrorMsg) Right (parseType name (T.pack t)) pure $ \conf fmt seed -> do let v = randomValue conf t' seed outValue fmt v where name = "option " ++ t outValue Text = T.putStrLn . V.valueText outValue Binary = BS.putStr . Bin.encode outValue Type = T.putStrLn . V.valueTypeText . V.valueType toValueType :: UncheckedTypeExp -> Either T.Text V.ValueType toValueType TETuple {} = Left "Cannot handle tuples yet." toValueType TERecord {} = Left "Cannot handle records yet." toValueType TEApply {} = Left "Cannot handle type applications yet." toValueType TEArrow {} = Left "Cannot generate functions." toValueType TESum {} = Left "Cannot handle sumtypes yet." toValueType TEDim {} = Left "Cannot handle existential sizes." toValueType (TEParens t _) = toValueType t toValueType (TEUnique t _) = toValueType t toValueType (TEArray d t _) = do d' <- constantDim d V.ValueType ds t' <- toValueType t pure $ V.ValueType (d' : ds) t' where constantDim (SizeExp (IntLit k _ _) _) = Right $ fromInteger k constantDim _ = Left "Array has non-constant dimension declaration." toValueType (TEVar (QualName [] v) _) | Just t <- lookup v m = Right $ V.ValueType [] t where m = map f [minBound .. maxBound] f t = (nameFromText (V.primTypeText t), t) toValueType (TEVar v _) = Left $ "Unknown type " <> prettyText v -- | Closed interval, as in @System.Random@. type Range a = (a, a) data RandomConfiguration = RandomConfiguration { i8Range :: Range Int8, i16Range :: Range Int16, i32Range :: Range Int32, i64Range :: Range Int64, u8Range :: Range Word8, u16Range :: Range Word16, u32Range :: Range Word32, u64Range :: Range Word64, f16Range :: Range Half, f32Range :: Range Float, f64Range :: Range Double } -- The following lines provide evidence about how Haskells record -- system sucks. seti8Range :: Range Int8 -> RandomConfiguration -> RandomConfiguration seti8Range bounds config = config {i8Range = bounds} seti16Range :: Range Int16 -> RandomConfiguration -> RandomConfiguration seti16Range bounds config = config {i16Range = bounds} seti32Range :: Range Int32 -> RandomConfiguration -> RandomConfiguration seti32Range bounds config = config {i32Range = bounds} seti64Range :: Range Int64 -> RandomConfiguration -> RandomConfiguration seti64Range bounds config = config {i64Range = bounds} setu8Range :: Range Word8 -> RandomConfiguration -> RandomConfiguration setu8Range bounds config = config {u8Range = bounds} setu16Range :: Range Word16 -> RandomConfiguration -> RandomConfiguration setu16Range bounds config = config {u16Range = bounds} setu32Range :: Range Word32 -> RandomConfiguration -> RandomConfiguration setu32Range bounds config = config {u32Range = bounds} setu64Range :: Range Word64 -> RandomConfiguration -> RandomConfiguration setu64Range bounds config = config {u64Range = bounds} setf16Range :: Range Half -> RandomConfiguration -> RandomConfiguration setf16Range bounds config = config {f16Range = bounds} setf32Range :: Range Float -> RandomConfiguration -> RandomConfiguration setf32Range bounds config = config {f32Range = bounds} setf64Range :: Range Double -> RandomConfiguration -> RandomConfiguration setf64Range bounds config = config {f64Range = bounds} initialRandomConfiguration :: RandomConfiguration initialRandomConfiguration = RandomConfiguration (minBound, maxBound) (minBound, maxBound) (minBound, maxBound) (minBound, maxBound) (minBound, maxBound) (minBound, maxBound) (minBound, maxBound) (minBound, maxBound) (0.0, 1.0) (0.0, 1.0) (0.0, 1.0) randomValue :: RandomConfiguration -> V.ValueType -> Word64 -> V.Value randomValue conf (V.ValueType ds t) seed = case t of V.I8 -> gen i8Range V.I8Value V.I16 -> gen i16Range V.I16Value V.I32 -> gen i32Range V.I32Value V.I64 -> gen i64Range V.I64Value V.U8 -> gen u8Range V.U8Value V.U16 -> gen u16Range V.U16Value V.U32 -> gen u32Range V.U32Value V.U64 -> gen u64Range V.U64Value V.F16 -> gen f16Range V.F16Value V.F32 -> gen f32Range V.F32Value V.F64 -> gen f64Range V.F64Value V.Bool -> gen (const (False, True)) V.BoolValue where gen range final = randomVector (range conf) final ds seed randomVector :: (SVec.Storable v, UniformRange v) => Range v -> (SVec.Vector Int -> SVec.Vector v -> V.Value) -> [Int] -> Word64 -> V.Value randomVector range final ds seed = runST $ do -- Use some nice impure computation where we can preallocate a -- vector of the desired size, populate it via the random number -- generator, and then finally reutrn a frozen binary vector. arr <- USVec.new n let fill g i | i < n = do let (v, g') = uniformR range g USVec.write arr i v g' `seq` fill g' $! i + 1 | otherwise = pure () fill (mkStdGen $ fromIntegral seed) 0 final (SVec.fromList ds) . SVec.convert <$> freeze arr where n = product ds -- XXX: The following instance is an orphan. Maybe it could be -- avoided with some newtype trickery or refactoring, but it's so -- convenient this way. instance UniformRange Half where uniformRM (a, b) g = (convFloat :: Float -> Half) <$> uniformRM (convFloat a, convFloat b) g futhark-0.25.27/src/Futhark/CLI/Defs.hs000066400000000000000000000034011475065116200173570ustar00rootroot00000000000000-- | @futhark defs@ module Futhark.CLI.Defs (main) where import Data.Sequence qualified as Seq import Data.Text qualified as T import Data.Text.IO qualified as T import Futhark.Compiler import Futhark.Util.Loc import Futhark.Util.Options import Language.Futhark data DefKind = Value | Module | ModuleType | Type data Def = Def DefKind Name Loc kindText :: DefKind -> T.Text kindText Value = "value" kindText Module = "module" kindText ModuleType = "module type" kindText Type = "type" printDef :: Def -> IO () printDef (Def k name loc) = do T.putStrLn $ T.unwords [kindText k, nameToText name, T.pack (locStr loc)] defsInProg :: UncheckedProg -> Seq.Seq Def defsInProg = foldMap defsInDec . progDecs where defsInDec (ValDec vb) = Seq.singleton $ Def Value (valBindName vb) (locOf vb) defsInDec (TypeDec tb) = Seq.singleton $ Def Type (typeAlias tb) (locOf tb) defsInDec (LocalDec d _) = defsInDec d defsInDec (OpenDec me _) = defsInModExp me defsInDec (ModDec mb) = defsInModExp $ modExp mb defsInDec ModTypeDec {} = mempty defsInDec ImportDec {} = mempty defsInModExp ModVar {} = mempty defsInModExp (ModParens me _) = defsInModExp me defsInModExp ModImport {} = mempty defsInModExp (ModDecs ds _) = foldMap defsInDec ds defsInModExp (ModApply me1 me2 _ _ _) = defsInModExp me1 <> defsInModExp me2 defsInModExp (ModAscript me _ _ _) = defsInModExp me defsInModExp (ModLambda _ _ me _) = defsInModExp me -- | Run @futhark defs@. main :: String -> [String] -> IO () main = mainWithOptions () [] "program" $ \args () -> case args of [file] -> Just $ do prog <- readUntypedProgramOrDie file mapM_ printDef . foldMap (defsInProg . snd) $ filter (not . isBuiltin . fst) prog _ -> Nothing futhark-0.25.27/src/Futhark/CLI/Dev.hs000066400000000000000000000732061475065116200172260ustar00rootroot00000000000000-- | Futhark Compiler Driver module Futhark.CLI.Dev (main) where import Control.Category (id) import Control.Monad import Control.Monad.State import Data.Kind qualified import Data.List (intersperse) import Data.Maybe import Data.Text qualified as T import Data.Text.IO qualified as T import Futhark.Actions import Futhark.Analysis.AccessPattern (Analyse) import Futhark.Analysis.Alias qualified as Alias import Futhark.Analysis.Metrics (OpMetrics) import Futhark.Compiler.CLI hiding (compilerMain) import Futhark.IR (Op, Prog, prettyString) import Futhark.IR.Aliases (AliasableRep) import Futhark.IR.GPU qualified as GPU import Futhark.IR.GPUMem qualified as GPUMem import Futhark.IR.MC qualified as MC import Futhark.IR.MCMem qualified as MCMem import Futhark.IR.Parse import Futhark.IR.SOACS qualified as SOACS import Futhark.IR.Seq qualified as Seq import Futhark.IR.SeqMem qualified as SeqMem import Futhark.IR.TypeCheck (Checkable, checkProg) import Futhark.Internalise.ApplyTypeAbbrs as ApplyTypeAbbrs import Futhark.Internalise.Defunctionalise as Defunctionalise import Futhark.Internalise.Defunctorise as Defunctorise import Futhark.Internalise.FullNormalise as FullNormalise import Futhark.Internalise.LiftLambdas as LiftLambdas import Futhark.Internalise.Monomorphise as Monomorphise import Futhark.Internalise.ReplaceRecords as ReplaceRecords import Futhark.Optimise.ArrayLayout import Futhark.Optimise.ArrayShortCircuiting qualified as ArrayShortCircuiting import Futhark.Optimise.CSE import Futhark.Optimise.DoubleBuffer import Futhark.Optimise.Fusion import Futhark.Optimise.GenRedOpt import Futhark.Optimise.HistAccs import Futhark.Optimise.InliningDeadFun import Futhark.Optimise.MemoryBlockMerging qualified as MemoryBlockMerging import Futhark.Optimise.ReduceDeviceSyncs (reduceDeviceSyncs) import Futhark.Optimise.Sink import Futhark.Optimise.TileLoops import Futhark.Optimise.Unstream import Futhark.Pass import Futhark.Pass.AD import Futhark.Pass.ExpandAllocations import Futhark.Pass.ExplicitAllocations.GPU qualified as GPU import Futhark.Pass.ExplicitAllocations.MC qualified as MC import Futhark.Pass.ExplicitAllocations.Seq qualified as Seq import Futhark.Pass.ExtractKernels import Futhark.Pass.ExtractMulticore import Futhark.Pass.FirstOrderTransform import Futhark.Pass.LiftAllocations as LiftAllocations import Futhark.Pass.LowerAllocations as LowerAllocations import Futhark.Pass.Simplify import Futhark.Passes import Futhark.Util.Log import Futhark.Util.Options import Futhark.Util.Pretty qualified as PP import Language.Futhark.Core (locStr, nameFromString) import Language.Futhark.Parser (SyntaxError (..), parseFuthark) import System.Exit import System.FilePath import System.IO import Prelude hiding (id) -- | What to do with the program after it has been read. data FutharkPipeline = -- | Just print it. PrettyPrint | -- | Run the type checker; print type errors. TypeCheck | -- | Run this pipeline. Pipeline [UntypedPass] | -- | Partially evaluate away the module language. Defunctorise | -- | Defunctorise and normalise. FullNormalise | -- | Defunctorise, normalise and monomorphise. Monomorphise | -- | Defunctorise, normalise, monomorphise and lambda-lift. LiftLambdas | -- | Defunctorise, normalise, monomorphise, lambda-lift and defunctionalise. Defunctionalise data Config = Config { futharkConfig :: FutharkConfig, -- | Nothing is distinct from a empty pipeline - -- it means we don't even run the internaliser. futharkPipeline :: FutharkPipeline, futharkCompilerMode :: CompilerMode, futharkAction :: UntypedAction, -- | If true, prints programs as raw ASTs instead -- of their prettyprinted form. futharkPrintAST :: Bool } -- | Get a Futhark pipeline from the configuration - an empty one if -- none exists. getFutharkPipeline :: Config -> [UntypedPass] getFutharkPipeline = toPipeline . futharkPipeline where toPipeline (Pipeline p) = p toPipeline _ = [] data UntypedPassState = SOACS (Prog SOACS.SOACS) | GPU (Prog GPU.GPU) | MC (Prog MC.MC) | Seq (Prog Seq.Seq) | GPUMem (Prog GPUMem.GPUMem) | MCMem (Prog MCMem.MCMem) | SeqMem (Prog SeqMem.SeqMem) getSOACSProg :: UntypedPassState -> Maybe (Prog SOACS.SOACS) getSOACSProg (SOACS prog) = Just prog getSOACSProg _ = Nothing class Representation s where -- | A human-readable description of the representation expected or -- contained, usable for error messages. representation :: s -> String instance Representation UntypedPassState where representation (SOACS _) = "SOACS" representation (GPU _) = "GPU" representation (MC _) = "MC" representation (Seq _) = "Seq" representation (GPUMem _) = "GPUMem" representation (MCMem _) = "MCMem" representation (SeqMem _) = "SeqMem" instance PP.Pretty UntypedPassState where pretty (SOACS prog) = PP.pretty prog pretty (GPU prog) = PP.pretty prog pretty (MC prog) = PP.pretty prog pretty (Seq prog) = PP.pretty prog pretty (SeqMem prog) = PP.pretty prog pretty (MCMem prog) = PP.pretty prog pretty (GPUMem prog) = PP.pretty prog newtype UntypedPass = UntypedPass ( UntypedPassState -> PipelineConfig -> FutharkM UntypedPassState ) type BackendAction rep = FutharkConfig -> CompilerMode -> FilePath -> Action rep data UntypedAction = SOACSAction (Action SOACS.SOACS) | GPUAction (Action GPU.GPU) | GPUMemAction (BackendAction GPUMem.GPUMem) | MCMemAction (BackendAction MCMem.MCMem) | SeqMemAction (BackendAction SeqMem.SeqMem) | PolyAction ( forall (rep :: Data.Kind.Type). ( AliasableRep rep, (OpMetrics (Op rep)), Analyse rep ) => Action rep ) instance Representation UntypedAction where representation (SOACSAction _) = "SOACS" representation (GPUAction _) = "GPU" representation (GPUMemAction _) = "GPUMem" representation (MCMemAction _) = "MCMem" representation (SeqMemAction _) = "SeqMem" representation PolyAction {} = "" newConfig :: Config newConfig = Config newFutharkConfig (Pipeline []) ToExecutable action False where action = PolyAction printAction changeFutharkConfig :: (FutharkConfig -> FutharkConfig) -> Config -> Config changeFutharkConfig f cfg = cfg {futharkConfig = f $ futharkConfig cfg} type FutharkOption = FunOptDescr Config passOption :: String -> UntypedPass -> String -> [String] -> FutharkOption passOption desc pass short long = Option short long ( NoArg $ Right $ \cfg -> cfg {futharkPipeline = Pipeline $ getFutharkPipeline cfg ++ [pass]} ) desc kernelsMemProg :: String -> UntypedPassState -> FutharkM (Prog GPUMem.GPUMem) kernelsMemProg _ (GPUMem prog) = pure prog kernelsMemProg name rep = externalErrorS $ "Pass '" <> name <> "' expects GPUMem representation, but got " <> representation rep soacsProg :: String -> UntypedPassState -> FutharkM (Prog SOACS.SOACS) soacsProg _ (SOACS prog) = pure prog soacsProg name rep = externalErrorS $ "Pass '" <> name <> "' expects SOACS representation, but got " <> representation rep kernelsProg :: String -> UntypedPassState -> FutharkM (Prog GPU.GPU) kernelsProg _ (GPU prog) = pure prog kernelsProg name rep = externalErrorS $ "Pass '" <> name <> "' expects GPU representation, but got " <> representation rep seqMemProg :: String -> UntypedPassState -> FutharkM (Prog SeqMem.SeqMem) seqMemProg _ (SeqMem prog) = pure prog seqMemProg name rep = externalErrorS $ "Pass '" <> name <> "' expects SeqMem representation, but got " <> representation rep mcProg :: String -> UntypedPassState -> FutharkM (Prog MC.MC) mcProg _ (MC prog) = pure prog mcProg name rep = externalErrorS $ "Pass " ++ name ++ " expects MC representation, but got " ++ representation rep mcMemProg :: String -> UntypedPassState -> FutharkM (Prog MCMem.MCMem) mcMemProg _ (MCMem prog) = pure prog mcMemProg name rep = externalErrorS $ "Pass '" <> name <> "' expects MCMem representation, but got " <> representation rep typedPassOption :: (Checkable torep) => (String -> UntypedPassState -> FutharkM (Prog fromrep)) -> (Prog torep -> UntypedPassState) -> Pass fromrep torep -> String -> FutharkOption typedPassOption getProg putProg pass short = passOption (passDescription pass) (UntypedPass perform) short long where perform s config = do prog <- getProg (passName pass) s putProg <$> runPipeline (onePass pass) config prog long = [passLongOption pass] soacsPassOption :: Pass SOACS.SOACS SOACS.SOACS -> String -> FutharkOption soacsPassOption = typedPassOption soacsProg SOACS kernelsPassOption :: Pass GPU.GPU GPU.GPU -> String -> FutharkOption kernelsPassOption = typedPassOption kernelsProg GPU mcPassOption :: Pass MC.MC MC.MC -> String -> FutharkOption mcPassOption = typedPassOption mcProg MC seqMemPassOption :: Pass SeqMem.SeqMem SeqMem.SeqMem -> String -> FutharkOption seqMemPassOption = typedPassOption seqMemProg SeqMem mcMemPassOption :: Pass MCMem.MCMem MCMem.MCMem -> String -> FutharkOption mcMemPassOption = typedPassOption mcMemProg MCMem kernelsMemPassOption :: Pass GPUMem.GPUMem GPUMem.GPUMem -> String -> FutharkOption kernelsMemPassOption = typedPassOption kernelsMemProg GPUMem simplifyOption :: String -> FutharkOption simplifyOption short = passOption (passDescription pass) (UntypedPass perform) short long where perform (SOACS prog) config = SOACS <$> runPipeline (onePass simplifySOACS) config prog perform (GPU prog) config = GPU <$> runPipeline (onePass simplifyGPU) config prog perform (MC prog) config = MC <$> runPipeline (onePass simplifyMC) config prog perform (Seq prog) config = Seq <$> runPipeline (onePass simplifySeq) config prog perform (SeqMem prog) config = SeqMem <$> runPipeline (onePass simplifySeqMem) config prog perform (GPUMem prog) config = GPUMem <$> runPipeline (onePass simplifyGPUMem) config prog perform (MCMem prog) config = MCMem <$> runPipeline (onePass simplifyMCMem) config prog long = [passLongOption pass] pass = simplifySOACS allocateOption :: String -> FutharkOption allocateOption short = passOption (passDescription pass) (UntypedPass perform) short long where perform (GPU prog) config = GPUMem <$> runPipeline (onePass GPU.explicitAllocations) config prog perform (Seq prog) config = SeqMem <$> runPipeline (onePass Seq.explicitAllocations) config prog perform (MC prog) config = MCMem <$> runPipeline (onePass MC.explicitAllocations) config prog perform s _ = externalErrorS $ "Pass '" <> passDescription pass <> "' cannot operate on " <> representation s long = [passLongOption pass] pass = Seq.explicitAllocations cseOption :: String -> FutharkOption cseOption short = passOption (passDescription pass) (UntypedPass perform) short long where perform (SOACS prog) config = SOACS <$> runPipeline (onePass $ performCSE True) config prog perform (GPU prog) config = GPU <$> runPipeline (onePass $ performCSE True) config prog perform (MC prog) config = MC <$> runPipeline (onePass $ performCSE True) config prog perform (Seq prog) config = Seq <$> runPipeline (onePass $ performCSE True) config prog perform (SeqMem prog) config = SeqMem <$> runPipeline (onePass $ performCSE False) config prog perform (GPUMem prog) config = GPUMem <$> runPipeline (onePass $ performCSE False) config prog perform (MCMem prog) config = MCMem <$> runPipeline (onePass $ performCSE False) config prog long = [passLongOption pass] pass = performCSE True :: Pass SOACS.SOACS SOACS.SOACS sinkOption :: String -> FutharkOption sinkOption short = passOption (passDescription pass) (UntypedPass perform) short long where perform (GPU prog) config = GPU <$> runPipeline (onePass sinkGPU) config prog perform (MC prog) config = MC <$> runPipeline (onePass sinkMC) config prog perform s _ = externalErrorS $ "Pass '" ++ passDescription pass ++ "' cannot operate on " ++ representation s long = [passLongOption pass] pass = sinkGPU pipelineOption :: (UntypedPassState -> Maybe (Prog fromrep)) -> String -> (Prog torep -> UntypedPassState) -> String -> Pipeline fromrep torep -> String -> [String] -> FutharkOption pipelineOption getprog repdesc repf desc pipeline = passOption desc $ UntypedPass pipelinePass where pipelinePass rep config = case getprog rep of Just prog -> repf <$> runPipeline pipeline config prog Nothing -> externalErrorS $ "Expected " ++ repdesc ++ " representation, but got " ++ representation rep soacsPipelineOption :: String -> Pipeline SOACS.SOACS SOACS.SOACS -> String -> [String] -> FutharkOption soacsPipelineOption = pipelineOption getSOACSProg "SOACS" SOACS unstreamOption :: String -> FutharkOption unstreamOption short = passOption (passDescription pass) (UntypedPass perform) short long where perform (GPU prog) config = GPU <$> runPipeline (onePass unstreamGPU) config prog perform (MC prog) config = MC <$> runPipeline (onePass unstreamMC) config prog perform s _ = externalErrorS $ "Pass '" ++ passDescription pass ++ "' cannot operate on " ++ representation s long = [passLongOption pass] pass = unstreamGPU commandLineOptions :: [FutharkOption] commandLineOptions = [ Option "v" ["verbose"] (OptArg (Right . changeFutharkConfig . incVerbosity) "FILE") "Print verbose output on standard error; wrong program to FILE.", Option [] ["Werror"] (NoArg $ Right $ changeFutharkConfig $ \opts -> opts {futharkWerror = True}) "Treat warnings as errors.", Option "w" [] (NoArg $ Right $ changeFutharkConfig $ \opts -> opts {futharkWarn = False}) "Disable all warnings.", Option "t" ["type-check"] ( NoArg $ Right $ \opts -> opts {futharkPipeline = TypeCheck} ) "Print on standard output the type-checked program.", Option [] ["no-check"] ( NoArg $ Right $ changeFutharkConfig $ \opts -> opts {futharkTypeCheck = False} ) "Disable type-checking.", Option [] ["pretty-print"] ( NoArg $ Right $ \opts -> opts {futharkPipeline = PrettyPrint} ) "Parse and prettyString-print the AST of the given program.", Option [] ["backend"] ( ReqArg ( \arg -> do action <- case arg of "c" -> Right $ SeqMemAction compileCAction "multicore" -> Right $ MCMemAction compileMulticoreAction "opencl" -> Right $ GPUMemAction compileOpenCLAction "hip" -> Right $ GPUMemAction compileHIPAction "cuda" -> Right $ GPUMemAction compileCUDAAction "wasm" -> Right $ SeqMemAction compileCtoWASMAction "wasm-multicore" -> Right $ MCMemAction compileMulticoreToWASMAction "ispc" -> Right $ MCMemAction compileMulticoreToISPCAction "python" -> Right $ SeqMemAction compilePythonAction "pyopencl" -> Right $ GPUMemAction compilePyOpenCLAction _ -> Left $ error $ "Invalid backend: " <> arg Right $ \opts -> opts {futharkAction = action} ) "c|multicore|opencl|cuda|hip|python|pyopencl" ) "Run this compiler backend on pipeline result.", Option [] ["compile-imp-seq"] ( NoArg $ Right $ \opts -> opts {futharkAction = SeqMemAction $ \_ _ _ -> impCodeGenAction} ) "Translate pipeline result to ImpSequential and write it on stdout.", Option [] ["compile-imp-gpu"] ( NoArg $ Right $ \opts -> opts {futharkAction = GPUMemAction $ \_ _ _ -> kernelImpCodeGenAction} ) "Translate pipeline result to ImpGPU and write it on stdout.", Option [] ["compile-imp-multicore"] ( NoArg $ Right $ \opts -> opts {futharkAction = MCMemAction $ \_ _ _ -> multicoreImpCodeGenAction} ) "Translate pipeline result to ImpMC write it on stdout.", Option "p" ["print"] (NoArg $ Right $ \opts -> opts {futharkAction = PolyAction printAction}) "Print the resulting IR (default action).", Option [] ["print-aliases"] (NoArg $ Right $ \opts -> opts {futharkAction = PolyAction printAliasesAction}) "Print the resulting IR with aliases.", Option [] ["fusion-graph"] (NoArg $ Right $ \opts -> opts {futharkAction = SOACSAction printFusionGraph}) "Print fusion graph.", Option [] ["print-last-use-gpu"] ( NoArg $ Right $ \opts -> opts {futharkAction = GPUMemAction $ \_ _ _ -> printLastUseGPU} ) "Print last use information ss.", Option [] ["print-interference-gpu"] ( NoArg $ Right $ \opts -> opts {futharkAction = GPUMemAction $ \_ _ _ -> printInterferenceGPU} ) "Print interference information.", Option [] ["print-mem-alias-gpu"] ( NoArg $ Right $ \opts -> opts {futharkAction = GPUMemAction $ \_ _ _ -> printMemAliasGPU} ) "Print memory alias information.", Option "z" ["memory-access-pattern"] (NoArg $ Right $ \opts -> opts {futharkAction = PolyAction printMemoryAccessAnalysis}) "Print the result of analysing memory access patterns. Currently only for --gpu --mc.", Option [] ["call-graph"] (NoArg $ Right $ \opts -> opts {futharkAction = SOACSAction callGraphAction}) "Print the resulting call graph.", Option "m" ["metrics"] (NoArg $ Right $ \opts -> opts {futharkAction = PolyAction metricsAction}) "Print AST metrics of the resulting internal representation on standard output.", Option [] ["defunctorise"] (NoArg $ Right $ \opts -> opts {futharkPipeline = Defunctorise}) "Partially evaluate all module constructs and print the residual program.", Option [] ["normalise"] (NoArg $ Right $ \opts -> opts {futharkPipeline = FullNormalise}) "Fully normalise the program.", Option [] ["monomorphise"] (NoArg $ Right $ \opts -> opts {futharkPipeline = Monomorphise}) "Monomorphise the program.", Option [] ["lift-lambdas"] (NoArg $ Right $ \opts -> opts {futharkPipeline = LiftLambdas}) "Lambda-lift the program.", Option [] ["defunctionalise"] (NoArg $ Right $ \opts -> opts {futharkPipeline = Defunctionalise}) "Defunctionalise the program.", Option [] ["ast"] (NoArg $ Right $ \opts -> opts {futharkPrintAST = True}) "Output ASTs instead of prettyprinted programs.", Option [] ["safe"] (NoArg $ Right $ changeFutharkConfig $ \opts -> opts {futharkSafe = True}) "Ignore 'unsafe'.", Option [] ["entry-points"] ( ReqArg ( \arg -> Right $ changeFutharkConfig $ \opts -> opts { futharkEntryPoints = nameFromString arg : futharkEntryPoints opts } ) "NAME" ) "Treat this function as an additional entry point.", Option [] ["library"] (NoArg $ Right $ \opts -> opts {futharkCompilerMode = ToLibrary}) "Generate a library instead of an executable.", Option [] ["executable"] (NoArg $ Right $ \opts -> opts {futharkCompilerMode = ToExecutable}) "Generate an executable instead of a library (set by default).", Option [] ["server"] (NoArg $ Right $ \opts -> opts {futharkCompilerMode = ToServer}) "Generate a server executable.", typedPassOption soacsProg Seq firstOrderTransform "f", soacsPassOption fuseSOACs "o", soacsPassOption inlineAggressively [], soacsPassOption inlineConservatively [], soacsPassOption removeDeadFunctions [], soacsPassOption applyAD [], soacsPassOption applyADInnermost [], kernelsPassOption optimiseArrayLayoutGPU [], mcPassOption optimiseArrayLayoutMC [], kernelsPassOption optimiseGenRed [], kernelsPassOption tileLoops [], kernelsPassOption histAccsGPU [], unstreamOption [], sinkOption [], kernelsPassOption reduceDeviceSyncs [], typedPassOption soacsProg GPU extractKernels [], typedPassOption soacsProg MC extractMulticore [], allocateOption "a", kernelsMemPassOption doubleBufferGPU [], mcMemPassOption doubleBufferMC [], kernelsMemPassOption expandAllocations [], kernelsMemPassOption MemoryBlockMerging.optimise [], seqMemPassOption LiftAllocations.liftAllocationsSeqMem [], kernelsMemPassOption LiftAllocations.liftAllocationsGPUMem [], seqMemPassOption LowerAllocations.lowerAllocationsSeqMem [], kernelsMemPassOption LowerAllocations.lowerAllocationsGPUMem [], seqMemPassOption ArrayShortCircuiting.optimiseSeqMem [], mcMemPassOption ArrayShortCircuiting.optimiseMCMem [], kernelsMemPassOption ArrayShortCircuiting.optimiseGPUMem [], cseOption [], simplifyOption "e", soacsPipelineOption "Run the default optimised pipeline" standardPipeline "s" ["standard"], pipelineOption getSOACSProg "GPU" GPU "Run the default optimised kernels pipeline" gpuPipeline [] ["gpu"], pipelineOption getSOACSProg "GPUMem" GPUMem "Run the full GPU compilation pipeline" gpumemPipeline [] ["gpu-mem"], pipelineOption getSOACSProg "Seq" Seq "Run the sequential CPU compilation pipeline" seqPipeline [] ["seq"], pipelineOption getSOACSProg "SeqMem" SeqMem "Run the sequential CPU+memory compilation pipeline" seqmemPipeline [] ["seq-mem"], pipelineOption getSOACSProg "MC" MC "Run the multicore compilation pipeline" mcPipeline [] ["mc"], pipelineOption getSOACSProg "MCMem" MCMem "Run the multicore+memory compilation pipeline" mcmemPipeline [] ["mc-mem"] ] incVerbosity :: Maybe FilePath -> FutharkConfig -> FutharkConfig incVerbosity file cfg = cfg {futharkVerbose = (v, file `mplus` snd (futharkVerbose cfg))} where v = case fst $ futharkVerbose cfg of NotVerbose -> Verbose Verbose -> VeryVerbose VeryVerbose -> VeryVerbose -- | Entry point. Non-interactive, except when reading interpreter -- input from standard input. main :: String -> [String] -> IO () main = mainWithOptions newConfig commandLineOptions "options... program" compile where compile [file] config = Just $ do res <- runFutharkM (m file config) $ fst $ futharkVerbose $ futharkConfig config case res of Left err -> do dumpError (futharkConfig config) err exitWith $ ExitFailure 2 Right () -> pure () compile _ _ = Nothing m file config = do let p :: (Show a, PP.Pretty a) => [a] -> IO () p = mapM_ putStrLn . intersperse "" . map (if futharkPrintAST config then show else prettyString) readProgram' = readProgramFile (futharkEntryPoints (futharkConfig config)) file case futharkPipeline config of PrettyPrint -> liftIO $ do maybe_prog <- parseFuthark file <$> T.readFile file case maybe_prog of Left (SyntaxError loc err) -> fail $ "Syntax error at " <> locStr loc <> ":\n" <> T.unpack err Right prog | futharkPrintAST config -> print prog | otherwise -> putStrLn $ prettyString prog TypeCheck -> do (_, imports, _) <- readProgram' liftIO $ forM_ (map snd imports) $ \fm -> putStrLn $ if futharkPrintAST config then show $ fileProg fm else prettyString $ fileProg fm Defunctorise -> do (_, imports, src) <- readProgram' liftIO $ p $ flip evalState src $ Defunctorise.transformProg imports >>= ApplyTypeAbbrs.transformProg FullNormalise -> do (_, imports, src) <- readProgram' liftIO $ p $ flip evalState src $ Defunctorise.transformProg imports >>= ApplyTypeAbbrs.transformProg >>= FullNormalise.transformProg LiftLambdas -> do (_, imports, src) <- readProgram' liftIO $ p $ flip evalState src $ Defunctorise.transformProg imports >>= ApplyTypeAbbrs.transformProg >>= FullNormalise.transformProg >>= ReplaceRecords.transformProg >>= LiftLambdas.transformProg Monomorphise -> do (_, imports, src) <- readProgram' liftIO $ p $ flip evalState src $ Defunctorise.transformProg imports >>= ApplyTypeAbbrs.transformProg >>= FullNormalise.transformProg >>= ReplaceRecords.transformProg >>= LiftLambdas.transformProg >>= Monomorphise.transformProg Defunctionalise -> do (_, imports, src) <- readProgram' liftIO $ p $ flip evalState src $ Defunctorise.transformProg imports >>= ApplyTypeAbbrs.transformProg >>= FullNormalise.transformProg >>= ReplaceRecords.transformProg >>= LiftLambdas.transformProg >>= Monomorphise.transformProg >>= Defunctionalise.transformProg Pipeline {} -> do let (base, ext) = splitExtension file readCore parse construct = do logMsg $ "Reading " <> file <> "..." input <- liftIO $ T.readFile file logMsg ("Parsing..." :: T.Text) case parse file input of Left err -> externalErrorS $ T.unpack err Right prog -> do logMsg ("Typechecking..." :: T.Text) case checkProg $ Alias.aliasAnalysis prog of Left err -> externalErrorS $ show err Right () -> runPolyPasses config base $ construct prog handlers = [ ( ".fut", do prog <- runPipelineOnProgram (futharkConfig config) id file runPolyPasses config base (SOACS prog) ), (".fut_soacs", readCore parseSOACS SOACS), (".fut_seq", readCore parseSeq Seq), (".fut_seq_mem", readCore parseSeqMem SeqMem), (".fut_gpu", readCore parseGPU GPU), (".fut_gpu_mem", readCore parseGPUMem GPUMem), (".fut_mc", readCore parseMC MC), (".fut_mc_mem", readCore parseMCMem MCMem) ] case lookup ext handlers of Just handler -> handler Nothing -> externalErrorS $ unwords [ "Unsupported extension", show ext, ". Supported extensions:", unwords $ map fst handlers ] runPolyPasses :: Config -> FilePath -> UntypedPassState -> FutharkM () runPolyPasses config base initial_prog = do end_prog <- foldM (runPolyPass pipeline_config) initial_prog (getFutharkPipeline config) case (end_prog, futharkAction config) of (SOACS prog, SOACSAction action) -> otherAction action prog (GPU prog, GPUAction action) -> otherAction action prog (SeqMem prog, SeqMemAction action) -> backendAction prog action (GPUMem prog, GPUMemAction action) -> backendAction prog action (MCMem prog, MCMemAction action) -> backendAction prog action (SOACS soacs_prog, PolyAction acs) -> otherAction acs soacs_prog (GPU kernels_prog, PolyAction acs) -> otherAction acs kernels_prog (MC mc_prog, PolyAction acs) -> otherAction acs mc_prog (Seq seq_prog, PolyAction acs) -> otherAction acs seq_prog (GPUMem mem_prog, PolyAction acs) -> otherAction acs mem_prog (SeqMem mem_prog, PolyAction acs) -> otherAction acs mem_prog (MCMem mem_prog, PolyAction acs) -> otherAction acs mem_prog (_, action) -> externalErrorS $ "Action expects " ++ representation action ++ " representation, but got " ++ representation end_prog ++ "." logMsg ("Done." :: String) where backendAction prog actionf = do let action = actionf (futharkConfig config) (futharkCompilerMode config) base otherAction action prog otherAction action prog = do logMsg $ "Running action " ++ actionName action actionProcedure action prog pipeline_config = PipelineConfig { pipelineVerbose = fst (futharkVerbose $ futharkConfig config) > NotVerbose, pipelineValidate = futharkTypeCheck $ futharkConfig config } runPolyPass :: PipelineConfig -> UntypedPassState -> UntypedPass -> FutharkM UntypedPassState runPolyPass pipeline_config s (UntypedPass f) = f s pipeline_config futhark-0.25.27/src/Futhark/CLI/Doc.hs000066400000000000000000000073651475065116200172200ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TemplateHaskell #-} -- | @futhark doc@ module Futhark.CLI.Doc (main) where import Control.Monad import Control.Monad.State import Data.FileEmbed import Data.List (nubBy) import Data.Text qualified as T import Data.Text.IO qualified as T import Data.Text.Lazy qualified as LT import Futhark.Compiler (Imports, dumpError, fileProg, newFutharkConfig, readProgramFiles) import Futhark.Doc.Generator import Futhark.Pipeline (FutharkM, Verbosity (..), runFutharkM) import Futhark.Util (directoryContents) import Futhark.Util.Options import Language.Futhark.Semantic (mkInitialImport) import Language.Futhark.Syntax (DocComment (..), progDoc) import System.Directory (createDirectoryIfMissing) import System.Exit import System.FilePath import System.IO import Text.Blaze.Html.Renderer.Text cssFile :: T.Text cssFile = $(embedStringFile "rts/futhark-doc/style.css") data DocConfig = DocConfig { docOutput :: Maybe FilePath, docVerbose :: Bool } initialDocConfig :: DocConfig initialDocConfig = DocConfig { docOutput = Nothing, docVerbose = False } printDecs :: DocConfig -> FilePath -> [FilePath] -> Imports -> IO () printDecs cfg dir files imports = do let direct_imports = map (mkInitialImport . normalise . dropExtension) files (file_htmls, _warnings) = renderFiles direct_imports $ filter (not . ignored) imports mapM_ (write . fmap (LT.toStrict . renderHtml)) file_htmls write ("style.css", cssFile) where write :: (FilePath, T.Text) -> IO () write (name, content) = do let file = dir makeRelative "/" name when (docVerbose cfg) $ hPutStrLn stderr $ "Writing " <> file createDirectoryIfMissing True $ takeDirectory file T.writeFile file content -- Some files are not worth documenting; typically because -- they contain tests. The current crude mechanism is to -- recognise them by a file comment containing "ignore". ignored (_, fm) = case progDoc (fileProg fm) of Just (DocComment s _) -> T.strip s == "ignore" _ -> False type DocOption = OptDescr (Either (IO ()) (DocConfig -> DocConfig)) commandLineOptions :: [DocOption] commandLineOptions = [ Option "o" ["output-directory"] ( ReqArg (\dirname -> Right $ \config -> config {docOutput = Just dirname}) "DIR" ) "Directory in which to put generated documentation.", Option "v" ["verbose"] (NoArg $ Right $ \config -> config {docVerbose = True}) "Print status messages on stderr." ] futFiles :: FilePath -> IO [FilePath] futFiles dir = filter isFut <$> directoryContents dir where isFut = (== ".fut") . takeExtension -- | Run @futhark doc@. main :: String -> [String] -> IO () main = mainWithOptions initialDocConfig commandLineOptions "options... -o outdir programs..." f where f [dir] config = Just $ do res <- runFutharkM (m config dir) Verbose case res of Left err -> liftIO $ do dumpError newFutharkConfig err exitWith $ ExitFailure 2 Right () -> pure () f _ _ = Nothing m :: DocConfig -> FilePath -> FutharkM () m config dir = case docOutput config of Nothing -> liftIO $ do hPutStrLn stderr "Must specify output directory with -o." exitWith $ ExitFailure 1 Just outdir -> do files <- liftIO $ futFiles dir when (docVerbose config) $ liftIO $ do mapM_ (hPutStrLn stderr . ("Found source file " <>)) files hPutStrLn stderr "Reading files..." (_w, imports, _vns) <- readProgramFiles [] files liftIO $ printDecs config outdir files $ nubBy sameImport imports sameImport (x, _) (y, _) = x == y futhark-0.25.27/src/Futhark/CLI/Eval.hs000066400000000000000000000105211475065116200173660ustar00rootroot00000000000000-- | @futhark eval@ module Futhark.CLI.Eval (main) where import Control.Exception import Control.Monad import Control.Monad.Except (ExceptT, runExceptT, throwError) import Control.Monad.Free.Church import Control.Monad.IO.Class (MonadIO, liftIO) import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T import Data.Text.IO qualified as T import Futhark.Compiler import Futhark.MonadFreshNames import Futhark.Pipeline import Futhark.Util.Options import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.Interpreter qualified as I import Language.Futhark.Parser import Language.Futhark.Semantic qualified as T import Language.Futhark.TypeChecker qualified as I import Language.Futhark.TypeChecker qualified as T import System.Exit import System.FilePath import System.IO import Prelude -- | Run @futhark eval@. main :: String -> [String] -> IO () main = mainWithOptions interpreterConfig options "options... " run where run [] _ = Nothing run exprs config = Just $ runExprs exprs config runExprs :: [String] -> InterpreterConfig -> IO () runExprs exprs cfg = do let InterpreterConfig _ file = cfg maybe_new_state <- newFutharkiState cfg file (src, env, ctx) <- case maybe_new_state of Left reason -> do hPutDocLn stderr reason exitWith $ ExitFailure 2 Right s -> pure s mapM_ (runExpr src env ctx) exprs -- Use parseExp, checkExp, then interpretExp. runExpr :: VNameSource -> T.Env -> I.Ctx -> String -> IO () runExpr src env ctx str = do uexp <- case parseExp "" (T.pack str) of Left (SyntaxError _ serr) -> do T.hPutStrLn stderr serr exitWith $ ExitFailure 1 Right e -> pure e fexp <- case T.checkExp [] src env uexp of (_, Left terr) -> do hPutDoc stderr $ I.prettyTypeError terr exitWith $ ExitFailure 1 (_, Right ([], e)) -> pure e (_, Right (tparams, e)) -> do putDocLn $ "Inferred type of expression: " <> align (pretty (typeOf e)) T.putStrLn $ "The following types are ambiguous: " <> T.intercalate ", " (map (nameToText . toName . typeParamName) tparams) exitWith $ ExitFailure 1 pval <- runInterpreterNoBreak $ I.interpretExp ctx fexp case pval of Left err -> do hPutDoc stderr $ I.prettyInterpreterError err exitWith $ ExitFailure 1 Right val -> putDoc $ I.prettyValue val <> hardline data InterpreterConfig = InterpreterConfig { interpreterPrintWarnings :: Bool, interpreterFile :: Maybe String } interpreterConfig :: InterpreterConfig interpreterConfig = InterpreterConfig True Nothing options :: [FunOptDescr InterpreterConfig] options = [ Option "f" ["file"] ( ReqArg ( \entry -> Right $ \config -> config {interpreterFile = Just entry} ) "NAME" ) "The file to load before evaluating expressions.", Option "w" ["no-warnings"] (NoArg $ Right $ \config -> config {interpreterPrintWarnings = False}) "Do not print warnings." ] newFutharkiState :: InterpreterConfig -> Maybe FilePath -> IO (Either (Doc AnsiStyle) (VNameSource, T.Env, I.Ctx)) newFutharkiState cfg maybe_file = runExceptT $ do (ws, imports, src) <- badOnLeft prettyCompilerError =<< liftIO ( runExceptT (readProgramFiles [] $ maybeToList maybe_file) `catch` \(err :: IOException) -> pure (externalErrorS (show err)) ) when (interpreterPrintWarnings cfg) $ liftIO $ hPutDoc stderr $ prettyWarnings ws ictx <- foldM (\ctx -> badOnLeft I.prettyInterpreterError <=< runInterpreterNoBreak . I.interpretImport ctx) I.initialCtx $ map (fmap fileProg) imports let (tenv, ienv) = let (iname, fm) = last imports in ( fileScope fm, ictx {I.ctxEnv = I.ctxImports ictx M.! iname} ) pure (src, tenv, ienv) where badOnLeft :: (err -> err') -> Either err a -> ExceptT err' IO a badOnLeft _ (Right x) = pure x badOnLeft p (Left err) = throwError $ p err runInterpreterNoBreak :: (MonadIO m) => F I.ExtOp a -> m (Either I.InterpreterError a) runInterpreterNoBreak m = runF m (pure . Right) intOp where intOp (I.ExtOpError err) = pure $ Left err intOp (I.ExtOpTrace w v c) = do liftIO $ putDocLn $ pretty w <> ":" <+> align (unAnnotate v) c intOp (I.ExtOpBreak _ _ _ c) = c futhark-0.25.27/src/Futhark/CLI/Fmt.hs000066400000000000000000000027731475065116200172370ustar00rootroot00000000000000-- | @futhark fmt@ module Futhark.CLI.Fmt (main) where import Control.Monad (forM_, unless) import Data.Text qualified as T import Data.Text.IO qualified as T import Futhark.Fmt.Printer import Futhark.Util.Options import Futhark.Util.Pretty (docText, hPutDoc, putDoc) import Language.Futhark import Language.Futhark.Parser (SyntaxError (..)) import System.Exit import System.IO newtype FmtCfg = FmtCfg { cfgCheck :: Bool } initialFmtCfg :: FmtCfg initialFmtCfg = FmtCfg {cfgCheck = False} fmtOptions :: [FunOptDescr FmtCfg] fmtOptions = [ Option "" ["check"] (NoArg $ Right $ \cfg -> cfg {cfgCheck = True}) "Check whether file is correctly formatted." ] -- | Run @futhark fmt@. main :: String -> [String] -> IO () main = mainWithOptions initialFmtCfg fmtOptions "[FILES]" $ \args cfg -> case args of [] -> Just $ putDoc =<< onInput "" =<< T.getContents files -> Just $ forM_ files $ \file -> do file_s <- T.readFile file doc <- onInput file file_s if cfgCheck cfg then unless (docText doc == file_s) $ do T.hPutStrLn stderr $ T.pack file <> ": not formatted correctly." T.hPutStr stderr $ docText doc exitFailure else withFile file WriteMode $ \h -> hPutDoc h doc where onInput fname s = do case fmtToDoc fname s of Left (SyntaxError loc err) -> do T.hPutStrLn stderr $ locText loc <> ":\n" <> prettyText err exitFailure Right fmt -> pure fmt futhark-0.25.27/src/Futhark/CLI/HIP.hs000066400000000000000000000006711475065116200171240ustar00rootroot00000000000000-- | @futhark hip@ module Futhark.CLI.HIP (main) where import Futhark.Actions (compileHIPAction) import Futhark.Compiler.CLI import Futhark.Passes (gpumemPipeline) -- | Run @futhark hip@. main :: String -> [String] -> IO () main = compilerMain () [] "Compile HIP" "Generate HIP/C code from optimised Futhark program." gpumemPipeline $ \fcfg () mode outpath prog -> actionProcedure (compileHIPAction fcfg mode outpath) prog futhark-0.25.27/src/Futhark/CLI/LSP.hs000066400000000000000000000025501475065116200171400ustar00rootroot00000000000000{-# LANGUAGE ExplicitNamespaces #-} -- | @futhark lsp@ module Futhark.CLI.LSP (main) where import Control.Monad.IO.Class (MonadIO (liftIO)) import Data.IORef (newIORef) import Futhark.LSP.Handlers (handlers) import Futhark.LSP.State (emptyState) import Language.LSP.Protocol.Types ( SaveOptions (SaveOptions), TextDocumentSyncKind (TextDocumentSyncKind_Incremental), TextDocumentSyncOptions (..), type (|?) (InR), ) import Language.LSP.Server -- | Run @futhark lsp@ main :: String -> [String] -> IO () main _prog _args = do state_mvar <- newIORef emptyState _ <- runServer $ ServerDefinition { onConfigChange = const $ pure (), configSection = "Futhark", parseConfig = const . const $ Right (), defaultConfig = (), doInitialize = \env _req -> pure $ Right env, staticHandlers = handlers state_mvar, interpretHandler = \env -> Iso (runLspT env) liftIO, options = defaultOptions { optTextDocumentSync = Just syncOptions } } pure () syncOptions :: TextDocumentSyncOptions syncOptions = TextDocumentSyncOptions { _openClose = Just False, _change = Just TextDocumentSyncKind_Incremental, _willSave = Just False, _willSaveWaitUntil = Just False, _save = Just $ InR $ SaveOptions $ Just False } futhark-0.25.27/src/Futhark/CLI/Literate.hs000066400000000000000000001165761475065116200202710ustar00rootroot00000000000000-- | @futhark literate@ -- -- Also contains various utility definitions used by "Futhark.CLI.Script". module Futhark.CLI.Literate ( main, Options (..), initialOptions, scriptCommandLineOptions, prepareServer, ) where import Codec.BMP qualified as BMP import Control.Monad import Control.Monad.Except import Control.Monad.State hiding (State) import Data.Bifunctor (first, second) import Data.Bits import Data.ByteString qualified as BS import Data.ByteString.Lazy qualified as LBS import Data.Char import Data.Functor (($>)) import Data.Int (Int64) import Data.List qualified as L import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T import Data.Text.Encoding qualified as T import Data.Text.IO qualified as T import Data.Text.Read qualified as T import Data.Vector.Storable qualified as SVec import Data.Vector.Storable.ByteString qualified as SVec import Data.Void import Data.Word (Word32, Word8) import Futhark.Data import Futhark.Script import Futhark.Server import Futhark.Test import Futhark.Test.Values import Futhark.Util ( directoryContents, fancyTerminal, hashText, nubOrd, runProgramWithExitCode, showText, ) import Futhark.Util.Options import Futhark.Util.Pretty (prettyText, prettyTextOneLine) import Futhark.Util.Pretty qualified as PP import Futhark.Util.ProgressBar import System.Directory ( copyFile, createDirectoryIfMissing, doesFileExist, getCurrentDirectory, removePathForcibly, setCurrentDirectory, ) import System.Environment (getExecutablePath) import System.Exit import System.FilePath import System.IO import System.IO.Error (isDoesNotExistError) import System.IO.Temp (withSystemTempDirectory, withSystemTempFile) import Text.Megaparsec hiding (State, failure, token) import Text.Megaparsec.Char import Text.Printf newtype ImgParams = ImgParams { imgFile :: Maybe FilePath } deriving (Show) defaultImgParams :: ImgParams defaultImgParams = ImgParams {imgFile = Nothing} data VideoParams = VideoParams { videoFPS :: Maybe Int, videoLoop :: Maybe Bool, videoAutoplay :: Maybe Bool, videoFormat :: Maybe T.Text, videoFile :: Maybe FilePath } deriving (Show) defaultVideoParams :: VideoParams defaultVideoParams = VideoParams { videoFPS = Nothing, videoLoop = Nothing, videoAutoplay = Nothing, videoFormat = Nothing, videoFile = Nothing } data AudioParams = AudioParams { audioSamplingFrequency :: Maybe Int, audioCodec :: Maybe T.Text } deriving (Show) defaultAudioParams :: AudioParams defaultAudioParams = AudioParams { audioSamplingFrequency = Nothing, audioCodec = Nothing } data Directive = DirectiveRes Exp | DirectiveBrief Directive | DirectiveCovert Directive | DirectiveImg Exp ImgParams | DirectivePlot Exp (Maybe (Int, Int)) | DirectiveGnuplot Exp T.Text | DirectiveVideo Exp VideoParams | DirectiveAudio Exp AudioParams deriving (Show) varsInDirective :: Directive -> S.Set EntryName varsInDirective (DirectiveRes e) = varsInExp e varsInDirective (DirectiveBrief d) = varsInDirective d varsInDirective (DirectiveCovert d) = varsInDirective d varsInDirective (DirectiveImg e _) = varsInExp e varsInDirective (DirectivePlot e _) = varsInExp e varsInDirective (DirectiveGnuplot e _) = varsInExp e varsInDirective (DirectiveVideo e _) = varsInExp e varsInDirective (DirectiveAudio e _) = varsInExp e pprDirective :: Bool -> Directive -> PP.Doc a pprDirective _ (DirectiveRes e) = "> " <> PP.align (PP.pretty e) pprDirective _ (DirectiveBrief f) = pprDirective False f pprDirective _ (DirectiveCovert f) = pprDirective False f pprDirective _ (DirectiveImg e params) = ("> :img " <> PP.align (PP.pretty e)) <> if null params' then mempty else ";" <> PP.hardline <> PP.stack params' where params' = catMaybes [p "file" imgFile PP.pretty] p s f pretty = do x <- f params Just $ s <> ": " <> pretty x pprDirective True (DirectivePlot e (Just (h, w))) = PP.stack [ "> :plot2d " <> PP.pretty e <> ";", "size: (" <> PP.pretty w <> "," <> PP.pretty h <> ")" ] pprDirective _ (DirectivePlot e _) = "> :plot2d " <> PP.align (PP.pretty e) pprDirective True (DirectiveGnuplot e script) = PP.stack $ "> :gnuplot " <> PP.align (PP.pretty e) <> ";" : map PP.pretty (T.lines script) pprDirective False (DirectiveGnuplot e _) = "> :gnuplot " <> PP.align (PP.pretty e) pprDirective False (DirectiveVideo e _) = "> :video " <> PP.align (PP.pretty e) pprDirective True (DirectiveVideo e params) = ("> :video " <> PP.pretty e) <> if null params' then mempty else ";" <> PP.hardline <> PP.stack params' where params' = catMaybes [ p "fps" videoFPS PP.pretty, p "loop" videoLoop ppBool, p "autoplay" videoAutoplay ppBool, p "format" videoFormat PP.pretty, p "file" videoFile PP.pretty ] ppBool b = if b then "true" else "false" p s f pretty = do x <- f params Just $ s <> ": " <> pretty x pprDirective _ (DirectiveAudio e params) = ("> :audio " <> PP.pretty e) <> if null params' then mempty else ";" <> PP.hardline <> PP.stack params' where params' = catMaybes [ p "sampling_frequency" audioSamplingFrequency PP.pretty, p "codec" audioCodec PP.pretty ] p s f pretty = do x <- f params Just $ s <> ": " <> pretty x instance PP.Pretty Directive where pretty = pprDirective True data Block = BlockCode T.Text | BlockComment T.Text | BlockDirective Directive T.Text deriving (Show) varsInScripts :: [Block] -> S.Set EntryName varsInScripts = foldMap varsInBlock where varsInBlock (BlockDirective d _) = varsInDirective d varsInBlock BlockCode {} = mempty varsInBlock BlockComment {} = mempty type Parser = Parsec Void T.Text postlexeme :: Parser () postlexeme = void $ hspace *> optional (try $ eol *> "--" *> postlexeme) lexeme :: Parser a -> Parser a lexeme p = p <* postlexeme token :: T.Text -> Parser () token = void . try . lexeme . string parseInt :: Parser Int parseInt = lexeme $ read <$> some (satisfy isDigit) restOfLine :: Parser T.Text restOfLine = takeWhileP Nothing (/= '\n') <* (void eol <|> eof) parseBlockComment :: Parser T.Text parseBlockComment = T.unlines <$> some line where line = "--" *> optional " " *> restOfLine parseTestBlock :: Parser T.Text parseTestBlock = T.unlines <$> ((:) <$> header <*> remainder) where header = "-- ==" <* eol remainder = map ("-- " <>) . T.lines <$> parseBlockComment parseBlockCode :: Parser T.Text parseBlockCode = T.unlines . noblanks <$> some line where noblanks = reverse . dropWhile T.null . reverse . dropWhile T.null line = try (notFollowedBy "--") *> notFollowedBy eof *> restOfLine parsePlotParams :: Parser (Maybe (Int, Int)) parsePlotParams = optional $ ";" *> hspace *> eol *> token "-- size:" *> token "(" *> ((,) <$> parseInt <* token "," <*> parseInt) <* token ")" withPredicate :: (a -> Bool) -> String -> Parser a -> Parser a withPredicate f msg p = do r <- lookAhead p if f r then p else fail msg parseFilePath :: Parser FilePath parseFilePath = withPredicate ok "filename must not have directory component" p where p = T.unpack <$> lexeme (takeWhileP Nothing (not . isSpace)) ok f = takeFileName f == f parseImgParams :: Parser ImgParams parseImgParams = fmap (fromMaybe defaultImgParams) $ optional $ ";" *> hspace *> eol *> "-- " *> parseParams defaultImgParams where parseParams params = choice [ choice [pFile params] >>= parseParams, pure params ] pFile params = do token "file:" b <- parseFilePath pure params {imgFile = Just b} parseVideoParams :: Parser VideoParams parseVideoParams = fmap (fromMaybe defaultVideoParams) $ optional $ ";" *> hspace *> eol *> "-- " *> parseParams defaultVideoParams where parseParams params = choice [ choice [pLoop params, pFPS params, pAutoplay params, pFormat params] >>= parseParams, pure params ] parseBool = token "true" $> True <|> token "false" $> False pLoop params = do token "loop:" b <- parseBool pure params {videoLoop = Just b} pFPS params = do token "fps:" fps <- parseInt pure params {videoFPS = Just fps} pAutoplay params = do token "autoplay:" b <- parseBool pure params {videoAutoplay = Just b} pFormat params = do token "format:" s <- lexeme $ takeWhileP Nothing (not . isSpace) pure params {videoFormat = Just s} parseAudioParams :: Parser AudioParams parseAudioParams = fmap (fromMaybe defaultAudioParams) $ optional $ ";" *> hspace *> eol *> "-- " *> parseParams defaultAudioParams where parseParams params = choice [ choice [pSamplingFrequency params, pCodec params] >>= parseParams, pure params ] pSamplingFrequency params = do token "sampling_frequency:" hz <- parseInt pure params {audioSamplingFrequency = Just hz} pCodec params = do token "codec:" s <- lexeme $ takeWhileP Nothing (not . isSpace) pure params {audioCodec = Just s} atStartOfLine :: Parser () atStartOfLine = do col <- sourceColumn <$> getSourcePos when (col /= pos1) empty afterExp :: Parser () afterExp = choice [atStartOfLine, choice [void eol, eof]] withParsedSource :: Parser a -> (a -> T.Text -> b) -> Parser b withParsedSource p f = do s <- getInput bef <- getOffset x <- p aft <- getOffset pure $ f x $ T.take (aft - bef) s stripCommentPrefix :: T.Text -> T.Text stripCommentPrefix = T.unlines . map onLine . T.lines where onLine s | "-- " `T.isPrefixOf` s = T.drop 3 s | otherwise = T.drop 2 s parseBlock :: Parser Block parseBlock = choice [ withParsedSource (token "-- >" *> parseDirective) $ \d s -> BlockDirective d $ stripCommentPrefix s, BlockCode <$> parseTestBlock, BlockCode <$> parseBlockCode, BlockComment <$> parseBlockComment ] where parseDirective = choice [ DirectiveRes <$> parseExp postlexeme <* afterExp, directiveName "covert" $> DirectiveCovert <*> parseDirective, directiveName "brief" $> DirectiveBrief <*> parseDirective, directiveName "img" $> DirectiveImg <*> parseExp postlexeme <*> parseImgParams <* choice [void eol, eof], directiveName "plot2d" $> DirectivePlot <*> parseExp postlexeme <*> parsePlotParams <* choice [void eol, eof], directiveName "gnuplot" $> DirectiveGnuplot <*> parseExp postlexeme <*> (";" *> hspace *> eol *> parseBlockComment), (directiveName "video" <|> directiveName "video") $> DirectiveVideo <*> parseExp postlexeme <*> parseVideoParams <* eol, directiveName "audio" $> DirectiveAudio <*> parseExp postlexeme <*> parseAudioParams <* choice [void eol, eof] ] directiveName s = try $ token (":" <> s) parseProg :: FilePath -> T.Text -> Either T.Text [Block] parseProg fname s = either (Left . T.pack . errorBundlePretty) Right $ parse (many parseBlock <* eof) fname s parseProgFile :: FilePath -> IO [Block] parseProgFile prog = do pres <- parseProg prog <$> T.readFile prog case pres of Left err -> do T.hPutStr stderr err exitFailure Right script -> pure script -- | The collection of file paths (all inside the image directory) -- produced during directive execution. type Files = S.Set FilePath newtype State = State {stateFiles :: Files} newtype ScriptM a = ScriptM (ExceptT T.Text (StateT State IO) a) deriving ( Functor, Applicative, Monad, MonadError T.Text, MonadFail, MonadIO, MonadState State ) runScriptM :: ScriptM a -> IO (Either T.Text a, Files) runScriptM (ScriptM m) = second stateFiles <$> runStateT (runExceptT m) s where s = State mempty withTempFile :: (FilePath -> ScriptM a) -> ScriptM a withTempFile f = join . liftIO . withSystemTempFile "futhark-literate" $ \tmpf tmpf_h -> do hClose tmpf_h (res, files) <- runScriptM (f tmpf) pure $ do modify $ \s -> s {stateFiles = files <> stateFiles s} either throwError pure res withTempDir :: (FilePath -> ScriptM a) -> ScriptM a withTempDir f = join . liftIO . withSystemTempDirectory "futhark-literate" $ \dir -> do (res, files) <- runScriptM (f dir) pure $ do modify $ \s -> s {stateFiles = files <> stateFiles s} either throwError pure res greyFloatToImg :: (RealFrac a, SVec.Storable a) => SVec.Vector a -> SVec.Vector Word32 greyFloatToImg = SVec.map grey where grey i = let i' = round (i * 255) .&. 0xFF in (i' `shiftL` 16) .|. (i' `shiftL` 8) .|. i' greyByteToImg :: (Integral a, SVec.Storable a) => SVec.Vector a -> SVec.Vector Word32 greyByteToImg = SVec.map grey where grey i = (fromIntegral i `shiftL` 16) .|. (fromIntegral i `shiftL` 8) .|. fromIntegral i -- BMPs are RGBA and bottom-up where we assumes images are top-down -- and ARGB. We fix this up before encoding the BMP. This is -- probably a little slower than it has to be. vecToBMP :: Int -> Int -> SVec.Vector Word32 -> LBS.ByteString vecToBMP h w = BMP.renderBMP . BMP.packRGBA32ToBMP24 w h . SVec.vectorToByteString . frobVec where frobVec vec = SVec.generate (h * w * 4) (pix vec) pix vec l = let (i, j) = (l `div` 4) `divMod` w argb = vec SVec.! ((h - 1 - i) * w + j) c = (argb `shiftR` (24 - ((l + 1) `mod` 4) * 8)) .&. 0xFF in fromIntegral c :: Word8 valueToBMP :: Value -> Maybe LBS.ByteString valueToBMP v@(U32Value _ bytes) | [h, w] <- valueShape v = Just $ vecToBMP h w bytes valueToBMP v@(I32Value _ bytes) | [h, w] <- valueShape v = Just $ vecToBMP h w $ SVec.map fromIntegral bytes valueToBMP v@(F32Value _ bytes) | [h, w] <- valueShape v = Just $ vecToBMP h w $ greyFloatToImg bytes valueToBMP v@(U8Value _ bytes) | [h, w] <- valueShape v = Just $ vecToBMP h w $ greyByteToImg bytes valueToBMP v@(F64Value _ bytes) | [h, w] <- valueShape v = Just $ vecToBMP h w $ greyFloatToImg bytes valueToBMP v@(BoolValue _ bytes) | [h, w] <- valueShape v = Just $ vecToBMP h w $ greyByteToImg $ SVec.map ((*) 255 . fromEnum) bytes valueToBMP _ = Nothing valueToBMPs :: Value -> Maybe [LBS.ByteString] valueToBMPs = mapM valueToBMP . valueElems system :: (MonadIO m, MonadError T.Text m) => FilePath -> [String] -> T.Text -> m T.Text system prog options input = do res <- liftIO $ runProgramWithExitCode prog options $ T.encodeUtf8 input case res of Left err -> throwError $ prog' <> " failed: " <> showText err Right (ExitSuccess, stdout_t, _) -> pure $ T.pack stdout_t Right (ExitFailure code', _, stderr_t) -> throwError $ prog' <> " failed with exit code " <> showText code' <> " and stderr:\n" <> T.pack stderr_t where prog' = "'" <> T.pack prog <> "'" formatDataForGnuplot :: [Value] -> T.Text formatDataForGnuplot = T.unlines . map line . L.transpose . map valueElems where line = T.unwords . map prettyText imgBlock :: FilePath -> T.Text imgBlock f = "![](" <> T.pack f <> ")\n" videoBlock :: VideoParams -> FilePath -> T.Text videoBlock opts f = "![](" <> T.pack f <> ")" <> opts' <> "\n" where opts' | all T.null [loop, autoplay] = mempty | otherwise = "{" <> T.unwords [loop, autoplay] <> "}" boolOpt s prop | Just b <- prop opts = if b then s <> "=\"true\"" else s <> "=\"false\"" | otherwise = mempty loop = boolOpt "loop" videoLoop autoplay = boolOpt "autoplay" videoAutoplay plottable :: CompoundValue -> Maybe [Value] plottable (ValueTuple vs) = do (vs', ns') <- mapAndUnzipM inspect vs guard $ length (nubOrd ns') == 1 Just vs' where inspect (ValueAtom v) | [n] <- valueShape v = Just (v, n) inspect _ = Nothing plottable _ = Nothing withGnuplotData :: [(T.Text, T.Text)] -> [(T.Text, [Value])] -> ([T.Text] -> [T.Text] -> ScriptM a) -> ScriptM a withGnuplotData sets [] cont = uncurry cont $ unzip $ reverse sets withGnuplotData sets ((f, vs) : xys) cont = withTempFile $ \fname -> do liftIO $ T.writeFile fname $ formatDataForGnuplot vs withGnuplotData ((f, f <> "='" <> T.pack fname <> "'") : sets) xys cont loadBMP :: FilePath -> ScriptM (Compound Value) loadBMP bmpfile = do res <- liftIO $ BMP.readBMP bmpfile case res of Left err -> throwError $ "Failed to read BMP:\n" <> showText err Right bmp -> do let bmp_bs = BMP.unpackBMPToRGBA32 bmp (w, h) = BMP.bmpDimensions bmp shape = SVec.fromList [fromIntegral h, fromIntegral w] pix l = let (i, j) = l `divMod` w l' = (h - 1 - i) * w + j r = fromIntegral $ bmp_bs `BS.index` (l' * 4) g = fromIntegral $ bmp_bs `BS.index` (l' * 4 + 1) b = fromIntegral $ bmp_bs `BS.index` (l' * 4 + 2) a = fromIntegral $ bmp_bs `BS.index` (l' * 4 + 3) in (a `shiftL` 24) .|. (r `shiftL` 16) .|. (g `shiftL` 8) .|. b pure $ ValueAtom $ U32Value shape $ SVec.generate (w * h) pix loadImage :: FilePath -> ScriptM (Compound Value) loadImage imgfile = withTempDir $ \dir -> do let bmpfile = dir takeBaseName imgfile `replaceExtension` "bmp" void $ system "convert" [imgfile, "-type", "TrueColorAlpha", bmpfile] mempty loadBMP bmpfile loadPCM :: Int -> FilePath -> ScriptM (Compound Value) loadPCM num_channels pcmfile = do contents <- liftIO $ LBS.readFile pcmfile let v = SVec.byteStringToVector $ LBS.toStrict contents channel_length = SVec.length v `div` num_channels shape = SVec.fromList [ fromIntegral num_channels, fromIntegral channel_length ] -- ffmpeg outputs audio data in column-major format. `backPermuter` computes the -- tranposed indexes for a backpermutation. backPermuter i = (i `mod` channel_length) * num_channels + i `div` channel_length perm = SVec.generate (SVec.length v) backPermuter pure $ ValueAtom $ F64Value shape $ SVec.backpermute v perm loadAudio :: FilePath -> ScriptM (Compound Value) loadAudio audiofile = do s <- system "ffprobe" [audiofile, "-show_entries", "stream=channels", "-select_streams", "a", "-of", "compact=p=0:nk=1", "-v", "0"] mempty case T.decimal s of Right (num_channels, _) -> do withTempDir $ \dir -> do let pcmfile = dir takeBaseName audiofile `replaceExtension` "pcm" void $ system "ffmpeg" ["-i", audiofile, "-c:a", "pcm_f64le", "-map", "0", "-f", "data", pcmfile] mempty loadPCM num_channels pcmfile _ -> throwError "$loadImg failed to detect the number of channels in the audio input" literateBuiltin :: EvalBuiltin ScriptM literateBuiltin "loadimg" vs = case vs of [ValueAtom v] | Just path <- getValue v -> do let path' = map (chr . fromIntegral) (path :: [Word8]) loadImage path' _ -> throwError $ "$loadimg does not accept arguments of types: " <> T.intercalate ", " (map (prettyText . fmap valueType) vs) literateBuiltin "loadaudio" vs = case vs of [ValueAtom v] | Just path <- getValue v -> do let path' = map (chr . fromIntegral) (path :: [Word8]) loadAudio path' _ -> throwError $ "$loadaudio does not accept arguments of types: " <> T.intercalate ", " (map (prettyText . fmap valueType) vs) literateBuiltin f vs = scriptBuiltin "." f vs -- | Some of these only make sense for @futhark literate@, but enough -- are also sensible for @futhark script@ that we can share them. data Options = Options { scriptBackend :: String, scriptFuthark :: Maybe FilePath, scriptExtraOptions :: [String], scriptCompilerOptions :: [String], scriptSkipCompilation :: Bool, scriptOutput :: Maybe FilePath, scriptVerbose :: Int, scriptStopOnError :: Bool, scriptBinary :: Bool, scriptExps :: [Either FilePath T.Text] } -- | The configuration before any user-provided options are processed. initialOptions :: Options initialOptions = Options { scriptBackend = "c", scriptFuthark = Nothing, scriptExtraOptions = [], scriptCompilerOptions = [], scriptSkipCompilation = False, scriptOutput = Nothing, scriptVerbose = 0, scriptStopOnError = False, scriptBinary = False, scriptExps = [] } data Env = Env { envImgDir :: FilePath, envOpts :: Options, envServer :: ScriptServer, envHash :: T.Text } newFile :: Env -> (Maybe FilePath, FilePath) -> (FilePath -> ScriptM ()) -> ScriptM FilePath newFile env (fname_desired, template) m = do let fname_base = fromMaybe (T.unpack (envHash env) <> "-" <> template) fname_desired fname = envImgDir env fname_base exists <- liftIO $ doesFileExist fname liftIO $ createDirectoryIfMissing True $ envImgDir env when (exists && scriptVerbose (envOpts env) > 0) $ liftIO . T.hPutStrLn stderr $ "Using existing file: " <> T.pack fname unless exists $ do when (scriptVerbose (envOpts env) > 0) $ liftIO . T.hPutStrLn stderr $ "Generating new file: " <> T.pack fname m fname modify $ \s -> s {stateFiles = S.insert fname $ stateFiles s} pure fname newFileContents :: Env -> (Maybe FilePath, FilePath) -> (FilePath -> ScriptM ()) -> ScriptM T.Text newFileContents env f m = liftIO . T.readFile =<< newFile env f m processDirective :: Env -> Directive -> ScriptM T.Text processDirective env (DirectiveBrief d) = processDirective env d processDirective env (DirectiveCovert d) = processDirective env d processDirective env (DirectiveRes e) = do result <- newFileContents env (Nothing, "eval.txt") $ \resultf -> do v <- either nope pure =<< evalExpToGround literateBuiltin (envServer env) e liftIO $ T.writeFile resultf $ prettyText v pure $ T.unlines ["```", result, "```"] where nope t = throwError $ "Cannot show value of type " <> prettyText t -- processDirective env (DirectiveImg e params) = do fmap imgBlock . newFile env (imgFile params, "img.png") $ \pngfile -> do maybe_v <- evalExpToGround literateBuiltin (envServer env) e case maybe_v of Right (ValueAtom v) | Just bmp <- valueToBMP v -> do withTempDir $ \dir -> do let bmpfile = dir "img.bmp" liftIO $ LBS.writeFile bmpfile bmp void $ system "convert" [bmpfile, pngfile] mempty Right v -> nope $ fmap valueType v Left t -> nope t where nope t = throwError $ "Cannot create image from value of type " <> prettyText t -- processDirective env (DirectivePlot e size) = do fmap imgBlock . newFile env (Nothing, "plot.png") $ \pngfile -> do maybe_v <- evalExpToGround literateBuiltin (envServer env) e case maybe_v of Right v | Just vs <- plottable2d v -> plotWith [(Nothing, vs)] pngfile Right (ValueRecord m) | Just m' <- traverse plottable2d m -> do plotWith (map (first Just) $ M.toList m') pngfile Right v -> throwError $ "Cannot plot value of type " <> prettyText (fmap valueType v) Left t -> throwError $ "Cannot plot opaque value of type " <> prettyText t where plottable2d v = do [x, y] <- plottable v Just [x, y] tag (Nothing, xys) j = ("data" <> showText (j :: Int), xys) tag (Just f, xys) _ = (f, xys) plotWith xys pngfile = withGnuplotData [] (zipWith tag xys [0 ..]) $ \fs sets -> do let size' = T.pack $ case size of Nothing -> "500,500" Just (w, h) -> show w ++ "," ++ show h plotCmd f title = let title' = case title of Nothing -> "notitle" Just x -> "title '" <> x <> "'" in f <> " " <> title' <> " with lines" cmds = T.intercalate ", " (zipWith plotCmd fs (map fst xys)) script = T.unlines [ "set terminal png size " <> size' <> " enhanced", "set output '" <> T.pack pngfile <> "'", "set key outside", T.unlines sets, "plot " <> cmds ] void $ system "gnuplot" [] script -- processDirective env (DirectiveGnuplot e script) = do fmap imgBlock . newFile env (Nothing, "plot.png") $ \pngfile -> do maybe_v <- evalExpToGround literateBuiltin (envServer env) e case maybe_v of Right (ValueRecord m) | Just m' <- traverse plottable m -> plotWith (M.toList m') pngfile Right v -> throwError $ "Cannot plot value of type " <> prettyText (fmap valueType v) Left t -> throwError $ "Cannot plot opaque value of type " <> prettyText t where plotWith xys pngfile = withGnuplotData [] xys $ \_ sets -> do let script' = T.unlines [ "set terminal png enhanced", "set output '" <> T.pack pngfile <> "'", T.unlines sets, script ] void $ system "gnuplot" [] script' -- processDirective env (DirectiveVideo e params) = do unless (format `elem` ["webm", "gif"]) $ throwError $ "Unknown video format: " <> format let file = (videoFile params, "video" <.> T.unpack format) fmap (videoBlock params) . newFile env file $ \videofile -> do v <- evalExp literateBuiltin (envServer env) e let nope = throwError $ "Cannot produce video from value of type " <> prettyText (fmap scriptValueType v) case v of ValueAtom SValue {} -> do ValueAtom arr <- getExpValue (envServer env) v case valueToBMPs arr of Nothing -> nope Just bmps -> withTempDir $ \dir -> do zipWithM_ (writeBMPFile dir) [0 ..] bmps onWebM videofile =<< bmpsToVideo dir ValueTuple [stepfun, initial, num_frames] | ValueAtom (SFun stepfun' _ [_, _] closure) <- stepfun, ValueAtom (SValue "i64" _) <- num_frames -> do Just (ValueAtom num_frames') <- mapM getValue <$> getExpValue (envServer env) num_frames withTempDir $ \dir -> do let num_frames_int = fromIntegral (num_frames' :: Int64) renderFrames dir (stepfun', map ValueAtom closure) initial num_frames_int onWebM videofile =<< bmpsToVideo dir _ -> nope where framerate = fromMaybe 30 $ videoFPS params format = fromMaybe "webm" $ videoFormat params bmpfile dir j = dir printf "frame%010d.bmp" (j :: Int) (progressStep, progressDone) | fancyTerminal, scriptVerbose (envOpts env) > 0 = ( \j num_frames -> liftIO $ do T.putStr $ "\r" <> progressBar (ProgressBar 40 (fromIntegral num_frames - 1) (fromIntegral j)) <> "generating frame " <> prettyText (j + 1) <> "/" <> prettyText num_frames <> " " hFlush stdout, liftIO $ T.putStrLn "" ) | otherwise = (\_ _ -> pure (), pure ()) renderFrames dir (stepfun, closure) initial num_frames = do foldM_ frame initial [0 .. num_frames - 1] progressDone where frame old_state j = do progressStep j num_frames v <- evalExp literateBuiltin (envServer env) . Call (FuncFut stepfun) . map valueToExp $ closure ++ [old_state] freeValue (envServer env) old_state let nope = throwError $ "Cannot handle step function return type: " <> prettyText (fmap scriptValueType v) case v of ValueTuple [arr_v@(ValueAtom SValue {}), new_state] -> do ValueAtom arr <- getExpValue (envServer env) arr_v freeValue (envServer env) arr_v case valueToBMP arr of Nothing -> nope Just bmp -> do writeBMPFile dir j bmp pure new_state _ -> nope writeBMPFile dir j bmp = liftIO $ LBS.writeFile (bmpfile dir j) bmp bmpsToVideo dir = do void $ system "ffmpeg" [ "-y", "-r", show framerate, "-i", dir "frame%010d.bmp", "-c:v", "libvpx-vp9", "-pix_fmt", "yuv420p", "-b:v", "2M", dir "video.webm" ] mempty pure $ dir "video.webm" onWebM videofile webmfile | format == "gif" = void $ system "ffmpeg" ["-i", webmfile, videofile] mempty | otherwise = liftIO $ copyFile webmfile videofile -- processDirective env (DirectiveAudio e params) = do fmap imgBlock . newFile env (Nothing, "output." <> T.unpack output_format) $ \audiofile -> do withTempDir $ \dir -> do maybe_v <- evalExpToGround literateBuiltin (envServer env) e maybe_raw_files <- toRawFiles dir maybe_v case maybe_raw_files of (input_format, raw_files) -> do void $ system "ffmpeg" ( concatMap ( \raw_file -> [ "-f", input_format, "-ar", show sampling_frequency, "-i", raw_file ] ) raw_files ++ [ "-f", T.unpack output_format, "-filter_complex", concatMap (\i -> "[" <> show i <> ":a]") [0 .. length raw_files - 1] <> "amerge=inputs=" <> show (length raw_files) <> "[a]", "-map", "[a]", audiofile ] ) mempty where writeRaw dir name v = do let rawfile = dir name let Just bytes = toBytes v liftIO $ LBS.writeFile rawfile $ LBS.fromStrict bytes toRawFiles dir (Right (ValueAtom v)) | length (valueShape v) == 1, Just input_format <- toFfmpegFormat v = do writeRaw dir "raw.pcm" v pure (input_format, [dir "raw.pcm"]) | length (valueShape v) == 2, Just input_format <- toFfmpegFormat v = do (input_format,) <$> zipWithM ( \v' i -> do let file_name = "raw-" <> show i <> ".pcm" writeRaw dir file_name v' pure $ dir file_name ) (valueElems v) [0 :: Int ..] toRawFiles _ v = nope $ fmap (fmap valueType) v toFfmpegFormat I8Value {} = Just "s8" toFfmpegFormat U8Value {} = Just "u8" toFfmpegFormat I16Value {} = Just "s16le" toFfmpegFormat U16Value {} = Just "u16le" toFfmpegFormat I32Value {} = Just "s32le" toFfmpegFormat U32Value {} = Just "u32le" toFfmpegFormat F32Value {} = Just "f32le" toFfmpegFormat F64Value {} = Just "f64le" toFfmpegFormat _ = Nothing toBytes (I8Value _ bytes) = Just $ SVec.vectorToByteString bytes toBytes (U8Value _ bytes) = Just $ SVec.vectorToByteString bytes toBytes (I16Value _ bytes) = Just $ SVec.vectorToByteString bytes toBytes (U16Value _ bytes) = Just $ SVec.vectorToByteString bytes toBytes (I32Value _ bytes) = Just $ SVec.vectorToByteString bytes toBytes (U32Value _ bytes) = Just $ SVec.vectorToByteString bytes toBytes (F32Value _ bytes) = Just $ SVec.vectorToByteString bytes toBytes (F64Value _ bytes) = Just $ SVec.vectorToByteString bytes toBytes _ = Nothing output_format = fromMaybe "wav" $ audioCodec params sampling_frequency = fromMaybe 44100 $ audioSamplingFrequency params nope _ = throwError "Cannot create audio from value" -- Did this script block succeed or fail? data Failure = Failure | Success deriving (Eq, Ord, Show) processBlock :: Env -> Block -> IO (Failure, T.Text, Files) processBlock _ (BlockCode code) | T.null code = pure (Success, mempty, mempty) | otherwise = pure (Success, "```futhark\n" <> code <> "```\n", mempty) processBlock _ (BlockComment pretty) = pure (Success, pretty, mempty) processBlock env (BlockDirective directive text) = do when (scriptVerbose (envOpts env) > 0) $ T.hPutStrLn stderr . PP.docText $ "Processing " <> PP.align (PP.pretty directive) <> "..." let prompt = case directive of DirectiveCovert _ -> mempty DirectiveBrief _ -> "```\n" <> PP.docText (pprDirective False directive) <> "\n```\n" _ -> "```\n" <> text <> "```\n" env' = env {envHash = hashText (envHash env <> prettyText directive)} (r, files) <- runScriptM $ processDirective env' directive case r of Left err -> failed prompt err files Right t -> pure (Success, prompt <> "\n" <> t, files) where failed prompt err files = do let message = prettyTextOneLine directive <> " failed:\n" <> err <> "\n" liftIO $ T.hPutStr stderr message when (scriptStopOnError (envOpts env)) exitFailure pure ( Failure, T.unlines [prompt, "**FAILED**", "```", err, "```"], files ) -- Delete all files in the given directory that are not contained in -- 'files'. cleanupImgDir :: Env -> Files -> IO () cleanupImgDir env keep_files = mapM_ toRemove . filter (not . (`S.member` keep_files)) =<< (directoryContents (envImgDir env) `catchError` onError) where onError e | isDoesNotExistError e = pure [] | otherwise = throwError e toRemove f = do when (scriptVerbose (envOpts env) > 0) $ T.hPutStrLn stderr $ "Deleting unused file: " <> T.pack f removePathForcibly f processScript :: Env -> [Block] -> IO (Failure, T.Text) processScript env script = do (failures, outputs, files) <- unzip3 <$> mapM (processBlock env) script cleanupImgDir env $ mconcat files pure (L.foldl' min Success failures, T.intercalate "\n" outputs) -- | Common command line options that transform 'Options'. scriptCommandLineOptions :: [FunOptDescr Options] scriptCommandLineOptions = [ Option [] ["backend"] ( ReqArg (\backend -> Right $ \config -> config {scriptBackend = backend}) "PROGRAM" ) "The compiler used (defaults to 'c').", Option [] ["futhark"] ( ReqArg (\prog -> Right $ \config -> config {scriptFuthark = Just prog}) "PROGRAM" ) "The binary used for operations (defaults to same binary as 'futhark script').", Option "p" ["pass-option"] ( ReqArg ( \opt -> Right $ \config -> config {scriptExtraOptions = opt : scriptExtraOptions config} ) "OPT" ) "Pass this option to programs being run.", Option [] ["pass-compiler-option"] ( ReqArg ( \opt -> Right $ \config -> config {scriptCompilerOptions = opt : scriptCompilerOptions config} ) "OPT" ) "Pass this option to the compiler.", Option [] ["skip-compilation"] (NoArg $ Right $ \config -> config {scriptSkipCompilation = True}) "Use already compiled program.", Option "v" ["verbose"] (NoArg $ Right $ \config -> config {scriptVerbose = scriptVerbose config + 1}) "Enable logging. Pass multiple times for more." ] commandLineOptions :: [FunOptDescr Options] commandLineOptions = scriptCommandLineOptions <> [ Option "o" ["output"] (ReqArg (\opt -> Right $ \config -> config {scriptOutput = Just opt}) "FILE") "Override output file. Image directory is set to basename appended with -img/.", Option [] ["stop-on-error"] (NoArg $ Right $ \config -> config {scriptStopOnError = True}) "Stop and do not produce output file if any directive fails." ] -- | Start up (and eventually shut down) a Futhark server -- corresponding to the provided program. If the program has a @.fut@ -- extension, it will be compiled automatically. prepareServer :: FilePath -> Options -> (ScriptServer -> IO a) -> IO a prepareServer prog opts f = do futhark <- maybe getExecutablePath pure $ scriptFuthark opts let is_fut = takeExtension prog == ".fut" unless (scriptSkipCompilation opts || not is_fut) $ do let compile_options = "--server" : scriptCompilerOptions opts when (scriptVerbose opts > 0) $ T.hPutStrLn stderr $ "Compiling " <> T.pack prog <> "..." when (scriptVerbose opts > 1) $ T.hPutStrLn stderr $ T.pack $ unwords compile_options let onError err = do T.hPutStrLn stderr err exitFailure void $ either onError pure <=< runExceptT $ compileProgram compile_options (FutharkExe futhark) (scriptBackend opts) prog let run_options = scriptExtraOptions opts onLine "call" l = T.putStrLn l onLine "startup" l = T.putStrLn l onLine _ _ = pure () prog' = if is_fut then dropExtension prog else prog cfg = (futharkServerCfg ("." prog') run_options) { cfgOnLine = if scriptVerbose opts > 0 then onLine else const . const $ pure () } withScriptServer cfg f -- | Run @futhark literate@. main :: String -> [String] -> IO () main = mainWithOptions initialOptions commandLineOptions "program" $ \args opts -> case args of [prog] -> Just $ do futhark <- maybe getExecutablePath pure $ scriptFuthark opts let onError err = do T.hPutStrLn stderr err exitFailure proghash <- either onError pure <=< runExceptT $ system futhark ["hash", prog] mempty script <- parseProgFile prog orig_dir <- getCurrentDirectory let entryOpt v = "--entry-point=" ++ T.unpack v opts' = opts { scriptCompilerOptions = map entryOpt (S.toList (varsInScripts script)) <> scriptCompilerOptions opts } prepareServer prog opts' $ \server -> do let mdfile = fromMaybe (prog `replaceExtension` "md") $ scriptOutput opts prog_dir = takeDirectory prog imgdir = dropExtension (takeFileName mdfile) <> "-img" env = Env { envServer = server, envOpts = opts, envHash = proghash, envImgDir = imgdir } when (scriptVerbose opts > 0) $ do T.hPutStrLn stderr $ "Executing from " <> T.pack prog_dir setCurrentDirectory prog_dir (failure, md) <- processScript env script T.writeFile (orig_dir mdfile) md when (failure == Failure) exitFailure _ -> Nothing futhark-0.25.27/src/Futhark/CLI/Main.hs000066400000000000000000000145621475065116200173740ustar00rootroot00000000000000-- | The main function for the @futhark@ command line program. module Futhark.CLI.Main (main) where import Control.Exception import Data.List (sortOn) import Data.Maybe import Data.Text.IO qualified as T import Futhark.CLI.Autotune qualified as Autotune import Futhark.CLI.Bench qualified as Bench import Futhark.CLI.Benchcmp qualified as Benchcmp import Futhark.CLI.C qualified as C import Futhark.CLI.CUDA qualified as CCUDA import Futhark.CLI.Check qualified as Check import Futhark.CLI.Datacmp qualified as Datacmp import Futhark.CLI.Dataset qualified as Dataset import Futhark.CLI.Defs qualified as Defs import Futhark.CLI.Dev qualified as Dev import Futhark.CLI.Doc qualified as Doc import Futhark.CLI.Eval qualified as Eval import Futhark.CLI.Fmt qualified as Fmt import Futhark.CLI.HIP qualified as HIP import Futhark.CLI.LSP qualified as LSP import Futhark.CLI.Literate qualified as Literate import Futhark.CLI.Misc qualified as Misc import Futhark.CLI.Multicore qualified as Multicore import Futhark.CLI.MulticoreISPC qualified as MulticoreISPC import Futhark.CLI.MulticoreWASM qualified as MulticoreWASM import Futhark.CLI.OpenCL qualified as OpenCL import Futhark.CLI.Pkg qualified as Pkg import Futhark.CLI.Profile qualified as Profile import Futhark.CLI.PyOpenCL qualified as PyOpenCL import Futhark.CLI.Python qualified as Python import Futhark.CLI.Query qualified as Query import Futhark.CLI.REPL qualified as REPL import Futhark.CLI.Run qualified as Run import Futhark.CLI.Script qualified as Script import Futhark.CLI.Test qualified as Test import Futhark.CLI.WASM qualified as WASM import Futhark.Error import Futhark.Util (maxinum, showText) import Futhark.Util.Options import GHC.IO.Encoding (setLocaleEncoding) import GHC.IO.Exception (IOErrorType (..), IOException (..)) import System.Environment import System.Exit import System.IO import Prelude type Command = String -> [String] -> IO () commands :: [(String, (Command, String))] commands = sortOn fst [ ("dev", (Dev.main, "Run compiler passes directly.")), ("eval", (Eval.main, "Evaluate Futhark expressions passed in as arguments")), ("repl", (REPL.main, "Run interactive Read-Eval-Print-Loop.")), ("run", (Run.main, "Run a program through the (slow!) interpreter.")), ("c", (C.main, "Compile to sequential C.")), ("opencl", (OpenCL.main, "Compile to C calling OpenCL.")), ("cuda", (CCUDA.main, "Compile to C calling CUDA.")), ("hip", (HIP.main, "Compile to C calling HIP.")), ("multicore", (Multicore.main, "Compile to multicore C.")), ("python", (Python.main, "Compile to sequential Python.")), ("pyopencl", (PyOpenCL.main, "Compile to Python calling PyOpenCL.")), ("wasm", (WASM.main, "Compile to WASM with sequential C")), ("wasm-multicore", (MulticoreWASM.main, "Compile to WASM with multicore C")), ("ispc", (MulticoreISPC.main, "Compile to multicore ISPC")), ("test", (Test.main, "Test Futhark programs.")), ("bench", (Bench.main, "Benchmark Futhark programs.")), ("dataset", (Dataset.main, "Generate random test data.")), ("datacmp", (Datacmp.main, "Compare Futhark data files for equality.")), ("dataget", (Misc.mainDataget, "Extract test data.")), ("doc", (Doc.main, "Generate documentation for Futhark code.")), ("pkg", (Pkg.main, "Manage local packages.")), ("check", (Check.main, "Type-check a program.")), ("check-syntax", (Misc.mainCheckSyntax, "Syntax-check a program.")), ("imports", (Misc.mainImports, "Print all non-builtin imported Futhark files.")), ("hash", (Misc.mainHash, "Print hash of program AST.")), ("autotune", (Autotune.main, "Autotune threshold parameters.")), ("defs", (Defs.main, "Show location and name of all definitions.")), ("query", (Query.main, "Query semantic information about program.")), ("literate", (Literate.main, "Process a literate Futhark program.")), ("script", (Script.main, "Run FutharkScript expressions.")), ("lsp", (LSP.main, "Run LSP server.")), ("thanks", (Misc.mainThanks, "Express gratitude.")), ("tokens", (Misc.mainTokens, "Print tokens from Futhark file.")), ("benchcmp", (Benchcmp.main, "Compare two benchmark results.")), ("profile", (Profile.main, "Analyse profiling data.")), ("fmt", (Fmt.main, "Reformat Futhark source file.")) ] msg :: String msg = unlines $ [" options...", "Commands:", ""] ++ [ " " <> cmd <> replicate (k - length cmd) ' ' <> desc | (cmd, (_, desc)) <- commands ] where k = maxinum (map (length . fst) commands) + 3 -- | Catch all IO exceptions and print a better error message if they -- happen. reportingIOErrors :: IO () -> IO () reportingIOErrors = flip catches [ Handler onExit, Handler onICE, Handler onIOException, Handler onError ] where onExit :: ExitCode -> IO () onExit = throwIO onICE :: InternalError -> IO () onICE (Error CompilerLimitation s) = do T.hPutStrLn stderr "Known compiler limitation encountered. Sorry." T.hPutStrLn stderr "Revise your program or try a different Futhark compiler." T.hPutStrLn stderr s exitWith $ ExitFailure 1 onICE (Error CompilerBug s) = do T.hPutStrLn stderr "Internal compiler error." T.hPutStrLn stderr "Please report this at https://github.com/diku-dk/futhark/issues." T.hPutStrLn stderr s exitWith $ ExitFailure 1 onError :: SomeException -> IO () onError e | Just UserInterrupt <- asyncExceptionFromException e = pure () -- This corresponds to CTRL-C, which is not an error. | otherwise = do T.hPutStrLn stderr "Internal compiler error (unhandled IO exception)." T.hPutStrLn stderr "Please report this at https://github.com/diku-dk/futhark/issues" T.hPutStrLn stderr $ showText e exitWith $ ExitFailure 1 onIOException :: IOException -> IO () onIOException e | ioe_type e == ResourceVanished = exitWith $ ExitFailure 1 | otherwise = throw e -- | The @futhark@ executable. main :: IO () main = reportingIOErrors $ do hSetEncoding stdout utf8 hSetEncoding stderr utf8 setLocaleEncoding utf8 args <- getArgs prog <- getProgName case args of cmd : args' | Just (m, _) <- lookup cmd commands -> m (unwords [prog, cmd]) args' _ -> mainWithOptions () [] msg (const . const Nothing) prog args futhark-0.25.27/src/Futhark/CLI/Misc.hs000066400000000000000000000105271475065116200174000ustar00rootroot00000000000000-- | Various small subcommands that are too simple to deserve their own file. module Futhark.CLI.Misc ( mainImports, mainHash, mainDataget, mainCheckSyntax, mainThanks, mainTokens, ) where import Control.Monad import Control.Monad.State import Data.ByteString.Lazy qualified as BS import Data.Function (on) import Data.List (nubBy) import Data.Loc (L (..), startPos) import Data.Text qualified as T import Data.Text.IO qualified as T import Futhark.Compiler import Futhark.Test import Futhark.Util (hashText, interactWithFileSafely) import Futhark.Util.Options import Futhark.Util.Pretty (prettyTextOneLine) import Language.Futhark.Parser.Lexer (scanTokensText) import Language.Futhark.Prop (isBuiltin) import Language.Futhark.Semantic (includeToString) import System.Environment (getExecutablePath) import System.Exit import System.FilePath import System.IO import System.Random -- | @futhark imports@ mainImports :: String -> [String] -> IO () mainImports = mainWithOptions () [] "program" $ \args () -> case args of [file] -> Just $ do (_, prog_imports, _) <- readProgramOrDie file liftIO . putStr . unlines . map (++ ".fut") . filter (not . isBuiltin) $ map (includeToString . fst) prog_imports _ -> Nothing -- | @futhark hash@ mainHash :: String -> [String] -> IO () mainHash = mainWithOptions () [] "program" $ \args () -> case args of [file] -> Just $ do prog <- filter (not . isBuiltin . fst) <$> readUntypedProgramOrDie file -- The 'map snd' is an attempt to get rid of the file names so -- they won't affect the hashing. liftIO $ T.putStrLn $ hashText $ prettyTextOneLine $ map snd prog _ -> Nothing -- | @futhark dataget@ mainDataget :: String -> [String] -> IO () mainDataget = mainWithOptions () [] "program dataset" $ \args () -> case args of [file, dataset] -> Just $ dataget file $ T.pack dataset _ -> Nothing where dataget prog dataset = do let dir = takeDirectory prog runs <- testSpecRuns <$> testSpecFromProgramOrDie prog let exact = filter ((dataset ==) . runDescription) runs infixes = filter ((dataset `T.isInfixOf`) . runDescription) runs futhark <- FutharkExe <$> getExecutablePath case nubBy ((==) `on` runDescription) $ if null exact then infixes else exact of [x] -> BS.putStr =<< getValuesBS futhark dir (runInput x) [] -> do T.hPutStr stderr $ "No dataset '" <> dataset <> "'.\n" T.hPutStr stderr "Available datasets:\n" mapM_ (T.hPutStrLn stderr . (" " <>) . runDescription) runs exitFailure runs' -> do T.hPutStr stderr $ "Dataset '" <> dataset <> "' ambiguous:\n" mapM_ (T.hPutStrLn stderr . (" " <>) . runDescription) runs' exitFailure testSpecRuns = testActionRuns . testAction testActionRuns CompileTimeFailure {} = [] testActionRuns (RunCases ios _ _) = concatMap iosTestRuns ios -- | @futhark check-syntax@ mainCheckSyntax :: String -> [String] -> IO () mainCheckSyntax = mainWithOptions () [] "program" $ \args () -> case args of [file] -> Just $ void $ readUntypedProgramOrDie file _ -> Nothing -- | @futhark thanks@ mainThanks :: String -> [String] -> IO () mainThanks = mainWithOptions () [] "" $ \args () -> case args of [] -> Just $ do i <- randomRIO (0, n - 1) putStrLn $ responses !! i _ -> Nothing where n = length responses responses = [ "You're welcome!", "Tell all your friends about Futhark!", "Likewise!", "And thank you in return for trying the language!", "It's our pleasure!", "Have fun with Futhark!" ] -- | @futhark tokens@ mainTokens :: String -> [String] -> IO () mainTokens = mainWithOptions () [] "program" $ \args () -> case args of [file] -> Just $ do res <- interactWithFileSafely (scanTokensText (startPos file) <$> T.readFile file) case res of Nothing -> do hPutStrLn stderr $ file <> ": file not found." exitWith $ ExitFailure 2 Just (Left e) -> do hPrint stderr e exitWith $ ExitFailure 2 Just (Right (Left e)) -> do hPrint stderr e exitWith $ ExitFailure 2 Just (Right (Right tokens)) -> mapM_ printToken tokens _ -> Nothing where printToken (L _ token) = print token futhark-0.25.27/src/Futhark/CLI/Multicore.hs000066400000000000000000000007461475065116200204520ustar00rootroot00000000000000-- | @futhark multicore@ module Futhark.CLI.Multicore (main) where import Futhark.Actions (compileMulticoreAction) import Futhark.Compiler.CLI import Futhark.Passes (mcmemPipeline) -- | Run @futhark multicore@. main :: String -> [String] -> IO () main = compilerMain () [] "Compile to multicore C" "Generate multicore C code from optimised Futhark program." mcmemPipeline $ \fcfg () mode outpath prog -> actionProcedure (compileMulticoreAction fcfg mode outpath) prog futhark-0.25.27/src/Futhark/CLI/MulticoreISPC.hs000066400000000000000000000007741475065116200211320ustar00rootroot00000000000000-- | @futhark multicore@ module Futhark.CLI.MulticoreISPC (main) where import Futhark.Actions (compileMulticoreToISPCAction) import Futhark.Compiler.CLI import Futhark.Passes (mcmemPipeline) -- | Run @futhark multicore@. main :: String -> [String] -> IO () main = compilerMain () [] "Compile to multicore ISPC" "Generate multicore ISPC code from optimised Futhark program." mcmemPipeline $ \fcfg () mode outpath prog -> actionProcedure (compileMulticoreToISPCAction fcfg mode outpath) prog futhark-0.25.27/src/Futhark/CLI/MulticoreWASM.hs000066400000000000000000000010251475065116200211310ustar00rootroot00000000000000-- | @futhark wasm-multicore@ module Futhark.CLI.MulticoreWASM (main) where import Futhark.Actions (compileMulticoreToWASMAction) import Futhark.Compiler.CLI import Futhark.Passes (mcmemPipeline) -- | Run @futhark c@ main :: String -> [String] -> IO () main = compilerMain () [] "Compile to multicore WASM" "Generate multicore WASM with the multicore C backend code from optimised Futhark program." mcmemPipeline $ \fcfg () mode outpath prog -> actionProcedure (compileMulticoreToWASMAction fcfg mode outpath) prog futhark-0.25.27/src/Futhark/CLI/OpenCL.hs000066400000000000000000000007151475065116200176230ustar00rootroot00000000000000-- | @futhark opencl@ module Futhark.CLI.OpenCL (main) where import Futhark.Actions (compileOpenCLAction) import Futhark.Compiler.CLI import Futhark.Passes (gpumemPipeline) -- | Run @futhark opencl@ main :: String -> [String] -> IO () main = compilerMain () [] "Compile OpenCL" "Generate OpenCL/C code from optimised Futhark program." gpumemPipeline $ \fcfg () mode outpath prog -> actionProcedure (compileOpenCLAction fcfg mode outpath) prog futhark-0.25.27/src/Futhark/CLI/Pkg.hs000066400000000000000000000333461475065116200172320ustar00rootroot00000000000000-- | @futhark pkg@ module Futhark.CLI.Pkg (main) where import Control.Monad import Control.Monad.IO.Class import Control.Monad.Reader import Control.Monad.State import Data.List (intercalate) import Data.Map qualified as M import Data.Maybe import Data.Monoid import Data.Text qualified as T import Data.Text.IO qualified as T import Futhark.Pkg.Info import Futhark.Pkg.Solve import Futhark.Pkg.Types import Futhark.Util (directoryContents, maxinum) import Futhark.Util.Log import Futhark.Util.Options import System.Directory import System.Environment import System.Exit import System.FilePath import System.IO import System.IO.Temp (withSystemTempDirectory) import Prelude --- Installing packages installInDir :: CacheDir -> BuildList -> FilePath -> PkgM () installInDir cachedir (BuildList bl) dir = forM_ (M.toList bl) $ \(p, v) -> do info <- lookupPackageRev cachedir p v (filedir, files) <- getFiles $ pkgGetFiles info -- The directory in the local file system that will contain the -- package files. let pdir = dir T.unpack p -- Remove any existing directory for this package. This is a bit -- inefficient, as the likelihood that the old ``lib`` directory -- already contains the correct version is rather high. We should -- have a way to recognise this situation, and not download the -- zipball in that case. liftIO $ removePathForcibly pdir forM_ files $ \file -> do let from = filedir file to = pdir file liftIO $ createDirectoryIfMissing True $ takeDirectory to logMsg $ "Copying " <> from <> "\n" <> "to " <> to liftIO $ copyFile from to libDir, libNewDir, libOldDir :: FilePath (libDir, libNewDir, libOldDir) = ("lib", "lib~new", "lib~old") -- | Install the packages listed in the build list in the @lib@ -- directory of the current working directory. Since we are touching -- the file system, we are going to be very paranoid. In particular, -- we want to avoid corrupting the @lib@ directory if something fails -- along the way. -- -- The procedure is as follows: -- -- 1) Create a directory @lib~new@. Delete an existing @lib~new@ if -- necessary. -- -- 2) Populate @lib~new@ based on the build list. -- -- 3) Rename @lib@ to @lib~old@. Delete an existing @lib~old@ if -- necessary. -- -- 4) Rename @lib~new@ to @lib@ -- -- 5) If the current package has package path @p@, move @lib~old/p@ to -- @lib~new/p@. -- -- 6) Delete @lib~old@. -- -- Since POSIX at least guarantees atomic renames, the only place this -- can fail is between steps 3, 4, and 5. In that case, at least the -- @lib~old@ will still exist and can be put back by the user. installBuildList :: CacheDir -> Maybe PkgPath -> BuildList -> PkgM () installBuildList cachedir p bl = do libdir_exists <- liftIO $ doesDirectoryExist libDir -- 1 liftIO $ do removePathForcibly libNewDir createDirectoryIfMissing False libNewDir -- 2 installInDir cachedir bl libNewDir -- 3 when libdir_exists $ liftIO $ do removePathForcibly libOldDir renameDirectory libDir libOldDir -- 4 liftIO $ renameDirectory libNewDir libDir -- 5 case pkgPathFilePath <$> p of Just pfp | libdir_exists -> liftIO $ do pkgdir_exists <- doesDirectoryExist $ libOldDir pfp when pkgdir_exists $ do -- Ensure the parent directories exist so that we can move the -- package directory directly. createDirectoryIfMissing True $ takeDirectory $ libDir pfp renameDirectory (libOldDir pfp) (libDir pfp) _ -> pure () -- 6 when libdir_exists $ liftIO $ removePathForcibly libOldDir getPkgManifest :: PkgM PkgManifest getPkgManifest = do file_exists <- liftIO $ doesFileExist futharkPkg dir_exists <- liftIO $ doesDirectoryExist futharkPkg case (file_exists, dir_exists) of (True, _) -> liftIO $ parsePkgManifestFromFile futharkPkg (_, True) -> fail $ futharkPkg <> " exists, but it is a directory! What in Odin's beard..." _ -> liftIO $ do T.putStrLn $ T.pack futharkPkg <> " not found - pretending it's empty." pure $ newPkgManifest Nothing putPkgManifest :: PkgManifest -> PkgM () putPkgManifest = liftIO . T.writeFile futharkPkg . prettyPkgManifest --- The CLI newtype PkgConfig = PkgConfig {pkgVerbose :: Bool} -- | The monad in which futhark-pkg runs. newtype PkgM a = PkgM {unPkgM :: ReaderT PkgConfig (StateT (PkgRegistry PkgM) IO) a} deriving (Functor, Applicative, MonadIO, MonadReader PkgConfig) instance Monad PkgM where PkgM m >>= f = PkgM $ m >>= unPkgM . f instance MonadFail PkgM where fail s = liftIO $ do prog <- getProgName putStrLn $ prog ++ ": " ++ s exitFailure instance MonadPkgRegistry PkgM where putPkgRegistry = PkgM . put getPkgRegistry = PkgM get instance MonadLogger PkgM where addLog l = do verbose <- asks pkgVerbose when verbose $ liftIO $ T.hPutStrLn stderr $ toText l runPkgM :: PkgConfig -> PkgM a -> IO a runPkgM cfg (PkgM m) = evalStateT (runReaderT m cfg) mempty cmdMain :: String -> ([String] -> PkgConfig -> Maybe (IO ())) -> String -> [String] -> IO () cmdMain = mainWithOptions (PkgConfig False) options where options = [ Option "v" ["verbose"] (NoArg $ Right $ \cfg -> cfg {pkgVerbose = True}) "Write running diagnostics to stderr." ] doFmt :: String -> [String] -> IO () doFmt = mainWithOptions () [] "" $ \args () -> case args of [] -> Just $ do m <- parsePkgManifestFromFile futharkPkg T.writeFile futharkPkg $ prettyPkgManifest m _ -> Nothing withCacheDir :: (CacheDir -> IO a) -> IO a withCacheDir f = withSystemTempDirectory "futhark-pkg" $ f . CacheDir doCheck :: String -> [String] -> IO () doCheck = cmdMain "check" $ \args cfg -> case args of [] -> Just . withCacheDir $ \cachedir -> runPkgM cfg $ do m <- getPkgManifest bl <- solveDeps cachedir $ pkgRevDeps m liftIO $ T.putStrLn "Dependencies chosen:" liftIO $ T.putStr $ prettyBuildList bl case commented $ manifestPkgPath m of Nothing -> pure () Just p -> do let pdir = "lib" T.unpack p pdir_exists <- liftIO $ doesDirectoryExist pdir unless pdir_exists $ liftIO $ do T.putStrLn $ "Problem: the directory " <> T.pack pdir <> " does not exist." exitFailure anything <- liftIO $ any ((== ".fut") . takeExtension) <$> directoryContents ("lib" T.unpack p) unless anything $ liftIO $ do T.putStrLn $ "Problem: the directory " <> T.pack pdir <> " does not contain any .fut files." exitFailure _ -> Nothing doSync :: String -> [String] -> IO () doSync = cmdMain "" $ \args cfg -> case args of [] -> Just . withCacheDir $ \cachedir -> runPkgM cfg $ do m <- getPkgManifest bl <- solveDeps cachedir $ pkgRevDeps m installBuildList cachedir (commented $ manifestPkgPath m) bl _ -> Nothing doAdd :: String -> [String] -> IO () doAdd = cmdMain "PKGPATH" $ \args cfg -> case args of [p, v] | Right v' <- parseVersion $ T.pack v -> Just $ withCacheDir $ \cachedir -> runPkgM cfg $ doAdd' cachedir (T.pack p) v' [p] -> Just $ withCacheDir $ \cachedir -> runPkgM cfg $ -- Look up the newest revision of the package. doAdd' cachedir (T.pack p) =<< lookupNewestRev cachedir (T.pack p) _ -> Nothing where doAdd' cachedir p v = do m <- getPkgManifest -- See if this package (and its dependencies) even exists. We -- do this by running the solver with the dependencies already -- in the manifest, plus this new one. The Monoid instance for -- PkgRevDeps is left-biased, so we are careful to use the new -- version for this package. _ <- solveDeps cachedir $ PkgRevDeps (M.singleton p (v, Nothing)) <> pkgRevDeps m -- We either replace any existing occurence of package 'p', or -- we add a new one. p_info <- lookupPackageRev cachedir p v let hash = case (_svMajor v, _svMinor v, _svPatch v) of -- We do not perform hash-pinning for -- (0,0,0)-versions, because these already embed a -- specific revision ID into their version number. (0, 0, 0) -> Nothing _ -> Just $ pkgRevCommit p_info req = Required p v hash (m', prev_r) = addRequiredToManifest req m case prev_r of Just prev_r' | requiredPkgRev prev_r' == v -> liftIO $ T.putStrLn $ "Package already at version " <> prettySemVer v <> "; nothing to do." | otherwise -> liftIO $ T.putStrLn $ "Replaced " <> p <> " " <> prettySemVer (requiredPkgRev prev_r') <> " => " <> prettySemVer v <> "." Nothing -> liftIO $ T.putStrLn $ "Added new required package " <> p <> " " <> prettySemVer v <> "." putPkgManifest m' liftIO $ T.putStrLn "Remember to run 'futhark pkg sync'." doRemove :: String -> [String] -> IO () doRemove = cmdMain "PKGPATH" $ \args cfg -> case args of [p] -> Just $ runPkgM cfg $ doRemove' $ T.pack p _ -> Nothing where doRemove' p = do m <- getPkgManifest case removeRequiredFromManifest p m of Nothing -> liftIO $ do T.putStrLn $ "No package " <> p <> " found in " <> T.pack futharkPkg <> "." exitFailure Just (m', r) -> do putPkgManifest m' liftIO $ T.putStrLn $ "Removed " <> p <> " " <> prettySemVer (requiredPkgRev r) <> "." doInit :: String -> [String] -> IO () doInit = cmdMain "PKGPATH" $ \args cfg -> case args of [p] -> Just $ runPkgM cfg $ doCreate' $ T.pack p _ -> Nothing where validPkgPath p = not $ any (`elem` [".", ".."]) $ splitDirectories $ T.unpack p doCreate' p = do unless (validPkgPath p) . liftIO $ do T.putStrLn $ "Not a valid package path: " <> p T.putStrLn "Note: package paths are usually URIs." T.putStrLn "Note: 'futhark init' is only needed when creating a package, not to use packages." exitFailure exists <- liftIO $ (||) <$> doesFileExist futharkPkg <*> doesDirectoryExist futharkPkg when exists $ liftIO $ do T.putStrLn $ T.pack futharkPkg <> " already exists." exitFailure liftIO $ createDirectoryIfMissing True $ "lib" T.unpack p liftIO $ T.putStrLn $ "Created directory " <> T.pack ("lib" T.unpack p) <> "." putPkgManifest $ newPkgManifest $ Just p liftIO $ T.putStrLn $ "Wrote " <> T.pack futharkPkg <> "." doUpgrade :: String -> [String] -> IO () doUpgrade = cmdMain "" $ \args cfg -> case args of [] -> Just . withCacheDir $ \cachedir -> runPkgM cfg $ do m <- getPkgManifest rs <- traverse (mapM (traverse (upgrade cachedir))) $ manifestRequire m putPkgManifest m {manifestRequire = rs} if rs == manifestRequire m then liftIO $ T.putStrLn "Nothing to upgrade." else liftIO $ T.putStrLn "Remember to run 'futhark pkg sync'." _ -> Nothing where upgrade cachedir req = do v <- lookupNewestRev cachedir $ requiredPkg req h <- pkgRevCommit <$> lookupPackageRev cachedir (requiredPkg req) v when (v /= requiredPkgRev req) $ liftIO $ T.putStrLn $ "Upgraded " <> requiredPkg req <> " " <> prettySemVer (requiredPkgRev req) <> " => " <> prettySemVer v <> "." pure req { requiredPkgRev = v, requiredHash = Just h } doVersions :: String -> [String] -> IO () doVersions = cmdMain "PKGPATH" $ \args cfg -> case args of [p] -> Just $ withCacheDir $ \cachedir -> runPkgM cfg $ doVersions' cachedir $ T.pack p _ -> Nothing where doVersions' cachedir = mapM_ (liftIO . T.putStrLn . prettySemVer) . M.keys . pkgVersions <=< lookupPackage cachedir -- | Run @futhark pkg@. main :: String -> [String] -> IO () main prog args = do -- Avoid Git asking for credentials. We prefer failure. liftIO $ setEnv "GIT_TERMINAL_PROMPT" "0" let commands = [ ( "add", (doAdd, "Add another required package to futhark.pkg.") ), ( "check", (doCheck, "Check that futhark.pkg is satisfiable.") ), ( "init", (doInit, "Create a new futhark.pkg and a lib/ skeleton.") ), ( "fmt", (doFmt, "Reformat futhark.pkg.") ), ( "sync", (doSync, "Populate lib/ as specified by futhark.pkg.") ), ( "remove", (doRemove, "Remove a required package from futhark.pkg.") ), ( "upgrade", (doUpgrade, "Upgrade all packages to newest versions.") ), ( "versions", (doVersions, "List available versions for a package.") ) ] usage = "options... <" <> intercalate "|" (map fst commands) <> ">" case args of cmd : args' | Just (m, _) <- lookup cmd commands -> m (unwords [prog, cmd]) args' _ -> do let bad _ () = Just $ do let k = maxinum (map (length . fst) commands) + 3 usageMsg . T.unlines $ [" ...:", "", "Commands:"] ++ [ " " <> T.pack cmd <> T.pack (replicate (k - length cmd) ' ') <> desc | (cmd, (_, desc)) <- commands ] mainWithOptions () [] usage bad prog args where usageMsg s = do T.putStrLn $ "Usage: " <> T.pack prog <> " [--version] [--help] " <> s exitFailure futhark-0.25.27/src/Futhark/CLI/Profile.hs000066400000000000000000000147351475065116200201120ustar00rootroot00000000000000-- | @futhark profile@ module Futhark.CLI.Profile (main) where import Control.Exception (catch) import Data.ByteString.Lazy.Char8 qualified as BS import Data.List qualified as L import Data.Map qualified as M import Data.Text qualified as T import Data.Text.IO qualified as T import Futhark.Bench import Futhark.Util (showText) import Futhark.Util.Options import System.Directory (createDirectoryIfMissing, removePathForcibly) import System.Exit import System.FilePath import System.IO import Text.Printf commonPrefix :: (Eq e) => [e] -> [e] -> [e] commonPrefix _ [] = [] commonPrefix [] _ = [] commonPrefix (x : xs) (y : ys) | x == y = x : commonPrefix xs ys | otherwise = [] longestCommonPrefix :: [FilePath] -> FilePath longestCommonPrefix [] = "" longestCommonPrefix (x : xs) = foldr commonPrefix x xs memoryReport :: M.Map T.Text Integer -> T.Text memoryReport = T.unlines . ("Peak memory usage in bytes" :) . map f . M.toList where f (space, bytes) = space <> ": " <> showText bytes padRight :: Int -> T.Text -> T.Text padRight k s = s <> T.replicate (k - T.length s) " " padLeft :: Int -> T.Text -> T.Text padLeft k s = T.replicate (k - T.length s) " " <> s data EvSummary = EvSummary { evCount :: Integer, evSum :: Double, evMin :: Double, evMax :: Double } tabulateEvents :: [ProfilingEvent] -> T.Text tabulateEvents = mkRows . M.toList . M.fromListWith comb . map pair where pair (ProfilingEvent name dur _) = (name, EvSummary 1 dur dur dur) comb (EvSummary xn xdur xmin xmax) (EvSummary yn ydur ymin ymax) = EvSummary (xn + yn) (xdur + ydur) (min xmin ymin) (max xmax ymax) numpad = 15 mkRows rows = let longest = foldl max numpad $ map (T.length . fst) rows header = headerRow longest splitter = T.map (const '-') header bottom = T.unwords [ showText (sum (map (evCount . snd) rows)), "events with a total runtime of", T.pack $ printf "%.2fμs" $ sum $ map (evSum . snd) rows ] in T.unlines $ header : splitter : map (mkRow longest) rows <> [splitter, bottom] headerRow longest = T.unwords [ padLeft longest "Cost centre", padLeft numpad "count", padLeft numpad "sum", padLeft numpad "avg", padLeft numpad "min", padLeft numpad "max" ] mkRow longest (name, ev) = T.unwords [ padRight longest name, padLeft numpad (showText (evCount ev)), padLeft numpad $ T.pack $ printf "%.2fμs" (evSum ev), padLeft numpad $ T.pack $ printf "%.2fμs" $ evSum ev / fromInteger (evCount ev), padLeft numpad $ T.pack $ printf "%.2fμs" (evMin ev), padLeft numpad $ T.pack $ printf "%.2fμs" (evMax ev) ] timeline :: [ProfilingEvent] -> T.Text timeline = T.unlines . L.intercalate [""] . map onEvent where onEvent (ProfilingEvent name duration description) = [name, "Duration: " <> showText duration <> " μs"] <> T.lines description data TargetFiles = TargetFiles { summaryFile :: FilePath, timelineFile :: FilePath } writeAnalysis :: TargetFiles -> ProfilingReport -> IO () writeAnalysis tf r = do T.writeFile (summaryFile tf) $ memoryReport (profilingMemory r) <> "\n\n" <> tabulateEvents (profilingEvents r) T.writeFile (timelineFile tf) $ timeline (profilingEvents r) prepareDir :: FilePath -> IO FilePath prepareDir json_path = do let top_dir = takeFileName json_path -<.> "prof" T.hPutStrLn stderr $ "Writing results to " <> T.pack top_dir <> "/" removePathForcibly top_dir pure top_dir analyseProfilingReport :: FilePath -> ProfilingReport -> IO () analyseProfilingReport json_path r = do top_dir <- prepareDir json_path createDirectoryIfMissing True top_dir let tf = TargetFiles { summaryFile = top_dir "summary", timelineFile = top_dir "timeline" } writeAnalysis tf r analyseBenchResults :: FilePath -> [BenchResult] -> IO () analyseBenchResults json_path bench_results = do top_dir <- prepareDir json_path T.hPutStrLn stderr $ "Stripping '" <> T.pack prefix <> "' from program paths." mapM_ (onBenchResult top_dir) bench_results where prefix = longestCommonPrefix $ map benchResultProg bench_results -- Eliminate characters that are filesystem-meaningful. escape '/' = '_' escape c = c problem prog_name name what = T.hPutStrLn stderr $ prog_name <> " dataset " <> name <> ": " <> what onBenchResult top_dir (BenchResult prog_path data_results) = do let (prog_path', entry) = span (/= ':') prog_path prog_name = drop (length prefix) prog_path' prog_dir = top_dir dropExtension prog_name drop 1 entry createDirectoryIfMissing True prog_dir mapM_ (onDataResult prog_dir (T.pack prog_name)) data_results onDataResult _ prog_name (DataResult name (Left _)) = problem prog_name name "execution failed" onDataResult prog_dir prog_name (DataResult name (Right res)) = do let name' = prog_dir T.unpack (T.map escape name) case stdErr res of Nothing -> problem prog_name name "no log recorded" Just text -> T.writeFile (name' <.> ".log") text case report res of Nothing -> problem prog_name name "no profiling information" Just r -> let tf = TargetFiles { summaryFile = name' <> ".summary", timelineFile = name' <> ".timeline" } in writeAnalysis tf r readFileSafely :: FilePath -> IO (Either String BS.ByteString) readFileSafely filepath = (Right <$> BS.readFile filepath) `catch` couldNotRead where couldNotRead e = pure $ Left $ show (e :: IOError) onFile :: FilePath -> IO () onFile json_path = do s <- readFileSafely json_path case s of Left a -> do hPutStrLn stderr a exitWith $ ExitFailure 2 Right s' -> case decodeBenchResults s' of Left _ -> case decodeProfilingReport s' of Nothing -> do hPutStrLn stderr $ "Cannot recognise " <> json_path <> " as benchmark results or a profiling report." Just pr -> analyseProfilingReport json_path pr Right br -> analyseBenchResults json_path br -- | Run @futhark profile@. main :: String -> [String] -> IO () main = mainWithOptions () [] "[files]" f where f files () = Just $ mapM_ onFile files futhark-0.25.27/src/Futhark/CLI/PyOpenCL.hs000066400000000000000000000007201475065116200201300ustar00rootroot00000000000000-- | @futhark pyopencl@ module Futhark.CLI.PyOpenCL (main) where import Futhark.Actions (compilePyOpenCLAction) import Futhark.Compiler.CLI import Futhark.Passes -- | Run @futhark pyopencl@. main :: String -> [String] -> IO () main = compilerMain () [] "Compile PyOpenCL" "Generate Python + OpenCL code from optimised Futhark program." gpumemPipeline $ \fcfg () mode outpath prog -> actionProcedure (compilePyOpenCLAction fcfg mode outpath) prog futhark-0.25.27/src/Futhark/CLI/Python.hs000066400000000000000000000007101475065116200177570ustar00rootroot00000000000000-- | @futhark py@ module Futhark.CLI.Python (main) where import Futhark.Actions (compilePythonAction) import Futhark.Compiler.CLI import Futhark.Passes -- | Run @futhark py@ main :: String -> [String] -> IO () main = compilerMain () [] "Compile sequential Python" "Generate sequential Python code from optimised Futhark program." seqmemPipeline $ \fcfg () mode outpath prog -> actionProcedure (compilePythonAction fcfg mode outpath) prog futhark-0.25.27/src/Futhark/CLI/Query.hs000066400000000000000000000026411475065116200176100ustar00rootroot00000000000000-- | @futhark query@ module Futhark.CLI.Query (main) where import Futhark.Compiler import Futhark.Util.Loc import Futhark.Util.Options import Language.Futhark.Query import Language.Futhark.Syntax import Text.Read (readMaybe) -- | Run @futhark query@. main :: String -> [String] -> IO () main = mainWithOptions () [] "program line col" $ \args () -> case args of [file, line, col] -> do line' <- readMaybe line col' <- readMaybe col Just $ do (_, imports, _) <- readProgramOrDie file -- The 'offset' part of the Pos is not used and can be arbitrary. case atPos imports $ Pos file line' col' 0 of Nothing -> putStrLn "No information available." Just (AtName qn def loc) -> do putStrLn $ "Name: " ++ prettyString qn putStrLn $ "Position: " ++ locStr (srclocOf loc) case def of Nothing -> pure () Just (BoundTerm t defloc) -> do putStrLn $ "Type: " ++ prettyString t putStrLn $ "Definition: " ++ locStr (srclocOf defloc) Just (BoundType defloc) -> putStrLn $ "Definition: " ++ locStr (srclocOf defloc) Just (BoundModule defloc) -> putStrLn $ "Definition: " ++ locStr (srclocOf defloc) Just (BoundModuleType defloc) -> putStrLn $ "Definition: " ++ locStr (srclocOf defloc) _ -> Nothing futhark-0.25.27/src/Futhark/CLI/REPL.hs000066400000000000000000000460061475065116200172500ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | @futhark repl@ module Futhark.CLI.REPL (main) where import Control.Exception import Control.Monad import Control.Monad.Except import Control.Monad.Free.Church import Control.Monad.State import Data.Char import Data.List (intersperse, isPrefixOf) import Data.List.NonEmpty qualified as NE import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T import Data.Text.IO qualified as T import Data.Version import Futhark.Compiler import Futhark.Format (parseFormatString) import Futhark.MonadFreshNames import Futhark.Util (fancyTerminal, showText) import Futhark.Util.Options import Futhark.Util.Pretty (AnsiStyle, Color (..), Doc, align, annotate, bgColorDull, bold, brackets, color, docText, docTextForHandle, hardline, italicized, oneLine, pretty, putDoc, putDocLn, unAnnotate, (<+>)) import Futhark.Version import Language.Futhark import Language.Futhark.Interpreter qualified as I import Language.Futhark.Parser import Language.Futhark.Semantic qualified as T import Language.Futhark.TypeChecker qualified as T import NeatInterpolation (text) import System.Console.Haskeline qualified as Haskeline import System.Directory import System.IO (stdout) import Text.Read (readMaybe) banner :: Doc AnsiStyle banner = mconcat . map ((<> hardline) . decorate . pretty) $ [ "┃╱╱ ┃╲ ┃ ┃╲ ┃╲ ╱" :: T.Text, "┃╱ ┃ ╲ ┃╲ ┃╲ ┃╱ ╱ ", "┃ ┃ ╲ ┃╱ ┃ ┃╲ ╲ ", "┃ ┃ ╲ ┃ ┃ ┃ ╲ ╲" ] where decorate = annotate (bgColorDull Red <> bold <> color White) -- | Run @futhark repl@. main :: String -> [String] -> IO () main = mainWithOptions () [] "options... [program.fut]" run where run [] _ = Just $ repl Nothing run [prog] _ = Just $ repl $ Just prog run _ _ = Nothing data StopReason = EOF | Stop | Exit | Load FilePath | Interrupt replSettings :: Haskeline.Settings IO replSettings = Haskeline.setComplete replComplete Haskeline.defaultSettings repl :: Maybe FilePath -> IO () repl maybe_prog = do when fancyTerminal $ do putDoc banner putStrLn $ "Version " ++ showVersion version ++ "." putStrLn "Copyright (C) DIKU, University of Copenhagen, released under the ISC license." putStrLn "" putDoc $ "Run" <+> annotate bold ":help" <+> "for a list of commands." putStrLn "" let toploop s = do (stop, s') <- Haskeline.handleInterrupt (pure (Left Interrupt, s)) . Haskeline.withInterrupt $ runStateT (runExceptT $ runFutharkiM $ forever readEvalPrint) s case stop of Left Stop -> finish s' Left EOF -> finish s' Left Exit -> finish s' Left Interrupt -> do liftIO $ T.putStrLn "Interrupted" toploop s' {futharkiCount = futharkiCount s' + 1} Left (Load file) -> do liftIO $ T.putStrLn $ "Loading " <> T.pack file maybe_new_state <- liftIO $ newFutharkiState (futharkiCount s) (futharkiProg s) $ Just file case maybe_new_state of Right new_state -> toploop new_state Left err -> do liftIO $ putDocLn err toploop s' Right _ -> pure () finish _s = pure () maybe_init_state <- liftIO $ newFutharkiState 0 noLoadedProg maybe_prog s <- case maybe_init_state of Left prog_err -> do noprog_init_state <- liftIO $ newFutharkiState 0 noLoadedProg Nothing case noprog_init_state of Left err -> error $ "Failed to initialise interpreter state: " <> T.unpack (docText err) Right s -> do liftIO $ putDocLn prog_err pure s {futharkiLoaded = maybe_prog} Right s -> pure s Haskeline.runInputT replSettings $ toploop s putStrLn "Leaving 'futhark repl'." -- | Representation of breaking at a breakpoint, to allow for -- navigating through the stack frames and such. data Breaking = Breaking { breakingStack :: NE.NonEmpty I.StackFrame, -- | Index of the current breakpoint (with 0 being the outermost). breakingAt :: Int } data FutharkiState = FutharkiState { futharkiProg :: LoadedProg, futharkiCount :: Int, futharkiEnv :: (T.Env, I.Ctx), -- | Are we currently stopped at a breakpoint? futharkiBreaking :: Maybe Breaking, -- | Skip breakpoints at these locations. futharkiSkipBreaks :: [Loc], futharkiBreakOnNaN :: Bool, -- | The currently loaded file. futharkiLoaded :: Maybe FilePath } extendEnvs :: LoadedProg -> (T.Env, I.Ctx) -> [ImportName] -> (T.Env, I.Ctx) extendEnvs prog (tenv, ictx) opens = (tenv', ictx') where tenv' = T.envWithImports t_imports tenv ictx' = I.ctxWithImports i_envs ictx t_imports = filter ((`elem` opens) . fst) $ lpImports prog i_envs = map snd $ filter ((`elem` opens) . fst) $ M.toList $ I.ctxImports ictx newFutharkiState :: Int -> LoadedProg -> Maybe FilePath -> IO (Either (Doc AnsiStyle) FutharkiState) newFutharkiState count prev_prog maybe_file = runExceptT $ do let files = maybeToList maybe_file -- Put code through the type checker. prog <- badOnLeft prettyProgErrors =<< liftIO (reloadProg prev_prog files M.empty) liftIO $ putDoc $ prettyWarnings $ lpWarnings prog -- Then into the interpreter. ictx <- foldM (\ctx -> badOnLeft (pretty . show) <=< runInterpreterNoBreak . I.interpretImport ctx) I.initialCtx $ map (fmap fileProg) (lpImports prog) let (tenv, ienv) = let (iname, fm) = last $ lpImports prog in ( fileScope fm, ictx {I.ctxEnv = I.ctxImports ictx M.! iname} ) pure FutharkiState { futharkiProg = prog, futharkiCount = count, futharkiEnv = (tenv, ienv), futharkiBreaking = Nothing, futharkiSkipBreaks = mempty, futharkiBreakOnNaN = False, futharkiLoaded = maybe_file } where badOnLeft :: (err -> err') -> Either err a -> ExceptT err' IO a badOnLeft _ (Right x) = pure x badOnLeft p (Left err) = throwError $ p err getPrompt :: FutharkiM String getPrompt = do i <- gets futharkiCount fmap T.unpack $ liftIO $ docTextForHandle stdout $ annotate bold $ brackets (pretty i) <> "> " -- The ExceptT part is more of a continuation, really. newtype FutharkiM a = FutharkiM {runFutharkiM :: ExceptT StopReason (StateT FutharkiState (Haskeline.InputT IO)) a} deriving ( Functor, Applicative, Monad, MonadState FutharkiState, MonadIO, MonadError StopReason ) readEvalPrint :: FutharkiM () readEvalPrint = do prompt <- getPrompt line <- inputLine prompt breaking <- gets futharkiBreaking case T.uncons line of Nothing | isJust breaking -> throwError Stop | otherwise -> pure () Just (':', command) -> do let (cmdname, rest) = T.break isSpace command arg = T.dropWhileEnd isSpace $ T.dropWhile isSpace rest case filter ((cmdname `T.isPrefixOf`) . fst) commands of [] -> liftIO $ T.putStrLn $ "Unknown command '" <> cmdname <> "'" [(_, (cmdf, _))] -> cmdf arg matches -> liftIO . T.putStrLn $ "Ambiguous command; could be one of " <> mconcat (intersperse ", " (map fst matches)) _ -> do -- Read a declaration or expression. case parseDecOrExp prompt line of Left (SyntaxError _ err) -> liftIO $ T.putStrLn err Right (Left d) -> onDec d Right (Right e) -> do valOrErr <- onExp e case valOrErr of Left err -> liftIO $ putDocLn err Right val -> liftIO $ putDocLn $ I.prettyValue val modify $ \s -> s {futharkiCount = futharkiCount s + 1} where inputLine prompt = do inp <- FutharkiM $ lift $ lift $ Haskeline.getInputLine prompt case inp of Just s -> pure $ T.pack s Nothing -> throwError EOF getIt :: FutharkiM (Imports, VNameSource, T.Env, I.Ctx) getIt = do imports <- gets $ lpImports . futharkiProg src <- gets $ lpNameSource . futharkiProg (tenv, ienv) <- gets futharkiEnv pure (imports, src, tenv, ienv) onDec :: UncheckedDec -> FutharkiM () onDec d = do old_imports <- gets $ lpImports . futharkiProg cur_import <- gets $ T.mkInitialImport . fromMaybe "." . futharkiLoaded let mkImport = T.mkImportFrom cur_import files = map (T.includeToFilePath . mkImport . fst) $ decImports d cur_prog <- gets futharkiProg imp_r <- liftIO $ extendProg cur_prog files M.empty case imp_r of Left e -> liftIO $ putDoc $ prettyProgErrors e Right prog -> do env <- gets futharkiEnv let (tenv, ienv) = extendEnvs prog env $ map (T.mkInitialImport . fst) $ decImports d imports = lpImports prog src = lpNameSource prog case T.checkDec imports src tenv cur_import d of (_, Left e) -> liftIO $ putDoc $ T.prettyTypeErrorNoLoc e (_, Right (tenv', d', src')) -> do let new_imports = filter ((`notElem` map fst old_imports) . fst) imports int_r <- runInterpreter $ do let onImport ienv' (s, imp) = I.interpretImport ienv' (s, T.fileProg imp) ienv' <- foldM onImport ienv new_imports I.interpretDec ienv' d' case int_r of Left err -> liftIO $ print err Right ienv' -> modify $ \s -> s { futharkiEnv = (tenv', ienv'), futharkiProg = prog {lpNameSource = src'} } onExp :: UncheckedExp -> FutharkiM (Either (Doc AnsiStyle) I.Value) onExp e = do (imports, src, tenv, ienv) <- getIt case T.checkExp imports src tenv e of (_, Left err) -> pure $ Left $ T.prettyTypeErrorNoLoc err (_, Right (tparams, e')) | null tparams -> do r <- runInterpreter $ I.interpretExp ienv e' case r of Left err -> pure $ Left $ pretty $ showText err Right v -> pure $ Right v | otherwise -> pure $ Left $ ("Inferred type of expression: " <> align (pretty (typeOf e'))) <> hardline <> pretty ( "The following types are ambiguous: " <> T.intercalate ", " (map (nameToText . toName . typeParamName) tparams) ) <> hardline prettyBreaking :: Breaking -> T.Text prettyBreaking b = prettyStacktrace (breakingAt b) $ map locText $ NE.toList $ breakingStack b -- Are we currently willing to break for this reason? Among othe -- things, we do not want recursive breakpoints. It could work fine -- technically, but is probably too confusing to be useful. breakForReason :: FutharkiState -> I.StackFrame -> I.BreakReason -> Bool breakForReason s _ I.BreakNaN | not $ futharkiBreakOnNaN s = False breakForReason s top _ = isNothing (futharkiBreaking s) && locOf top `notElem` futharkiSkipBreaks s runInterpreter :: F I.ExtOp a -> FutharkiM (Either I.InterpreterError a) runInterpreter m = runF m (pure . Right) intOp where intOp (I.ExtOpError err) = pure $ Left err intOp (I.ExtOpTrace w v c) = do liftIO $ putDocLn $ pretty w <> ":" <+> unAnnotate v c intOp (I.ExtOpBreak w why callstack c) = do s <- get let why' = case why of I.BreakPoint -> "Breakpoint" I.BreakNaN -> "NaN produced" top = NE.head callstack ctx = I.stackFrameCtx top tenv = I.typeCheckerEnv $ I.ctxEnv ctx breaking = Breaking callstack 0 -- Are we supposed to respect this breakpoint? when (breakForReason s top why) $ do liftIO $ T.putStrLn $ why' <> " at " <> locText w liftIO $ T.putStrLn $ prettyBreaking breaking liftIO $ T.putStrLn " to continue." -- Note the cleverness to preserve the Haskeline session (for -- line history and such). (stop, s') <- FutharkiM . lift . lift $ runStateT (runExceptT $ runFutharkiM $ forever readEvalPrint) s { futharkiEnv = (tenv, ctx), futharkiCount = futharkiCount s + 1, futharkiBreaking = Just breaking } case stop of Left (Load file) -> throwError $ Load file _ -> do liftIO $ putStrLn "Continuing..." put s { futharkiCount = futharkiCount s', futharkiSkipBreaks = futharkiSkipBreaks s' <> futharkiSkipBreaks s, futharkiBreakOnNaN = futharkiBreakOnNaN s' } c runInterpreterNoBreak :: (MonadIO m) => F I.ExtOp a -> m (Either I.InterpreterError a) runInterpreterNoBreak m = runF m (pure . Right) intOp where intOp (I.ExtOpError err) = pure $ Left err intOp (I.ExtOpTrace w v c) = do liftIO $ putDocLn $ pretty w <> ":" <+> align (unAnnotate v) c intOp (I.ExtOpBreak _ I.BreakNaN _ c) = c intOp (I.ExtOpBreak w _ _ c) = do liftIO $ T.putStrLn $ locText w <> ": " <> "ignoring breakpoint when computating constant." c replComplete :: Haskeline.CompletionFunc IO replComplete = loadComplete where loadComplete (prev, aft) | ":load " `isPrefixOf` reverse prev = Haskeline.completeFilename (prev, aft) | otherwise = Haskeline.noCompletion (prev, aft) type Command = T.Text -> FutharkiM () loadCommand :: Command loadCommand file = do loaded <- gets futharkiLoaded case (T.null file, loaded) of (True, Just loaded') -> throwError $ Load loaded' (True, Nothing) -> liftIO $ T.putStrLn "No file specified and no file previously loaded." (False, _) -> throwError $ Load $ T.unpack file genTypeCommand :: (String -> T.Text -> Either SyntaxError a) -> (Imports -> VNameSource -> T.Env -> a -> (Warnings, Either T.TypeError b)) -> (b -> Doc AnsiStyle) -> Command genTypeCommand f g h e = do prompt <- getPrompt case f prompt e of Left (SyntaxError _ err) -> liftIO $ T.putStrLn err Right e' -> do (imports, src, tenv, _) <- getIt case snd $ g imports src tenv e' of Left err -> liftIO $ putDoc $ T.prettyTypeErrorNoLoc err Right x -> liftIO $ putDocLn $ h x typeCommand :: Command typeCommand = genTypeCommand parseExp T.checkExp $ \(ps, e) -> oneLine (pretty (typeOf e)) <> if not (null ps) then annotate italicized $ "\n\nPolymorphic in" <+> mconcat (intersperse " " $ map pretty ps) <> "." else mempty mtypeCommand :: Command mtypeCommand = genTypeCommand parseModExp T.checkModExp $ pretty . fst formatCommand :: Command formatCommand input = do case parseFormatString input of Left err -> liftIO $ T.putStrLn err Right parts -> do prompt <- getPrompt case mapM (traverse $ parseExp prompt) parts of Left (SyntaxError _ err) -> liftIO $ T.putStr err Right parts' -> do parts'' <- mapM sequenceA <$> mapM (traverse onExp) parts' case parts'' of Left err -> liftIO $ putDoc err Right parts''' -> liftIO . T.putStrLn . mconcat $ map (either id (docText . I.prettyValue)) parts''' unbreakCommand :: Command unbreakCommand _ = do top <- gets $ fmap (NE.head . breakingStack) . futharkiBreaking case top of Nothing -> liftIO $ putStrLn "Not currently stopped at a breakpoint." Just top' -> do modify $ \s -> s {futharkiSkipBreaks = locOf top' : futharkiSkipBreaks s} throwError Stop nanbreakCommand :: Command nanbreakCommand _ = do modify $ \s -> s {futharkiBreakOnNaN = not $ futharkiBreakOnNaN s} b <- gets futharkiBreakOnNaN liftIO $ putStrLn $ if b then "Now treating NaNs as breakpoints." else "No longer treating NaNs as breakpoints." frameCommand :: Command frameCommand which = do maybe_stack <- gets $ fmap breakingStack . futharkiBreaking case (maybe_stack, readMaybe $ T.unpack which) of (Just stack, Just i) | frame : _ <- NE.drop i stack -> do let breaking = Breaking stack i ctx = I.stackFrameCtx frame tenv = I.typeCheckerEnv $ I.ctxEnv ctx modify $ \s -> s { futharkiEnv = (tenv, ctx), futharkiBreaking = Just breaking } liftIO $ T.putStrLn $ prettyBreaking breaking (Just _, _) -> liftIO $ putStrLn $ "Invalid stack index: " ++ T.unpack which (Nothing, _) -> liftIO $ putStrLn "Not stopped at a breakpoint." pwdCommand :: Command pwdCommand _ = liftIO $ putStrLn =<< getCurrentDirectory cdCommand :: Command cdCommand dir | T.null dir = liftIO $ putStrLn "Usage: ':cd '." | otherwise = liftIO $ setCurrentDirectory (T.unpack dir) `catch` \(err :: IOException) -> print err helpCommand :: Command helpCommand _ = liftIO $ forM_ commands $ \(cmd, (_, desc)) -> do putDoc $ annotate bold $ ":" <> pretty cmd <> hardline T.putStrLn $ T.replicate (1 + T.length cmd) "─" T.putStr desc T.putStrLn "" T.putStrLn "" quitCommand :: Command quitCommand _ = throwError Exit commands :: [(T.Text, (Command, T.Text))] commands = [ ( "load", ( loadCommand, [text| Load a Futhark source file. Usage: > :load foo.fut If the loading succeeds, any expressions entered subsequently can use the declarations in the source file. Only one source file can be loaded at a time. Using the :load command a second time will replace the previously loaded file. It will also replace any declarations entered at the REPL. |] ) ), ( "format", ( formatCommand, [text| Use format strings to print arbitrary futhark expressions. Usage: > :format The value of foo: {foo}. The value of 2+2={2+2} |] ) ), ( "type", ( typeCommand, [text| Show the type of an expression, which must fit on a single line. |] ) ), ( "mtype", ( mtypeCommand, [text| Show the type of a module expression, which must fit on a single line. |] ) ), ( "unbreak", ( unbreakCommand, [text| Skip all future occurrences of the current breakpoint. |] ) ), ( "nanbreak", ( nanbreakCommand, [text| Toggle treating operators that produce new NaNs as breakpoints. We consider a NaN to be "new" if none of the arguments to the operator in question is a NaN. |] ) ), ( "frame", ( frameCommand, [text| While at a break point, jump to another stack frame, whose variables can then be inspected. Resuming from the breakpoint will jump back to the innermost stack frame. |] ) ), ( "pwd", ( pwdCommand, [text| Print the current working directory. |] ) ), ( "cd", ( cdCommand, [text| Change the current working directory. |] ) ), ( "help", ( helpCommand, [text| Print a list of commands and a description of their behaviour. |] ) ), ( "quit", ( quitCommand, [text| Exit REPL. |] ) ) ] futhark-0.25.27/src/Futhark/CLI/Run.hs000066400000000000000000000105401475065116200172440ustar00rootroot00000000000000-- | @futhark run@ module Futhark.CLI.Run (main) where import Control.Exception import Control.Monad import Control.Monad.Except (ExceptT, runExceptT, throwError) import Control.Monad.Free.Church import Control.Monad.IO.Class (MonadIO, liftIO) import Data.ByteString.Lazy qualified as BS import Data.Map qualified as M import Data.Maybe import Data.Text.IO qualified as T import Futhark.Compiler import Futhark.Data.Reader (readValues) import Futhark.Pipeline import Futhark.Util.Options import Futhark.Util.Pretty (AnsiStyle, Doc, align, hPutDoc, hPutDocLn, pretty, unAnnotate, (<+>)) import Language.Futhark import Language.Futhark.Interpreter qualified as I import Language.Futhark.Semantic qualified as T import System.Exit import System.FilePath import System.IO import Prelude -- | Run @futhark run@. main :: String -> [String] -> IO () main = mainWithOptions interpreterConfig options "options... " run where run [prog] config = Just $ interpret config prog run _ _ = Nothing interpret :: InterpreterConfig -> FilePath -> IO () interpret config fp = do pr <- newFutharkiState config fp (tenv, ienv) <- case pr of Left err -> do hPutDoc stderr err exitFailure Right env -> pure env let entry = interpreterEntryPoint config vr <- readValues <$> BS.getContents inps <- case vr of Nothing -> do T.hPutStrLn stderr "Incorrectly formatted input data." exitFailure Just vs -> pure vs (fname, ret) <- case M.lookup (T.Term, entry) $ T.envNameMap tenv of Just fname | Just (T.BoundV _ t) <- M.lookup (qualLeaf fname) $ T.envVtable tenv -> pure (fname, toStructural $ snd $ unfoldFunType t) _ -> do T.hPutStrLn stderr $ "Invalid entry point: " <> prettyText entry exitFailure case I.interpretFunction ienv (qualLeaf fname) inps of Left err -> do T.hPutStrLn stderr err exitFailure Right run -> do run' <- runInterpreter' run case run' of Left err -> do hPrint stderr err exitFailure Right res -> case (I.fromTuple res, isTupleRecord ret) of (Just vs, Just ts) -> zipWithM_ putValue vs ts _ -> putValue res ret putValue :: I.Value -> TypeBase () () -> IO () putValue v t | I.isEmptyArray v = T.putStrLn $ I.prettyEmptyArray t v | otherwise = T.putStrLn $ I.valueText v data InterpreterConfig = InterpreterConfig { interpreterEntryPoint :: Name, interpreterPrintWarnings :: Bool } interpreterConfig :: InterpreterConfig interpreterConfig = InterpreterConfig defaultEntryPoint True options :: [FunOptDescr InterpreterConfig] options = [ Option "e" ["entry-point"] ( ReqArg ( \entry -> Right $ \config -> config {interpreterEntryPoint = nameFromString entry} ) "NAME" ) "The entry point to execute.", Option "w" ["no-warnings"] (NoArg $ Right $ \config -> config {interpreterPrintWarnings = False}) "Do not print warnings." ] newFutharkiState :: InterpreterConfig -> FilePath -> IO (Either (Doc AnsiStyle) (T.Env, I.Ctx)) newFutharkiState cfg file = runExceptT $ do (ws, imports, _src) <- badOnLeft prettyCompilerError =<< liftIO ( runExceptT (readProgramFile [] file) `catch` \(err :: IOException) -> pure (externalErrorS (show err)) ) when (interpreterPrintWarnings cfg) $ liftIO $ hPutDoc stderr $ prettyWarnings ws let loadImport ctx = badOnLeft I.prettyInterpreterError <=< runInterpreter' . I.interpretImport ctx ictx <- foldM loadImport I.initialCtx $ map (fmap fileProg) imports let (tenv, ienv) = let (iname, fm) = last imports in ( fileScope fm, ictx {I.ctxEnv = I.ctxImports ictx M.! iname} ) pure (tenv, ienv) where badOnLeft :: (err -> err') -> Either err a -> ExceptT err' IO a badOnLeft _ (Right x) = pure x badOnLeft p (Left err) = throwError $ p err runInterpreter' :: (MonadIO m) => F I.ExtOp a -> m (Either I.InterpreterError a) runInterpreter' m = runF m (pure . Right) intOp where intOp (I.ExtOpError err) = pure $ Left err intOp (I.ExtOpTrace w v c) = do liftIO $ hPutDocLn stderr $ pretty w <> ":" <+> align (unAnnotate v) c intOp (I.ExtOpBreak _ _ _ c) = c futhark-0.25.27/src/Futhark/CLI/Script.hs000066400000000000000000000100701475065116200177420ustar00rootroot00000000000000-- | @futhark script@ module Futhark.CLI.Script (main) where import Control.Monad.Except import Control.Monad.IO.Class (MonadIO, liftIO) import Data.Binary qualified as Bin import Data.ByteString.Lazy.Char8 qualified as BS import Data.Char (chr) import Data.Text qualified as T import Data.Text.IO qualified as T import Futhark.CLI.Literate ( Options (..), initialOptions, prepareServer, scriptCommandLineOptions, ) import Futhark.Script import Futhark.Test.Values (Compound (..), getValue, valueType) import Futhark.Util.Options import Futhark.Util.Pretty (prettyText) import System.Exit import System.IO commandLineOptions :: [FunOptDescr Options] commandLineOptions = scriptCommandLineOptions ++ [ Option "D" ["debug"] ( NoArg $ Right $ \config -> config { scriptExtraOptions = "-D" : scriptExtraOptions config, scriptVerbose = scriptVerbose config + 1 } ) "Enable debugging.", Option "L" ["log"] ( NoArg $ Right $ \config -> config { scriptExtraOptions = "-L" : scriptExtraOptions config, scriptVerbose = scriptVerbose config + 1 } ) "Enable logging.", Option "b" ["binary"] (NoArg $ Right $ \config -> config {scriptBinary = True}) "Produce binary output.", Option "f" ["file"] ( ReqArg (\f -> Right $ \config -> config {scriptExps = scriptExps config ++ [Left f]}) "FILE" ) "Run FutharkScript from this file.", Option "e" ["expression"] ( ReqArg (\s -> Right $ \config -> config {scriptExps = scriptExps config ++ [Right (T.pack s)]}) "EXP" ) "Run this expression." ] parseScriptFile :: FilePath -> IO Exp parseScriptFile f = do s <- T.readFile f case parseExpFromText f s of Left e -> do T.hPutStrLn stderr e exitFailure Right e -> pure e getExp :: Either FilePath T.Text -> IO Exp getExp (Left f) = parseScriptFile f getExp (Right s) = case parseExpFromText "command line option" s of Left e -> do T.hPutStrLn stderr e exitFailure Right e -> pure e -- A few extra procedures that are not handled by scriptBuiltin. extScriptBuiltin :: (MonadError T.Text m, MonadIO m) => EvalBuiltin m extScriptBuiltin "store" [ValueAtom fv, ValueAtom vv] | Just path <- getValue fv = do let path' = map (chr . fromIntegral) (path :: [Bin.Word8]) liftIO $ BS.writeFile path' $ Bin.encode vv pure $ ValueTuple [] extScriptBuiltin "store" vs = throwError $ "$store does not accept arguments of types: " <> T.intercalate ", " (map (prettyText . fmap valueType) vs) extScriptBuiltin f vs = scriptBuiltin "." f vs -- | Run @futhark script@. main :: String -> [String] -> IO () main = mainWithOptions initialOptions commandLineOptions "PROGRAM [EXP]" $ \args opts -> case args of [prog, script] -> Just $ main' prog opts $ scriptExps opts ++ [Right $ T.pack script] [prog] -> Just $ main' prog opts $ scriptExps opts _ -> Nothing where main' prog opts scripts = do scripts' <- mapM getExp scripts prepareServer prog opts $ \s -> do r <- runExceptT $ do vs <- mapM (evalExp extScriptBuiltin s) scripts' case reverse vs of [] -> pure Nothing v : _ -> Just <$> getExpValue s v case r of Left e -> do T.hPutStrLn stderr e exitFailure Right Nothing -> pure () Right (Just v) -> if scriptBinary opts then case v of ValueAtom v' -> BS.putStr $ Bin.encode v' _ -> T.hPutStrLn stderr "Result value cannot be represented in binary format." else T.putStrLn $ prettyText v futhark-0.25.27/src/Futhark/CLI/Test.hs000066400000000000000000000710561475065116200174300ustar00rootroot00000000000000{-# LANGUAGE LambdaCase #-} -- | @futhark test@ module Futhark.CLI.Test (main) where import Control.Applicative.Lift (Errors, Lift (..), failure, runErrors) import Control.Concurrent import Control.Concurrent.Async import Control.Exception import Control.Monad import Control.Monad.Except (ExceptT (..), MonadError, runExceptT, withExceptT) import Control.Monad.Except qualified as E import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.Trans.Class (lift) import Data.ByteString qualified as SBS import Data.ByteString.Lazy qualified as LBS import Data.List (delete, partition) import Data.Map.Strict qualified as M import Data.Text qualified as T import Data.Text.Encoding qualified as T import Data.Text.IO qualified as T import Data.Time.Clock.System (SystemTime (..), getSystemTime) import Futhark.Analysis.Metrics.Type import Futhark.Server import Futhark.Test import Futhark.Util (atMostChars, fancyTerminal, showText) import Futhark.Util.Options import Futhark.Util.Pretty (annotate, bgColor, bold, hardline, pretty, putDoc, vsep) import Futhark.Util.Table import System.Console.ANSI (clearFromCursorToScreenEnd, clearLine, cursorUpLine) import System.Console.Terminal.Size qualified as Terminal import System.Environment import System.Exit import System.FilePath import System.IO import System.Process.ByteString (readProcessWithExitCode) import Text.Regex.TDFA --- Test execution -- The use of [T.Text] here is somewhat kludgy. We use it to track how -- many errors have occurred during testing of a single program (which -- may have multiple entry points). This should really not be done at -- the monadic level - a test failing should be handled explicitly. type TestM = ExceptT [T.Text] IO -- Taken from transformers-0.5.5.0. eitherToErrors :: Either e a -> Errors e a eitherToErrors = either failure Pure throwError :: (MonadError [e] m) => e -> m a throwError e = E.throwError [e] runTestM :: TestM () -> IO TestResult runTestM = fmap (either Failure $ const Success) . runExceptT liftExcept :: ExceptT T.Text IO a -> TestM a liftExcept = either (E.throwError . pure) pure <=< liftIO . runExceptT context :: T.Text -> TestM a -> TestM a context s = withExceptT $ \case [] -> [] (e : es') -> (s <> ":\n" <> e) : es' context1 :: (Monad m) => T.Text -> ExceptT T.Text m a -> ExceptT T.Text m a context1 s = withExceptT $ \e -> s <> ":\n" <> e accErrors :: [TestM a] -> TestM [a] accErrors tests = do eithers <- lift $ mapM runExceptT tests let errors = traverse eitherToErrors eithers ExceptT $ pure $ runErrors errors accErrors_ :: [TestM a] -> TestM () accErrors_ = void . accErrors data TestResult = Success | Failure [T.Text] deriving (Eq, Show) pureTestResults :: IO [TestResult] -> TestM () pureTestResults m = do errs <- foldr collectErrors mempty <$> liftIO m unless (null errs) $ E.throwError $ concat errs where collectErrors Success errs = errs collectErrors (Failure err) errs = err : errs -- | The longest we are willing to wait for a test, in microseconds. timeout :: Int timeout = 5 * 60 * 1000000 withProgramServer :: FilePath -> FilePath -> [String] -> (Server -> IO [TestResult]) -> TestM () withProgramServer program runner extra_options f = do -- Explicitly prefixing the current directory is necessary for -- readProcessWithExitCode to find the binary when binOutputf has -- no path component. let binOutputf = dropExtension program binpath = "." binOutputf (to_run, to_run_args) | null runner = (binpath, extra_options) | otherwise = (runner, binpath : extra_options) prog_ctx = "Running " <> T.pack (unwords $ binpath : extra_options) context prog_ctx . pureTestResults . liftIO $ withServer (futharkServerCfg to_run to_run_args) $ \server -> race (threadDelay timeout) (f server) >>= \case Left _ -> do abortServer server fail $ "test timeout after " <> show timeout <> " microseconds" Right r -> pure r data TestMode = -- | Only type check. TypeCheck | -- | Only compile (do not run). Compile | -- | Only internalise (do not run). Internalise | -- | Test compiled code. Compiled | -- | Test interpreted code. Interpreted | -- | Perform structure tests. Structure deriving (Eq, Show) data TestCase = TestCase { _testCaseMode :: TestMode, testCaseProgram :: FilePath, testCaseTest :: ProgramTest, _testCasePrograms :: ProgConfig } deriving (Show) instance Eq TestCase where x == y = testCaseProgram x == testCaseProgram y instance Ord TestCase where x `compare` y = testCaseProgram x `compare` testCaseProgram y data RunResult = ErrorResult T.Text | SuccessResult [Value] progNotFound :: T.Text -> T.Text progNotFound s = s <> ": command not found" optimisedProgramMetrics :: ProgConfig -> StructurePipeline -> FilePath -> TestM AstMetrics optimisedProgramMetrics programs pipeline program = case pipeline of SOACSPipeline -> check ["-s"] GpuPipeline -> check ["--gpu"] MCPipeline -> check ["--mc"] SeqMemPipeline -> check ["--seq-mem"] GpuMemPipeline -> check ["--gpu-mem"] MCMemPipeline -> check ["--mc-mem"] NoPipeline -> check [] where check opt = do futhark <- liftIO $ maybe getExecutablePath pure $ configFuthark programs let opts = ["dev"] ++ opt ++ ["--metrics", program] (code, output, err) <- liftIO $ readProcessWithExitCode futhark opts "" let output' = T.decodeUtf8 output case code of ExitSuccess | [(m, [])] <- reads $ T.unpack output' -> pure m | otherwise -> throwError $ "Could not read metrics output:\n" <> output' ExitFailure 127 -> throwError $ progNotFound $ T.pack futhark ExitFailure _ -> throwError $ T.decodeUtf8 err testMetrics :: ProgConfig -> FilePath -> StructureTest -> TestM () testMetrics programs program (StructureTest pipeline (AstMetrics expected)) = context "Checking metrics" $ do actual <- optimisedProgramMetrics programs pipeline program accErrors_ $ map (ok actual) $ M.toList expected where maybePipeline :: StructurePipeline -> T.Text maybePipeline SOACSPipeline = "(soacs) " maybePipeline GpuPipeline = "(gpu) " maybePipeline MCPipeline = "(mc) " maybePipeline SeqMemPipeline = "(seq-mem) " maybePipeline GpuMemPipeline = "(gpu-mem) " maybePipeline MCMemPipeline = "(mc-mem) " maybePipeline NoPipeline = "" ok (AstMetrics metrics) (name, expected_occurences) = case M.lookup name metrics of Nothing | expected_occurences > 0 -> throwError $ name <> maybePipeline pipeline <> " should have occurred " <> showText expected_occurences <> " times, but did not occur at all in optimised program." Just actual_occurences | expected_occurences /= actual_occurences -> throwError $ name <> maybePipeline pipeline <> " should have occurred " <> showText expected_occurences <> " times, but occurred " <> showText actual_occurences <> " times." _ -> pure () testWarnings :: [WarningTest] -> SBS.ByteString -> TestM () testWarnings warnings futerr = accErrors_ $ map testWarning warnings where testWarning (ExpectedWarning regex_s regex) | not (match regex $ T.unpack $ T.decodeUtf8 futerr) = throwError $ "Expected warning:\n " <> regex_s <> "\nGot warnings:\n " <> T.decodeUtf8 futerr | otherwise = pure () runInterpretedEntry :: FutharkExe -> FilePath -> InputOutputs -> TestM () runInterpretedEntry (FutharkExe futhark) program (InputOutputs entry run_cases) = let dir = takeDirectory program runInterpretedCase run@(TestRun _ inputValues _ index _) = unless (any (`elem` runTags run) ["compiled", "script"]) $ context ("Entry point: " <> entry <> "; dataset: " <> runDescription run) $ do input <- T.unlines . map valueText <$> getValues (FutharkExe futhark) dir inputValues expectedResult' <- getExpectedResult (FutharkExe futhark) program entry run (code, output, err) <- liftIO $ readProcessWithExitCode futhark ["run", "-e", T.unpack entry, program] $ T.encodeUtf8 input case code of ExitFailure 127 -> throwError $ progNotFound $ T.pack futhark _ -> liftExcept $ compareResult entry index program expectedResult' =<< runResult program code output err in accErrors_ $ map runInterpretedCase run_cases runTestCase :: TestCase -> TestM () runTestCase (TestCase mode program testcase progs) = do futhark <- liftIO $ maybe getExecutablePath pure $ configFuthark progs let checkctx = mconcat [ "Type-checking with '", T.pack futhark, " check ", T.pack program, "'" ] case testAction testcase of CompileTimeFailure expected_error -> unless (mode `elem` [Structure, Internalise]) . context checkctx $ do (code, _, err) <- liftIO $ readProcessWithExitCode futhark ["check", program] "" case code of ExitSuccess -> throwError "Expected failure\n" ExitFailure 127 -> throwError $ progNotFound $ T.pack futhark ExitFailure 1 -> throwError $ T.decodeUtf8 err ExitFailure _ -> liftExcept $ checkError expected_error $ T.decodeUtf8 err RunCases {} | mode == TypeCheck -> do let options = ["check", program] ++ configExtraCompilerOptions progs context checkctx $ do (code, _, err) <- liftIO $ readProcessWithExitCode futhark options "" case code of ExitSuccess -> pure () ExitFailure 127 -> throwError $ progNotFound $ T.pack futhark ExitFailure _ -> throwError $ T.decodeUtf8 err | mode == Internalise -> do let options = ["dev", program] ++ configExtraCompilerOptions progs context checkctx $ do (code, _, err) <- liftIO $ readProcessWithExitCode futhark options "" case code of ExitSuccess -> pure () ExitFailure 127 -> throwError $ progNotFound $ T.pack futhark ExitFailure _ -> throwError $ T.decodeUtf8 err RunCases ios structures warnings -> do -- Compile up-front and reuse same executable for several entry points. let backend = configBackend progs extra_compiler_options = configExtraCompilerOptions progs when (mode `elem` [Compiled, Interpreted]) $ context "Generating reference outputs" $ -- We probably get the concurrency at the test program level, -- so force just one data set at a time here. withExceptT pure $ ensureReferenceOutput (Just 1) (FutharkExe futhark) "c" program ios when (mode == Structure) $ mapM_ (testMetrics progs program) structures when (mode `elem` [Compile, Compiled]) $ context ("Compiling with --backend=" <> T.pack backend) $ do compileTestProgram extra_compiler_options (FutharkExe futhark) backend program warnings unless (mode == Compile) $ do (tuning_opts, _) <- liftIO $ determineTuning (configTuning progs) program let extra_options = determineCache (configCacheExt progs) program ++ tuning_opts ++ configExtraOptions progs runner = configRunner progs context "Running compiled program" $ withProgramServer program runner extra_options $ \server -> do let run = runCompiledEntry (FutharkExe futhark) server program concat <$> mapM run ios when (mode == Interpreted) $ context "Interpreting" $ accErrors_ $ map (runInterpretedEntry (FutharkExe futhark) program) ios liftCommand :: (MonadError T.Text m, MonadIO m) => IO (Maybe CmdFailure) -> m () liftCommand m = do r <- liftIO m case r of Just (CmdFailure _ err) -> E.throwError $ T.unlines err Nothing -> pure () runCompiledEntry :: FutharkExe -> Server -> FilePath -> InputOutputs -> IO [TestResult] runCompiledEntry futhark server program (InputOutputs entry run_cases) = do output_types <- cmdOutputs server entry input_types <- cmdInputs server entry case (,) <$> output_types <*> input_types of Left (CmdFailure _ err) -> pure [Failure err] Right (output_types', input_types') -> do let outs = ["out" <> showText i | i <- [0 .. length output_types' - 1]] ins = ["in" <> showText i | i <- [0 .. length input_types' - 1]] onRes = either (Failure . pure) (const Success) mapM (fmap onRes . runCompiledCase input_types' outs ins) run_cases where dir = takeDirectory program runCompiledCase input_types outs ins run = runExceptT $ do let TestRun _ input_spec _ index _ = run case_ctx = "Entry point: " <> entry <> "; dataset: " <> runDescription run context1 case_ctx $ do expected <- getExpectedResult futhark program entry run valuesAsVars server (zip ins (map inputType input_types)) futhark dir input_spec call_r <- liftIO $ cmdCall server entry outs ins liftCommand $ cmdFree server ins res <- case call_r of Left (CmdFailure _ err) -> pure $ ErrorResult $ T.unlines err Right _ -> SuccessResult <$> readResults server outs <* liftCommand (cmdFree server outs) compareResult entry index program expected res checkError :: (MonadError T.Text m) => ExpectedError -> T.Text -> m () checkError (ThisError regex_s regex) err | not (match regex $ T.unpack err) = E.throwError $ "Expected error:\n " <> regex_s <> "\nGot error:\n" <> T.unlines (map (" " <>) (T.lines err)) checkError _ _ = pure () runResult :: (MonadIO m, MonadError T.Text m) => FilePath -> ExitCode -> SBS.ByteString -> SBS.ByteString -> m RunResult runResult program ExitSuccess stdout_s _ = case valuesFromByteString "stdout" $ LBS.fromStrict stdout_s of Left e -> do let actualf = program `addExtension` "actual" liftIO $ SBS.writeFile actualf stdout_s E.throwError $ T.pack e <> "\n(See " <> T.pack actualf <> ")" Right vs -> pure $ SuccessResult vs runResult _ (ExitFailure _) _ stderr_s = pure $ ErrorResult $ T.decodeUtf8 stderr_s compileTestProgram :: [String] -> FutharkExe -> String -> FilePath -> [WarningTest] -> TestM () compileTestProgram extra_options futhark backend program warnings = do (_, futerr) <- withExceptT pure $ compileProgram ("--server" : extra_options) futhark backend program testWarnings warnings futerr compareResult :: (MonadIO m, MonadError T.Text m) => T.Text -> Int -> FilePath -> ExpectedResult [Value] -> RunResult -> m () compareResult _ _ _ (Succeeds Nothing) SuccessResult {} = pure () compareResult entry index program (Succeeds (Just expected_vs)) (SuccessResult actual_vs) = checkResult (program <.> T.unpack entry <.> show index) expected_vs actual_vs compareResult _ _ _ (RunTimeFailure expectedError) (ErrorResult actualError) = checkError expectedError actualError compareResult _ _ _ (Succeeds _) (ErrorResult err) = E.throwError $ "Function failed with error:\n" <> err compareResult _ _ _ (RunTimeFailure f) (SuccessResult _) = E.throwError $ "Program succeeded, but expected failure:\n " <> showText f --- --- Test manager --- data TestStatus = TestStatus { testStatusRemain :: [TestCase], testStatusRun :: [TestCase], testStatusTotal :: Int, testStatusFail :: Int, testStatusPass :: Int, testStatusRuns :: Int, testStatusRunsRemain :: Int, testStatusRunPass :: Int, testStatusRunFail :: Int } catching :: IO TestResult -> IO TestResult catching m = m `catch` save where save :: SomeException -> IO TestResult save e = pure $ Failure [showText e] doTest :: TestCase -> IO TestResult doTest = catching . runTestM . runTestCase makeTestCase :: TestConfig -> TestMode -> (FilePath, ProgramTest) -> TestCase makeTestCase config mode (file, spec) = TestCase mode file spec $ configPrograms config data ReportMsg = TestStarted TestCase | TestDone TestCase TestResult runTest :: MVar TestCase -> MVar ReportMsg -> IO () runTest testmvar resmvar = forever $ do test <- takeMVar testmvar putMVar resmvar $ TestStarted test res <- doTest test putMVar resmvar $ TestDone test res excludedTest :: TestConfig -> TestCase -> Bool excludedTest config = any (`elem` configExclude config) . testTags . testCaseTest -- | Exclude those test cases that have tags we do not wish to run. excludeCases :: TestConfig -> TestCase -> TestCase excludeCases config tcase = tcase {testCaseTest = onTest $ testCaseTest tcase} where onTest (ProgramTest desc tags action) = ProgramTest desc tags $ onAction action onAction (RunCases ios stest wtest) = RunCases (map onIOs ios) stest wtest onAction action = action onIOs (InputOutputs entry runs) = InputOutputs entry $ filter (not . any excluded . runTags) runs excluded = (`elem` configExclude config) putStatusTable :: TestStatus -> IO () putStatusTable ts = hPutTable stdout rows 1 where rows = [ [mkEntry "" mempty, passed, failed, mkEntry "remaining" mempty], map (`mkEntry` mempty) ["programs", passedProgs, failedProgs, remainProgs'], map (`mkEntry` mempty) ["runs", passedRuns, failedRuns, remainRuns'] ] passed = mkEntry "passed" $ color Green failed = mkEntry "failed" $ color Red passedProgs = show $ testStatusPass ts failedProgs = show $ testStatusFail ts totalProgs = show $ testStatusTotal ts totalRuns = show $ testStatusRuns ts passedRuns = show $ testStatusRunPass ts failedRuns = show $ testStatusRunFail ts remainProgs = show . length $ testStatusRemain ts remainProgs' = remainProgs ++ "/" ++ totalProgs remainRuns = show $ testStatusRunsRemain ts remainRuns' = remainRuns ++ "/" ++ totalRuns tableLines :: Int tableLines = 8 spaceTable :: IO () spaceTable = putStr $ replicate tableLines '\n' reportTable :: TestStatus -> IO () reportTable ts = do moveCursorToTableTop putStatusTable ts clearLine w <- maybe 80 Terminal.width <$> Terminal.size T.putStrLn $ atMostChars (w - T.length labelstr) running where running = labelstr <> (T.unwords . reverse . map (T.pack . testCaseProgram) . testStatusRun) ts labelstr = "Now testing: " reportLine :: MVar SystemTime -> TestStatus -> IO () reportLine time_mvar ts = modifyMVar_ time_mvar $ \time -> do time_now <- getSystemTime if systemSeconds time_now - systemSeconds time >= period then do T.putStrLn $ showText (testStatusFail ts) <> " failed, " <> showText (testStatusPass ts) <> " passed, " <> showText num_remain <> " to go." pure time_now else pure time where num_remain = length $ testStatusRemain ts period = 60 moveCursorToTableTop :: IO () moveCursorToTableTop = cursorUpLine tableLines runTests :: TestConfig -> [FilePath] -> IO () runTests config paths = do -- We force line buffering to ensure that we produce running output. -- Otherwise, CI tools and the like may believe we are hung and kill -- us. hSetBuffering stdout LineBuffering let mode = configTestMode config all_tests <- map (makeTestCase config mode) <$> testSpecsFromPathsOrDie paths testmvar <- newEmptyMVar reportmvar <- newEmptyMVar concurrency <- maybe getNumCapabilities pure $ configConcurrency config replicateM_ concurrency $ forkIO $ runTest testmvar reportmvar let (excluded, included) = partition (excludedTest config) all_tests _ <- forkIO $ mapM_ (putMVar testmvar . excludeCases config) included time_mvar <- newMVar $ MkSystemTime 0 0 let fancy = not (configLineOutput config) && fancyTerminal report | fancy = reportTable | otherwise = reportLine time_mvar clear | fancy = clearFromCursorToScreenEnd | otherwise = pure () numTestCases tc = case testAction $ testCaseTest tc of CompileTimeFailure _ -> 1 RunCases ios sts wts -> length (concatMap iosTestRuns ios) + length sts + length wts getResults ts | null (testStatusRemain ts) = report ts >> pure ts | otherwise = do report ts msg <- takeMVar reportmvar case msg of TestStarted test -> getResults $ ts {testStatusRun = test : testStatusRun ts} TestDone test res -> do let ts' = ts { testStatusRemain = test `delete` testStatusRemain ts, testStatusRun = test `delete` testStatusRun ts, testStatusRunsRemain = testStatusRunsRemain ts - numTestCases test } case res of Success -> do let ts'' = ts' { testStatusRunPass = testStatusRunPass ts' + numTestCases test } getResults $ ts'' {testStatusPass = testStatusPass ts + 1} Failure s -> do when fancy moveCursorToTableTop clear putDoc $ annotate (bold <> bgColor Red) (pretty (testCaseProgram test) <> ":") <> hardline <> vsep (map pretty s) <> hardline when fancy spaceTable getResults $ ts' { testStatusFail = testStatusFail ts' + 1, testStatusRunPass = testStatusRunPass ts' + max 0 (numTestCases test - length s), testStatusRunFail = testStatusRunFail ts' + min (numTestCases test) (length s) } when fancy spaceTable ts <- getResults TestStatus { testStatusRemain = included, testStatusRun = [], testStatusTotal = length included, testStatusFail = 0, testStatusPass = 0, testStatusRuns = sum $ map numTestCases included, testStatusRunsRemain = sum $ map numTestCases included, testStatusRunPass = 0, testStatusRunFail = 0 } -- Removes "Now testing" output. if fancy then cursorUpLine 1 >> clearLine else putStrLn $ show (testStatusPass ts) <> "/" <> show (testStatusTotal ts) <> " passed." unless (null excluded) . putStrLn $ show (length excluded) ++ " program(s) excluded." exitWith $ case testStatusFail ts of 0 -> ExitSuccess _ -> ExitFailure 1 --- --- Configuration and command line parsing --- data TestConfig = TestConfig { configTestMode :: TestMode, configPrograms :: ProgConfig, configExclude :: [T.Text], configLineOutput :: Bool, configConcurrency :: Maybe Int } defaultConfig :: TestConfig defaultConfig = TestConfig { configTestMode = Compiled, configExclude = ["disable"], configPrograms = ProgConfig { configBackend = "c", configFuthark = Nothing, configRunner = "", configExtraOptions = [], configExtraCompilerOptions = [], configTuning = Just "tuning", configCacheExt = Nothing }, configLineOutput = False, configConcurrency = Nothing } data ProgConfig = ProgConfig { configBackend :: String, configFuthark :: Maybe FilePath, configRunner :: FilePath, configExtraCompilerOptions :: [String], configTuning :: Maybe String, configCacheExt :: Maybe String, -- | Extra options passed to the programs being run. configExtraOptions :: [String] } deriving (Show) changeProgConfig :: (ProgConfig -> ProgConfig) -> TestConfig -> TestConfig changeProgConfig f config = config {configPrograms = f $ configPrograms config} setBackend :: FilePath -> ProgConfig -> ProgConfig setBackend backend config = config {configBackend = backend} setFuthark :: FilePath -> ProgConfig -> ProgConfig setFuthark futhark config = config {configFuthark = Just futhark} setRunner :: FilePath -> ProgConfig -> ProgConfig setRunner runner config = config {configRunner = runner} addCompilerOption :: String -> ProgConfig -> ProgConfig addCompilerOption option config = config {configExtraCompilerOptions = configExtraCompilerOptions config ++ [option]} addOption :: String -> ProgConfig -> ProgConfig addOption option config = config {configExtraOptions = configExtraOptions config ++ [option]} commandLineOptions :: [FunOptDescr TestConfig] commandLineOptions = [ Option "t" ["typecheck"] (NoArg $ Right $ \config -> config {configTestMode = TypeCheck}) "Only perform type-checking", Option "i" ["interpreted"] (NoArg $ Right $ \config -> config {configTestMode = Interpreted}) "Only interpret", Option "c" ["compiled"] (NoArg $ Right $ \config -> config {configTestMode = Compiled}) "Only run compiled code (the default)", Option "C" ["compile"] (NoArg $ Right $ \config -> config {configTestMode = Compile}) "Only compile, do not run.", Option "s" ["structure"] (NoArg $ Right $ \config -> config {configTestMode = Structure}) "Perform structure tests.", Option "I" ["internalise"] (NoArg $ Right $ \config -> config {configTestMode = Internalise}) "Only run the compiler frontend.", Option [] ["no-terminal", "notty"] (NoArg $ Right $ \config -> config {configLineOutput = True}) "Provide simpler line-based output.", Option [] ["backend"] (ReqArg (Right . changeProgConfig . setBackend) "BACKEND") "Backend used for compilation (defaults to 'c').", Option [] ["futhark"] (ReqArg (Right . changeProgConfig . setFuthark) "PROGRAM") "Program to run for subcommands (defaults to same binary as 'futhark test').", Option [] ["runner"] (ReqArg (Right . changeProgConfig . setRunner) "PROGRAM") "The program used to run the Futhark-generated programs (defaults to nothing).", Option [] ["exclude"] ( ReqArg ( \tag -> Right $ \config -> config {configExclude = T.pack tag : configExclude config} ) "TAG" ) "Exclude test programs that define this tag.", Option "p" ["pass-option"] (ReqArg (Right . changeProgConfig . addOption) "OPT") "Pass this option to programs being run.", Option [] ["pass-compiler-option"] (ReqArg (Right . changeProgConfig . addCompilerOption) "OPT") "Pass this option to the compiler (or typechecker if in -t mode).", Option [] ["no-tuning"] (NoArg $ Right $ changeProgConfig $ \config -> config {configTuning = Nothing}) "Do not load tuning files.", Option [] ["cache-extension"] ( ReqArg (\s -> Right $ changeProgConfig $ \config -> config {configCacheExt = Just s}) "EXTENSION" ) "Use cache files with this extension (none by default).", Option [] ["concurrency"] ( ReqArg ( \n -> case reads n of [(n', "")] | n' > 0 -> Right $ \config -> config {configConcurrency = Just n'} _ -> Left . optionsError $ "'" ++ n ++ "' is not a positive integer." ) "NUM" ) "Number of tests to run concurrently." ] excludeBackend :: TestConfig -> TestConfig excludeBackend config = config { configExclude = "no_" <> T.pack (configBackend (configPrograms config)) : configExclude config } -- | Run @futhark test@. main :: String -> [String] -> IO () main = mainWithOptions defaultConfig commandLineOptions "options... programs..." $ \progs config -> case progs of [] -> Nothing _ -> Just $ runTests (excludeBackend config) progs futhark-0.25.27/src/Futhark/CLI/WASM.hs000066400000000000000000000007411475065116200172510ustar00rootroot00000000000000-- | @futhark wasm@ module Futhark.CLI.WASM (main) where import Futhark.Actions (compileCtoWASMAction) import Futhark.Compiler.CLI import Futhark.Passes (seqmemPipeline) -- | Run @futhark c@ main :: String -> [String] -> IO () main = compilerMain () [] "Compile to WASM" "Generate WASM with the sequential C backend code from optimised Futhark program." seqmemPipeline $ \fcfg () mode outpath prog -> actionProcedure (compileCtoWASMAction fcfg mode outpath) prog futhark-0.25.27/src/Futhark/CodeGen/000077500000000000000000000000001475065116200170415ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/Backends/000077500000000000000000000000001475065116200205535ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/Backends/CCUDA.hs000066400000000000000000000125741475065116200217370ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | Code generation for CUDA. module Futhark.CodeGen.Backends.CCUDA ( compileProg, GC.CParts (..), GC.asLibrary, GC.asExecutable, GC.asServer, ) where import Data.Map qualified as M import Data.Text qualified as T import Futhark.CodeGen.Backends.GPU import Futhark.CodeGen.Backends.GenericC qualified as GC import Futhark.CodeGen.Backends.GenericC.Options import Futhark.CodeGen.ImpCode.OpenCL import Futhark.CodeGen.ImpGen.CUDA qualified as ImpGen import Futhark.CodeGen.RTS.C (backendsCudaH) import Futhark.IR.GPUMem hiding ( CmpSizeLe, GetSize, GetSizeMax, ) import Futhark.MonadFreshNames import Language.C.Quote.OpenCL qualified as C import NeatInterpolation (untrimming) mkBoilerplate :: T.Text -> [(Name, KernelConstExp)] -> M.Map Name KernelSafety -> [PrimType] -> [FailureMsg] -> GC.CompilerM OpenCL () () mkBoilerplate cuda_program macros kernels types failures = do generateGPUBoilerplate cuda_program macros backendsCudaH (M.keys kernels) types failures GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_add_nvrtc_option(struct futhark_context_config *cfg, const char* opt);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_device(struct futhark_context_config *cfg, const char* s);|] GC.headerDecl GC.InitDecl [C.cedecl|const char* futhark_context_config_get_program(struct futhark_context_config *cfg);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_program(struct futhark_context_config *cfg, const char* s);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_dump_ptx_to(struct futhark_context_config *cfg, const char* s);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_load_ptx_from(struct futhark_context_config *cfg, const char* s);|] cliOptions :: [Option] cliOptions = gpuOptions ++ [ Option { optionLongName = "dump-cuda", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Dump the embedded CUDA kernels to the indicated file.", optionAction = [C.cstm|{const char* prog = futhark_context_config_get_program(cfg); if (dump_file(optarg, prog, strlen(prog)) != 0) { fprintf(stderr, "%s: %s\n", optarg, strerror(errno)); exit(1); } exit(0);}|] }, Option { optionLongName = "load-cuda", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Instead of using the embedded CUDA kernels, load them from the indicated file.", optionAction = [C.cstm|{ size_t n; const char *s = slurp_file(optarg, &n); if (s == NULL) { fprintf(stderr, "%s: %s\n", optarg, strerror(errno)); exit(1); } futhark_context_config_set_program(cfg, s); }|] }, Option { optionLongName = "dump-ptx", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Dump the PTX-compiled version of the embedded kernels to the indicated file.", optionAction = [C.cstm|{futhark_context_config_dump_ptx_to(cfg, optarg); entry_point = NULL;}|] }, Option { optionLongName = "load-ptx", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Load PTX code from the indicated file.", optionAction = [C.cstm|futhark_context_config_load_ptx_from(cfg, optarg);|] }, Option { optionLongName = "nvrtc-option", optionShortName = Nothing, optionArgument = RequiredArgument "OPT", optionDescription = "Add an additional build option to the string passed to NVRTC.", optionAction = [C.cstm|futhark_context_config_add_nvrtc_option(cfg, optarg);|] } ] cudaMemoryType :: GC.MemoryType OpenCL () cudaMemoryType "device" = pure [C.cty|typename CUdeviceptr|] cudaMemoryType space = error $ "GPU backend does not support '" ++ space ++ "' memory space." -- | Compile the program to C with calls to CUDA. compileProg :: (MonadFreshNames m) => T.Text -> Prog GPUMem -> m (ImpGen.Warnings, GC.CParts) compileProg version prog = do ( ws, Program cuda_code cuda_prelude macros kernels types params failures prog' ) <- ImpGen.compileProg prog (ws,) <$> GC.compileProg "cuda" version params operations (mkBoilerplate (cuda_prelude <> cuda_code) macros kernels types failures) cuda_includes (Space "device", [Space "device", DefaultSpace]) cliOptions prog' where operations :: GC.Operations OpenCL () operations = gpuOperations { GC.opsMemoryType = cudaMemoryType, GC.opsCritical = ( [C.citems|CUDA_SUCCEED_FATAL(cuCtxPushCurrent(ctx->cu_ctx));|], [C.citems|CUDA_SUCCEED_FATAL(cuCtxPopCurrent(&ctx->cu_ctx));|] ) } cuda_includes = [untrimming| #include #include #include |] futhark-0.25.27/src/Futhark/CodeGen/Backends/COpenCL.hs000066400000000000000000000215201475065116200223320ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | Code generation for C with OpenCL. module Futhark.CodeGen.Backends.COpenCL ( compileProg, GC.CParts (..), GC.asLibrary, GC.asExecutable, GC.asServer, ) where import Control.Monad.State import Data.Map qualified as M import Data.Text qualified as T import Futhark.CodeGen.Backends.GPU import Futhark.CodeGen.Backends.GenericC qualified as GC import Futhark.CodeGen.Backends.GenericC.Options import Futhark.CodeGen.ImpCode.OpenCL import Futhark.CodeGen.ImpGen.OpenCL qualified as ImpGen import Futhark.CodeGen.OpenCL.Heuristics import Futhark.CodeGen.RTS.C (backendsOpenclH) import Futhark.IR.GPUMem hiding ( CmpSizeLe, GetSize, GetSizeMax, ) import Futhark.MonadFreshNames import Language.C.Quote.OpenCL qualified as C import Language.C.Syntax qualified as C import NeatInterpolation (untrimming) sizeHeuristicsCode :: SizeHeuristic -> C.Stm sizeHeuristicsCode (SizeHeuristic platform_name device_type which (TPrimExp what)) = [C.cstm| if ($exp:which' == 0 && strstr(option->platform_name, $string:platform_name) != NULL && (option->device_type & $exp:(clDeviceType device_type)) == $exp:(clDeviceType device_type)) { $items:get_size }|] where clDeviceType DeviceGPU = [C.cexp|CL_DEVICE_TYPE_GPU|] clDeviceType DeviceCPU = [C.cexp|CL_DEVICE_TYPE_CPU|] which' = case which of LockstepWidth -> [C.cexp|ctx->lockstep_width|] NumBlocks -> [C.cexp|ctx->cfg->gpu.default_grid_size|] BlockSize -> [C.cexp|ctx->cfg->gpu.default_block_size|] TileSize -> [C.cexp|ctx->cfg->gpu.default_tile_size|] RegTileSize -> [C.cexp|ctx->cfg->gpu.default_reg_tile_size|] Threshold -> [C.cexp|ctx->cfg->gpu.default_threshold|] get_size = let (e, m) = runState (GC.compilePrimExp onLeaf what) mempty in concat (M.elems m) ++ [[C.citem|$exp:which' = $exp:e;|]] onLeaf (DeviceInfo s) = do let s' = "CL_DEVICE_" ++ s v = s ++ "_val" m <- get case M.lookup s m of Nothing -> -- XXX: Cheating with the type here; works for the infos we -- currently use because we zero-initialise and assume a -- little-endian platform, but should be made more -- size-aware in the future. modify $ M.insert s' [C.citems|size_t $id:v = 0; clGetDeviceInfo(ctx->device, $id:s', sizeof($id:v), &$id:v, NULL);|] Just _ -> pure () pure [C.cexp|$id:v|] mkBoilerplate :: T.Text -> [(Name, KernelConstExp)] -> M.Map Name KernelSafety -> [PrimType] -> [FailureMsg] -> GC.CompilerM OpenCL () () mkBoilerplate opencl_program macros kernels types failures = do generateGPUBoilerplate opencl_program macros backendsOpenclH (M.keys kernels) types failures GC.earlyDecl [C.cedecl|void post_opencl_setup(struct futhark_context *ctx, struct opencl_device_option *option) { $stms:(map sizeHeuristicsCode sizeHeuristicsTable) }|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_add_build_option(struct futhark_context_config *cfg, const char* opt);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_device(struct futhark_context_config *cfg, const char* s);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_platform(struct futhark_context_config *cfg, const char* s);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_select_device_interactively(struct futhark_context_config *cfg);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_list_devices(struct futhark_context_config *cfg);|] GC.headerDecl GC.InitDecl [C.cedecl|const char* futhark_context_config_get_program(struct futhark_context_config *cfg);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_program(struct futhark_context_config *cfg, const char* s);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_dump_binary_to(struct futhark_context_config *cfg, const char* s);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_load_binary_from(struct futhark_context_config *cfg, const char* s);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_command_queue(struct futhark_context_config *cfg, typename cl_command_queue);|] GC.headerDecl GC.MiscDecl [C.cedecl|typename cl_command_queue futhark_context_get_command_queue(struct futhark_context* ctx);|] cliOptions :: [Option] cliOptions = gpuOptions ++ [ Option { optionLongName = "platform", optionShortName = Just 'p', optionArgument = RequiredArgument "NAME", optionDescription = "Use the first OpenCL platform whose name contains the given string.", optionAction = [C.cstm|futhark_context_config_set_platform(cfg, optarg);|] }, Option { optionLongName = "dump-opencl", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Dump the embedded OpenCL program to the indicated file.", optionAction = [C.cstm|{const char* prog = futhark_context_config_get_program(cfg); if (dump_file(optarg, prog, strlen(prog)) != 0) { fprintf(stderr, "%s: %s\n", optarg, strerror(errno)); exit(1); } exit(0);}|] }, Option { optionLongName = "load-opencl", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Instead of using the embedded OpenCL program, load it from the indicated file.", optionAction = [C.cstm|{ size_t n; const char *s = slurp_file(optarg, &n); if (s == NULL) { fprintf(stderr, "%s: %s\n", optarg, strerror(errno)); exit(1); } futhark_context_config_set_program(cfg, s); }|] }, Option { optionLongName = "dump-opencl-binary", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Dump the compiled version of the embedded OpenCL program to the indicated file.", optionAction = [C.cstm|{futhark_context_config_dump_binary_to(cfg, optarg); entry_point = NULL;}|] }, Option { optionLongName = "load-opencl-binary", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Load an OpenCL binary from the indicated file.", optionAction = [C.cstm|futhark_context_config_load_binary_from(cfg, optarg);|] }, Option { optionLongName = "build-option", optionShortName = Nothing, optionArgument = RequiredArgument "OPT", optionDescription = "Add an additional build option to the string passed to clBuildProgram().", optionAction = [C.cstm|futhark_context_config_add_build_option(cfg, optarg);|] }, Option { optionLongName = "list-devices", optionShortName = Nothing, optionArgument = NoArgument, optionDescription = "List all OpenCL devices and platforms available on the system.", optionAction = [C.cstm|{futhark_context_config_list_devices(cfg); entry_point = NULL;}|] } ] openclMemoryType :: GC.MemoryType OpenCL () openclMemoryType "device" = pure [C.cty|typename cl_mem|] openclMemoryType space = error $ "GPU backend does not support '" ++ space ++ "' memory space." -- | Compile the program to C with calls to OpenCL. compileProg :: (MonadFreshNames m) => T.Text -> Prog GPUMem -> m (ImpGen.Warnings, GC.CParts) compileProg version prog = do ( ws, Program opencl_code opencl_prelude macros kernels types params failures prog' ) <- ImpGen.compileProg prog (ws,) <$> GC.compileProg "opencl" version params operations (mkBoilerplate (opencl_prelude <> opencl_code) macros kernels types failures) opencl_includes (Space "device", [Space "device", DefaultSpace]) cliOptions prog' where operations :: GC.Operations OpenCL () operations = gpuOperations { GC.opsMemoryType = openclMemoryType } opencl_includes = [untrimming| #define CL_TARGET_OPENCL_VERSION 120 #define CL_USE_DEPRECATED_OPENCL_1_2_APIS #ifdef __APPLE__ #define CL_SILENCE_DEPRECATION #include #else #include #endif |] futhark-0.25.27/src/Futhark/CodeGen/Backends/GPU.hs000066400000000000000000000506071475065116200215520ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | C code generation for GPU, in general. -- -- This module generates codes that targets the tiny GPU API -- abstraction layer we define in the runtime system. module Futhark.CodeGen.Backends.GPU ( gpuOperations, gpuOptions, generateGPUBoilerplate, ) where import Control.Monad import Control.Monad.Identity import Data.Bifunctor (bimap) import Data.Map qualified as M import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC qualified as GC import Futhark.CodeGen.Backends.GenericC.Options import Futhark.CodeGen.Backends.GenericC.Pretty (expText, idText) import Futhark.CodeGen.Backends.SimpleRep (primStorageType, toStorage) import Futhark.CodeGen.ImpCode.OpenCL import Futhark.CodeGen.RTS.C (gpuH, gpuPrototypesH) import Futhark.MonadFreshNames import Futhark.Util (chunk) import Futhark.Util.Pretty (prettyTextOneLine) import Language.C.Quote.OpenCL qualified as C import Language.C.Syntax qualified as C genKernelFunction :: KernelName -> KernelSafety -> [C.Param] -> [(C.Exp, C.Exp)] -> GC.CompilerM op s Name genKernelFunction kernel_name safety arg_params arg_set = do let kernel_fname = "gpu_kernel_" <> kernel_name GC.libDecl [C.cedecl|static int $id:kernel_fname (struct futhark_context* ctx, unsigned int grid_x, unsigned int grid_y, unsigned int grid_z, unsigned int block_x, unsigned int block_y, unsigned int block_z, unsigned int shared_bytes, $params:arg_params) { if (grid_x * grid_y * grid_z * block_x * block_y * block_z != 0) { void* args[$int:num_args] = { $inits:(failure_inits<>args_inits) }; size_t args_sizes[$int:num_args] = { $inits:(failure_sizes<>args_sizes) }; return gpu_launch_kernel(ctx, ctx->program->$id:kernel_name, $string:(prettyString kernel_name), (const typename int32_t[]){grid_x, grid_y, grid_z}, (const typename int32_t[]){block_x, block_y, block_z}, shared_bytes, $int:num_args, args, args_sizes); } return FUTHARK_SUCCESS; }|] pure kernel_fname where num_args = numFailureParams safety + length arg_set expToInit e = [C.cinit|$exp:e|] (args_sizes, args_inits) = bimap (map expToInit) (map expToInit) $ unzip arg_set (failure_inits, failure_sizes) = unzip . take (numFailureParams safety) $ [ ([C.cinit|&ctx->global_failure|], [C.cinit|sizeof(ctx->global_failure)|]), ([C.cinit|&ctx->failure_is_an_option|], [C.cinit|sizeof(ctx->failure_is_an_option)|]), ([C.cinit|&ctx->global_failure_args|], [C.cinit|sizeof(ctx->global_failure_args)|]) ] getParamByKey :: Name -> C.Exp getParamByKey key = [C.cexp|*ctx->tuning_params.$id:key|] kernelConstToExp :: KernelConst -> C.Exp kernelConstToExp (SizeConst key _) = getParamByKey key kernelConstToExp (SizeMaxConst size_class) = [C.cexp|ctx->$id:field|] where field = "max_" <> prettyString size_class compileBlockDim :: BlockDim -> GC.CompilerM op s C.Exp compileBlockDim (Left e) = GC.compileExp e compileBlockDim (Right e) = pure $ compileConstExp e genLaunchKernel :: KernelSafety -> KernelName -> Count Bytes (TExp Int64) -> [KernelArg] -> [Exp] -> [BlockDim] -> GC.CompilerM op s () genLaunchKernel safety kernel_name shared_memory args num_tblocks tblock_size = do (arg_params, arg_params_inits, call_args) <- unzip3 <$> zipWithM mkArgs [(0 :: Int) ..] args (grid_x, grid_y, grid_z) <- mkDims <$> mapM GC.compileExp num_tblocks (block_x, block_y, block_z) <- mkDims <$> mapM compileBlockDim tblock_size kernel_fname <- genKernelFunction kernel_name safety arg_params arg_params_inits shared_memory' <- GC.compileExp $ untyped $ unCount shared_memory GC.stm [C.cstm|{ err = $id:kernel_fname(ctx, $exp:grid_x, $exp:grid_y, $exp:grid_z, $exp:block_x, $exp:block_y, $exp:block_z, $exp:shared_memory', $args:call_args); if (err != FUTHARK_SUCCESS) { goto cleanup; } }|] when (safety >= SafetyFull) $ GC.stm [C.cstm|ctx->failure_is_an_option = 1;|] where mkDims [] = ([C.cexp|0|], [C.cexp|0|], [C.cexp|0|]) mkDims [x] = (x, [C.cexp|1|], [C.cexp|1|]) mkDims [x, y] = (x, y, [C.cexp|1|]) mkDims (x : y : z : _) = (x, y, z) mkArgs i (ValueKArg e t) = do let arg = "arg" <> show i e' <- GC.compileExp e pure ( [C.cparam|$ty:(primStorageType t) $id:arg|], ([C.cexp|sizeof($id:arg)|], [C.cexp|&$id:arg|]), toStorage t e' ) mkArgs i (MemKArg v) = do let arg = "arg" <> show i v' <- GC.rawMem v pure ( [C.cparam|typename gpu_mem $id:arg|], ([C.cexp|sizeof($id:arg)|], [C.cexp|&$id:arg|]), v' ) callKernel :: GC.OpCompiler OpenCL () callKernel (GetSize v key) = GC.stm [C.cstm|$id:v = $exp:(getParamByKey key);|] callKernel (CmpSizeLe v key x) = do x' <- GC.compileExp x GC.stm [C.cstm|$id:v = $exp:(getParamByKey key) <= $exp:x';|] -- Output size information if logging is enabled. The autotuner -- depends on the format of this output, so use caution if changing -- it. GC.stm [C.cstm|if (ctx->logging) { fprintf(ctx->log, "Compared %s <= %ld: %s.\n", $string:(T.unpack (prettyTextOneLine key)), (long)$exp:x', $id:v ? "true" : "false"); }|] callKernel (GetSizeMax v size_class) = do let e = kernelConstToExp $ SizeMaxConst size_class GC.stm [C.cstm|$id:v = $exp:e;|] callKernel (LaunchKernel safety kernel_name shared_memory args num_tblocks tblock_size) = genLaunchKernel safety kernel_name shared_memory args num_tblocks tblock_size copygpu2gpu :: GC.DoCopy op s copygpu2gpu _ t shape dst (dstoffset, dststride) src (srcoffset, srcstride) = do let fname = "lmad_copy_gpu2gpu_" <> show (primByteSize t :: Int) <> "b" r = length shape dststride_inits = [[C.cinit|$exp:e|] | Count e <- dststride] srcstride_inits = [[C.cinit|$exp:e|] | Count e <- srcstride] shape_inits = [[C.cinit|$exp:e|] | Count e <- shape] GC.stm [C.cstm| if ((err = $id:fname(ctx, $int:r, $exp:dst, $exp:(unCount dstoffset), (typename int64_t[]){ $inits:dststride_inits }, $exp:src, $exp:(unCount srcoffset), (typename int64_t[]){ $inits:srcstride_inits }, (typename int64_t[]){ $inits:shape_inits })) != 0) { goto cleanup; } |] copyhost2gpu :: GC.DoCopy op s copyhost2gpu sync t shape dst (dstoffset, dststride) src (srcoffset, srcstride) = do let r = length shape dststride_inits = [[C.cinit|$exp:e|] | Count e <- dststride] srcstride_inits = [[C.cinit|$exp:e|] | Count e <- srcstride] shape_inits = [[C.cinit|$exp:e|] | Count e <- shape] GC.stm [C.cstm| if ((err = lmad_copy_host2gpu (ctx, $int:(primByteSize t::Int), $exp:sync', $int:r, $exp:dst, $exp:(unCount dstoffset), (typename int64_t[]){ $inits:dststride_inits }, $exp:src, $exp:(unCount srcoffset), (typename int64_t[]){ $inits:srcstride_inits }, (typename int64_t[]){ $inits:shape_inits })) != 0) { goto cleanup; } |] where sync' = case sync of GC.CopyBarrier -> [C.cexp|true|] GC.CopyNoBarrier -> [C.cexp|false|] copygpu2host :: GC.DoCopy op s copygpu2host sync t shape dst (dstoffset, dststride) src (srcoffset, srcstride) = do let r = length shape dststride_inits = [[C.cinit|$exp:e|] | Count e <- dststride] srcstride_inits = [[C.cinit|$exp:e|] | Count e <- srcstride] shape_inits = [[C.cinit|$exp:e|] | Count e <- shape] GC.stm [C.cstm| if ((err = lmad_copy_gpu2host (ctx, $int:(primByteSize t::Int), $exp:sync', $int:r, $exp:dst, $exp:(unCount dstoffset), (typename int64_t[]){ $inits:dststride_inits }, $exp:src, $exp:(unCount srcoffset), (typename int64_t[]){ $inits:srcstride_inits }, (typename int64_t[]){ $inits:shape_inits })) != 0) { goto cleanup; } |] where sync' = case sync of GC.CopyBarrier -> [C.cexp|true|] GC.CopyNoBarrier -> [C.cexp|false|] gpuCopies :: M.Map (Space, Space) (GC.DoCopy op s) gpuCopies = M.fromList [ ((Space "device", Space "device"), copygpu2gpu), ((Space "device", DefaultSpace), copyhost2gpu), ((DefaultSpace, Space "device"), copygpu2host) ] createKernels :: [KernelName] -> GC.CompilerM op s () createKernels kernels = forM_ kernels $ \name -> GC.contextFieldDyn (C.toIdent name mempty) [C.cty|typename gpu_kernel|] [C.cstm|gpu_create_kernel(ctx, &ctx->program->$id:name, $string:(T.unpack (idText (C.toIdent name mempty))));|] [C.cstm|gpu_free_kernel(ctx, ctx->program->$id:name);|] allocateGPU :: GC.Allocate op () allocateGPU mem size tag "device" = GC.stm [C.cstm|(void)gpu_alloc(ctx, ctx->log, (size_t)$exp:size, $exp:tag, &$exp:mem, (size_t*)&$exp:size);|] allocateGPU _ _ _ space = error $ "Cannot allocate in '" ++ space ++ "' memory space." deallocateGPU :: GC.Deallocate op () deallocateGPU mem size tag "device" = GC.stm [C.cstm|(void)gpu_free(ctx, $exp:mem, $exp:size, $exp:tag);|] deallocateGPU _ _ _ space = error $ "Cannot deallocate in '" ++ space ++ "' space" -- It is often faster to do a blocking clEnqueueReadBuffer() than to -- do an async clEnqueueReadBuffer() followed by a clFinish(), even -- with an in-order command queue. This is safe if and only if there -- are no possible outstanding failures. readScalarGPU :: GC.ReadScalar op () readScalarGPU mem i t "device" _ = do val <- newVName "read_res" GC.decl [C.cdecl|$ty:t $id:val;|] GC.stm [C.cstm|if ((err = gpu_scalar_from_device(ctx, &$id:val, $exp:mem, $exp:i * sizeof($ty:t), sizeof($ty:t))) != 0) { goto cleanup; }|] GC.stm [C.cstm|if (ctx->failure_is_an_option && futhark_context_sync(ctx) != 0) { err = 1; goto cleanup; }|] pure [C.cexp|$id:val|] readScalarGPU _ _ _ space _ = error $ "Cannot read from '" ++ space ++ "' memory space." -- TODO: Optimised special case when the scalar is a constant, in -- which case we can do the write asynchronously. writeScalarGPU :: GC.WriteScalar op () writeScalarGPU mem i t "device" _ val = do val' <- newVName "write_tmp" GC.item [C.citem|$ty:t $id:val' = $exp:val;|] GC.stm [C.cstm|if ((err = gpu_scalar_to_device(ctx, $exp:mem, $exp:i * sizeof($ty:t), sizeof($ty:t), &$id:val')) != 0) { goto cleanup; }|] writeScalarGPU _ _ _ space _ _ = error $ "Cannot write to '" ++ space ++ "' memory space." syncArg :: GC.CopyBarrier -> C.Exp syncArg GC.CopyBarrier = [C.cexp|true|] syncArg GC.CopyNoBarrier = [C.cexp|false|] copyGPU :: GC.Copy OpenCL () copyGPU _ dstmem dstidx (Space "device") srcmem srcidx (Space "device") nbytes = GC.stm [C.cstm|err = gpu_memcpy(ctx, $exp:dstmem, $exp:dstidx, $exp:srcmem, $exp:srcidx, $exp:nbytes);|] copyGPU b dstmem dstidx DefaultSpace srcmem srcidx (Space "device") nbytes = GC.stm [C.cstm|err = memcpy_gpu2host(ctx, $exp:(syncArg b), $exp:dstmem, $exp:dstidx, $exp:srcmem, $exp:srcidx, $exp:nbytes);|] copyGPU b dstmem dstidx (Space "device") srcmem srcidx DefaultSpace nbytes = GC.stm [C.cstm|err = memcpy_host2gpu(ctx, $exp:(syncArg b), $exp:dstmem, $exp:dstidx, $exp:srcmem, $exp:srcidx, $exp:nbytes);|] copyGPU _ _ _ destspace _ _ srcspace _ = error $ "Cannot copy to " ++ show destspace ++ " from " ++ show srcspace gpuOperations :: GC.Operations OpenCL () gpuOperations = GC.defaultOperations { GC.opsCompiler = callKernel, GC.opsWriteScalar = writeScalarGPU, GC.opsReadScalar = readScalarGPU, GC.opsAllocate = allocateGPU, GC.opsDeallocate = deallocateGPU, GC.opsCopy = copyGPU, GC.opsCopies = gpuCopies <> GC.opsCopies GC.defaultOperations, GC.opsFatMemory = True } -- | Options that are common to multiple GPU-like backends. gpuOptions :: [Option] gpuOptions = [ Option { optionLongName = "device", optionShortName = Just 'd', optionArgument = RequiredArgument "NAME", optionDescription = "Use the first device whose name contains the given string.", optionAction = [C.cstm|futhark_context_config_set_device(cfg, optarg);|] }, Option { optionLongName = "default-thread-block-size", optionShortName = Nothing, optionArgument = RequiredArgument "INT", optionDescription = "The default size of thread blocks that are launched.", optionAction = [C.cstm|futhark_context_config_set_default_thread_block_size(cfg, atoi(optarg));|] }, Option { optionLongName = "default-grid-size", optionShortName = Nothing, optionArgument = RequiredArgument "INT", optionDescription = "The default number of thread blocks that are launched.", optionAction = [C.cstm|futhark_context_config_set_default_grid_size(cfg, atoi(optarg));|] }, Option { optionLongName = "default-group-size", optionShortName = Nothing, optionArgument = RequiredArgument "INT", optionDescription = "Alias for --default-thread-block-size.", optionAction = [C.cstm|futhark_context_config_set_default_group_size(cfg, atoi(optarg));|] }, Option { optionLongName = "default-num-groups", optionShortName = Nothing, optionArgument = RequiredArgument "INT", optionDescription = "Alias for --default-num-thread-blocks.", optionAction = [C.cstm|futhark_context_config_set_default_num_groups(cfg, atoi(optarg));|] }, Option { optionLongName = "default-tile-size", optionShortName = Nothing, optionArgument = RequiredArgument "INT", optionDescription = "The default tile size for two-dimensional tiling.", optionAction = [C.cstm|futhark_context_config_set_default_tile_size(cfg, atoi(optarg));|] }, Option { optionLongName = "default-reg-tile-size", optionShortName = Nothing, optionArgument = RequiredArgument "INT", optionDescription = "The default register tile size for two-dimensional tiling.", optionAction = [C.cstm|futhark_context_config_set_default_reg_tile_size(cfg, atoi(optarg));|] }, Option { optionLongName = "default-registers", optionShortName = Nothing, optionArgument = RequiredArgument "INT", optionDescription = "The amount of register memory in bytes.", optionAction = [C.cstm|futhark_context_config_set_default_registers(cfg, atoi(optarg));|] }, Option { optionLongName = "default-cache", optionShortName = Nothing, optionArgument = RequiredArgument "INT", optionDescription = "The amount of register memory in bytes.", optionAction = [C.cstm|futhark_context_config_set_default_cache(cfg, atoi(optarg));|] }, Option { optionLongName = "default-threshold", optionShortName = Nothing, optionArgument = RequiredArgument "INT", optionDescription = "The default parallelism threshold.", optionAction = [C.cstm|futhark_context_config_set_default_threshold(cfg, atoi(optarg));|] }, Option { optionLongName = "unified-memory", optionShortName = Nothing, optionArgument = RequiredArgument "INT", optionDescription = "Whether to use unified memory", optionAction = [C.cstm|futhark_context_config_set_unified_memory(cfg, atoi(optarg));|] } ] errorMsgNumArgs :: ErrorMsg a -> Int errorMsgNumArgs = length . errorMsgArgTypes failureMsgFunction :: [FailureMsg] -> C.Definition failureMsgFunction failures = let printfEscape = let escapeChar '%' = "%%" escapeChar c = [c] in concatMap escapeChar onPart (ErrorString s) = printfEscape $ T.unpack s -- FIXME: bogus for non-ints. onPart ErrorVal {} = "%lld" onFailure i (FailureMsg emsg@(ErrorMsg parts) backtrace) = let msg = concatMap onPart parts ++ "\n" ++ printfEscape backtrace msgargs = [[C.cexp|args[$int:j]|] | j <- [0 .. errorMsgNumArgs emsg - 1]] in [C.cstm|case $int:i: {return msgprintf($string:msg, $args:msgargs); break;}|] failure_cases = zipWith onFailure [(0 :: Int) ..] failures in [C.cedecl|static char* get_failure_msg(int failure_idx, typename int64_t args[]) { (void)args; switch (failure_idx) { $stms:failure_cases } return strdup("Unknown error. This is a compiler bug."); }|] compileConstExp :: KernelConstExp -> C.Exp compileConstExp e = runIdentity $ GC.compilePrimExp (pure . kernelConstToExp) e -- | Called after most code has been generated to generate the bulk of -- the boilerplate. generateGPUBoilerplate :: T.Text -> [(Name, KernelConstExp)] -> T.Text -> [Name] -> [PrimType] -> [FailureMsg] -> GC.CompilerM OpenCL () () generateGPUBoilerplate gpu_program macros backendH kernels types failures = do createKernels kernels let gpu_program_fragments = -- Some C compilers limit the size of literal strings, so -- chunk the entire program into small bits here, and -- concatenate it again at runtime. [[C.cinit|$string:s|] | s <- chunk 2000 $ T.unpack gpu_program] program_fragments = gpu_program_fragments ++ [[C.cinit|NULL|]] f64_required | FloatType Float64 `elem` types = [C.cexp|1|] | otherwise = [C.cexp|0|] max_failure_args = foldl max 0 $ map (errorMsgNumArgs . failureError) failures setMacro i (name, e) = [C.cstm|{names[$int:i] = $string:(nameToString name); values[$int:i] = $esc:e';}|] where e' = T.unpack $ expText $ compileConstExp e mapM_ GC.earlyDecl [C.cunit|static const int max_failure_args = $int:max_failure_args; static const int f64_required = $exp:f64_required; static const char *gpu_program[] = {$inits:program_fragments}; $esc:(T.unpack gpuPrototypesH) $esc:(T.unpack backendH) $esc:(T.unpack gpuH) static int gpu_macros(struct futhark_context *ctx, char*** names_out, typename int64_t** values_out) { int num_macros = $int:(length macros); char** names = malloc(num_macros * sizeof(char*)); typename int64_t* values = malloc(num_macros * sizeof(int64_t)); $stms:(zipWith setMacro [(0::Int)..] macros) *names_out = names; *values_out = values; return num_macros; } |] GC.earlyDecl $ failureMsgFunction failures GC.generateProgramStruct GC.onClear [C.citem|if (ctx->error == NULL) { gpu_free_all(ctx); }|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_default_thread_block_size(struct futhark_context_config *cfg, int size);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_default_grid_size(struct futhark_context_config *cfg, int size);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_default_group_size(struct futhark_context_config *cfg, int size);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_default_num_groups(struct futhark_context_config *cfg, int size);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_default_tile_size(struct futhark_context_config *cfg, int size);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_default_reg_tile_size(struct futhark_context_config *cfg, int size);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_default_registers(struct futhark_context_config *cfg, int size);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_default_cache(struct futhark_context_config *cfg, int size);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_default_threshold(struct futhark_context_config *cfg, int size);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_unified_memory(struct futhark_context_config* cfg, int flag);|] futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericC.hs000066400000000000000000000575041475065116200226010ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | C code generation for whole programs, built on -- "Futhark.CodeGen.Backends.GenericC.Monad". Most of this module is -- concerned with constructing the C API. module Futhark.CodeGen.Backends.GenericC ( compileProg, compileProg', defaultOperations, ParamMap, CParts (..), asLibrary, asExecutable, asServer, module Futhark.CodeGen.Backends.GenericC.Monad, module Futhark.CodeGen.Backends.GenericC.Code, ) where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor (second) import Data.DList qualified as DL import Data.List qualified as L import Data.Loc import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC.CLI (cliDefs) import Futhark.CodeGen.Backends.GenericC.Code import Futhark.CodeGen.Backends.GenericC.EntryPoints import Futhark.CodeGen.Backends.GenericC.Fun import Futhark.CodeGen.Backends.GenericC.Monad import Futhark.CodeGen.Backends.GenericC.Options import Futhark.CodeGen.Backends.GenericC.Pretty import Futhark.CodeGen.Backends.GenericC.Server (serverDefs) import Futhark.CodeGen.Backends.GenericC.Types import Futhark.CodeGen.ImpCode import Futhark.CodeGen.RTS.C (cacheH, contextH, contextPrototypesH, copyH, errorsH, eventListH, freeListH, halfH, lockH, timingH, utilH) import Futhark.IR.GPU.Sizes import Futhark.Manifest qualified as Manifest import Futhark.MonadFreshNames import Futhark.Util (zEncodeText) import Language.C.Quote.OpenCL qualified as C import Language.C.Syntax qualified as C import NeatInterpolation (untrimming) defCall :: CallCompiler op s defCall dests fname args = do let out_args = [[C.cexp|&$id:d|] | d <- dests] args' = [C.cexp|ctx|] : out_args ++ args item [C.citem|if ($id:(funName fname)($args:args') != 0) { err = 1; goto cleanup; }|] defError :: ErrorCompiler op s defError msg stacktrace = do (formatstr, formatargs) <- errorMsgString msg let formatstr' = "Error: " <> formatstr <> "\n\nBacktrace:\n%s" items [C.citems|set_error(ctx, msgprintf($string:formatstr', $args:formatargs, $string:stacktrace)); err = FUTHARK_PROGRAM_ERROR; goto cleanup;|] lmadcopyCPU :: DoCopy op s lmadcopyCPU _ t shape dst (dstoffset, dststride) src (srcoffset, srcstride) = do let fname :: String (fname, ty) = case primByteSize t :: Int of 1 -> ("lmad_copy_1b", [C.cty|typename uint8_t|]) 2 -> ("lmad_copy_2b", [C.cty|typename uint16_t|]) 4 -> ("lmad_copy_4b", [C.cty|typename uint32_t|]) 8 -> ("lmad_copy_8b", [C.cty|typename uint64_t|]) k -> error $ "lmadcopyCPU: " <> show k r = length shape dststride_inits = [[C.cinit|$exp:e|] | Count e <- dststride] srcstride_inits = [[C.cinit|$exp:e|] | Count e <- srcstride] shape_inits = [[C.cinit|$exp:e|] | Count e <- shape] stm [C.cstm| $id:fname(ctx, $int:r, ($ty:ty*) $exp:dst, $exp:(unCount dstoffset), (typename int64_t[]){ $inits:dststride_inits }, ($ty:ty*) $exp:src, $exp:(unCount srcoffset), (typename int64_t[]){ $inits:srcstride_inits }, (typename int64_t[]){ $inits:shape_inits });|] -- | A set of operations that fail for every operation involving -- non-default memory spaces. Uses plain pointers and @malloc@ for -- memory management. defaultOperations :: Operations op s defaultOperations = Operations { opsWriteScalar = defWriteScalar, opsReadScalar = defReadScalar, opsAllocate = defAllocate, opsDeallocate = defDeallocate, opsCopy = defCopy, opsCopies = M.singleton (DefaultSpace, DefaultSpace) lmadcopyCPU, opsMemoryType = defMemoryType, opsCompiler = defCompiler, opsFatMemory = True, opsError = defError, opsCall = defCall, opsCritical = mempty } where defWriteScalar _ _ _ _ _ = error "Cannot write to non-default memory space because I am dumb" defReadScalar _ _ _ _ = error "Cannot read from non-default memory space" defAllocate _ _ _ = error "Cannot allocate in non-default memory space" defDeallocate _ _ = error "Cannot deallocate in non-default memory space" defCopy _ destmem destoffset DefaultSpace srcmem srcoffset DefaultSpace size = copyMemoryDefaultSpace destmem destoffset srcmem srcoffset size defCopy _ _ _ _ _ _ _ _ = error "Cannot copy to or from non-default memory space" defMemoryType _ = error "Has no type for non-default memory space" defCompiler _ = error "The default compiler cannot compile extended operations" declsCode :: (HeaderSection -> Bool) -> CompilerState s -> T.Text declsCode p = definitionsText . concatMap (DL.toList . snd) . filter (p . fst) . M.toList . compHeaderDecls initDecls, arrayDecls, opaqueDecls, opaqueTypeDecls, entryDecls, miscDecls :: CompilerState s -> T.Text initDecls = declsCode (== InitDecl) arrayDecls = declsCode isArrayDecl where isArrayDecl ArrayDecl {} = True isArrayDecl _ = False opaqueTypeDecls = declsCode isOpaqueTypeDecl where isOpaqueTypeDecl OpaqueTypeDecl {} = True isOpaqueTypeDecl _ = False opaqueDecls = declsCode isOpaqueDecl where isOpaqueDecl OpaqueDecl {} = True isOpaqueDecl _ = False entryDecls = declsCode (== EntryDecl) miscDecls = declsCode (== MiscDecl) defineMemorySpace :: Space -> CompilerM op s ([C.Definition], C.BlockItem) defineMemorySpace space = do rm <- rawMemCType space earlyDecl [C.cedecl|struct $id:sname { int *references; $ty:rm mem; typename int64_t size; const char *desc; };|] -- Unreferencing a memory block consists of decreasing its reference -- count and freeing the corresponding memory if the count reaches -- zero. free <- collect $ freeRawMem [C.cexp|block->mem|] [C.cexp|block->size|] space [C.cexp|desc|] ctx_ty <- contextType let unrefdef = [C.cedecl|int $id:(fatMemUnRef space) ($ty:ctx_ty *ctx, $ty:mty *block, const char *desc) { if (block->references != NULL) { *(block->references) -= 1; if (ctx->detail_memory) { fprintf(ctx->log, "Unreferencing block %s (allocated as %s) in %s: %d references remaining.\n", desc, block->desc, $string:spacedesc, *(block->references)); } if (*(block->references) == 0) { ctx->$id:usagename -= block->size; $items:free free(block->references); if (ctx->detail_memory) { fprintf(ctx->log, "%lld bytes freed (now allocated: %lld bytes)\n", (long long) block->size, (long long) ctx->$id:usagename); } } block->references = NULL; } return 0; }|] -- When allocating a memory block we initialise the reference count to 1. alloc <- collect $ allocRawMem [C.cexp|block->mem|] [C.cexp|size|] space [C.cexp|desc|] let allocdef = [C.cedecl|int $id:(fatMemAlloc space) ($ty:ctx_ty *ctx, $ty:mty *block, typename int64_t size, const char *desc) { if (size < 0) { futhark_panic(1, "Negative allocation of %lld bytes attempted for %s in %s.\n", (long long)size, desc, $string:spacedesc, ctx->$id:usagename); } int ret = $id:(fatMemUnRef space)(ctx, block, desc); if (ret != FUTHARK_SUCCESS) { return ret; } if (ctx->detail_memory) { fprintf(ctx->log, "Allocating %lld bytes for %s in %s (currently allocated: %lld bytes).\n", (long long) size, desc, $string:spacedesc, (long long) ctx->$id:usagename); } $items:alloc if (ctx->error == NULL) { block->references = (int*) malloc(sizeof(int)); *(block->references) = 1; block->size = size; block->desc = desc; long long new_usage = ctx->$id:usagename + size; if (ctx->detail_memory) { fprintf(ctx->log, "Received block of %lld bytes; now allocated: %lld bytes", (long long)block->size, new_usage); } ctx->$id:usagename = new_usage; if (new_usage > ctx->$id:peakname) { ctx->$id:peakname = new_usage; if (ctx->detail_memory) { fprintf(ctx->log, " (new peak).\n"); } } else if (ctx->detail_memory) { fprintf(ctx->log, ".\n"); } return FUTHARK_SUCCESS; } else { // We are naively assuming that any memory allocation error is due to OOM. // We preserve the original error so that a savvy user can perhaps find // glory despite our naiveté. // We cannot use set_error() here because we want to replace the old error. lock_lock(&ctx->error_lock); char *old_error = ctx->error; ctx->error = msgprintf("Failed to allocate memory in %s.\nAttempted allocation: %12lld bytes\nCurrently allocated: %12lld bytes\n%s", $string:spacedesc, (long long) size, (long long) ctx->$id:usagename, old_error); free(old_error); lock_unlock(&ctx->error_lock); return FUTHARK_OUT_OF_MEMORY; } }|] -- Memory setting - unreference the destination and increase the -- count of the source by one. let setdef = [C.cedecl|int $id:(fatMemSet space) ($ty:ctx_ty *ctx, $ty:mty *lhs, $ty:mty *rhs, const char *lhs_desc) { int ret = $id:(fatMemUnRef space)(ctx, lhs, lhs_desc); if (rhs->references != NULL) { (*(rhs->references))++; } *lhs = *rhs; return ret; } |] onClear [C.citem|ctx->$id:peakname = 0;|] let peakmsg = "\"" <> spacedesc <> "\": %lld" pure ( [unrefdef, allocdef, setdef], [C.citem|str_builder(&builder, $string:peakmsg, (long long) ctx->$id:peakname);|] ) where mty = fatMemType space (peakname, usagename, sname, spacedesc) = case space of Space sid -> ( C.toIdent ("peak_mem_usage_" ++ sid) noLoc, C.toIdent ("cur_mem_usage_" ++ sid) noLoc, C.toIdent ("memblock_" ++ sid) noLoc, "space '" ++ sid ++ "'" ) _ -> ( "peak_mem_usage_default", "cur_mem_usage_default", "memblock", "default space" ) -- | The result of compilation to C is multiple parts, which can be -- put together in various ways. The obvious way is to concatenate -- all of them, which yields a CLI program. Another is to compile the -- library part by itself, and use the header file to call into it. data CParts = CParts { cHeader :: T.Text, -- | Utility definitions that must be visible -- to both CLI and library parts. cUtils :: T.Text, cCLI :: T.Text, cServer :: T.Text, cLib :: T.Text, -- | The manifest, in JSON format. cJsonManifest :: T.Text } gnuSource :: T.Text gnuSource = [untrimming| // We need to define _GNU_SOURCE before // _any_ headers files are imported to get // the usage statistics of a thread (i.e. have RUSAGE_THREAD) on GNU/Linux // https://manpages.courier-mta.org/htmlman2/getrusage.2.html #ifndef _GNU_SOURCE // Avoid possible double-definition warning. #define _GNU_SOURCE #endif |] -- We may generate variables that are never used (e.g. for -- certificates) or functions that are never called (e.g. unused -- intrinsics), and generated code may have other cosmetic issues that -- compilers warn about. We disable these warnings to not clutter the -- compilation logs. disableWarnings :: T.Text disableWarnings = [untrimming| #ifdef __clang__ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-variable" #pragma clang diagnostic ignored "-Wunused-const-variable" #pragma clang diagnostic ignored "-Wparentheses" #pragma clang diagnostic ignored "-Wunused-label" #pragma clang diagnostic ignored "-Wunused-but-set-variable" #elif __GNUC__ #pragma GCC diagnostic ignored "-Wunused-function" #pragma GCC diagnostic ignored "-Wunused-variable" #pragma GCC diagnostic ignored "-Wunused-const-variable" #pragma GCC diagnostic ignored "-Wparentheses" #pragma GCC diagnostic ignored "-Wunused-label" #pragma GCC diagnostic ignored "-Wunused-but-set-variable" #endif |] -- | Produce header, implementation, and manifest files. asLibrary :: CParts -> (T.Text, T.Text, T.Text) asLibrary parts = ( "#pragma once\n\n" <> cHeader parts, gnuSource <> disableWarnings <> cHeader parts <> cUtils parts <> cLib parts, cJsonManifest parts ) -- | As executable with command-line interface. asExecutable :: CParts -> T.Text asExecutable parts = gnuSource <> disableWarnings <> cHeader parts <> cUtils parts <> cCLI parts <> cLib parts -- | As server executable. asServer :: CParts -> T.Text asServer parts = gnuSource <> disableWarnings <> cHeader parts <> cUtils parts <> cServer parts <> cLib parts relevantParams :: Name -> ParamMap -> [Name] relevantParams fname m = map fst $ filter ((fname `S.member`) . snd . snd) $ M.toList m compileProg' :: (MonadFreshNames m) => T.Text -> T.Text -> ParamMap -> Operations op s -> s -> CompilerM op s () -> T.Text -> (Space, [Space]) -> [Option] -> Definitions op -> m (CParts, CompilerState s) compileProg' backend version params ops def extra header_extra (arr_space, spaces) options prog = do src <- getNameSource let ((prototypes, definitions, entry_point_decls, manifest), endstate) = runCompilerM ops src def compileProgAction initdecls = initDecls endstate entrydecls = entryDecls endstate arraydecls = arrayDecls endstate opaquetypedecls = opaqueTypeDecls endstate opaquedecls = opaqueDecls endstate miscdecls = miscDecls endstate let headerdefs = [untrimming| // Headers #include #include #include #include #include $header_extra #ifdef __cplusplus extern "C" { #endif // Initialisation $initdecls // Arrays $arraydecls // Opaque values $opaquetypedecls $opaquedecls // Entry points $entrydecls // Miscellaneous $miscdecls #define FUTHARK_BACKEND_$backend $errorsH #ifdef __cplusplus } #endif |] let utildefs = [untrimming| #include #include #include #include #include // If NDEBUG is set, the assert() macro will do nothing. Since Futhark // (unfortunately) makes use of assert() for error detection (and even some // side effects), we want to avoid that. #undef NDEBUG #include #include #define SCALAR_FUN_ATTR static inline $utilH $cacheH $halfH $timingH $lockH $freeListH $eventListH |] let early_decls = definitionsText $ DL.toList $ compEarlyDecls endstate lib_decls = definitionsText $ DL.toList $ compLibDecls endstate clidefs = cliDefs options manifest serverdefs = serverDefs options manifest libdefs = [untrimming| #ifdef _MSC_VER #define inline __inline #endif #include #include #include #include #include $header_extra #define FUTHARK_F64_ENABLED $cScalarDefs $contextPrototypesH $early_decls $contextH $copyH #define FUTHARK_FUN_ATTR static $prototypes $lib_decls $definitions $entry_point_decls |] pure ( CParts { cHeader = headerdefs, cUtils = utildefs, cCLI = clidefs, cServer = serverdefs, cLib = libdefs, cJsonManifest = Manifest.manifestToJSON manifest }, endstate ) where Definitions types consts (Functions funs) = prog compileProgAction = do (memfuns, memreport) <- mapAndUnzipM defineMemorySpace spaces get_consts <- compileConstants consts ctx_ty <- contextType (prototypes, functions) <- mapAndUnzipM (compileFun get_consts [[C.cparam|$ty:ctx_ty *ctx|]]) funs (entry_points, entry_points_manifest) <- fmap (unzip . catMaybes) $ forM funs $ \(fname, fun) -> onEntryPoint get_consts (relevantParams fname params) fname fun headerDecl InitDecl [C.cedecl|struct futhark_context_config;|] headerDecl InitDecl [C.cedecl|struct futhark_context_config* futhark_context_config_new(void);|] headerDecl InitDecl [C.cedecl|void futhark_context_config_free(struct futhark_context_config* cfg);|] headerDecl InitDecl [C.cedecl|int futhark_context_config_set_tuning_param(struct futhark_context_config *cfg, const char *param_name, size_t new_value);|] headerDecl InitDecl [C.cedecl|struct futhark_context;|] headerDecl InitDecl [C.cedecl|struct futhark_context* futhark_context_new(struct futhark_context_config* cfg);|] headerDecl InitDecl [C.cedecl|void futhark_context_free(struct futhark_context* cfg);|] headerDecl MiscDecl [C.cedecl|int futhark_context_sync(struct futhark_context* ctx);|] generateTuningParams params extra let set_tuning_params = zipWith (\i k -> [C.cstm|ctx->tuning_params.$id:k = &ctx->cfg->tuning_params[$int:i];|]) [(0 :: Int) ..] $ M.keys params earlyDecl [C.cedecl|static void set_tuning_params(struct futhark_context* ctx) { (void)ctx; $stms:set_tuning_params }|] mapM_ earlyDecl $ concat memfuns type_funs <- generateAPITypes arr_space types headerDecl InitDecl [C.cedecl|void futhark_context_config_set_debugging(struct futhark_context_config* cfg, int flag);|] headerDecl InitDecl [C.cedecl|void futhark_context_config_set_profiling(struct futhark_context_config* cfg, int flag);|] headerDecl InitDecl [C.cedecl|void futhark_context_config_set_logging(struct futhark_context_config* cfg, int flag);|] headerDecl MiscDecl [C.cedecl|void futhark_context_config_set_cache_file(struct futhark_context_config* cfg, const char *f);|] headerDecl InitDecl [C.cedecl|int futhark_get_tuning_param_count(void);|] headerDecl InitDecl [C.cedecl|const char* futhark_get_tuning_param_name(int);|] headerDecl InitDecl [C.cedecl|const char* futhark_get_tuning_param_class(int);|] headerDecl MiscDecl [C.cedecl|char* futhark_context_get_error(struct futhark_context* ctx);|] headerDecl MiscDecl [C.cedecl|void futhark_context_set_logging_file(struct futhark_context* ctx, typename FILE* f);|] headerDecl MiscDecl [C.cedecl|void futhark_context_pause_profiling(struct futhark_context* ctx);|] headerDecl MiscDecl [C.cedecl|void futhark_context_unpause_profiling(struct futhark_context* ctx);|] generateCommonLibFuns memreport pure ( definitionsText prototypes, funcsText functions, definitionsText entry_points, Manifest.Manifest (M.fromList entry_points_manifest) type_funs backend version ) -- | Compile imperative program to a C program. Always uses the -- function named "main" as entry point, so make sure it is defined. compileProg :: (MonadFreshNames m) => T.Text -> T.Text -> ParamMap -> Operations op () -> CompilerM op () () -> T.Text -> (Space, [Space]) -> [Option] -> Definitions op -> m CParts compileProg backend version params ops extra header_extra (arr_space, spaces) options prog = fst <$> compileProg' backend version params ops () extra header_extra (arr_space, spaces) options prog generateTuningParams :: ParamMap -> CompilerM op a () generateTuningParams params = do let (param_names, (param_classes, _param_users)) = second unzip $ unzip $ M.toList params strinit s = [C.cinit|$string:(T.unpack s)|] intinit x = [C.cinit|$int:x|] size_name_inits = map (strinit . prettyText) param_names size_var_inits = map (strinit . zEncodeText . prettyText) param_names size_class_inits = map (strinit . prettyText) param_classes size_default_inits = map (intinit . fromMaybe 0 . sizeDefault) param_classes size_decls = map (\k -> [C.csdecl|typename int64_t *$id:k;|]) param_names num_params = length params earlyDecl [C.cedecl|struct tuning_params { int dummy; $sdecls:size_decls };|] earlyDecl [C.cedecl|static const int num_tuning_params = $int:num_params;|] earlyDecl [C.cedecl|static const char *tuning_param_names[] = { $inits:size_name_inits, NULL };|] earlyDecl [C.cedecl|static const char *tuning_param_vars[] = { $inits:size_var_inits, NULL };|] earlyDecl [C.cedecl|static const char *tuning_param_classes[] = { $inits:size_class_inits, NULL };|] earlyDecl [C.cedecl|static typename int64_t tuning_param_defaults[] = { $inits:size_default_inits, 0 };|] generateCommonLibFuns :: [C.BlockItem] -> CompilerM op s () generateCommonLibFuns memreport = do ctx <- contextType ops <- asks envOperations sync <- publicName "context_sync" let comma = [C.citem|str_builder_char(&builder, ',');|] publicDef_ "context_report" MiscDecl $ \s -> ( [C.cedecl|char* $id:s($ty:ctx *ctx);|], [C.cedecl|char* $id:s($ty:ctx *ctx) { if ($id:sync(ctx) != 0) { return NULL; } struct str_builder builder; str_builder_init(&builder); str_builder_char(&builder, '{'); str_builder_str(&builder, "\"memory\":{"); $items:(L.intersperse comma memreport) str_builder_str(&builder, "},\"events\":["); if (report_events_in_list(&ctx->event_list, &builder) != 0) { free(builder.str); return NULL; } else { str_builder_str(&builder, "]}"); return builder.str; } }|] ) clears <- gets $ DL.toList . compClearItems publicDef_ "context_clear_caches" MiscDecl $ \s -> ( [C.cedecl|int $id:s($ty:ctx* ctx);|], [C.cedecl|int $id:s($ty:ctx* ctx) { $items:(criticalSection ops clears) return ctx->error != NULL; }|] ) compileConstants :: Constants op -> CompilerM op s [C.BlockItem] compileConstants (Constants ps init_consts) = do ctx_ty <- contextType const_fields <- mapM constParamField ps earlyDecl [C.cedecl|struct constants { int dummy; $sdecls:const_fields };|] inNewFunction $ do -- We locally define macros for the constants, so that when we -- generate assignments to local variables, we actually assign into -- the constants struct. This is not needed for functions, because -- they can only read constants, not write them. let (defs, undefs) = unzip $ map constMacro ps init_consts' <- collect $ do mapM_ resetMemConst ps compileCode init_consts decl_mem <- declAllocatedMem free_mem <- freeAllocatedMem libDecl [C.cedecl|static int init_constants($ty:ctx_ty *ctx) { (void)ctx; int err = 0; $items:defs $items:decl_mem $items:init_consts' $items:free_mem $items:undefs cleanup: return err; }|] inNewFunction $ do free_consts <- collect $ mapM_ freeConst ps libDecl [C.cedecl|static int free_constants($ty:ctx_ty *ctx) { (void)ctx; $items:free_consts return 0; }|] mapM getConst ps where constParamField (ScalarParam name bt) = do let ctp = primTypeToCType bt pure [C.csdecl|$ty:ctp $id:name;|] constParamField (MemParam name space) = do ty <- memToCType name space pure [C.csdecl|$ty:ty $id:name;|] constMacro p = ([C.citem|$escstm:def|], [C.citem|$escstm:undef|]) where p' = T.unpack $ idText (C.toIdent (paramName p) mempty) def = "#define " ++ p' ++ " (" ++ "ctx->constants->" ++ p' ++ ")" undef = "#undef " ++ p' resetMemConst ScalarParam {} = pure () resetMemConst (MemParam name space) = resetMem name space freeConst ScalarParam {} = pure () freeConst (MemParam name space) = unRefMem [C.cexp|ctx->constants->$id:name|] space getConst (ScalarParam name bt) = do let ctp = primTypeToCType bt pure [C.citem|$ty:ctp $id:name = ctx->constants->$id:name;|] getConst (MemParam name space) = do ty <- memToCType name space pure [C.citem|$ty:ty $id:name = ctx->constants->$id:name;|] futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericC/000077500000000000000000000000001475065116200222325ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericC/CLI.hs000066400000000000000000000421151475065116200232000ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | Code generation for standalone executables. module Futhark.CodeGen.Backends.GenericC.CLI ( cliDefs, ) where import Data.List (unzip5) import Data.Map qualified as M import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC.Options import Futhark.CodeGen.Backends.GenericC.Pretty import Futhark.CodeGen.Backends.SimpleRep ( cproduct, escapeName, primAPIType, primStorageType, scalarToPrim, ) import Futhark.CodeGen.RTS.C (tuningH, valuesH) import Futhark.Manifest import Futhark.Util.Pretty (prettyString) import Language.C.Quote.OpenCL qualified as C import Language.C.Syntax qualified as C genericOptions :: [Option] genericOptions = [ Option { optionLongName = "write-runtime-to", optionShortName = Just 't', optionArgument = RequiredArgument "FILE", optionDescription = "Print the time taken to execute the program to the indicated file, an integral number of microseconds.", optionAction = set_runtime_file }, Option { optionLongName = "runs", optionShortName = Just 'r', optionArgument = RequiredArgument "INT", optionDescription = "Perform NUM runs of the program.", optionAction = set_num_runs }, Option { optionLongName = "debugging", optionShortName = Just 'D', optionArgument = NoArgument, optionDescription = "Perform possibly expensive internal correctness checks and verbose logging.", optionAction = [C.cstm|{futhark_context_config_set_debugging(cfg, 1);}|] }, Option { optionLongName = "log", optionShortName = Just 'L', optionArgument = NoArgument, optionDescription = "Print various low-overhead logging information to stderr while running.", optionAction = [C.cstm|{futhark_context_config_set_logging(cfg, 1);}|] }, Option { optionLongName = "profile", optionShortName = Just 'P', optionArgument = NoArgument, optionDescription = "Enable the collection of profiling information.", optionAction = [C.cstm|futhark_context_config_set_profiling(cfg, 1);|] }, Option { optionLongName = "entry-point", optionShortName = Just 'e', optionArgument = RequiredArgument "NAME", optionDescription = "The entry point to run. Defaults to main.", optionAction = [C.cstm|if (entry_point != NULL) entry_point = optarg;|] }, Option { optionLongName = "binary-output", optionShortName = Just 'b', optionArgument = NoArgument, optionDescription = "Print the program result in the binary output format.", optionAction = [C.cstm|binary_output = 1;|] }, Option { optionLongName = "no-print-result", optionShortName = Just 'n', optionArgument = NoArgument, optionDescription = "Do not print the program result.", optionAction = [C.cstm|print_result = 0;|] }, Option { optionLongName = "help", optionShortName = Just 'h', optionArgument = NoArgument, optionDescription = "Print help information and exit.", optionAction = [C.cstm|{ printf("Usage: %s [OPTION]...\nOptions:\n\n%s\nFor more information, consult the Futhark User's Guide or the man pages.\n", fut_progname, option_descriptions); exit(0); }|] }, Option { optionLongName = "print-params", optionShortName = Nothing, optionArgument = NoArgument, optionDescription = "Print all tuning parameters that can be set with --param or --tuning.", optionAction = [C.cstm|{ int n = futhark_get_tuning_param_count(); for (int i = 0; i < n; i++) { printf("%s (%s)\n", futhark_get_tuning_param_name(i), futhark_get_tuning_param_class(i)); } exit(0); }|] }, Option { optionLongName = "param", optionShortName = Nothing, optionArgument = RequiredArgument "ASSIGNMENT", optionDescription = "Set a tuning parameter to the given value.", optionAction = [C.cstm|{ char *name = optarg; char *equals = strstr(optarg, "="); char *value_str = equals != NULL ? equals+1 : optarg; int value = atoi(value_str); if (equals != NULL) { *equals = 0; if (futhark_context_config_set_tuning_param(cfg, name, (size_t)value) != 0) { futhark_panic(1, "Unknown size: %s\n", name); } } else { futhark_panic(1, "Invalid argument for size option: %s\n", optarg); }}|] }, Option { optionLongName = "tuning", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Read size=value assignments from the given file.", optionAction = [C.cstm|{ char *ret = load_tuning_file(optarg, cfg, (int(*)(void*, const char*, size_t)) futhark_context_config_set_tuning_param); if (ret != NULL) { futhark_panic(1, "When loading tuning file '%s': %s\n", optarg, ret); }}|] }, Option { optionLongName = "cache-file", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Store program cache here.", optionAction = [C.cstm|futhark_context_config_set_cache_file(cfg, optarg);|] } ] where set_runtime_file = [C.cstm|{ runtime_file = fopen(optarg, "w"); if (runtime_file == NULL) { futhark_panic(1, "Cannot open %s: %s\n", optarg, strerror(errno)); } }|] set_num_runs = [C.cstm|{ num_runs = atoi(optarg); perform_warmup = 1; if (num_runs <= 0) { futhark_panic(1, "Need a positive number of runs, not %s\n", optarg); } }|] readInput :: Manifest -> Int -> T.Text -> ([C.BlockItem], C.Stm, C.Stm, C.Stm, C.Exp) readInput manifest i tname = case M.lookup tname $ manifestTypes manifest of Nothing -> let (_, t) = scalarToPrim tname dest = "read_value_" ++ show i info = T.unpack tname <> "_info" in ( [C.citems| $ty:(primStorageType t) $id:dest; if (read_scalar(stdin, &$id:info, &$id:dest) != 0) { futhark_panic(1, "Error when reading input #%d of type %s (errno: %s).\n", $int:i, $string:(T.unpack tname), strerror(errno)); };|], [C.cstm|;|], [C.cstm|;|], [C.cstm|;|], [C.cexp|$id:dest|] ) Just (TypeOpaque desc _ _) -> ( [C.citems|futhark_panic(1, "Cannot read input #%d of type %s\n", $int:i, $string:(T.unpack desc));|], [C.cstm|;|], [C.cstm|;|], [C.cstm|;|], [C.cexp|NULL|] ) Just (TypeArray t et rank ops) -> let dest = "read_value_" ++ show i shape = "read_shape_" ++ show i arr = "read_arr_" ++ show i ty = [C.cty|typename $id:t|] dims_exps = [[C.cexp|$id:shape[$int:j]|] | j <- [0 .. rank - 1]] t' = uncurry primAPIType $ scalarToPrim et new_array = arrayNew ops free_array = arrayFree ops info = T.unpack et <> "_info" items = [C.citems| $ty:ty $id:dest; typename int64_t $id:shape[$int:rank]; $ty:t' *$id:arr = NULL; errno = 0; if (read_array(stdin, &$id:info, (void**) &$id:arr, $id:shape, $int:rank) != 0) { futhark_panic(1, "Cannot read input #%d of type %s (errno: %s).\n", $int:i, $string:(T.unpack tname), strerror(errno)); }|] in ( items, [C.cstm|assert(($id:dest = $id:new_array(ctx, $id:arr, $args:dims_exps)) != NULL);|], [C.cstm|assert($id:free_array(ctx, $id:dest) == 0);|], [C.cstm|free($id:arr);|], [C.cexp|$id:dest|] ) readInputs :: Manifest -> [T.Text] -> [([C.BlockItem], C.Stm, C.Stm, C.Stm, C.Exp)] readInputs manifest = zipWith (readInput manifest) [0 ..] prepareOutputs :: Manifest -> [T.Text] -> [(C.BlockItem, C.Exp, C.Stm)] prepareOutputs manifest = zipWith prepareResult [(0 :: Int) ..] where prepareResult i tname = do let result = "result_" ++ show i case M.lookup tname $ manifestTypes manifest of Nothing -> let (s, pt) = scalarToPrim tname ty = primAPIType s pt in ( [C.citem|$ty:ty $id:result;|], [C.cexp|$id:result|], [C.cstm|;|] ) Just (TypeArray t _ _ ops) -> ( [C.citem|typename $id:t $id:result;|], [C.cexp|$id:result|], [C.cstm|assert($id:(arrayFree ops)(ctx, $id:result) == 0);|] ) Just (TypeOpaque t ops _) -> ( [C.citem|typename $id:t $id:result;|], [C.cexp|$id:result|], [C.cstm|assert($id:(opaqueFree ops)(ctx, $id:result) == 0);|] ) -- | Return a statement printing the given external value. printStm :: Manifest -> T.Text -> C.Exp -> C.Stm printStm manifest tname e = case M.lookup tname $ manifestTypes manifest of Nothing -> let info = tname <> "_info" in [C.cstm|write_scalar(stdout, binary_output, &$id:info, &$exp:e);|] Just (TypeOpaque desc _ _) -> [C.cstm|{ fprintf(stderr, "Values of type \"%s\" have no external representation.\n", $string:(T.unpack desc)); retval = 1; goto print_end; }|] Just (TypeArray _ et rank ops) -> let et' = uncurry primAPIType $ scalarToPrim et values_array = arrayValues ops shape_array = arrayShape ops num_elems = cproduct [[C.cexp|$id:shape_array(ctx, $exp:e)[$int:i]|] | i <- [0 .. rank - 1]] info = et <> "_info" in [C.cstm|{ $ty:et' *arr = calloc($exp:num_elems, $id:info.size); assert(arr != NULL); assert($id:values_array(ctx, $exp:e, arr) == 0); assert(futhark_context_sync(ctx) == 0); write_array(stdout, binary_output, &$id:info, arr, $id:shape_array(ctx, $exp:e), $int:rank); free(arr); }|] printResult :: Manifest -> [(T.Text, C.Exp)] -> [C.Stm] printResult manifest = concatMap f where f (v, e) = [printStm manifest v e, [C.cstm|printf("\n");|]] cliEntryPoint :: Manifest -> T.Text -> EntryPoint -> (C.Definition, C.Initializer) cliEntryPoint manifest entry_point_name (EntryPoint cfun _tuning_params outputs inputs) = let (input_items, pack_input, free_input, free_parsed, input_args) = unzip5 $ readInputs manifest $ map inputType inputs (output_decls, output_vals, free_outputs) = unzip3 $ prepareOutputs manifest $ map outputType outputs printstms = printResult manifest $ zip (map outputType outputs) output_vals cli_entry_point_function_name = "futrts_cli_entry_" <> T.unpack (escapeName entry_point_name) pause_profiling = "futhark_context_pause_profiling" :: T.Text unpause_profiling = "futhark_context_unpause_profiling" :: T.Text addrOf e = [C.cexp|&$exp:e|] run_it = [C.citems| int r; // Run the program once. $stms:pack_input if (futhark_context_sync(ctx) != 0) { futhark_panic(1, "%s", futhark_context_get_error(ctx)); }; // Only profile last run. if (profile_run) { $id:unpause_profiling(ctx); } t_start = get_wall_time(); r = $id:cfun(ctx, $args:(map addrOf output_vals), $args:input_args); if (r != 0) { futhark_panic(1, "%s", futhark_context_get_error(ctx)); } if (futhark_context_sync(ctx) != 0) { futhark_panic(1, "%s", futhark_context_get_error(ctx)); }; if (profile_run) { $id:pause_profiling(ctx); } t_end = get_wall_time(); long int elapsed_usec = t_end - t_start; if (time_runs && runtime_file != NULL) { fprintf(runtime_file, "%lld\n", (long long) elapsed_usec); fflush(runtime_file); } $stms:free_input |] in ( [C.cedecl| static int $id:cli_entry_point_function_name(struct futhark_context *ctx) { typename int64_t t_start, t_end; int time_runs = 0, profile_run = 0; int retval = 0; // We do not want to profile all the initialisation. $id:pause_profiling(ctx); // Declare and read input. set_binary_mode(stdin); $items:(mconcat input_items) if (end_of_input(stdin) != 0) { futhark_panic(1, "Expected EOF on stdin after reading input for \"%s\".\n", $string:(prettyString entry_point_name)); } $items:output_decls // Warmup run if (perform_warmup) { $items:run_it $stms:free_outputs } time_runs = 1; // Proper run. for (int run = 0; run < num_runs; run++) { // Only profile last run. profile_run = run == num_runs -1; $items:run_it if (run < num_runs-1) { $stms:free_outputs } } // Free the parsed input. $stms:free_parsed if (print_result) { // Print the final result. if (binary_output) { set_binary_mode(stdout); } $stms:printstms } print_end: {} $stms:free_outputs return retval; }|], [C.cinit|{ .name = $string:(T.unpack entry_point_name), .fun = $id:cli_entry_point_function_name }|] ) {-# NOINLINE cliDefs #-} -- | Generate Futhark standalone executable code. cliDefs :: [Option] -> Manifest -> T.Text cliDefs options manifest = let option_parser = generateOptionParser "parse_options" $ genericOptions ++ options (cli_entry_point_decls, entry_point_inits) = unzip $ map (uncurry (cliEntryPoint manifest)) $ M.toList $ manifestEntryPoints manifest in definitionsText [C.cunit| $esc:("#include ") $esc:("#include ") $esc:("#include ") $esc:("#include ") $esc:(T.unpack valuesH) static int binary_output = 0; static int print_result = 1; static typename FILE *runtime_file; static int perform_warmup = 0; static int num_runs = 1; // If the entry point is NULL, the program will terminate after doing initialisation and such. static const char *entry_point = "main"; $esc:(T.unpack tuningH) $func:option_parser $edecls:cli_entry_point_decls typedef int entry_point_fun(struct futhark_context*); struct entry_point_entry { const char *name; entry_point_fun *fun; }; int main(int argc, char** argv) { int retval = 0; fut_progname = argv[0]; struct futhark_context_config *cfg = futhark_context_config_new(); assert(cfg != NULL); int parsed_options = parse_options(cfg, argc, argv); argc -= parsed_options; argv += parsed_options; if (argc != 0) { futhark_panic(1, "Excess non-option: %s\n", argv[0]); } struct futhark_context *ctx = futhark_context_new(cfg); assert (ctx != NULL); char* error = futhark_context_get_error(ctx); if (error != NULL) { futhark_panic(1, "%s", error); } struct entry_point_entry entry_points[] = { $inits:entry_point_inits }; if (entry_point != NULL) { int num_entry_points = sizeof(entry_points) / sizeof(entry_points[0]); entry_point_fun *entry_point_fun = NULL; for (int i = 0; i < num_entry_points; i++) { if (strcmp(entry_points[i].name, entry_point) == 0) { entry_point_fun = entry_points[i].fun; break; } } if (entry_point_fun == NULL) { fprintf(stderr, "No entry point '%s'. Select another with --entry-point. Options are:\n", entry_point); for (int i = 0; i < num_entry_points; i++) { fprintf(stderr, "%s\n", entry_points[i].name); } return 1; } if (isatty(fileno(stdin))) { fprintf(stderr, "Reading input from TTY.\n"); fprintf(stderr, "Send EOF (CTRL-d) after typing all input values.\n"); } retval = entry_point_fun(ctx); if (runtime_file != NULL) { fclose(runtime_file); } } futhark_context_free(ctx); futhark_context_config_free(cfg); return retval; }|] futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericC/Code.hs000066400000000000000000000422731475065116200234500ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | Translation of ImpCode Exp and Code to C. module Futhark.CodeGen.Backends.GenericC.Code ( compilePrimExp, compileExp, compileCode, compileDest, compileArg, compileCopy, compileCopyWith, errorMsgString, linearCode, ) where import Control.Monad import Control.Monad.Identity import Control.Monad.Reader (asks) import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC.Monad import Futhark.CodeGen.Backends.GenericC.Pretty (expText, idText, typeText) import Futhark.CodeGen.ImpCode import Futhark.IR.Prop (isBuiltInFunction) import Futhark.MonadFreshNames import Language.C.Quote.OpenCL qualified as C import Language.C.Syntax qualified as C errorMsgString :: ErrorMsg Exp -> CompilerM op s (String, [C.Exp]) errorMsgString (ErrorMsg parts) = do let boolStr e = [C.cexp|($exp:e) ? "true" : "false"|] asLongLong e = [C.cexp|(long long int)$exp:e|] asDouble e = [C.cexp|(double)$exp:e|] onPart (ErrorString s) = pure ("%s", [C.cexp|$string:(T.unpack s)|]) onPart (ErrorVal Bool x) = ("%s",) . boolStr <$> compileExp x onPart (ErrorVal Unit _) = pure ("%s", [C.cexp|"()"|]) onPart (ErrorVal (IntType Int8) x) = ("%hhd",) <$> compileExp x onPart (ErrorVal (IntType Int16) x) = ("%hd",) <$> compileExp x onPart (ErrorVal (IntType Int32) x) = ("%d",) <$> compileExp x onPart (ErrorVal (IntType Int64) x) = ("%lld",) . asLongLong <$> compileExp x onPart (ErrorVal (FloatType Float16) x) = ("%f",) . asDouble <$> compileExp x onPart (ErrorVal (FloatType Float32) x) = ("%f",) . asDouble <$> compileExp x onPart (ErrorVal (FloatType Float64) x) = ("%f",) <$> compileExp x (formatstrs, formatargs) <- mapAndUnzipM onPart parts pure (mconcat formatstrs, formatargs) -- | Tell me how to compile a @v@, and I'll Compile any @PrimExp v@ for you. compilePrimExp :: (Monad m) => (v -> m C.Exp) -> PrimExp v -> m C.Exp compilePrimExp _ (ValueExp val) = pure $ C.toExp val mempty compilePrimExp f (LeafExp v _) = f v compilePrimExp f (UnOpExp Complement {} x) = do x' <- compilePrimExp f x pure [C.cexp|~$exp:x'|] compilePrimExp f (UnOpExp SSignum {} x) = do x' <- compilePrimExp f x pure [C.cexp|($exp:x' > 0 ? 1 : 0) - ($exp:x' < 0 ? 1 : 0)|] compilePrimExp f (UnOpExp USignum {} x) = do x' <- compilePrimExp f x pure [C.cexp|($exp:x' > 0 ? 1 : 0) - ($exp:x' < 0 ? 1 : 0) != 0|] compilePrimExp f (UnOpExp (Neg Bool) x) = do x' <- compilePrimExp f x pure [C.cexp|!$exp:x'|] compilePrimExp f (UnOpExp Neg {} x) = do x' <- compilePrimExp f x pure [C.cexp|-$exp:x'|] compilePrimExp f (UnOpExp op x) = do x' <- compilePrimExp f x pure [C.cexp|$id:(prettyString op)($exp:x')|] compilePrimExp f (CmpOpExp cmp x y) = do x' <- compilePrimExp f x y' <- compilePrimExp f y pure $ case cmp of CmpEq {} -> [C.cexp|$exp:x' == $exp:y'|] FCmpLt {} -> [C.cexp|$exp:x' < $exp:y'|] FCmpLe {} -> [C.cexp|$exp:x' <= $exp:y'|] CmpLlt {} -> [C.cexp|$exp:x' < $exp:y'|] CmpLle {} -> [C.cexp|$exp:x' <= $exp:y'|] _ -> [C.cexp|$id:(prettyString cmp)($exp:x', $exp:y')|] compilePrimExp f (ConvOpExp conv x) = do x' <- compilePrimExp f x pure [C.cexp|$id:(prettyString conv)($exp:x')|] compilePrimExp f (BinOpExp bop x y) = do x' <- compilePrimExp f x y' <- compilePrimExp f y -- Note that integer addition, subtraction, and multiplication with -- OverflowWrap are not handled by explicit operators, but rather by -- functions. This is because we want to implicitly convert them to -- unsigned numbers, so we can do overflow without invoking -- undefined behaviour. pure $ case bop of Add _ OverflowUndef -> [C.cexp|$exp:x' + $exp:y'|] Sub _ OverflowUndef -> [C.cexp|$exp:x' - $exp:y'|] Mul _ OverflowUndef -> [C.cexp|$exp:x' * $exp:y'|] FAdd {} -> [C.cexp|$exp:x' + $exp:y'|] FSub {} -> [C.cexp|$exp:x' - $exp:y'|] FMul {} -> [C.cexp|$exp:x' * $exp:y'|] FDiv {} -> [C.cexp|$exp:x' / $exp:y'|] Xor {} -> [C.cexp|$exp:x' ^ $exp:y'|] And {} -> [C.cexp|$exp:x' & $exp:y'|] Or {} -> [C.cexp|$exp:x' | $exp:y'|] LogAnd {} -> [C.cexp|$exp:x' && $exp:y'|] LogOr {} -> [C.cexp|$exp:x' || $exp:y'|] _ -> [C.cexp|$id:(prettyString bop)($exp:x', $exp:y')|] compilePrimExp f (FunExp h args _) = do args' <- mapM (compilePrimExp f) args pure [C.cexp|$id:(funName (nameFromText h))($args:args')|] -- | Compile prim expression to C expression. compileExp :: Exp -> CompilerM op s C.Exp compileExp = compilePrimExp $ \v -> pure [C.cexp|$id:v|] instance C.ToExp (TExp t) where toExp e _ = runIdentity . compilePrimExp (\v -> pure [C.cexp|$id:v|]) $ untyped e linearCode :: Code op -> [Code op] linearCode = reverse . go [] where go acc (x :>>: y) = go (go acc x) y go acc x = x : acc assignmentOperator :: BinOp -> Maybe (VName -> C.Exp -> C.Exp) assignmentOperator Add {} = Just $ \d e -> [C.cexp|$id:d += $exp:e|] assignmentOperator Sub {} = Just $ \d e -> [C.cexp|$id:d -= $exp:e|] assignmentOperator Mul {} = Just $ \d e -> [C.cexp|$id:d *= $exp:e|] assignmentOperator _ = Nothing generateRead :: C.Exp -> C.Exp -> PrimType -> Space -> Volatility -> CompilerM op s C.Exp generateRead _ _ Unit _ _ = pure [C.cexp|$exp:(UnitValue)|] generateRead src iexp _ ScalarSpace {} _ = pure [C.cexp|$exp:src[$exp:iexp]|] generateRead src iexp restype DefaultSpace vol = pure . fromStorage restype $ derefPointer src iexp [C.cty|$tyquals:(volQuals vol) $ty:(primStorageType restype)*|] generateRead src iexp restype (Space space) vol = do reader <- asks (opsReadScalar . envOperations) fromStorage restype <$> reader src iexp (primStorageType restype) space vol generateWrite :: C.Exp -> C.Exp -> PrimType -> Space -> Volatility -> C.Exp -> CompilerM op s () generateWrite _ _ Unit _ _ _ = pure () generateWrite dest idx _ ScalarSpace {} _ elemexp = do stm [C.cstm|$exp:dest[$exp:idx] = $exp:elemexp;|] generateWrite dest idx elemtype DefaultSpace vol elemexp = do let deref = derefPointer dest idx [C.cty|$tyquals:(volQuals vol) $ty:(primStorageType elemtype)*|] elemexp' = toStorage elemtype elemexp stm [C.cstm|$exp:deref = $exp:elemexp';|] generateWrite dest idx elemtype (Space space) vol elemexp = do writer <- asks (opsWriteScalar . envOperations) writer dest idx (primStorageType elemtype) space vol (toStorage elemtype elemexp) compileRead :: VName -> Count u (TPrimExp t VName) -> PrimType -> Space -> Volatility -> CompilerM op s C.Exp compileRead src (Count iexp) restype space vol = do src' <- rawMem src iexp' <- compileExp (untyped iexp) generateRead src' iexp' restype space vol memNeedsWrapping :: VName -> CompilerM op s Bool memNeedsWrapping v = do refcount <- fatMemory DefaultSpace cached <- isJust <$> cacheMem v pure $ refcount && cached -- | Compile an argument to a function applicaiton. compileArg :: Arg -> CompilerM op s C.Exp compileArg (MemArg m) = do -- Function might expect fat memory, so if this is a lexical/cached -- raw pointer, we have to wrap it in a struct. wrap <- memNeedsWrapping m if wrap then pure [C.cexp|($ty:(fatMemType DefaultSpace)) {.references = NULL, .mem = $exp:m}|] else pure [C.cexp|$exp:m|] compileArg (ExpArg e) = compileExp e -- | Prepare a destination for function application. compileDest :: VName -> CompilerM op s (VName, [C.Stm]) compileDest v = do -- Function result be fat memory, so if target is a raw pointer, we -- have to wrap it in a struct and unwrap it afterwards. wrap <- memNeedsWrapping v if wrap then do v' <- newVName $ baseString v <> "_struct" item [C.citem|$ty:(fatMemType DefaultSpace) $id:v' = {.references = NULL, .mem = $exp:v};|] pure (v', [C.cstms|$id:v = $id:v'.mem;|]) else pure (v, mempty) compileCode :: Code op -> CompilerM op s () compileCode (Op op) = join $ asks (opsCompiler . envOperations) <*> pure op compileCode Skip = pure () compileCode (Comment s code) = do xs <- collect $ compileCode code let comment = "// " ++ T.unpack s stm [C.cstm|$comment:comment { $items:xs } |] compileCode (TracePrint msg) = do (formatstr, formatargs) <- errorMsgString msg stm [C.cstm|fprintf(ctx->log, $string:formatstr, $args:formatargs);|] compileCode (DebugPrint s (Just e)) = do e' <- compileExp e stm [C.cstm|if (ctx->debugging) { fprintf(ctx->log, $string:fmtstr, $exp:s, ($ty:ety)$exp:e', '\n'); }|] where (fmt, ety) = case primExpType e of IntType _ -> ("llu", [C.cty|long long int|]) FloatType _ -> ("f", [C.cty|double|]) _ -> ("d", [C.cty|int|]) fmtstr = "%s: %" ++ fmt ++ "%c" compileCode (DebugPrint s Nothing) = stm [C.cstm|if (ctx->debugging) { fprintf(ctx->log, "%s\n", $exp:s); }|] -- :>>: is treated in a special way to detect declare-set pairs in -- order to generate prettier code. compileCode (c1 :>>: c2) = go (linearCode (c1 :>>: c2)) where go (DeclareScalar name vol t : SetScalar dest e : code) | name == dest = do let ct = primTypeToCType t e' <- compileExp e item [C.citem|$tyquals:(volQuals vol) $ty:ct $id:name = $exp:e';|] go code go (DeclareScalar name vol t : Read dest src i restype space read_vol : code) | name == dest = do let ct = primTypeToCType t e <- compileRead src i restype space read_vol item [C.citem|$tyquals:(volQuals vol) $ty:ct $id:name = $exp:e;|] go code go (DeclareScalar name vol t : Call [dest] fname args : code) | name == dest, isBuiltInFunction fname = do let ct = primTypeToCType t args' <- mapM compileArg args item [C.citem|$tyquals:(volQuals vol) $ty:ct $id:name = $id:(funName fname)($args:args');|] go code go (x : xs) = compileCode x >> go xs go [] = pure () compileCode (Assert e msg (loc, locs)) = do e' <- compileExp e err <- collect . join $ asks (opsError . envOperations) <*> pure msg <*> pure stacktrace stm [C.cstm|if (!$exp:e') { $items:err }|] where stacktrace = T.unpack $ prettyStacktrace 0 $ map locText $ loc : locs compileCode (Allocate _ _ ScalarSpace {}) = -- Handled by the declaration of the memory block, which is -- translated to an actual array. pure () compileCode (Allocate name (Count (TPrimExp e)) space) = do size <- compileExp e cached <- cacheMem name case cached of Just cur_size -> stm [C.cstm|if ($exp:cur_size < $exp:size) { err = lexical_realloc(ctx, &$exp:name, &$exp:cur_size, $exp:size); if (err != FUTHARK_SUCCESS) { goto cleanup; } }|] _ -> allocMem name size space [C.cstm|{err = 1; goto cleanup;}|] compileCode (Free name space) = do cached <- isJust <$> cacheMem name unless cached $ unRefMem name space compileCode (For i bound body) = do let i' = C.toIdent i t = primTypeToCType $ primExpType bound bound' <- compileExp bound body' <- collect $ compileCode body stm [C.cstm|for ($ty:t $id:i' = 0; $id:i' < $exp:bound'; $id:i'++) { $items:body' }|] compileCode (While cond body) = do cond' <- compileExp $ untyped cond body' <- collect $ compileCode body stm [C.cstm|while ($exp:cond') { $items:body' }|] compileCode (If cond tbranch fbranch) = do cond' <- compileExp $ untyped cond tbranch' <- collect $ compileCode tbranch fbranch' <- collect $ compileCode fbranch stm $ case (tbranch', fbranch') of (_, []) -> [C.cstm|if ($exp:cond') { $items:tbranch' }|] ([], _) -> [C.cstm|if (!($exp:cond')) { $items:fbranch' }|] (_, [C.BlockStm x@C.If {}]) -> [C.cstm|if ($exp:cond') { $items:tbranch' } else $stm:x|] _ -> [C.cstm|if ($exp:cond') { $items:tbranch' } else { $items:fbranch' }|] compileCode (Copy t shape (dst, dstspace) (dstoffset, dststrides) (src, srcspace) (srcoffset, srcstrides)) = do cp <- asks $ M.lookup (dstspace, srcspace) . opsCopies . envOperations case cp of Just cp' | t /= Unit -> do shape' <- traverse (traverse (compileExp . untyped)) shape dst' <- rawMem dst src' <- rawMem src dstoffset' <- traverse (compileExp . untyped) dstoffset dststrides' <- traverse (traverse (compileExp . untyped)) dststrides srcoffset' <- traverse (compileExp . untyped) srcoffset srcstrides' <- traverse (traverse (compileExp . untyped)) srcstrides cp' CopyBarrier t shape' dst' (dstoffset', dststrides') src' (srcoffset', srcstrides') _ -> compileCopy t shape (dst, dstspace) (dstoffset, dststrides) (src, srcspace) (srcoffset, srcstrides) compileCode (Write _ _ Unit _ _ _) = pure () compileCode (Write dst (Count idx) elemtype space vol elemexp) = do dst' <- rawMem dst idx' <- compileExp (untyped idx) elemexp' <- compileExp elemexp generateWrite dst' idx' elemtype space vol elemexp' compileCode (Read x src i restype space vol) = do e <- compileRead src i restype space vol stm [C.cstm|$id:x = $exp:e;|] compileCode (DeclareMem name space) = declMem name space compileCode (DeclareScalar name vol t) = do let ct = primTypeToCType t decl [C.cdecl|$tyquals:(volQuals vol) $ty:ct $id:name;|] compileCode (DeclareArray name t vs) = do name_realtype <- newVName $ baseString name ++ "_realtype" let ct = primTypeToCType t case vs of ArrayValues vs' -> do -- To handle very large literal arrays (which are inefficient -- with language-c-quote, see #2160), we do our own formatting and inject it as a string. let array_decl = "static " <> typeText ct <> " " <> idText (C.toIdent name_realtype mempty) <> "[" <> prettyText (length vs') <> "] = { " <> T.intercalate "," (map (expText . flip C.toExp mempty) vs') <> "};" earlyDecl [C.cedecl|$esc:(T.unpack array_decl)|] ArrayZeros n -> earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:n];|] -- Fake a memory block. item [C.citem|struct memblock $id:name = (struct memblock){NULL, (unsigned char*)$id:name_realtype, 0, $string:(prettyString name)};|] -- For assignments of the form 'x = x OP e', we generate C assignment -- operators to make the resulting code slightly nicer. This has no -- effect on performance. compileCode (SetScalar dest (BinOpExp op (LeafExp x _) y)) | dest == x, Just f <- assignmentOperator op = do y' <- compileExp y stm [C.cstm|$exp:(f dest y');|] compileCode (SetScalar dest src) = do src' <- compileExp src stm [C.cstm|$id:dest = $exp:src';|] compileCode (SetMem dest src space) = setMem dest src space compileCode (Call [dest] fname args) | isBuiltInFunction fname = do args' <- mapM compileArg args stm [C.cstm|$id:dest = $id:(funName fname)($args:args');|] compileCode (Call dests fname args) = do (dests', unpack_dest) <- mapAndUnzipM compileDest dests join $ asks (opsCall . envOperations) <*> pure dests' <*> pure fname <*> mapM compileArg args stms $ mconcat unpack_dest -- | Compile an 'Copy' using sequential nested loops, but -- parameterised over how to do the reads and writes. compileCopyWith :: [Count Elements (TExp Int64)] -> (C.Exp -> C.Exp -> CompilerM op s ()) -> ( Count Elements (TExp Int64), [Count Elements (TExp Int64)] ) -> (C.Exp -> CompilerM op s C.Exp) -> ( Count Elements (TExp Int64), [Count Elements (TExp Int64)] ) -> CompilerM op s () compileCopyWith shape doWrite dst_lmad doRead src_lmad = do let (dstoffset, dststrides) = dst_lmad (srcoffset, srcstrides) = src_lmad shape' <- mapM (compileExp . untyped . unCount) shape body <- collect $ do dst_i <- compileExp . untyped . unCount $ dstoffset + sum (zipWith (*) is' dststrides) src_i <- compileExp . untyped . unCount $ srcoffset + sum (zipWith (*) is' srcstrides) doWrite dst_i =<< doRead src_i items $ loops (zip is shape') body where r = length shape is = map (VName "i") [0 .. r - 1] is' :: [Count Elements (TExp Int64)] is' = map (elements . le64) is loops [] body = body loops ((i, n) : ins) body = [C.citems|for (typename int64_t $id:i = 0; $id:i < $exp:n; $id:i++) { $items:(loops ins body) }|] -- | Compile an 'Copy' using sequential nested loops and -- t'Read'/t'Write' of individual scalars. This always works, but can -- be pretty slow if those reads and writes are costly. compileCopy :: PrimType -> [Count Elements (TExp Int64)] -> (VName, Space) -> ( Count Elements (TExp Int64), [Count Elements (TExp Int64)] ) -> (VName, Space) -> ( Count Elements (TExp Int64), [Count Elements (TExp Int64)] ) -> CompilerM op s () compileCopy t shape (dst, dstspace) dst_lmad (src, srcspace) src_lmad = do src' <- rawMem src dst' <- rawMem dst let doWrite dst_i = generateWrite dst' dst_i t dstspace Nonvolatile doRead src_i = generateRead src' src_i t srcspace Nonvolatile compileCopyWith shape doWrite dst_lmad doRead src_lmad futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericC/EntryPoints.hs000066400000000000000000000205041475065116200250650ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | Generate the entry point packing/unpacking code. module Futhark.CodeGen.Backends.GenericC.EntryPoints ( onEntryPoint, ) where import Control.Monad import Control.Monad.Reader (asks) import Data.Maybe import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC.Monad import Futhark.CodeGen.Backends.GenericC.Types (opaqueToCType, valueTypeToCType) import Futhark.CodeGen.ImpCode import Futhark.Manifest qualified as Manifest import Language.C.Quote.OpenCL qualified as C import Language.C.Syntax qualified as C valueDescToType :: ValueDesc -> ValueType valueDescToType (ScalarValue pt signed _) = ValueType signed (Rank 0) pt valueDescToType (ArrayValue _ _ pt signed shape) = ValueType signed (Rank (length shape)) pt prepareEntryInputs :: [ExternalValue] -> CompilerM op s ([(C.Param, Maybe C.Exp)], [C.BlockItem]) prepareEntryInputs args = collect' $ zipWithM prepare [(0 :: Int) ..] args where arg_names = namesFromList $ concatMap evNames args evNames (OpaqueValue _ vds) = map vdName vds evNames (TransparentValue vd) = [vdName vd] vdName (ArrayValue v _ _ _ _) = v vdName (ScalarValue _ _ v) = v prepare pno (TransparentValue vd) = do let pname = "in" ++ show pno (ty, check) <- prepareValue Public [C.cexp|$id:pname|] vd pure ( [C.cparam|const $ty:ty $id:pname|], if null check then Nothing else Just $ allTrue check ) prepare pno (OpaqueValue desc vds) = do ty <- opaqueToCType desc let pname = "in" ++ show pno field i ScalarValue {} = [C.cexp|$id:pname->$id:(tupleField i)|] field i ArrayValue {} = [C.cexp|$id:pname->$id:(tupleField i)|] checks <- map snd <$> zipWithM (prepareValue Private) (zipWith field [0 ..] vds) vds pure ( [C.cparam|const $ty:ty *$id:pname|], if all null checks then Nothing else Just $ allTrue $ concat checks ) prepareValue _ src (ScalarValue pt signed name) = do let pt' = primAPIType signed pt src' = fromStorage pt $ C.toExp src mempty stm [C.cstm|$id:name = $exp:src';|] pure (pt', []) prepareValue pub src vd@(ArrayValue mem _ _ _ shape) = do ty <- valueTypeToCType pub $ valueDescToType vd stm [C.cstm|$exp:mem = $exp:src->mem;|] let rank = length shape maybeCopyDim (Var d) i | d `notNameIn` arg_names = ( Just [C.cstm|$id:d = $exp:src->shape[$int:i];|], [C.cexp|$id:d == $exp:src->shape[$int:i]|] ) maybeCopyDim x i = ( Nothing, [C.cexp|$exp:x == $exp:src->shape[$int:i]|] ) let (sets, checks) = unzip $ zipWith maybeCopyDim shape [0 .. rank - 1] stms $ catMaybes sets pure ([C.cty|$ty:ty*|], checks) prepareEntryOutputs :: [ExternalValue] -> CompilerM op s ([C.Param], [C.BlockItem]) prepareEntryOutputs = collect' . zipWithM prepare [(0 :: Int) ..] where prepare pno (TransparentValue vd) = do let pname = "out" ++ show pno ty <- valueTypeToCType Public $ valueDescToType vd case vd of ArrayValue {} -> do stm [C.cstm|assert((*$id:pname = ($ty:ty*) malloc(sizeof($ty:ty))) != NULL);|] prepareValue [C.cexp|*$id:pname|] vd pure [C.cparam|$ty:ty **$id:pname|] ScalarValue {} -> do prepareValue [C.cexp|*$id:pname|] vd pure [C.cparam|$ty:ty *$id:pname|] prepare pno (OpaqueValue desc vds) = do let pname = "out" ++ show pno ty <- opaqueToCType desc vd_ts <- mapM (valueTypeToCType Private . valueDescToType) vds stm [C.cstm|assert((*$id:pname = ($ty:ty*) malloc(sizeof($ty:ty))) != NULL);|] forM_ (zip3 [0 ..] vd_ts vds) $ \(i, ct, vd) -> do let field = [C.cexp|((*$id:pname)->$id:(tupleField i))|] case vd of ScalarValue {} -> pure () ArrayValue {} -> do stm [C.cstm|assert(($exp:field = ($ty:ct*) malloc(sizeof($ty:ct))) != NULL);|] prepareValue field vd pure [C.cparam|$ty:ty **$id:pname|] prepareValue dest (ScalarValue t _ name) = let name' = toStorage t $ C.toExp name mempty in stm [C.cstm|$exp:dest = $exp:name';|] prepareValue dest (ArrayValue mem _ _ _ shape) = do stm [C.cstm|$exp:dest->mem = $id:mem;|] let rank = length shape maybeCopyDim (Constant x) i = [C.cstm|$exp:dest->shape[$int:i] = $exp:x;|] maybeCopyDim (Var d) i = [C.cstm|$exp:dest->shape[$int:i] = $id:d;|] stms $ zipWith maybeCopyDim shape [0 .. rank - 1] entryName :: Name -> T.Text entryName = ("entry_" <>) . escapeName . nameToText onEntryPoint :: [C.BlockItem] -> [Name] -> Name -> Function op -> CompilerM op s (Maybe (C.Definition, (T.Text, Manifest.EntryPoint))) onEntryPoint _ _ _ (Function Nothing _ _ _) = pure Nothing onEntryPoint get_consts relevant_params fname (Function (Just (EntryPoint ename results args)) outputs inputs _) = inNewFunction $ do let out_args = map (\p -> [C.cexp|&$id:(paramName p)|]) outputs in_args = map (\p -> [C.cexp|$id:(paramName p)|]) inputs inputdecls <- collect $ mapM_ stubParam inputs outputdecls <- collect $ mapM_ stubParam outputs decl_mem <- declAllocatedMem entry_point_function_name <- publicName $ entryName ename (inputs', unpack_entry_inputs) <- prepareEntryInputs $ map snd args let (entry_point_input_params, entry_point_input_checks) = unzip inputs' (entry_point_output_params, pack_entry_outputs) <- prepareEntryOutputs $ map snd results ctx_ty <- contextType headerDecl EntryDecl [C.cedecl|int $id:entry_point_function_name ($ty:ctx_ty *ctx, $params:entry_point_output_params, $params:entry_point_input_params);|] let checks = catMaybes entry_point_input_checks check_input = if null checks then [] else [C.citems| if (!($exp:(allTrue (catMaybes entry_point_input_checks)))) { ret = 1; set_error(ctx, msgprintf("Error: entry point arguments have invalid sizes.\n")); }|] critical = [C.citems| $items:decl_mem $items:unpack_entry_inputs $items:check_input if (ret == 0) { ret = $id:(funName fname)(ctx, $args:out_args, $args:in_args); if (ret == 0) { $items:get_consts $items:pack_entry_outputs } } |] ops <- asks envOperations let cdef = [C.cedecl| int $id:entry_point_function_name ($ty:ctx_ty *ctx, $params:entry_point_output_params, $params:entry_point_input_params) { $items:inputdecls $items:outputdecls int ret = 0; $items:(criticalSection ops critical) return ret; } |] manifest = Manifest.EntryPoint { Manifest.entryPointCFun = entry_point_function_name, Manifest.entryPointTuningParams = map nameToText relevant_params, -- Note that our convention about what is "input/output" -- and what is "results/args" is different between the -- manifest and ImpCode. Manifest.entryPointOutputs = map outputManifest results, Manifest.entryPointInputs = map inputManifest args } pure $ Just (cdef, (nameToText ename, manifest)) where stubParam (MemParam name space) = declMem name space stubParam (ScalarParam name ty) = do let ty' = primTypeToCType ty decl [C.cdecl|$ty:ty' $id:name = $exp:(blankPrimValue ty);|] vdType (TransparentValue (ScalarValue pt signed _)) = prettySigned (signed == Unsigned) pt vdType (TransparentValue (ArrayValue _ _ pt signed shape)) = mconcat (replicate (length shape) "[]") <> prettySigned (signed == Unsigned) pt vdType (OpaqueValue name _) = nameToText name outputManifest (u, vd) = Manifest.Output { Manifest.outputType = vdType vd, Manifest.outputUnique = u == Unique } inputManifest ((v, u), vd) = Manifest.Input { Manifest.inputName = nameToText v, Manifest.inputType = vdType vd, Manifest.inputUnique = u == Unique } futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericC/Fun.hs000066400000000000000000000104351475065116200233210ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | C code generation for functions. module Futhark.CodeGen.Backends.GenericC.Fun ( compileFun, compileVoidFun, module Futhark.CodeGen.Backends.GenericC.Monad, module Futhark.CodeGen.Backends.GenericC.Code, ) where import Control.Monad import Futhark.CodeGen.Backends.GenericC.Code import Futhark.CodeGen.Backends.GenericC.Monad import Futhark.CodeGen.ImpCode import Futhark.MonadFreshNames import Language.C.Quote.OpenCL qualified as C import Language.C.Syntax qualified as C compileFunBody :: [C.Exp] -> [Param] -> Code op -> CompilerM op s () compileFunBody output_ptrs outputs code = do mapM_ declareOutput outputs compileCode code zipWithM_ setRetVal' output_ptrs outputs where declareOutput (MemParam name space) = declMem name space declareOutput (ScalarParam name pt) = do let ctp = primTypeToCType pt decl [C.cdecl|$ty:ctp $id:name;|] setRetVal' p (MemParam name space) = -- It is required that the memory block is already initialised -- (although it may be NULL). setMem [C.cexp|*$exp:p|] name space setRetVal' p (ScalarParam name _) = stm [C.cstm|*$exp:p = $id:name;|] compileInput :: Param -> CompilerM op s C.Param compileInput (ScalarParam name bt) = do let ctp = primTypeToCType bt pure [C.cparam|$ty:ctp $id:name|] compileInput (MemParam name space) = do ty <- memToCType name space pure [C.cparam|$ty:ty $id:name|] compileOutput :: Param -> CompilerM op s (C.Param, C.Exp) compileOutput (ScalarParam name bt) = do let ctp = primTypeToCType bt p_name <- newVName $ "out_" ++ baseString name pure ([C.cparam|$ty:ctp *$id:p_name|], [C.cexp|$id:p_name|]) compileOutput (MemParam name space) = do ty <- memToCType name space p_name <- newVName $ baseString name ++ "_p" pure ([C.cparam|$ty:ty *$id:p_name|], [C.cexp|$id:p_name|]) compileFun :: [C.BlockItem] -> [C.Param] -> (Name, Function op) -> CompilerM op s (C.Definition, C.Func) compileFun get_constants extra (fname, func@(Function _ outputs inputs body)) = inNewFunction $ do (outparams, out_ptrs) <- mapAndUnzipM compileOutput outputs inparams <- mapM compileInput inputs cachingMemory (lexicalMemoryUsage func) $ \decl_cached free_cached -> do body' <- collect $ compileFunBody out_ptrs outputs body decl_mem <- declAllocatedMem free_mem <- freeAllocatedMem let futhark_function = C.DeclSpec [] [C.EscTypeQual "FUTHARK_FUN_ATTR" mempty] (C.Tint Nothing mempty) mempty pure ( [C.cedecl|$spec:futhark_function $id:(funName fname)($params:extra, $params:outparams, $params:inparams);|], [C.cfun|$spec:futhark_function $id:(funName fname)($params:extra, $params:outparams, $params:inparams) { $stms:ignores int err = 0; $items:decl_cached $items:decl_mem $items:get_constants $items:body' cleanup: { $stms:free_cached $items:free_mem } return err; }|] ) where -- Ignore all the boilerplate parameters, just in case we don't -- actually need to use them. ignores = [[C.cstm|(void)$id:p;|] | C.Param (Just p) _ _ _ <- extra] -- | Generate code for a function that returns void (meaning it cannot -- fail) and has no extra parameters (meaning it cannot allocate -- memory non-lexxical or do anything fancy). compileVoidFun :: [C.BlockItem] -> (Name, Function op) -> CompilerM op s (C.Definition, C.Func) compileVoidFun get_constants (fname, func@(Function _ outputs inputs body)) = inNewFunction $ do (outparams, out_ptrs) <- mapAndUnzipM compileOutput outputs inparams <- mapM compileInput inputs cachingMemory (lexicalMemoryUsage func) $ \decl_cached free_cached -> do body' <- collect $ compileFunBody out_ptrs outputs body let futhark_function = C.DeclSpec [] [C.EscTypeQual "FUTHARK_FUN_ATTR" mempty] (C.Tvoid mempty) mempty pure ( [C.cedecl|$spec:futhark_function $id:(funName fname)($params:outparams, $params:inparams);|], [C.cfun|$spec:futhark_function $id:(funName fname)($params:outparams, $params:inparams) { $items:decl_cached $items:get_constants $items:body' $stms:free_cached }|] ) futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericC/Monad.hs000066400000000000000000000500241475065116200236250ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | C code generator framework. module Futhark.CodeGen.Backends.GenericC.Monad ( -- * Pluggable compiler Operations (..), Publicness (..), OpCompiler, ErrorCompiler, CallCompiler, PointerQuals, MemoryType, WriteScalar, writeScalarPointerWithQuals, ReadScalar, readScalarPointerWithQuals, Allocate, Deallocate, CopyBarrier (..), Copy, DoCopy, -- * Monadic compiler interface CompilerM, CompilerState (..), CompilerEnv (..), getUserState, modifyUserState, generateProgramStruct, runCompilerM, inNewFunction, cachingMemory, volQuals, rawMem, item, items, stm, stms, decl, headerDecl, publicDef, publicDef_, onClear, HeaderSection (..), libDecl, earlyDecl, publicName, contextField, contextFieldDyn, memToCType, cacheMem, fatMemory, rawMemCType, freeRawMem, allocRawMem, fatMemType, declAllocatedMem, freeAllocatedMem, collect, collect', contextType, configType, -- * Building Blocks copyMemoryDefaultSpace, derefPointer, setMem, allocMem, unRefMem, declMem, resetMem, fatMemAlloc, fatMemSet, fatMemUnRef, criticalSection, module Futhark.CodeGen.Backends.SimpleRep, ) where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor (first) import Data.DList qualified as DL import Data.List (unzip4) import Data.Loc import Data.Map.Strict qualified as M import Data.Maybe import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC.Pretty import Futhark.CodeGen.Backends.SimpleRep import Futhark.CodeGen.ImpCode import Futhark.MonadFreshNames import Language.C.Quote.OpenCL qualified as C import Language.C.Syntax qualified as C -- How public an array type definition sould be. Public types show up -- in the generated API, while private types are used only to -- implement the members of opaques. data Publicness = Private | Public deriving (Eq, Ord, Show) type ArrayType = (Signedness, PrimType, Int) data CompilerState s = CompilerState { compArrayTypes :: M.Map ArrayType Publicness, compEarlyDecls :: DL.DList C.Definition, compNameSrc :: VNameSource, compUserState :: s, compHeaderDecls :: M.Map HeaderSection (DL.DList C.Definition), compLibDecls :: DL.DList C.Definition, compCtxFields :: DL.DList (C.Id, C.Type, Maybe C.Exp, Maybe (C.Stm, C.Stm)), compClearItems :: DL.DList C.BlockItem, compDeclaredMem :: [(VName, Space)], compItems :: DL.DList C.BlockItem } newCompilerState :: VNameSource -> s -> CompilerState s newCompilerState src s = CompilerState { compArrayTypes = mempty, compEarlyDecls = mempty, compNameSrc = src, compUserState = s, compHeaderDecls = mempty, compLibDecls = mempty, compCtxFields = mempty, compClearItems = mempty, compDeclaredMem = mempty, compItems = mempty } -- | In which part of the header file we put the declaration. This is -- to ensure that the header file remains structured and readable. data HeaderSection = ArrayDecl Name | OpaqueTypeDecl Name | OpaqueDecl Name | EntryDecl | MiscDecl | InitDecl deriving (Eq, Ord) -- | A substitute expression compiler, tried before the main -- compilation function. type OpCompiler op s = op -> CompilerM op s () type ErrorCompiler op s = ErrorMsg Exp -> String -> CompilerM op s () -- | The address space qualifiers for a pointer of the given type with -- the given annotation. type PointerQuals = String -> [C.TypeQual] -- | The type of a memory block in the given memory space. type MemoryType op s = SpaceId -> CompilerM op s C.Type -- | Write a scalar to the given memory block with the given element -- index and in the given memory space. type WriteScalar op s = C.Exp -> C.Exp -> C.Type -> SpaceId -> Volatility -> C.Exp -> CompilerM op s () -- | Read a scalar from the given memory block with the given element -- index and in the given memory space. type ReadScalar op s = C.Exp -> C.Exp -> C.Type -> SpaceId -> Volatility -> CompilerM op s C.Exp -- | Allocate a memory block of the given size and with the given tag -- in the given memory space, saving a reference in the given variable -- name. type Allocate op s = C.Exp -> C.Exp -> C.Exp -> SpaceId -> CompilerM op s () -- | De-allocate the given memory block, with the given tag, with the -- given size,, which is in the given memory space. type Deallocate op s = C.Exp -> C.Exp -> C.Exp -> SpaceId -> CompilerM op s () -- | Whether a copying operation should implicitly function as a -- barrier regarding further operations on the source. This is a -- rather subtle detail and is mostly useful for letting some -- device/GPU copies be asynchronous (#1664). data CopyBarrier = CopyBarrier | -- | Explicit context synchronisation should be done -- before the source or target is used. CopyNoBarrier deriving (Eq, Show) -- | Copy from one memory block to another. type Copy op s = CopyBarrier -> C.Exp -> C.Exp -> Space -> C.Exp -> C.Exp -> Space -> C.Exp -> CompilerM op s () -- | Perform an 'Copy'. It is expected that these functions are -- each specialised on which spaces they operate on, so that is not part of their arguments. type DoCopy op s = CopyBarrier -> PrimType -> [Count Elements C.Exp] -> C.Exp -> ( Count Elements C.Exp, [Count Elements C.Exp] ) -> C.Exp -> ( Count Elements C.Exp, [Count Elements C.Exp] ) -> CompilerM op s () -- | Call a function. type CallCompiler op s = [VName] -> Name -> [C.Exp] -> CompilerM op s () data Operations op s = Operations { opsWriteScalar :: WriteScalar op s, opsReadScalar :: ReadScalar op s, opsAllocate :: Allocate op s, opsDeallocate :: Deallocate op s, opsCopy :: Copy op s, opsMemoryType :: MemoryType op s, opsCompiler :: OpCompiler op s, opsError :: ErrorCompiler op s, opsCall :: CallCompiler op s, -- | @(dst,src)@-space mapping to copy functions. opsCopies :: M.Map (Space, Space) (DoCopy op s), -- | If true, use reference counting. Otherwise, bare -- pointers. opsFatMemory :: Bool, -- | Code to bracket critical sections. opsCritical :: ([C.BlockItem], [C.BlockItem]) } freeAllocatedMem :: CompilerM op s [C.BlockItem] freeAllocatedMem = collect $ mapM_ (uncurry unRefMem) =<< gets compDeclaredMem declAllocatedMem :: CompilerM op s [C.BlockItem] declAllocatedMem = collect $ mapM_ f =<< gets compDeclaredMem where f (name, space) = do ty <- memToCType name space decl [C.cdecl|$ty:ty $id:name;|] resetMem name space data CompilerEnv op s = CompilerEnv { envOperations :: Operations op s, -- | Mapping memory blocks to sizes. These memory blocks are CPU -- memory that we know are used in particularly simple ways (no -- reference counting necessary). To cut down on allocator -- pressure, we keep these allocations around for a long time, and -- record their sizes so we can reuse them if possible (and -- realloc() when needed). envCachedMem :: M.Map C.Exp VName } contextContents :: CompilerM op s ([C.FieldGroup], [C.Stm], [C.Stm]) contextContents = do (field_names, field_types, field_values, field_frees) <- gets $ unzip4 . DL.toList . compCtxFields let fields = [ [C.csdecl|$ty:ty $id:name;|] | (name, ty) <- zip field_names field_types ] init_fields = [ [C.cstm|ctx->program->$id:name = $exp:e;|] | (name, Just e) <- zip field_names field_values ] (setup, free) = unzip $ catMaybes field_frees pure (fields, init_fields <> setup, free) generateProgramStruct :: CompilerM op s () generateProgramStruct = do (fields, init_fields, free_fields) <- contextContents mapM_ earlyDecl [C.cunit|struct program { int dummy; $sdecls:fields }; static void setup_program(struct futhark_context* ctx) { (void)ctx; int error = 0; (void)error; ctx->program = malloc(sizeof(struct program)); $stms:init_fields } static void teardown_program(struct futhark_context *ctx) { (void)ctx; int error = 0; (void)error; $stms:free_fields free(ctx->program); }|] newtype CompilerM op s a = CompilerM (ReaderT (CompilerEnv op s) (State (CompilerState s)) a) deriving ( Functor, Applicative, Monad, MonadState (CompilerState s), MonadReader (CompilerEnv op s) ) instance MonadFreshNames (CompilerM op s) where getNameSource = gets compNameSrc putNameSource src = modify $ \s -> s {compNameSrc = src} runCompilerM :: Operations op s -> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s) runCompilerM ops src userstate (CompilerM m) = runState (runReaderT m (CompilerEnv ops mempty)) (newCompilerState src userstate) getUserState :: CompilerM op s s getUserState = gets compUserState modifyUserState :: (s -> s) -> CompilerM op s () modifyUserState f = modify $ \compstate -> compstate {compUserState = f $ compUserState compstate} collect :: CompilerM op s () -> CompilerM op s [C.BlockItem] collect m = snd <$> collect' m collect' :: CompilerM op s a -> CompilerM op s (a, [C.BlockItem]) collect' m = do old <- gets compItems modify $ \s -> s {compItems = mempty} x <- m new <- gets compItems modify $ \s -> s {compItems = old} pure (x, DL.toList new) -- | Used when we, inside an existing 'CompilerM' action, want to -- generate code for a new function. Use this so that the compiler -- understands that previously declared memory doesn't need to be -- freed inside this action. inNewFunction :: CompilerM op s a -> CompilerM op s a inNewFunction m = do old_mem <- gets compDeclaredMem modify $ \s -> s {compDeclaredMem = mempty} x <- local noCached m modify $ \s -> s {compDeclaredMem = old_mem} pure x where noCached env = env {envCachedMem = mempty} item :: C.BlockItem -> CompilerM op s () item x = modify $ \s -> s {compItems = DL.snoc (compItems s) x} items :: [C.BlockItem] -> CompilerM op s () items xs = modify $ \s -> s {compItems = DL.append (compItems s) (DL.fromList xs)} fatMemory :: Space -> CompilerM op s Bool fatMemory ScalarSpace {} = pure False fatMemory _ = asks $ opsFatMemory . envOperations cacheMem :: (C.ToExp a) => a -> CompilerM op s (Maybe VName) cacheMem a = asks $ M.lookup (C.toExp a noLoc) . envCachedMem -- | Construct a publicly visible definition using the specified name -- as the template. The first returned definition is put in the -- header file, and the second is the implementation. Returns the public -- name. publicDef :: T.Text -> HeaderSection -> (T.Text -> (C.Definition, C.Definition)) -> CompilerM op s T.Text publicDef s h f = do s' <- publicName s let (pub, priv) = f s' headerDecl h pub earlyDecl priv pure s' -- | As 'publicDef', but ignores the public name. publicDef_ :: T.Text -> HeaderSection -> (T.Text -> (C.Definition, C.Definition)) -> CompilerM op s () publicDef_ s h f = void $ publicDef s h f headerDecl :: HeaderSection -> C.Definition -> CompilerM op s () headerDecl sec def = modify $ \s -> s { compHeaderDecls = M.unionWith (<>) (compHeaderDecls s) (M.singleton sec (DL.singleton def)) } libDecl :: C.Definition -> CompilerM op s () libDecl def = modify $ \s -> s {compLibDecls = compLibDecls s <> DL.singleton def} earlyDecl :: C.Definition -> CompilerM op s () earlyDecl def = modify $ \s -> s {compEarlyDecls = compEarlyDecls s <> DL.singleton def} contextField :: C.Id -> C.Type -> Maybe C.Exp -> CompilerM op s () contextField name ty initial = modify $ \s -> s {compCtxFields = compCtxFields s <> DL.singleton (name, ty, initial, Nothing)} contextFieldDyn :: C.Id -> C.Type -> C.Stm -> C.Stm -> CompilerM op s () contextFieldDyn name ty create free = modify $ \s -> s {compCtxFields = compCtxFields s <> DL.singleton (name, ty, Nothing, Just (create, free))} onClear :: C.BlockItem -> CompilerM op s () onClear x = modify $ \s -> s {compClearItems = compClearItems s <> DL.singleton x} stm :: C.Stm -> CompilerM op s () stm s = item [C.citem|$stm:s|] stms :: [C.Stm] -> CompilerM op s () stms = mapM_ stm decl :: C.InitGroup -> CompilerM op s () decl x = item [C.citem|$decl:x;|] -- | Public names must have a consitent prefix. publicName :: T.Text -> CompilerM op s T.Text publicName s = pure $ "futhark_" <> s memToCType :: VName -> Space -> CompilerM op s C.Type memToCType v space = do refcount <- fatMemory space cached <- isJust <$> cacheMem v if refcount && not cached then pure $ fatMemType space else rawMemCType space rawMemCType :: Space -> CompilerM op s C.Type rawMemCType DefaultSpace = pure defaultMemBlockType rawMemCType (Space sid) = join $ asks (opsMemoryType . envOperations) <*> pure sid rawMemCType (ScalarSpace [] t) = pure [C.cty|$ty:(primTypeToCType t)[1]|] rawMemCType (ScalarSpace ds t) = pure [C.cty|$ty:(primTypeToCType t)[$exp:(cproduct ds')]|] where ds' = map (`C.toExp` noLoc) ds fatMemType :: Space -> C.Type fatMemType space = [C.cty|struct $id:name|] where name = case space of Space sid -> "memblock_" ++ sid _ -> "memblock" fatMemSet :: Space -> String fatMemSet (Space sid) = "memblock_set_" ++ sid fatMemSet _ = "memblock_set" fatMemAlloc :: Space -> String fatMemAlloc (Space sid) = "memblock_alloc_" ++ sid fatMemAlloc _ = "memblock_alloc" fatMemUnRef :: Space -> String fatMemUnRef (Space sid) = "memblock_unref_" ++ sid fatMemUnRef _ = "memblock_unref" rawMem :: VName -> CompilerM op s C.Exp rawMem v = rawMem' <$> fat <*> pure v where fat = asks ((&&) . opsFatMemory . envOperations) <*> (isNothing <$> cacheMem v) rawMem' :: (C.ToExp a) => Bool -> a -> C.Exp rawMem' True e = [C.cexp|$exp:e.mem|] rawMem' False e = [C.cexp|$exp:e|] allocRawMem :: (C.ToExp a, C.ToExp b, C.ToExp c) => a -> b -> Space -> c -> CompilerM op s () allocRawMem dest size space desc = case space of Space sid -> join $ asks (opsAllocate . envOperations) <*> pure [C.cexp|$exp:dest|] <*> pure [C.cexp|$exp:size|] <*> pure [C.cexp|$exp:desc|] <*> pure sid _ -> stm [C.cstm|host_alloc(ctx, (size_t)$exp:size, $exp:desc, (size_t*)&$exp:size, (void*)&$exp:dest);|] freeRawMem :: (C.ToExp a, C.ToExp b, C.ToExp c) => a -> b -> Space -> c -> CompilerM op s () freeRawMem mem size space desc = case space of Space sid -> do free_mem <- asks (opsDeallocate . envOperations) free_mem [C.cexp|$exp:mem|] [C.cexp|$exp:size|] [C.cexp|$exp:desc|] sid _ -> item [C.citem|host_free(ctx, (size_t)$exp:size, $exp:desc, (void*)$exp:mem);|] declMem :: VName -> Space -> CompilerM op s () declMem name space = do cached <- isJust <$> cacheMem name fat <- fatMemory space unless cached $ if fat then modify $ \s -> s {compDeclaredMem = (name, space) : compDeclaredMem s} else do ty <- memToCType name space decl [C.cdecl|$ty:ty $id:name;|] resetMem :: (C.ToExp a) => a -> Space -> CompilerM op s () resetMem mem space = do refcount <- fatMemory space cached <- isJust <$> cacheMem mem if cached then stm [C.cstm|$exp:mem = NULL;|] else when refcount $ stm [C.cstm|$exp:mem.references = NULL;|] setMem :: (C.ToExp a, C.ToExp b) => a -> b -> Space -> CompilerM op s () setMem dest src space = do refcount <- fatMemory space let src_s = T.unpack $ expText $ C.toExp src noLoc if refcount then stm [C.cstm|if ($id:(fatMemSet space)(ctx, &$exp:dest, &$exp:src, $string:src_s) != 0) { return 1; }|] else case space of ScalarSpace ds _ -> do i' <- newVName "i" let i = C.toIdent i' it = primTypeToCType $ IntType Int32 ds' = map (`C.toExp` noLoc) ds bound = cproduct ds' stm [C.cstm|for ($ty:it $id:i = 0; $id:i < $exp:bound; $id:i++) { $exp:dest[$id:i] = $exp:src[$id:i]; }|] _ -> stm [C.cstm|$exp:dest = $exp:src;|] unRefMem :: (C.ToExp a) => a -> Space -> CompilerM op s () unRefMem mem space = do refcount <- fatMemory space cached <- isJust <$> cacheMem mem let mem_s = T.unpack $ expText $ C.toExp mem noLoc when (refcount && not cached) $ stm [C.cstm|if ($id:(fatMemUnRef space)(ctx, &$exp:mem, $string:mem_s) != 0) { return 1; }|] allocMem :: (C.ToExp a, C.ToExp b) => a -> b -> Space -> C.Stm -> CompilerM op s () allocMem mem size space on_failure = do refcount <- fatMemory space let mem_s = T.unpack $ expText $ C.toExp mem noLoc if refcount then stm [C.cstm|if ($id:(fatMemAlloc space)(ctx, &$exp:mem, $exp:size, $string:mem_s)) { $stm:on_failure }|] else do freeRawMem mem size space mem_s allocRawMem mem size space [C.cexp|desc|] copyMemoryDefaultSpace :: C.Exp -> C.Exp -> C.Exp -> C.Exp -> C.Exp -> CompilerM op s () copyMemoryDefaultSpace destmem destidx srcmem srcidx nbytes = stm [C.cstm|if ($exp:nbytes > 0) { memmove($exp:destmem + $exp:destidx, $exp:srcmem + $exp:srcidx, $exp:nbytes); }|] cachingMemory :: M.Map VName Space -> ([C.BlockItem] -> [C.Stm] -> CompilerM op s a) -> CompilerM op s a cachingMemory lexical f = do -- We only consider lexical 'DefaultSpace' memory blocks to be -- cached. This is not a deep technical restriction, but merely a -- heuristic based on GPU memory usually involving larger -- allocations, that do not suffer from the overhead of reference -- counting. Beware: there is code elsewhere in codegen that -- assumes lexical memory is DefaultSpace too. let cached = M.keys $ M.filter (== DefaultSpace) lexical cached' <- forM cached $ \mem -> do size <- newVName $ prettyString mem <> "_cached_size" pure (mem, size) let lexMem env = env { envCachedMem = M.fromList (map (first (`C.toExp` noLoc)) cached') <> envCachedMem env } declCached (mem, size) = [ [C.citem|typename int64_t $id:size = 0;|], [C.citem|$ty:defaultMemBlockType $id:mem = NULL;|] ] freeCached (mem, _) = [C.cstm|free($id:mem);|] local lexMem $ f (concatMap declCached cached') (map freeCached cached') derefPointer :: C.Exp -> C.Exp -> C.Type -> C.Exp derefPointer ptr i res_t = [C.cexp|(($ty:res_t)$exp:ptr)[$exp:i]|] volQuals :: Volatility -> [C.TypeQual] volQuals Volatile = [C.ctyquals|volatile|] volQuals Nonvolatile = [] writeScalarPointerWithQuals :: PointerQuals -> WriteScalar op s writeScalarPointerWithQuals quals_f dest i elemtype space vol v = do let quals' = volQuals vol ++ quals_f space deref = derefPointer dest i [C.cty|$tyquals:quals' $ty:elemtype*|] stm [C.cstm|$exp:deref = $exp:v;|] readScalarPointerWithQuals :: PointerQuals -> ReadScalar op s readScalarPointerWithQuals quals_f dest i elemtype space vol = do let quals' = volQuals vol ++ quals_f space pure $ derefPointer dest i [C.cty|$tyquals:quals' $ty:elemtype*|] criticalSection :: Operations op s -> [C.BlockItem] -> [C.BlockItem] criticalSection ops x = [C.citems|lock_lock(&ctx->lock); $items:(fst (opsCritical ops)) $items:x $items:(snd (opsCritical ops)) lock_unlock(&ctx->lock); |] -- | The generated code must define a context struct with this name. contextType :: CompilerM op s C.Type contextType = do name <- publicName "context" pure [C.cty|struct $id:name|] -- | The generated code must define a configuration struct with this -- name. configType :: CompilerM op s C.Type configType = do name <- publicName "context_config" pure [C.cty|struct $id:name|] futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericC/Options.hs000066400000000000000000000115301475065116200242210ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | This module defines a generator for @getopt_long@ based command -- line argument parsing. Each option is associated with arbitrary C -- code that will perform side effects, usually by setting some global -- variables. module Futhark.CodeGen.Backends.GenericC.Options ( Option (..), OptionArgument (..), generateOptionParser, ) where import Data.Char (isSpace) import Data.Function ((&)) import Data.List (intercalate) import Data.Maybe import Language.C.Quote.C qualified as C import Language.C.Syntax qualified as C -- | Specification if a single command line option. The option must -- have a long name, and may also have a short name. -- -- In the action, the option argument (if any) is stored as in the -- @char*@-typed variable @optarg@. data Option = Option { optionLongName :: String, optionShortName :: Maybe Char, optionArgument :: OptionArgument, optionDescription :: String, optionAction :: C.Stm } -- | Whether an option accepts an argument. data OptionArgument = NoArgument | -- | The 'String' becomes part of the help pretty. RequiredArgument String | OptionalArgument -- | Generate an option parser as a function of the given name, that -- accepts the given command line options. The result is a function -- that should be called with @argc@ and @argv@. The function returns -- the number of @argv@ elements that have been processed. -- -- If option parsing fails for any reason, the entire process will -- terminate with error code 1. generateOptionParser :: String -> [Option] -> C.Func generateOptionParser fname options = [C.cfun|int $id:fname(struct futhark_context_config *cfg, int argc, char* const argv[]) { int $id:chosen_option; static struct option long_options[] = { $inits:option_fields, {0, 0, 0, 0} }; static char* option_descriptions = $string:option_descriptions; while (($id:chosen_option = getopt_long(argc, argv, $string:option_string, long_options, NULL)) != -1) { $stms:option_applications if ($id:chosen_option == ':') { futhark_panic(-1, "Missing argument for option %s\n", argv[optind-1]); } if ($id:chosen_option == '?') { fprintf(stderr, "Usage: %s [OPTIONS]...\nOptions:\n\n%s\n", fut_progname, $string:option_descriptions); futhark_panic(1, "Unknown option: %s\n", argv[optind-1]); } } return optind; } |] where chosen_option = "ch" option_string = ':' : optionString options option_applications = optionApplications chosen_option options option_fields = optionFields options option_descriptions = describeOptions options trim :: String -> String trim = f . f where f = reverse . dropWhile isSpace describeOptions :: [Option] -> String describeOptions opts = let in unlines $ fmap extendDesc with_short_descs where with_short_descs = fmap (\opt -> (opt, shortDesc opt)) opts max_short_desc_len = maximum $ fmap (length . snd) with_short_descs extendDesc :: (Option, String) -> String extendDesc (opt, short) = take (max_short_desc_len + 1) (short ++ repeat ' ') ++ ( optionDescription opt & lines & fmap trim & intercalate ('\n' : replicate (max_short_desc_len + 1) ' ') ) shortDesc :: Option -> String shortDesc opt = concat [ " ", maybe "" (\c -> "-" ++ [c] ++ "/") $ optionShortName opt, "--" ++ optionLongName opt, case optionArgument opt of NoArgument -> "" RequiredArgument what -> " " ++ what OptionalArgument -> " [ARG]" ] optionFields :: [Option] -> [C.Initializer] optionFields = zipWith field [(1 :: Int) ..] where field i option = [C.cinit| { $string:(optionLongName option), $id:arg, NULL, $int:i } |] where arg = case optionArgument option of NoArgument -> "no_argument" :: String RequiredArgument _ -> "required_argument" OptionalArgument -> "optional_argument" optionApplications :: String -> [Option] -> [C.Stm] optionApplications chosen_option = zipWith check [(1 :: Int) ..] where check i option = [C.cstm|if ($exp:cond) $stm:(optionAction option)|] where cond = case optionShortName option of Nothing -> [C.cexp|$id:chosen_option == $int:i|] Just c -> [C.cexp|($id:chosen_option == $int:i) || ($id:chosen_option == $char:c)|] optionString :: [Option] -> String optionString = concat . mapMaybe optionStringChunk where optionStringChunk option = do short <- optionShortName option pure $ short : case optionArgument option of NoArgument -> "" RequiredArgument _ -> ":" OptionalArgument -> "::" futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericC/Pretty.hs000066400000000000000000000022661475065116200240630ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} -- | Compatibility shims for mainland-pretty; the prettyprinting -- library used by language-c-quote. module Futhark.CodeGen.Backends.GenericC.Pretty ( expText, definitionsText, typeText, idText, funcText, funcsText, ) where import Data.Text qualified as T import Language.C.Pretty () import Language.C.Syntax qualified as C import Text.PrettyPrint.Mainland qualified as MPP import Text.PrettyPrint.Mainland.Class qualified as MPP render :: MPP.Doc -> String render = MPP.pretty 8000 -- | Prettyprint a C expression. expText :: C.Exp -> T.Text expText = T.pack . render . MPP.ppr -- | Prettyprint a list of C definitions. definitionsText :: [C.Definition] -> T.Text definitionsText = T.unlines . map (T.pack . render . MPP.ppr) -- | Prettyprint a single C type. typeText :: C.Type -> T.Text typeText = T.pack . render . MPP.ppr -- | Prettyprint a single identifier. idText :: C.Id -> T.Text idText = T.pack . render . MPP.ppr -- | Prettyprint a single function. funcText :: C.Func -> T.Text funcText = T.pack . render . MPP.ppr -- | Prettyprint a list of functions. funcsText :: [C.Func] -> T.Text funcsText = T.unlines . map funcText futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericC/Server.hs000066400000000000000000000350561475065116200240450ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | Code generation for server executables. module Futhark.CodeGen.Backends.GenericC.Server ( serverDefs, ) where import Data.Bifunctor (first, second) import Data.Map qualified as M import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC.Options import Futhark.CodeGen.Backends.GenericC.Pretty import Futhark.CodeGen.Backends.SimpleRep import Futhark.CodeGen.RTS.C (serverH, tuningH, valuesH) import Futhark.Manifest import Futhark.Util (zEncodeText) import Language.C.Quote.OpenCL qualified as C import Language.C.Syntax qualified as C import Language.Futhark.Core (nameFromText) genericOptions :: [Option] genericOptions = [ Option { optionLongName = "debugging", optionShortName = Just 'D', optionArgument = NoArgument, optionDescription = "Perform possibly expensive internal correctness checks and verbose logging.", optionAction = [C.cstm|futhark_context_config_set_debugging(cfg, 1);|] }, Option { optionLongName = "log", optionShortName = Just 'L', optionArgument = NoArgument, optionDescription = "Print various low-overhead logging information while running.", optionAction = [C.cstm|futhark_context_config_set_logging(cfg, 1);|] }, Option { optionLongName = "profile", optionShortName = Just 'P', optionArgument = NoArgument, optionDescription = "Enable the collection of profiling information.", optionAction = [C.cstm|futhark_context_config_set_profiling(cfg, 1);|] }, Option { optionLongName = "help", optionShortName = Just 'h', optionArgument = NoArgument, optionDescription = "Print help information and exit.", optionAction = [C.cstm|{ printf("Usage: %s [OPTIONS]...\nOptions:\n\n%s\nFor more information, consult the Futhark User's Guide or the man pages.\n", fut_progname, option_descriptions); exit(0); }|] }, Option { optionLongName = "print-params", optionShortName = Nothing, optionArgument = NoArgument, optionDescription = "Print all tuning parameters that can be set with --param or --tuning.", optionAction = [C.cstm|{ int n = futhark_get_tuning_param_count(); for (int i = 0; i < n; i++) { printf("%s (%s)\n", futhark_get_tuning_param_name(i), futhark_get_tuning_param_class(i)); } exit(0); }|] }, Option { optionLongName = "param", optionShortName = Nothing, optionArgument = RequiredArgument "ASSIGNMENT", optionDescription = "Set a tuning parameter to the given value.", optionAction = [C.cstm|{ char *name = optarg; char *equals = strstr(optarg, "="); char *value_str = equals != NULL ? equals+1 : optarg; int value = atoi(value_str); if (equals != NULL) { *equals = 0; if (futhark_context_config_set_tuning_param(cfg, name, value) != 0) { futhark_panic(1, "Unknown size: %s\n", name); } } else { futhark_panic(1, "Invalid argument for size option: %s\n", optarg); }}|] }, Option { optionLongName = "tuning", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Read size=value assignments from the given file.", optionAction = [C.cstm|{ char *ret = load_tuning_file(optarg, cfg, (int(*)(void*, const char*, size_t)) futhark_context_config_set_tuning_param); if (ret != NULL) { futhark_panic(1, "When loading tuning file '%s': %s\n", optarg, ret); }}|] }, Option { optionLongName = "cache-file", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Store program cache here.", optionAction = [C.cstm|futhark_context_config_set_cache_file(cfg, optarg);|] } ] typeStructName :: T.Text -> T.Text typeStructName tname = "type_" <> zEncodeText tname cType :: Manifest -> TypeName -> C.Type cType manifest tname = case M.lookup tname $ manifestTypes manifest of Just (TypeArray ctype _ _ _) -> [C.cty|typename $id:(T.unpack ctype)|] Just (TypeOpaque ctype _ _) -> [C.cty|typename $id:(T.unpack ctype)|] Nothing -> uncurry primAPIType $ scalarToPrim tname -- First component is forward declaration so we don't have to worry -- about ordering. typeBoilerplate :: Manifest -> (T.Text, Type) -> (C.Definition, C.Initializer, [C.Definition]) typeBoilerplate _ (tname, TypeArray _ et rank ops) = let type_name = typeStructName tname aux_name = type_name <> "_aux" info_name = et <> "_info" shape_args = [[C.cexp|shape[$int:i]|] | i <- [0 .. rank - 1]] array_new_wrap = arrayNew ops <> "_wrap" in ( [C.cedecl|const struct type $id:type_name;|], [C.cinit|&$id:type_name|], [C.cunit| void* $id:array_new_wrap(struct futhark_context *ctx, const void* p, const typename int64_t* shape) { return $id:(arrayNew ops)(ctx, p, $args:shape_args); } const struct array_aux $id:aux_name = { .name = $string:(T.unpack tname), .rank = $int:rank, .info = &$id:info_name, .new = (typename array_new_fn)$id:array_new_wrap, .free = (typename array_free_fn)$id:(arrayFree ops), .shape = (typename array_shape_fn)$id:(arrayShape ops), .values = (typename array_values_fn)$id:(arrayValues ops) }; const struct type $id:type_name = { .name = $string:(T.unpack tname), .restore = (typename restore_fn)restore_array, .store = (typename store_fn)store_array, .free = (typename free_fn)free_array, .aux = &$id:aux_name };|] ) typeBoilerplate manifest (tname, TypeOpaque c_type_name ops extra_ops) = let type_name = typeStructName tname aux_name = type_name <> "_aux" (record_edecls, record_init) = recordDefs type_name extra_ops in ( [C.cedecl|const struct type $id:type_name;|], [C.cinit|&$id:type_name|], record_edecls ++ [C.cunit| const struct opaque_aux $id:aux_name = { .store = (typename opaque_store_fn)$id:(opaqueStore ops), .restore = (typename opaque_restore_fn)$id:(opaqueRestore ops), .free = (typename opaque_free_fn)$id:(opaqueFree ops) }; const struct type $id:type_name = { .name = $string:(T.unpack tname), .restore = (typename restore_fn)restore_opaque, .store = (typename store_fn)store_opaque, .free = (typename free_fn)free_opaque, .aux = &$id:aux_name, .record = $init:record_init };|] ) where recordDefs type_name (Just (OpaqueRecord (RecordOps fields new))) = let new_wrap = new <> "_wrap" record_name = type_name <> "_record" fields_name = type_name <> "_fields" onField i (RecordField name field_tname project) = let field_c_type = cType manifest field_tname field_v = "v" <> show (i :: Int) in ( [C.cinit|{.name = $string:(T.unpack name), .type = &$id:(typeStructName field_tname), .project = (typename project_fn)$id:project }|], [C.citem|const $ty:field_c_type $id:field_v = *(const $ty:field_c_type*)fields[$int:i];|], [C.cexp|$id:field_v|] ) (field_inits, get_fields, field_args) = unzip3 $ zipWith onField [0 ..] fields in ( [C.cunit| const struct field $id:fields_name[] = { $inits:field_inits }; int $id:new_wrap(struct futhark_context* ctx, void** outp, const void* fields[]) { typename $id:c_type_name *out = (typename $id:c_type_name*) outp; $items:get_fields return $id:new(ctx, out, $args:field_args); } const struct record $id:record_name = { .num_fields = $int:(length fields), .fields = $id:fields_name, .new = $id:new_wrap };|], [C.cinit|&$id:record_name|] ) recordDefs _ _ = ([], [C.cinit|NULL|]) entryTypeBoilerplate :: Manifest -> ([C.Definition], [C.Initializer], [C.Definition]) entryTypeBoilerplate manifest = second concat . unzip3 . map (typeBoilerplate manifest) . M.toList . manifestTypes $ manifest oneEntryBoilerplate :: Manifest -> (T.Text, EntryPoint) -> ([C.Definition], C.Initializer) oneEntryBoilerplate manifest (name, EntryPoint cfun tuning_params outputs inputs) = let call_f = "call_" <> nameFromText name out_types = map outputType outputs in_types = map inputType inputs out_types_name = nameFromText name <> "_out_types" in_types_name = nameFromText name <> "_in_types" out_unique_name = nameFromText name <> "_out_unique" in_unique_name = nameFromText name <> "_in_unique" tuning_params_name = nameFromText name <> "_tuning_params" (out_items, out_args) | null out_types = ([C.citems|(void)outs;|], mempty) | otherwise = unzip $ zipWith loadOut [0 ..] out_types (in_items, in_args) | null in_types = ([C.citems|(void)ins;|], mempty) | otherwise = unzip $ zipWith loadIn [0 ..] in_types in ( [C.cunit| const struct type* $id:out_types_name[] = { $inits:(map typeStructInit out_types), NULL }; bool $id:out_unique_name[] = { $inits:(map outputUniqueInit outputs) }; const struct type* $id:in_types_name[] = { $inits:(map typeStructInit in_types), NULL }; bool $id:in_unique_name[] = { $inits:(map inputUniqueInit inputs) }; const char* $id:tuning_params_name[] = { $inits:(map textInit tuning_params), NULL }; int $id:call_f(struct futhark_context *ctx, void **outs, void **ins) { $items:out_items $items:in_items return $id:cfun(ctx, $args:out_args, $args:in_args); } |], [C.cinit|{ .name = $string:(T.unpack name), .f = $id:call_f, .tuning_params = $id:tuning_params_name, .in_types = $id:in_types_name, .out_types = $id:out_types_name, .in_unique = $id:in_unique_name, .out_unique = $id:out_unique_name }|] ) where typeStructInit tname = [C.cinit|&$id:(typeStructName tname)|] inputUniqueInit = uniqueInit . inputUnique outputUniqueInit = uniqueInit . outputUnique uniqueInit True = [C.cinit|true|] uniqueInit False = [C.cinit|false|] loadOut i tname = let v = "out" ++ show (i :: Int) in ( [C.citem|$ty:(cType manifest tname) *$id:v = outs[$int:i];|], [C.cexp|$id:v|] ) loadIn i tname = let v = "in" ++ show (i :: Int) in ( [C.citem|$ty:(cType manifest tname) $id:v = *($ty:(cType manifest tname)*)ins[$int:i];|], [C.cexp|$id:v|] ) textInit t = [C.cinit|$string:(T.unpack t)|] entryBoilerplate :: Manifest -> ([C.Definition], [C.Initializer]) entryBoilerplate manifest = first concat $ unzip $ map (oneEntryBoilerplate manifest) $ M.toList $ manifestEntryPoints manifest mkBoilerplate :: Manifest -> ([C.Definition], [C.Initializer], [C.Initializer]) mkBoilerplate manifest = let (type_decls, type_inits, type_defs) = entryTypeBoilerplate manifest (entry_defs, entry_inits) = entryBoilerplate manifest scalar_type_inits = map scalarTypeInit scalar_types in (type_decls ++ type_defs ++ entry_defs, scalar_type_inits ++ type_inits, entry_inits) where scalarTypeInit tname = [C.cinit|&$id:(typeStructName tname)|] scalar_types = [ "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f16", "f32", "f64", "bool" ] {-# NOINLINE serverDefs #-} -- | Generate Futhark server executable code. serverDefs :: [Option] -> Manifest -> T.Text serverDefs options manifest = let option_parser = generateOptionParser "parse_options" $ genericOptions ++ options (boilerplate_defs, type_inits, entry_point_inits) = mkBoilerplate manifest in definitionsText [C.cunit| $esc:("#include ") $esc:("#include ") $esc:("#include ") // If the entry point is NULL, the program will terminate after doing initialisation and such. It is not used for anything else in server mode. static const char *entry_point = "main"; $esc:(T.unpack valuesH) $esc:(T.unpack serverH) $esc:(T.unpack tuningH) $edecls:boilerplate_defs const struct type* types[] = { $inits:type_inits, NULL }; struct entry_point entry_points[] = { $inits:entry_point_inits, { .name = NULL } }; struct futhark_prog prog = { .types = types, .entry_points = entry_points }; $func:option_parser int main(int argc, char** argv) { fut_progname = argv[0]; struct futhark_context_config *cfg = futhark_context_config_new(); assert(cfg != NULL); int parsed_options = parse_options(cfg, argc, argv); argc -= parsed_options; argv += parsed_options; if (argc != 0) { futhark_panic(1, "Excess non-option: %s\n", argv[0]); } struct futhark_context *ctx = futhark_context_new(cfg); assert (ctx != NULL); futhark_context_set_logging_file(ctx, stdout); char* error = futhark_context_get_error(ctx); if (error != NULL) { futhark_panic(1, "Error during context initialisation:\n%s", error); } if (entry_point != NULL) { run_server(&prog, cfg, ctx); } futhark_context_free(ctx); futhark_context_config_free(cfg); } |] futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericC/Types.hs000066400000000000000000001070711475065116200237000ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | Code generation for public API types. module Futhark.CodeGen.Backends.GenericC.Types ( generateAPITypes, valueTypeToCType, opaqueToCType, ) where import Control.Monad import Control.Monad.Reader (asks) import Control.Monad.State (gets, modify) import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC.Monad import Futhark.CodeGen.Backends.GenericC.Pretty import Futhark.CodeGen.ImpCode import Futhark.Manifest qualified as Manifest import Futhark.Util (chunks, mapAccumLM, zEncodeText) import Language.C.Quote.OpenCL qualified as C import Language.C.Syntax qualified as C opaqueToCType :: Name -> CompilerM op s C.Type opaqueToCType desc = do name <- publicName $ opaqueName desc pure [C.cty|struct $id:name|] valueTypeToCType :: Publicness -> ValueType -> CompilerM op s C.Type valueTypeToCType _ (ValueType signed (Rank 0) pt) = pure $ primAPIType signed pt valueTypeToCType pub (ValueType signed (Rank rank) pt) = do name <- publicName $ arrayName pt signed rank let add = M.insertWith max (signed, pt, rank) pub modify $ \s -> s {compArrayTypes = add $ compArrayTypes s} pure [C.cty|struct $id:name|] prepareNewMem :: (C.ToExp arr, C.ToExp dim) => arr -> Space -> [dim] -> PrimType -> CompilerM op s () prepareNewMem arr space shape pt = do let rank = length shape arr_size = cproduct [[C.cexp|$exp:k|] | k <- shape] resetMem [C.cexp|$exp:arr->mem|] space allocMem [C.cexp|$exp:arr->mem|] [C.cexp|$exp:arr_size * $int:(primByteSize pt::Int)|] space [C.cstm|err = 1;|] forM_ (zip [0 .. rank - 1] shape) $ \(i, dim_s) -> stm [C.cstm|$exp:arr->shape[$int:i] = $exp:dim_s;|] arrayLibraryFunctions :: Publicness -> Space -> PrimType -> Signedness -> Int -> CompilerM op s Manifest.ArrayOps arrayLibraryFunctions pub space pt signed rank = do let pt' = primAPIType signed pt name = arrayName pt signed rank arr_name = "futhark_" <> name array_type = [C.cty|struct $id:arr_name|] new_array <- publicName $ "new_" <> name new_raw_array <- publicName $ "new_raw_" <> name free_array <- publicName $ "free_" <> name values_array <- publicName $ "values_" <> name values_raw_array <- publicName $ "values_raw_" <> name shape_array <- publicName $ "shape_" <> name index_array <- publicName $ "index_" <> name let shape_names = ["dim" <> prettyText i | i <- [0 .. rank - 1]] shape_params = [[C.cparam|typename int64_t $id:k|] | k <- shape_names] shape = [[C.cexp|$id:k|] | k <- shape_names] index_names = ["i" <> prettyText i | i <- [0 .. rank - 1]] index_params = [[C.cparam|typename int64_t $id:k|] | k <- index_names] arr_size = cproduct shape arr_size_array = cproduct [[C.cexp|arr->shape[$int:i]|] | i <- [0 .. rank - 1]] copy <- asks $ opsCopy . envOperations memty <- rawMemCType space new_body <- collect $ do prepareNewMem [C.cexp|arr|] space shape pt copy CopyNoBarrier [C.cexp|arr->mem.mem|] [C.cexp|0|] space [C.cexp|(const unsigned char*)data|] [C.cexp|0|] DefaultSpace [C.cexp|((size_t)$exp:arr_size) * $int:(primByteSize pt::Int)|] new_raw_body <- collect $ do resetMem [C.cexp|arr->mem|] space stm [C.cstm|arr->mem.mem = data;|] forM_ [0 .. rank - 1] $ \i -> let dim_s = "dim" ++ show i in stm [C.cstm|arr->shape[$int:i] = $id:dim_s;|] free_body <- collect $ unRefMem [C.cexp|arr->mem|] space values_body <- collect $ copy CopyNoBarrier [C.cexp|(unsigned char*)data|] [C.cexp|0|] DefaultSpace [C.cexp|arr->mem.mem|] [C.cexp|0|] space [C.cexp|((size_t)$exp:arr_size_array) * $int:(primByteSize pt::Int)|] let arr_strides = do r <- [0 .. rank - 1] pure $ cproduct [[C.cexp|arr->shape[$int:i]|] | i <- [r + 1 .. rank - 1]] index_exp = cproduct [ [C.cexp|$int:(primByteSize pt::Int)|], csum (zipWith (\x y -> [C.cexp|$id:x * $exp:y|]) index_names arr_strides) ] in_bounds = allTrue [ [C.cexp|$id:p >= 0 && $id:p < arr->shape[$int:i]|] | (p, i) <- zip index_names [0 .. rank - 1] ] index_body <- collect $ copy CopyNoBarrier [C.cexp|(unsigned char*)out|] [C.cexp|0|] DefaultSpace [C.cexp|arr->mem.mem|] index_exp space [C.cexp|$int:(primByteSize pt::Int)|] ctx_ty <- contextType ops <- asks envOperations let proto = case pub of Public -> headerDecl (ArrayDecl (nameFromText name)) Private -> libDecl proto [C.cedecl|struct $id:arr_name;|] proto [C.cedecl|$ty:array_type* $id:new_array($ty:ctx_ty *ctx, const $ty:pt' *data, $params:shape_params);|] proto [C.cedecl|$ty:array_type* $id:new_raw_array($ty:ctx_ty *ctx, $ty:memty data, $params:shape_params);|] proto [C.cedecl|int $id:free_array($ty:ctx_ty *ctx, $ty:array_type *arr);|] proto [C.cedecl|int $id:values_array($ty:ctx_ty *ctx, $ty:array_type *arr, $ty:pt' *data);|] proto [C.cedecl|int $id:index_array($ty:ctx_ty *ctx, $ty:pt' *out, $ty:array_type *arr, $params:index_params);|] proto [C.cedecl|$ty:memty $id:values_raw_array($ty:ctx_ty *ctx, $ty:array_type *arr);|] proto [C.cedecl|const typename int64_t* $id:shape_array($ty:ctx_ty *ctx, $ty:array_type *arr);|] mapM_ libDecl [C.cunit| $ty:array_type* $id:new_array($ty:ctx_ty *ctx, const $ty:pt' *data, $params:shape_params) { int err = 0; $ty:array_type* bad = NULL; $ty:array_type *arr = ($ty:array_type*) malloc(sizeof($ty:array_type)); if (arr == NULL) { return bad; } $items:(criticalSection ops new_body) if (err != 0) { free(arr); return bad; } return arr; } $ty:array_type* $id:new_raw_array($ty:ctx_ty *ctx, $ty:memty data, $params:shape_params) { int err = 0; $ty:array_type* bad = NULL; $ty:array_type *arr = ($ty:array_type*) malloc(sizeof($ty:array_type)); if (arr == NULL) { return bad; } $items:(criticalSection ops new_raw_body) return arr; } int $id:free_array($ty:ctx_ty *ctx, $ty:array_type *arr) { $items:(criticalSection ops free_body) free(arr); return 0; } int $id:values_array($ty:ctx_ty *ctx, $ty:array_type *arr, $ty:pt' *data) { int err = 0; $items:(criticalSection ops values_body) return err; } int $id:index_array($ty:ctx_ty *ctx, $ty:pt' *out, $ty:array_type *arr, $params:index_params) { int err = 0; if ($exp:in_bounds) { $items:(criticalSection ops index_body) } else { err = 1; set_error(ctx, strdup("Index out of bounds.")); } return err; } $ty:memty $id:values_raw_array($ty:ctx_ty *ctx, $ty:array_type *arr) { (void)ctx; return arr->mem.mem; } const typename int64_t* $id:shape_array($ty:ctx_ty *ctx, $ty:array_type *arr) { (void)ctx; return arr->shape; } |] pure $ Manifest.ArrayOps { Manifest.arrayFree = free_array, Manifest.arrayShape = shape_array, Manifest.arrayValues = values_array, Manifest.arrayNew = new_array, Manifest.arrayNewRaw = new_raw_array, Manifest.arrayValuesRaw = values_raw_array, Manifest.arrayIndex = index_array } lookupOpaqueType :: Name -> OpaqueTypes -> OpaqueType lookupOpaqueType v (OpaqueTypes types) = case lookup v types of Just t -> t Nothing -> error $ "Unknown opaque type: " ++ show v opaquePayload :: OpaqueTypes -> OpaqueType -> [ValueType] opaquePayload _ (OpaqueType ts) = ts opaquePayload _ (OpaqueSum ts _) = ts opaquePayload _ (OpaqueArray _ _ ts) = ts opaquePayload types (OpaqueRecord fs) = concatMap f fs where f (_, TypeOpaque s) = opaquePayload types $ lookupOpaqueType s types f (_, TypeTransparent v) = [v] opaquePayload types (OpaqueRecordArray _ _ fs) = concatMap f fs where f (_, TypeOpaque s) = opaquePayload types $ lookupOpaqueType s types f (_, TypeTransparent v) = [v] entryPointTypeToCType :: Publicness -> EntryPointType -> CompilerM op s C.Type entryPointTypeToCType _ (TypeOpaque desc) = opaqueToCType desc entryPointTypeToCType pub (TypeTransparent vt) = valueTypeToCType pub vt entryTypeName :: EntryPointType -> Manifest.TypeName entryTypeName (TypeOpaque desc) = nameToText desc entryTypeName (TypeTransparent vt) = prettyText vt -- | Figure out which of the members of an opaque type corresponds to -- which fields. recordFieldPayloads :: OpaqueTypes -> [EntryPointType] -> [a] -> [[a]] recordFieldPayloads types = chunks . map typeLength where typeLength (TypeTransparent _) = 1 typeLength (TypeOpaque desc) = length $ opaquePayload types $ lookupOpaqueType desc types projectField :: Operations op s -> EntryPointType -> [(Int, ValueType)] -> CompilerM op s (C.Type, [C.BlockItem]) projectField _ (TypeTransparent (ValueType sign (Rank 0) pt)) [(i, _)] = do pure ( primAPIType sign pt, [C.citems|v = obj->$id:(tupleField i);|] ) projectField ops (TypeTransparent vt) [(i, _)] = do ct <- valueTypeToCType Public vt pure ( [C.cty|$ty:ct *|], criticalSection ops [C.citems|v = malloc(sizeof($ty:ct)); memcpy(v, obj->$id:(tupleField i), sizeof($ty:ct)); (void)(*(v->mem.references))++;|] ) projectField _ (TypeTransparent _) rep = error $ "projectField: invalid representation of transparent type: " ++ show rep projectField ops (TypeOpaque f_desc) components = do ct <- opaqueToCType f_desc let setField j (i, ValueType _ (Rank r) _) = if r == 0 then [C.citems|v->$id:(tupleField j) = obj->$id:(tupleField i);|] else [C.citems|v->$id:(tupleField j) = malloc(sizeof(*v->$id:(tupleField j))); *v->$id:(tupleField j) = *obj->$id:(tupleField i); (void)(*(v->$id:(tupleField j)->mem.references))++;|] pure ( [C.cty|$ty:ct *|], criticalSection ops [C.citems|v = malloc(sizeof($ty:ct)); $items:(concat (zipWith setField [0..] components))|] ) recordProjectFunctions :: OpaqueTypes -> Name -> [(Name, EntryPointType)] -> [ValueType] -> CompilerM op s [Manifest.RecordField] recordProjectFunctions types desc fs vds = do opaque_type <- opaqueToCType desc ctx_ty <- contextType ops <- asks envOperations let onField ((f, et), elems) = do let f' = if isValidCName $ opaqueName desc <> "_" <> nameToText f then nameToText f else zEncodeText (nameToText f) project <- publicName $ "project_" <> opaqueName desc <> "_" <> f' (et_ty, project_items) <- projectField ops et elems headerDecl (OpaqueDecl desc) [C.cedecl|int $id:project($ty:ctx_ty *ctx, $ty:et_ty *out, const $ty:opaque_type *obj);|] libDecl [C.cedecl|int $id:project($ty:ctx_ty *ctx, $ty:et_ty *out, const $ty:opaque_type *obj) { (void)ctx; $ty:et_ty v; $items:project_items *out = v; return 0; }|] pure $ Manifest.RecordField (nameToText f) (entryTypeName et) project mapM onField . zip fs . recordFieldPayloads types (map snd fs) $ zip [0 ..] vds setFieldField :: (C.ToExp a) => Int -> a -> ValueType -> C.Stm setFieldField i e (ValueType _ (Rank r) _) | r == 0 = [C.cstm|v->$id:(tupleField i) = $exp:e;|] | otherwise = [C.cstm|{v->$id:(tupleField i) = malloc(sizeof(*$exp:e)); *v->$id:(tupleField i) = *$exp:e; (void)(*(v->$id:(tupleField i)->mem.references))++;}|] recordNewSetFields :: OpaqueTypes -> [(Name, EntryPointType)] -> [ValueType] -> CompilerM op s ([C.Id], [C.Param], [C.BlockItem]) recordNewSetFields types fs = fmap (L.unzip3 . snd) . mapAccumLM onField 0 . zip fs . recordFieldPayloads types (map snd fs) where onField offset ((f, et), f_vts) = do let param_name = C.toIdent ("f_" <> f) mempty case et of TypeTransparent (ValueType sign (Rank 0) pt) -> do let ct = primAPIType sign pt pure ( offset + 1, ( param_name, [C.cparam|const $ty:ct $id:param_name|], [C.citem|v->$id:(tupleField offset) = $id:param_name;|] ) ) TypeTransparent vt -> do ct <- valueTypeToCType Public vt pure ( offset + 1, ( param_name, [C.cparam|const $ty:ct* $id:param_name|], [C.citem|{v->$id:(tupleField offset) = malloc(sizeof($ty:ct)); *v->$id:(tupleField offset) = *$id:param_name; (void)(*(v->$id:(tupleField offset)->mem.references))++;}|] ) ) TypeOpaque f_desc -> do ct <- opaqueToCType f_desc let param_fields = do i <- [0 ..] pure [C.cexp|$id:param_name->$id:(tupleField i)|] pure ( offset + length f_vts, ( param_name, [C.cparam|const $ty:ct* $id:param_name|], [C.citem|{$stms:(zipWith3 setFieldField [offset ..] param_fields f_vts)}|] ) ) recordNewFunctions :: OpaqueTypes -> Name -> [(Name, EntryPointType)] -> [ValueType] -> CompilerM op s Manifest.CFuncName recordNewFunctions types desc fs vds = do opaque_type <- opaqueToCType desc ctx_ty <- contextType ops <- asks envOperations new <- publicName $ "new_" <> opaqueName desc (_, params, new_stms) <- recordNewSetFields types fs vds headerDecl (OpaqueDecl desc) [C.cedecl|int $id:new($ty:ctx_ty *ctx, $ty:opaque_type** out, $params:params);|] libDecl [C.cedecl|int $id:new($ty:ctx_ty *ctx, $ty:opaque_type** out, $params:params) { $ty:opaque_type* v = malloc(sizeof($ty:opaque_type)); $items:(criticalSection ops new_stms) *out = v; return FUTHARK_SUCCESS; }|] pure new -- Because records and arrays-of-records are very similar in their -- actual representation, we can reuse most of the code. Only indexing -- requires something special. recordArrayProjectFunctions :: OpaqueTypes -> Name -> [(Name, EntryPointType)] -> [ValueType] -> CompilerM op s [Manifest.RecordField] recordArrayProjectFunctions = recordProjectFunctions recordArrayZipFunctions :: OpaqueTypes -> Name -> [(Name, EntryPointType)] -> [ValueType] -> Int -> CompilerM op s Manifest.CFuncName recordArrayZipFunctions types desc fs vds rank = do opaque_type <- opaqueToCType desc ctx_ty <- contextType ops <- asks envOperations new <- publicName $ "zip_" <> opaqueName desc (param_names, params, new_stms) <- recordNewSetFields types fs vds headerDecl (OpaqueDecl desc) [C.cedecl|int $id:new($ty:ctx_ty *ctx, $ty:opaque_type** out, $params:params);|] libDecl [C.cedecl|int $id:new($ty:ctx_ty *ctx, $ty:opaque_type** out, $params:params) { if (!$exp:(sameShape param_names)) { set_error(ctx, strdup("Cannot zip arrays with different shapes.")); return 1; } $ty:opaque_type* v = malloc(sizeof($ty:opaque_type)); $items:(criticalSection ops new_stms) *out = v; return FUTHARK_SUCCESS; }|] pure new where valueShape TypeTransparent {} p = [[C.cexp|$id:p->shape[$int:i]|] | i <- [0 .. rank - 1]] -- We know that the opaque value must contain arrays. valueShape TypeOpaque {} p = [[C.cexp|$id:p->$id:(tupleField 0)->shape[$int:i]|] | i <- [0 .. rank - 1]] sameShape param_names = allTrue $ map allEqual $ L.transpose $ zipWith valueShape (map snd fs) param_names recordArrayIndexFunctions :: Space -> OpaqueTypes -> Name -> Int -> Name -> [ValueType] -> CompilerM op s Manifest.CFuncName recordArrayIndexFunctions space _types desc rank elemtype vds = do index_f <- publicName $ "index_" <> opaqueName desc ctx_ty <- contextType array_ct <- opaqueToCType desc obj_ct <- opaqueToCType elemtype copy <- asks $ opsCopy . envOperations index_items <- collect $ zipWithM_ (setField copy) [0 ..] vds headerDecl (OpaqueDecl desc) [C.cedecl|int $id:index_f($ty:ctx_ty *ctx, $ty:obj_ct **out, $ty:array_ct *arr, $params:index_params);|] libDecl [C.cedecl|int $id:index_f($ty:ctx_ty *ctx, $ty:obj_ct **out, $ty:array_ct *arr, $params:index_params) { int err = 0; if ($exp:in_bounds) { $ty:obj_ct* v = malloc(sizeof($ty:obj_ct)); $items:index_items if (err == 0) { *out = v; } } else { err = 1; set_error(ctx, strdup("Index out of bounds.")); } return err; }|] pure index_f where index_names = ["i" <> prettyText i | i <- [0 .. rank - 1]] index_params = [[C.cparam|typename int64_t $id:k|] | k <- index_names] indexExp pt r shape = cproduct [ [C.cexp|$int:(primByteSize pt::Int)|], csum (zipWith (\x y -> [C.cexp|$id:x * $exp:y|]) index_names strides) ] where strides = do d <- [0 .. r - 1] pure $ cproduct [[C.cexp|$exp:shape[$int:i]|] | i <- [d + 1 .. r - 1]] in_bounds = allTrue [ [C.cexp|$id:p >= 0 && $id:p < arr->$id:(tupleField 0)->shape[$int:i]|] | (p, i) <- zip index_names [0 .. rank - 1] ] setField copy j (ValueType _ (Rank r) pt) | r == rank = -- Easy case: just copy the scalar from the array into the -- variable. copy CopyNoBarrier [C.cexp|(unsigned char*)&v->$id:(tupleField j)|] [C.cexp|0|] DefaultSpace [C.cexp|arr->$id:(tupleField j)->mem.mem|] (indexExp pt rank [C.cexp|arr->$id:(tupleField j)->shape|]) space [C.cexp|$int:(primByteSize pt::Int)|] | otherwise = do -- Tricky case, where we first have to allocate memory. let shape = do i <- [rank .. r - 1] pure [C.cexp|arr->$id:(tupleField j)->shape[$int:i]|] stm [C.cstm|v->$id:(tupleField j) = malloc(sizeof(*v->$id:(tupleField j)));|] prepareNewMem [C.cexp|v->$id:(tupleField j)|] space shape pt -- Now we can copy into the freshly allocated memory. copy CopyNoBarrier [C.cexp|v->$id:(tupleField j)->mem.mem|] [C.cexp|0|] space [C.cexp|arr->$id:(tupleField j)->mem.mem|] (indexExp pt r [C.cexp|arr->$id:(tupleField j)->shape|]) space $ cproduct ([C.cexp|$int:(primByteSize pt::Int)|] : shape) recordArrayShapeFunctions :: Name -> CompilerM op s Manifest.CFuncName recordArrayShapeFunctions desc = do shape_f <- publicName $ "shape_" <> opaqueName desc ctx_ty <- contextType array_ct <- opaqueToCType desc -- We know that the opaque value consists of arrays of at least the -- expected rank, and which have the same outer shape, so we just -- return the shape of the first one. headerDecl (OpaqueDecl desc) [C.cedecl|const typename int64_t* $id:shape_f($ty:ctx_ty *ctx, $ty:array_ct *arr);|] libDecl [C.cedecl|const typename int64_t* $id:shape_f($ty:ctx_ty *ctx, $ty:array_ct *arr) { (void)ctx; return arr->$id:(tupleField 0)->shape; }|] pure shape_f opaqueArrayIndexFunctions :: Space -> OpaqueTypes -> Name -> Int -> Name -> [ValueType] -> CompilerM op s Manifest.CFuncName opaqueArrayIndexFunctions = recordArrayIndexFunctions opaqueArrayShapeFunctions :: Name -> CompilerM op s Manifest.CFuncName opaqueArrayShapeFunctions = recordArrayShapeFunctions sumVariants :: Name -> [(Name, [(EntryPointType, [Int])])] -> [ValueType] -> CompilerM op s [Manifest.SumVariant] sumVariants desc variants vds = do opaque_ty <- opaqueToCType desc ctx_ty <- contextType ops <- asks envOperations let onVariant i (name, payload) = do construct <- publicName $ "new_" <> opaqueName desc <> "_" <> nameToText name destruct <- publicName $ "destruct_" <> opaqueName desc <> "_" <> nameToText name constructFunction ops ctx_ty opaque_ty i construct payload destructFunction ops ctx_ty opaque_ty i destruct payload pure $ Manifest.SumVariant { Manifest.sumVariantName = nameToText name, Manifest.sumVariantPayload = map (entryTypeName . fst) payload, Manifest.sumVariantConstruct = construct, Manifest.sumVariantDestruct = destruct } zipWithM onVariant [0 :: Int ..] variants where constructFunction ops ctx_ty opaque_ty i fname payload = do (params, new_stms) <- unzip <$> zipWithM constructPayload [0 ..] payload let used = concatMap snd payload set_unused_stms <- mapM setUnused $ filter ((`notElem` used) . fst) (zip [0 ..] vds) headerDecl (OpaqueDecl desc) [C.cedecl|int $id:fname($ty:ctx_ty *ctx, $ty:opaque_ty **out, $params:params);|] libDecl [C.cedecl|int $id:fname($ty:ctx_ty *ctx, $ty:opaque_ty **out, $params:params) { (void)ctx; $ty:opaque_ty* v = malloc(sizeof($ty:opaque_ty)); v->$id:(tupleField 0) = $int:i; { $items:(criticalSection ops new_stms) } // Set other fields { $items:set_unused_stms } *out = v; return FUTHARK_SUCCESS; }|] -- We must initialise some of the fields that are unused in this -- variant; specifically the ones corresponding to arrays. This -- has the unfortunate effect that all arrays in the nonused -- constructor are set to have size 0. setUnused (_, ValueType _ (Rank 0) _) = pure [C.citem|{}|] setUnused (i, ValueType signed (Rank rank) pt) = do new_array <- publicName $ "new_" <> arrayName pt signed rank let dims = replicate rank [C.cexp|0|] pure [C.citem|v->$id:(tupleField i) = $id:new_array(ctx, NULL, $args:dims);|] constructPayload j (et, is) = do let param_name = "v" <> show (j :: Int) case et of TypeTransparent (ValueType sign (Rank 0) pt) -> do let ct = primAPIType sign pt i = head is pure ( [C.cparam|const $ty:ct $id:param_name|], [C.citem|v->$id:(tupleField i) = $id:param_name;|] ) TypeTransparent vt -> do ct <- valueTypeToCType Public vt let i = head is pure ( [C.cparam|const $ty:ct* $id:param_name|], [C.citem|{v->$id:(tupleField i) = malloc(sizeof($ty:ct)); memcpy(v->$id:(tupleField i), $id:param_name, sizeof(const $ty:ct)); (void)(*(v->$id:(tupleField i)->mem.references))++;}|] ) TypeOpaque f_desc -> do ct <- opaqueToCType f_desc let param_fields = do i <- [0 ..] pure [C.cexp|$id:param_name->$id:(tupleField i)|] vts = map (vds !!) is pure ( [C.cparam|const $ty:ct* $id:param_name|], [C.citem|{$stms:(zipWith3 setFieldField is param_fields vts)}|] ) destructFunction ops ctx_ty opaque_ty i fname payload = do (params, destruct_stms) <- unzip <$> zipWithM (destructPayload ops) [0 ..] payload headerDecl (OpaqueDecl desc) [C.cedecl|int $id:fname($ty:ctx_ty *ctx, $params:params, const $ty:opaque_ty *obj);|] libDecl [C.cedecl|int $id:fname($ty:ctx_ty *ctx, $params:params, const $ty:opaque_ty *obj) { (void)ctx; assert(obj->$id:(tupleField 0) == $int:i); $stms:destruct_stms return FUTHARK_SUCCESS; }|] destructPayload ops j (et, is) = do let param_name = "v" <> show (j :: Int) (ct, project_items) <- projectField ops et $ zip is $ map (vds !!) is pure ( [C.cparam|$ty:ct* $id:param_name|], [C.cstm|{$ty:ct v; $items:project_items *$id:param_name = v; }|] ) sumVariantFunction :: Name -> CompilerM op s Manifest.CFuncName sumVariantFunction desc = do opaque_ty <- opaqueToCType desc ctx_ty <- contextType variant <- publicName $ "variant_" <> opaqueName desc headerDecl (OpaqueDecl desc) [C.cedecl|int $id:variant($ty:ctx_ty *ctx, const $ty:opaque_ty* v);|] -- This depends on the assumption that the first value always -- encodes the variant. libDecl [C.cedecl|int $id:variant($ty:ctx_ty *ctx, const $ty:opaque_ty* v) { (void)ctx; return v->$id:(tupleField 0); }|] pure variant opaqueExtraOps :: Space -> OpaqueTypes -> Name -> OpaqueType -> [ValueType] -> CompilerM op s (Maybe Manifest.OpaqueExtraOps) opaqueExtraOps _ _ _ (OpaqueType _) _ = pure Nothing opaqueExtraOps _ _types desc (OpaqueSum _ cs) vds = Just . Manifest.OpaqueSum <$> ( Manifest.SumOps <$> sumVariants desc cs vds <*> sumVariantFunction desc ) opaqueExtraOps _ types desc (OpaqueRecord fs) vds = Just . Manifest.OpaqueRecord <$> ( Manifest.RecordOps <$> recordProjectFunctions types desc fs vds <*> recordNewFunctions types desc fs vds ) opaqueExtraOps space types desc (OpaqueRecordArray rank elemtype fs) vds = Just . Manifest.OpaqueRecordArray <$> ( Manifest.RecordArrayOps rank (nameToText elemtype) <$> recordArrayProjectFunctions types desc fs vds <*> recordArrayZipFunctions types desc fs vds rank <*> recordArrayIndexFunctions space types desc rank elemtype vds <*> recordArrayShapeFunctions desc ) opaqueExtraOps space types desc (OpaqueArray rank elemtype _) vds = Just . Manifest.OpaqueArray <$> ( Manifest.OpaqueArrayOps rank (nameToText elemtype) <$> opaqueArrayIndexFunctions space types desc rank elemtype vds <*> opaqueArrayShapeFunctions desc ) opaqueLibraryFunctions :: Space -> OpaqueTypes -> Name -> OpaqueType -> CompilerM op s (Manifest.OpaqueOps, Maybe Manifest.OpaqueExtraOps) opaqueLibraryFunctions space types desc ot = do name <- publicName $ opaqueName desc free_opaque <- publicName $ "free_" <> opaqueName desc store_opaque <- publicName $ "store_" <> opaqueName desc restore_opaque <- publicName $ "restore_" <> opaqueName desc let opaque_type = [C.cty|struct $id:name|] freeComponent i (ValueType signed (Rank rank) pt) = unless (rank == 0) $ do let field = tupleField i free_array <- publicName $ "free_" <> arrayName pt signed rank -- Protect against NULL here, because we also want to use this -- to free partially loaded opaques. stm [C.cstm|if (obj->$id:field != NULL && (tmp = $id:free_array(ctx, obj->$id:field)) != 0) { ret = tmp; }|] storeComponent i (ValueType sign (Rank 0) pt) = let field = tupleField i in ( storageSize pt 0 [C.cexp|NULL|], storeValueHeader sign pt 0 [C.cexp|NULL|] [C.cexp|out|] ++ [C.cstms|memcpy(out, &obj->$id:field, sizeof(obj->$id:field)); out += sizeof(obj->$id:field);|] ) storeComponent i (ValueType sign (Rank rank) pt) = let arr_name = arrayName pt sign rank field = tupleField i shape_array = "futhark_shape_" <> arr_name values_array = "futhark_values_" <> arr_name shape' = [C.cexp|$id:shape_array(ctx, obj->$id:field)|] num_elems = cproduct [[C.cexp|$exp:shape'[$int:j]|] | j <- [0 .. rank - 1]] in ( storageSize pt rank shape', storeValueHeader sign pt rank shape' [C.cexp|out|] ++ [C.cstms|ret |= $id:values_array(ctx, obj->$id:field, (void*)out); out += $exp:num_elems * sizeof($ty:(primStorageType pt));|] ) ctx_ty <- contextType let vds = opaquePayload types ot free_body <- collect $ zipWithM_ freeComponent [0 ..] vds store_body <- collect $ do let (sizes, stores) = unzip $ zipWith storeComponent [0 ..] vds size_vars = map (("size_" ++) . show) [0 .. length sizes - 1] size_sum = csum [[C.cexp|$id:size|] | size <- size_vars] forM_ (zip size_vars sizes) $ \(v, e) -> item [C.citem|typename int64_t $id:v = $exp:e;|] stm [C.cstm|*n = $exp:size_sum;|] stm [C.cstm|if (p != NULL && *p == NULL) { *p = malloc(*n); }|] stm [C.cstm|if (p != NULL) { unsigned char *out = *p; $stms:(concat stores) }|] let restoreComponent i (ValueType sign (Rank 0) pt) = do let field = tupleField i dataptr = "data_" ++ show i stms $ loadValueHeader sign pt 0 [C.cexp|NULL|] [C.cexp|src|] item [C.citem|const void* $id:dataptr = src;|] stm [C.cstm|src += sizeof(obj->$id:field);|] pure [C.cstms|memcpy(&obj->$id:field, $id:dataptr, sizeof(obj->$id:field));|] restoreComponent i (ValueType sign (Rank rank) pt) = do let field = tupleField i arr_name = arrayName pt sign rank new_array = "futhark_new_" <> arr_name dataptr = "data_" <> prettyText i shapearr = "shape_" <> prettyText i dims = [[C.cexp|$id:shapearr[$int:j]|] | j <- [0 .. rank - 1]] num_elems = cproduct dims item [C.citem|typename int64_t $id:shapearr[$int:rank] = {0};|] stms $ loadValueHeader sign pt rank [C.cexp|$id:shapearr|] [C.cexp|src|] item [C.citem|const void* $id:dataptr = src;|] stm [C.cstm|obj->$id:field = NULL;|] stm [C.cstm|src += $exp:num_elems * sizeof($ty:(primStorageType pt));|] pure [C.cstms| obj->$id:field = $id:new_array(ctx, $id:dataptr, $args:dims); if (obj->$id:field == NULL) { err = 1; }|] load_body <- collect $ do loads <- concat <$> zipWithM restoreComponent [0 ..] (opaquePayload types ot) stm [C.cstm|if (err == 0) { $stms:loads }|] headerDecl (OpaqueTypeDecl desc) [C.cedecl|struct $id:name;|] headerDecl (OpaqueDecl desc) [C.cedecl|int $id:free_opaque($ty:ctx_ty *ctx, $ty:opaque_type *obj);|] headerDecl (OpaqueDecl desc) [C.cedecl|int $id:store_opaque($ty:ctx_ty *ctx, const $ty:opaque_type *obj, void **p, size_t *n);|] headerDecl (OpaqueDecl desc) [C.cedecl|$ty:opaque_type* $id:restore_opaque($ty:ctx_ty *ctx, const void *p);|] extra_ops <- opaqueExtraOps space types desc ot vds -- We do not need to enclose most bodies in a critical section, -- because when we operate on the components of the opaque, we are -- calling public API functions that do their own locking. The -- exception is projection, where we fiddle with reference counts. mapM_ libDecl [C.cunit| int $id:free_opaque($ty:ctx_ty *ctx, $ty:opaque_type *obj) { (void)ctx; int ret = 0, tmp; $items:free_body free(obj); return ret; } int $id:store_opaque($ty:ctx_ty *ctx, const $ty:opaque_type *obj, void **p, size_t *n) { (void)ctx; int ret = 0; $items:store_body return ret; } $ty:opaque_type* $id:restore_opaque($ty:ctx_ty *ctx, const void *p) { (void)ctx; int err = 0; const unsigned char *src = p; $ty:opaque_type* obj = malloc(sizeof($ty:opaque_type)); $items:load_body if (err != 0) { int ret = 0, tmp; $items:free_body free(obj); obj = NULL; } return obj; } |] pure ( Manifest.OpaqueOps { Manifest.opaqueFree = free_opaque, Manifest.opaqueStore = store_opaque, Manifest.opaqueRestore = restore_opaque }, extra_ops ) generateArray :: Space -> ((Signedness, PrimType, Int), Publicness) -> CompilerM op s (Maybe (T.Text, Manifest.Type)) generateArray space ((signed, pt, rank), pub) = do name <- publicName $ arrayName pt signed rank let memty = fatMemType space libDecl [C.cedecl|struct $id:name { $ty:memty mem; typename int64_t shape[$int:rank]; };|] ops <- arrayLibraryFunctions pub space pt signed rank let pt_name = prettySigned (signed == Unsigned) pt pretty_name = mconcat (replicate rank "[]") <> pt_name arr_type = [C.cty|struct $id:name*|] case pub of Public -> pure $ Just ( pretty_name, Manifest.TypeArray (typeText arr_type) pt_name rank ops ) Private -> pure Nothing generateOpaque :: Space -> OpaqueTypes -> (Name, OpaqueType) -> CompilerM op s (T.Text, Manifest.Type) generateOpaque space types (desc, ot) = do name <- publicName $ opaqueName desc members <- zipWithM field (opaquePayload types ot) [(0 :: Int) ..] libDecl [C.cedecl|struct $id:name { $sdecls:members };|] (ops, extra_ops) <- opaqueLibraryFunctions space types desc ot let opaque_type = [C.cty|struct $id:name*|] pure ( nameToText desc, Manifest.TypeOpaque (typeText opaque_type) ops extra_ops ) where field vt@(ValueType _ (Rank r) _) i = do ct <- valueTypeToCType Private vt pure $ if r == 0 then [C.csdecl|$ty:ct $id:(tupleField i);|] else [C.csdecl|$ty:ct *$id:(tupleField i);|] generateAPITypes :: Space -> OpaqueTypes -> CompilerM op s (M.Map T.Text Manifest.Type) generateAPITypes arr_space types@(OpaqueTypes opaques) = do mapM_ (findNecessaryArrays . snd) opaques array_ts <- mapM (generateArray arr_space) . M.toList =<< gets compArrayTypes opaque_ts <- mapM (generateOpaque arr_space types) opaques pure $ M.fromList $ catMaybes array_ts <> opaque_ts where -- Ensure that array types will be generated before the opaque -- types that allow projection of them. This is because the -- projection functions somewhat uglily directly poke around in -- the innards to increment reference counts. findNecessaryArrays (OpaqueType _) = pure () findNecessaryArrays (OpaqueArray {}) = pure () findNecessaryArrays (OpaqueRecordArray _ _ fs) = mapM_ (entryPointTypeToCType Public . snd) fs findNecessaryArrays (OpaqueSum _ variants) = mapM_ (mapM_ (entryPointTypeToCType Public . fst) . snd) variants findNecessaryArrays (OpaqueRecord fs) = mapM_ (entryPointTypeToCType Public . snd) fs futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericPython.hs000066400000000000000000001330301475065116200236650ustar00rootroot00000000000000-- | A generic Python code generator which is polymorphic in the type -- of the operations. Concretely, we use this to handle both -- sequential and PyOpenCL Python code. module Futhark.CodeGen.Backends.GenericPython ( compileProg, CompilerMode, Constructor (..), emptyConstructor, compileName, compileVar, compileDim, compileExp, compilePrimExp, compileCode, compilePrimValue, compilePrimType, compilePrimToNp, compilePrimToExtNp, fromStorage, toStorage, Operations (..), DoCopy, defaultOperations, unpackDim, CompilerM (..), OpCompiler, WriteScalar, ReadScalar, Allocate, Copy, EntryOutput, EntryInput, CompilerEnv (..), CompilerState (..), stm, atInit, collect', collect, simpleCall, ) where import Control.Monad import Control.Monad.RWS hiding (reader, writer) import Data.Char (isAlpha, isAlphaNum) import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericPython.AST import Futhark.CodeGen.Backends.GenericPython.Options import Futhark.CodeGen.ImpCode (Count (..), Elements, TExp, elements, le64, untyped) import Futhark.CodeGen.ImpCode qualified as Imp import Futhark.CodeGen.RTS.Python import Futhark.Compiler.Config (CompilerMode (..)) import Futhark.IR.Prop (isBuiltInFunction, subExpVars) import Futhark.IR.Syntax.Core (Space (..)) import Futhark.MonadFreshNames import Futhark.Util (zEncodeText) import Futhark.Util.Pretty (prettyString, prettyText) import Language.Futhark.Primitive hiding (Bool) -- | A substitute expression compiler, tried before the main -- compilation function. type OpCompiler op s = op -> CompilerM op s () -- | Write a scalar to the given memory block with the given index and -- in the given memory space. type WriteScalar op s = PyExp -> PyExp -> PrimType -> Imp.SpaceId -> PyExp -> CompilerM op s () -- | Read a scalar from the given memory block with the given index and -- in the given memory space. type ReadScalar op s = PyExp -> PyExp -> PrimType -> Imp.SpaceId -> CompilerM op s PyExp -- | Allocate a memory block of the given size in the given memory -- space, saving a reference in the given variable name. type Allocate op s = PyExp -> PyExp -> Imp.SpaceId -> CompilerM op s () -- | Copy from one memory block to another. type Copy op s = PyExp -> PyExp -> Imp.Space -> PyExp -> PyExp -> Imp.Space -> PyExp -> PrimType -> CompilerM op s () -- | Perform an 'Imp.Copy'. It is expected that these functions -- are each specialised on which spaces they operate on, so that is -- not part of their arguments. type DoCopy op s = PrimType -> [Count Elements PyExp] -> PyExp -> ( Count Elements PyExp, [Count Elements PyExp] ) -> PyExp -> ( Count Elements PyExp, [Count Elements PyExp] ) -> CompilerM op s () -- | Construct the Python array being returned from an entry point. type EntryOutput op s = VName -> Imp.SpaceId -> PrimType -> Imp.Signedness -> [Imp.DimSize] -> CompilerM op s PyExp -- | Unpack the array being passed to an entry point. type EntryInput op s = PyExp -> Imp.SpaceId -> PrimType -> Imp.Signedness -> [Imp.DimSize] -> PyExp -> CompilerM op s () data Operations op s = Operations { opsWriteScalar :: WriteScalar op s, opsReadScalar :: ReadScalar op s, opsAllocate :: Allocate op s, -- | @(dst,src)@-space mapping to copy functions. opsCopies :: M.Map (Space, Space) (DoCopy op s), opsCompiler :: OpCompiler op s, opsEntryOutput :: EntryOutput op s, opsEntryInput :: EntryInput op s } -- | A set of operations that fail for every operation involving -- non-default memory spaces. Uses plain pointers and @malloc@ for -- memory management. defaultOperations :: Operations op s defaultOperations = Operations { opsWriteScalar = defWriteScalar, opsReadScalar = defReadScalar, opsAllocate = defAllocate, opsCopies = M.singleton (DefaultSpace, DefaultSpace) lmadcopyCPU, opsCompiler = defCompiler, opsEntryOutput = defEntryOutput, opsEntryInput = defEntryInput } where defWriteScalar _ _ _ _ _ = error "Cannot write to non-default memory space because I am dumb" defReadScalar _ _ _ _ = error "Cannot read from non-default memory space" defAllocate _ _ _ = error "Cannot allocate in non-default memory space" defCompiler _ = error "The default compiler cannot compile extended operations" defEntryOutput _ _ _ _ = error "Cannot return array not in default memory space" defEntryInput _ _ _ _ = error "Cannot accept array not in default memory space" data CompilerEnv op s = CompilerEnv { envOperations :: Operations op s, envVarExp :: M.Map String PyExp } envOpCompiler :: CompilerEnv op s -> OpCompiler op s envOpCompiler = opsCompiler . envOperations envReadScalar :: CompilerEnv op s -> ReadScalar op s envReadScalar = opsReadScalar . envOperations envWriteScalar :: CompilerEnv op s -> WriteScalar op s envWriteScalar = opsWriteScalar . envOperations envAllocate :: CompilerEnv op s -> Allocate op s envAllocate = opsAllocate . envOperations envEntryOutput :: CompilerEnv op s -> EntryOutput op s envEntryOutput = opsEntryOutput . envOperations envEntryInput :: CompilerEnv op s -> EntryInput op s envEntryInput = opsEntryInput . envOperations newCompilerEnv :: Operations op s -> CompilerEnv op s newCompilerEnv ops = CompilerEnv { envOperations = ops, envVarExp = mempty } data CompilerState s = CompilerState { compNameSrc :: VNameSource, compInit :: [PyStmt], compUserState :: s } newCompilerState :: VNameSource -> s -> CompilerState s newCompilerState src s = CompilerState { compNameSrc = src, compInit = [], compUserState = s } newtype CompilerM op s a = CompilerM (RWS (CompilerEnv op s) [PyStmt] (CompilerState s) a) deriving ( Functor, Applicative, Monad, MonadState (CompilerState s), MonadReader (CompilerEnv op s), MonadWriter [PyStmt] ) instance MonadFreshNames (CompilerM op s) where getNameSource = gets compNameSrc putNameSource src = modify $ \s -> s {compNameSrc = src} collect :: CompilerM op s () -> CompilerM op s [PyStmt] collect m = pass $ do ((), w) <- listen m pure (w, const mempty) collect' :: CompilerM op s a -> CompilerM op s (a, [PyStmt]) collect' m = pass $ do (x, w) <- listen m pure ((x, w), const mempty) atInit :: PyStmt -> CompilerM op s () atInit x = modify $ \s -> s {compInit = compInit s ++ [x]} stm :: PyStmt -> CompilerM op s () stm x = tell [x] futharkFun :: T.Text -> T.Text futharkFun s = "futhark_" <> zEncodeText s compileOutput :: [Imp.Param] -> [PyExp] compileOutput = map (Var . compileName . Imp.paramName) runCompilerM :: Operations op s -> VNameSource -> s -> CompilerM op s a -> a runCompilerM ops src userstate (CompilerM m) = fst $ evalRWS m (newCompilerEnv ops) (newCompilerState src userstate) standardOptions :: [Option] standardOptions = [ Option { optionLongName = "tuning", optionShortName = Nothing, optionArgument = RequiredArgument "open", optionAction = [Exp $ simpleCall "read_tuning_file" [Var "sizes", Var "optarg"]] }, -- Does not actually do anything for Python backends. Option { optionLongName = "cache-file", optionShortName = Nothing, optionArgument = RequiredArgument "str", optionAction = [Pass] }, Option { optionLongName = "log", optionShortName = Just 'L', optionArgument = NoArgument, optionAction = [Pass] } ] executableOptions :: [Option] executableOptions = standardOptions ++ [ Option { optionLongName = "write-runtime-to", optionShortName = Just 't', optionArgument = RequiredArgument "str", optionAction = [ If (Var "runtime_file") [Exp $ simpleCall "runtime_file.close" []] [], Assign (Var "runtime_file") $ simpleCall "open" [Var "optarg", String "w"] ] }, Option { optionLongName = "runs", optionShortName = Just 'r', optionArgument = RequiredArgument "str", optionAction = [ Assign (Var "num_runs") $ Var "optarg", Assign (Var "do_warmup_run") $ Bool True ] }, Option { optionLongName = "entry-point", optionShortName = Just 'e', optionArgument = RequiredArgument "str", optionAction = [Assign (Var "entry_point") $ Var "optarg"] }, Option { optionLongName = "binary-output", optionShortName = Just 'b', optionArgument = NoArgument, optionAction = [Assign (Var "binary_output") $ Bool True] } ] functionExternalValues :: Imp.EntryPoint -> [Imp.ExternalValue] functionExternalValues entry = map snd (Imp.entryPointResults entry) ++ map snd (Imp.entryPointArgs entry) -- | Is this name a valid Python identifier? If not, it should be escaped -- before being emitted. isValidPyName :: T.Text -> Bool isValidPyName = maybe True check . T.uncons where check (c, cs) = isAlpha c && T.all constituent cs constituent c = isAlphaNum c || c == '_' -- | If the provided text is a valid identifier, then return it -- verbatim. Otherwise, escape it such that it becomes valid. escapeName :: Name -> T.Text escapeName v | isValidPyName v' = v' | otherwise = zEncodeText v' where v' = nameToText v opaqueDefs :: Imp.Functions a -> M.Map T.Text [PyExp] opaqueDefs (Imp.Functions funs) = mconcat . map evd . concatMap functionExternalValues . mapMaybe (Imp.functionEntry . snd) $ funs where evd Imp.TransparentValue {} = mempty evd (Imp.OpaqueValue name vds) = M.singleton (nameToText name) $ map (String . vd) vds vd (Imp.ScalarValue pt s _) = readTypeEnum pt s vd (Imp.ArrayValue _ _ pt s dims) = mconcat (replicate (length dims) "[]") <> readTypeEnum pt s -- | The class generated by the code generator must have a -- constructor, although it can be vacuous. data Constructor = Constructor [String] [PyStmt] -- | A constructor that takes no arguments and does nothing. emptyConstructor :: Constructor emptyConstructor = Constructor ["self"] [Pass] constructorToFunDef :: Constructor -> [PyStmt] -> PyFunDef constructorToFunDef (Constructor params body) at_init = Def "__init__" params $ body <> at_init compileProg :: (MonadFreshNames m) => CompilerMode -> String -> Constructor -> [PyStmt] -> [PyStmt] -> Operations op s -> s -> [PyStmt] -> [Option] -> Imp.Definitions op -> m T.Text compileProg mode class_name constructor imports defines ops userstate sync options prog = do src <- getNameSource let prog' = runCompilerM ops src userstate compileProg' pure . prettyText . PyProg $ imports ++ [ Import "argparse" Nothing, Assign (Var "sizes") $ Dict [] ] ++ defines ++ [ Escape valuesPy, Escape memoryPy, Escape panicPy, Escape tuningPy, Escape scalarPy, Escape serverPy ] ++ prog' where Imp.Definitions _types consts (Imp.Functions funs) = prog compileProg' = withConstantSubsts consts $ do compileConstants consts definitions <- mapM compileFunc funs at_inits <- gets compInit let constructor' = constructorToFunDef constructor at_inits case mode of ToLibrary -> do (entry_points, entry_point_types) <- unzip . catMaybes <$> mapM (compileEntryFun sync DoNotReturnTiming) funs pure [ ClassDef $ Class class_name $ Assign (Var "entry_points") (Dict entry_point_types) : Assign (Var "opaques") (Dict $ zip (map String opaque_names) (map Tuple opaque_payloads)) : map FunDef (constructor' : definitions ++ entry_points) ] ToServer -> do (entry_points, entry_point_types) <- unzip . catMaybes <$> mapM (compileEntryFun sync ReturnTiming) funs pure $ parse_options_server ++ [ ClassDef ( Class class_name $ Assign (Var "entry_points") (Dict entry_point_types) : Assign (Var "opaques") (Dict $ zip (map String opaque_names) (map Tuple opaque_payloads)) : map FunDef (constructor' : definitions ++ entry_points) ), Assign (Var "server") (simpleCall "Server" [simpleCall class_name []]), Exp $ simpleCall "server.run" [] ] ToExecutable -> do let classinst = Assign (Var "self") $ simpleCall class_name [] (entry_point_defs, entry_point_names, entry_points) <- unzip3 . catMaybes <$> mapM (callEntryFun sync) funs pure $ parse_options_executable ++ ClassDef ( Class class_name $ map FunDef $ constructor' : definitions ) : classinst : map FunDef entry_point_defs ++ selectEntryPoint entry_point_names entry_points parse_options_executable = Assign (Var "runtime_file") None : Assign (Var "do_warmup_run") (Bool False) : Assign (Var "num_runs") (Integer 1) : Assign (Var "entry_point") (String "main") : Assign (Var "binary_output") (Bool False) : generateOptionParser (executableOptions ++ options) parse_options_server = generateOptionParser (standardOptions ++ options) (opaque_names, opaque_payloads) = unzip $ M.toList $ opaqueDefs $ Imp.defFuns prog selectEntryPoint entry_point_names entry_points = [ Assign (Var "entry_points") $ Dict $ zip (map String entry_point_names) entry_points, Assign (Var "entry_point_fun") $ simpleCall "entry_points.get" [Var "entry_point"], If (BinOp "==" (Var "entry_point_fun") None) [ Exp $ simpleCall "sys.exit" [ Call ( Field (String "No entry point '{}'. Select another with --entry point. Options are:\n{}") "format" ) [ Arg $ Var "entry_point", Arg $ Call (Field (String "\n") "join") [Arg $ simpleCall "entry_points.keys" []] ] ] ] [Exp $ simpleCall "entry_point_fun" []] ] withConstantSubsts :: Imp.Constants op -> CompilerM op s a -> CompilerM op s a withConstantSubsts (Imp.Constants ps _) = local $ \env -> env {envVarExp = foldMap constExp ps} where constExp p = M.singleton (compileName $ Imp.paramName p) (Index (Var "self.constants") $ IdxExp $ String $ prettyText $ Imp.paramName p) compileConstants :: Imp.Constants op -> CompilerM op s () compileConstants (Imp.Constants _ init_consts) = do atInit $ Assign (Var "self.constants") $ Dict [] mapM_ atInit =<< collect (compileCode init_consts) compileFunc :: (Name, Imp.Function op) -> CompilerM op s PyFunDef compileFunc (fname, Imp.Function _ outputs inputs body) = do body' <- collect $ compileCode body let inputs' = map (compileName . Imp.paramName) inputs let ret = Return $ tupleOrSingle $ compileOutput outputs pure $ Def (T.unpack $ futharkFun $ nameToText fname) ("self" : inputs') $ body' ++ [ret] tupleOrSingle :: [PyExp] -> PyExp tupleOrSingle [e] = e tupleOrSingle es = Tuple es -- | A 'Call' where the function is a variable and every argument is a -- simple 'Arg'. simpleCall :: String -> [PyExp] -> PyExp simpleCall fname = Call (Var fname) . map Arg compileName :: VName -> String compileName = T.unpack . zEncodeText . prettyText compileDim :: Imp.DimSize -> CompilerM op s PyExp compileDim (Imp.Constant v) = pure $ compilePrimValue v compileDim (Imp.Var v) = compileVar v unpackDim :: PyExp -> Imp.DimSize -> Int32 -> CompilerM op s () unpackDim arr_name (Imp.Constant c) i = do let shape_name = Field arr_name "shape" let constant_c = compilePrimValue c let constant_i = Integer $ toInteger i stm $ Assert (BinOp "==" constant_c (Index shape_name $ IdxExp constant_i)) $ String "Entry point arguments have invalid sizes." unpackDim arr_name (Imp.Var var) i = do let shape_name = Field arr_name "shape" src = Index shape_name $ IdxExp $ Integer $ toInteger i var' <- compileVar var stm $ If (BinOp "==" var' None) [Assign var' $ simpleCall "np.int64" [src]] [ Assert (BinOp "==" var' src) $ String "Error: entry point arguments have invalid sizes." ] entryPointOutput :: Imp.ExternalValue -> CompilerM op s PyExp entryPointOutput (Imp.OpaqueValue desc vs) = simpleCall "opaque" . (String (prettyText desc) :) <$> mapM (entryPointOutput . Imp.TransparentValue) vs entryPointOutput (Imp.TransparentValue (Imp.ScalarValue bt ept name)) = do name' <- compileVar name pure $ simpleCall tf [name'] where tf = compilePrimToExtNp bt ept entryPointOutput (Imp.TransparentValue (Imp.ArrayValue mem (Imp.Space sid) bt ept dims)) = do pack_output <- asks envEntryOutput pack_output mem sid bt ept dims entryPointOutput (Imp.TransparentValue (Imp.ArrayValue mem _ bt ept dims)) = do mem' <- compileVar mem dims' <- mapM compileDim dims pure $ simpleCall "np.reshape" [ Index (Call (Field mem' "view") [Arg $ Var $ compilePrimToExtNp bt ept]) (IdxRange (Integer 0) (foldl1 (BinOp "*") dims')), Tuple dims' ] badInput :: Int -> PyExp -> T.Text -> PyStmt badInput i e t = Raise $ simpleCall "TypeError" [ Call (Field (String err_msg) "format") [Arg (String t), Arg $ simpleCall "type" [e], Arg e] ] where err_msg = T.unlines [ "Argument #" <> prettyText i <> " has invalid value", "Futhark type: {}", "Argument has Python type {} and value: {}" ] badInputType :: Int -> PyExp -> T.Text -> PyExp -> PyExp -> PyStmt badInputType i e t de dg = Raise $ simpleCall "TypeError" [ Call (Field (String err_msg) "format") [Arg (String t), Arg $ simpleCall "type" [e], Arg e, Arg de, Arg dg] ] where err_msg = T.unlines [ "Argument #" <> prettyText i <> " has invalid value", "Futhark type: {}", "Argument has Python type {} and value: {}", "Expected array with elements of dtype: {}", "The array given has elements of dtype: {}" ] badInputDim :: Int -> PyExp -> T.Text -> Int -> PyStmt badInputDim i e typ dimf = Raise $ simpleCall "TypeError" [ Call (Field (String err_msg) "format") [Arg eft, Arg aft] ] where eft = String (mconcat (replicate dimf "[]") <> typ) aft = BinOp "+" (BinOp "*" (String "[]") (Field e "ndim")) (String typ) err_msg = T.unlines [ "Argument #" <> prettyText i <> " has invalid value", "Dimensionality mismatch", "Expected Futhark type: {}", "Bad Python value passed", "Actual Futhark type: {}" ] declEntryPointInputSizes :: [Imp.ExternalValue] -> CompilerM op s () declEntryPointInputSizes = mapM_ onSize . concatMap sizes where sizes (Imp.TransparentValue v) = valueSizes v sizes (Imp.OpaqueValue _ vs) = concatMap valueSizes vs valueSizes (Imp.ArrayValue _ _ _ _ dims) = subExpVars dims valueSizes Imp.ScalarValue {} = [] onSize v = stm $ Assign (Var (compileName v)) None entryPointInput :: (Int, Imp.ExternalValue, PyExp) -> CompilerM op s () entryPointInput (i, Imp.OpaqueValue desc vs, e) = do let type_is_ok = BinOp "and" (simpleCall "isinstance" [e, Var "opaque"]) (BinOp "==" (Field e "desc") (String (nameToText desc))) stm $ If (UnOp "not" type_is_ok) [badInput i e (nameToText desc)] [] mapM_ entryPointInput $ zip3 (repeat i) (map Imp.TransparentValue vs) $ map (Index (Field e "data") . IdxExp . Integer) [0 ..] entryPointInput (i, Imp.TransparentValue (Imp.ScalarValue bt s name), e) = do vname' <- compileVar name let -- HACK: A Numpy int64 will signal an OverflowError if we pass -- it a number bigger than 2**63. This does not happen if we -- pass e.g. int8 a number bigger than 2**7. As a workaround, -- we first go through the corresponding ctypes type, which does -- not have this problem. ctobject = compilePrimType bt npobject = compilePrimToNp bt npcall = simpleCall npobject [ case bt of IntType Int64 -> simpleCall ctobject [e] _ -> e ] stm $ Try [Assign vname' npcall] [ Catch (Tuple [Var "TypeError", Var "AssertionError"]) [badInput i e $ prettySigned (s == Imp.Unsigned) bt] ] entryPointInput (i, Imp.TransparentValue (Imp.ArrayValue mem (Imp.Space sid) bt ept dims), e) = do unpack_input <- asks envEntryInput mem' <- compileVar mem unpack <- collect $ unpack_input mem' sid bt ept dims e stm $ Try unpack [ Catch (Tuple [Var "TypeError", Var "AssertionError"]) [ badInput i e $ mconcat (replicate (length dims) "[]") <> prettySigned (ept == Imp.Unsigned) bt ] ] entryPointInput (i, Imp.TransparentValue (Imp.ArrayValue mem _ t s dims), e) = do let type_is_wrong = UnOp "not" $ BinOp "in" (simpleCall "type" [e]) $ List [Var "np.ndarray"] let dtype_is_wrong = UnOp "not" $ BinOp "==" (Field e "dtype") $ Var $ compilePrimToExtNp t s let dim_is_wrong = UnOp "not" $ BinOp "==" (Field e "ndim") $ Integer $ toInteger $ length dims stm $ If type_is_wrong [ badInput i e $ mconcat (replicate (length dims) "[]") <> prettySigned (s == Imp.Unsigned) t ] [] stm $ If dtype_is_wrong [ badInputType i e (mconcat (replicate (length dims) "[]") <> prettySigned (s == Imp.Unsigned) t) (simpleCall "np.dtype" [Var (compilePrimToExtNp t s)]) (Field e "dtype") ] [] stm $ If dim_is_wrong [badInputDim i e (prettySigned (s == Imp.Unsigned) t) (length dims)] [] zipWithM_ (unpackDim e) dims [0 ..] dest <- compileVar mem let unwrap_call = simpleCall "unwrapArray" [e] stm $ Assign dest unwrap_call extValueDescName :: Imp.ExternalValue -> T.Text extValueDescName (Imp.TransparentValue v) = extName $ T.pack $ compileName $ valueDescVName v extValueDescName (Imp.OpaqueValue desc []) = extName $ zEncodeText $ nameToText desc extValueDescName (Imp.OpaqueValue desc (v : _)) = extName $ zEncodeText (nameToText desc) <> "_" <> prettyText (baseTag (valueDescVName v)) extName :: T.Text -> T.Text extName = (<> "_ext") valueDescVName :: Imp.ValueDesc -> VName valueDescVName (Imp.ScalarValue _ _ vname) = vname valueDescVName (Imp.ArrayValue vname _ _ _ _) = vname -- Key into the FUTHARK_PRIMTYPES dict. readTypeEnum :: PrimType -> Imp.Signedness -> T.Text readTypeEnum (IntType Int8) Imp.Unsigned = "u8" readTypeEnum (IntType Int16) Imp.Unsigned = "u16" readTypeEnum (IntType Int32) Imp.Unsigned = "u32" readTypeEnum (IntType Int64) Imp.Unsigned = "u64" readTypeEnum (IntType Int8) Imp.Signed = "i8" readTypeEnum (IntType Int16) Imp.Signed = "i16" readTypeEnum (IntType Int32) Imp.Signed = "i32" readTypeEnum (IntType Int64) Imp.Signed = "i64" readTypeEnum (FloatType Float16) _ = "f16" readTypeEnum (FloatType Float32) _ = "f32" readTypeEnum (FloatType Float64) _ = "f64" readTypeEnum Imp.Bool _ = "bool" readTypeEnum Unit _ = "bool" readInput :: Imp.ExternalValue -> PyStmt readInput (Imp.OpaqueValue desc _) = Raise $ simpleCall "Exception" [String $ "Cannot read argument of type " <> nameToText desc <> "."] readInput decl@(Imp.TransparentValue (Imp.ScalarValue bt ept _)) = let type_name = readTypeEnum bt ept in Assign (Var $ T.unpack $ extValueDescName decl) $ simpleCall "read_value" [String type_name] readInput decl@(Imp.TransparentValue (Imp.ArrayValue _ _ bt ept dims)) = let type_name = readTypeEnum bt ept in Assign (Var $ T.unpack $ extValueDescName decl) $ simpleCall "read_value" [String $ mconcat (replicate (length dims) "[]") <> type_name] printValue :: [(Imp.ExternalValue, PyExp)] -> CompilerM op s [PyStmt] printValue = fmap concat . mapM (uncurry printValue') where -- We copy non-host arrays to the host before printing. This is -- done in a hacky way - we assume the value has a .get()-method -- that returns an equivalent Numpy array. This works for PyOpenCL, -- but we will probably need yet another plugin mechanism here in -- the future. printValue' (Imp.OpaqueValue desc _) _ = pure [ Exp $ simpleCall "sys.stdout.write" [String $ "# nameToText desc <> ">"] ] printValue' (Imp.TransparentValue (Imp.ArrayValue mem (Space _) bt ept shape)) e = printValue' (Imp.TransparentValue (Imp.ArrayValue mem DefaultSpace bt ept shape)) $ simpleCall (prettyString e ++ ".get") [] printValue' (Imp.TransparentValue _) e = pure [ Exp $ Call (Var "write_value") [ Arg e, ArgKeyword "binary" (Var "binary_output") ], Exp $ simpleCall "sys.stdout.write" [String "\n"] ] prepareEntry :: Imp.EntryPoint -> (Name, Imp.Function op) -> CompilerM op s ( [String], [PyStmt], [PyStmt], [PyStmt], [(Imp.ExternalValue, PyExp)] ) prepareEntry (Imp.EntryPoint _ results args) (fname, Imp.Function _ outputs inputs _) = do let output_paramNames = map (compileName . Imp.paramName) outputs funTuple = tupleOrSingle $ fmap Var output_paramNames prepareIn <- collect $ do declEntryPointInputSizes $ map snd args mapM_ entryPointInput . zip3 [0 ..] (map snd args) $ map (Var . T.unpack . extValueDescName . snd) args (res, prepareOut) <- collect' $ mapM (entryPointOutput . snd) results let argexps_lib = map (compileName . Imp.paramName) inputs fname' = "self." <> futharkFun (nameToText fname) -- We ignore overflow errors and the like for executable entry -- points. These are (somewhat) well-defined in Futhark. ignore s = ArgKeyword s $ String "ignore" errstate = Call (Var "np.errstate") $ map ignore ["divide", "over", "under", "invalid"] call argexps = [ With errstate [Assign funTuple $ simpleCall (T.unpack fname') (fmap Var argexps)] ] pure ( map (T.unpack . extValueDescName . snd) args, prepareIn, call argexps_lib, prepareOut, zip (map snd results) res ) data ReturnTiming = ReturnTiming | DoNotReturnTiming compileEntryFun :: [PyStmt] -> ReturnTiming -> (Name, Imp.Function op) -> CompilerM op s (Maybe (PyFunDef, (PyExp, PyExp))) compileEntryFun sync timing fun | Just entry <- Imp.functionEntry $ snd fun = do let ename = Imp.entryPointName entry (params, prepareIn, body_lib, prepareOut, res) <- prepareEntry entry fun let (maybe_sync, ret) = case timing of DoNotReturnTiming -> ( [], Return $ tupleOrSingle $ map snd res ) ReturnTiming -> ( sync, Return $ Tuple [ Var "runtime", tupleOrSingle $ map snd res ] ) (pts, rts) = entryTypes entry do_run = Assign (Var "time_start") (simpleCall "time.time" []) : body_lib ++ maybe_sync ++ [ Assign (Var "runtime") $ BinOp "-" (toMicroseconds (simpleCall "time.time" [])) (toMicroseconds (Var "time_start")) ] pure $ Just ( Def (T.unpack (escapeName ename)) ("self" : params) $ prepareIn ++ do_run ++ prepareOut ++ sync ++ [ret], ( String (nameToText ename), Tuple [ String (escapeName ename), List (map String pts), List (map String rts) ] ) ) | otherwise = pure Nothing entryTypes :: Imp.EntryPoint -> ([T.Text], [T.Text]) entryTypes (Imp.EntryPoint _ res args) = (map descArg args, map desc res) where descArg ((_, u), d) = desc (u, d) desc (u, Imp.OpaqueValue d _) = prettyText u <> nameToText d desc (u, Imp.TransparentValue (Imp.ScalarValue pt s _)) = prettyText u <> readTypeEnum pt s desc (u, Imp.TransparentValue (Imp.ArrayValue _ _ pt s dims)) = prettyText u <> mconcat (replicate (length dims) "[]") <> readTypeEnum pt s callEntryFun :: [PyStmt] -> (Name, Imp.Function op) -> CompilerM op s (Maybe (PyFunDef, T.Text, PyExp)) callEntryFun _ (_, Imp.Function Nothing _ _ _) = pure Nothing callEntryFun pre_timing fun@(fname, Imp.Function (Just entry) _ _ _) = do let Imp.EntryPoint ename _ decl_args = entry (_, prepare_in, body_bin, _, res) <- prepareEntry entry fun let str_input = map (readInput . snd) decl_args end_of_input = [Exp $ simpleCall "end_of_input" [String $ prettyText fname]] exitcall = [Exp $ simpleCall "sys.exit" [Field (String "Assertion.{} failed") "format(e)"]] except' = Catch (Var "AssertionError") exitcall do_run = body_bin ++ pre_timing (do_run_with_timing, close_runtime_file) = addTiming do_run do_warmup_run = If (Var "do_warmup_run") do_run [] do_num_runs = For "i" (simpleCall "range" [simpleCall "int" [Var "num_runs"]]) do_run_with_timing str_output <- printValue res let fname' = "entry_" ++ T.unpack (escapeName fname) pure $ Just ( Def fname' [] $ str_input ++ end_of_input ++ prepare_in ++ [Try [do_warmup_run, do_num_runs] [except']] ++ [close_runtime_file] ++ str_output, nameToText ename, Var fname' ) addTiming :: [PyStmt] -> ([PyStmt], PyStmt) addTiming statements = ( [Assign (Var "time_start") $ simpleCall "time.time" []] ++ statements ++ [ Assign (Var "time_end") $ simpleCall "time.time" [], If (Var "runtime_file") print_runtime [] ], If (Var "runtime_file") [Exp $ simpleCall "runtime_file.close" []] [] ) where print_runtime = [ Exp $ simpleCall "runtime_file.write" [ simpleCall "str" [ BinOp "-" (toMicroseconds (Var "time_end")) (toMicroseconds (Var "time_start")) ] ], Exp $ simpleCall "runtime_file.write" [String "\n"], Exp $ simpleCall "runtime_file.flush" [] ] toMicroseconds :: PyExp -> PyExp toMicroseconds x = simpleCall "int" [BinOp "*" x $ Integer 1000000] compileUnOp :: Imp.UnOp -> String compileUnOp op = case op of Neg Imp.Bool -> "not" Neg _ -> "-" Complement {} -> "~" Abs {} -> "abs" FAbs {} -> "abs" SSignum {} -> "ssignum" USignum {} -> "usignum" FSignum {} -> "np.sign" compileBinOpLike :: (Monad m) => (v -> m PyExp) -> Imp.PrimExp v -> Imp.PrimExp v -> m (PyExp, PyExp, String -> m PyExp) compileBinOpLike f x y = do x' <- compilePrimExp f x y' <- compilePrimExp f y let simple s = pure $ BinOp s x' y' pure (x', y', simple) -- | The ctypes type corresponding to a 'PrimType'. compilePrimType :: PrimType -> String compilePrimType t = case t of IntType Int8 -> "ct.c_int8" IntType Int16 -> "ct.c_int16" IntType Int32 -> "ct.c_int32" IntType Int64 -> "ct.c_int64" FloatType Float16 -> "ct.c_uint16" FloatType Float32 -> "ct.c_float" FloatType Float64 -> "ct.c_double" Imp.Bool -> "ct.c_bool" Unit -> "ct.c_bool" -- | The Numpy type corresponding to a 'PrimType'. compilePrimToNp :: Imp.PrimType -> String compilePrimToNp bt = case bt of IntType Int8 -> "np.int8" IntType Int16 -> "np.int16" IntType Int32 -> "np.int32" IntType Int64 -> "np.int64" FloatType Float16 -> "np.float16" FloatType Float32 -> "np.float32" FloatType Float64 -> "np.float64" Imp.Bool -> "np.byte" Unit -> "np.byte" -- | The Numpy type corresponding to a 'PrimType', taking sign into account. compilePrimToExtNp :: Imp.PrimType -> Imp.Signedness -> String compilePrimToExtNp bt ept = case (bt, ept) of (IntType Int8, Imp.Unsigned) -> "np.uint8" (IntType Int16, Imp.Unsigned) -> "np.uint16" (IntType Int32, Imp.Unsigned) -> "np.uint32" (IntType Int64, Imp.Unsigned) -> "np.uint64" (IntType Int8, _) -> "np.int8" (IntType Int16, _) -> "np.int16" (IntType Int32, _) -> "np.int32" (IntType Int64, _) -> "np.int64" (FloatType Float16, _) -> "np.float16" (FloatType Float32, _) -> "np.float32" (FloatType Float64, _) -> "np.float64" (Imp.Bool, _) -> "np.bool_" (Unit, _) -> "np.byte" -- | Convert from scalar to storage representation for the given type. toStorage :: PrimType -> PyExp -> PyExp toStorage (FloatType Float16) e = simpleCall "ct.c_int16" [simpleCall "futhark_to_bits16" [e]] toStorage t e = simpleCall (compilePrimType t) [e] -- | Convert from storage to scalar representation for the given type. fromStorage :: PrimType -> PyExp -> PyExp fromStorage (FloatType Float16) e = simpleCall "futhark_from_bits16" [simpleCall "np.int16" [e]] fromStorage t e = simpleCall (compilePrimToNp t) [e] compilePrimValue :: Imp.PrimValue -> PyExp compilePrimValue (IntValue (Int8Value v)) = simpleCall "np.int8" [Integer $ toInteger v] compilePrimValue (IntValue (Int16Value v)) = simpleCall "np.int16" [Integer $ toInteger v] compilePrimValue (IntValue (Int32Value v)) = simpleCall "np.int32" [Integer $ toInteger v] compilePrimValue (IntValue (Int64Value v)) = simpleCall "np.int64" [Integer $ toInteger v] compilePrimValue (FloatValue (Float16Value v)) | isInfinite v = if v > 0 then Var "np.float16(np.inf)" else Var "np.float16(-np.inf)" | isNaN v = Var "np.float16(np.nan)" | otherwise = simpleCall "np.float16" [Float $ fromRational $ toRational v] compilePrimValue (FloatValue (Float32Value v)) | isInfinite v = if v > 0 then Var "np.float32(np.inf)" else Var "np.float32(-np.inf)" | isNaN v = Var "np.float32(np.nan)" | otherwise = simpleCall "np.float32" [Float $ fromRational $ toRational v] compilePrimValue (FloatValue (Float64Value v)) | isInfinite v = if v > 0 then Var "np.inf" else Var "-np.inf" | isNaN v = Var "np.float64(np.nan)" | otherwise = simpleCall "np.float64" [Float $ fromRational $ toRational v] compilePrimValue (BoolValue v) = Bool v compilePrimValue UnitValue = Var "np.byte(0)" compileVar :: VName -> CompilerM op s PyExp compileVar v = asks $ fromMaybe (Var v') . M.lookup v' . envVarExp where v' = compileName v -- | Tell me how to compile a @v@, and I'll Compile any @PrimExp v@ for you. compilePrimExp :: (Monad m) => (v -> m PyExp) -> Imp.PrimExp v -> m PyExp compilePrimExp _ (Imp.ValueExp v) = pure $ compilePrimValue v compilePrimExp f (Imp.LeafExp v _) = f v compilePrimExp f (Imp.BinOpExp op x y) = do (x', y', simple) <- compileBinOpLike f x y case op of Add {} -> simple "+" Sub {} -> simple "-" Mul {} -> simple "*" FAdd {} -> simple "+" FSub {} -> simple "-" FMul {} -> simple "*" FDiv {} -> simple "/" FMod {} -> simple "%" Xor {} -> simple "^" And {} -> simple "&" Or {} -> simple "|" Shl {} -> simple "<<" LogAnd {} -> simple "and" LogOr {} -> simple "or" _ -> pure $ simpleCall (prettyString op) [x', y'] compilePrimExp f (Imp.ConvOpExp conv x) = do x' <- compilePrimExp f x pure $ simpleCall (prettyString conv) [x'] compilePrimExp f (Imp.CmpOpExp cmp x y) = do (x', y', simple) <- compileBinOpLike f x y case cmp of CmpEq {} -> simple "==" FCmpLt {} -> simple "<" FCmpLe {} -> simple "<=" CmpLlt -> simple "<" CmpLle -> simple "<=" _ -> pure $ simpleCall (prettyString cmp) [x', y'] compilePrimExp f (Imp.UnOpExp op exp1) = UnOp (compileUnOp op) <$> compilePrimExp f exp1 compilePrimExp f (Imp.FunExp h args _) = simpleCall (T.unpack (futharkFun (prettyText h))) <$> mapM (compilePrimExp f) args compileExp :: Imp.Exp -> CompilerM op s PyExp compileExp = compilePrimExp compileVar errorMsgString :: Imp.ErrorMsg Imp.Exp -> CompilerM op s (T.Text, [PyExp]) errorMsgString (Imp.ErrorMsg parts) = do let onPart (Imp.ErrorString s) = pure ("%s", String s) onPart (Imp.ErrorVal IntType {} x) = ("%d",) <$> compileExp x onPart (Imp.ErrorVal FloatType {} x) = ("%f",) <$> compileExp x onPart (Imp.ErrorVal Imp.Bool x) = ("%r",) <$> compileExp x onPart (Imp.ErrorVal Unit {} x) = ("%r",) <$> compileExp x (formatstrs, formatargs) <- mapAndUnzipM onPart parts pure (mconcat formatstrs, formatargs) generateRead :: PyExp -> PyExp -> PrimType -> Space -> CompilerM op s PyExp generateRead _ _ Unit _ = pure (compilePrimValue UnitValue) generateRead _ _ _ ScalarSpace {} = error "GenericPython.generateRead: ScalarSpace" generateRead src iexp pt DefaultSpace = do let pt' = compilePrimType pt pure $ fromStorage pt $ simpleCall "indexArray" [src, iexp, Var pt'] generateRead src iexp pt (Space space) = do reader <- asks envReadScalar reader src iexp pt space generateWrite :: PyExp -> PyExp -> PrimType -> Space -> PyExp -> CompilerM op s () generateWrite _ _ Unit _ _ = pure () generateWrite _ _ _ ScalarSpace {} _ = do error "GenericPython.generateWrite: ScalarSpace" generateWrite dst iexp pt (Imp.Space space) elemexp = do writer <- asks envWriteScalar writer dst iexp pt space elemexp generateWrite dst iexp _ DefaultSpace elemexp = stm $ Exp $ simpleCall "writeScalarArray" [dst, iexp, elemexp] -- | Compile an 'Copy' using sequential nested loops, but -- parameterised over how to do the reads and writes. compileCopyWith :: [Count Elements (TExp Int64)] -> (PyExp -> PyExp -> CompilerM op s ()) -> ( Count Elements (TExp Int64), [Count Elements (TExp Int64)] ) -> (PyExp -> CompilerM op s PyExp) -> ( Count Elements (TExp Int64), [Count Elements (TExp Int64)] ) -> CompilerM op s () compileCopyWith shape doWrite dst_lmad doRead src_lmad = do let (dstoffset, dststrides) = dst_lmad (srcoffset, srcstrides) = src_lmad shape' <- mapM (compileExp . untyped . unCount) shape body <- collect $ do dst_i <- compileExp . untyped . unCount $ dstoffset + sum (zipWith (*) is' dststrides) src_i <- compileExp . untyped . unCount $ srcoffset + sum (zipWith (*) is' srcstrides) doWrite dst_i =<< doRead src_i mapM_ stm $ loops (zip is shape') body where r = length shape is = map (VName "i") [0 .. r - 1] is' :: [Count Elements (TExp Int64)] is' = map (elements . le64) is loops [] body = body loops ((i, n) : ins) body = [For (compileName i) (simpleCall "range" [n]) $ loops ins body] -- | Compile an 'Copy' using sequential nested loops and -- 'Imp.Read'/'Imp.Write' of individual scalars. This always works, -- but can be pretty slow if those reads and writes are costly. compileCopy :: PrimType -> [Count Elements (TExp Int64)] -> (VName, Space) -> ( Count Elements (TExp Int64), [Count Elements (TExp Int64)] ) -> (VName, Space) -> ( Count Elements (TExp Int64), [Count Elements (TExp Int64)] ) -> CompilerM op s () compileCopy t shape (dst, dstspace) dst_lmad (src, srcspace) src_lmad = do src' <- compileVar src dst' <- compileVar dst let doWrite dst_i = generateWrite dst' dst_i t dstspace doRead src_i = generateRead src' src_i t srcspace compileCopyWith shape doWrite dst_lmad doRead src_lmad compileCode :: Imp.Code op -> CompilerM op s () compileCode Imp.DebugPrint {} = pure () compileCode Imp.TracePrint {} = pure () compileCode (Imp.Op op) = join $ asks envOpCompiler <*> pure op compileCode (Imp.If cond tb fb) = do cond' <- compileExp $ Imp.untyped cond tb' <- collect $ compileCode tb fb' <- collect $ compileCode fb stm $ If cond' tb' fb' compileCode (c1 Imp.:>>: c2) = do compileCode c1 compileCode c2 compileCode (Imp.While cond body) = do cond' <- compileExp $ Imp.untyped cond body' <- collect $ compileCode body stm $ While cond' body' compileCode (Imp.For i bound body) = do bound' <- compileExp bound let i' = compileName i body' <- collect $ compileCode body counter <- prettyString <$> newVName "counter" one <- prettyString <$> newVName "one" stm $ Assign (Var i') $ simpleCall (compilePrimToNp (Imp.primExpType bound)) [Integer 0] stm $ Assign (Var one) $ simpleCall (compilePrimToNp (Imp.primExpType bound)) [Integer 1] stm $ For counter (simpleCall "range" [bound']) $ body' ++ [AssignOp "+" (Var i') (Var one)] compileCode (Imp.SetScalar name exp1) = stm =<< Assign <$> compileVar name <*> compileExp exp1 compileCode Imp.DeclareMem {} = pure () compileCode (Imp.DeclareScalar v _ Unit) = do v' <- compileVar v stm $ Assign v' $ Var "True" compileCode Imp.DeclareScalar {} = pure () compileCode (Imp.DeclareArray name t vs) = do let arr_name = compileName name <> "_arr" -- It is important to store the Numpy array in a temporary variable -- to prevent it from going "out-of-scope" before calling -- unwrapArray (which internally uses the .ctype method); see -- https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.ctypes.html stm $ Assign (Var arr_name) $ case vs of Imp.ArrayValues vs' -> Call (Var "np.array") [ Arg $ List $ map compilePrimValue vs', ArgKeyword "dtype" $ Var $ compilePrimToNp t ] Imp.ArrayZeros n -> Call (Var "np.zeros") [ Arg $ Integer $ fromIntegral n, ArgKeyword "dtype" $ Var $ compilePrimToNp t ] name' <- compileVar name stm $ Assign name' $ simpleCall "unwrapArray" [Var arr_name] compileCode (Imp.Comment s code) = do code' <- collect $ compileCode code stm $ Comment (T.unpack s) code' compileCode (Imp.Assert e msg (loc, locs)) = do e' <- compileExp e (formatstr, formatargs) <- errorMsgString msg stm $ Assert e' ( BinOp "%" (String $ "Error: " <> formatstr <> "\n\nBacktrace:\n" <> stacktrace) (Tuple formatargs) ) where stacktrace = prettyStacktrace 0 $ map locText $ loc : locs compileCode (Imp.Call dests fname args) = do args' <- mapM compileArg args dests' <- tupleOrSingle <$> mapM compileVar dests let fname' | isBuiltInFunction fname = futharkFun (prettyText fname) | otherwise = "self." <> futharkFun (prettyText fname) call' = simpleCall (T.unpack fname') args' -- If the function returns nothing (is called only for side -- effects), take care not to assign to an empty tuple. stm $ if null dests then Exp call' else Assign dests' call' where compileArg (Imp.MemArg m) = compileVar m compileArg (Imp.ExpArg e) = compileExp e compileCode (Imp.SetMem dest src _) = stm =<< Assign <$> compileVar dest <*> compileVar src compileCode (Imp.Allocate name (Imp.Count (Imp.TPrimExp e)) (Imp.Space space)) = join $ asks envAllocate <*> compileVar name <*> compileExp e <*> pure space compileCode (Imp.Allocate name (Imp.Count (Imp.TPrimExp e)) _) = do e' <- compileExp e let allocate' = simpleCall "allocateMem" [e'] stm =<< Assign <$> compileVar name <*> pure allocate' compileCode (Imp.Free name _) = stm =<< Assign <$> compileVar name <*> pure None compileCode (Imp.Copy t shape (dst, dstspace) (dstoffset, dststrides) (src, srcspace) (srcoffset, srcstrides)) = do cp <- asks $ M.lookup (dstspace, srcspace) . opsCopies . envOperations case cp of Nothing -> compileCopy t shape (dst, dstspace) (dstoffset, dststrides) (src, srcspace) (srcoffset, srcstrides) Just cp' -> do shape' <- traverse (traverse (compileExp . untyped)) shape dst' <- compileVar dst src' <- compileVar src dstoffset' <- traverse (compileExp . untyped) dstoffset dststrides' <- traverse (traverse (compileExp . untyped)) dststrides srcoffset' <- traverse (compileExp . untyped) srcoffset srcstrides' <- traverse (traverse (compileExp . untyped)) srcstrides cp' t shape' dst' (dstoffset', dststrides') src' (srcoffset', srcstrides') compileCode (Imp.Write dst (Imp.Count idx) pt space _ elemexp) = do dst' <- compileVar dst idx' <- compileExp $ Imp.untyped idx elemexp' <- compileExp elemexp generateWrite dst' idx' pt space elemexp' compileCode (Imp.Read x src (Imp.Count iexp) pt space _) = do x' <- compileVar x iexp' <- compileExp $ untyped iexp src' <- compileVar src stm . Assign x' =<< generateRead src' iexp' pt space compileCode Imp.Skip = pure () lmadcopyCPU :: DoCopy op s lmadcopyCPU t shape dst (dstoffset, dststride) src (srcoffset, srcstride) = stm . Exp . simpleCall "lmad_copy" $ [ Var (compilePrimType t), dst, unCount dstoffset, List (map unCount dststride), src, unCount srcoffset, List (map unCount srcstride), List (map unCount shape) ] futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericPython/000077500000000000000000000000001475065116200233315ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericPython/AST.hs000066400000000000000000000130311475065116200243120ustar00rootroot00000000000000module Futhark.CodeGen.Backends.GenericPython.AST ( PyExp (..), PyIdx (..), PyArg (..), PyStmt (..), module Language.Futhark.Core, PyProg (..), PyExcept (..), PyFunDef (..), PyClassDef (..), ) where import Data.Text qualified as T import Futhark.Util.Pretty import Language.Futhark.Core data UnOp = -- | Boolean negation. Not | -- | Bitwise complement. Complement | -- | Numerical negation. Negate | -- | Absolute/numerical value. Abs deriving (Eq, Show) data PyExp = Integer Integer | Bool Bool | Float Double | String T.Text | RawStringLiteral T.Text | Var String | BinOp String PyExp PyExp | UnOp String PyExp | Cond PyExp PyExp PyExp | Index PyExp PyIdx | Call PyExp [PyArg] | Tuple [PyExp] | List [PyExp] | Field PyExp String | Dict [(PyExp, PyExp)] | Lambda String PyExp | None deriving (Eq, Show) data PyIdx = IdxRange PyExp PyExp | IdxExp PyExp deriving (Eq, Show) data PyArg = ArgKeyword String PyExp | Arg PyExp deriving (Eq, Show) data PyStmt = If PyExp [PyStmt] [PyStmt] | Try [PyStmt] [PyExcept] | While PyExp [PyStmt] | For String PyExp [PyStmt] | With PyExp [PyStmt] | Assign PyExp PyExp | AssignOp String PyExp PyExp | Comment String [PyStmt] | Assert PyExp PyExp | Raise PyExp | Exp PyExp | Return PyExp | Pass | -- Definition-like statements. Import String (Maybe String) | FunDef PyFunDef | ClassDef PyClassDef | -- Some arbitrary string of Python code. Escape T.Text deriving (Eq, Show) data PyExcept = Catch PyExp [PyStmt] deriving (Eq, Show) data PyFunDef = Def String [String] [PyStmt] deriving (Eq, Show) data PyClassDef = Class String [PyStmt] deriving (Eq, Show) newtype PyProg = PyProg [PyStmt] deriving (Eq, Show) instance Pretty PyIdx where pretty (IdxExp e) = pretty e pretty (IdxRange from to) = pretty from <> ":" <> pretty to instance Pretty PyArg where pretty (ArgKeyword k e) = pretty k <> equals <> pretty e pretty (Arg e) = pretty e instance Pretty PyExp where pretty (Integer x) = pretty x pretty (Bool x) = pretty x pretty (Float x) | isInfinite x = if x > 0 then "float('inf')" else "float('-inf')" | otherwise = pretty x pretty (String x) = pretty $ show x pretty (RawStringLiteral s) = "\"\"\"" <> pretty s <> "\"\"\"" pretty (Var n) = pretty $ map (\x -> if x == '\'' then 'm' else x) n pretty (Field e s) = pretty e <> "." <> pretty s pretty (BinOp s e1 e2) = parens (pretty e1 <+> pretty s <+> pretty e2) pretty (UnOp s e) = pretty s <> parens (pretty e) pretty (Cond e1 e2 e3) = pretty e2 <+> "if" <+> pretty e1 <+> "else" <+> pretty e3 pretty (Index src idx) = pretty src <> brackets (pretty idx) pretty (Call fun exps) = pretty fun <> parens (commasep $ map pretty exps) pretty (Tuple [dim]) = parens (pretty dim <> ",") pretty (Tuple dims) = parens (commasep $ map pretty dims) pretty (List es) = brackets $ commasep $ map pretty es pretty (Dict kvs) = braces $ commasep $ map ppElem kvs where ppElem (k, v) = pretty k <> colon <+> pretty v pretty (Lambda p e) = "lambda" <+> pretty p <> ":" <+> pretty e pretty None = "None" instance Pretty PyStmt where pretty (If cond [] []) = "if" <+> pretty cond <> ":" indent 2 "pass" pretty (If cond [] fbranch) = "if" <+> pretty cond <> ":" indent 2 "pass" "else:" indent 2 (stack $ map pretty fbranch) pretty (If cond tbranch []) = "if" <+> pretty cond <> ":" indent 2 (stack $ map pretty tbranch) pretty (If cond tbranch fbranch) = "if" <+> pretty cond <> ":" indent 2 (stack $ map pretty tbranch) "else:" indent 2 (stack $ map pretty fbranch) pretty (Try pystms pyexcepts) = "try:" indent 2 (stack $ map pretty pystms) stack (map pretty pyexcepts) pretty (While cond body) = "while" <+> pretty cond <> ":" indent 2 (stack $ map pretty body) pretty (For i what body) = "for" <+> pretty i <+> "in" <+> pretty what <> ":" indent 2 (stack $ map pretty body) pretty (With what body) = "with" <+> pretty what <> ":" indent 2 (stack $ map pretty body) pretty (Assign e1 e2) = pretty e1 <+> "=" <+> pretty e2 pretty (AssignOp op e1 e2) = pretty e1 <+> pretty (op ++ "=") <+> pretty e2 pretty (Comment s body) = "#" <> pretty s stack (map pretty body) pretty (Assert e1 e2) = "assert" <+> pretty e1 <> "," <+> pretty e2 pretty (Raise e) = "raise" <+> pretty e pretty (Exp c) = pretty c pretty (Return e) = "return" <+> pretty e pretty Pass = "pass" pretty (Import from (Just as)) = "import" <+> pretty from <+> "as" <+> pretty as pretty (Import from Nothing) = "import" <+> pretty from pretty (FunDef d) = pretty d pretty (ClassDef d) = pretty d pretty (Escape s) = stack $ map pretty $ T.lines s instance Pretty PyFunDef where pretty (Def fname params body) = "def" <+> pretty fname <> parens (commasep $ map pretty params) <> ":" indent 2 (stack (map pretty body)) instance Pretty PyClassDef where pretty (Class cname body) = "class" <+> pretty cname <> ":" indent 2 (stack (map pretty body)) instance Pretty PyExcept where pretty (Catch pyexp stms) = "except" <+> pretty pyexp <+> "as e:" indent 2 (vsep $ map pretty stms) instance Pretty PyProg where pretty (PyProg stms) = vsep (map pretty stms) futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericPython/Options.hs000066400000000000000000000054721475065116200253300ustar00rootroot00000000000000-- | This module defines a generator for @getopt@ based command -- line argument parsing. Each option is associated with arbitrary -- Python code that will perform side effects, usually by setting some -- global variables. module Futhark.CodeGen.Backends.GenericPython.Options ( Option (..), OptionArgument (..), generateOptionParser, ) where import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericPython.AST -- | Specification if a single command line option. The option must -- have a long name, and may also have a short name. -- -- When the statement is being executed, the argument (if any) will be -- stored in the variable @optarg@. data Option = Option { optionLongName :: T.Text, optionShortName :: Maybe Char, optionArgument :: OptionArgument, optionAction :: [PyStmt] } -- | Whether an option accepts an argument. data OptionArgument = NoArgument | RequiredArgument String | OptionalArgument -- | Generate option parsing code that accepts the given command line options. Will read from @sys.argv@. -- -- If option parsing fails for any reason, the entire process will -- terminate with error code 1. generateOptionParser :: [Option] -> [PyStmt] generateOptionParser options = [ Assign (Var "parser") ( Call (Var "argparse.ArgumentParser") [ ArgKeyword "description" $ String "A compiled Futhark program." ] ) ] ++ map parseOption options ++ [ Assign (Var "parser_result") $ Call (Var "vars") [Arg $ Call (Var "parser.parse_args") [Arg $ Var "sys.argv[1:]"]] ] ++ map executeOption options where parseOption option = Exp $ Call (Var "parser.add_argument") $ map (Arg . String) name_args ++ argument_args where name_args = maybe id (\x l -> ("-" <> T.singleton x) : l) (optionShortName option) ["--" <> optionLongName option] argument_args = case optionArgument option of RequiredArgument t -> [ ArgKeyword "action" (String "append"), ArgKeyword "default" $ List [], ArgKeyword "type" $ Var t ] NoArgument -> [ ArgKeyword "action" (String "append_const"), ArgKeyword "default" $ List [], ArgKeyword "const" None ] OptionalArgument -> [ ArgKeyword "action" (String "append"), ArgKeyword "default" $ List [], ArgKeyword "nargs" $ String "?" ] executeOption option = For "optarg" ( Index (Var "parser_result") $ IdxExp $ String $ fieldName option ) $ optionAction option fieldName = T.map escape . optionLongName where escape '-' = '_' escape c = c futhark-0.25.27/src/Futhark/CodeGen/Backends/GenericWASM.hs000066400000000000000000000257461475065116200231710ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} module Futhark.CodeGen.Backends.GenericWASM ( GC.CParts (..), GC.asLibrary, GC.asExecutable, GC.asServer, EntryPointType, JSEntryPoint (..), emccExportNames, javascriptWrapper, extToString, runServer, libraryExports, ) where import Data.List (intercalate) import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC qualified as GC import Futhark.CodeGen.Backends.SimpleRep (opaqueName) import Futhark.CodeGen.ImpCode.Sequential qualified as Imp import Futhark.CodeGen.RTS.JavaScript import Futhark.Util (nubOrd, showText) import Language.Futhark.Primitive import NeatInterpolation (text) extToString :: Imp.ExternalValue -> String extToString (Imp.TransparentValue (Imp.ArrayValue vn _ pt s dimSize)) = concat (replicate (length dimSize) "[]") ++ extToString (Imp.TransparentValue (Imp.ScalarValue pt s vn)) extToString (Imp.TransparentValue (Imp.ScalarValue (FloatType Float16) _ _)) = "f16" extToString (Imp.TransparentValue (Imp.ScalarValue (FloatType Float32) _ _)) = "f32" extToString (Imp.TransparentValue (Imp.ScalarValue (FloatType Float64) _ _)) = "f64" extToString (Imp.TransparentValue (Imp.ScalarValue (IntType Int8) Imp.Signed _)) = "i8" extToString (Imp.TransparentValue (Imp.ScalarValue (IntType Int16) Imp.Signed _)) = "i16" extToString (Imp.TransparentValue (Imp.ScalarValue (IntType Int32) Imp.Signed _)) = "i32" extToString (Imp.TransparentValue (Imp.ScalarValue (IntType Int64) Imp.Signed _)) = "i64" extToString (Imp.TransparentValue (Imp.ScalarValue (IntType Int8) Imp.Unsigned _)) = "u8" extToString (Imp.TransparentValue (Imp.ScalarValue (IntType Int16) Imp.Unsigned _)) = "u16" extToString (Imp.TransparentValue (Imp.ScalarValue (IntType Int32) Imp.Unsigned _)) = "u32" extToString (Imp.TransparentValue (Imp.ScalarValue (IntType Int64) Imp.Unsigned _)) = "u64" extToString (Imp.TransparentValue (Imp.ScalarValue Bool _ _)) = "bool" extToString (Imp.TransparentValue (Imp.ScalarValue Unit _ _)) = error "extToString: Unit" extToString (Imp.OpaqueValue oname _) = T.unpack $ opaqueName oname type EntryPointType = String data JSEntryPoint = JSEntryPoint { name :: String, parameters :: [EntryPointType], ret :: [EntryPointType] } emccExportNames :: [JSEntryPoint] -> [String] emccExportNames jses = map (\jse -> "'_futhark_entry_" ++ T.unpack (GC.escapeName (T.pack (name jse))) ++ "'") jses ++ map (\arg -> "'" ++ gfn "new" arg ++ "'") arrays ++ map (\arg -> "'" ++ gfn "free" arg ++ "'") arrays ++ map (\arg -> "'" ++ gfn "shape" arg ++ "'") arrays ++ map (\arg -> "'" ++ gfn "values_raw" arg ++ "'") arrays ++ map (\arg -> "'" ++ gfn "values" arg ++ "'") arrays ++ map (\arg -> "'" ++ "_futhark_free_" ++ arg ++ "'") opaques ++ [ "_futhark_context_config_new", "_futhark_context_config_free", "_futhark_context_new", "_futhark_context_free", "_futhark_context_get_error" ] where arrays = filter isArray typs opaques = filter isOpaque typs typs = nubOrd $ concatMap (\jse -> parameters jse ++ ret jse) jses gfn typ str = "_futhark_" ++ typ ++ "_" ++ baseType str ++ "_" ++ show (dim str) ++ "d" javascriptWrapper :: [JSEntryPoint] -> T.Text javascriptWrapper entryPoints = T.unlines [ serverJs, valuesJs, wrapperclassesJs, classFutharkContext entryPoints ] classFutharkContext :: [JSEntryPoint] -> T.Text classFutharkContext entryPoints = T.unlines [ "class FutharkContext {", constructor entryPoints, getFreeFun, getEntryPointsFun, getErrorFun, T.unlines $ map toFutharkArray arrays, T.unlines $ map jsWrapEntryPoint entryPoints, "}", [text| async function newFutharkContext() { var wasm = await loadWASM(); return new FutharkContext(wasm); } |] ] where arrays = filter isArray typs typs = nubOrd $ concatMap (\jse -> parameters jse ++ ret jse) entryPoints constructor :: [JSEntryPoint] -> T.Text constructor jses = [text| constructor(wasm, num_threads) { this.wasm = wasm; this.cfg = this.wasm._futhark_context_config_new(); if (num_threads) this.wasm._futhark_context_config_set_num_threads(this.cfg, num_threads); this.ctx = this.wasm._futhark_context_new(this.cfg); this.entry_points = { ${entries} }; } |] where entries = T.intercalate "," $ map dicEntry jses getFreeFun :: T.Text getFreeFun = [text| free() { this.wasm._futhark_context_free(this.ctx); this.wasm._futhark_context_config_free(this.cfg); } |] getEntryPointsFun :: T.Text getEntryPointsFun = [text| get_entry_points() { return this.entry_points; } |] getErrorFun :: T.Text getErrorFun = [text| get_error() { var ptr = this.wasm._futhark_context_get_error(this.ctx); var len = HEAP8.subarray(ptr).indexOf(0); var str = String.fromCharCode(...HEAP8.subarray(ptr, ptr + len)); this.wasm._free(ptr); return str; } |] dicEntry :: JSEntryPoint -> T.Text dicEntry jse = [text| "${ename}" : ["${fname}", ${params}, ${rets}] |] where fname = GC.escapeName $ T.pack $ name jse ename = T.pack $ name jse params = showText $ parameters jse rets = showText $ ret jse jsWrapEntryPoint :: JSEntryPoint -> T.Text jsWrapEntryPoint jse = [text| ${func_name}(${inparams}) { var out = [${outparams}].map(n => this.wasm._malloc(n)); var to_free = []; var do_free = () => { out.forEach(this.wasm._free); to_free.forEach(f => f.free()); }; ${paramsToPtr} if (this.wasm._futhark_entry_${func_name}(this.ctx, ...out, ${ins}) > 0) { do_free(); throw this.get_error(); } ${results} do_free(); return ${res}; } |] where func_name = GC.escapeName $ T.pack $ name jse alp = [0 .. length (parameters jse) - 1] inparams = T.pack $ intercalate ", " ["in" ++ show i | i <- alp] ins = T.pack $ intercalate ", " [maybeDerefence ("in" ++ show i) $ parameters jse !! i | i <- alp] paramsToPtr = T.pack $ unlines $ filter ("" /=) [arrayPointer ("in" ++ show i) $ parameters jse !! i | i <- alp] alr = [0 .. length (ret jse) - 1] outparams = T.pack $ intercalate ", " [show $ typeSize $ ret jse !! i | i <- alr] results = T.pack $ unlines [makeResult i $ ret jse !! i | i <- alr] res_array = intercalate ", " ["result" ++ show i | i <- alr] res = T.pack $ if length (ret jse) == 1 then "result0" else "[" ++ res_array ++ "]" maybeDerefence :: String -> String -> String maybeDerefence arg typ = if isScalar typ then arg else arg ++ ".ptr" arrayPointer :: String -> String -> String arrayPointer arg typ = if isArray typ then " if (" ++ arg ++ " instanceof Array) { " ++ reassign ++ "; to_free.push(" ++ arg ++ "); }" else "" where reassign = arg ++ " = this.new_" ++ signature ++ "_from_jsarray(" ++ arg ++ ")" signature = baseType typ ++ "_" ++ show (dim typ) ++ "d" makeResult :: Int -> String -> String makeResult i typ = " var result" ++ show i ++ " = " ++ if isArray typ then "this.new_" ++ signature ++ "_from_ptr(" ++ readout ++ ");" else if isOpaque typ then "new FutharkOpaque(this, " ++ readout ++ ", this.wasm._futhark_free_" ++ typ ++ ");" else readout ++ if typ == "bool" then "!==0;" else ";" where res = "out[" ++ show i ++ "]" readout = typeHeap typ ++ "[" ++ res ++ " >> " ++ show (typeShift typ) ++ "]" signature = baseType typ ++ "_" ++ show (dim typ) ++ "d" baseType :: String -> String baseType ('[' : ']' : end) = baseType end baseType typ = typ dim :: String -> Int dim ('[' : ']' : end) = dim end + 1 dim _ = 0 isArray :: String -> Bool isArray typ = take 2 typ == "[]" isOpaque :: String -> Bool isOpaque typ = take 6 typ == "opaque" isScalar :: String -> Bool isScalar typ = not (isArray typ || isOpaque typ) typeSize :: String -> Integer typeSize typ = case typ of "i8" -> 1 "i16" -> 2 "i32" -> 4 "i64" -> 8 "u8" -> 1 "u16" -> 2 "u32" -> 4 "u64" -> 8 "f16" -> 2 "f32" -> 4 "f64" -> 8 "bool" -> 1 _ -> 4 typeShift :: String -> Integer typeShift typ = case typ of "i8" -> 0 "i16" -> 1 "i32" -> 2 "i64" -> 3 "u8" -> 0 "u16" -> 1 "u32" -> 2 "u64" -> 3 "f16" -> 1 "f32" -> 2 "f64" -> 3 "bool" -> 0 _ -> 2 typeHeap :: String -> String typeHeap typ = case typ of "i8" -> "this.wasm.HEAP8" "i16" -> "this.wasm.HEAP16" "i32" -> "this.wasm.HEAP32" "i64" -> "this.wasm.HEAP64" "u8" -> "this.wasm.HEAPU8" "u16" -> "this.wasm.HEAPU16" "u32" -> "this.wasm.HEAPU32" "u64" -> "(new BigUint64Array(this.wasm.HEAP64.buffer))" "f16" -> "this.wasm.HEAPU16" "f32" -> "this.wasm.HEAPF32" "f64" -> "this.wasm.HEAPF64" "bool" -> "this.wasm.HEAP8" _ -> "this.wasm.HEAP32" toFutharkArray :: String -> T.Text toFutharkArray typ = [text| ${new}_from_jsarray(${arraynd_p}) { return this.${new}(${arraynd_flat_p}, ${arraynd_dims_p}); } ${new}(array, ${dims}) { console.assert(array.length === ${dims_multiplied}, 'len=%s,dims=%s', array.length, [${dims}].toString()); var copy = this.wasm._malloc(array.length << ${shift}); ${heapType}.set(array, copy >> ${shift}); var ptr = ${fnew}(this.ctx, copy, ${bigint_dims}); this.wasm._free(copy); return this.${new}_from_ptr(ptr); } ${new}_from_ptr(ptr) { return new FutharkArray(this, ptr, ${args}); } |] where d = dim typ ftype = baseType typ heap = typeHeap ftype signature = ftype ++ "_" ++ show d ++ "d" new = T.pack $ "new_" ++ signature fnew = T.pack $ "this.wasm._futhark_new_" ++ signature fshape = "this.wasm._futhark_shape_" ++ signature fvalues = "this.wasm._futhark_values_raw_" ++ signature ffree = "this.wasm._futhark_free_" ++ signature arraynd = "array" ++ show d ++ "d" shift = showText (typeShift ftype) heapType = T.pack heap arraynd_flat = if d > 1 then arraynd ++ ".flat()" else arraynd arraynd_dims = intercalate ", " [arraynd ++ mult i "[0]" ++ ".length" | i <- [0 .. d - 1]] dims = T.pack $ intercalate ", " ["d" ++ show i | i <- [0 .. d - 1]] dims_multiplied = T.pack $ intercalate "*" ["Number(d" ++ show i ++ ")" | i <- [0 .. d - 1]] bigint_dims = T.pack $ intercalate ", " ["BigInt(d" ++ show i ++ ")" | i <- [0 .. d - 1]] mult i s = concat $ replicate i s (arraynd_p, arraynd_flat_p, arraynd_dims_p) = (T.pack arraynd, T.pack arraynd_flat, T.pack arraynd_dims) args = T.pack $ intercalate ", " ["'" ++ ftype ++ "'", show d, heap, fshape, fvalues, ffree] -- | Javascript code that can be appended to the generated module to -- run a Futhark server instance on startup. runServer :: T.Text runServer = [text| Module.onRuntimeInitialized = () => { var context = new FutharkContext(Module); var server = new Server(context); server.run(); }|] -- | The names exported by the generated module. libraryExports :: T.Text libraryExports = "export {newFutharkContext, FutharkContext, FutharkArray, FutharkOpaque};" futhark-0.25.27/src/Futhark/CodeGen/Backends/HIP.hs000066400000000000000000000102501475065116200215250ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | Code generation for HIP. module Futhark.CodeGen.Backends.HIP ( compileProg, GC.CParts (..), GC.asLibrary, GC.asExecutable, GC.asServer, ) where import Data.Map qualified as M import Data.Text qualified as T import Futhark.CodeGen.Backends.GPU import Futhark.CodeGen.Backends.GenericC qualified as GC import Futhark.CodeGen.Backends.GenericC.Options import Futhark.CodeGen.ImpCode.OpenCL import Futhark.CodeGen.ImpGen.HIP qualified as ImpGen import Futhark.CodeGen.RTS.C (backendsHipH) import Futhark.IR.GPUMem hiding ( CmpSizeLe, GetSize, GetSizeMax, ) import Futhark.MonadFreshNames import Language.C.Quote.OpenCL qualified as C import NeatInterpolation (untrimming) mkBoilerplate :: T.Text -> [(Name, KernelConstExp)] -> M.Map Name KernelSafety -> [PrimType] -> [FailureMsg] -> GC.CompilerM OpenCL () () mkBoilerplate hip_program macros kernels types failures = do generateGPUBoilerplate hip_program macros backendsHipH (M.keys kernels) types failures GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_add_build_option(struct futhark_context_config *cfg, const char* opt);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_device(struct futhark_context_config *cfg, const char* s);|] GC.headerDecl GC.InitDecl [C.cedecl|const char* futhark_context_config_get_program(struct futhark_context_config *cfg);|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_program(struct futhark_context_config *cfg, const char* s);|] cliOptions :: [Option] cliOptions = gpuOptions ++ [ Option { optionLongName = "dump-hip", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Dump the embedded HIP kernels to the indicated file.", optionAction = [C.cstm|{const char* prog = futhark_context_config_get_program(cfg); if (dump_file(optarg, prog, strlen(prog)) != 0) { fprintf(stderr, "%s: %s\n", optarg, strerror(errno)); exit(1); } exit(0);}|] }, Option { optionLongName = "load-hip", optionShortName = Nothing, optionArgument = RequiredArgument "FILE", optionDescription = "Instead of using the embedded HIP kernels, load them from the indicated file.", optionAction = [C.cstm|{ size_t n; const char *s = slurp_file(optarg, &n); if (s == NULL) { fprintf(stderr, "%s: %s\n", optarg, strerror(errno)); exit(1); } futhark_context_config_set_program(cfg, s); }|] }, Option { optionLongName = "build-option", optionShortName = Nothing, optionArgument = RequiredArgument "OPT", optionDescription = "Add an additional build option to the string passed to HIPRTC.", optionAction = [C.cstm|futhark_context_config_add_build_option(cfg, optarg);|] } ] hipMemoryType :: GC.MemoryType OpenCL () hipMemoryType "device" = pure [C.cty|typename hipDeviceptr_t|] hipMemoryType space = error $ "GPU backend does not support '" ++ space ++ "' memory space." -- | Compile the program to C with calls to HIP. compileProg :: (MonadFreshNames m) => T.Text -> Prog GPUMem -> m (ImpGen.Warnings, GC.CParts) compileProg version prog = do ( ws, Program hip_code hip_prelude macros kernels types params failures prog' ) <- ImpGen.compileProg prog (ws,) <$> GC.compileProg "hip" version params operations (mkBoilerplate (hip_prelude <> hip_code) macros kernels types failures) hip_includes (Space "device", [Space "device", DefaultSpace]) cliOptions prog' where operations :: GC.Operations OpenCL () operations = gpuOperations { GC.opsMemoryType = hipMemoryType } hip_includes = [untrimming| #define __HIP_PLATFORM_AMD__ #include #include |] futhark-0.25.27/src/Futhark/CodeGen/Backends/MulticoreC.hs000066400000000000000000000443431475065116200231650ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | C code generator. This module can convert a correct ImpCode -- program to an equivalent C program. module Futhark.CodeGen.Backends.MulticoreC ( compileProg, GC.CParts (..), GC.asLibrary, GC.asExecutable, GC.asServer, operations, cliOptions, compileOp, ValueType (..), paramToCType, prepareTaskStruct, closureFreeStructField, generateParLoopFn, addTimingFields, functionTiming, functionIterations, multicoreDef, multicoreName, DefSpecifier, atomicOps, ) where import Control.Monad import Data.Loc import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC qualified as GC import Futhark.CodeGen.Backends.GenericC.Options import Futhark.CodeGen.Backends.MulticoreC.Boilerplate (generateBoilerplate) import Futhark.CodeGen.Backends.SimpleRep import Futhark.CodeGen.ImpCode.Multicore hiding (ValueType) import Futhark.CodeGen.ImpGen.Multicore qualified as ImpGen import Futhark.IR.MCMem (MCMem, Prog) import Futhark.MonadFreshNames import Language.C.Quote.OpenCL qualified as C import Language.C.Syntax qualified as C -- | Compile the program to ImpCode with multicore operations. compileProg :: (MonadFreshNames m) => T.Text -> Prog MCMem -> m (ImpGen.Warnings, GC.CParts) compileProg version = traverse ( GC.compileProg "multicore" version mempty operations generateBoilerplate "#include \n" (DefaultSpace, [DefaultSpace]) cliOptions ) <=< ImpGen.compileProg -- | Multicore-related command line options. cliOptions :: [Option] cliOptions = [ Option { optionLongName = "num-threads", optionShortName = Nothing, optionArgument = RequiredArgument "INT", optionAction = [C.cstm|futhark_context_config_set_num_threads(cfg, atoi(optarg));|], optionDescription = "Set number of threads used for execution." } ] -- | Operations for generating multicore code. operations :: GC.Operations Multicore s operations = GC.defaultOperations { GC.opsCompiler = compileOp, GC.opsCritical = -- The thread entering an API function is always considered -- the "first worker" - note that this might differ from the -- thread that created the context! This likely only matters -- for entry points, since they are the only API functions -- that contain parallel operations. ( [C.citems|worker_local = &ctx->scheduler.workers[0];|], [] ) } closureFreeStructField :: VName -> Name closureFreeStructField v = nameFromString "free_" <> nameFromString (prettyString v) closureRetvalStructField :: VName -> Name closureRetvalStructField v = nameFromString "retval_" <> nameFromString (prettyString v) data ValueType = Prim PrimType | MemBlock | RawMem compileFreeStructFields :: [VName] -> [(C.Type, ValueType)] -> [C.FieldGroup] compileFreeStructFields = zipWith field where field name (ty, Prim _) = [C.csdecl|$ty:ty $id:(closureFreeStructField name);|] field name (_, _) = [C.csdecl|$ty:defaultMemBlockType $id:(closureFreeStructField name);|] compileRetvalStructFields :: [VName] -> [(C.Type, ValueType)] -> [C.FieldGroup] compileRetvalStructFields = zipWith field where field name (ty, Prim _) = [C.csdecl|$ty:ty *$id:(closureRetvalStructField name);|] field name (_, _) = [C.csdecl|$ty:defaultMemBlockType $id:(closureRetvalStructField name);|] compileSetStructValues :: (C.ToIdent a) => a -> [VName] -> [(C.Type, ValueType)] -> [C.Stm] compileSetStructValues struct = zipWith field where field name (_, Prim pt) = [C.cstm|$id:struct.$id:(closureFreeStructField name)=$exp:(toStorage pt (C.toExp name noLoc));|] field name (_, MemBlock) = [C.cstm|$id:struct.$id:(closureFreeStructField name)=$id:name.mem;|] field name (_, RawMem) = [C.cstm|$id:struct.$id:(closureFreeStructField name)=$id:name;|] compileSetRetvalStructValues :: (C.ToIdent a) => a -> [VName] -> [(C.Type, ValueType)] -> [C.Stm] compileSetRetvalStructValues struct vnames we = concat $ zipWith field vnames we where field name (ct, Prim _) = [C.cstms|$id:struct.$id:(closureRetvalStructField name)=(($ty:ct*)&$id:name); $escstm:("#if defined(ISPC)") $id:struct.$id:(closureRetvalStructField name)+= programIndex; $escstm:("#endif")|] field name (_, MemBlock) = [C.cstms|$id:struct.$id:(closureRetvalStructField name)=$id:name.mem;|] field name (_, RawMem) = [C.cstms|$id:struct.$id:(closureRetvalStructField name)=$id:name;|] compileGetRetvalStructVals :: (C.ToIdent a) => a -> [VName] -> [(C.Type, ValueType)] -> [C.InitGroup] compileGetRetvalStructVals struct = zipWith field where field name (ty, Prim pt) = let inner = [C.cexp|*$id:struct->$id:(closureRetvalStructField name)|] in [C.cdecl|$ty:ty $id:name = $exp:(fromStorage pt inner);|] field name (ty, _) = [C.cdecl|$ty:ty $id:name = {.desc = $string:(prettyString name), .mem = $id:struct->$id:(closureRetvalStructField name), .size = 0, .references = NULL};|] compileGetStructVals :: (C.ToIdent a) => a -> [VName] -> [(C.Type, ValueType)] -> [C.InitGroup] compileGetStructVals struct = zipWith field where field name (ty, Prim pt) = let inner = [C.cexp|$id:struct->$id:(closureFreeStructField name)|] in [C.cdecl|$ty:ty $id:name = $exp:(fromStorage pt inner);|] field name (ty, _) = [C.cdecl|$ty:ty $id:name = {.desc = $string:(prettyString name), .mem = $id:struct->$id:(closureFreeStructField name), .size = 0, .references = NULL};|] compileWriteBackResVals :: (C.ToIdent a) => a -> [VName] -> [(C.Type, ValueType)] -> [C.Stm] compileWriteBackResVals struct = zipWith field where field name (_, Prim pt) = [C.cstm|*$id:struct->$id:(closureRetvalStructField name) = $exp:(toStorage pt (C.toExp name noLoc));|] field name (_, _) = [C.cstm|$id:struct->$id:(closureRetvalStructField name) = $id:name.mem;|] paramToCType :: Param -> GC.CompilerM op s (C.Type, ValueType) paramToCType (ScalarParam _ pt) = do let t = primStorageType pt pure (t, Prim pt) paramToCType (MemParam name space') = mcMemToCType name space' mcMemToCType :: VName -> Space -> GC.CompilerM op s (C.Type, ValueType) mcMemToCType v space = do refcount <- GC.fatMemory space cached <- isJust <$> GC.cacheMem v pure ( GC.fatMemType space, if refcount && not cached then MemBlock else RawMem ) benchmarkCode :: Name -> [C.BlockItem] -> GC.CompilerM op s [C.BlockItem] benchmarkCode name code = do event <- newVName "event" pure [C.citems| struct mc_event* $id:event = mc_event_new(ctx); if ($id:event != NULL) { $id:event->bef = get_wall_time(); } $items:code if ($id:event != NULL) { $id:event->aft = get_wall_time(); lock_lock(&ctx->event_list_lock); add_event(ctx, $string:(nameToString name), strdup("nothing further"), $id:event, (typename event_report_fn)mc_event_report); lock_unlock(&ctx->event_list_lock); }|] functionTiming :: Name -> C.Id functionTiming = (`C.toIdent` mempty) . (<> "_total_time") functionIterations :: Name -> C.Id functionIterations = (`C.toIdent` mempty) . (<> "_total_iter") addTimingFields :: Name -> GC.CompilerM op s () addTimingFields name = do GC.contextField (functionTiming name) [C.cty|typename int64_t|] $ Just [C.cexp|0|] GC.contextField (functionIterations name) [C.cty|typename int64_t|] $ Just [C.cexp|0|] multicoreName :: String -> GC.CompilerM op s Name multicoreName s = do s' <- newVName ("futhark_mc_" ++ s) pure $ nameFromString $ baseString s' ++ "_" ++ show (baseTag s') type DefSpecifier s = String -> (Name -> GC.CompilerM Multicore s C.Definition) -> GC.CompilerM Multicore s Name multicoreDef :: DefSpecifier s multicoreDef s f = do s' <- multicoreName s GC.libDecl =<< f s' pure s' generateParLoopFn :: (C.ToIdent a) => M.Map VName Space -> String -> MCCode -> a -> [(VName, (C.Type, ValueType))] -> [(VName, (C.Type, ValueType))] -> GC.CompilerM Multicore s Name generateParLoopFn lexical basename code fstruct free retval = do let (fargs, fctypes) = unzip free let (retval_args, retval_ctypes) = unzip retval multicoreDef basename $ \s -> do fbody <- benchmarkCode s <=< GC.inNewFunction $ GC.cachingMemory lexical $ \decl_cached free_cached -> GC.collect $ do mapM_ GC.item [C.citems|$decls:(compileGetStructVals fstruct fargs fctypes)|] mapM_ GC.item [C.citems|$decls:(compileGetRetvalStructVals fstruct retval_args retval_ctypes)|] code' <- GC.collect $ GC.compileCode code mapM_ GC.item decl_cached mapM_ GC.item =<< GC.declAllocatedMem mapM_ GC.item code' free_mem <- GC.freeAllocatedMem GC.stm [C.cstm|cleanup: {$stms:free_cached $items:free_mem}|] pure [C.cedecl|int $id:s(void *args, typename int64_t iterations, int tid, struct scheduler_info info) { int err = 0; int subtask_id = tid; struct $id:fstruct *$id:fstruct = (struct $id:fstruct*) args; struct futhark_context *ctx = $id:fstruct->ctx; $items:fbody if (err == 0) { $stms:(compileWriteBackResVals fstruct retval_args retval_ctypes) } return err; }|] prepareTaskStruct :: DefSpecifier s -> String -> [VName] -> [(C.Type, ValueType)] -> [VName] -> [(C.Type, ValueType)] -> GC.CompilerM Multicore s Name prepareTaskStruct def name free_args free_ctypes retval_args retval_ctypes = do let makeStruct s = pure [C.cedecl|struct $id:s { struct futhark_context *ctx; $sdecls:(compileFreeStructFields free_args free_ctypes) $sdecls:(compileRetvalStructFields retval_args retval_ctypes) };|] fstruct <- def name makeStruct let fstruct' = fstruct <> "_" GC.decl [C.cdecl|struct $id:fstruct $id:fstruct';|] GC.stm [C.cstm|$id:fstruct'.ctx = ctx;|] GC.stms [C.cstms|$stms:(compileSetStructValues fstruct' free_args free_ctypes)|] GC.stms [C.cstms|$stms:(compileSetRetvalStructValues fstruct' retval_args retval_ctypes)|] pure fstruct -- Generate a segop function for top_level and potentially nested SegOp code compileOp :: GC.OpCompiler Multicore s compileOp (GetLoopBounds start end) = do GC.stm [C.cstm|$id:start = start;|] GC.stm [C.cstm|$id:end = end;|] compileOp (GetTaskId v) = GC.stm [C.cstm|$id:v = subtask_id;|] compileOp (GetNumTasks v) = GC.stm [C.cstm|$id:v = info.nsubtasks;|] compileOp (SegOp name params seq_task par_task retvals (SchedulerInfo e sched)) = do let (ParallelTask seq_code) = seq_task free_ctypes <- mapM paramToCType params retval_ctypes <- mapM paramToCType retvals let free_args = map paramName params retval_args = map paramName retvals free = zip free_args free_ctypes retval = zip retval_args retval_ctypes e' <- GC.compileExp e let lexical = lexicalMemoryUsageMC TraverseKernels $ Function Nothing [] params seq_code fstruct <- prepareTaskStruct multicoreDef "task" free_args free_ctypes retval_args retval_ctypes fpar_task <- generateParLoopFn lexical (name ++ "_task") seq_code fstruct free retval addTimingFields fpar_task let ftask_name = fstruct <> "_task" GC.decl [C.cdecl|struct scheduler_segop $id:ftask_name;|] GC.stm [C.cstm|$id:ftask_name.args = &$id:(fstruct <> "_");|] GC.stm [C.cstm|$id:ftask_name.top_level_fn = $id:fpar_task;|] GC.stm [C.cstm|$id:ftask_name.name = $string:(nameToString fpar_task);|] GC.stm [C.cstm|$id:ftask_name.iterations = $exp:e';|] -- Create the timing fields for the task GC.stm [C.cstm|$id:ftask_name.task_time = &ctx->program->$id:(functionTiming fpar_task);|] GC.stm [C.cstm|$id:ftask_name.task_iter = &ctx->program->$id:(functionIterations fpar_task);|] case sched of Dynamic -> GC.stm [C.cstm|$id:ftask_name.sched = DYNAMIC;|] Static -> GC.stm [C.cstm|$id:ftask_name.sched = STATIC;|] -- Generate the nested segop function if available case par_task of Just (ParallelTask nested_code) -> do let lexical_nested = lexicalMemoryUsageMC TraverseKernels $ Function Nothing [] params nested_code fnpar_task <- generateParLoopFn lexical_nested (name ++ "_nested_task") nested_code fstruct free retval GC.stm [C.cstm|$id:ftask_name.nested_fn = $id:fnpar_task;|] Nothing -> GC.stm [C.cstm|$id:ftask_name.nested_fn=NULL;|] let ftask_err = fpar_task <> "_err" code = [C.citems|int $id:ftask_err = scheduler_prepare_task(&ctx->scheduler, &$id:ftask_name); if ($id:ftask_err != 0) { err = $id:ftask_err; goto cleanup; }|] mapM_ GC.item code compileOp (ParLoop s' body free) = do free_ctypes <- mapM paramToCType free let free_args = map paramName free let lexical = lexicalMemoryUsageMC TraverseKernels $ Function Nothing [] free body fstruct <- prepareTaskStruct multicoreDef (s' ++ "_parloop_struct") free_args free_ctypes mempty mempty ftask <- multicoreDef (s' ++ "_parloop") $ \s -> do fbody <- benchmarkCode s <=< GC.inNewFunction $ GC.cachingMemory lexical $ \decl_cached free_cached -> GC.collect $ do GC.items [C.citems|$decls:(compileGetStructVals fstruct free_args free_ctypes)|] GC.decl [C.cdecl|typename int64_t iterations = end-start;|] body' <- GC.collect $ GC.compileCode body mapM_ GC.item decl_cached mapM_ GC.item =<< GC.declAllocatedMem free_mem <- GC.freeAllocatedMem mapM_ GC.item body' GC.stm [C.cstm|cleanup: {$stms:free_cached $items:free_mem}|] pure [C.cedecl|static int $id:s(void *args, typename int64_t start, typename int64_t end, int subtask_id, int tid) { (void)subtask_id; (void)tid; int err = 0; struct $id:fstruct *$id:fstruct = (struct $id:fstruct*) args; struct futhark_context *ctx = $id:fstruct->ctx; $items:fbody return err; }|] let ftask_name = ftask <> "_task" GC.decl [C.cdecl|struct scheduler_parloop $id:ftask_name;|] GC.stm [C.cstm|$id:ftask_name.name = $string:(nameToString ftask);|] GC.stm [C.cstm|$id:ftask_name.fn = $id:ftask;|] GC.stm [C.cstm|$id:ftask_name.args = &$id:(fstruct <> "_");|] GC.stm [C.cstm|$id:ftask_name.iterations = iterations;|] GC.stm [C.cstm|$id:ftask_name.info = info;|] let ftask_err = ftask <> "_err" ftask_total = ftask <> "_total" code' <- benchmarkCode ftask_total [C.citems|int $id:ftask_err = scheduler_execute_task(&ctx->scheduler, &$id:ftask_name); if ($id:ftask_err != 0) { err = $id:ftask_err; goto cleanup; }|] mapM_ GC.item code' compileOp (Atomic aop) = atomicOps aop (\ty _ -> pure [C.cty|$ty:ty*|]) compileOp (ISPCKernel body _) = scopedBlock body compileOp (ForEach i from bound body) = do let i' = C.toIdent i t = primTypeToCType $ primExpType bound from' <- GC.compileExp from bound' <- GC.compileExp bound body' <- GC.collect $ GC.compileCode body GC.stm [C.cstm|for ($ty:t $id:i' = $exp:from'; $id:i' < $exp:bound'; $id:i'++) { $items:body' }|] compileOp (ForEachActive i body) = do GC.decl [C.cdecl|typename int64_t $id:i = 0;|] scopedBlock body compileOp (ExtractLane dest tar _) = do tar' <- GC.compileExp tar GC.stm [C.cstm|$id:dest = $exp:tar';|] scopedBlock :: MCCode -> GC.CompilerM Multicore s () scopedBlock code = do inner <- GC.collect $ GC.compileCode code GC.stm [C.cstm|{$items:inner}|] doAtomic :: (C.ToIdent a1) => a1 -> VName -> Count u (TExp Int32) -> Exp -> String -> C.Type -> (C.Type -> VName -> GC.CompilerM op s C.Type) -> GC.CompilerM op s () doAtomic old arr ind val op ty castf = do ind' <- GC.compileExp $ untyped $ unCount ind val' <- GC.compileExp val cast <- castf ty arr arr' <- GC.rawMem arr GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast)$exp:arr')[$exp:ind'], ($ty:ty) $exp:val', __ATOMIC_RELAXED);|] atomicOps :: AtomicOp -> (C.Type -> VName -> GC.CompilerM op s C.Type) -> GC.CompilerM op s () atomicOps (AtomicCmpXchg t old arr ind res val) castf = do ind' <- GC.compileExp $ untyped $ unCount ind new_val' <- GC.compileExp val cast <- castf [C.cty|$ty:(GC.primTypeToCType t)|] arr arr' <- GC.rawMem arr GC.stm [C.cstm|$id:res = $id:op(&(($ty:cast)$exp:arr')[$exp:ind'], &$id:old, $exp:new_val', 0, __ATOMIC_SEQ_CST, __ATOMIC_RELAXED);|] where op :: String op = "__atomic_compare_exchange_n" atomicOps (AtomicXchg t old arr ind val) castf = do ind' <- GC.compileExp $ untyped $ unCount ind val' <- GC.compileExp val cast <- castf [C.cty|$ty:(GC.primTypeToCType t)|] arr GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast)$id:arr.mem)[$exp:ind'], $exp:val', __ATOMIC_SEQ_CST);|] where op :: String op = "__atomic_exchange_n" atomicOps (AtomicAdd t old arr ind val) castf = doAtomic old arr ind val "__atomic_fetch_add" [C.cty|$ty:(GC.intTypeToCType t)|] castf atomicOps (AtomicSub t old arr ind val) castf = doAtomic old arr ind val "__atomic_fetch_sub" [C.cty|$ty:(GC.intTypeToCType t)|] castf atomicOps (AtomicAnd t old arr ind val) castf = doAtomic old arr ind val "__atomic_fetch_and" [C.cty|$ty:(GC.intTypeToCType t)|] castf atomicOps (AtomicOr t old arr ind val) castf = doAtomic old arr ind val "__atomic_fetch_or" [C.cty|$ty:(GC.intTypeToCType t)|] castf atomicOps (AtomicXor t old arr ind val) castf = doAtomic old arr ind val "__atomic_fetch_xor" [C.cty|$ty:(GC.intTypeToCType t)|] castf futhark-0.25.27/src/Futhark/CodeGen/Backends/MulticoreC/000077500000000000000000000000001475065116200226215ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/Backends/MulticoreC/Boilerplate.hs000066400000000000000000000014131475065116200254160ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | Boilerplate for multicore C code. module Futhark.CodeGen.Backends.MulticoreC.Boilerplate (generateBoilerplate) where import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC qualified as GC import Futhark.CodeGen.RTS.C (backendsMulticoreH, schedulerH) import Language.C.Quote.OpenCL qualified as C -- | Generate the necessary boilerplate. generateBoilerplate :: GC.CompilerM op s () generateBoilerplate = do mapM_ GC.earlyDecl [C.cunit|$esc:(T.unpack schedulerH)|] mapM_ GC.earlyDecl [C.cunit|$esc:(T.unpack backendsMulticoreH)|] GC.headerDecl GC.InitDecl [C.cedecl|void futhark_context_config_set_num_threads(struct futhark_context_config *cfg, int n);|] GC.generateProgramStruct {-# NOINLINE generateBoilerplate #-} futhark-0.25.27/src/Futhark/CodeGen/Backends/MulticoreISPC.hs000066400000000000000000001224711475065116200235400ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | C code generator. This module can convert a correct ImpCode -- program to an equivalent ISPC program. module Futhark.CodeGen.Backends.MulticoreISPC ( compileProg, GC.CParts (..), GC.asLibrary, GC.asExecutable, GC.asServer, operations, ISPCState, ) where import Control.Lens (each, over) import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor import Data.DList qualified as DL import Data.List (unzip4) import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC qualified as GC import Futhark.CodeGen.Backends.GenericC.Pretty import Futhark.CodeGen.Backends.MulticoreC qualified as MC import Futhark.CodeGen.Backends.MulticoreC.Boilerplate (generateBoilerplate) import Futhark.CodeGen.Backends.SimpleRep import Futhark.CodeGen.ImpCode.Multicore import Futhark.CodeGen.ImpGen.Multicore qualified as ImpGen import Futhark.CodeGen.RTS.C (errorsH, ispcUtilH, uniformH) import Futhark.IR.MCMem (MCMem, Prog) import Futhark.IR.Prop (isBuiltInFunction) import Futhark.MonadFreshNames import Language.C.Quote.OpenCL qualified as C import Language.C.Syntax qualified as C import NeatInterpolation (untrimming) type ISPCCompilerM a = GC.CompilerM Multicore ISPCState a -- | Transient state tracked by the ISPC backend. data ISPCState = ISPCState { sDefs :: DL.DList C.Definition, sUniform :: Names } uniform :: C.TypeQual uniform = C.EscTypeQual "uniform" noLoc unmasked :: C.TypeQual unmasked = C.EscTypeQual "unmasked" noLoc export :: C.TypeQual export = C.EscTypeQual "export" noLoc varying :: C.TypeQual varying = C.EscTypeQual "varying" noLoc -- | Compile the program to C and ISPC code using multicore operations. compileProg :: (MonadFreshNames m) => T.Text -> Prog MCMem -> m (ImpGen.Warnings, (GC.CParts, T.Text)) compileProg version prog = do -- Dynamic scheduling seems completely broken currently, so we disable it. (ws, defs) <- ImpGen.compileProg prog let Functions funs = defFuns defs (ws', (cparts, endstate)) <- traverse ( GC.compileProg' "ispc" version mempty operations (ISPCState mempty mempty) ( do generateBoilerplate mapM_ compileBuiltinFun funs ) "#include \n" (DefaultSpace, [DefaultSpace]) MC.cliOptions ) (ws, defs) let ispc_decls = definitionsText $ DL.toList $ sDefs $ GC.compUserState endstate -- The bool #define is a workaround around an ISPC bug, stdbool doesn't get included. let ispcdefs = [untrimming| #define bool uint8 typedef int64 int64_t; typedef int32 int32_t; typedef int16 int16_t; typedef int8 int8_t; typedef int8 char; typedef unsigned int64 uint64_t; typedef unsigned int32 uint32_t; typedef unsigned int16 uint16_t; typedef unsigned int8 uint8_t; #define volatile #define SCALAR_FUN_ATTR static inline $errorsH #define INFINITY (floatbits((uniform int)0x7f800000)) #define NAN (floatbits((uniform int)0x7fc00000)) #define fabs(x) abs(x) #define FUTHARK_F64_ENABLED $cScalarDefs $uniformH $ispcUtilH $ispc_decls|] pure (ws', (cparts, ispcdefs)) -- | Compiler operations specific to the ISPC multicore backend. operations :: GC.Operations Multicore ISPCState operations = MC.operations { GC.opsCompiler = compileOp, -- FIXME: the default codegen for LMAD copies does not work for ISPC. GC.opsCopies = mempty } ispcDecl :: C.Definition -> ISPCCompilerM () ispcDecl def = GC.modifyUserState (\s -> s {sDefs = sDefs s <> DL.singleton def}) ispcEarlyDecl :: C.Definition -> ISPCCompilerM () ispcEarlyDecl def = GC.modifyUserState (\s -> s {sDefs = DL.singleton def <> sDefs s}) ispcDef :: MC.DefSpecifier ISPCState ispcDef s f = do s' <- MC.multicoreName s ispcDecl =<< f s' pure s' -- | Expose a struct to both ISPC and C. sharedDef :: MC.DefSpecifier ISPCState sharedDef s f = do s' <- MC.multicoreName s ispcDecl =<< f s' GC.earlyDecl =<< f s' pure s' -- | ISPC has no string literals, so this makes one in C and exposes it via an -- external function, returning the name. makeStringLiteral :: String -> ISPCCompilerM Name makeStringLiteral str = do name <- MC.multicoreDef "strlit_shim" $ \s -> pure [C.cedecl|char* $id:s() { return $string:str; }|] ispcDecl [C.cedecl|extern "C" $tyqual:unmasked $tyqual:uniform char* $tyqual:uniform $id:name();|] pure name -- | Set memory in ISPC setMem :: (C.ToExp a, C.ToExp b) => a -> b -> Space -> ISPCCompilerM () setMem dest src space = do let src_s = T.unpack $ expText $ C.toExp src noLoc strlit <- makeStringLiteral src_s GC.stm [C.cstm|if ($id:(GC.fatMemSet space)(ctx, &$exp:dest, &$exp:src, $id:strlit()) != 0) { $escstm:("unmasked { return 1; }") }|] -- | Unref memory in ISPC unRefMem :: (C.ToExp a) => a -> Space -> ISPCCompilerM () unRefMem mem space = do cached <- isJust <$> GC.cacheMem mem let mem_s = T.unpack $ expText $ C.toExp mem noLoc strlit <- makeStringLiteral mem_s unless cached $ GC.stm [C.cstm|if ($id:(GC.fatMemUnRef space)(ctx, &$exp:mem, $id:strlit()) != 0) { $escstm:("unmasked { return 1; }") }|] -- | Allocate memory in ISPC allocMem :: (C.ToExp a, C.ToExp b) => a -> b -> Space -> C.Stm -> ISPCCompilerM () allocMem mem size space on_failure = do let mem_s = T.unpack $ expText $ C.toExp mem noLoc strlit <- makeStringLiteral mem_s GC.stm [C.cstm|if ($id:(GC.fatMemAlloc space)(ctx, &$exp:mem, $exp:size, $id:strlit())) { $stm:on_failure }|] -- | Free memory in ISPC freeAllocatedMem :: ISPCCompilerM [C.BlockItem] freeAllocatedMem = GC.collect $ mapM_ (uncurry unRefMem) =<< gets GC.compDeclaredMem -- | Given a ImpCode function, generate all the required machinery for calling -- it in ISPC, both in a varying or uniform context. This involves handling -- for the fact that ISPC cannot pass structs by value to external functions. compileBuiltinFun :: (Name, Function op) -> ISPCCompilerM () compileBuiltinFun (fname, func@(Function _ outputs inputs _)) | isNothing $ functionEntry func = do let extra = [[C.cparam|$tyqual:uniform struct futhark_context * $tyqual:uniform ctx|]] extra_c = [[C.cparam|struct futhark_context * ctx|]] extra_exp = [[C.cexp|$id:p|] | C.Param (Just p) _ _ _ <- extra] (inparams_c, in_args_c) <- mapAndUnzipM (compileInputsExtern []) inputs (outparams_c, out_args_c) <- mapAndUnzipM (compileOutputsExtern []) outputs (inparams_extern, _) <- mapAndUnzipM (compileInputsExtern [C.ctyquals|$tyqual:uniform|]) inputs (outparams_extern, _) <- mapAndUnzipM (compileOutputsExtern [C.ctyquals|$tyqual:uniform|]) outputs (inparams_uni, in_args_noderef) <- mapAndUnzipM compileInputsUniform inputs (outparams_uni, out_args_noderef) <- mapAndUnzipM compileOutputsUniform outputs (inparams_varying, in_args_vary, prebody_in') <- unzip3 <$> mapM compileInputsVarying inputs (outparams_varying, out_args_vary, prebody_out', postbody_out') <- unzip4 <$> mapM compileOutputsVarying outputs let (prebody_in, prebody_out, postbody_out) = over each concat (prebody_in', prebody_out', postbody_out') GC.libDecl [C.cedecl|int $id:(funName fname <> "_extern")($params:extra_c, $params:outparams_c, $params:inparams_c) { return $id:(funName fname)($args:extra_exp, $args:out_args_c, $args:in_args_c); }|] let ispc_extern = [C.cedecl|extern "C" $tyqual:unmasked $tyqual:uniform int $id:((funName fname) <> "_extern") ($params:extra, $params:outparams_extern, $params:inparams_extern);|] ispc_uniform = [C.cedecl|$tyqual:uniform int $id:(funName fname) ($params:extra, $params:outparams_uni, $params:inparams_uni) { return $id:(funName (fname<>"_extern"))( $args:extra_exp, $args:out_args_noderef, $args:in_args_noderef); }|] ispc_varying = [C.cedecl|$tyqual:uniform int $id:(funName fname) ($params:extra, $params:outparams_varying, $params:inparams_varying) { $tyqual:uniform int err = 0; $items:prebody_in $items:prebody_out $escstm:("foreach_active (i)") { err |= $id:(funName $ fname<>"_extern")( $args:extra_exp, $args:out_args_vary, $args:in_args_vary); } $items:postbody_out return err; }|] mapM_ ispcEarlyDecl [ispc_varying, ispc_uniform, ispc_extern] | otherwise = pure () where compileInputsExtern vari (ScalarParam name bt) = do let ctp = GC.primTypeToCType bt pure ([C.cparam|$tyquals:vari $ty:ctp $id:name|], [C.cexp|$id:name|]) compileInputsExtern vari (MemParam name space) = do ty <- GC.memToCType name space pure ([C.cparam|$tyquals:vari $ty:ty * $tyquals:vari $id:name|], [C.cexp|*$id:name|]) compileOutputsExtern vari (ScalarParam name bt) = do p_name <- newVName $ "out_" ++ baseString name let ctp = GC.primTypeToCType bt pure ([C.cparam|$tyquals:vari $ty:ctp * $tyquals:vari $id:p_name|], [C.cexp|$id:p_name|]) compileOutputsExtern vari (MemParam name space) = do ty <- GC.memToCType name space p_name <- newVName $ baseString name ++ "_p" pure ([C.cparam|$tyquals:vari $ty:ty * $tyquals:vari $id:p_name|], [C.cexp|$id:p_name|]) compileInputsUniform (ScalarParam name bt) = do let ctp = GC.primTypeToCType bt params = [C.cparam|$tyqual:uniform $ty:ctp $id:name|] args = [C.cexp|$id:name|] pure (params, args) compileInputsUniform (MemParam name space) = do ty <- GC.memToCType name space let params = [C.cparam|$tyqual:uniform $ty:ty $id:name|] args = [C.cexp|&$id:name|] pure (params, args) compileOutputsUniform (ScalarParam name bt) = do p_name <- newVName $ "out_" ++ baseString name let ctp = GC.primTypeToCType bt params = [C.cparam|$tyqual:uniform $ty:ctp *$tyqual:uniform $id:p_name|] args = [C.cexp|$id:p_name|] pure (params, args) compileOutputsUniform (MemParam name space) = do ty <- GC.memToCType name space p_name <- newVName $ baseString name ++ "_p" let params = [C.cparam|$tyqual:uniform $ty:ty $id:p_name|] args = [C.cexp|&$id:p_name|] pure (params, args) compileInputsVarying (ScalarParam name bt) = do let ctp = GC.primTypeToCType bt params = [C.cparam|$ty:ctp $id:name|] args = [C.cexp|extract($id:name,i)|] pre_body = [] pure (params, args, pre_body) compileInputsVarying (MemParam name space) = do typ <- GC.memToCType name space newvn <- newVName $ "aos_" <> baseString name let params = [C.cparam|$ty:typ $id:name|] args = [C.cexp|&$id:(newvn)[i]|] pre_body = [C.citems|$tyqual:uniform $ty:typ $id:(newvn)[programCount]; $id:(newvn)[programIndex] = $id:name;|] pure (params, args, pre_body) compileOutputsVarying (ScalarParam name bt) = do p_name <- newVName $ "out_" ++ baseString name deref_name <- newVName $ "aos_" ++ baseString name vari_p_name <- newVName $ "convert_" ++ baseString name let ctp = GC.primTypeToCType bt pre_body = [C.citems|$tyqual:varying $ty:ctp $id:vari_p_name = *$id:p_name; $tyqual:uniform $ty:ctp $id:deref_name[programCount]; $id:deref_name[programIndex] = $id:vari_p_name;|] post_body = [C.citems|*$id:p_name = $id:(deref_name)[programIndex];|] params = [C.cparam|$tyqual:varying $ty:ctp * $tyqual:uniform $id:p_name|] args = [C.cexp|&$id:(deref_name)[i]|] pure (params, args, pre_body, post_body) compileOutputsVarying (MemParam name space) = do typ <- GC.memToCType name space newvn <- newVName $ "aos_" <> baseString name let params = [C.cparam|$ty:typ $id:name|] args = [C.cexp|&$id:(newvn)[i]|] pre_body = [C.citems|$tyqual:uniform $ty:typ $id:(newvn)[programCount]; $id:(newvn)[programIndex] = $id:name;|] pure (params, args, pre_body, []) -- | Handle logging an error message in ISPC. handleError :: ErrorMsg Exp -> String -> ISPCCompilerM () handleError msg stacktrace = do -- Get format sting (formatstr, formatargs) <- GC.errorMsgString msg let formatstr' = "Error: " <> formatstr <> "\n\nBacktrace:\n%s" -- Get args types and names for shim let arg_types = errorMsgArgTypes msg arg_names <- mapM (newVName . const "arg") arg_types let params = zipWith (\ty name -> [C.cparam|$ty:(GC.primTypeToCType ty) $id:name|]) arg_types arg_names let params_uni = zipWith (\ty name -> [C.cparam|$tyqual:uniform $ty:(GC.primTypeToCType ty) $id:name|]) arg_types arg_names -- Make shim let formatargs' = mapArgNames msg formatargs arg_names shim <- MC.multicoreDef "assert_shim" $ \s -> do pure [C.cedecl|void $id:s(struct futhark_context* ctx, $params:params) { set_error(ctx, msgprintf($string:formatstr', $args:formatargs', $string:stacktrace)); }|] ispcDecl [C.cedecl|extern "C" $tyqual:unmasked void $id:shim($tyqual:uniform struct futhark_context* $tyqual:uniform, $params:params_uni);|] -- Call the shim args <- getErrorValExps msg uni <- newVName "uni" let args' = map (\x -> [C.cexp|extract($exp:x, $id:uni)|]) args GC.items [C.citems| $escstm:("foreach_active(" <> prettyString uni <> ")") { $id:shim(ctx, $args:args'); err = FUTHARK_PROGRAM_ERROR; } $escstm:("unmasked { return err; }")|] where getErrorVal (ErrorString _) = Nothing getErrorVal (ErrorVal _ v) = Just v getErrorValExps (ErrorMsg m) = mapM compileExp $ mapMaybe getErrorVal m mapArgNames' (x : xs) (y : ys) (t : ts) | isJust $ getErrorVal x = [C.cexp|$id:t|] : mapArgNames' xs ys ts | otherwise = y : mapArgNames' xs ys (t : ts) mapArgNames' _ ys [] = ys mapArgNames' _ _ _ = [] mapArgNames (ErrorMsg parts) = mapArgNames' parts -- | Given the name and type of a parameter, return the C type used to -- represent it. We use uniform pointers to varying values for lexical -- memory blocks, as this generally results in less gathers/scatters. getMemType :: VName -> PrimType -> ISPCCompilerM C.Type getMemType dest elemtype = do cached <- isJust <$> GC.cacheMem dest if cached then pure [C.cty|$tyqual:varying $ty:(primStorageType elemtype)* uniform|] else pure [C.cty|$ty:(primStorageType elemtype)*|] compileExp :: Exp -> ISPCCompilerM C.Exp compileExp e@(ValueExp (FloatValue (Float64Value v))) = if isInfinite v || isNaN v then GC.compileExp e else pure [C.cexp|$esc:(prettyString v <> "d")|] compileExp e@(ValueExp (FloatValue (Float16Value v))) = if isInfinite v || isNaN v then GC.compileExp e else pure [C.cexp|$esc:(prettyString v <> "f16")|] compileExp (ValueExp val) = pure $ C.toExp val mempty compileExp (LeafExp v _) = pure [C.cexp|$id:v|] compileExp (UnOpExp Complement {} x) = do x' <- compileExp x pure [C.cexp|~$exp:x'|] compileExp (UnOpExp (Neg Bool) x) = do x' <- compileExp x pure [C.cexp|!$exp:x'|] compileExp (UnOpExp Neg {} x) = do x' <- compileExp x pure [C.cexp|-$exp:x'|] compileExp (UnOpExp (FAbs Float32) x) = do x' <- compileExp x pure [C.cexp|(float)fabs($exp:x')|] compileExp (UnOpExp (FAbs Float64) x) = do x' <- compileExp x pure [C.cexp|fabs($exp:x')|] compileExp (UnOpExp SSignum {} x) = do x' <- compileExp x pure [C.cexp|($exp:x' > 0 ? 1 : 0) - ($exp:x' < 0 ? 1 : 0)|] compileExp (UnOpExp USignum {} x) = do x' <- compileExp x pure [C.cexp|($exp:x' > 0 ? 1 : 0) - ($exp:x' < 0 ? 1 : 0) != 0|] compileExp (UnOpExp op x) = do x' <- compileExp x pure [C.cexp|$id:(prettyString op)($exp:x')|] compileExp (CmpOpExp cmp x y) = do x' <- compileExp x y' <- compileExp y pure $ case cmp of CmpEq {} -> [C.cexp|$exp:x' == $exp:y'|] FCmpLt {} -> [C.cexp|$exp:x' < $exp:y'|] FCmpLe {} -> [C.cexp|$exp:x' <= $exp:y'|] CmpLlt {} -> [C.cexp|$exp:x' < $exp:y'|] CmpLle {} -> [C.cexp|$exp:x' <= $exp:y'|] _ -> [C.cexp|$id:(prettyString cmp)($exp:x', $exp:y')|] compileExp (ConvOpExp conv x) = do x' <- compileExp x pure [C.cexp|$id:(prettyString conv)($exp:x')|] compileExp (BinOpExp bop x y) = do x' <- compileExp x y' <- compileExp y pure $ case bop of Add _ OverflowUndef -> [C.cexp|$exp:x' + $exp:y'|] Sub _ OverflowUndef -> [C.cexp|$exp:x' - $exp:y'|] Mul _ OverflowUndef -> [C.cexp|$exp:x' * $exp:y'|] FAdd {} -> [C.cexp|$exp:x' + $exp:y'|] FSub {} -> [C.cexp|$exp:x' - $exp:y'|] FMul {} -> [C.cexp|$exp:x' * $exp:y'|] FDiv {} -> [C.cexp|$exp:x' / $exp:y'|] Xor {} -> [C.cexp|$exp:x' ^ $exp:y'|] And {} -> [C.cexp|$exp:x' & $exp:y'|] Or {} -> [C.cexp|$exp:x' | $exp:y'|] LogAnd {} -> [C.cexp|$exp:x' && $exp:y'|] LogOr {} -> [C.cexp|$exp:x' || $exp:y'|] _ -> [C.cexp|$id:(prettyString bop)($exp:x', $exp:y')|] compileExp (FunExp h args _) = do args' <- mapM compileExp args pure [C.cexp|$id:(funName (nameFromText h))($args:args')|] -- | Compile a block of code with ISPC specific semantics, falling back -- to generic C when this semantics is not needed. -- All recursive constructors are duplicated here, since not doing so -- would cause use to enter regular generic C codegen with no escape. compileCode :: MCCode -> ISPCCompilerM () compileCode (Comment s code) = do xs <- GC.collect $ compileCode code let comment = "// " ++ T.unpack s GC.stm [C.cstm|$comment:comment { $items:xs } |] compileCode (DeclareScalar name _ t) = do let ct = GC.primTypeToCType t quals <- getVariabilityQuals name GC.decl [C.cdecl|$tyquals:quals $ty:ct $id:name;|] compileCode (DeclareArray name t vs) = do name_realtype <- newVName $ baseString name ++ "_realtype" let ct = GC.primTypeToCType t case vs of ArrayValues vs' -> do let vs'' = [[C.cinit|$exp:v|] | v <- vs'] GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:(length vs')] = {$inits:vs''};|] ArrayZeros n -> GC.earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:n];|] -- Make an exported C shim to access a faked memory block. shim <- MC.multicoreDef "get_static_array_shim" $ \f -> pure [C.cedecl|struct memblock $id:f(struct futhark_context* ctx) { return (struct memblock){NULL,(unsigned char*)$id:name_realtype,0}; }|] ispcDecl [C.cedecl|extern "C" $tyqual:unmasked $tyqual:uniform struct memblock $tyqual:uniform $id:shim($tyqual:uniform struct futhark_context* $tyqual:uniform ctx);|] -- Call it GC.item [C.citem|$tyqual:uniform struct memblock $id:name = $id:shim(ctx);|] compileCode (c1 :>>: c2) = go (GC.linearCode (c1 :>>: c2)) where go (DeclareScalar name _ t : SetScalar dest e : code) | name == dest = do let ct = GC.primTypeToCType t e' <- compileExp e quals <- getVariabilityQuals name GC.item [C.citem|$tyquals:quals $ty:ct $id:name = $exp:e';|] go code go (x : xs) = compileCode x >> go xs go [] = pure () compileCode (Allocate name (Count (TPrimExp e)) space) = do size <- compileExp e cached <- GC.cacheMem name case cached of Just cur_size -> GC.stm [C.cstm|if ($exp:cur_size < $exp:size) { err = lexical_realloc(ctx, &$exp:name, &$exp:cur_size, $exp:size); if (err != FUTHARK_SUCCESS) { $escstm:("unmasked { return err; }") } }|] _ -> allocMem name size space [C.cstm|$escstm:("unmasked { return 1; }")|] compileCode (SetMem dest src space) = setMem dest src space compileCode (Write dest (Count idx) elemtype DefaultSpace _ elemexp) | isConstExp (untyped idx) = do dest' <- GC.rawMem dest idxexp <- compileExp $ constFoldPrimExp $ untyped idx deref <- GC.derefPointer dest' [C.cexp|($tyquals:([varying]) typename int64_t)$exp:idxexp|] <$> getMemType dest elemtype elemexp' <- toStorage elemtype <$> compileExp elemexp GC.stm [C.cstm|$exp:deref = $exp:elemexp';|] | otherwise = do dest' <- GC.rawMem dest idxexp <- compileExp $ untyped idx deref <- GC.derefPointer dest' [C.cexp|($tyquals:([varying]) typename int64_t)$exp:idxexp|] <$> getMemType dest elemtype elemexp' <- toStorage elemtype <$> compileExp elemexp GC.stm [C.cstm|$exp:deref = $exp:elemexp';|] where isConstExp = isSimple . constFoldPrimExp isSimple (ValueExp _) = True isSimple _ = False compileCode (Read x src (Count iexp) restype DefaultSpace _) = do src' <- GC.rawMem src e <- fmap (fromStorage restype) $ GC.derefPointer src' <$> compileExp (untyped iexp) <*> getMemType src restype GC.stm [C.cstm|$id:x = $exp:e;|] compileCode (Copy t shape (dst, DefaultSpace) dst_lmad (src, DefaultSpace) src_lmad) = do dst' <- GC.rawMem dst src' <- GC.rawMem src let doWrite dst_i ve = do deref <- GC.derefPointer dst' [C.cexp|($tyquals:([varying]) typename int64_t)$exp:dst_i|] <$> getMemType dst t GC.stm [C.cstm|$exp:deref = $exp:(toStorage t ve);|] doRead src_i = fromStorage t . GC.derefPointer src' src_i <$> getMemType src t GC.compileCopyWith shape doWrite dst_lmad doRead src_lmad compileCode (Free name space) = do cached <- isJust <$> GC.cacheMem name unless cached $ unRefMem name space compileCode (For i bound body) -- The special-case here is to avoid certain pathological/contrived -- programs that construct statically known zero-element arrays. -- Due to the way we do constant-fold index functions, this produces -- code that looks like it has uniform/varying mismatches (i.e. race -- conditions) to ISPC, even though that code is never actually run. | isZero bound = pure () | otherwise = do let i' = C.toIdent i t = GC.primTypeToCType $ primExpType bound bound' <- compileExp bound body' <- GC.collect $ compileCode body quals <- getVariabilityQuals i GC.stm [C.cstm|for ($tyquals:quals $ty:t $id:i' = 0; $id:i' < $exp:bound'; $id:i'++) { $items:body' }|] where isZero (ValueExp v) = zeroIsh v isZero _ = False compileCode (While cond body) = do cond' <- compileExp $ untyped cond body' <- GC.collect $ compileCode body GC.stm [C.cstm|while ($exp:cond') { $items:body' }|] compileCode (If cond tbranch fbranch) = do cond' <- compileExp $ untyped cond tbranch' <- GC.collect $ compileCode tbranch fbranch' <- GC.collect $ compileCode fbranch GC.stm $ case (tbranch', fbranch') of (_, []) -> [C.cstm|if ($exp:cond') { $items:tbranch' }|] ([], _) -> [C.cstm|if (!($exp:cond')) { $items:fbranch' }|] _ -> [C.cstm|if ($exp:cond') { $items:tbranch' } else { $items:fbranch' }|] compileCode (Call dests fname args) = do (dests', unpack_dest) <- mapAndUnzipM GC.compileDest dests defCallIspc dests' fname =<< mapM GC.compileArg args GC.stms $ mconcat unpack_dest where defCallIspc dests' fname' args' = do let out_args = [[C.cexp|&$id:d|] | d <- dests'] args'' | isBuiltInFunction fname' = args' | otherwise = [C.cexp|ctx|] : out_args ++ args' case dests' of [d] | isBuiltInFunction fname' -> GC.stm [C.cstm|$id:d = $id:(funName fname')($args:args'');|] _ -> GC.item [C.citem| if ($id:(funName fname')($args:args'') != 0) { $escstm:("unmasked { return 1; }") }|] compileCode (Assert e msg (loc, locs)) = do e' <- compileExp e err <- GC.collect $ handleError msg stacktrace GC.stm [C.cstm|if (!$exp:e') { $items:err }|] where stacktrace = T.unpack $ prettyStacktrace 0 $ map locText $ loc : locs compileCode code = GC.compileCode code -- | Prepare a struct with memory allocted in the scope and populate -- its fields with values prepareMemStruct :: [(VName, VName)] -> [VName] -> ISPCCompilerM Name prepareMemStruct lexmems fatmems = do let lex_defs = concatMap lexMemDef lexmems let fat_defs = map fatMemDef fatmems name <- ispcDef "mem_struct" $ \s -> do pure [C.cedecl|struct $id:s { $sdecls:lex_defs $sdecls:fat_defs };|] let name' = name <> "_" GC.decl [C.cdecl|$tyqual:uniform struct $id:name $id:name';|] forM_ (concatMap (\(a, b) -> [a, b]) lexmems) $ \m -> GC.stm [C.cstm|$id:name'.$id:m = $id:m;|] forM_ fatmems $ \m -> GC.stm [C.cstm|$id:name'.$id:m = &$id:m;|] pure name where lexMemDef (name, size) = [ [C.csdecl|$tyqual:varying unsigned char * $tyqual:uniform $id:name;|], [C.csdecl|$tyqual:varying size_t $id:size;|] ] fatMemDef name = [C.csdecl|$tyqual:varying struct memblock * $tyqual:uniform $id:name;|] -- | Get memory from the memory struct into local variables compileGetMemStructVals :: Name -> [(VName, VName)] -> [VName] -> ISPCCompilerM () compileGetMemStructVals struct lexmems fatmems = do forM_ fatmems $ \m -> GC.decl [C.cdecl|struct memblock $id:m = *$id:struct->$id:m;|] forM_ lexmems $ \(m, s) -> do GC.decl [C.cdecl|$tyqual:varying unsigned char * $tyqual:uniform $id:m = $id:struct->$id:m;|] GC.decl [C.cdecl|size_t $id:s = $id:struct->$id:s;|] -- | Write back potentially changed memory addresses and sizes to the memory struct compileWritebackMemStructVals :: Name -> [(VName, VName)] -> [VName] -> ISPCCompilerM () compileWritebackMemStructVals struct lexmems fatmems = do forM_ fatmems $ \m -> GC.stm [C.cstm|*$id:struct->$id:m = $id:m;|] forM_ lexmems $ \(m, s) -> do GC.stm [C.cstm|$id:struct->$id:m = $id:m;|] GC.stm [C.cstm|$id:struct->$id:s = $id:s;|] -- | Read back potentially changed memory addresses and sizes to the memory struct into local variables compileReadbackMemStructVals :: Name -> [(VName, VName)] -> [VName] -> ISPCCompilerM () compileReadbackMemStructVals struct lexmems fatmems = do forM_ fatmems $ \m -> GC.stm [C.cstm|$id:m = *$id:struct.$id:m;|] forM_ lexmems $ \(m, s) -> do GC.stm [C.cstm|$id:m = $id:struct.$id:m;|] GC.stm [C.cstm|$id:s = $id:struct.$id:s;|] compileGetStructVals :: Name -> [VName] -> [(C.Type, MC.ValueType)] -> ISPCCompilerM [C.BlockItem] compileGetStructVals struct a b = concat <$> zipWithM field a b where struct' = struct <> "_" field name (ty, MC.Prim pt) = do let inner = [C.cexp|$id:struct'->$id:(MC.closureFreeStructField name)|] pure [C.citems|$tyqual:uniform $ty:ty $id:name = $exp:(fromStorage pt inner);|] field name (_, _) = do strlit <- makeStringLiteral $ prettyString name pure [C.citems|$tyqual:uniform struct memblock $id:name; $id:name.desc = $id:strlit(); $id:name.mem = $id:struct'->$id:(MC.closureFreeStructField name); $id:name.size = 0; $id:name.references = NULL;|] -- | Can the given code produce an error? If so, we can't use foreach -- loops, since they don't allow for early-outs in error handling. mayProduceError :: MCCode -> Bool mayProduceError (x :>>: y) = mayProduceError x || mayProduceError y mayProduceError (If _ x y) = mayProduceError x || mayProduceError y mayProduceError (For _ _ x) = mayProduceError x mayProduceError (While _ x) = mayProduceError x mayProduceError (Comment _ x) = mayProduceError x mayProduceError (Op (ForEachActive _ body)) = mayProduceError body mayProduceError (Op (ForEach _ _ _ body)) = mayProduceError body mayProduceError (Op SegOp {}) = True mayProduceError Allocate {} = True mayProduceError Assert {} = True mayProduceError SetMem {} = True mayProduceError Free {} = True mayProduceError Call {} = True mayProduceError _ = False -- Generate a segop function for top_level and potentially nested SegOp code compileOp :: GC.OpCompiler Multicore ISPCState compileOp (SegOp name params seq_task par_task retvals (SchedulerInfo e sched)) = do let (ParallelTask seq_code) = seq_task free_ctypes <- mapM MC.paramToCType params retval_ctypes <- mapM MC.paramToCType retvals let free_args = map paramName params retval_args = map paramName retvals free = zip free_args free_ctypes retval = zip retval_args retval_ctypes e' <- compileExp e let lexical = lexicalMemoryUsageMC OpaqueKernels $ Function Nothing [] params seq_code fstruct <- MC.prepareTaskStruct sharedDef "task" free_args free_ctypes retval_args retval_ctypes fpar_task <- MC.generateParLoopFn lexical (name ++ "_task") seq_code fstruct free retval MC.addTimingFields fpar_task let ftask_name = fstruct <> "_task" to_c <- GC.collect $ do GC.decl [C.cdecl|struct scheduler_segop $id:ftask_name;|] GC.stm [C.cstm|$id:ftask_name.args = args;|] GC.stm [C.cstm|$id:ftask_name.top_level_fn = $id:fpar_task;|] GC.stm [C.cstm|$id:ftask_name.name = $string:(nameToString fpar_task);|] GC.stm [C.cstm|$id:ftask_name.iterations = iterations;|] -- Create the timing fields for the task GC.stm [C.cstm|$id:ftask_name.task_time = &ctx->program->$id:(MC.functionTiming fpar_task);|] GC.stm [C.cstm|$id:ftask_name.task_iter = &ctx->program->$id:(MC.functionIterations fpar_task);|] case sched of Dynamic -> GC.stm [C.cstm|$id:ftask_name.sched = DYNAMIC;|] Static -> GC.stm [C.cstm|$id:ftask_name.sched = STATIC;|] -- Generate the nested segop function if available case par_task of Just (ParallelTask nested_code) -> do let lexical_nested = lexicalMemoryUsageMC OpaqueKernels $ Function Nothing [] params nested_code fnpar_task <- MC.generateParLoopFn lexical_nested (name ++ "_nested_task") nested_code fstruct free retval GC.stm [C.cstm|$id:ftask_name.nested_fn = $id:fnpar_task;|] Nothing -> GC.stm [C.cstm|$id:ftask_name.nested_fn=NULL;|] GC.stm [C.cstm|return scheduler_prepare_task(&ctx->scheduler, &$id:ftask_name);|] schedn <- MC.multicoreDef "schedule_shim" $ \s -> pure [C.cedecl|int $id:s(struct futhark_context* ctx, void* args, typename int64_t iterations) { $items:to_c }|] ispcDecl [C.cedecl|extern "C" $tyqual:unmasked $tyqual:uniform int $id:schedn (struct futhark_context $tyqual:uniform * $tyqual:uniform ctx, struct $id:fstruct $tyqual:uniform * $tyqual:uniform args, $tyqual:uniform int iterations);|] aos_name <- newVName "aos" GC.items [C.citems| $escstm:("#if defined(ISPC)") $tyqual:uniform struct $id:fstruct $id:aos_name[programCount]; $id:aos_name[programIndex] = $id:(fstruct <> "_"); $escstm:("foreach_active (i)") { if (err == 0) { err = $id:schedn(ctx, &$id:aos_name[i], extract($exp:e', i)); } } if (err != 0) { $escstm:("unmasked { return err; }") } $escstm:("#else") err = $id:schedn(ctx, &$id:(fstruct <> "_"), $exp:e'); if (err != 0) { goto cleanup; } $escstm:("#endif")|] compileOp (ISPCKernel body free) = do free_ctypes <- mapM MC.paramToCType free let free_args = map paramName free let lexical = lexicalMemoryUsageMC OpaqueKernels $ Function Nothing [] free body -- Generate ISPC kernel fstruct <- MC.prepareTaskStruct sharedDef "param_struct" free_args free_ctypes [] [] let fstruct' = fstruct <> "_" ispcShim <- ispcDef "loop_ispc" $ \s -> do mainBody <- GC.inNewFunction $ analyzeVariability body $ cachingMemory lexical $ \decl_cached free_cached lexmems -> GC.collect $ do GC.decl [C.cdecl|$tyqual:uniform struct futhark_context * $tyqual:uniform ctx = $id:fstruct'->ctx;|] GC.items =<< compileGetStructVals fstruct free_args free_ctypes body' <- GC.collect $ compileCode body mapM_ GC.item decl_cached mapM_ GC.item =<< GC.declAllocatedMem -- Make inner kernel for error handling, if needed if mayProduceError body then do fatmems <- gets (map fst . GC.compDeclaredMem) mstruct <- prepareMemStruct lexmems fatmems let mstruct' = mstruct <> "_" innerShim <- ispcDef "inner_ispc" $ \t -> do innerBody <- GC.collect $ do GC.decl [C.cdecl|$tyqual:uniform struct futhark_context * $tyqual:uniform ctx = $id:fstruct'->ctx;|] GC.items =<< compileGetStructVals fstruct free_args free_ctypes compileGetMemStructVals mstruct' lexmems fatmems GC.decl [C.cdecl|$tyqual:uniform int err = 0;|] mapM_ GC.item body' compileWritebackMemStructVals mstruct' lexmems fatmems GC.stm [C.cstm|return err;|] pure [C.cedecl| static $tyqual:unmasked inline $tyqual:uniform int $id:t( $tyqual:uniform typename int64_t start, $tyqual:uniform typename int64_t end, struct $id:fstruct $tyqual:uniform * $tyqual:uniform $id:fstruct', struct $id:mstruct $tyqual:uniform * $tyqual:uniform $id:mstruct') { $items:innerBody }|] -- Call the kernel and read back potentially changed memory GC.decl [C.cdecl|$tyqual:uniform int err = $id:innerShim(start, end, $id:fstruct', &$id:mstruct');|] compileReadbackMemStructVals mstruct' lexmems fatmems else do GC.decl [C.cdecl|$tyqual:uniform int err = 0;|] mapM_ GC.item body' free_mem <- freeAllocatedMem GC.stm [C.cstm|cleanup: {$stms:free_cached $items:free_mem}|] GC.stm [C.cstm|return err;|] GC.earlyDecl [C.cedecl|int $id:s(typename int64_t start, typename int64_t end, struct $id:fstruct * $id:fstruct');|] pure [C.cedecl| $tyqual:export $tyqual:uniform int $id:s($tyqual:uniform typename int64_t start, $tyqual:uniform typename int64_t end, struct $id:fstruct $tyqual:uniform * $tyqual:uniform $id:fstruct' ) { $items:mainBody }|] -- Generate C code to call into ISPC kernel GC.items [C.citems| err = $id:ispcShim(start, end, & $id:fstruct'); if (err != 0) { goto cleanup; }|] compileOp (ForEach i from bound body) = do from' <- compileExp from bound' <- compileExp bound body' <- GC.collect $ compileCode body if mayProduceError body then GC.stms [C.cstms| for ($tyqual:uniform typename int64_t i = 0; i < (($exp:bound' - $exp:from') / programCount); i++) { typename int64_t $id:i = $exp:from' + programIndex + i * programCount; $items:body' } if (programIndex < (($exp:bound' - $exp:from') % programCount)) { typename int64_t $id:i = $exp:from' + programIndex + ((($exp:bound' - $exp:from') / programCount) * programCount); $items:body' }|] else GC.stms [C.cstms| $escstm:(T.unpack ("foreach (" <> prettyText i <> " = " <> expText from' <> " ... " <> expText bound' <> ")")) { $items:body' }|] compileOp (ForEachActive name body) = do body' <- GC.collect $ compileCode body GC.stms [C.cstms| for ($tyqual:uniform unsigned int $id:name = 0; $id:name < programCount; $id:name++) { if (programIndex == $id:name) { $items:body' } }|] compileOp (ExtractLane dest (ValueExp v) _) = -- extract() on constants is not allowed (type is uniform, not -- varying), so just turn them into an assignment. GC.stm [C.cstm|$id:dest = $exp:v;|] compileOp (ExtractLane dest tar lane) = do tar' <- compileExp tar lane' <- compileExp lane GC.stm [C.cstm|$id:dest = extract($exp:tar', $exp:lane');|] compileOp (Atomic aop) = MC.atomicOps aop $ \ty arr -> do cached <- isJust <$> GC.cacheMem arr if cached then pure [C.cty|$tyqual:varying $ty:ty* $tyqual:uniform|] else pure [C.cty|$ty:ty*|] compileOp op = MC.compileOp op -- | Like @GenericC.cachingMemory@, but adapted for ISPC codegen. cachingMemory :: M.Map VName Space -> ([C.BlockItem] -> [C.Stm] -> [(VName, VName)] -> GC.CompilerM op s a) -> GC.CompilerM op s a cachingMemory lexical f = do let cached = M.keys $ M.filter (== DefaultSpace) lexical cached' <- forM cached $ \mem -> do size <- newVName $ prettyString mem <> "_cached_size" pure (mem, size) let lexMem env = env { GC.envCachedMem = M.fromList (map (first (`C.toExp` noLoc)) cached') <> GC.envCachedMem env } declCached (mem, size) = [ [C.citem|size_t $id:size = 0;|], [C.citem|$tyqual:varying unsigned char * $tyqual:uniform $id:mem = NULL;|] ] freeCached (mem, _) = [C.cstm|free($id:mem);|] local lexMem $ f (concatMap declCached cached') (map freeCached cached') cached' -- Variability analysis type Dependencies = M.Map VName Names data Variability = Uniform | Varying deriving (Eq, Ord, Show) newtype VariabilityM a = VariabilityM (ReaderT Names (State Dependencies) a) deriving ( Functor, Applicative, Monad, MonadState Dependencies, MonadReader Names ) execVariabilityM :: VariabilityM a -> Dependencies execVariabilityM (VariabilityM m) = execState (runReaderT m mempty) mempty -- | Extend the set of dependencies with a new one addDeps :: VName -> Names -> VariabilityM () addDeps v ns = do deps <- get env <- ask case M.lookup v deps of Nothing -> put $ M.insert v (ns <> env) deps Just ns' -> put $ M.insert v (ns <> ns') deps -- | Find all the dependencies in a body of code findDeps :: MCCode -> VariabilityM () findDeps (x :>>: y) = do findDeps x findDeps y findDeps (If cond x y) = local (<> freeIn cond) $ do findDeps x findDeps y findDeps (For idx bound x) = do addDeps idx free local (<> free) $ findDeps x where free = freeIn bound findDeps (While cond x) = do local (<> freeIn cond) $ findDeps x findDeps (Comment _ x) = findDeps x findDeps (Op (SegOp _ free _ _ retvals _)) = mapM_ ( \x -> addDeps (paramName x) $ namesFromList $ map paramName free ) retvals findDeps (Op (ForEach _ _ _ body)) = findDeps body findDeps (Op (ForEachActive _ body)) = findDeps body findDeps (SetScalar name e) = addDeps name $ freeIn e findDeps (Call tars _ args) = mapM_ (\x -> addDeps x $ freeIn args) tars findDeps (Read x arr (Count iexp) _ DefaultSpace _) = do addDeps x $ freeIn (untyped iexp) addDeps x $ oneName arr findDeps (Op (GetLoopBounds x y)) = do addDeps x mempty addDeps y mempty findDeps (Op (ExtractLane x _ _)) = do addDeps x mempty findDeps (Op (Atomic (AtomicCmpXchg _ old arr ind res val))) = do addDeps res $ freeIn arr <> freeIn ind <> freeIn val addDeps old $ freeIn arr <> freeIn ind <> freeIn val findDeps _ = pure () -- | Take a list of dependencies and iterate them to a fixed point. depsFixedPoint :: Dependencies -> Dependencies depsFixedPoint deps = if deps == deps' then deps else depsFixedPoint deps' where grow names = names <> foldMap (\n -> M.findWithDefault mempty n deps) (namesIntMap names) deps' = M.map grow deps -- | Find roots of variance. These are memory blocks declared in -- the current scope as well as loop indices of foreach loops. findVarying :: MCCode -> [VName] findVarying (x :>>: y) = findVarying x ++ findVarying y findVarying (If _ x y) = findVarying x ++ findVarying y findVarying (For _ _ x) = findVarying x findVarying (While _ x) = findVarying x findVarying (Comment _ x) = findVarying x findVarying (Op (ForEachActive _ body)) = findVarying body findVarying (Op (ForEach idx _ _ body)) = idx : findVarying body findVarying (DeclareMem mem _) = [mem] findVarying _ = [] -- | Analyze variability in a body of code and run an action with -- info about that variability in the compiler state. analyzeVariability :: MCCode -> ISPCCompilerM a -> ISPCCompilerM a analyzeVariability code m = do let roots = findVarying code let deps = depsFixedPoint $ execVariabilityM $ findDeps code let safelist = M.filter (\b -> all (`notNameIn` b) roots) deps let safe = namesFromList $ M.keys safelist pre_state <- GC.getUserState GC.modifyUserState (\s -> s {sUniform = safe}) a <- m GC.modifyUserState (\s -> s {sUniform = sUniform pre_state}) pure a -- | Get the variability of a variable getVariability :: VName -> ISPCCompilerM Variability getVariability name = do uniforms <- sUniform <$> GC.getUserState pure $ if name `nameIn` uniforms then Uniform else Varying -- | Get the variability qualifiers of a variable getVariabilityQuals :: VName -> ISPCCompilerM [C.TypeQual] getVariabilityQuals name = variQuals <$> getVariability name where variQuals Uniform = [C.ctyquals|$tyqual:uniform|] variQuals Varying = [] futhark-0.25.27/src/Futhark/CodeGen/Backends/MulticoreWASM.hs000066400000000000000000000043501475065116200235440ustar00rootroot00000000000000-- | C code generator. This module can convert a correct ImpCode -- program to an equivalent C program. This C program is expected to -- be converted to WebAssembly, so we also produce the intended -- JavaScript wrapper. module Futhark.CodeGen.Backends.MulticoreWASM ( compileProg, runServer, libraryExports, GC.CParts (..), GC.asLibrary, GC.asExecutable, GC.asServer, ) where import Data.Maybe import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC qualified as GC import Futhark.CodeGen.Backends.GenericWASM import Futhark.CodeGen.Backends.MulticoreC qualified as MC import Futhark.CodeGen.Backends.MulticoreC.Boilerplate (generateBoilerplate) import Futhark.CodeGen.ImpCode.Multicore qualified as Imp import Futhark.CodeGen.ImpGen.Multicore qualified as ImpGen import Futhark.IR.MCMem import Futhark.MonadFreshNames -- | Compile Futhark program to wasm-multicore program (some assembly -- required). -- -- The triple that is returned consists of -- -- * Generated C code (to be passed to Emscripten). -- -- * JavaScript wrapper code that presents a nicer interface to the -- Emscripten-produced code (this should be put in a @.class.js@ -- file by itself). -- -- * Options that should be passed to @emcc@. compileProg :: (MonadFreshNames m) => T.Text -> Prog MCMem -> m (ImpGen.Warnings, (GC.CParts, T.Text, [String])) compileProg version prog = do (ws, prog') <- ImpGen.compileProg prog prog'' <- GC.compileProg "wasm_multicore" version mempty MC.operations generateBoilerplate "" (DefaultSpace, [DefaultSpace]) MC.cliOptions prog' pure ( ws, ( prog'', javascriptWrapper (fRepMyRep prog'), "_futhark_context_config_set_num_threads" : emccExportNames (fRepMyRep prog') ) ) fRepMyRep :: Imp.Definitions Imp.Multicore -> [JSEntryPoint] fRepMyRep prog = let Imp.Functions fs = Imp.defFuns prog function (Imp.Function entry _ _ _) = do Imp.EntryPoint n res args <- entry Just $ JSEntryPoint { name = nameToString n, parameters = map (extToString . snd) args, ret = map (extToString . snd) res } in mapMaybe (function . snd) fs futhark-0.25.27/src/Futhark/CodeGen/Backends/PyOpenCL.hs000066400000000000000000000330211475065116200225370ustar00rootroot00000000000000-- | Code generation for Python with OpenCL. module Futhark.CodeGen.Backends.PyOpenCL ( compileProg, ) where import Control.Monad import Control.Monad.Identity import Data.Map qualified as M import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericPython hiding (compileProg) import Futhark.CodeGen.Backends.GenericPython qualified as GP import Futhark.CodeGen.Backends.GenericPython.AST import Futhark.CodeGen.Backends.GenericPython.Options import Futhark.CodeGen.Backends.PyOpenCL.Boilerplate import Futhark.CodeGen.ImpCode (Count (..)) import Futhark.CodeGen.ImpCode.OpenCL qualified as Imp import Futhark.CodeGen.ImpGen.OpenCL qualified as ImpGen import Futhark.CodeGen.RTS.Python (openclPy) import Futhark.IR.GPUMem (GPUMem, Prog) import Futhark.MonadFreshNames import Futhark.Util (zEncodeText) import Futhark.Util.Pretty (prettyString, prettyText) -- | Compile the program to Python with calls to OpenCL. compileProg :: (MonadFreshNames m) => CompilerMode -> String -> Prog GPUMem -> m (ImpGen.Warnings, T.Text) compileProg mode class_name prog = do ( ws, Imp.Program opencl_code opencl_prelude macros kernels types sizes failures prog' ) <- ImpGen.compileProg prog -- prepare the strings for assigning the kernels and set them as global let assign = unlines $ map ( \x -> prettyString $ Assign (Var (T.unpack ("self." <> zEncodeText (nameToText x) <> "_var"))) (Var $ T.unpack $ "program." <> zEncodeText (nameToText x)) ) $ M.keys kernels let defines = [ Assign (Var "synchronous") $ Bool False, Assign (Var "preferred_platform") None, Assign (Var "build_options") $ List [], Assign (Var "preferred_device") None, Assign (Var "default_threshold") None, Assign (Var "default_group_size") None, Assign (Var "default_num_groups") None, Assign (Var "default_tile_size") None, Assign (Var "default_reg_tile_size") None, Assign (Var "fut_opencl_src") $ RawStringLiteral $ opencl_prelude <> opencl_code ] let imports = [ Import "sys" Nothing, Import "numpy" $ Just "np", Import "ctypes" $ Just "ct", Escape openclPy, Import "pyopencl.array" Nothing, Import "time" Nothing ] let constructor = Constructor [ "self", "build_options=build_options", "command_queue=None", "interactive=False", "platform_pref=preferred_platform", "device_pref=preferred_device", "default_group_size=default_group_size", "default_num_groups=default_num_groups", "default_tile_size=default_tile_size", "default_reg_tile_size=default_reg_tile_size", "default_threshold=default_threshold", "sizes=sizes" ] [Escape $ openClInit macros types assign sizes failures] options = [ Option { optionLongName = "platform", optionShortName = Just 'p', optionArgument = RequiredArgument "str", optionAction = [Assign (Var "preferred_platform") $ Var "optarg"] }, Option { optionLongName = "device", optionShortName = Just 'd', optionArgument = RequiredArgument "str", optionAction = [Assign (Var "preferred_device") $ Var "optarg"] }, Option { optionLongName = "build-option", optionShortName = Nothing, optionArgument = RequiredArgument "str", optionAction = [ Assign (Var "build_options") $ BinOp "+" (Var "build_options") $ List [Var "optarg"] ] }, Option { optionLongName = "default-threshold", optionShortName = Nothing, optionArgument = RequiredArgument "int", optionAction = [Assign (Var "default_threshold") $ Var "optarg"] }, Option { optionLongName = "default-group-size", optionShortName = Nothing, optionArgument = RequiredArgument "int", optionAction = [Assign (Var "default_group_size") $ Var "optarg"] }, Option { optionLongName = "default-num-groups", optionShortName = Nothing, optionArgument = RequiredArgument "int", optionAction = [Assign (Var "default_num_groups") $ Var "optarg"] }, Option { optionLongName = "default-tile-size", optionShortName = Nothing, optionArgument = RequiredArgument "int", optionAction = [Assign (Var "default_tile_size") $ Var "optarg"] }, Option { optionLongName = "default-reg-tile-size", optionShortName = Nothing, optionArgument = RequiredArgument "int", optionAction = [Assign (Var "default_reg_tile_size") $ Var "optarg"] }, Option { optionLongName = "param", optionShortName = Nothing, optionArgument = RequiredArgument "param_assignment", optionAction = [ Assign ( Index (Var "params") ( IdxExp ( Index (Var "optarg") (IdxExp (Integer 0)) ) ) ) (Index (Var "optarg") (IdxExp (Integer 1))) ] } ] (ws,) <$> GP.compileProg mode class_name constructor imports defines operations () [Exp $ simpleCall "sync" [Var "self"]] options prog' where operations :: Operations Imp.OpenCL () operations = Operations { opsCompiler = callKernel, opsWriteScalar = writeOpenCLScalar, opsReadScalar = readOpenCLScalar, opsAllocate = allocateOpenCLBuffer, opsCopies = M.insert (Imp.Space "device", Imp.Space "device") copygpu2gpu $ opsCopies defaultOperations, opsEntryOutput = packArrayOutput, opsEntryInput = unpackArrayInput } -- We have many casts to 'long', because PyOpenCL may get confused at -- the 32-bit numbers that ImpCode uses for offsets and the like. asLong :: PyExp -> PyExp asLong x = simpleCall "np.int64" [x] getParamByKey :: Name -> PyExp getParamByKey key = Index (Var "self.sizes") (IdxExp $ String $ prettyText key) kernelConstToExp :: Imp.KernelConst -> PyExp kernelConstToExp (Imp.SizeConst key _) = getParamByKey key kernelConstToExp (Imp.SizeMaxConst size_class) = Var $ "self.max_" <> prettyString size_class compileConstExp :: Imp.KernelConstExp -> PyExp compileConstExp e = runIdentity $ compilePrimExp (pure . kernelConstToExp) e compileBlockDim :: Imp.BlockDim -> CompilerM op s PyExp compileBlockDim (Left e) = asLong <$> compileExp e compileBlockDim (Right e) = pure $ compileConstExp e callKernel :: OpCompiler Imp.OpenCL () callKernel (Imp.GetSize v key) = do v' <- compileVar v stm $ Assign v' $ getParamByKey key callKernel (Imp.CmpSizeLe v key x) = do v' <- compileVar v x' <- compileExp x stm $ Assign v' $ BinOp "<=" (getParamByKey key) x' callKernel (Imp.GetSizeMax v size_class) = do v' <- compileVar v stm $ Assign v' $ kernelConstToExp $ Imp.SizeMaxConst size_class callKernel (Imp.LaunchKernel safety name shared_memory args num_threadblocks workgroup_size) = do num_threadblocks' <- mapM (fmap asLong . compileExp) num_threadblocks workgroup_size' <- mapM compileBlockDim workgroup_size let kernel_size = zipWith mult_exp num_threadblocks' workgroup_size' total_elements = foldl mult_exp (Integer 1) kernel_size cond = BinOp "!=" total_elements (Integer 0) shared_memory' <- compileExp $ Imp.untyped $ Imp.unCount shared_memory body <- collect $ launchKernel name safety kernel_size workgroup_size' shared_memory' args stm $ If cond body [] when (safety >= Imp.SafetyFull) $ stm $ Assign (Var "self.failure_is_an_option") $ compilePrimValue (Imp.IntValue (Imp.Int32Value 1)) where mult_exp = BinOp "*" launchKernel :: Imp.KernelName -> Imp.KernelSafety -> [PyExp] -> [PyExp] -> PyExp -> [Imp.KernelArg] -> CompilerM op s () launchKernel kernel_name safety kernel_dims threadblock_dims shared_memory args = do let kernel_dims' = Tuple kernel_dims threadblock_dims' = Tuple threadblock_dims kernel_name' = "self." <> zEncodeText (nameToText kernel_name) <> "_var" args' <- mapM processKernelArg args let failure_args = take (Imp.numFailureParams safety) [ Var "self.global_failure", Var "self.failure_is_an_option", Var "self.global_failure_args" ] stm . Exp $ simpleCall (T.unpack $ kernel_name' <> ".set_args") $ [simpleCall "cl.LocalMemory" [simpleCall "max" [shared_memory, Integer 1]]] ++ failure_args ++ args' stm . Exp $ simpleCall "cl.enqueue_nd_range_kernel" [Var "self.queue", Var (T.unpack kernel_name'), kernel_dims', threadblock_dims'] finishIfSynchronous where processKernelArg :: Imp.KernelArg -> CompilerM op s PyExp processKernelArg (Imp.ValueKArg e bt) = toStorage bt <$> compileExp e processKernelArg (Imp.MemKArg v) = compileVar v writeOpenCLScalar :: WriteScalar Imp.OpenCL () writeOpenCLScalar mem i bt "device" val = do let nparr = Call (Var "np.array") [Arg val, ArgKeyword "dtype" $ Var $ compilePrimType bt] stm $ Exp $ Call (Var "cl.enqueue_copy") [ Arg $ Var "self.queue", Arg mem, Arg nparr, ArgKeyword "dst_offset" $ BinOp "*" (asLong i) (Integer $ Imp.primByteSize bt), ArgKeyword "is_blocking" $ Var "synchronous" ] writeOpenCLScalar _ _ _ space _ = error $ "Cannot write to '" ++ space ++ "' memory space." readOpenCLScalar :: ReadScalar Imp.OpenCL () readOpenCLScalar mem i bt "device" = do val <- newVName "read_res" let val' = Var $ prettyString val let nparr = Call (Var "np.empty") [ Arg $ Integer 1, ArgKeyword "dtype" (Var $ compilePrimType bt) ] stm $ Assign val' nparr stm $ Exp $ Call (Var "cl.enqueue_copy") [ Arg $ Var "self.queue", Arg val', Arg mem, ArgKeyword "src_offset" $ BinOp "*" (asLong i) (Integer $ Imp.primByteSize bt), ArgKeyword "is_blocking" $ Var "synchronous" ] stm $ Exp $ simpleCall "sync" [Var "self"] pure $ Index val' $ IdxExp $ Integer 0 readOpenCLScalar _ _ _ space = error $ "Cannot read from '" ++ space ++ "' memory space." allocateOpenCLBuffer :: Allocate Imp.OpenCL () allocateOpenCLBuffer mem size "device" = stm $ Assign mem $ simpleCall "opencl_alloc" [Var "self", size, String $ prettyText mem] allocateOpenCLBuffer _ _ space = error $ "Cannot allocate in '" ++ space ++ "' space" packArrayOutput :: EntryOutput Imp.OpenCL () packArrayOutput mem "device" bt ept dims = do mem' <- compileVar mem dims' <- mapM compileDim dims pure $ Call (Var "cl.array.Array") [ Arg $ Var "self.queue", Arg $ Tuple $ dims' <> [Integer 0 | bt == Imp.Unit], Arg $ Var $ compilePrimToExtNp bt ept, ArgKeyword "data" mem' ] packArrayOutput _ sid _ _ _ = error $ "Cannot return array from " ++ sid ++ " space." unpackArrayInput :: EntryInput Imp.OpenCL () unpackArrayInput mem "device" t s dims e = do let type_is_ok = BinOp "and" (BinOp "in" (simpleCall "type" [e]) (List [Var "np.ndarray", Var "cl.array.Array"])) (BinOp "==" (Field e "dtype") (Var (compilePrimToExtNp t s))) stm $ Assert type_is_ok $ String "Parameter has unexpected type" zipWithM_ (unpackDim e) dims [0 ..] let memsize' = simpleCall "np.int64" [Field e "nbytes"] pyOpenCLArrayCase = [Assign mem $ Field e "data"] numpyArrayCase <- collect $ do allocateOpenCLBuffer mem memsize' "device" stm $ ifNotZeroSize memsize' $ Exp $ Call (Var "cl.enqueue_copy") [ Arg $ Var "self.queue", Arg mem, Arg $ Call (Var "normaliseArray") [Arg e], ArgKeyword "is_blocking" $ Var "synchronous" ] stm $ If (BinOp "==" (simpleCall "type" [e]) (Var "cl.array.Array")) pyOpenCLArrayCase numpyArrayCase unpackArrayInput _ sid _ _ _ _ = error $ "Cannot accept array from " ++ sid ++ " space." ifNotZeroSize :: PyExp -> PyStmt -> PyStmt ifNotZeroSize e s = If (BinOp "!=" e (Integer 0)) [s] [] finishIfSynchronous :: CompilerM op s () finishIfSynchronous = stm $ If (Var "synchronous") [Exp $ simpleCall "sync" [Var "self"]] [] copygpu2gpu :: DoCopy op s copygpu2gpu t shape dst (dstoffset, dststride) src (srcoffset, srcstride) = do stm . Exp . simpleCall "lmad_copy_gpu2gpu" $ [ Var "self", Var (compilePrimType t), dst, unCount dstoffset, List (map unCount dststride), src, unCount srcoffset, List (map unCount srcstride), List (map unCount shape) ] futhark-0.25.27/src/Futhark/CodeGen/Backends/PyOpenCL/000077500000000000000000000000001475065116200222045ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/Backends/PyOpenCL/Boilerplate.hs000066400000000000000000000121621475065116200250040ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | Various boilerplate definitions for the PyOpenCL backend. module Futhark.CodeGen.Backends.PyOpenCL.Boilerplate ( openClInit, ) where import Control.Monad.Identity import Data.Map qualified as M import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericPython qualified as Py import Futhark.CodeGen.Backends.GenericPython.AST import Futhark.CodeGen.ImpCode.OpenCL ( ErrorMsg (..), ErrorMsgPart (..), FailureMsg (..), KernelConst (..), KernelConstExp, ParamMap, PrimType (..), errorMsgArgTypes, sizeDefault, untyped, ) import Futhark.CodeGen.OpenCL.Heuristics import Futhark.Util.Pretty (prettyString, prettyText) import NeatInterpolation (text) errorMsgNumArgs :: ErrorMsg a -> Int errorMsgNumArgs = length . errorMsgArgTypes getParamByKey :: Name -> PyExp getParamByKey key = Index (Var "self.sizes") (IdxExp $ String $ prettyText key) kernelConstToExp :: KernelConst -> PyExp kernelConstToExp (SizeConst key _) = getParamByKey key kernelConstToExp (SizeMaxConst size_class) = Var $ "self.max_" <> prettyString size_class compileConstExp :: KernelConstExp -> PyExp compileConstExp e = runIdentity $ Py.compilePrimExp (pure . kernelConstToExp) e -- | Python code (as a string) that calls the -- @initiatialize_opencl_object@ procedure. Should be put in the -- class constructor. openClInit :: [(Name, KernelConstExp)] -> [PrimType] -> String -> ParamMap -> [FailureMsg] -> T.Text openClInit constants types assign sizes failures = [text| size_heuristics=$size_heuristics self.global_failure_args_max = $max_num_args self.failure_msgs=$failure_msgs constants = $constants' program = initialise_opencl_object(self, program_src=fut_opencl_src, build_options=build_options, command_queue=command_queue, interactive=interactive, platform_pref=platform_pref, device_pref=device_pref, default_group_size=default_group_size, default_num_groups=default_num_groups, default_tile_size=default_tile_size, default_reg_tile_size=default_reg_tile_size, default_threshold=default_threshold, size_heuristics=size_heuristics, required_types=$types', user_sizes=sizes, all_sizes=$sizes', constants=constants) $assign' |] where assign' = T.pack assign size_heuristics = prettyText $ sizeHeuristicsToPython sizeHeuristicsTable types' = prettyText $ map (show . prettyString) types -- Looks enough like Python. sizes' = prettyText $ sizeClassesToPython sizes max_num_args = prettyText $ foldl max 0 $ map (errorMsgNumArgs . failureError) failures failure_msgs = prettyText $ List $ map formatFailure failures onConstant (name, e) = Tuple [ String (nameToText name), Lambda "" (compileConstExp e) ] constants' = prettyText $ List $ map onConstant constants formatFailure :: FailureMsg -> PyExp formatFailure (FailureMsg (ErrorMsg parts) backtrace) = String $ mconcat (map onPart parts) <> "\n" <> formatEscape backtrace where formatEscape = let escapeChar '{' = "{{" escapeChar '}' = "}}" escapeChar c = T.singleton c in mconcat . map escapeChar onPart (ErrorString s) = formatEscape $ T.unpack s onPart ErrorVal {} = "{}" sizeClassesToPython :: ParamMap -> PyExp sizeClassesToPython = Dict . map f . M.toList where f (size_name, (size_class, _)) = ( String $ prettyText size_name, Dict [ (String "class", String $ prettyText size_class), ( String "value", maybe None (Integer . fromIntegral) $ sizeDefault size_class ) ] ) sizeHeuristicsToPython :: [SizeHeuristic] -> PyExp sizeHeuristicsToPython = List . map f where f (SizeHeuristic platform_name device_type which what) = Tuple [ String (T.pack platform_name), clDeviceType device_type, which', what' ] where clDeviceType DeviceGPU = Var "cl.device_type.GPU" clDeviceType DeviceCPU = Var "cl.device_type.CPU" which' = case which of LockstepWidth -> String "lockstep_width" NumBlocks -> String "num_groups" BlockSize -> String "group_size" TileSize -> String "tile_size" RegTileSize -> String "reg_tile_size" Threshold -> String "threshold" what' = Lambda "device" $ runIdentity $ Py.compilePrimExp onLeaf $ untyped what onLeaf (DeviceInfo s) = pure $ Py.simpleCall "device.get_info" [Py.simpleCall "getattr" [Var "cl.device_info", String (T.pack s)]] futhark-0.25.27/src/Futhark/CodeGen/Backends/SequentialC.hs000066400000000000000000000022711475065116200233260ustar00rootroot00000000000000-- | C code generator. This module can convert a correct ImpCode -- program to an equivalent C program. The C code is strictly -- sequential, but can handle the full Futhark language. module Futhark.CodeGen.Backends.SequentialC ( compileProg, GC.CParts (..), GC.asLibrary, GC.asExecutable, GC.asServer, ) where import Control.Monad import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC qualified as GC import Futhark.CodeGen.Backends.SequentialC.Boilerplate import Futhark.CodeGen.ImpCode.Sequential qualified as Imp import Futhark.CodeGen.ImpGen.Sequential qualified as ImpGen import Futhark.IR.SeqMem import Futhark.MonadFreshNames -- | Compile the program to sequential C. compileProg :: (MonadFreshNames m) => T.Text -> Prog SeqMem -> m (ImpGen.Warnings, GC.CParts) compileProg version = traverse ( GC.compileProg "c" version mempty operations generateBoilerplate mempty (DefaultSpace, [DefaultSpace]) [] ) <=< ImpGen.compileProg where operations :: GC.Operations Imp.Sequential () operations = GC.defaultOperations { GC.opsCompiler = const $ pure () } futhark-0.25.27/src/Futhark/CodeGen/Backends/SequentialC/000077500000000000000000000000001475065116200227705ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/Backends/SequentialC/Boilerplate.hs000066400000000000000000000010611475065116200255640ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | Boilerplate for sequential C code. module Futhark.CodeGen.Backends.SequentialC.Boilerplate (generateBoilerplate) where import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC qualified as GC import Futhark.CodeGen.RTS.C (backendsCH) import Language.C.Quote.OpenCL qualified as C -- | Generate the necessary boilerplate. generateBoilerplate :: GC.CompilerM op s () generateBoilerplate = do GC.earlyDecl [C.cedecl|$esc:(T.unpack backendsCH)|] GC.generateProgramStruct {-# NOINLINE generateBoilerplate #-} futhark-0.25.27/src/Futhark/CodeGen/Backends/SequentialPython.hs000066400000000000000000000024041475065116200244230ustar00rootroot00000000000000-- | Code generation for sequential Python. module Futhark.CodeGen.Backends.SequentialPython ( compileProg, ) where import Control.Monad import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericPython qualified as GenericPython import Futhark.CodeGen.Backends.GenericPython.AST import Futhark.CodeGen.ImpCode.Sequential qualified as Imp import Futhark.CodeGen.ImpGen.Sequential qualified as ImpGen import Futhark.IR.SeqMem import Futhark.MonadFreshNames -- | Compile the program to Python. compileProg :: (MonadFreshNames m) => GenericPython.CompilerMode -> String -> Prog SeqMem -> m (ImpGen.Warnings, T.Text) compileProg mode class_name = ImpGen.compileProg >=> traverse ( GenericPython.compileProg mode class_name GenericPython.emptyConstructor imports defines operations () [] [] ) where imports = [ Import "sys" Nothing, Import "numpy" $ Just "np", Import "ctypes" $ Just "ct", Import "time" Nothing ] defines = [] operations :: GenericPython.Operations Imp.Sequential () operations = GenericPython.defaultOperations { GenericPython.opsCompiler = const $ pure () } futhark-0.25.27/src/Futhark/CodeGen/Backends/SequentialWASM.hs000066400000000000000000000042471475065116200237200ustar00rootroot00000000000000-- | C code generator. This module can convert a correct ImpCode -- program to an equivalent C program. This C program is expected to -- be converted to WebAssembly, so we also produce the intended -- JavaScript wrapper. module Futhark.CodeGen.Backends.SequentialWASM ( compileProg, runServer, libraryExports, GC.CParts (..), GC.asLibrary, GC.asExecutable, GC.asServer, ) where import Data.Maybe import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC qualified as GC import Futhark.CodeGen.Backends.GenericWASM import Futhark.CodeGen.Backends.SequentialC.Boilerplate import Futhark.CodeGen.ImpCode.Sequential qualified as Imp import Futhark.CodeGen.ImpGen.Sequential qualified as ImpGen import Futhark.IR.SeqMem import Futhark.MonadFreshNames -- | Compile Futhark program to wasm program (some assembly -- required). -- -- The triple that is returned consists of -- -- * Generated C code (to be passed to Emscripten). -- -- * JavaScript wrapper code that presents a nicer interface to the -- Emscripten-produced code (this should be put in a @.class.js@ -- file by itself). -- -- * Options that should be passed to @emcc@. compileProg :: (MonadFreshNames m) => T.Text -> Prog SeqMem -> m (ImpGen.Warnings, (GC.CParts, T.Text, [String])) compileProg version prog = do (ws, prog') <- ImpGen.compileProg prog prog'' <- GC.compileProg "wasm" version mempty operations generateBoilerplate "" (DefaultSpace, [DefaultSpace]) [] prog' pure (ws, (prog'', javascriptWrapper (fRepMyRep prog'), emccExportNames (fRepMyRep prog'))) where operations :: GC.Operations Imp.Sequential () operations = GC.defaultOperations { GC.opsCompiler = const $ pure () } fRepMyRep :: Imp.Program -> [JSEntryPoint] fRepMyRep prog = let Imp.Functions fs = Imp.defFuns prog function (Imp.Function entry _ _ _) = do Imp.EntryPoint n res args <- entry Just $ JSEntryPoint { name = nameToString n, parameters = map (extToString . snd) args, ret = map (extToString . snd) res } in mapMaybe (function . snd) fs futhark-0.25.27/src/Futhark/CodeGen/Backends/SimpleRep.hs000066400000000000000000000265741475065116200230250ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | Simple C runtime representation. -- -- Most types use the same memory and scalar variable representation. -- For those that do not (as of this writing, only `Float16`), we use -- 'primStorageType' for the array element representation, and -- 'primTypeToCType' for their scalar representation. Use 'toStorage' -- and 'fromStorage' to convert back and forth. module Futhark.CodeGen.Backends.SimpleRep ( tupleField, funName, defaultMemBlockType, intTypeToCType, primTypeToCType, primStorageType, primAPIType, arrayName, opaqueName, isValidCName, escapeName, toStorage, fromStorage, cproduct, csum, allEqual, allTrue, scalarToPrim, -- * Primitive value operations cScalarDefs, -- * Storing/restoring values in byte sequences storageSize, storeValueHeader, loadValueHeader, ) where import Control.Monad (void) import Data.Char (isAlpha, isAlphaNum, isDigit) import Data.Text qualified as T import Data.Void (Void) import Futhark.CodeGen.ImpCode import Futhark.CodeGen.RTS.C (scalarF16H, scalarH) import Futhark.Util (hashText, showText, zEncodeText) import Language.C.Quote.C qualified as C import Language.C.Syntax qualified as C import Text.Megaparsec import Text.Megaparsec.Char (space) -- | The C type corresponding to a signed integer type. intTypeToCType :: IntType -> C.Type intTypeToCType Int8 = [C.cty|typename int8_t|] intTypeToCType Int16 = [C.cty|typename int16_t|] intTypeToCType Int32 = [C.cty|typename int32_t|] intTypeToCType Int64 = [C.cty|typename int64_t|] -- | The C type corresponding to an unsigned integer type. uintTypeToCType :: IntType -> C.Type uintTypeToCType Int8 = [C.cty|typename uint8_t|] uintTypeToCType Int16 = [C.cty|typename uint16_t|] uintTypeToCType Int32 = [C.cty|typename uint32_t|] uintTypeToCType Int64 = [C.cty|typename uint64_t|] -- | The C type corresponding to a primitive type. Integers are -- assumed to be unsigned. primTypeToCType :: PrimType -> C.Type primTypeToCType (IntType t) = intTypeToCType t primTypeToCType (FloatType Float16) = [C.cty|typename f16|] primTypeToCType (FloatType Float32) = [C.cty|float|] primTypeToCType (FloatType Float64) = [C.cty|double|] primTypeToCType Bool = [C.cty|typename bool|] primTypeToCType Unit = [C.cty|typename bool|] -- | The C storage type for arrays of this primitive type. primStorageType :: PrimType -> C.Type primStorageType (FloatType Float16) = [C.cty|typename uint16_t|] primStorageType t = primTypeToCType t -- | The C API corresponding to a primitive type. Integers are -- assumed to have the specified sign. primAPIType :: Signedness -> PrimType -> C.Type primAPIType Unsigned (IntType t) = uintTypeToCType t primAPIType Signed (IntType t) = intTypeToCType t primAPIType _ t = primStorageType t -- | Convert from scalar to storage representation for the given type. toStorage :: PrimType -> C.Exp -> C.Exp toStorage (FloatType Float16) e = [C.cexp|futrts_to_bits16($exp:e)|] toStorage _ e = e -- | Convert from storage to scalar representation for the given type. fromStorage :: PrimType -> C.Exp -> C.Exp fromStorage (FloatType Float16) e = [C.cexp|futrts_from_bits16($exp:e)|] fromStorage _ e = e -- | @tupleField i@ is the name of field number @i@ in a tuple. tupleField :: Int -> String tupleField i = "v" ++ show i -- | @funName f@ is the name of the C function corresponding to -- the Futhark function @f@. funName :: Name -> T.Text funName = ("futrts_" <>) . zEncodeText . nameToText -- | The type of memory blocks in the default memory space. defaultMemBlockType :: C.Type defaultMemBlockType = [C.cty|unsigned char*|] -- | The name of exposed array type structs. arrayName :: PrimType -> Signedness -> Int -> T.Text arrayName pt signed rank = prettySigned (signed == Unsigned) pt <> "_" <> prettyText rank <> "d" -- | Is this name a valid C identifier? If not, it should be escaped -- before being emitted into C. isValidCName :: T.Text -> Bool isValidCName = maybe True check . T.uncons where check (c, cs) = isAlpha c && T.all constituent cs constituent c = isAlphaNum c || c == '_' -- | If the provided text is a valid C identifier, then return it -- verbatim. Otherwise, escape it such that it becomes valid. escapeName :: T.Text -> T.Text escapeName v | isValidCName v = v | otherwise = zEncodeText v -- | Valid C identifier name? valid :: T.Text -> Bool valid s = T.head s /= '_' && not (isDigit $ T.head s) && T.all ok s where ok c = isAlphaNum c || c == '_' -- | Find a nice C type name name for the Futhark type. This solely -- serves to make the generated header file easy to read, and we can -- always fall back on an ugly hash. findPrettyName :: T.Text -> Either String T.Text findPrettyName = either (Left . errorBundlePretty) Right . parse (p <* eof) "type name" where p :: Parsec Void T.Text T.Text p = choice [pArr, pTup, pAtom] pArr = do dims <- some "[]" (("arr" <> showText (length dims) <> "d_") <>) <$> p pTup = between "(" ")" $ do ts <- p `sepBy` pComma pure $ "tup" <> showText (length ts) <> "_" <> T.intercalate "_" ts pAtom = T.pack <$> some (satisfy (`notElem` ("[]{}()," :: String))) pComma = void $ "," <* space -- | The name of exposed opaque types. opaqueName :: Name -> T.Text opaqueName "()" = "opaque_unit" -- Hopefully this ad-hoc convenience won't bite us. opaqueName s | Right v <- findPrettyName s', valid v = "opaque_" <> v | valid s' = "opaque_" <> s' where s' = nameToText s opaqueName s = "opaque_" <> hashText (nameToText s) -- | The 'PrimType' (and sign) corresponding to a human-readable scalar -- type name (e.g. @f64@). Beware: partial! scalarToPrim :: T.Text -> (Signedness, PrimType) scalarToPrim "bool" = (Signed, Bool) scalarToPrim "i8" = (Signed, IntType Int8) scalarToPrim "i16" = (Signed, IntType Int16) scalarToPrim "i32" = (Signed, IntType Int32) scalarToPrim "i64" = (Signed, IntType Int64) scalarToPrim "u8" = (Unsigned, IntType Int8) scalarToPrim "u16" = (Unsigned, IntType Int16) scalarToPrim "u32" = (Unsigned, IntType Int32) scalarToPrim "u64" = (Unsigned, IntType Int64) scalarToPrim "f16" = (Signed, FloatType Float16) scalarToPrim "f32" = (Signed, FloatType Float32) scalarToPrim "f64" = (Signed, FloatType Float64) scalarToPrim tname = error $ "scalarToPrim: " <> T.unpack tname -- | Return an expression multiplying together the given expressions. -- If an empty list is given, the expression @1@ is returned. cproduct :: [C.Exp] -> C.Exp cproduct [] = [C.cexp|1|] cproduct (e : es) = foldl mult e es where mult x y = [C.cexp|$exp:x * $exp:y|] -- | Return an expression summing the given expressions. -- If an empty list is given, the expression @0@ is returned. csum :: [C.Exp] -> C.Exp csum [] = [C.cexp|0|] csum (e : es) = foldl mult e es where mult x y = [C.cexp|$exp:x + $exp:y|] -- | An expression that is true if these are also all true. allTrue :: [C.Exp] -> C.Exp allTrue [] = [C.cexp|true|] allTrue [x] = x allTrue (x : xs) = [C.cexp|$exp:x && $exp:(allTrue xs)|] -- | An expression that is true if these expressions are all equal by -- @==@. allEqual :: [C.Exp] -> C.Exp allEqual [x, y] = [C.cexp|$exp:x == $exp:y|] allEqual (x : y : xs) = [C.cexp|$exp:x == $exp:y && $exp:(allEqual(y:xs))|] allEqual _ = [C.cexp|true|] instance C.ToIdent Name where toIdent = C.toIdent . zEncodeText . nameToText -- Orphan! instance C.ToIdent T.Text where toIdent = C.toIdent . T.unpack instance C.ToIdent VName where toIdent = C.toIdent . zEncodeText . prettyText instance C.ToExp VName where toExp v _ = [C.cexp|$id:v|] instance C.ToExp IntValue where toExp (Int8Value k) _ = [C.cexp|(typename int8_t)$int:k|] toExp (Int16Value k) _ = [C.cexp|(typename int16_t)$int:k|] toExp (Int32Value k) _ = [C.cexp|$int:k|] toExp (Int64Value k) _ = [C.cexp|(typename int64_t)$int:k|] instance C.ToExp FloatValue where toExp (Float16Value x) _ | isInfinite x = if x > 0 then [C.cexp|(typename f16)INFINITY|] else [C.cexp|(typename f16)-INFINITY|] | isNaN x = [C.cexp|(typename f16)NAN|] | otherwise = [C.cexp|(typename f16)$float:(fromRational (toRational x))|] toExp (Float32Value x) _ | isInfinite x = if x > 0 then [C.cexp|INFINITY|] else [C.cexp|-INFINITY|] | isNaN x = [C.cexp|NAN|] | otherwise = [C.cexp|$float:x|] toExp (Float64Value x) _ | isInfinite x = if x > 0 then [C.cexp|INFINITY|] else [C.cexp|-INFINITY|] | isNaN x = [C.cexp|NAN|] | otherwise = [C.cexp|$double:x|] instance C.ToExp PrimValue where toExp (IntValue v) = C.toExp v toExp (FloatValue v) = C.toExp v toExp (BoolValue True) = C.toExp (1 :: Int8) toExp (BoolValue False) = C.toExp (0 :: Int8) toExp UnitValue = C.toExp (0 :: Int8) instance C.ToExp SubExp where toExp (Var v) = C.toExp v toExp (Constant c) = C.toExp c -- | Implementations of scalar operations. cScalarDefs :: T.Text cScalarDefs = scalarH <> scalarF16H -- | @storageSize pt rank shape@ produces an expression giving size -- taken when storing this value in the binary value format. It is -- assumed that the @shape@ is an array with @rank@ dimensions. storageSize :: PrimType -> Int -> C.Exp -> C.Exp storageSize pt rank shape = [C.cexp|$int:header_size + $int:rank * sizeof(typename int64_t) + $exp:(cproduct dims) * sizeof($ty:(primStorageType pt))|] where header_size :: Int header_size = 1 + 1 + 1 + 4 -- 'b' dims = [[C.cexp|$exp:shape[$int:i]|] | i <- [0 .. rank - 1]] typeStr :: Signedness -> PrimType -> String typeStr sign pt = case (sign, pt) of (_, Bool) -> "bool" (_, Unit) -> "bool" (_, FloatType Float16) -> " f16" (_, FloatType Float32) -> " f32" (_, FloatType Float64) -> " f64" (Signed, IntType Int8) -> " i8" (Signed, IntType Int16) -> " i16" (Signed, IntType Int32) -> " i32" (Signed, IntType Int64) -> " i64" (Unsigned, IntType Int8) -> " u8" (Unsigned, IntType Int16) -> " u16" (Unsigned, IntType Int32) -> " u32" (Unsigned, IntType Int64) -> " u64" -- | Produce code for storing the header (everything besides the -- actual payload) for a value of this type. storeValueHeader :: Signedness -> PrimType -> Int -> C.Exp -> C.Exp -> [C.Stm] storeValueHeader sign pt rank shape dest = [C.cstms| *$exp:dest++ = 'b'; *$exp:dest++ = 2; *$exp:dest++ = $int:rank; memcpy($exp:dest, $string:(typeStr sign pt), 4); $exp:dest += 4; $stms:copy_shape |] where copy_shape | rank == 0 = [] | otherwise = [C.cstms| memcpy($exp:dest, $exp:shape, $int:rank*sizeof(typename int64_t)); $exp:dest += $int:rank*sizeof(typename int64_t);|] -- | Produce code for loading the header (everything besides the -- actual payload) for a value of this type. loadValueHeader :: Signedness -> PrimType -> Int -> C.Exp -> C.Exp -> [C.Stm] loadValueHeader sign pt rank shape src = [C.cstms| err |= (*$exp:src++ != 'b'); err |= (*$exp:src++ != 2); err |= (*$exp:src++ != $exp:rank); err |= (memcmp($exp:src, $string:(typeStr sign pt), 4) != 0); $exp:src += 4; if (err == 0) { $stms:load_shape $exp:src += $int:rank*sizeof(typename int64_t); }|] where load_shape | rank == 0 = [] | otherwise = [C.cstms|memcpy($exp:shape, src, $int:rank*sizeof(typename int64_t));|] futhark-0.25.27/src/Futhark/CodeGen/ImpCode.hs000066400000000000000000000643261475065116200207300ustar00rootroot00000000000000{-# LANGUAGE Strict #-} -- | ImpCode is an imperative intermediate language used as a stepping -- stone in code generation. The functional core IR -- ("Futhark.IR.Syntax") gets translated into ImpCode by -- "Futhark.CodeGen.ImpGen". Later we then translate ImpCode to, for -- example, C. -- -- == Basic design -- -- ImpCode distinguishes between /statements/ ('Code'), which may have -- side effects, and /expressions/ ('Exp') which do not. Expressions -- involve only scalars and have a type. The actual expression -- definition is in "Futhark.Analysis.PrimExp", specifically -- 'Futhark.Analysis.PrimExp.PrimExp' and its phantom-typed variant -- 'Futhark.Analysis.PrimExp.TPrimExp'. -- -- 'Code' is a generic representation parametrised on an extensible -- arbitrary operation, represented by the 'Op' constructor. Specific -- instantiations of ImpCode, such as -- "Futhark.CodeGen.ImpCode.Multicore", will pass in a specific kind -- of operation to express backend-specific functionality (in the case -- of multicore, this is -- 'Futhark.CodeGen.ImpCode.Multicore.Multicore'). -- -- == Arrays and memory -- -- ImpCode does not have arrays. 'DeclareArray' is for declaring -- constant array literals, not arrays in general. Instead, ImpCode -- deals only with memory. Array operations present in core IR -- programs are turned into 'Write', v'Read', and 'Copy' -- operations that use flat indexes and offsets based on the index -- function of the original array. -- -- == Scoping -- -- ImpCode is much simpler than the functional core IR; partly because -- we hope to do less work on it. We don't have real optimisation -- passes on ImpCode. One result of this simplicity is that ImpCode -- has a fairly naive view of scoping. The /only/ things that can -- bring new names into scope are 'DeclareMem', 'DeclareScalar', -- 'DeclareArray', 'For', and function parameters. In particular, -- 'Op's /cannot/ bind parameters. The standard workaround is to -- define 'Op's that retrieve the value of an implicit parameter and -- assign it to a variable declared with the normal -- mechanisms. 'Futhark.CodeGen.ImpCode.Multicore.GetLoopBounds' is an -- example of this pattern. -- -- == Inspiration -- -- ImpCode was originally inspired by the paper "Defunctionalizing -- Push Arrays" (FHPC '14). module Futhark.CodeGen.ImpCode ( Definitions (..), Functions (..), Function, FunctionT (..), EntryPoint (..), Constants (..), ValueDesc (..), ExternalValue (..), Param (..), paramName, MemSize, DimSize, Code (..), PrimValue (..), Exp, TExp, Volatility (..), Arg (..), var, ArrayContents (..), declaredIn, lexicalMemoryUsage, declsFirst, calledFuncs, callGraph, ParamMap, -- * Typed enumerations Bytes, Elements, elements, bytes, withElemType, -- * Re-exports from other modules. prettyText, prettyString, module Futhark.IR.Syntax.Core, module Language.Futhark.Core, module Language.Futhark.Primitive, module Futhark.Analysis.PrimExp, module Futhark.Analysis.PrimExp.Convert, module Futhark.IR.GPU.Sizes, module Futhark.IR.Prop.Names, ) where import Data.Bifunctor (second) import Data.List (intersperse, partition) import Data.Map qualified as M import Data.Ord (comparing) import Data.Set qualified as S import Data.Text qualified as T import Data.Traversable import Futhark.Analysis.PrimExp import Futhark.Analysis.PrimExp.Convert import Futhark.IR.GPU.Sizes (Count (..), SizeClass (..)) import Futhark.IR.Pretty () import Futhark.IR.Prop.Names import Futhark.IR.Syntax.Core ( EntryPointType (..), ErrorMsg (..), ErrorMsgPart (..), OpaqueType (..), OpaqueTypes (..), Rank (..), Signedness (..), Space (..), SpaceId, SubExp (..), ValueType (..), errorMsgArgTypes, ) import Futhark.Util (nubByOrd) import Futhark.Util.Pretty hiding (space) import Language.Futhark.Core import Language.Futhark.Primitive -- | The size of a memory block. type MemSize = SubExp -- | The size of an array. type DimSize = SubExp -- | An ImpCode function parameter. data Param = MemParam VName Space | ScalarParam VName PrimType deriving (Eq, Show) -- | The name of a parameter. paramName :: Param -> VName paramName (MemParam name _) = name paramName (ScalarParam name _) = name -- | A collection of imperative functions and constants. data Definitions a = Definitions { defTypes :: OpaqueTypes, defConsts :: Constants a, defFuns :: Functions a } deriving (Show) instance Functor Definitions where fmap f (Definitions types consts funs) = Definitions types (fmap f consts) (fmap f funs) -- | A collection of imperative functions. newtype Functions a = Functions {unFunctions :: [(Name, Function a)]} deriving (Show) instance Semigroup (Functions a) where Functions x <> Functions y = Functions $ x ++ y instance Monoid (Functions a) where mempty = Functions [] -- | A collection of imperative constants. data Constants a = Constants { -- | The constants that are made available to the functions. constsDecl :: [Param], -- | Setting the value of the constants. Note that this must not -- contain declarations of the names defined in 'constsDecl'. constsInit :: Code a } deriving (Show) instance Functor Constants where fmap f (Constants params code) = Constants params (fmap f code) instance Monoid (Constants a) where mempty = Constants mempty mempty instance Semigroup (Constants a) where Constants ps1 c1 <> Constants ps2 c2 = Constants (nubByOrd (comparing (prettyString . paramName)) $ ps1 <> ps2) (c1 <> c2) -- | A description of an externally meaningful value. data ValueDesc = -- | An array with memory block memory space, element type, -- signedness of element type (if applicable), and shape. ArrayValue VName Space PrimType Signedness [DimSize] | -- | A scalar value with signedness if applicable. ScalarValue PrimType Signedness VName deriving (Eq, Show) -- | ^ An externally visible value. This can be an opaque value -- (covering several physical internal values), or a single value that -- can be used externally. We record the uniqueness because it is -- important to the external interface as well. data ExternalValue = -- | The string is a human-readable description with no other -- semantics. OpaqueValue Name [ValueDesc] | TransparentValue ValueDesc deriving (Show) -- | Information about how this function can be called from the outside world. data EntryPoint = EntryPoint { entryPointName :: Name, entryPointResults :: [(Uniqueness, ExternalValue)], entryPointArgs :: [((Name, Uniqueness), ExternalValue)] } deriving (Show) -- | A imperative function, containing the body as well as its -- low-level inputs and outputs, as well as its high-level arguments -- and results. The latter are only present if the function is an entry -- point. data FunctionT a = Function { functionEntry :: Maybe EntryPoint, functionOutput :: [Param], functionInput :: [Param], functionBody :: Code a } deriving (Show) -- | Type alias for namespace control. type Function = FunctionT -- | The contents of a statically declared constant array. Such -- arrays are always unidimensional, and reshaped if necessary in the -- code that uses them. data ArrayContents = -- | Precisely these values. ArrayValues [PrimValue] | -- | This many zeroes. ArrayZeros Int deriving (Show) -- | A block of imperative code. Parameterised by an 'Op', which -- allows extensibility. Concrete uses of this type will instantiate -- the type parameter with e.g. a construct for launching GPU kernels. data Code a = -- | No-op. Crucial for the 'Monoid' instance. Skip | -- | Statement composition. Crucial for the 'Semigroup' instance. Code a :>>: Code a | -- | A for-loop iterating the given number of times. -- The loop parameter starts counting from zero and will -- have the same (integer) type as the bound. The bound -- is evaluated just once, before the loop is entered. For VName Exp (Code a) | -- | While loop. The conditional is (of course) -- re-evaluated before every iteration of the loop. While (TExp Bool) (Code a) | -- | Declare a memory block variable that will point to -- memory in the given memory space. Note that this is -- distinct from allocation. The memory block must be the -- target of either an 'Allocate' or a 'SetMem' before it -- can be used for reading or writing. DeclareMem VName Space | -- | Declare a scalar variable with an initially undefined value. DeclareScalar VName Volatility PrimType | -- | Create a DefaultSpace array containing the given values. The -- lifetime of the array will be the entire application. This is -- mostly used for constant arrays. DeclareArray VName PrimType ArrayContents | -- | Memory space must match the corresponding -- 'DeclareMem'. Allocate VName (Count Bytes (TExp Int64)) Space | -- | Indicate that some memory block will never again be -- referenced via the indicated variable. However, it -- may still be accessed through aliases. It is only -- safe to actually deallocate the memory block if this -- is the last reference. There is no guarantee that -- all memory blocks will be freed with this statement. -- Backends are free to ignore it entirely. Free VName Space | -- | @Copy pt shape dest dest_lmad src src_lmad@. Copy PrimType [Count Elements (TExp Int64)] (VName, Space) ( Count Elements (TExp Int64), [Count Elements (TExp Int64)] ) (VName, Space) ( Count Elements (TExp Int64), [Count Elements (TExp Int64)] ) | -- | @Write mem i t space vol v@ writes the value @v@ to -- @mem@ offset by @i@ elements of type @t@. The -- 'Space' argument is the memory space of @mem@ -- (technically redundant, but convenient). Write VName (Count Elements (TExp Int64)) PrimType Space Volatility Exp | -- | Set a scalar variable. SetScalar VName Exp | -- | Read a scalar from memory from memory. The first 'VName' is -- the target scalar variable, and the remaining arguments have -- the same meaning as with 'Write'. Read VName VName (Count Elements (TExp Int64)) PrimType Space Volatility | -- | Must be in same space. SetMem VName VName Space | -- | Function call. The results are written to the -- provided 'VName' variables. Call [VName] Name [Arg] | -- | Conditional execution. If (TExp Bool) (Code a) (Code a) | -- | Assert that something must be true. Should it turn -- out not to be true, then report a failure along with -- the given error message. Assert Exp (ErrorMsg Exp) (SrcLoc, [SrcLoc]) | -- | Has the same semantics as the contained code, but -- the comment should show up in generated code for ease -- of inspection. Comment T.Text (Code a) | -- | Print the given value to the screen, somehow -- annotated with the given string as a description. If -- no type/value pair, just print the string. This has -- no semantic meaning, but is used entirely for -- debugging. Code generators are free to ignore this -- statement. DebugPrint String (Maybe Exp) | -- | Log the given message, *without* a trailing linebreak (unless -- part of the message). TracePrint (ErrorMsg Exp) | -- | Perform an extensible operation. Op a deriving (Show) -- | The volatility of a memory access or variable. Feel free to -- ignore this for backends where it makes no sense (anything but C -- and similar low-level things) data Volatility = Volatile | Nonvolatile deriving (Eq, Ord, Show) instance Semigroup (Code a) where Skip <> y = y x <> Skip = x x <> y = x :>>: y instance Monoid (Code a) where mempty = Skip -- | Find those memory blocks that are used only lexically. That is, -- are not used as the source or target of a 'SetMem', or are the -- result of the function, nor passed as arguments to other functions. -- This is interesting because such memory blocks do not need -- reference counting, but can be managed in a purely stack-like -- fashion. -- -- We do not look inside any 'Op's. We assume that no 'Op' is going -- to 'SetMem' a memory block declared outside it. lexicalMemoryUsage :: Function a -> M.Map VName Space lexicalMemoryUsage func = M.filterWithKey (const . (`notNameIn` nonlexical)) $ declared $ functionBody func where nonlexical = set (functionBody func) <> namesFromList (map paramName (functionOutput func)) go f (x :>>: y) = f x <> f y go f (If _ x y) = f x <> f y go f (For _ _ x) = f x go f (While _ x) = f x go f (Comment _ x) = f x go _ _ = mempty declared (DeclareMem mem space) = M.singleton mem space declared x = go declared x set (SetMem x y _) = namesFromList [x, y] set (Call dests _ args) = -- Some of the dests might not be memory, but it does not matter. namesFromList dests <> foldMap onArg args where onArg ExpArg {} = mempty onArg (MemArg x) = oneName x set x = go set x -- | Reorder the code such that all declarations appear first. This -- is always possible, because 'DeclareScalar' and 'DeclareMem' do -- not depend on any local bindings. declsFirst :: Code a -> Code a declsFirst = mconcat . uncurry (<>) . partition isDecl . listify where listify (c1 :>>: c2) = listify c1 <> listify c2 listify (If cond c1 c2) = [If cond (declsFirst c1) (declsFirst c2)] listify (For i e c) = [For i e (declsFirst c)] listify (While cond c) = [While cond (declsFirst c)] listify c = [c] isDecl (DeclareScalar {}) = True isDecl (DeclareMem {}) = True isDecl _ = False -- | The set of functions that are called by this code. Accepts a -- function for determing function calls in 'Op's. calledFuncs :: (a -> S.Set Name) -> Code a -> S.Set Name calledFuncs _ (Call _ v _) = S.singleton v calledFuncs f (Op x) = f x calledFuncs f (x :>>: y) = calledFuncs f x <> calledFuncs f y calledFuncs f (If _ x y) = calledFuncs f x <> calledFuncs f y calledFuncs f (For _ _ x) = calledFuncs f x calledFuncs f (While _ x) = calledFuncs f x calledFuncs f (Comment _ x) = calledFuncs f x calledFuncs _ _ = mempty -- | Compute call graph, as per 'calledFuncs', but also include -- transitive calls. callGraph :: (a -> S.Set Name) -> Functions a -> M.Map Name (S.Set Name) callGraph f (Functions funs) = loop $ M.fromList $ map (second $ calledFuncs f . functionBody) funs where loop cur = let grow v = maybe (S.singleton v) (S.insert v) (M.lookup v cur) next = M.map (foldMap grow) cur in if next == cur then cur else loop next -- | A mapping from names of tuning parameters to their class, as well -- as which functions make use of them (including transitively). type ParamMap = M.Map Name (SizeClass, S.Set Name) -- | A side-effect free expression whose execution will produce a -- single primitive value. type Exp = PrimExp VName -- | Like 'Exp', but with a required/known type. type TExp t = TPrimExp t VName -- | A function call argument. data Arg = ExpArg Exp | MemArg VName deriving (Show) -- | Phantom type for a count of elements. data Elements -- | Phantom type for a count of bytes. data Bytes -- | This expression counts elements. elements :: a -> Count Elements a elements = Count -- | This expression counts bytes. bytes :: a -> Count Bytes a bytes = Count -- | Convert a count of elements into a count of bytes, given the -- per-element size. withElemType :: Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64) withElemType (Count e) t = bytes $ sExt64 e * primByteSize t -- | Turn a 'VName' into a 'Exp'. var :: VName -> PrimType -> Exp var = LeafExp -- Prettyprinting definitions. instance (Pretty op) => Pretty (Definitions op) where pretty (Definitions types consts funs) = pretty types pretty consts pretty funs instance (Pretty op) => Pretty (Functions op) where pretty (Functions funs) = stack $ intersperse mempty $ map ppFun funs where ppFun (name, fun) = "Function " <> pretty name <> colon indent 2 (pretty fun) instance (Pretty op) => Pretty (Constants op) where pretty (Constants decls code) = "Constants:" indent 2 (stack $ map pretty decls) mempty "Initialisation:" indent 2 (pretty code) instance Pretty EntryPoint where pretty (EntryPoint name results args) = "Name:" indent 2 (dquotes (pretty name)) "Arguments:" indent 2 (stack $ map ppArg args) "Results:" indent 2 (stack $ map ppRes results) where ppArg ((p, u), t) = pretty p <+> ":" <+> ppRes (u, t) ppRes (u, t) = pretty u <> pretty t instance (Pretty op) => Pretty (FunctionT op) where pretty (Function entry outs ins body) = "Inputs:" indent 2 (stack $ map pretty ins) "Outputs:" indent 2 (stack $ map pretty outs) "Entry:" indent 2 (pretty entry) "Body:" indent 2 (pretty body) instance Pretty Param where pretty (ScalarParam name ptype) = pretty ptype <+> pretty name pretty (MemParam name space) = "mem" <> pretty space <> " " <> pretty name instance Pretty ValueDesc where pretty (ScalarValue t ept name) = pretty t <+> pretty name <> ept' where ept' = case ept of Unsigned -> " (unsigned)" Signed -> mempty pretty (ArrayValue mem space et ept shape) = foldMap (brackets . pretty) shape <> (pretty et <+> "at" <+> pretty mem <> pretty space <+> ept') where ept' = case ept of Unsigned -> " (unsigned)" Signed -> mempty instance Pretty ExternalValue where pretty (TransparentValue v) = pretty v pretty (OpaqueValue desc vs) = "opaque" <+> dquotes (pretty desc) <+> nestedBlock "{" "}" (stack $ map pretty vs) instance Pretty ArrayContents where pretty (ArrayValues vs) = braces (commasep $ map pretty vs) pretty (ArrayZeros n) = braces "0" <+> "*" <+> pretty n instance (Pretty op) => Pretty (Code op) where pretty (Op op) = pretty op pretty Skip = "skip" pretty (c1 :>>: c2) = pretty c1 pretty c2 pretty (For i limit body) = "for" <+> pretty i <+> langle <+> pretty limit <+> "{" indent 2 (pretty body) "}" pretty (While cond body) = "while" <+> pretty cond <+> "{" indent 2 (pretty body) "}" pretty (DeclareMem name space) = "var" <+> pretty name <> ": mem" <> pretty space pretty (DeclareScalar name vol t) = "var" <+> pretty name <> ":" <+> vol' <> pretty t where vol' = case vol of Volatile -> "volatile " Nonvolatile -> mempty pretty (DeclareArray name t vs) = "array" <+> pretty name <+> ":" <+> pretty t <+> equals <+> pretty vs pretty (Allocate name e space) = pretty name <+> "<-" <+> "malloc" <> parens (pretty e) <> pretty space pretty (Free name space) = "free" <> parens (pretty name) <> pretty space pretty (Write name i bt space vol val) = pretty name <> langle <> vol' <> pretty bt <> pretty space <> rangle <> brackets (pretty i) <+> "<-" <+> pretty val where vol' = case vol of Volatile -> "volatile " Nonvolatile -> mempty pretty (Read name v is bt space vol) = pretty name <+> "<-" <+> pretty v <> langle <> vol' <> pretty bt <> pretty space <> rangle <> brackets (pretty is) where vol' = case vol of Volatile -> "volatile " Nonvolatile -> mempty pretty (SetScalar name val) = pretty name <+> "<-" <+> pretty val pretty (SetMem dest from DefaultSpace) = pretty dest <+> "<-" <+> pretty from pretty (SetMem dest from space) = pretty dest <+> "<-" <+> pretty from <+> "@" <> pretty space pretty (Assert e msg _) = "assert" <> parens (commasep [pretty msg, pretty e]) pretty (Copy t shape (dst, dstspace) (dstoffset, dststrides) (src, srcspace) (srcoffset, srcstrides)) = ("lmadcopy_" <> pretty (length shape) <> "d_" <> pretty t) <> (parens . align) ( foldMap (brackets . pretty) shape <> "," p dst dstspace dstoffset dststrides <> "," p src srcspace srcoffset srcstrides ) where p mem space offset strides = pretty mem <> pretty space <> "+" <> pretty offset <+> foldMap (brackets . pretty) strides pretty (If cond tbranch fbranch) = "if" <+> pretty cond <+> "then {" indent 2 (pretty tbranch) "} else" <+> case fbranch of If {} -> pretty fbranch _ -> "{" indent 2 (pretty fbranch) "}" pretty (Call [] fname args) = "call" <+> pretty fname <> parens (commasep $ map pretty args) pretty (Call dests fname args) = "call" <+> commasep (map pretty dests) <+> "<-" <+> pretty fname <> parens (commasep $ map pretty args) pretty (Comment s code) = "--" <+> pretty s pretty code pretty (DebugPrint desc (Just e)) = "debug" <+> parens (commasep [pretty (show desc), pretty e]) pretty (DebugPrint desc Nothing) = "debug" <+> parens (pretty (show desc)) pretty (TracePrint msg) = "trace" <+> parens (pretty msg) instance Pretty Arg where pretty (MemArg m) = pretty m pretty (ExpArg e) = pretty e instance Functor Functions where fmap = fmapDefault instance Foldable Functions where foldMap = foldMapDefault instance Traversable Functions where traverse f (Functions funs) = Functions <$> traverse f' funs where f' (name, fun) = (name,) <$> traverse f fun instance Functor FunctionT where fmap = fmapDefault instance Foldable FunctionT where foldMap = foldMapDefault instance Traversable FunctionT where traverse f (Function entry outs ins body) = Function entry outs ins <$> traverse f body instance Functor Code where fmap = fmapDefault instance Foldable Code where foldMap = foldMapDefault instance Traversable Code where traverse f (x :>>: y) = (:>>:) <$> traverse f x <*> traverse f y traverse f (For i bound code) = For i bound <$> traverse f code traverse f (While cond code) = While cond <$> traverse f code traverse f (If cond x y) = If cond <$> traverse f x <*> traverse f y traverse f (Op kernel) = Op <$> f kernel traverse _ Skip = pure Skip traverse _ (DeclareMem name space) = pure $ DeclareMem name space traverse _ (DeclareScalar name vol bt) = pure $ DeclareScalar name vol bt traverse _ (DeclareArray name t vs) = pure $ DeclareArray name t vs traverse _ (Allocate name size s) = pure $ Allocate name size s traverse _ (Free name space) = pure $ Free name space traverse _ (Copy t shape (dst, dstspace) (dstoffset, dststrides) (src, srcspace) (srcoffset, srcstrides)) = pure $ Copy t shape (dst, dstspace) (dstoffset, dststrides) (src, srcspace) (srcoffset, srcstrides) traverse _ (Write name i bt val space vol) = pure $ Write name i bt val space vol traverse _ (Read x name i bt space vol) = pure $ Read x name i bt space vol traverse _ (SetScalar name val) = pure $ SetScalar name val traverse _ (SetMem dest from space) = pure $ SetMem dest from space traverse _ (Assert e msg loc) = pure $ Assert e msg loc traverse _ (Call dests fname args) = pure $ Call dests fname args traverse f (Comment s code) = Comment s <$> traverse f code traverse _ (DebugPrint s v) = pure $ DebugPrint s v traverse _ (TracePrint msg) = pure $ TracePrint msg -- | The names declared with 'DeclareMem', 'DeclareScalar', and -- 'DeclareArray' in the given code. declaredIn :: Code a -> Names declaredIn (DeclareMem name _) = oneName name declaredIn (DeclareScalar name _ _) = oneName name declaredIn (DeclareArray name _ _) = oneName name declaredIn (If _ t f) = declaredIn t <> declaredIn f declaredIn (x :>>: y) = declaredIn x <> declaredIn y declaredIn (For i _ body) = oneName i <> declaredIn body declaredIn (While _ body) = declaredIn body declaredIn (Comment _ body) = declaredIn body declaredIn _ = mempty instance FreeIn EntryPoint where freeIn' (EntryPoint _ res args) = freeIn' (map snd res) <> freeIn' (map snd args) instance (FreeIn a) => FreeIn (Functions a) where freeIn' (Functions fs) = foldMap (onFun . snd) fs where onFun f = fvBind pnames $ freeIn' (functionBody f) <> freeIn' (functionEntry f) where pnames = namesFromList $ map paramName $ functionInput f <> functionOutput f instance FreeIn ValueDesc where freeIn' (ArrayValue mem _ _ _ dims) = freeIn' mem <> freeIn' dims freeIn' ScalarValue {} = mempty instance FreeIn ExternalValue where freeIn' (TransparentValue vd) = freeIn' vd freeIn' (OpaqueValue _ vds) = foldMap freeIn' vds instance (FreeIn a) => FreeIn (Code a) where freeIn' (x :>>: y) = fvBind (declaredIn x) $ freeIn' x <> freeIn' y freeIn' Skip = mempty freeIn' (For i bound body) = fvBind (oneName i) $ freeIn' bound <> freeIn' body freeIn' (While cond body) = freeIn' cond <> freeIn' body freeIn' (DeclareMem _ space) = freeIn' space freeIn' DeclareScalar {} = mempty freeIn' DeclareArray {} = mempty freeIn' (Allocate name size space) = freeIn' name <> freeIn' size <> freeIn' space freeIn' (Free name _) = freeIn' name freeIn' (Copy _ shape (dst, _) (dstoffset, dststrides) (src, _) (srcoffset, srcstrides)) = freeIn' shape <> freeIn' dst <> freeIn' dstoffset <> freeIn' dststrides <> freeIn' src <> freeIn' srcoffset <> freeIn' srcstrides freeIn' (SetMem x y _) = freeIn' x <> freeIn' y freeIn' (Write v i _ _ _ e) = freeIn' v <> freeIn' i <> freeIn' e freeIn' (Read x v i _ _ _) = freeIn' x <> freeIn' v <> freeIn' i freeIn' (SetScalar x y) = freeIn' x <> freeIn' y freeIn' (Call dests _ args) = freeIn' dests <> freeIn' args freeIn' (If cond t f) = freeIn' cond <> freeIn' t <> freeIn' f freeIn' (Assert e msg _) = freeIn' e <> foldMap freeIn' msg freeIn' (Op op) = freeIn' op freeIn' (Comment _ code) = freeIn' code freeIn' (DebugPrint _ v) = maybe mempty freeIn' v freeIn' (TracePrint msg) = foldMap freeIn' msg instance FreeIn Arg where freeIn' (MemArg m) = freeIn' m freeIn' (ExpArg e) = freeIn' e futhark-0.25.27/src/Futhark/CodeGen/ImpCode/000077500000000000000000000000001475065116200203615ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/ImpCode/GPU.hs000066400000000000000000000260161475065116200213550ustar00rootroot00000000000000-- | Variation of "Futhark.CodeGen.ImpCode" that contains the notion -- of a kernel invocation. module Futhark.CodeGen.ImpCode.GPU ( Program, HostCode, KernelCode, KernelConst (..), KernelConstExp, HostOp (..), KernelOp (..), Fence (..), AtomicOp (..), BlockDim, Kernel (..), KernelUse (..), module Futhark.CodeGen.ImpCode, module Futhark.IR.GPU.Sizes, ) where import Futhark.CodeGen.ImpCode import Futhark.IR.GPU.Sizes import Futhark.IR.Pretty () import Futhark.Util.Pretty -- | A program that calls kernels. type Program = Definitions HostOp -- | Host-level code that can call kernels. type HostCode = Code HostOp -- | Code inside a kernel. type KernelCode = Code KernelOp -- | A run-time constant related to kernels. data KernelConst = SizeConst Name SizeClass | SizeMaxConst SizeClass deriving (Eq, Ord, Show) -- | An expression whose variables are kernel constants. type KernelConstExp = PrimExp KernelConst -- | An operation that runs on the host (CPU). data HostOp = CallKernel Kernel | GetSize VName Name SizeClass | CmpSizeLe VName Name SizeClass Exp | GetSizeMax VName SizeClass deriving (Show) -- | The size of one dimension of a block. type BlockDim = Either Exp KernelConstExp -- | A generic kernel containing arbitrary kernel code. data Kernel = Kernel { kernelBody :: Code KernelOp, -- | The host variables referenced by the kernel. kernelUses :: [KernelUse], kernelNumBlocks :: [Exp], kernelBlockSize :: [BlockDim], -- | A short descriptive and _unique_ name - should be -- alphanumeric and without spaces. kernelName :: Name, -- | If true, this kernel does not need to check whether we are in -- a failing state, as it can cope. Intuitively, it means that the -- kernel does not depend on any non-scalar parameters to make -- control flow decisions. Replication, transpose, and copy -- kernels are examples of this. kernelFailureTolerant :: Bool, -- | If true, multi-versioning branches will consider this kernel -- when considering the shared memory requirements. Set this to -- false for kernels that do their own checking. kernelCheckSharedMemory :: Bool } deriving (Show) -- | Information about a host-level variable that is used inside this -- kernel. When generating the actual kernel code, this is used to -- deduce which parameters are needed. data KernelUse = ScalarUse VName PrimType | MemoryUse VName | ConstUse VName KernelConstExp deriving (Eq, Ord, Show) instance Pretty KernelConst where pretty (SizeConst key size_class) = "get_size" <> parens (commasep [pretty key, pretty size_class]) pretty (SizeMaxConst size_class) = "get_max_size" <> parens (pretty size_class) instance FreeIn KernelConst where freeIn' SizeConst {} = mempty freeIn' (SizeMaxConst _) = mempty instance Pretty KernelUse where pretty (ScalarUse name t) = oneLine $ "scalar_copy" <> parens (commasep [pretty name, pretty t]) pretty (MemoryUse name) = oneLine $ "mem_copy" <> parens (commasep [pretty name]) pretty (ConstUse name e) = oneLine $ "const" <> parens (commasep [pretty name, pretty e]) instance Pretty HostOp where pretty (GetSize dest key size_class) = pretty dest <+> "<-" <+> "get_size" <> parens (commasep [pretty key, pretty size_class]) pretty (GetSizeMax dest size_class) = pretty dest <+> "<-" <+> "get_size_max" <> parens (pretty size_class) pretty (CmpSizeLe dest name size_class x) = pretty dest <+> "<-" <+> "get_size" <> parens (commasep [pretty name, pretty size_class]) <+> "<" <+> pretty x pretty (CallKernel c) = pretty c instance FreeIn HostOp where freeIn' (CallKernel c) = freeIn' c freeIn' (CmpSizeLe dest _ _ x) = freeIn' dest <> freeIn' x freeIn' (GetSizeMax dest _) = freeIn' dest freeIn' (GetSize dest _ _) = freeIn' dest instance FreeIn Kernel where freeIn' kernel = freeIn' ( kernelBody kernel, kernelNumBlocks kernel, kernelBlockSize kernel ) instance Pretty Kernel where pretty kernel = "kernel" <+> brace ( "blocks" <+> brace (pretty $ kernelNumBlocks kernel) "tblock_size" <+> brace (list $ map pSize $ kernelBlockSize kernel) "uses" <+> brace (stack $ map pretty $ kernelUses kernel) "failure_tolerant" <+> brace (pretty $ kernelFailureTolerant kernel) "check_shared_memory" <+> brace (pretty $ kernelCheckSharedMemory kernel) "body" <+> brace (pretty $ kernelBody kernel) ) where pSize (Left x) = "dyn" <+> pretty x pSize (Right x) = "const" <+> pretty x -- | When we do a barrier or fence, is it at the local or global -- level? By the 'Ord' instance, global is greater than local. data Fence = FenceLocal | FenceGlobal deriving (Show, Eq, Ord) -- | An operation that occurs within a kernel body. data KernelOp = GetBlockId VName Int | GetLocalId VName Int | GetLocalSize VName Int | GetLockstepWidth VName | Atomic Space AtomicOp | Barrier Fence | MemFence Fence | SharedAlloc VName (Count Bytes (TExp Int64)) | -- | Perform a barrier and also check whether any -- threads have failed an assertion. Make sure all -- threads would reach all 'ErrorSync's if any of them -- do. A failing assertion will jump to the next -- following 'ErrorSync', so make sure it's not inside -- control flow or similar. ErrorSync Fence deriving (Show) -- | Atomic operations return the value stored before the update. This -- old value is stored in the first 'VName' (except for -- 'AtomicWrite'). The second 'VName' is the memory block to update. -- The 'Exp' is the new value. data AtomicOp = AtomicAdd IntType VName VName (Count Elements (TExp Int64)) Exp | AtomicFAdd FloatType VName VName (Count Elements (TExp Int64)) Exp | AtomicSMax IntType VName VName (Count Elements (TExp Int64)) Exp | AtomicSMin IntType VName VName (Count Elements (TExp Int64)) Exp | AtomicUMax IntType VName VName (Count Elements (TExp Int64)) Exp | AtomicUMin IntType VName VName (Count Elements (TExp Int64)) Exp | AtomicAnd IntType VName VName (Count Elements (TExp Int64)) Exp | AtomicOr IntType VName VName (Count Elements (TExp Int64)) Exp | AtomicXor IntType VName VName (Count Elements (TExp Int64)) Exp | AtomicCmpXchg PrimType VName VName (Count Elements (TExp Int64)) Exp Exp | AtomicXchg PrimType VName VName (Count Elements (TExp Int64)) Exp | -- | Corresponds to a write followed by a memory fence. The old -- value is not read. AtomicWrite PrimType VName (Count Elements (TExp Int64)) Exp deriving (Show) instance FreeIn AtomicOp where freeIn' (AtomicAdd _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicFAdd _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicSMax _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicSMin _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicUMax _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicUMin _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicAnd _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicOr _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicXor _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicCmpXchg _ _ arr i x y) = freeIn' arr <> freeIn' i <> freeIn' x <> freeIn' y freeIn' (AtomicXchg _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicWrite _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x instance Pretty KernelOp where pretty (GetBlockId dest i) = pretty dest <+> "<-" <+> "get_tblock_id" <> parens (pretty i) pretty (GetLocalId dest i) = pretty dest <+> "<-" <+> "get_local_id" <> parens (pretty i) pretty (GetLocalSize dest i) = pretty dest <+> "<-" <+> "get_local_size" <> parens (pretty i) pretty (GetLockstepWidth dest) = pretty dest <+> "<-" <+> "get_lockstep_width()" pretty (Barrier FenceLocal) = "local_barrier()" pretty (Barrier FenceGlobal) = "global_barrier()" pretty (MemFence FenceLocal) = "mem_fence_local()" pretty (MemFence FenceGlobal) = "mem_fence_global()" pretty (SharedAlloc name size) = pretty name <+> equals <+> "shared_alloc" <> parens (pretty size) pretty (ErrorSync FenceLocal) = "error_sync_local()" pretty (ErrorSync FenceGlobal) = "error_sync_global()" pretty (Atomic _ (AtomicAdd t old arr ind x)) = pretty old <+> "<-" <+> "atomic_add_" <> pretty t <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicFAdd t old arr ind x)) = pretty old <+> "<-" <+> "atomic_fadd_" <> pretty t <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicSMax t old arr ind x)) = pretty old <+> "<-" <+> "atomic_smax" <> pretty t <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicSMin t old arr ind x)) = pretty old <+> "<-" <+> "atomic_smin" <> pretty t <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicUMax t old arr ind x)) = pretty old <+> "<-" <+> "atomic_umax" <> pretty t <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicUMin t old arr ind x)) = pretty old <+> "<-" <+> "atomic_umin" <> pretty t <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicAnd t old arr ind x)) = pretty old <+> "<-" <+> "atomic_and" <> pretty t <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicOr t old arr ind x)) = pretty old <+> "<-" <+> "atomic_or" <> pretty t <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicXor t old arr ind x)) = pretty old <+> "<-" <+> "atomic_xor" <> pretty t <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicCmpXchg t old arr ind x y)) = pretty old <+> "<-" <+> "atomic_cmp_xchg" <> pretty t <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x, pretty y]) pretty (Atomic _ (AtomicXchg t old arr ind x)) = pretty old <+> "<-" <+> "atomic_xchg" <> pretty t <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) pretty (Atomic _ (AtomicWrite t arr ind x)) = "atomic_write" <> pretty t <> parens (commasep [pretty arr <> brackets (pretty ind), pretty x]) instance FreeIn KernelOp where freeIn' (Atomic _ op) = freeIn' op freeIn' (SharedAlloc _ size) = freeIn' size freeIn' _ = mempty brace :: Doc a -> Doc a brace body = " {" indent 2 body "}" futhark-0.25.27/src/Futhark/CodeGen/ImpCode/Multicore.hs000066400000000000000000000177251475065116200226740ustar00rootroot00000000000000-- | Multicore imperative code. module Futhark.CodeGen.ImpCode.Multicore ( Program, Multicore (..), MCCode, Scheduling (..), SchedulerInfo (..), AtomicOp (..), ParallelTask (..), KernelHandling (..), lexicalMemoryUsageMC, module Futhark.CodeGen.ImpCode, ) where import Data.Map qualified as M import Futhark.CodeGen.ImpCode import Futhark.Util.Pretty -- | An imperative multicore program. type Program = Functions Multicore -- | A multicore operation. data Multicore = SegOp String [Param] ParallelTask (Maybe ParallelTask) [Param] SchedulerInfo | ParLoop String MCCode [Param] | -- | A kernel of ISPC code, or a scoped block in regular C. ISPCKernel MCCode [Param] | -- | A foreach loop in ISPC, or a regular for loop in C. ForEach VName Exp Exp MCCode | -- | A foreach_active loop in ISPC, or a single execution in C. ForEachActive VName MCCode | -- | Extract a value from a given lane and assign it to a variable. -- This is just a regular assignment in C. ExtractLane VName Exp Exp | -- | Retrieve inclusive start and exclusive end indexes of the -- chunk we are supposed to be executing. Only valid immediately -- inside a 'ParLoop' construct! GetLoopBounds VName VName | -- | Retrieve the task ID that is currently executing. Only valid -- immediately inside a 'ParLoop' construct! GetTaskId VName | -- | Retrieve the number of subtasks to execute. Only valid -- immediately inside a 'SegOp' or 'ParLoop' construct! GetNumTasks VName | Atomic AtomicOp -- | Multicore code. type MCCode = Code Multicore -- | Atomic operations return the value stored before the update. -- This old value is stored in the first 'VName'. The second 'VName' -- is the memory block to update. The 'Exp' is the new value. data AtomicOp = AtomicAdd IntType VName VName (Count Elements (TExp Int32)) Exp | AtomicSub IntType VName VName (Count Elements (TExp Int32)) Exp | AtomicAnd IntType VName VName (Count Elements (TExp Int32)) Exp | AtomicOr IntType VName VName (Count Elements (TExp Int32)) Exp | AtomicXor IntType VName VName (Count Elements (TExp Int32)) Exp | AtomicXchg PrimType VName VName (Count Elements (TExp Int32)) Exp | AtomicCmpXchg PrimType VName VName (Count Elements (TExp Int32)) VName Exp deriving (Show) instance FreeIn AtomicOp where freeIn' (AtomicAdd _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicSub _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicAnd _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicOr _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicXor _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x freeIn' (AtomicCmpXchg _ _ arr i retval x) = freeIn' arr <> freeIn' i <> freeIn' x <> freeIn' retval freeIn' (AtomicXchg _ _ arr i x) = freeIn' arr <> freeIn' i <> freeIn' x -- | Information about parallel work that is do be done. This is -- passed to the scheduler to help it make scheduling decisions. data SchedulerInfo = SchedulerInfo { -- | The number of total iterations for a task. iterations :: Exp, -- | The type scheduling for the task. scheduling :: Scheduling } -- | A task for a v'SegOp'. newtype ParallelTask = ParallelTask MCCode -- | Whether the Scheduler should schedule the tasks as Dynamic -- or it is restainted to Static data Scheduling = Dynamic | Static instance Pretty Scheduling where pretty Dynamic = "Dynamic" pretty Static = "Static" instance Pretty SchedulerInfo where pretty (SchedulerInfo i sched) = stack [ nestedBlock "scheduling {" "}" (pretty sched), nestedBlock "iter {" "}" (pretty i) ] instance Pretty ParallelTask where pretty (ParallelTask code) = pretty code instance Pretty Multicore where pretty (GetLoopBounds start end) = pretty (start, end) <+> "<-" <+> "get_loop_bounds()" pretty (GetTaskId v) = pretty v <+> "<-" <+> "get_task_id()" pretty (GetNumTasks v) = pretty v <+> "<-" <+> "get_num_tasks()" pretty (SegOp s free seq_code par_code retval scheduler) = "SegOp" <+> pretty s <+> nestedBlock "{" "}" ppbody where ppbody = stack [ pretty scheduler, nestedBlock "free {" "}" (pretty free), nestedBlock "seq {" "}" (pretty seq_code), maybe mempty (nestedBlock "par {" "}" . pretty) par_code, nestedBlock "retvals {" "}" (pretty retval) ] pretty (ParLoop s body params) = "parloop" <+> pretty s nestedBlock "{" "}" ppbody where ppbody = stack [ nestedBlock "params {" "}" (pretty params), nestedBlock "body {" "}" (pretty body) ] pretty (Atomic _) = "AtomicOp" pretty (ISPCKernel body _) = "ispc" <+> nestedBlock "{" "}" (pretty body) pretty (ForEach i from to body) = "foreach" <+> pretty i <+> "=" <+> pretty from <+> "to" <+> pretty to <+> nestedBlock "{" "}" (pretty body) pretty (ForEachActive i body) = "foreach_active" <+> pretty i <+> nestedBlock "{" "}" (pretty body) pretty (ExtractLane dest tar lane) = pretty dest <+> "<-" <+> "extract" <+> parens (commasep $ map pretty [tar, lane]) instance FreeIn SchedulerInfo where freeIn' (SchedulerInfo iter _) = freeIn' iter instance FreeIn ParallelTask where freeIn' (ParallelTask code) = freeIn' code instance FreeIn Multicore where freeIn' (GetLoopBounds start end) = freeIn' (start, end) freeIn' (GetTaskId v) = freeIn' v freeIn' (GetNumTasks v) = freeIn' v freeIn' (SegOp _ _ par_code seq_code _ info) = freeIn' par_code <> freeIn' seq_code <> freeIn' info freeIn' (ParLoop _ body _) = freeIn' body freeIn' (Atomic aop) = freeIn' aop freeIn' (ISPCKernel body _) = freeIn' body freeIn' (ForEach i from to body) = fvBind (oneName i) (freeIn' body <> freeIn' from <> freeIn' to) freeIn' (ForEachActive i body) = fvBind (oneName i) (freeIn' body) freeIn' (ExtractLane dest tar lane) = freeIn' dest <> freeIn' tar <> freeIn' lane -- | Whether 'lexicalMemoryUsageMC' should look inside nested kernels -- or not. data KernelHandling = TraverseKernels | OpaqueKernels -- | Like @lexicalMemoryUsage@, but traverses some inner multicore ops. lexicalMemoryUsageMC :: KernelHandling -> Function Multicore -> M.Map VName Space lexicalMemoryUsageMC gokernel func = M.filterWithKey (const . (`notNameIn` nonlexical)) $ declared $ functionBody func where nonlexical = set (functionBody func) <> namesFromList (map paramName (functionOutput func)) go f (x :>>: y) = f x <> f y go f (If _ x y) = f x <> f y go f (For _ _ x) = f x go f (While _ x) = f x go f (Comment _ x) = f x go f (Op op) = goOp f op go _ _ = mempty -- We want SetMems and declarations to be visible through custom control flow -- so we don't erroneously treat a memblock that could be lexical as needing -- refcounting. Importantly, for ISPC, we do not look into kernels, since they -- go into new functions. For the Multicore backend, we can do it, though. goOp f (ForEach _ _ _ body) = go f body goOp f (ForEachActive _ body) = go f body goOp f (ISPCKernel body _) = case gokernel of TraverseKernels -> go f body OpaqueKernels -> mempty goOp _ _ = mempty declared (DeclareMem mem spc) = M.singleton mem spc declared x = go declared x set (SetMem x y _) = namesFromList [x, y] set (Call _ _ args) = foldMap onArg args where onArg ExpArg {} = mempty onArg (MemArg x) = oneName x -- Critically, don't treat inputs to nested segops as lexical when generating -- ISPC, since we want to use AoS memory for lexical blocks, which is -- incompatible with pointer assignmentes visible in C. set (Op (SegOp _ params _ _ retvals _)) = case gokernel of TraverseKernels -> mempty OpaqueKernels -> namesFromList $ map paramName params <> map paramName retvals set x = go set x futhark-0.25.27/src/Futhark/CodeGen/ImpCode/OpenCL.hs000066400000000000000000000065151475065116200220440ustar00rootroot00000000000000-- | Imperative code with an OpenCL component. -- -- Apart from ordinary imperative code, this also carries around an -- OpenCL program as a string, as well as a list of kernels defined by -- the OpenCL program. -- -- The imperative code has been augmented with a 'LaunchKernel' -- operation that allows one to execute an OpenCL kernel. module Futhark.CodeGen.ImpCode.OpenCL ( Program (..), KernelName, KernelArg (..), CLCode, OpenCL (..), KernelSafety (..), numFailureParams, KernelTarget (..), FailureMsg (..), BlockDim, KernelConst (..), KernelConstExp, module Futhark.CodeGen.ImpCode, module Futhark.IR.GPU.Sizes, ) where import Data.Map qualified as M import Data.Text qualified as T import Futhark.CodeGen.ImpCode import Futhark.CodeGen.ImpCode.GPU (BlockDim, KernelConst (..), KernelConstExp) import Futhark.IR.GPU.Sizes import Futhark.Util.Pretty -- | An program calling OpenCL kernels. data Program = Program { openClProgram :: T.Text, -- | Must be prepended to the program. openClPrelude :: T.Text, -- | Definitions to be passed as macro definitions to the kernel -- compiler. openClMacroDefs :: [(Name, KernelConstExp)], openClKernelNames :: M.Map KernelName KernelSafety, -- | So we can detect whether the device is capable. openClUsedTypes :: [PrimType], -- | Runtime-configurable constants. openClParams :: ParamMap, -- | Assertion failure error messages. openClFailures :: [FailureMsg], hostDefinitions :: Definitions OpenCL } -- | Something that can go wrong in a kernel. Part of the machinery -- for reporting error messages from within kernels. data FailureMsg = FailureMsg { failureError :: ErrorMsg Exp, failureBacktrace :: String } -- | A piece of code calling OpenCL. type CLCode = Code OpenCL -- | The name of a kernel. type KernelName = Name -- | An argument to be passed to a kernel. data KernelArg = -- | Pass the value of this scalar expression as argument. ValueKArg Exp PrimType | -- | Pass this pointer as argument. MemKArg VName deriving (Show) -- | Whether a kernel can potentially fail (because it contains bounds -- checks and such). data MayFail = MayFail | CannotFail deriving (Show) -- | Information about bounds checks and how sensitive it is to -- errors. Ordered by least demanding to most. data KernelSafety = -- | Does not need to know if we are in a failing state, and also -- cannot fail. SafetyNone | -- | Needs to be told if there's a global failure, and that's it, -- and cannot fail. SafetyCheap | -- | Needs all parameters, may fail itself. SafetyFull deriving (Eq, Ord, Show) -- | How many leading failure arguments we must pass when launching a -- kernel with these safety characteristics. numFailureParams :: KernelSafety -> Int numFailureParams SafetyNone = 0 numFailureParams SafetyCheap = 1 numFailureParams SafetyFull = 3 -- | Host-level OpenCL operation. data OpenCL = LaunchKernel KernelSafety KernelName (Count Bytes (TExp Int64)) [KernelArg] [Exp] [BlockDim] | GetSize VName Name | CmpSizeLe VName Name Exp | GetSizeMax VName SizeClass deriving (Show) -- | The target platform when compiling imperative code to a 'Program' data KernelTarget = TargetOpenCL | TargetCUDA | TargetHIP deriving (Eq) instance Pretty OpenCL where pretty = pretty . show futhark-0.25.27/src/Futhark/CodeGen/ImpCode/Sequential.hs000066400000000000000000000007201475065116200230260ustar00rootroot00000000000000-- | Sequential imperative code. module Futhark.CodeGen.ImpCode.Sequential ( Program, Sequential, module Futhark.CodeGen.ImpCode, ) where import Futhark.CodeGen.ImpCode import Futhark.Util.Pretty -- | An imperative program. type Program = Definitions Sequential -- | Phantom type for identifying sequential imperative code. data Sequential instance Pretty Sequential where pretty _ = mempty instance FreeIn Sequential where freeIn' _ = mempty futhark-0.25.27/src/Futhark/CodeGen/ImpGen.hs000066400000000000000000001751271475065116200205710ustar00rootroot00000000000000{-# LANGUAGE LambdaCase #-} {-# LANGUAGE Strict #-} {-# LANGUAGE TypeFamilies #-} module Futhark.CodeGen.ImpGen ( -- * Entry Points compileProg, -- * Pluggable Compiler OpCompiler, ExpCompiler, CopyCompiler, StmsCompiler, AllocCompiler, Operations (..), defaultOperations, MemLoc (..), sliceMemLoc, MemEntry (..), ScalarEntry (..), -- * Monadic Compiler Interface ImpM, localDefaultSpace, askFunction, newVNameForFun, nameForFun, askEnv, localEnv, localOps, VTable, getVTable, localVTable, subImpM, subImpM_, emit, emitFunction, hasFunction, collect, collect', VarEntry (..), ArrayEntry (..), -- * Lookups lookupVar, lookupArray, lookupArraySpace, lookupMemory, lookupAcc, askAttrs, -- * Building Blocks TV, MkTV (..), tvSize, tvExp, tvVar, ToExp (..), compileAlloc, everythingVolatile, compileBody, compileBody', compileLoopBody, defCompileStms, compileStms, compileExp, defCompileExp, fullyIndexArray, fullyIndexArray', copy, copyDWIM, copyDWIMFix, lmadCopy, typeSize, inBounds, caseMatch, -- * Constructing code. newVName, dLParams, dFParams, addLoopVar, dScope, dArray, dPrim, dPrimS, dPrimSV, dPrimVol, dPrim_, dPrimV_, dPrimV, dPrimVE, dIndexSpace, dIndexSpace', sFor, sWhile, sComment, sIf, sWhen, sUnless, sOp, sDeclareMem, sAlloc, sAlloc_, sArray, sArrayInMem, sAllocArray, sAllocArrayPerm, sStaticArray, sWrite, sUpdate, sLoopNest, sLoopSpace, (<--), (<~~), function, genConstants, warn, module Language.Futhark.Warnings, ) where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Control.Parallel.Strategies import Data.Bifunctor (first) import Data.DList qualified as DL import Data.Either import Data.List (find) import Data.List.NonEmpty (NonEmpty (..)) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Data.String import Data.Text qualified as T import Futhark.CodeGen.ImpCode ( Bytes, Count, Elements, elements, ) import Futhark.CodeGen.ImpCode qualified as Imp import Futhark.Construct hiding (ToExp (..)) import Futhark.IR.Mem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.IR.SOACS (SOACS) import Futhark.Util import Futhark.Util.IntegralExp import Futhark.Util.Pretty hiding (nest, space) import Language.Futhark.Warnings import Prelude hiding (mod, quot) -- | How to compile an t'Op'. type OpCompiler rep r op = Pat (LetDec rep) -> Op rep -> ImpM rep r op () -- | How to compile some 'Stms'. type StmsCompiler rep r op = Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op () -- | How to compile an 'Exp'. type ExpCompiler rep r op = Pat (LetDec rep) -> Exp rep -> ImpM rep r op () type CopyCompiler rep r op = PrimType -> MemLoc -> MemLoc -> ImpM rep r op () -- | An alternate way of compiling an allocation. type AllocCompiler rep r op = VName -> Count Bytes (Imp.TExp Int64) -> ImpM rep r op () data Operations rep r op = Operations { opsExpCompiler :: ExpCompiler rep r op, opsOpCompiler :: OpCompiler rep r op, opsStmsCompiler :: StmsCompiler rep r op, opsCopyCompiler :: CopyCompiler rep r op, opsAllocCompilers :: M.Map Space (AllocCompiler rep r op) } -- | An operations set for which the expression compiler always -- returns 'defCompileExp'. defaultOperations :: (Mem rep inner, FreeIn op) => OpCompiler rep r op -> Operations rep r op defaultOperations opc = Operations { opsExpCompiler = defCompileExp, opsOpCompiler = opc, opsStmsCompiler = defCompileStms, opsCopyCompiler = lmadCopy, opsAllocCompilers = mempty } -- | When an array is declared, this is where it is stored. data MemLoc = MemLoc { memLocName :: VName, memLocShape :: [Imp.DimSize], memLocLMAD :: LMAD.LMAD (Imp.TExp Int64) } deriving (Eq, Show) sliceMemLoc :: MemLoc -> Slice (Imp.TExp Int64) -> MemLoc sliceMemLoc (MemLoc mem shape lmad) slice = MemLoc mem shape $ LMAD.slice lmad slice flatSliceMemLoc :: MemLoc -> FlatSlice (Imp.TExp Int64) -> MemLoc flatSliceMemLoc (MemLoc mem shape lmad) slice = MemLoc mem shape $ LMAD.flatSlice lmad slice data ArrayEntry = ArrayEntry { entryArrayLoc :: MemLoc, entryArrayElemType :: PrimType } deriving (Show) entryArrayShape :: ArrayEntry -> [Imp.DimSize] entryArrayShape = memLocShape . entryArrayLoc newtype MemEntry = MemEntry {entryMemSpace :: Imp.Space} deriving (Show) newtype ScalarEntry = ScalarEntry { entryScalarType :: PrimType } deriving (Show) -- | Every non-scalar variable must be associated with an entry. data VarEntry rep = ArrayVar (Maybe (Exp rep)) ArrayEntry | ScalarVar (Maybe (Exp rep)) ScalarEntry | MemVar (Maybe (Exp rep)) MemEntry | AccVar (Maybe (Exp rep)) (VName, Shape, [Type]) deriving (Show) data ValueDestination = ScalarDestination VName | MemoryDestination VName | -- | The 'MemLoc' is 'Just' if a copy if -- required. If it is 'Nothing', then a -- copy/assignment of a memory block somewhere -- takes care of this array. ArrayDestination (Maybe MemLoc) deriving (Show) data Env rep r op = Env { envExpCompiler :: ExpCompiler rep r op, envStmsCompiler :: StmsCompiler rep r op, envOpCompiler :: OpCompiler rep r op, envCopyCompiler :: CopyCompiler rep r op, envAllocCompilers :: M.Map Space (AllocCompiler rep r op), envDefaultSpace :: Imp.Space, envVolatility :: Imp.Volatility, -- | User-extensible environment. envEnv :: r, -- | Name of the function we are compiling, if any. envFunction :: Maybe Name, -- | The set of attributes that are active on the enclosing -- statements (including the one we are currently compiling). envAttrs :: Attrs } newEnv :: r -> Operations rep r op -> Imp.Space -> Env rep r op newEnv r ops ds = Env { envExpCompiler = opsExpCompiler ops, envStmsCompiler = opsStmsCompiler ops, envOpCompiler = opsOpCompiler ops, envCopyCompiler = opsCopyCompiler ops, envAllocCompilers = mempty, envDefaultSpace = ds, envVolatility = Imp.Nonvolatile, envEnv = r, envFunction = Nothing, envAttrs = mempty } -- | The symbol table used during compilation. type VTable rep = M.Map VName (VarEntry rep) data ImpState rep r op = ImpState { stateVTable :: VTable rep, stateFunctions :: Imp.Functions op, stateCode :: Imp.Code op, stateConstants :: Imp.Constants op, stateWarnings :: Warnings, -- | Maps the arrays backing each accumulator to their -- update function and neutral elements. This works -- because an array name can only become part of a single -- accumulator throughout its lifetime. If the arrays -- backing an accumulator is not in this mapping, the -- accumulator is scatter-like. stateAccs :: M.Map VName ([VName], Maybe (Lambda rep, [SubExp])), stateNameSource :: VNameSource } newState :: VNameSource -> ImpState rep r op newState = ImpState mempty mempty mempty mempty mempty mempty newtype ImpM rep r op a = ImpM (ReaderT (Env rep r op) (State (ImpState rep r op)) a) deriving ( Functor, Applicative, Monad, MonadState (ImpState rep r op), MonadReader (Env rep r op) ) instance MonadFreshNames (ImpM rep r op) where getNameSource = gets stateNameSource putNameSource src = modify $ \s -> s {stateNameSource = src} -- Cannot be an KernelsMem scope because the index functions have -- the wrong leaves (VName instead of Imp.Exp). instance HasScope SOACS (ImpM rep r op) where askScope = gets $ M.map (LetName . entryType) . stateVTable where entryType (MemVar _ memEntry) = Mem (entryMemSpace memEntry) entryType (ArrayVar _ arrayEntry) = Array (entryArrayElemType arrayEntry) (Shape $ entryArrayShape arrayEntry) NoUniqueness entryType (ScalarVar _ scalarEntry) = Prim $ entryScalarType scalarEntry entryType (AccVar _ (acc, ispace, ts)) = Acc acc ispace ts NoUniqueness runImpM :: ImpM rep r op a -> r -> Operations rep r op -> Imp.Space -> ImpState rep r op -> (a, ImpState rep r op) runImpM (ImpM m) r ops space = runState (runReaderT m $ newEnv r ops space) subImpM_ :: r' -> Operations rep r' op' -> ImpM rep r' op' a -> ImpM rep r op (Imp.Code op') subImpM_ r ops m = snd <$> subImpM r ops m subImpM :: r' -> Operations rep r' op' -> ImpM rep r' op' a -> ImpM rep r op (a, Imp.Code op') subImpM r ops (ImpM m) = do env <- ask s <- get let env' = env { envExpCompiler = opsExpCompiler ops, envStmsCompiler = opsStmsCompiler ops, envCopyCompiler = opsCopyCompiler ops, envOpCompiler = opsOpCompiler ops, envAllocCompilers = opsAllocCompilers ops, envEnv = r } s' = ImpState { stateVTable = stateVTable s, stateFunctions = mempty, stateCode = mempty, stateNameSource = stateNameSource s, stateConstants = mempty, stateWarnings = mempty, stateAccs = stateAccs s } (x, s'') = runState (runReaderT m env') s' putNameSource $ stateNameSource s'' warnings $ stateWarnings s'' pure (x, stateCode s'') -- | Execute a code generation action, returning the code that was -- emitted. collect :: ImpM rep r op () -> ImpM rep r op (Imp.Code op) collect = fmap snd . collect' collect' :: ImpM rep r op a -> ImpM rep r op (a, Imp.Code op) collect' m = do prev_code <- gets stateCode modify $ \s -> s {stateCode = mempty} x <- m new_code <- gets stateCode modify $ \s -> s {stateCode = prev_code} pure (x, new_code) -- | Emit some generated imperative code. emit :: Imp.Code op -> ImpM rep r op () emit code = modify $ \s -> s {stateCode = stateCode s <> code} warnings :: Warnings -> ImpM rep r op () warnings ws = modify $ \s -> s {stateWarnings = ws <> stateWarnings s} -- | Emit a warning about something the user should be aware of. warn :: (Located loc) => loc -> [loc] -> T.Text -> ImpM rep r op () warn loc locs problem = warnings $ singleWarning' (locOf loc) (map locOf locs) (pretty problem) -- | Emit a function in the generated code. emitFunction :: Name -> Imp.Function op -> ImpM rep r op () emitFunction fname fun = do Imp.Functions fs <- gets stateFunctions modify $ \s -> s {stateFunctions = Imp.Functions $ (fname, fun) : fs} -- | Check if a function of a given name exists. hasFunction :: Name -> ImpM rep r op Bool hasFunction fname = gets $ \s -> let Imp.Functions fs = stateFunctions s in isJust $ lookup fname fs constsVTable :: (Mem rep inner) => Stms rep -> VTable rep constsVTable = foldMap stmVtable where stmVtable (Let pat _ e) = foldMap (peVtable e) $ patElems pat peVtable e (PatElem name dec) = M.singleton name $ memBoundToVarEntry (Just e) $ letDecMem dec compileProg :: (Mem rep inner, FreeIn op, MonadFreshNames m) => r -> Operations rep r op -> Imp.Space -> Prog rep -> m (Warnings, Imp.Definitions op) compileProg r ops space (Prog types consts funs) = modifyNameSource $ \src -> let (_, ss) = unzip $ parMap rpar (compileFunDef' src) funs free_in_funs = freeIn $ mconcat $ map stateFunctions ss ((), s') = runImpM (compileConsts free_in_funs consts) r ops space $ combineStates ss in ( ( stateWarnings s', Imp.Definitions types (foldMap stateConstants ss <> stateConstants s') (stateFunctions s') ), stateNameSource s' ) where compileFunDef' src fdef = runImpM (compileFunDef types fdef) r ops space (newState src) {stateVTable = constsVTable consts} combineStates ss = let Imp.Functions funs' = mconcat $ map stateFunctions ss src = mconcat (map stateNameSource ss) in (newState src) { stateFunctions = Imp.Functions $ M.toList $ M.fromList funs', stateWarnings = mconcat $ map stateWarnings ss } compileConsts :: Names -> Stms rep -> ImpM rep r op () compileConsts used_consts stms = genConstants $ do compileStms used_consts stms $ pure () pure (used_consts, ()) lookupOpaqueType :: Name -> OpaqueTypes -> OpaqueType lookupOpaqueType v (OpaqueTypes types) = case lookup v types of Just t -> t Nothing -> error $ "Unknown opaque type: " ++ show v valueTypeSign :: ValueType -> Signedness valueTypeSign (ValueType sign _ _) = sign entryPointSignedness :: OpaqueTypes -> EntryPointType -> [Signedness] entryPointSignedness _ (TypeTransparent vt) = [valueTypeSign vt] entryPointSignedness types (TypeOpaque desc) = case lookupOpaqueType desc types of OpaqueType vts -> map valueTypeSign vts OpaqueArray _ _ vts -> map valueTypeSign vts OpaqueRecordArray _ _ fs -> foldMap (entryPointSignedness types . snd) fs OpaqueRecord fs -> foldMap (entryPointSignedness types . snd) fs OpaqueSum vts _ -> map valueTypeSign vts -- | How many value parameters are accepted by this entry point? This -- is used to determine which of the function parameters correspond to -- the parameters of the original function (they must all come at the -- end). entryPointSize :: OpaqueTypes -> EntryPointType -> Int entryPointSize _ (TypeTransparent _) = 1 entryPointSize types (TypeOpaque desc) = case lookupOpaqueType desc types of OpaqueType vts -> length vts OpaqueArray _ _ vts -> length vts OpaqueRecordArray _ _ fs -> sum $ map (entryPointSize types . snd) fs OpaqueRecord fs -> sum $ map (entryPointSize types . snd) fs OpaqueSum vts _ -> length vts compileInParam :: (Mem rep inner) => FParam rep -> ImpM rep r op (Either Imp.Param ArrayDecl) compileInParam fparam = case paramDec fparam of MemPrim bt -> pure $ Left $ Imp.ScalarParam name bt MemMem space -> pure $ Left $ Imp.MemParam name space MemArray bt shape _ (ArrayIn mem lmad) -> pure $ Right $ ArrayDecl name bt $ MemLoc mem (shapeDims shape) lmad MemAcc {} -> error "Functions may not have accumulator parameters." where name = paramName fparam data ArrayDecl = ArrayDecl VName PrimType MemLoc compileInParams :: (Mem rep inner) => OpaqueTypes -> [FParam rep] -> Maybe [EntryParam] -> ImpM rep r op ([Imp.Param], [ArrayDecl], Maybe [((Name, Uniqueness), Imp.ExternalValue)]) compileInParams types params eparams = do (inparams, arrayds) <- partitionEithers <$> mapM compileInParam params let findArray x = find (isArrayDecl x) arrayds summaries = M.fromList $ mapMaybe memSummary params where memSummary param | MemMem space <- paramDec param = Just (paramName param, space) | otherwise = Nothing findMemInfo :: VName -> Maybe Space findMemInfo = flip M.lookup summaries mkValueDesc fparam signedness = case (findArray $ paramName fparam, paramType fparam) of (Just (ArrayDecl _ bt (MemLoc mem shape _)), _) -> do memspace <- findMemInfo mem Just $ Imp.ArrayValue mem memspace bt signedness shape (_, Prim bt) -> Just $ Imp.ScalarValue bt signedness $ paramName fparam _ -> Nothing mkExts (EntryParam v u et@(TypeOpaque desc) : epts) fparams = let signs = entryPointSignedness types et n = entryPointSize types et (fparams', rest) = splitAt n fparams in ( (v, u), Imp.OpaqueValue desc (catMaybes $ zipWith mkValueDesc fparams' signs) ) : mkExts epts rest mkExts (EntryParam v u (TypeTransparent (ValueType s _ _)) : epts) (fparam : fparams) = maybeToList (((v, u),) . Imp.TransparentValue <$> mkValueDesc fparam s) ++ mkExts epts fparams mkExts _ _ = [] pure ( inparams, arrayds, case eparams of Just eparams' -> let num_val_params = sum (map (entryPointSize types . entryParamType) eparams') (_ctx_params, val_params) = splitAt (length params - num_val_params) params in Just $ mkExts eparams' val_params Nothing -> Nothing ) where isArrayDecl x (ArrayDecl y _ _) = x == y compileOutParam :: FunReturns -> ImpM rep r op (Maybe Imp.Param, ValueDestination) compileOutParam (MemPrim t) = do name <- newVName "prim_out" pure (Just $ Imp.ScalarParam name t, ScalarDestination name) compileOutParam (MemMem space) = do name <- newVName "mem_out" pure (Just $ Imp.MemParam name space, MemoryDestination name) compileOutParam MemArray {} = pure (Nothing, ArrayDestination Nothing) compileOutParam MemAcc {} = error "Functions may not return accumulators." compileExternalValues :: (Mem rep inner) => OpaqueTypes -> [RetType rep] -> [EntryResult] -> [Maybe Imp.Param] -> ImpM rep r op [(Uniqueness, Imp.ExternalValue)] compileExternalValues types orig_rts orig_epts maybe_params = do let (ctx_rts, val_rts) = splitAt (length orig_rts - sum (map (entryPointSize types . entryResultType) orig_epts)) orig_rts let nthOut i = case maybeNth i maybe_params of Just (Just p) -> Imp.paramName p Just Nothing -> error $ "Output " ++ show i ++ " not a param." Nothing -> error $ "Param " ++ show i ++ " does not exist." mkValueDesc _ signedness (MemArray t shape _ ret) = do (mem, space) <- case ret of ReturnsNewBlock space j _lmad -> pure (nthOut j, space) ReturnsInBlock mem _lmad -> do space <- entryMemSpace <$> lookupMemory mem pure (mem, space) pure $ Imp.ArrayValue mem space t signedness $ map f $ shapeDims shape where f (Free v) = v f (Ext i) = Var $ nthOut i mkValueDesc i signedness (MemPrim bt) = pure $ Imp.ScalarValue bt signedness $ nthOut i mkValueDesc _ _ MemAcc {} = error "mkValueDesc: unexpected MemAcc output." mkValueDesc _ _ MemMem {} = error "mkValueDesc: unexpected MemMem output." mkExts i (EntryResult u et@(TypeOpaque desc) : epts) rets = do let signs = entryPointSignedness types et n = entryPointSize types et (rets', rest) = splitAt n rets vds <- forM (zip3 [i ..] signs rets') $ \(j, s, r) -> mkValueDesc j s r ((u, Imp.OpaqueValue desc vds) :) <$> mkExts (i + n) epts rest mkExts i (EntryResult u (TypeTransparent (ValueType s _ _)) : epts) (ret : rets) = do vd <- mkValueDesc i s ret ((u, Imp.TransparentValue vd) :) <$> mkExts (i + 1) epts rets mkExts _ _ _ = pure [] mkExts (length ctx_rts) orig_epts val_rts compileOutParams :: (Mem rep inner) => OpaqueTypes -> [RetType rep] -> Maybe [EntryResult] -> ImpM rep r op (Maybe [(Uniqueness, Imp.ExternalValue)], [Imp.Param], [ValueDestination]) compileOutParams types orig_rts maybe_orig_epts = do (maybe_params, dests) <- mapAndUnzipM compileOutParam orig_rts evs <- case maybe_orig_epts of Just orig_epts -> Just <$> compileExternalValues types orig_rts orig_epts maybe_params Nothing -> pure Nothing pure (evs, catMaybes maybe_params, dests) compileFunDef :: (Mem rep inner) => OpaqueTypes -> FunDef rep -> ImpM rep r op () compileFunDef types (FunDef entry _ fname rettype params body) = local (\env -> env {envFunction = name_entry `mplus` Just fname}) $ do ((outparams, inparams, results, args), body') <- collect' compile let entry' = case (name_entry, results, args) of (Just name_entry', Just results', Just args') -> Just $ Imp.EntryPoint name_entry' results' args' _ -> Nothing emitFunction fname $ Imp.Function entry' outparams inparams body' where (name_entry, params_entry, ret_entry) = case entry of Nothing -> (Nothing, Nothing, Nothing) Just (x, y, z) -> (Just x, Just y, Just z) compile = do (inparams, arrayds, args) <- compileInParams types params params_entry (results, outparams, dests) <- compileOutParams types (map fst rettype) ret_entry addFParams params addArrays arrayds let Body _ stms ses = body compileStms (freeIn ses) stms $ forM_ (zip dests ses) $ \(d, SubExpRes _ se) -> copyDWIMDest d [] se [] pure (outparams, inparams, results, args) compileBody :: Pat (LetDec rep) -> Body rep -> ImpM rep r op () compileBody pat (Body _ stms ses) = do dests <- destinationFromPat pat compileStms (freeIn ses) stms $ forM_ (zip dests ses) $ \(d, SubExpRes _ se) -> copyDWIMDest d [] se [] compileBody' :: [Param dec] -> Body rep -> ImpM rep r op () compileBody' params (Body _ stms ses) = compileStms (freeIn ses) stms $ forM_ (zip params ses) $ \(param, SubExpRes _ se) -> copyDWIM (paramName param) [] se [] compileLoopBody :: (Typed dec) => [Param dec] -> Body rep -> ImpM rep r op () compileLoopBody mergeparams (Body _ stms ses) = do -- We cannot write the results to the merge parameters immediately, -- as some of the results may actually *be* merge parameters, and -- would thus be clobbered. Therefore, we first copy to new -- variables mirroring the merge parameters, and then copy this -- buffer to the merge parameters. This is efficient, because the -- operations are all scalar operations. tmpnames <- mapM (newVName . (++ "_tmp") . baseString . paramName) mergeparams compileStms (freeIn ses) stms $ do copy_to_merge_params <- forM (zip3 mergeparams tmpnames ses) $ \(p, tmp, SubExpRes _ se) -> case typeOf p of Prim pt -> do emit $ Imp.DeclareScalar tmp Imp.Nonvolatile pt emit $ Imp.SetScalar tmp $ toExp' pt se pure $ emit $ Imp.SetScalar (paramName p) $ Imp.var tmp pt Mem space | Var v <- se -> do emit $ Imp.DeclareMem tmp space emit $ Imp.SetMem tmp v space pure $ emit $ Imp.SetMem (paramName p) tmp space _ -> pure $ pure () sequence_ copy_to_merge_params compileStms :: Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op () compileStms alive_after_stms all_stms m = do cb <- asks envStmsCompiler cb alive_after_stms all_stms m defCompileStms :: (Mem rep inner, FreeIn op) => Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op () defCompileStms alive_after_stms all_stms m = -- We keep track of any memory blocks produced by the statements, -- and after the last time that memory block is used, we insert a -- Free. This is very conservative, but can cut down on lifetimes -- in some cases. void $ compileStms' mempty $ stmsToList all_stms where compileStms' allocs (Let pat aux e : bs) = do dVars (Just e) (patElems pat) e_code <- localAttrs (stmAuxAttrs aux) $ collect $ compileExp pat e (live_after, bs_code) <- collect' $ compileStms' (patternAllocs pat <> allocs) bs let dies_here v = (v `notNameIn` live_after) && (v `nameIn` freeIn e_code) to_free = S.filter (dies_here . fst) allocs emit e_code mapM_ (emit . uncurry Imp.Free) to_free emit bs_code pure $ freeIn e_code <> live_after compileStms' _ [] = do code <- collect m emit code pure $ freeIn code <> alive_after_stms patternAllocs = S.fromList . mapMaybe isMemPatElem . patElems isMemPatElem pe = case patElemType pe of Mem space -> Just (patElemName pe, space) _ -> Nothing compileExp :: Pat (LetDec rep) -> Exp rep -> ImpM rep r op () compileExp pat e = do ec <- asks envExpCompiler ec pat e -- | Generate an expression that is true if the subexpressions match -- the case pasttern. caseMatch :: [SubExp] -> [Maybe PrimValue] -> Imp.TExp Bool caseMatch ses vs = foldl (.&&.) true (zipWith cmp ses vs) where cmp se (Just (BoolValue True)) = isBool $ toExp' Bool se cmp se (Just v) = isBool $ toExp' (primValueType v) se ~==~ ValueExp v cmp _ Nothing = true defCompileExp :: (Mem rep inner) => Pat (LetDec rep) -> Exp rep -> ImpM rep r op () defCompileExp pat (Match ses cases defbody _) = foldr f (compileBody pat defbody) cases where f (Case vs body) = sIf (caseMatch ses vs) (compileBody pat body) defCompileExp pat (Apply fname args _ _) = do dest <- destinationFromPat pat targets <- funcallTargets dest args' <- catMaybes <$> mapM compileArg args emit $ Imp.Call targets fname args' where compileArg (se, _) = do t <- subExpType se case (se, t) of (_, Prim pt) -> pure $ Just $ Imp.ExpArg $ toExp' pt se (Var v, Mem {}) -> pure $ Just $ Imp.MemArg v _ -> pure Nothing defCompileExp pat (BasicOp op) = defCompileBasicOp pat op defCompileExp pat (Loop merge form body) = do attrs <- askAttrs when ("unroll" `inAttrs` attrs) $ warn (noLoc :: SrcLoc) [] "#[unroll] on loop with unknown number of iterations." -- FIXME: no location. dFParams params forM_ merge $ \(p, se) -> when ((== 0) $ arrayRank $ paramType p) $ copyDWIM (paramName p) [] se [] let doBody = compileLoopBody params body case form of ForLoop i _ bound -> do bound' <- toExp bound sFor' i bound' doBody WhileLoop cond -> sWhile (TPrimExp $ Imp.var cond Bool) doBody pat_dests <- destinationFromPat pat forM_ (zip pat_dests $ map (Var . paramName . fst) merge) $ \(d, r) -> copyDWIMDest d [] r [] where params = map fst merge defCompileExp pat (WithAcc inputs lam) = do dLParams $ lambdaParams lam forM_ (zip inputs $ lambdaParams lam) $ \((_, arrs, op), p) -> modify $ \s -> s {stateAccs = M.insert (paramName p) (arrs, op) $ stateAccs s} compileStms mempty (bodyStms $ lambdaBody lam) $ do let nonacc_res = drop num_accs (bodyResult (lambdaBody lam)) nonacc_pat_names = takeLast (length nonacc_res) (patNames pat) forM_ (zip nonacc_pat_names nonacc_res) $ \(v, SubExpRes _ se) -> copyDWIM v [] se [] where num_accs = length inputs defCompileExp pat (Op op) = do opc <- asks envOpCompiler opc pat op tracePrim :: T.Text -> PrimType -> SubExp -> ImpM rep r op () tracePrim s t se = emit . Imp.TracePrint $ ErrorMsg [ErrorString (s <> ": "), ErrorVal t (toExp' t se), ErrorString "\n"] traceArray :: T.Text -> PrimType -> Shape -> SubExp -> ImpM rep r op () traceArray s t shape se = do emit . Imp.TracePrint $ ErrorMsg [ErrorString (s <> ": ")] sLoopNest shape $ \is -> do arr_elem <- dPrimS "arr_elem" t copyDWIMFix arr_elem [] se is emit . Imp.TracePrint $ ErrorMsg [ErrorVal t (toExp' t arr_elem), " "] emit . Imp.TracePrint $ ErrorMsg ["\n"] defCompileBasicOp :: (Mem rep inner) => Pat (LetDec rep) -> BasicOp -> ImpM rep r op () defCompileBasicOp (Pat [pe]) (SubExp se) = copyDWIM (patElemName pe) [] se [] defCompileBasicOp (Pat [pe]) (Opaque op se) = do copyDWIM (patElemName pe) [] se [] case op of OpaqueNil -> pure () OpaqueTrace s -> sComment ("Trace: " <> s) $ do se_t <- subExpType se case se_t of Prim t -> tracePrim s t se Array t shape _ -> traceArray s t shape se _ -> warn [mempty :: SrcLoc] mempty $ s <> ": cannot trace value of this (core) type: " <> prettyText se_t defCompileBasicOp (Pat [pe]) (UnOp op e) = do e' <- toExp e patElemName pe <~~ Imp.UnOpExp op e' defCompileBasicOp (Pat [pe]) (ConvOp conv e) = do e' <- toExp e patElemName pe <~~ Imp.ConvOpExp conv e' defCompileBasicOp (Pat [pe]) (BinOp bop x y) = do x' <- toExp x y' <- toExp y patElemName pe <~~ Imp.BinOpExp bop x' y' defCompileBasicOp (Pat [pe]) (CmpOp bop x y) = do x' <- toExp x y' <- toExp y patElemName pe <~~ Imp.CmpOpExp bop x' y' defCompileBasicOp _ (Assert e msg loc) = do e' <- toExp e msg' <- traverse toExp msg emit $ Imp.Assert e' msg' loc attrs <- askAttrs when (AttrComp "warn" ["safety_checks"] `inAttrs` attrs) $ uncurry warn loc "Safety check required at run-time." defCompileBasicOp (Pat [pe]) (Index src slice) | Just idxs <- sliceIndices slice = copyDWIM (patElemName pe) [] (Var src) $ map (DimFix . pe64) idxs defCompileBasicOp _ Index {} = pure () defCompileBasicOp (Pat [pe]) (Update safety _ slice se) = case safety of Unsafe -> write Safe -> sWhen (inBounds slice' dims) write where slice' = fmap pe64 slice dims = map pe64 $ arrayDims $ patElemType pe write = sUpdate (patElemName pe) slice' se defCompileBasicOp _ FlatIndex {} = pure () defCompileBasicOp (Pat [pe]) (FlatUpdate _ slice v) = do pe_loc <- entryArrayLoc <$> lookupArray (patElemName pe) v_loc <- entryArrayLoc <$> lookupArray v let pe_loc' = flatSliceMemLoc pe_loc $ fmap pe64 slice copy (elemType (patElemType pe)) pe_loc' v_loc defCompileBasicOp (Pat [pe]) (Replicate shape se) | Acc {} <- patElemType pe = pure () | shape == mempty = copyDWIM (patElemName pe) [] se [] | otherwise = sLoopNest shape $ \is -> copyDWIMFix (patElemName pe) is se [] defCompileBasicOp _ Scratch {} = pure () defCompileBasicOp (Pat [pe]) (Iota n e s it) = do e' <- toExp e s' <- toExp s sFor "i" (pe64 n) $ \i -> do let i' = sExt it $ untyped i x <- dPrimV "x" . TPrimExp $ BinOpExp (Add it OverflowUndef) e' $ BinOpExp (Mul it OverflowUndef) i' s' copyDWIMFix (patElemName pe) [i] (Var (tvVar x)) [] defCompileBasicOp (Pat [pe]) (Manifest _ src) = copyDWIM (patElemName pe) [] (Var src) [] defCompileBasicOp (Pat [pe]) (Concat i (x :| ys) _) = do offs_glb <- dPrimV "tmp_offs" 0 forM_ (x : ys) $ \y -> do y_dims <- arrayDims <$> lookupType y let rows = case drop i y_dims of [] -> error $ "defCompileBasicOp Concat: empty array shape for " ++ prettyString y r : _ -> pe64 r skip_dims = take i y_dims sliceAllDim d = DimSlice 0 d 1 skip_slices = map (sliceAllDim . pe64) skip_dims destslice = skip_slices ++ [DimSlice (tvExp offs_glb) rows 1] copyDWIM (patElemName pe) destslice (Var y) [] offs_glb <-- tvExp offs_glb + rows defCompileBasicOp (Pat [pe]) (ArrayVal vs t) = do dest_mem <- entryArrayLoc <$> lookupArray (patElemName pe) static_array <- newVNameForFun "static_array" emit $ Imp.DeclareArray static_array t $ Imp.ArrayValues vs let static_src = MemLoc static_array [intConst Int64 $ fromIntegral $ length vs] $ LMAD.iota 0 [fromIntegral $ length vs] addVar static_array $ MemVar Nothing $ MemEntry DefaultSpace copy t dest_mem static_src defCompileBasicOp (Pat [pe]) (ArrayLit es _) | Just vs@(v : _) <- mapM isLiteral es = do let t = primValueType v defCompileBasicOp (Pat [pe]) (ArrayVal vs t) | otherwise = forM_ (zip [0 ..] es) $ \(i, e) -> copyDWIMFix (patElemName pe) [fromInteger i] e [] where isLiteral (Constant v) = Just v isLiteral _ = Nothing defCompileBasicOp _ Rearrange {} = pure () defCompileBasicOp _ Reshape {} = pure () defCompileBasicOp _ (UpdateAcc safety acc is vs) = sComment "UpdateAcc" $ do -- We are abusing the comment mechanism to wrap the operator in -- braces when we end up generating code. This is necessary because -- we might otherwise end up declaring lambda parameters (if any) -- multiple times, as they are duplicated every time we do an -- UpdateAcc for the same accumulator. let is' = map pe64 is -- We need to figure out whether we are updating a scatter-like -- accumulator or a generalised reduction. This also binds the -- index parameters. (_, _, arrs, dims, op) <- lookupAcc acc is' let boundsCheck = case safety of Safe -> sWhen (inBounds (Slice (map DimFix is')) dims) _ -> id boundsCheck $ case op of Nothing -> -- Scatter-like. forM_ (zip arrs vs) $ \(arr, v) -> copyDWIMFix arr is' v [] Just lam -> do -- Generalised reduction. dLParams $ lambdaParams lam let (x_params, y_params) = splitAt (length vs) $ map paramName $ lambdaParams lam forM_ (zip x_params arrs) $ \(xp, arr) -> copyDWIMFix xp [] (Var arr) is' forM_ (zip y_params vs) $ \(yp, v) -> copyDWIM yp [] v [] compileStms mempty (bodyStms $ lambdaBody lam) $ forM_ (zip arrs (bodyResult (lambdaBody lam))) $ \(arr, SubExpRes _ se) -> copyDWIMFix arr is' se [] defCompileBasicOp pat e = error $ "ImpGen.defCompileBasicOp: Invalid pattern\n " ++ prettyString pat ++ "\nfor expression\n " ++ prettyString e -- | Note: a hack to be used only for functions. addArrays :: [ArrayDecl] -> ImpM rep r op () addArrays = mapM_ addArray where addArray (ArrayDecl name bt location) = addVar name $ ArrayVar Nothing ArrayEntry { entryArrayLoc = location, entryArrayElemType = bt } -- | Like 'dFParams', but does not create new declarations. -- Note: a hack to be used only for functions. addFParams :: (Mem rep inner) => [FParam rep] -> ImpM rep r op () addFParams = mapM_ addFParam where addFParam fparam = addVar (paramName fparam) $ memBoundToVarEntry Nothing $ noUniquenessReturns $ paramDec fparam -- | Another hack. addLoopVar :: VName -> IntType -> ImpM rep r op () addLoopVar i it = addVar i $ ScalarVar Nothing $ ScalarEntry $ IntType it dVars :: (Mem rep inner) => Maybe (Exp rep) -> [PatElem (LetDec rep)] -> ImpM rep r op () dVars e = mapM_ dVar where dVar = dScope e . scopeOfPatElem dFParams :: (Mem rep inner) => [FParam rep] -> ImpM rep r op () dFParams = dScope Nothing . scopeOfFParams dLParams :: (Mem rep inner) => [LParam rep] -> ImpM rep r op () dLParams = dScope Nothing . scopeOfLParams dPrimVol :: String -> PrimType -> Imp.TExp t -> ImpM rep r op (TV t) dPrimVol name t e = do name' <- newVName name emit $ Imp.DeclareScalar name' Imp.Volatile t addVar name' $ ScalarVar Nothing $ ScalarEntry t name' <~~ untyped e pure $ TV name' t dPrim_ :: VName -> PrimType -> ImpM rep r op () dPrim_ name t = do emit $ Imp.DeclareScalar name Imp.Nonvolatile t addVar name $ ScalarVar Nothing $ ScalarEntry t -- | Create variable of some provided dynamic type. You'll need this -- when you are compiling program code of Haskell-level unknown type. -- For other things, use other functions. dPrimS :: String -> PrimType -> ImpM rep r op VName dPrimS name t = do name' <- newVName name dPrim_ name' t pure name' -- | Create 'TV' of some provided dynamic type. No guarantee that the -- dynamic type matches the inferred type. dPrimSV :: String -> PrimType -> ImpM rep r op (TV t) dPrimSV name t = TV <$> dPrimS name t <*> pure t -- | Create 'TV' of some fixed type. dPrim :: (MkTV t) => String -> ImpM rep r op (TV t) dPrim name = do name' <- newVName name let tv = mkTV name' dPrim_ name' $ tvType tv pure tv dPrimV_ :: VName -> Imp.TExp t -> ImpM rep r op () dPrimV_ name e = do dPrim_ name t TV name t <-- e where t = primExpType $ untyped e dPrimV :: String -> Imp.TExp t -> ImpM rep r op (TV t) dPrimV name e = do name' <- dPrimS name pt let tv = TV name' pt tv <-- e pure tv where pt = primExpType $ untyped e dPrimVE :: String -> Imp.TExp t -> ImpM rep r op (Imp.TExp t) dPrimVE name e = do name' <- dPrimS name pt let tv = TV name' pt tv <-- e pure $ tvExp tv where pt = primExpType $ untyped e memBoundToVarEntry :: Maybe (Exp rep) -> MemBound NoUniqueness -> VarEntry rep memBoundToVarEntry e (MemPrim bt) = ScalarVar e ScalarEntry {entryScalarType = bt} memBoundToVarEntry e (MemMem space) = MemVar e $ MemEntry space memBoundToVarEntry e (MemAcc acc ispace ts _) = AccVar e (acc, ispace, ts) memBoundToVarEntry e (MemArray bt shape _ (ArrayIn mem lmad)) = let location = MemLoc mem (shapeDims shape) lmad in ArrayVar e ArrayEntry { entryArrayLoc = location, entryArrayElemType = bt } infoDec :: (Mem rep inner) => NameInfo rep -> MemInfo SubExp NoUniqueness MemBind infoDec (LetName dec) = letDecMem dec infoDec (FParamName dec) = noUniquenessReturns dec infoDec (LParamName dec) = dec infoDec (IndexName it) = MemPrim $ IntType it dInfo :: (Mem rep inner) => Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op () dInfo e name info = do let entry = memBoundToVarEntry e $ infoDec info case entry of MemVar _ entry' -> emit $ Imp.DeclareMem name $ entryMemSpace entry' ScalarVar _ entry' -> emit $ Imp.DeclareScalar name Imp.Nonvolatile $ entryScalarType entry' ArrayVar _ _ -> pure () AccVar {} -> pure () addVar name entry dScope :: (Mem rep inner) => Maybe (Exp rep) -> Scope rep -> ImpM rep r op () dScope e = mapM_ (uncurry $ dInfo e) . M.toList dArray :: VName -> PrimType -> ShapeBase SubExp -> VName -> LMAD -> ImpM rep r op () dArray name pt shape mem lmad = addVar name $ ArrayVar Nothing $ ArrayEntry location pt where location = MemLoc mem (shapeDims shape) lmad everythingVolatile :: ImpM rep r op a -> ImpM rep r op a everythingVolatile = local $ \env -> env {envVolatility = Imp.Volatile} funcallTargets :: [ValueDestination] -> ImpM rep r op [VName] funcallTargets dests = concat <$> mapM funcallTarget dests where funcallTarget (ScalarDestination name) = pure [name] funcallTarget (ArrayDestination _) = pure [] funcallTarget (MemoryDestination name) = pure [name] -- | A typed variable, which we can turn into a typed expression, or -- use as the target for an assignment. This is used to aid in type -- safety when doing code generation, by keeping the types straight. -- It is still easy to cheat when you need to. data TV t = TV VName PrimType -- | A type class that helps ensuring that the type annotation in a -- 'TV' is correct. class MkTV t where -- | Create a typed variable from a name and a dynamic type. mkTV :: VName -> TV t -- | Extract type from a 'TV'. tvType :: TV t -> PrimType instance MkTV Bool where mkTV v = TV v Bool tvType _ = Bool instance MkTV Int8 where mkTV v = TV v (IntType Int8) tvType _ = IntType Int8 instance MkTV Int16 where mkTV v = TV v (IntType Int16) tvType _ = IntType Int16 instance MkTV Int32 where mkTV v = TV v (IntType Int32) tvType _ = IntType Int32 instance MkTV Int64 where mkTV v = TV v (IntType Int64) tvType _ = IntType Int64 instance MkTV Half where mkTV v = TV v (FloatType Float16) tvType _ = FloatType Float16 instance MkTV Float where mkTV v = TV v (FloatType Float32) tvType _ = FloatType Float32 instance MkTV Double where mkTV v = TV v (FloatType Float64) tvType _ = FloatType Float64 -- | Convert a typed variable to a size (a SubExp). tvSize :: TV t -> Imp.DimSize tvSize = Var . tvVar -- | Convert a typed variable to a similarly typed expression. tvExp :: TV t -> Imp.TExp t tvExp (TV v t) = Imp.TPrimExp $ Imp.var v t -- | Extract the underlying variable name from a typed variable. tvVar :: TV t -> VName tvVar (TV v _) = v -- | Compile things to 'Imp.Exp'. class ToExp a where -- | Compile to an 'Imp.Exp', where the type (which must still be a -- primitive) is deduced monadically. toExp :: a -> ImpM rep r op Imp.Exp -- | Compile where we know the type in advance. toExp' :: PrimType -> a -> Imp.Exp instance ToExp SubExp where toExp (Constant v) = pure $ Imp.ValueExp v toExp (Var v) = lookupVar v >>= \case ScalarVar _ (ScalarEntry pt) -> pure $ Imp.var v pt _ -> error $ "toExp SubExp: SubExp is not a primitive type: " ++ prettyString v toExp' _ (Constant v) = Imp.ValueExp v toExp' t (Var v) = Imp.var v t instance ToExp VName where toExp = toExp . Var toExp' t = toExp' t . Var instance ToExp (PrimExp VName) where toExp = pure toExp' _ = id addVar :: VName -> VarEntry rep -> ImpM rep r op () addVar name entry = modify $ \s -> s {stateVTable = M.insert name entry $ stateVTable s} localDefaultSpace :: Imp.Space -> ImpM rep r op a -> ImpM rep r op a localDefaultSpace space = local (\env -> env {envDefaultSpace = space}) askFunction :: ImpM rep r op (Maybe Name) askFunction = asks envFunction -- | Generate a 'VName', prefixed with 'askFunction' if it exists. newVNameForFun :: String -> ImpM rep r op VName newVNameForFun s = do fname <- fmap nameToString <$> askFunction newVName $ maybe "" (++ ".") fname ++ s -- | Generate a 'Name', prefixed with 'askFunction' if it exists. nameForFun :: String -> ImpM rep r op Name nameForFun s = do fname <- askFunction pure $ maybe "" (<> ".") fname <> nameFromString s askEnv :: ImpM rep r op r askEnv = asks envEnv localEnv :: (r -> r) -> ImpM rep r op a -> ImpM rep r op a localEnv f = local $ \env -> env {envEnv = f $ envEnv env} -- | The active attributes, including those for the statement -- currently being compiled. askAttrs :: ImpM rep r op Attrs askAttrs = asks envAttrs -- | Add more attributes to what is returning by 'askAttrs'. localAttrs :: Attrs -> ImpM rep r op a -> ImpM rep r op a localAttrs attrs = local $ \env -> env {envAttrs = attrs <> envAttrs env} localOps :: Operations rep r op -> ImpM rep r op a -> ImpM rep r op a localOps ops = local $ \env -> env { envExpCompiler = opsExpCompiler ops, envStmsCompiler = opsStmsCompiler ops, envCopyCompiler = opsCopyCompiler ops, envOpCompiler = opsOpCompiler ops, envAllocCompilers = opsAllocCompilers ops } -- | Get the current symbol table. getVTable :: ImpM rep r op (VTable rep) getVTable = gets stateVTable putVTable :: VTable rep -> ImpM rep r op () putVTable vtable = modify $ \s -> s {stateVTable = vtable} -- | Run an action with a modified symbol table. All changes to the -- symbol table will be reverted once the action is done! localVTable :: (VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a localVTable f m = do old_vtable <- getVTable putVTable $ f old_vtable a <- m putVTable old_vtable pure a lookupVar :: VName -> ImpM rep r op (VarEntry rep) lookupVar name = do res <- gets $ M.lookup name . stateVTable case res of Just entry -> pure entry _ -> error $ "Unknown variable: " ++ prettyString name lookupArray :: VName -> ImpM rep r op ArrayEntry lookupArray name = do res <- lookupVar name case res of ArrayVar _ entry -> pure entry _ -> error $ "ImpGen.lookupArray: not an array: " ++ prettyString name lookupMemory :: VName -> ImpM rep r op MemEntry lookupMemory name = do res <- lookupVar name case res of MemVar _ entry -> pure entry _ -> error $ "Unknown memory block: " ++ prettyString name -- | In which memory space is this array allocated? lookupArraySpace :: VName -> ImpM rep r op Space lookupArraySpace = fmap entryMemSpace . lookupMemory <=< fmap (memLocName . entryArrayLoc) . lookupArray -- | In the case of a histogram-like accumulator, also sets the index -- parameters. lookupAcc :: VName -> [Imp.TExp Int64] -> ImpM rep r op (VName, Space, [VName], [Imp.TExp Int64], Maybe (Lambda rep)) lookupAcc name is = do res <- lookupVar name case res of AccVar _ (acc, ispace, _) -> do acc' <- gets $ M.lookup acc . stateAccs case acc' of Just ([], _) -> error $ "Accumulator with no arrays: " ++ prettyString name Just (arrs@(arr : _), Just (op, _)) -> do space <- lookupArraySpace arr let (i_params, ps) = splitAt (length is) $ lambdaParams op zipWithM_ dPrimV_ (map paramName i_params) is pure ( acc, space, arrs, map pe64 (shapeDims ispace), Just op {lambdaParams = ps} ) Just (arrs@(arr : _), Nothing) -> do space <- lookupArraySpace arr pure (acc, space, arrs, map pe64 (shapeDims ispace), Nothing) Nothing -> error $ "ImpGen.lookupAcc: unlisted accumulator: " ++ prettyString name _ -> error $ "ImpGen.lookupAcc: not an accumulator: " ++ prettyString name destinationFromPat :: Pat (LetDec rep) -> ImpM rep r op [ValueDestination] destinationFromPat = mapM inspect . patElems where inspect pe = do let name = patElemName pe entry <- lookupVar name case entry of ArrayVar _ (ArrayEntry MemLoc {} _) -> pure $ ArrayDestination Nothing MemVar {} -> pure $ MemoryDestination name ScalarVar {} -> pure $ ScalarDestination name AccVar {} -> pure $ ArrayDestination Nothing fullyIndexArray :: VName -> [Imp.TExp Int64] -> ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64)) fullyIndexArray name indices = do arr <- lookupArray name fullyIndexArray' (entryArrayLoc arr) indices fullyIndexArray' :: MemLoc -> [Imp.TExp Int64] -> ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64)) fullyIndexArray' (MemLoc mem _ lmad) indices = do space <- entryMemSpace <$> lookupMemory mem pure ( mem, space, elements $ LMAD.index lmad indices ) -- More complicated read/write operations that use index functions. copy :: CopyCompiler rep r op copy bt dst@(MemLoc dst_name _ dst_ixfn@dst_lmad) src@(MemLoc src_name _ src_ixfn@src_lmad) = do -- If we can statically determine that the two index-functions -- are equivalent, don't do anything unless (dst_name == src_name && dst_ixfn `LMAD.equivalent` src_ixfn) $ -- It's also possible that we can dynamically determine that the two -- index-functions are equivalent. sUnless ( fromBool (dst_name == src_name) .&&. LMAD.dynamicEqualsLMAD dst_lmad src_lmad ) $ do -- If none of the above is true, actually do the copy cc <- asks envCopyCompiler cc bt dst src lmadCopy :: CopyCompiler rep r op lmadCopy t dstloc srcloc = do let dstmem = memLocName dstloc srcmem = memLocName srcloc dstlmad = memLocLMAD dstloc srclmad = memLocLMAD srcloc srcspace <- entryMemSpace <$> lookupMemory srcmem dstspace <- entryMemSpace <$> lookupMemory dstmem emit $ Imp.Copy t (elements <$> LMAD.shape dstlmad) (dstmem, dstspace) ( LMAD.offset $ elements <$> dstlmad, map LMAD.ldStride $ LMAD.dims $ elements <$> dstlmad ) (srcmem, srcspace) ( LMAD.offset $ elements <$> srclmad, map LMAD.ldStride $ LMAD.dims $ elements <$> srclmad ) -- | Copy from here to there; both destination and source may be -- indexeded. copyArrayDWIM :: PrimType -> MemLoc -> [DimIndex (Imp.TExp Int64)] -> MemLoc -> [DimIndex (Imp.TExp Int64)] -> ImpM rep r op (Imp.Code op) copyArrayDWIM bt destlocation@(MemLoc _ destshape _) destslice srclocation@(MemLoc _ srcshape _) srcslice | Just destis <- mapM dimFix destslice, Just srcis <- mapM dimFix srcslice, length srcis == length srcshape, length destis == length destshape = do (targetmem, destspace, targetoffset) <- fullyIndexArray' destlocation destis (srcmem, srcspace, srcoffset) <- fullyIndexArray' srclocation srcis vol <- asks envVolatility collect $ do tmp <- dPrimS "tmp" bt emit $ Imp.Read tmp srcmem srcoffset bt srcspace vol emit $ Imp.Write targetmem targetoffset bt destspace vol $ Imp.var tmp bt | otherwise = do let destslice' = fullSliceNum (map pe64 destshape) destslice srcslice' = fullSliceNum (map pe64 srcshape) srcslice destrank = length $ sliceDims destslice' srcrank = length $ sliceDims srcslice' destlocation' = sliceMemLoc destlocation destslice' srclocation' = sliceMemLoc srclocation srcslice' if destrank /= srcrank then error $ "copyArrayDWIM: cannot copy to " ++ prettyString (memLocName destlocation) ++ " from " ++ prettyString (memLocName srclocation) ++ " because ranks do not match (" ++ prettyString destrank ++ " vs " ++ prettyString srcrank ++ ")" else if destlocation' == srclocation' then pure mempty -- Copy would be no-op. else collect $ copy bt destlocation' srclocation' -- Like 'copyDWIM', but the target is a 'ValueDestination' instead of -- a variable name. copyDWIMDest :: ValueDestination -> [DimIndex (Imp.TExp Int64)] -> SubExp -> [DimIndex (Imp.TExp Int64)] -> ImpM rep r op () copyDWIMDest _ _ (Constant v) (_ : _) = error $ unwords ["copyDWIMDest: constant source", prettyString v, "cannot be indexed."] copyDWIMDest pat dest_slice (Constant v) [] = case mapM dimFix dest_slice of Nothing -> error $ unwords ["copyDWIMDest: constant source", prettyString v, "with slice destination."] Just dest_is -> case pat of ScalarDestination name -> emit $ Imp.SetScalar name $ Imp.ValueExp v MemoryDestination {} -> error $ unwords ["copyDWIMDest: constant source", prettyString v, "cannot be written to memory destination."] ArrayDestination (Just dest_loc) -> do (dest_mem, dest_space, dest_i) <- fullyIndexArray' dest_loc dest_is vol <- asks envVolatility emit $ Imp.Write dest_mem dest_i bt dest_space vol $ Imp.ValueExp v ArrayDestination Nothing -> error "copyDWIMDest: ArrayDestination Nothing" where bt = primValueType v copyDWIMDest dest dest_slice (Var src) src_slice = do src_entry <- lookupVar src case (dest, src_entry) of (MemoryDestination mem, MemVar _ (MemEntry space)) -> emit $ Imp.SetMem mem src space (MemoryDestination {}, _) -> error $ unwords ["copyDWIMDest: cannot write", prettyString src, "to memory destination."] (_, MemVar {}) -> error $ unwords ["copyDWIMDest: source", prettyString src, "is a memory block."] (_, ScalarVar _ (ScalarEntry _)) | not $ null src_slice -> error $ unwords ["copyDWIMDest: prim-typed source", prettyString src, "with slice", prettyString src_slice] (ScalarDestination name, _) | not $ null dest_slice -> error $ unwords ["copyDWIMDest: prim-typed target", prettyString name, "with slice", prettyString dest_slice] (ScalarDestination name, ScalarVar _ (ScalarEntry pt)) -> emit $ Imp.SetScalar name $ Imp.var src pt (ScalarDestination name, ArrayVar _ arr) | Just src_is <- mapM dimFix src_slice, length src_slice == length (entryArrayShape arr) -> do let bt = entryArrayElemType arr (mem, space, i) <- fullyIndexArray' (entryArrayLoc arr) src_is vol <- asks envVolatility emit $ Imp.Read name mem i bt space vol | otherwise -> error $ unwords [ "copyDWIMDest: prim-typed target", prettyString name, "and array-typed source", prettyString src, "of shape", prettyString (entryArrayShape arr), "sliced with", prettyString src_slice ] (ArrayDestination (Just dest_loc), ArrayVar _ src_arr) -> do let src_loc = entryArrayLoc src_arr bt = entryArrayElemType src_arr emit =<< copyArrayDWIM bt dest_loc dest_slice src_loc src_slice (ArrayDestination (Just dest_loc), ScalarVar _ (ScalarEntry bt)) | Just dest_is <- mapM dimFix dest_slice, length dest_is == length (memLocShape dest_loc) -> do (dest_mem, dest_space, dest_i) <- fullyIndexArray' dest_loc dest_is vol <- asks envVolatility emit $ Imp.Write dest_mem dest_i bt dest_space vol (Imp.var src bt) | otherwise -> error $ unwords [ "copyDWIMDest: array-typed target and prim-typed source", prettyString src, "with slice", prettyString dest_slice ] (ArrayDestination Nothing, _) -> pure () -- Nothing to do; something else set some memory -- somewhere. (_, AccVar {}) -> pure () -- Nothing to do; accumulators are phantoms. -- | Copy from here to there; both destination and source be -- indexeded. If so, they better be arrays of enough dimensions. -- This function will generally just Do What I Mean, and Do The Right -- Thing. Both destination and source must be in scope. copyDWIM :: VName -> [DimIndex (Imp.TExp Int64)] -> SubExp -> [DimIndex (Imp.TExp Int64)] -> ImpM rep r op () copyDWIM dest dest_slice src src_slice = do dest_entry <- lookupVar dest let dest_target = case dest_entry of ScalarVar _ _ -> ScalarDestination dest ArrayVar _ (ArrayEntry (MemLoc mem shape lmad) _) -> ArrayDestination $ Just $ MemLoc mem shape lmad MemVar _ _ -> MemoryDestination dest AccVar {} -> -- Does not matter; accumulators are phantoms. ArrayDestination Nothing copyDWIMDest dest_target dest_slice src src_slice -- | As 'copyDWIM', but implicitly 'DimFix'es the indexes. copyDWIMFix :: VName -> [Imp.TExp Int64] -> SubExp -> [Imp.TExp Int64] -> ImpM rep r op () copyDWIMFix dest dest_is src src_is = copyDWIM dest (map DimFix dest_is) src (map DimFix src_is) -- | @compileAlloc pat size space@ allocates @n@ bytes of memory in -- @space@, writing the result to @pat@, which must contain a single -- memory-typed element. compileAlloc :: (Mem rep inner) => Pat (LetDec rep) -> SubExp -> Space -> ImpM rep r op () compileAlloc (Pat [mem]) e space = do let e' = Imp.bytes $ pe64 e allocator <- asks $ M.lookup space . envAllocCompilers case allocator of Nothing -> emit $ Imp.Allocate (patElemName mem) e' space Just allocator' -> allocator' (patElemName mem) e' compileAlloc pat _ _ = error $ "compileAlloc: Invalid pattern: " ++ prettyString pat -- | The number of bytes needed to represent the array in a -- straightforward contiguous format, as an t'Int64' expression. typeSize :: Type -> Count Bytes (Imp.TExp Int64) typeSize t = Imp.bytes $ primByteSize (elemType t) * product (map pe64 (arrayDims t)) -- | Is this indexing in-bounds for an array of the given shape? This -- is useful for things like scatter, which ignores out-of-bounds -- writes. inBounds :: Slice (Imp.TExp Int64) -> [Imp.TExp Int64] -> Imp.TExp Bool inBounds (Slice slice) dims = let condInBounds (DimFix i) d = 0 .<=. i .&&. i .<. d condInBounds (DimSlice i n s) d = 0 .<=. i .&&. i + (n - 1) * s .<. d in foldl1 (.&&.) $ zipWith condInBounds slice dims --- Building blocks for constructing code. sFor' :: VName -> Imp.Exp -> ImpM rep r op () -> ImpM rep r op () sFor' i bound body = do let it = case primExpType bound of IntType bound_t -> bound_t t -> error $ "sFor': bound " ++ prettyString bound ++ " is of type " ++ prettyString t addLoopVar i it body' <- collect body emit $ Imp.For i bound body' sFor :: String -> Imp.TExp t -> (Imp.TExp t -> ImpM rep r op ()) -> ImpM rep r op () sFor i bound body = do i' <- newVName i sFor' i' (untyped bound) $ body $ TPrimExp $ Imp.var i' $ primExpType $ untyped bound sWhile :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op () sWhile cond body = do body' <- collect body emit $ Imp.While cond body' -- | Execute a code generation action, wrapping the generated code -- within a 'Imp.Comment' with the given description. sComment :: T.Text -> ImpM rep r op () -> ImpM rep r op () sComment s code = do code' <- collect code emit $ Imp.Comment s code' sIf :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op () sIf cond tbranch fbranch = do tbranch' <- collect tbranch fbranch' <- collect fbranch -- Avoid generating branch if the condition is known statically. emit $ if cond == true then tbranch' else if cond == false then fbranch' else Imp.If cond tbranch' fbranch' sWhen :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op () sWhen cond tbranch = sIf cond tbranch (pure ()) sUnless :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op () sUnless cond = sIf cond (pure ()) sOp :: op -> ImpM rep r op () sOp = emit . Imp.Op sDeclareMem :: String -> Space -> ImpM rep r op VName sDeclareMem name space = do name' <- newVName name emit $ Imp.DeclareMem name' space addVar name' $ MemVar Nothing $ MemEntry space pure name' sAlloc_ :: VName -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op () sAlloc_ name' size' space = do allocator <- asks $ M.lookup space . envAllocCompilers case allocator of Nothing -> emit $ Imp.Allocate name' size' space Just allocator' -> allocator' name' size' sAlloc :: String -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op VName sAlloc name size space = do name' <- sDeclareMem name space sAlloc_ name' size space pure name' sArray :: String -> PrimType -> ShapeBase SubExp -> VName -> LMAD -> ImpM rep r op VName sArray name bt shape mem lmad = do name' <- newVName name dArray name' bt shape mem lmad pure name' -- | Declare an array in row-major order in the given memory block. sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName sArrayInMem name pt shape mem = sArray name pt shape mem $ LMAD.iota 0 $ map (isInt64 . primExpFromSubExp int64) $ shapeDims shape -- | Like 'sAllocArray', but permute the in-memory representation of the indices as specified. sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM rep r op VName sAllocArrayPerm name pt shape space perm = do let permuted_dims = rearrangeShape perm $ shapeDims shape mem <- sAlloc (name ++ "_mem") (typeSize (Array pt shape NoUniqueness)) space let iota_lmad = LMAD.iota 0 $ map (isInt64 . primExpFromSubExp int64) permuted_dims sArray name pt shape mem $ LMAD.permute iota_lmad $ rearrangeInverse perm -- | Uses linear/iota index function. sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName sAllocArray name pt shape space = sAllocArrayPerm name pt shape space [0 .. shapeRank shape - 1] -- | Uses linear/iota index function. sStaticArray :: String -> PrimType -> Imp.ArrayContents -> ImpM rep r op VName sStaticArray name pt vs = do let num_elems = case vs of Imp.ArrayValues vs' -> length vs' Imp.ArrayZeros n -> fromIntegral n shape = Shape [intConst Int64 $ toInteger num_elems] mem <- newVNameForFun $ name ++ "_mem" emit $ Imp.DeclareArray mem pt vs addVar mem $ MemVar Nothing $ MemEntry DefaultSpace sArray name pt shape mem $ LMAD.iota 0 [fromIntegral num_elems] sWrite :: VName -> [Imp.TExp Int64] -> Imp.Exp -> ImpM rep r op () sWrite arr is v = do (mem, space, offset) <- fullyIndexArray arr is vol <- asks envVolatility emit $ Imp.Write mem offset (primExpType v) space vol v sUpdate :: VName -> Slice (Imp.TExp Int64) -> SubExp -> ImpM rep r op () sUpdate arr slice v = copyDWIM arr (unSlice slice) v [] -- | Create a sequential 'Imp.For' loop covering a space of the given -- shape. The function is calling with the indexes for a given -- iteration. sLoopSpace :: [Imp.TExp t] -> ([Imp.TExp t] -> ImpM rep r op ()) -> ImpM rep r op () sLoopSpace = nest [] where nest is [] f = f $ reverse is nest is (d : ds) f = sFor "nest_i" d $ \i -> nest (i : is) ds f sLoopNest :: Shape -> ([Imp.TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op () sLoopNest = sLoopSpace . map pe64 . shapeDims -- | Untyped assignment. (<~~) :: VName -> Imp.Exp -> ImpM rep r op () x <~~ e = emit $ Imp.SetScalar x e infixl 3 <~~ -- | Typed assignment. (<--) :: TV t -> Imp.TExp t -> ImpM rep r op () TV x _ <-- e = emit $ Imp.SetScalar x $ untyped e infixl 3 <-- -- | Constructing an ad-hoc function that does not -- correspond to any of the IR functions in the input program. function :: Name -> [Imp.Param] -> [Imp.Param] -> ImpM rep r op () -> ImpM rep r op () function fname outputs inputs m = local newFunction $ do body <- collect $ do mapM_ addParam $ outputs ++ inputs m emitFunction fname $ Imp.Function Nothing outputs inputs body where addParam (Imp.MemParam name space) = addVar name $ MemVar Nothing $ MemEntry space addParam (Imp.ScalarParam name bt) = addVar name $ ScalarVar Nothing $ ScalarEntry bt newFunction env = env {envFunction = Just fname} -- Fish out those top-level declarations in the constant -- initialisation code that are free in the functions. constParams :: Names -> Imp.Code a -> (DL.DList Imp.Param, Imp.Code a) constParams used (x Imp.:>>: y) = constParams used x <> constParams used y constParams used (Imp.DeclareMem name space) | name `nameIn` used = ( DL.singleton $ Imp.MemParam name space, mempty ) constParams used (Imp.DeclareScalar name _ t) | name `nameIn` used = ( DL.singleton $ Imp.ScalarParam name t, mempty ) constParams used s@(Imp.DeclareArray name _ _) | name `nameIn` used = ( DL.singleton $ Imp.MemParam name DefaultSpace, s ) constParams _ s = (mempty, s) -- | Generate constants that get put outside of all functions. Will -- be executed at program startup. Action must return the names that -- should should be made available. This one has real sharp edges. Do -- not use inside 'subImpM'. Do not use any variable from the context. genConstants :: ImpM rep r op (Names, a) -> ImpM rep r op a genConstants m = do ((avail, a), code) <- collect' m let consts = uncurry Imp.Constants $ first DL.toList $ constParams avail code modify $ \s -> s {stateConstants = stateConstants s <> consts} pure a dSlices :: [Imp.TExp Int64] -> ImpM rep r op [Imp.TExp Int64] dSlices = fmap (drop 1 . snd) . dSlices' where dSlices' [] = pure (1, [1]) dSlices' (n : ns) = do (prod, ns') <- dSlices' ns n' <- dPrimVE "slice" $ n * prod pure (n', n' : ns') -- | @dIndexSpace f dims i@ computes a list of indices into an -- array with dimension @dims@ given the flat index @i@. The -- resulting list will have the same size as @dims@. Intermediate -- results are passed to @f@. dIndexSpace :: [(VName, Imp.TExp Int64)] -> Imp.TExp Int64 -> ImpM rep r op () dIndexSpace vs_ds j = do slices <- dSlices (map snd vs_ds) loop (zip (map fst vs_ds) slices) j where loop ((v, size) : rest) i = do dPrimV_ v (i `quot` size) i' <- dPrimVE "remnant" $ i - Imp.le64 v * size loop rest i' loop _ _ = pure () -- | Like 'dIndexSpace', but invent some new names for the indexes -- based on the given template. dIndexSpace' :: String -> [Imp.TExp Int64] -> Imp.TExp Int64 -> ImpM rep r op [Imp.TExp Int64] dIndexSpace' desc ds j = do ivs <- replicateM (length ds) (newVName desc) dIndexSpace (zip ivs ds) j pure $ map Imp.le64 ivs futhark-0.25.27/src/Futhark/CodeGen/ImpGen/000077500000000000000000000000001475065116200202205ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/ImpGen/CUDA.hs000066400000000000000000000010171475065116200212670ustar00rootroot00000000000000-- | Code generation for ImpCode with CUDA kernels. module Futhark.CodeGen.ImpGen.CUDA ( compileProg, Warnings, ) where import Data.Bifunctor (second) import Futhark.CodeGen.ImpCode.OpenCL import Futhark.CodeGen.ImpGen.GPU import Futhark.CodeGen.ImpGen.GPU.ToOpenCL import Futhark.IR.GPUMem import Futhark.MonadFreshNames -- | Compile the program to ImpCode with CUDA kernels. compileProg :: (MonadFreshNames m) => Prog GPUMem -> m (Warnings, Program) compileProg prog = second kernelsToCUDA <$> compileProgCUDA prog futhark-0.25.27/src/Futhark/CodeGen/ImpGen/GPU.hs000066400000000000000000000233411475065116200212120ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Compile a 'GPUMem' program to imperative code with kernels. -- This is mostly (but not entirely) the same process no matter if we -- are targeting OpenCL or CUDA. The important distinctions (the host -- level code) are introduced later. module Futhark.CodeGen.ImpGen.GPU ( compileProgOpenCL, compileProgCUDA, compileProgHIP, Warnings, ) where import Control.Monad import Data.List qualified as L import Data.Map qualified as M import Data.Maybe import Futhark.CodeGen.ImpCode.GPU qualified as Imp import Futhark.CodeGen.ImpGen hiding (compileProg) import Futhark.CodeGen.ImpGen qualified import Futhark.CodeGen.ImpGen.GPU.Base import Futhark.CodeGen.ImpGen.GPU.SegHist import Futhark.CodeGen.ImpGen.GPU.SegMap import Futhark.CodeGen.ImpGen.GPU.SegRed import Futhark.CodeGen.ImpGen.GPU.SegScan import Futhark.Error import Futhark.IR.GPUMem import Futhark.MonadFreshNames import Futhark.Util.IntegralExp (divUp, nextMul) import Prelude hiding (quot, rem) callKernelOperations :: Operations GPUMem HostEnv Imp.HostOp callKernelOperations = Operations { opsExpCompiler = expCompiler, opsCopyCompiler = lmadCopy, opsOpCompiler = opCompiler, opsStmsCompiler = defCompileStms, opsAllocCompilers = mempty } openclAtomics, cudaAtomics :: AtomicBinOp (openclAtomics, cudaAtomics) = (flip lookup opencl, flip lookup cuda) where opencl64 = [ (Add Int64 OverflowUndef, Imp.AtomicAdd Int64), (FAdd Float64, Imp.AtomicFAdd Float64), (SMax Int64, Imp.AtomicSMax Int64), (SMin Int64, Imp.AtomicSMin Int64), (UMax Int64, Imp.AtomicUMax Int64), (UMin Int64, Imp.AtomicUMin Int64), (And Int64, Imp.AtomicAnd Int64), (Or Int64, Imp.AtomicOr Int64), (Xor Int64, Imp.AtomicXor Int64) ] opencl32 = [ (Add Int32 OverflowUndef, Imp.AtomicAdd Int32), (FAdd Float32, Imp.AtomicFAdd Float32), (SMax Int32, Imp.AtomicSMax Int32), (SMin Int32, Imp.AtomicSMin Int32), (UMax Int32, Imp.AtomicUMax Int32), (UMin Int32, Imp.AtomicUMin Int32), (And Int32, Imp.AtomicAnd Int32), (Or Int32, Imp.AtomicOr Int32), (Xor Int32, Imp.AtomicXor Int32) ] opencl = opencl32 ++ opencl64 cuda = opencl compileProg :: (MonadFreshNames m) => HostEnv -> Prog GPUMem -> m (Warnings, Imp.Program) compileProg env = Futhark.CodeGen.ImpGen.compileProg env callKernelOperations device_space where device_space = Imp.Space "device" -- | Compile a 'GPUMem' program to low-level parallel code, with -- either CUDA or OpenCL characteristics. compileProgOpenCL, compileProgCUDA, compileProgHIP :: (MonadFreshNames m) => Prog GPUMem -> m (Warnings, Imp.Program) compileProgOpenCL = compileProg $ HostEnv openclAtomics OpenCL mempty compileProgCUDA = compileProg $ HostEnv cudaAtomics CUDA mempty compileProgHIP = compileProg $ HostEnv cudaAtomics HIP mempty opCompiler :: Pat LetDecMem -> Op GPUMem -> CallKernelGen () opCompiler dest (Alloc e space) = compileAlloc dest e space opCompiler (Pat [pe]) (Inner (SizeOp (GetSize key size_class))) = do fname <- askFunction sOp $ Imp.GetSize (patElemName pe) (keyWithEntryPoint fname key) $ sizeClassWithEntryPoint fname size_class opCompiler (Pat [pe]) (Inner (SizeOp (CmpSizeLe key size_class x))) = do fname <- askFunction let size_class' = sizeClassWithEntryPoint fname size_class sOp . Imp.CmpSizeLe (patElemName pe) (keyWithEntryPoint fname key) size_class' =<< toExp x opCompiler (Pat [pe]) (Inner (SizeOp (GetSizeMax size_class))) = sOp $ Imp.GetSizeMax (patElemName pe) size_class opCompiler (Pat [pe]) (Inner (SizeOp (CalcNumBlocks w64 max_num_tblocks_key tblock_size))) = do fname <- askFunction max_num_tblocks :: TV Int64 <- dPrim "max_num_tblocks" sOp $ Imp.GetSize (tvVar max_num_tblocks) (keyWithEntryPoint fname max_num_tblocks_key) $ sizeClassWithEntryPoint fname SizeGrid -- If 'w' is small, we launch fewer blocks than we normally would. -- We don't want any idle blocks. -- -- The calculations are done with 64-bit integers to avoid overflow -- issues. let num_tblocks_maybe_zero = sMin64 (pe64 w64 `divUp` pe64 tblock_size) $ sExt64 (tvExp max_num_tblocks) -- We also don't want zero blocks. let num_tblocks = sMax64 1 num_tblocks_maybe_zero mkTV (patElemName pe) <-- sExt32 num_tblocks opCompiler dest (Inner (SegOp op)) = segOpCompiler dest op opCompiler (Pat pes) (Inner (GPUBody _ (Body _ stms res))) = do tid <- newVName "tid" let one = Count (intConst Int64 1) sKernelThread "gpuseq" tid (defKernelAttrs one one) $ compileStms (freeIn res) stms $ forM_ (zip pes res) $ \(pe, SubExpRes _ se) -> copyDWIMFix (patElemName pe) [0] se [] opCompiler pat e = compilerBugS $ "opCompiler: Invalid pattern\n " ++ prettyString pat ++ "\nfor expression\n " ++ prettyString e sizeClassWithEntryPoint :: Maybe Name -> Imp.SizeClass -> Imp.SizeClass sizeClassWithEntryPoint fname (Imp.SizeThreshold path def) = Imp.SizeThreshold (map f path) def where f (name, x) = (keyWithEntryPoint fname name, x) sizeClassWithEntryPoint _ size_class = size_class segOpCompiler :: Pat LetDecMem -> SegOp SegLevel GPUMem -> CallKernelGen () segOpCompiler pat (SegMap lvl space _ kbody) = compileSegMap pat lvl space kbody segOpCompiler pat (SegRed lvl@(SegThread _ _) space reds _ kbody) = compileSegRed pat lvl space reds kbody segOpCompiler pat (SegScan lvl@(SegThread _ _) space scans _ kbody) = compileSegScan pat lvl space scans kbody segOpCompiler pat (SegHist lvl@(SegThread _ _) space ops _ kbody) = compileSegHist pat lvl space ops kbody segOpCompiler pat segop = compilerBugS $ "segOpCompiler: unexpected " ++ prettyString (segLevel segop) ++ " for rhs of pattern " ++ prettyString pat -- Create boolean expression that checks whether all kernels in the -- enclosed code do not use more shared memory than we have available. -- We look at *all* the kernels here, even those that might be -- otherwise protected by their own multi-versioning branches deeper -- down. Currently the compiler will not generate multi-versioning -- that makes this a problem, but it might in the future. checkSharedMemoryReqs :: (VName -> Bool) -> Imp.HostCode -> CallKernelGen (Maybe (Imp.TExp Bool)) checkSharedMemoryReqs in_scope code = do let alloc_sizes = map (sum . map alignedSize . localAllocSizes . Imp.kernelBody) $ getGPU code -- If any of the sizes involve a variable that is not known at this -- point, then we cannot check the requirements. if not $ all in_scope $ namesToList $ freeIn alloc_sizes then pure Nothing else do shared_memory_capacity :: TV Int64 <- dPrim "shared_memory_capacity" sOp $ Imp.GetSizeMax (tvVar shared_memory_capacity) SizeSharedMemory let shared_memory_capacity_64 = sExt64 $ tvExp shared_memory_capacity fits size = unCount size .<=. shared_memory_capacity_64 pure $ Just $ L.foldl' (.&&.) true (map fits alloc_sizes) where getGPU = foldMap getKernel getKernel (Imp.CallKernel k) | Imp.kernelCheckSharedMemory k = [k] getKernel _ = [] localAllocSizes = foldMap localAllocSize localAllocSize (Imp.SharedAlloc _ size) = [size] localAllocSize _ = [] -- These allocations will actually be padded to an 8-byte aligned -- size, so we should take that into account when checking whether -- they fit. alignedSize x = nextMul x 8 withAcc :: Pat LetDecMem -> [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))] -> Lambda GPUMem -> CallKernelGen () withAcc pat inputs lam = do atomics <- hostAtomics <$> askEnv locksForInputs atomics $ zip accs inputs where accs = map paramName $ lambdaParams lam locksForInputs _ [] = defCompileExp pat $ WithAcc inputs lam locksForInputs atomics ((c, (_, _, op)) : inputs') | Just (op_lam, _) <- op, AtomicLocking _ <- atomicUpdateLocking atomics op_lam = do let num_locks = 100151 locks_arr <- genZeroes "withacc_locks" num_locks let locks = Locks locks_arr num_locks extend env = env {hostLocks = M.insert c locks $ hostLocks env} localEnv extend $ locksForInputs atomics inputs' | otherwise = locksForInputs atomics inputs' expCompiler :: ExpCompiler GPUMem HostEnv Imp.HostOp -- We generate a simple kernel for iota and replicate. expCompiler (Pat [pe]) (BasicOp (Iota n x s et)) = do x' <- toExp x s' <- toExp s sIota (patElemName pe) (pe64 n) x' s' et expCompiler (Pat [pe]) (BasicOp (Replicate shape se)) | Acc {} <- patElemType pe = pure () | otherwise = if shapeRank shape == 0 then copyDWIM (patElemName pe) [] se [] else sReplicate (patElemName pe) se -- Allocation in the "shared" space is just a placeholder. expCompiler _ (Op (Alloc _ (Space "shared"))) = pure () expCompiler pat (WithAcc inputs lam) = withAcc pat inputs lam -- This is a multi-versioning Match created by incremental flattening. -- We need to augment the conditional with a check that any local -- memory requirements in tbranch are compatible with the hardware. -- We do not check anything for defbody, as we assume that it will -- always be safe (and what would we do if none of the branches would -- work?). expCompiler dest (Match cond (first_case : cases) defbranch sort@(MatchDec _ MatchEquiv)) = do scope <- askScope tcode <- collect $ compileBody dest $ caseBody first_case fcode <- collect $ expCompiler dest $ Match cond cases defbranch sort check <- checkSharedMemoryReqs (`M.member` scope) tcode let matches = caseMatch cond (casePat first_case) emit $ case check of Nothing -> fcode Just ok -> Imp.If (matches .&&. ok) tcode fcode expCompiler dest e = defCompileExp dest e futhark-0.25.27/src/Futhark/CodeGen/ImpGen/GPU/000077500000000000000000000000001475065116200206535ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/ImpGen/GPU/Base.hs000066400000000000000000001425711475065116200220730ustar00rootroot00000000000000{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} module Futhark.CodeGen.ImpGen.GPU.Base ( KernelConstants (..), kernelGlobalThreadId, kernelLocalThreadId, kernelBlockId, threadOperations, keyWithEntryPoint, CallKernelGen, InKernelGen, Locks (..), HostEnv (..), Target (..), KernelEnv (..), blockReduce, blockScan, blockLoop, isActive, sKernel, sKernelThread, KernelAttrs (..), defKernelAttrs, lvlKernelAttrs, allocLocal, compileThreadResult, virtualiseBlocks, kernelLoop, blockCoverSpace, fenceForArrays, updateAcc, genZeroes, isPrimParam, kernelConstToExp, getChunkSize, getSize, -- * Host-level bulk operations sReplicate, sIota, -- * Atomics AtomicBinOp, atomicUpdateLocking, Locking (..), AtomicUpdate (..), DoAtomicUpdate, writeAtomic, ) where import Control.Monad import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Futhark.CodeGen.ImpCode.GPU qualified as Imp import Futhark.CodeGen.ImpGen import Futhark.Error import Futhark.IR.GPUMem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.Transform.Rename import Futhark.Util (dropLast, nubOrd, splitFromEnd) import Futhark.Util.IntegralExp (divUp, quot, rem) import Prelude hiding (quot, rem) -- | Which target are we ultimately generating code for? While most -- of the kernels code is the same, there are some cases where we -- generate special code based on the ultimate low-level API we are -- targeting. data Target = CUDA | OpenCL | HIP -- | Information about the locks available for accumulators. data Locks = Locks { locksArray :: VName, locksCount :: Int } data HostEnv = HostEnv { hostAtomics :: AtomicBinOp, hostTarget :: Target, hostLocks :: M.Map VName Locks } data KernelEnv = KernelEnv { kernelAtomics :: AtomicBinOp, kernelConstants :: KernelConstants, kernelLocks :: M.Map VName Locks } type CallKernelGen = ImpM GPUMem HostEnv Imp.HostOp type InKernelGen = ImpM GPUMem KernelEnv Imp.KernelOp data KernelConstants = KernelConstants { kernelGlobalThreadIdVar :: TV Int32, kernelLocalThreadIdVar :: TV Int32, kernelBlockIdVar :: TV Int32, kernelNumBlocksCount :: Count NumBlocks SubExp, kernelBlockSizeCount :: Count BlockSize SubExp, kernelNumBlocks :: Imp.TExp Int64, kernelBlockSize :: Imp.TExp Int64, kernelNumThreads :: Imp.TExp Int32, kernelWaveSize :: Imp.TExp Int32, -- | A mapping from dimensions of nested SegOps to already -- computed local thread IDs. Only valid in non-virtualised case. kernelLocalIdMap :: M.Map [SubExp] [Imp.TExp Int32], -- | Mapping from dimensions of nested SegOps to how many -- iterations the virtualisation loop needs. kernelChunkItersMap :: M.Map [SubExp] (Imp.TExp Int32) } kernelGlobalThreadId, kernelLocalThreadId, kernelBlockId :: KernelConstants -> Imp.TExp Int32 kernelGlobalThreadId = tvExp . kernelGlobalThreadIdVar kernelLocalThreadId = tvExp . kernelLocalThreadIdVar kernelBlockId = tvExp . kernelBlockIdVar keyWithEntryPoint :: Maybe Name -> Name -> Name keyWithEntryPoint fname key = nameFromString $ maybe "" ((++ ".") . nameToString) fname ++ nameToString key allocLocal :: AllocCompiler GPUMem r Imp.KernelOp allocLocal mem size = sOp $ Imp.SharedAlloc mem size threadAlloc :: Pat LetDecMem -> SubExp -> Space -> InKernelGen () threadAlloc (Pat [_]) _ ScalarSpace {} = -- Handled by the declaration of the memory block, which is then -- translated to an actual scalar variable during C code generation. pure () threadAlloc (Pat [mem]) _ _ = compilerLimitationS $ "Cannot allocate memory block " ++ prettyString mem ++ " in kernel thread." threadAlloc dest _ _ = error $ "Invalid target for in-kernel allocation: " ++ show dest updateAcc :: Safety -> VName -> [SubExp] -> [SubExp] -> InKernelGen () updateAcc safety acc is vs = sComment "UpdateAcc" $ do -- See the ImpGen implementation of UpdateAcc for general notes. let is' = map pe64 is (c, space, arrs, dims, op) <- lookupAcc acc is' let boundsCheck = case safety of Safe -> sWhen (inBounds (Slice (map DimFix is')) dims) _ -> id boundsCheck $ case op of Nothing -> forM_ (zip arrs vs) $ \(arr, v) -> copyDWIMFix arr is' v [] Just lam -> do dLParams $ lambdaParams lam let (_x_params, y_params) = splitAt (length vs) $ map paramName $ lambdaParams lam forM_ (zip y_params vs) $ \(yp, v) -> copyDWIM yp [] v [] atomics <- kernelAtomics <$> askEnv case atomicUpdateLocking atomics lam of AtomicPrim f -> f space arrs is' AtomicCAS f -> f space arrs is' AtomicLocking f -> do c_locks <- M.lookup c . kernelLocks <$> askEnv case c_locks of Just (Locks locks num_locks) -> do let locking = Locking locks 0 1 0 $ pure . (`rem` fromIntegral num_locks) . flattenIndex dims f locking space arrs is' Nothing -> error $ "Missing locks for " ++ prettyString acc -- | Generate a constant device array of 32-bit integer zeroes with -- the given number of elements. Initialised with a replicate. genZeroes :: String -> Int -> CallKernelGen VName genZeroes desc n = genConstants $ do counters_mem <- sAlloc (desc <> "_mem") (4 * fromIntegral n) (Space "device") let shape = Shape [intConst Int64 (fromIntegral n)] counters <- sArrayInMem desc int32 shape counters_mem sReplicate counters $ intConst Int32 0 pure (namesFromList [counters_mem], counters) compileThreadExp :: ExpCompiler GPUMem KernelEnv Imp.KernelOp compileThreadExp (Pat [pe]) (BasicOp (Opaque _ se)) = -- Cannot print in GPU code. copyDWIM (patElemName pe) [] se [] -- The static arrays stuff does not work inside kernels. compileThreadExp (Pat [dest]) (BasicOp (ArrayVal vs t)) = compileThreadExp (Pat [dest]) (BasicOp (ArrayLit (map Constant vs) (Prim t))) compileThreadExp (Pat [dest]) (BasicOp (ArrayLit es _)) = forM_ (zip [0 ..] es) $ \(i, e) -> copyDWIMFix (patElemName dest) [fromIntegral (i :: Int64)] e [] compileThreadExp _ (BasicOp (UpdateAcc safety acc is vs)) = updateAcc safety acc is vs compileThreadExp dest e = defCompileExp dest e -- | Assign iterations of a for-loop to all threads in the kernel. -- The passed-in function is invoked with the (symbolic) iteration. -- The body must contain thread-level code. For multidimensional -- loops, use 'blockCoverSpace'. kernelLoop :: (IntExp t) => Imp.TExp t -> Imp.TExp t -> Imp.TExp t -> (Imp.TExp t -> InKernelGen ()) -> InKernelGen () kernelLoop tid num_threads n f = localOps threadOperations $ if n == num_threads then f tid else do num_chunks <- dPrimVE "num_chunks" $ n `divUp` num_threads sFor "chunk_i" num_chunks $ \chunk_i -> do i <- dPrimVE "i" $ chunk_i * num_threads + tid sWhen (i .<. n) $ f i -- | Assign iterations of a for-loop to threads in the threadblock. The -- passed-in function is invoked with the (symbolic) iteration. For -- multidimensional loops, use 'blockCoverSpace'. blockLoop :: (IntExp t) => Imp.TExp t -> (Imp.TExp t -> InKernelGen ()) -> InKernelGen () blockLoop n f = do constants <- kernelConstants <$> askEnv kernelLoop (kernelLocalThreadId constants `sExtAs` n) (kernelBlockSize constants `sExtAs` n) n f -- | Iterate collectively though a multidimensional space, such that -- all threads in the block participate. The passed-in function is -- invoked with a (symbolic) point in the index space. blockCoverSpace :: (IntExp t) => [Imp.TExp t] -> ([Imp.TExp t] -> InKernelGen ()) -> InKernelGen () blockCoverSpace ds f = do constants <- kernelConstants <$> askEnv let tblock_size = kernelBlockSize constants case splitFromEnd 1 ds of -- Optimise the case where the inner dimension of the space is -- equal to the block size. (ds', [last_d]) | last_d == (tblock_size `sExtAs` last_d) -> do let ltid = kernelLocalThreadId constants `sExtAs` last_d sLoopSpace ds' $ \ds_is -> f $ ds_is ++ [ltid] _ -> blockLoop (product ds) $ f . unflattenIndex ds -- Which fence do we need to protect shared access to this memory space? fenceForSpace :: Space -> Imp.Fence fenceForSpace (Space "shared") = Imp.FenceLocal fenceForSpace _ = Imp.FenceGlobal -- | If we are touching these arrays, which kind of fence do we need? fenceForArrays :: [VName] -> InKernelGen Imp.Fence fenceForArrays = fmap (L.foldl' max Imp.FenceLocal) . mapM need where need arr = fmap (fenceForSpace . entryMemSpace) . lookupMemory . memLocName . entryArrayLoc =<< lookupArray arr isPrimParam :: (Typed p) => Param p -> Bool isPrimParam = primType . paramType kernelConstToExp :: Imp.KernelConstExp -> CallKernelGen Imp.Exp kernelConstToExp = traverse f where f (Imp.SizeMaxConst c) = do v <- dPrimS (prettyString c) int64 sOp $ Imp.GetSizeMax v c pure v f (Imp.SizeConst k c) = do v <- dPrimS (nameToString k) int64 sOp $ Imp.GetSize v k c pure v -- | Given available register and a list of parameter types, compute -- the largest available chunk size given the parameters for which we -- want chunking and the available resources. Used in -- 'SegScan.SinglePass.compileSegScan', and 'SegRed.compileSegRed' -- (with primitive non-commutative operators only). getChunkSize :: [Type] -> Imp.KernelConstExp getChunkSize types = do let max_tblock_size = Imp.SizeMaxConst SizeThreadBlock max_block_mem = Imp.SizeMaxConst SizeSharedMemory max_block_reg = Imp.SizeMaxConst SizeRegisters k_mem = le64 max_block_mem `quot` le64 max_tblock_size k_reg = le64 max_block_reg `quot` le64 max_tblock_size types' = map elemType $ filter primType types sizes = map primByteSize types' sum_sizes = sum sizes sum_sizes' = sum (map (sMax64 4 . primByteSize) types') `quot` 4 max_size = maximum sizes mem_constraint = max k_mem sum_sizes `quot` max_size reg_constraint = (k_reg - 1 - sum_sizes') `quot` (2 * sum_sizes') untyped $ sMax64 1 $ sMin64 mem_constraint reg_constraint inChunkScan :: KernelConstants -> Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) -> Imp.TExp Int64 -> Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool -> [VName] -> InKernelGen () -> Lambda GPUMem -> InKernelGen () inChunkScan constants seg_flag arrs_full_size lockstep_width block_size active arrs barrier scan_lam = everythingVolatile $ do skip_threads <- dPrim "skip_threads" let actual_params = lambdaParams scan_lam (x_params, y_params) = splitAt (length actual_params `div` 2) actual_params y_to_x = forM_ (zip x_params y_params) $ \(x, y) -> when (isPrimParam x) $ copyDWIM (paramName x) [] (Var (paramName y)) [] -- Set initial y values sComment "read input for in-block scan" $ sWhen active $ do zipWithM_ readInitial y_params arrs -- Since the final result is expected to be in x_params, we may -- need to copy it there for the first thread in the block. sWhen (in_block_id .==. 0) y_to_x when array_scan barrier let op_to_x in_block_thread_active | Nothing <- seg_flag = localOps threadOperations . sWhen in_block_thread_active $ compileBody' x_params $ lambdaBody scan_lam | Just flag_true <- seg_flag = do inactive <- dPrimVE "inactive" $ flag_true (ltid32 - tvExp skip_threads) ltid32 sWhen (in_block_thread_active .&&. inactive) $ forM_ (zip x_params y_params) $ \(x, y) -> copyDWIM (paramName x) [] (Var (paramName y)) [] -- The convoluted control flow is to ensure all threads -- hit this barrier (if applicable). when array_scan barrier localOps threadOperations . sWhen in_block_thread_active . sUnless inactive $ compileBody' x_params $ lambdaBody scan_lam maybeBarrier = sWhen (lockstep_width .<=. tvExp skip_threads) barrier sComment "in-block scan (hopefully no barriers needed)" $ do skip_threads <-- 1 sWhile (tvExp skip_threads .<. block_size) $ do thread_active <- dPrimVE "thread_active" $ tvExp skip_threads .<=. in_block_id .&&. active sWhen thread_active . sComment "read operands" $ zipWithM_ (readParam (sExt64 $ tvExp skip_threads)) x_params arrs sComment "perform operation" $ op_to_x thread_active maybeBarrier sWhen thread_active . sComment "write result" $ sequence_ $ zipWith3 writeResult x_params y_params arrs maybeBarrier skip_threads <-- tvExp skip_threads * 2 where block_id = ltid32 `quot` block_size in_block_id = ltid32 - block_id * block_size ltid32 = kernelLocalThreadId constants ltid = sExt64 ltid32 gtid = sExt64 $ kernelGlobalThreadId constants array_scan = not $ all primType $ lambdaReturnType scan_lam readInitial p arr | isPrimParam p = copyDWIMFix (paramName p) [] (Var arr) [ltid] | otherwise = copyDWIMFix (paramName p) [] (Var arr) [gtid] readParam behind p arr | isPrimParam p = copyDWIMFix (paramName p) [] (Var arr) [ltid - behind] | otherwise = copyDWIMFix (paramName p) [] (Var arr) [gtid - behind + arrs_full_size] writeResult x y arr = do when (isPrimParam x) $ copyDWIMFix arr [ltid] (Var $ paramName x) [] copyDWIM (paramName y) [] (Var $ paramName x) [] blockScan :: Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) -> Imp.TExp Int64 -> Imp.TExp Int64 -> Lambda GPUMem -> [VName] -> InKernelGen () blockScan seg_flag arrs_full_size w lam arrs = do constants <- kernelConstants <$> askEnv renamed_lam <- renameLambda lam let ltid32 = kernelLocalThreadId constants ltid = sExt64 ltid32 (x_params, y_params) = splitAt (length arrs) $ lambdaParams lam dLParams (lambdaParams lam ++ lambdaParams renamed_lam) ltid_in_bounds <- dPrimVE "ltid_in_bounds" $ ltid .<. w fence <- fenceForArrays arrs -- The scan works by splitting the block into chunks, which are -- scanned separately. Typically, these chunks are at most the -- lockstep width, which enables barrier-free execution inside them. -- -- We hardcode the chunk size here. The only requirement is that it -- should not be less than the square root of the block size. With -- 32, we will work on blocks of size 1024 or smaller, which fits -- every device Troels has seen. Still, it would be nicer if it were -- a runtime parameter. Some day. let chunk_size = 32 simd_width = kernelWaveSize constants chunk_id = ltid32 `quot` chunk_size in_chunk_id = ltid32 - chunk_id * chunk_size doInChunkScan seg_flag' active = inChunkScan constants seg_flag' arrs_full_size simd_width chunk_size active arrs barrier array_scan = not $ all primType $ lambdaReturnType lam barrier | array_scan = sOp $ Imp.Barrier Imp.FenceGlobal | otherwise = sOp $ Imp.Barrier fence errorsync | array_scan = sOp $ Imp.ErrorSync Imp.FenceGlobal | otherwise = sOp $ Imp.ErrorSync Imp.FenceLocal block_offset = sExt64 (kernelBlockId constants) * kernelBlockSize constants writeBlockResult p arr | isPrimParam p = copyDWIMFix arr [sExt64 chunk_id] (Var $ paramName p) [] | otherwise = copyDWIMFix arr [block_offset + sExt64 chunk_id] (Var $ paramName p) [] readPrevBlockResult p arr | isPrimParam p = copyDWIMFix (paramName p) [] (Var arr) [sExt64 chunk_id - 1] | otherwise = copyDWIMFix (paramName p) [] (Var arr) [block_offset + sExt64 chunk_id - 1] doInChunkScan seg_flag ltid_in_bounds lam barrier let is_first_block = chunk_id .==. 0 when array_scan $ do sComment "save correct values for first block" $ sWhen is_first_block $ forM_ (zip x_params arrs) $ \(x, arr) -> unless (isPrimParam x) $ copyDWIMFix arr [arrs_full_size + block_offset + sExt64 chunk_size + ltid] (Var $ paramName x) [] barrier let last_in_block = in_chunk_id .==. chunk_size - 1 sComment "last thread of block 'i' writes its result to offset 'i'" $ sWhen (last_in_block .&&. ltid_in_bounds) $ everythingVolatile $ zipWithM_ writeBlockResult x_params arrs barrier let first_block_seg_flag = do flag_true <- seg_flag Just $ \from to -> flag_true (from * chunk_size + chunk_size - 1) (to * chunk_size + chunk_size - 1) sComment "scan the first block, after which offset 'i' contains carry-in for block 'i+1'" $ doInChunkScan first_block_seg_flag (is_first_block .&&. ltid_in_bounds) renamed_lam errorsync when array_scan $ do sComment "move correct values for first block back a block" $ sWhen is_first_block $ forM_ (zip x_params arrs) $ \(x, arr) -> unless (isPrimParam x) $ copyDWIMFix arr [arrs_full_size + block_offset + ltid] (Var arr) [arrs_full_size + block_offset + sExt64 chunk_size + ltid] barrier no_carry_in <- dPrimVE "no_carry_in" $ is_first_block .||. bNot ltid_in_bounds let read_carry_in = sUnless no_carry_in $ do forM_ (zip x_params y_params) $ \(x, y) -> copyDWIM (paramName y) [] (Var (paramName x)) [] zipWithM_ readPrevBlockResult x_params arrs op_to_x | Nothing <- seg_flag = sUnless no_carry_in $ compileBody' x_params $ lambdaBody lam | Just flag_true <- seg_flag = do inactive <- dPrimVE "inactive" $ flag_true (chunk_id * chunk_size - 1) ltid32 sUnless no_carry_in . sWhen inactive . forM_ (zip x_params y_params) $ \(x, y) -> copyDWIM (paramName x) [] (Var (paramName y)) [] -- The convoluted control flow is to ensure all threads -- hit this barrier (if applicable). when array_scan barrier sUnless no_carry_in $ sUnless inactive $ compileBody' x_params $ lambdaBody lam write_final_result = forM_ (zip x_params arrs) $ \(p, arr) -> when (isPrimParam p) $ copyDWIMFix arr [ltid] (Var $ paramName p) [] sComment "carry-in for every block except the first" $ localOps threadOperations $ do sComment "read operands" read_carry_in sComment "perform operation" op_to_x sComment "write final result" $ sUnless no_carry_in write_final_result barrier sComment "restore correct values for first block" $ sWhen (is_first_block .&&. ltid_in_bounds) $ forM_ (zip3 x_params y_params arrs) $ \(x, y, arr) -> if isPrimParam y then copyDWIMFix arr [ltid] (Var $ paramName y) [] else copyDWIMFix (paramName x) [] (Var arr) [arrs_full_size + block_offset + ltid] barrier blockReduce :: Imp.TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen () blockReduce w lam arrs = do offset <- dPrim "offset" blockReduceWithOffset offset w lam arrs blockReduceWithOffset :: TV Int32 -> Imp.TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen () blockReduceWithOffset offset w lam arrs = do constants <- kernelConstants <$> askEnv let local_tid = kernelLocalThreadId constants barrier | all primType $ lambdaReturnType lam = sOp $ Imp.Barrier Imp.FenceLocal | otherwise = sOp $ Imp.Barrier Imp.FenceGlobal errorsync | all primType $ lambdaReturnType lam = sOp $ Imp.ErrorSync Imp.FenceLocal | otherwise = sOp $ Imp.ErrorSync Imp.FenceGlobal readReduceArgument param arr = do let i = local_tid + tvExp offset copyDWIMFix (paramName param) [] (Var arr) [sExt64 i] writeReduceOpResult param arr = when (isPrimParam param) $ copyDWIMFix arr [sExt64 local_tid] (Var $ paramName param) [] writeArrayOpResult param arr = unless (isPrimParam param) $ copyDWIMFix arr [sExt64 local_tid] (Var $ paramName param) [] let (reduce_acc_params, reduce_arr_params) = splitAt (length arrs) $ lambdaParams lam skip_waves <- dPrimV "skip_waves" (1 :: Imp.TExp Int32) dLParams $ lambdaParams lam offset <-- (0 :: Imp.TExp Int32) sComment "participating threads read initial accumulator" $ localOps threadOperations . sWhen (local_tid .<. w) $ zipWithM_ readReduceArgument reduce_acc_params arrs let do_reduce = localOps threadOperations $ do sComment "read array element" $ zipWithM_ readReduceArgument reduce_arr_params arrs sComment "apply reduction operation" $ compileBody' reduce_acc_params $ lambdaBody lam sComment "write result of operation" $ zipWithM_ writeReduceOpResult reduce_acc_params arrs in_wave_reduce = everythingVolatile do_reduce wave_size = kernelWaveSize constants tblock_size = kernelBlockSize constants wave_id = local_tid `quot` wave_size in_wave_id = local_tid - wave_id * wave_size num_waves = (sExt32 tblock_size + wave_size - 1) `quot` wave_size arg_in_bounds = local_tid + tvExp offset .<. w doing_in_wave_reductions = tvExp offset .<. wave_size apply_in_in_wave_iteration = (in_wave_id .&. (2 * tvExp offset - 1)) .==. 0 in_wave_reductions = do offset <-- (1 :: Imp.TExp Int32) sWhile doing_in_wave_reductions $ do sWhen (arg_in_bounds .&&. apply_in_in_wave_iteration) in_wave_reduce offset <-- tvExp offset * 2 doing_cross_wave_reductions = tvExp skip_waves .<. num_waves is_first_thread_in_wave = in_wave_id .==. 0 wave_not_skipped = (wave_id .&. (2 * tvExp skip_waves - 1)) .==. 0 apply_in_cross_wave_iteration = arg_in_bounds .&&. is_first_thread_in_wave .&&. wave_not_skipped cross_wave_reductions = sWhile doing_cross_wave_reductions $ do barrier offset <-- tvExp skip_waves * wave_size sWhen apply_in_cross_wave_iteration do_reduce skip_waves <-- tvExp skip_waves * 2 in_wave_reductions cross_wave_reductions errorsync unless (all isPrimParam reduce_acc_params) $ sComment "Copy array-typed operands to result array" $ sWhen (local_tid .==. 0) $ localOps threadOperations $ zipWithM_ writeArrayOpResult reduce_acc_params arrs compileThreadOp :: OpCompiler GPUMem KernelEnv Imp.KernelOp compileThreadOp pat (Alloc size space) = threadAlloc pat size space compileThreadOp pat _ = compilerBugS $ "compileThreadOp: cannot compile rhs of binding " ++ prettyString pat -- | Perform a scalar write followed by a fence. writeAtomic :: VName -> [Imp.TExp Int64] -> SubExp -> [Imp.TExp Int64] -> InKernelGen () writeAtomic dst dst_is src src_is = do t <- stripArray (length dst_is) <$> lookupType dst sLoopSpace (map pe64 (arrayDims t)) $ \is -> do let pt = elemType t (dst_mem, dst_space, dst_offset) <- fullyIndexArray dst (dst_is ++ is) case src_is ++ is of [] -> sOp . Imp.Atomic dst_space $ Imp.AtomicWrite pt dst_mem dst_offset (toExp' pt src) _ -> do tmp <- dPrimSV "tmp" pt copyDWIMFix (tvVar tmp) [] src (src_is ++ is) sOp . Imp.Atomic dst_space $ Imp.AtomicWrite pt dst_mem dst_offset (untyped (tvExp tmp)) -- | Locking strategy used for an atomic update. data Locking = Locking { -- | Array containing the lock. lockingArray :: VName, -- | Value for us to consider the lock free. lockingIsUnlocked :: Imp.TExp Int32, -- | What to write when we lock it. lockingToLock :: Imp.TExp Int32, -- | What to write when we unlock it. lockingToUnlock :: Imp.TExp Int32, -- | A transformation from the logical lock index to the -- physical position in the array. This can also be used -- to make the lock array smaller. lockingMapping :: [Imp.TExp Int64] -> [Imp.TExp Int64] } -- | A function for generating code for an atomic update. Assumes -- that the bucket is in-bounds. type DoAtomicUpdate rep r = Space -> [VName] -> [Imp.TExp Int64] -> ImpM rep r Imp.KernelOp () -- | The mechanism that will be used for performing the atomic update. -- Approximates how efficient it will be. Ordered from most to least -- efficient. data AtomicUpdate rep r = -- | Supported directly by primitive. AtomicPrim (DoAtomicUpdate rep r) | -- | Can be done by efficient swaps. AtomicCAS (DoAtomicUpdate rep r) | -- | Requires explicit locking. AtomicLocking (Locking -> DoAtomicUpdate rep r) -- | Is there an atomic t'BinOp' corresponding to this t'BinOp'? type AtomicBinOp = BinOp -> Maybe (VName -> VName -> Count Imp.Elements (Imp.TExp Int64) -> Imp.Exp -> Imp.AtomicOp) -- | Do an atomic update corresponding to a binary operator lambda. atomicUpdateLocking :: AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv atomicUpdateLocking atomicBinOp lam | Just ops_and_ts <- lamIsBinOp lam, all (\(_, t, _, _) -> primBitSize t `elem` [32, 64]) ops_and_ts = primOrCas ops_and_ts $ \space arrs bucket -> -- If the operator is a vectorised binary operator on 32/64-bit -- values, we can use a particularly efficient -- implementation. If the operator has an atomic implementation -- we use that, otherwise it is still a binary operator which -- can be implemented by atomic compare-and-swap if 32/64 bits. forM_ (zip arrs ops_and_ts) $ \(a, (op, t, x, y)) -> do -- Common variables. old <- dPrimS "old" t (arr', _a_space, bucket_offset) <- fullyIndexArray a bucket case opHasAtomicSupport space old arr' bucket_offset op of Just f -> sOp $ f $ Imp.var y t Nothing -> atomicUpdateCAS space t a old bucket x $ x <~~ Imp.BinOpExp op (Imp.var x t) (Imp.var y t) where opHasAtomicSupport space old arr' bucket' bop = do let atomic f = Imp.Atomic space . f old arr' bucket' atomic <$> atomicBinOp bop primOrCas ops | all isPrim ops = AtomicPrim | otherwise = AtomicCAS isPrim (op, _, _, _) = isJust $ atomicBinOp op -- If the operator functions purely on single 32/64-bit values, we can -- use an implementation based on CAS, no matter what the operator -- does. atomicUpdateLocking _ op | [Prim t] <- lambdaReturnType op, [xp, _] <- lambdaParams op, primBitSize t `elem` [32, 64] = AtomicCAS $ \space [arr] bucket -> do old <- dPrimS "old" t atomicUpdateCAS space t arr old bucket (paramName xp) $ compileBody' [xp] (lambdaBody op) atomicUpdateLocking _ op = AtomicLocking $ \locking space arrs bucket -> do old <- dPrim "old" continue <- dPrimVol "continue" Bool true -- Correctly index into locks. (locks', _locks_space, locks_offset) <- fullyIndexArray (lockingArray locking) $ lockingMapping locking bucket -- Critical section let try_acquire_lock = sOp $ Imp.Atomic space $ Imp.AtomicCmpXchg int32 (tvVar old) locks' locks_offset (untyped $ lockingIsUnlocked locking) (untyped $ lockingToLock locking) lock_acquired = tvExp old .==. lockingIsUnlocked locking -- Even the releasing is done with an atomic rather than a -- simple write, for memory coherency reasons. release_lock = sOp $ Imp.Atomic space $ Imp.AtomicCmpXchg int32 (tvVar old) locks' locks_offset (untyped $ lockingToLock locking) (untyped $ lockingToUnlock locking) break_loop = continue <-- false -- Preparing parameters. It is assumed that the caller has already -- filled the arr_params. We copy the current value to the -- accumulator parameters. -- -- Note the use of 'everythingVolatile' when reading and writing the -- buckets. This was necessary to ensure correct execution on a -- newer NVIDIA GPU (RTX 2080). The 'volatile' modifiers likely -- make the writes pass through the (SM-local) L1 cache, which is -- necessary here, because we are really doing device-wide -- synchronisation without atomics (naughty!). let (acc_params, _arr_params) = splitAt (length arrs) $ lambdaParams op bind_acc_params = everythingVolatile $ sComment "bind lhs" $ forM_ (zip acc_params arrs) $ \(acc_p, arr) -> copyDWIMFix (paramName acc_p) [] (Var arr) bucket let op_body = sComment "execute operation" $ compileBody' acc_params $ lambdaBody op do_hist = everythingVolatile $ sComment "update global result" $ zipWithM_ (writeArray bucket) arrs $ map (Var . paramName) acc_params -- While-loop: Try to insert your value sWhile (tvExp continue) $ do try_acquire_lock sWhen lock_acquired $ do dLParams acc_params bind_acc_params op_body do_hist release_lock break_loop where writeArray bucket arr val = writeAtomic arr bucket val [] atomicUpdateCAS :: Space -> PrimType -> VName -> VName -> [Imp.TExp Int64] -> VName -> InKernelGen () -> InKernelGen () atomicUpdateCAS space t arr old bucket x do_op = do -- Code generation target: -- -- old = d_his[idx]; -- do { -- assumed = old; -- x = do_op(assumed, y); -- old = atomicCAS(&d_his[idx], assumed, tmp); -- } while(assumed != old); assumed <- dPrimS "assumed" t run_loop <- dPrimV "run_loop" true -- XXX: CUDA may generate really bad code if this is not a volatile -- read. Unclear why. The later reads are volatile, so maybe -- that's it. everythingVolatile $ copyDWIMFix old [] (Var arr) bucket (arr', _a_space, bucket_offset) <- fullyIndexArray arr bucket -- While-loop: Try to insert your value let (toBits, fromBits) = case t of FloatType Float16 -> ( \v -> Imp.FunExp "to_bits16" [v] int16, \v -> Imp.FunExp "from_bits16" [v] t ) FloatType Float32 -> ( \v -> Imp.FunExp "to_bits32" [v] int32, \v -> Imp.FunExp "from_bits32" [v] t ) FloatType Float64 -> ( \v -> Imp.FunExp "to_bits64" [v] int64, \v -> Imp.FunExp "from_bits64" [v] t ) _ -> (id, id) int | primBitSize t == 16 = int16 | primBitSize t == 32 = int32 | otherwise = int64 sWhile (tvExp run_loop) $ do assumed <~~ Imp.var old t x <~~ Imp.var assumed t do_op old_bits_v <- newVName "old_bits" dPrim_ old_bits_v int let old_bits = Imp.var old_bits_v int sOp . Imp.Atomic space $ Imp.AtomicCmpXchg int old_bits_v arr' bucket_offset (toBits (Imp.var assumed t)) (toBits (Imp.var x t)) old <~~ fromBits old_bits let won = CmpOpExp (CmpEq int) (toBits (Imp.var assumed t)) old_bits sWhen (isBool won) (run_loop <-- false) computeKernelUses :: (FreeIn a) => a -> [VName] -> CallKernelGen [Imp.KernelUse] computeKernelUses kernel_body bound_in_kernel = do let actually_free = freeIn kernel_body `namesSubtract` namesFromList bound_in_kernel -- Compute the variables that we need to pass to the kernel. nubOrd <$> readsFromSet actually_free readsFromSet :: Names -> CallKernelGen [Imp.KernelUse] readsFromSet = fmap catMaybes . mapM f . namesToList where f var = do t <- lookupType var vtable <- getVTable case t of Array {} -> pure Nothing Acc {} -> pure Nothing Mem (Space "shared") -> pure Nothing Mem {} -> pure $ Just $ Imp.MemoryUse var Prim bt -> isConstExp vtable (Imp.var var bt) >>= \case Just ce -> pure $ Just $ Imp.ConstUse var ce Nothing -> pure $ Just $ Imp.ScalarUse var bt isConstExp :: VTable GPUMem -> Imp.Exp -> ImpM rep r op (Maybe Imp.KernelConstExp) isConstExp vtable size = do fname <- askFunction let onLeaf name _ = lookupConstExp name lookupConstExp name = constExp =<< hasExp =<< M.lookup name vtable constExp (Op (Inner (SizeOp (GetSize key c)))) = Just $ LeafExp (Imp.SizeConst (keyWithEntryPoint fname key) c) int32 constExp (Op (Inner (SizeOp (GetSizeMax c)))) = Just $ LeafExp (Imp.SizeMaxConst c) int32 constExp e = primExpFromExp lookupConstExp e pure $ replaceInPrimExpM onLeaf size where hasExp (ArrayVar e _) = e hasExp (AccVar e _) = e hasExp (ScalarVar e _) = e hasExp (MemVar e _) = e kernelInitialisationSimple :: Count NumBlocks SubExp -> Count BlockSize SubExp -> CallKernelGen (KernelConstants, InKernelGen ()) kernelInitialisationSimple num_tblocks tblock_size = do global_tid <- newVName "global_tid" local_tid <- newVName "local_tid" tblock_id <- newVName "block_id" wave_size <- newVName "wave_size" inner_tblock_size <- newVName "tblock_size" let num_tblocks' = Imp.pe64 (unCount num_tblocks) tblock_size' = Imp.pe64 (unCount tblock_size) constants = KernelConstants { kernelGlobalThreadIdVar = mkTV global_tid, kernelLocalThreadIdVar = mkTV local_tid, kernelBlockIdVar = mkTV tblock_id, kernelNumBlocksCount = num_tblocks, kernelBlockSizeCount = tblock_size, kernelNumBlocks = num_tblocks', kernelBlockSize = tblock_size', kernelNumThreads = sExt32 (tblock_size' * num_tblocks'), kernelWaveSize = Imp.le32 wave_size, kernelLocalIdMap = mempty, kernelChunkItersMap = mempty } let set_constants = do dPrim_ local_tid int32 dPrim_ inner_tblock_size int32 dPrim_ wave_size int32 dPrim_ tblock_id int32 sOp (Imp.GetLocalId local_tid 0) sOp (Imp.GetLocalSize inner_tblock_size 0) sOp (Imp.GetLockstepWidth wave_size) sOp (Imp.GetBlockId tblock_id 0) dPrimV_ global_tid $ le32 tblock_id * le32 inner_tblock_size + le32 local_tid pure (constants, set_constants) isActive :: [(VName, SubExp)] -> Imp.TExp Bool isActive limit = case actives of [] -> true x : xs -> foldl (.&&.) x xs where (is, ws) = unzip limit actives = zipWith active is $ map pe64 ws active i = (Imp.le64 i .<.) -- | Change every memory block to be in the global address space, -- except those who are in the shared memory space. This only affects -- generated code - we still need to make sure that the memory is -- actually present on the device (and declared as variables in the -- kernel). makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a makeAllMemoryGlobal = localDefaultSpace (Imp.Space "global") . localVTable (M.map globalMemory) where globalMemory (MemVar _ entry) | entryMemSpace entry /= Space "shared" = MemVar Nothing entry {entryMemSpace = Imp.Space "global"} globalMemory entry = entry simpleKernelBlocks :: Imp.TExp Int64 -> Imp.TExp Int64 -> CallKernelGen (Imp.TExp Int32, Count NumBlocks SubExp, Count BlockSize SubExp) simpleKernelBlocks max_num_tblocks kernel_size = do tblock_size <- dPrim "tblock_size" fname <- askFunction let tblock_size_key = keyWithEntryPoint fname $ nameFromString $ prettyString $ tvVar tblock_size sOp $ Imp.GetSize (tvVar tblock_size) tblock_size_key Imp.SizeThreadBlock virt_num_tblocks <- dPrimVE "virt_num_tblocks" $ kernel_size `divUp` tvExp tblock_size num_tblocks <- dPrimV "num_tblocks" $ virt_num_tblocks `sMin64` max_num_tblocks pure (sExt32 virt_num_tblocks, Count $ tvSize num_tblocks, Count $ tvSize tblock_size) simpleKernelConstants :: Imp.TExp Int64 -> String -> CallKernelGen ( (Imp.TExp Int64 -> InKernelGen ()) -> InKernelGen (), KernelConstants ) simpleKernelConstants kernel_size desc = do -- For performance reasons, codegen assumes that the thread count is -- never more than will fit in an i32. This means we need to cap -- the number of blocks here. The cap is set much higher than any -- GPU will possibly need. Feel free to come back and laugh at me -- in the future. let max_num_tblocks = 1024 * 1024 thread_gtid <- newVName $ desc ++ "_gtid" thread_ltid <- newVName $ desc ++ "_ltid" tblock_id <- newVName $ desc ++ "_gid" inner_tblock_size <- newVName "tblock_size" (virt_num_tblocks, num_tblocks, tblock_size) <- simpleKernelBlocks max_num_tblocks kernel_size let tblock_size' = Imp.pe64 $ unCount tblock_size num_tblocks' = Imp.pe64 $ unCount num_tblocks constants = KernelConstants { kernelGlobalThreadIdVar = mkTV thread_gtid, kernelLocalThreadIdVar = mkTV thread_ltid, kernelBlockIdVar = mkTV tblock_id, kernelNumBlocksCount = num_tblocks, kernelBlockSizeCount = tblock_size, kernelNumBlocks = num_tblocks', kernelBlockSize = tblock_size', kernelNumThreads = sExt32 (tblock_size' * num_tblocks'), kernelWaveSize = 0, kernelLocalIdMap = mempty, kernelChunkItersMap = mempty } wrapKernel m = do dPrim_ thread_ltid int32 dPrim_ inner_tblock_size int32 dPrim_ tblock_id int32 sOp (Imp.GetLocalId thread_ltid 0) sOp (Imp.GetLocalSize inner_tblock_size 0) sOp (Imp.GetBlockId tblock_id 0) dPrimV_ thread_gtid $ le32 tblock_id * le32 inner_tblock_size + le32 thread_ltid virtualiseBlocks SegVirt virt_num_tblocks $ \virt_tblock_id -> do global_tid <- dPrimVE "global_tid" $ sExt64 virt_tblock_id * sExt64 (le32 inner_tblock_size) + sExt64 (kernelLocalThreadId constants) m global_tid pure (wrapKernel, constants) -- | For many kernels, we may not have enough physical blocks to cover -- the logical iteration space. Some blocks thus have to perform -- double duty; we put an outer loop to accomplish this. The -- advantage over just launching a bazillion threads is that the cost -- of memory expansion should be proportional to the number of -- *physical* threads (hardware parallelism), not the amount of -- application parallelism. virtualiseBlocks :: SegVirt -> Imp.TExp Int32 -> (Imp.TExp Int32 -> InKernelGen ()) -> InKernelGen () virtualiseBlocks SegVirt required_blocks m = do constants <- kernelConstants <$> askEnv phys_tblock_id <- dPrim "phys_tblock_id" sOp $ Imp.GetBlockId (tvVar phys_tblock_id) 0 iterations <- dPrimVE "iterations" $ (required_blocks - tvExp phys_tblock_id) `divUp` sExt32 (kernelNumBlocks constants) sFor "i" iterations $ \i -> do m . tvExp =<< dPrimV "virt_tblock_id" (tvExp phys_tblock_id + i * sExt32 (kernelNumBlocks constants)) -- Make sure the virtual block is actually done before we let -- another virtual block have its way with it. sOp $ Imp.ErrorSync Imp.FenceGlobal virtualiseBlocks _ _ m = m . tvExp . kernelBlockIdVar . kernelConstants =<< askEnv -- | Various extra configuration of the kernel being generated. data KernelAttrs = KernelAttrs { -- | Can this kernel execute correctly even if previous kernels failed? kAttrFailureTolerant :: Bool, -- | Does whatever launch this kernel check for shared memory capacity itself? kAttrCheckSharedMemory :: Bool, -- | Number of blocks. kAttrNumBlocks :: Count NumBlocks SubExp, -- | Block size. kAttrBlockSize :: Count BlockSize SubExp, -- | Variables that are specially in scope inside the kernel. -- Operationally, these will be available at kernel compile time -- (which happens at run-time, with access to machine-specific -- information). kAttrConstExps :: M.Map VName Imp.KernelConstExp } -- | The default kernel attributes. defKernelAttrs :: Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelAttrs defKernelAttrs num_tblocks tblock_size = KernelAttrs { kAttrFailureTolerant = False, kAttrCheckSharedMemory = True, kAttrNumBlocks = num_tblocks, kAttrBlockSize = tblock_size, kAttrConstExps = mempty } -- | Retrieve a size of the given size class and put it in a variable -- with the given name. getSize :: String -> SizeClass -> CallKernelGen (TV Int64) getSize desc size_class = do v <- dPrim desc fname <- askFunction let v_key = keyWithEntryPoint fname $ nameFromString $ prettyString $ tvVar v sOp $ Imp.GetSize (tvVar v) v_key size_class pure v -- | Compute kernel attributes from 'SegLevel'; including synthesising -- block-size and thread count if no grid is provided. lvlKernelAttrs :: SegLevel -> CallKernelGen KernelAttrs lvlKernelAttrs lvl = case lvl of SegThread _ Nothing -> mkGrid SegThread _ (Just (KernelGrid num_tblocks tblock_size)) -> pure $ defKernelAttrs num_tblocks tblock_size SegBlock _ Nothing -> mkGrid SegBlock _ (Just (KernelGrid num_tblocks tblock_size)) -> pure $ defKernelAttrs num_tblocks tblock_size SegThreadInBlock {} -> error "lvlKernelAttrs: SegThreadInBlock" where mkGrid = do tblock_size <- getSize "tblock_size" Imp.SizeThreadBlock num_tblocks <- getSize "num_tblocks" Imp.SizeGrid pure $ defKernelAttrs (Count $ tvSize num_tblocks) (Count $ tvSize tblock_size) sKernel :: Operations GPUMem KernelEnv Imp.KernelOp -> (KernelConstants -> Imp.TExp Int64) -> String -> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen () sKernel ops flatf name v attrs f = do (constants, set_constants) <- kernelInitialisationSimple (kAttrNumBlocks attrs) (kAttrBlockSize attrs) name' <- nameForFun $ name ++ "_" ++ show (baseTag v) sKernelOp attrs constants ops name' $ do set_constants dPrimV_ v $ flatf constants f sKernelThread :: String -> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen () sKernelThread = sKernel threadOperations $ sExt64 . kernelGlobalThreadId sKernelOp :: KernelAttrs -> KernelConstants -> Operations GPUMem KernelEnv Imp.KernelOp -> Name -> InKernelGen () -> CallKernelGen () sKernelOp attrs constants ops name m = do HostEnv atomics _ locks <- askEnv body <- makeAllMemoryGlobal $ subImpM_ (KernelEnv atomics constants locks) ops m uses <- computeKernelUses body $ M.keys $ kAttrConstExps attrs tblock_size <- onBlockSize $ kernelBlockSize constants emit . Imp.Op . Imp.CallKernel $ Imp.Kernel { Imp.kernelBody = body, Imp.kernelUses = uses <> map constToUse (M.toList (kAttrConstExps attrs)), Imp.kernelNumBlocks = [untyped $ kernelNumBlocks constants], Imp.kernelBlockSize = [tblock_size], Imp.kernelName = name, Imp.kernelFailureTolerant = kAttrFailureTolerant attrs, Imp.kernelCheckSharedMemory = kAttrCheckSharedMemory attrs } where -- Figure out if this expression actually corresponds to a -- KernelConst. onBlockSize e = do vtable <- getVTable x <- isConstExp vtable $ untyped e pure $ case x of Just kc -> Right kc _ -> Left $ untyped e constToUse (v, e) = Imp.ConstUse v e sKernelFailureTolerant :: Bool -> Operations GPUMem KernelEnv Imp.KernelOp -> KernelConstants -> Name -> InKernelGen () -> CallKernelGen () sKernelFailureTolerant tol ops constants name m = do sKernelOp attrs constants ops name m where attrs = ( defKernelAttrs (kernelNumBlocksCount constants) (kernelBlockSizeCount constants) ) { kAttrFailureTolerant = tol } threadOperations :: Operations GPUMem KernelEnv Imp.KernelOp threadOperations = (defaultOperations compileThreadOp) { opsCopyCompiler = lmadCopy, opsExpCompiler = compileThreadExp, opsStmsCompiler = \_ -> defCompileStms mempty, opsAllocCompilers = M.fromList [(Space "shared", allocLocal)] } -- | Perform a Replicate with a kernel. sReplicateKernel :: VName -> SubExp -> CallKernelGen () sReplicateKernel arr se = do t <- subExpType se ds <- dropLast (arrayRank t) . arrayDims <$> lookupType arr let dims = map pe64 $ ds ++ arrayDims t n <- dPrimVE "replicate_n" $ product $ map sExt64 dims (virtualise, constants) <- simpleKernelConstants n "replicate" fname <- askFunction let name = keyWithEntryPoint fname $ nameFromString $ "replicate_" ++ show (baseTag $ tvVar $ kernelGlobalThreadIdVar constants) sKernelFailureTolerant True threadOperations constants name $ virtualise $ \gtid -> do is' <- dIndexSpace' "rep_i" dims gtid sWhen (gtid .<. n) $ copyDWIMFix arr is' se $ drop (length ds) is' replicateName :: PrimType -> String replicateName bt = "replicate_" ++ prettyString bt replicateForType :: PrimType -> CallKernelGen Name replicateForType bt = do let fname = nameFromString $ "builtin#" <> replicateName bt exists <- hasFunction fname unless exists $ do mem <- newVName "mem" num_elems <- newVName "num_elems" val <- newVName "val" let params = [ Imp.MemParam mem (Space "device"), Imp.ScalarParam num_elems int64, Imp.ScalarParam val bt ] shape = Shape [Var num_elems] function fname [] params $ do arr <- sArray "arr" bt shape mem $ LMAD.iota 0 $ map pe64 $ shapeDims shape sReplicateKernel arr $ Var val pure fname replicateIsFill :: VName -> SubExp -> CallKernelGen (Maybe (CallKernelGen ())) replicateIsFill arr v = do ArrayEntry (MemLoc arr_mem arr_shape arr_lmad) _ <- lookupArray arr v_t <- subExpType v case v_t of Prim v_t' | LMAD.isDirect arr_lmad -> pure $ Just $ do fname <- replicateForType v_t' emit $ Imp.Call [] fname [ Imp.MemArg arr_mem, Imp.ExpArg $ untyped $ product $ map pe64 arr_shape, Imp.ExpArg $ toExp' v_t' v ] _ -> pure Nothing -- | Perform a Replicate with a kernel. sReplicate :: VName -> SubExp -> CallKernelGen () sReplicate arr se = do -- If the replicate is of a particularly common and simple form -- (morally a memset()/fill), then we use a common function. is_fill <- replicateIsFill arr se case is_fill of Just m -> m Nothing -> sReplicateKernel arr se -- | Perform an Iota with a kernel. sIotaKernel :: VName -> Imp.TExp Int64 -> Imp.Exp -> Imp.Exp -> IntType -> CallKernelGen () sIotaKernel arr n x s et = do destloc <- entryArrayLoc <$> lookupArray arr (virtualise, constants) <- simpleKernelConstants n "iota" fname <- askFunction let name = keyWithEntryPoint fname $ nameFromString $ "iota_" ++ prettyString et ++ "_" ++ show (baseTag $ tvVar $ kernelGlobalThreadIdVar constants) sKernelFailureTolerant True threadOperations constants name $ virtualise $ \gtid -> sWhen (gtid .<. n) $ do (destmem, destspace, destidx) <- fullyIndexArray' destloc [gtid] emit $ Imp.Write destmem destidx (IntType et) destspace Imp.Nonvolatile $ BinOpExp (Add et OverflowWrap) (BinOpExp (Mul et OverflowWrap) (Imp.sExt et $ untyped gtid) s) x iotaName :: IntType -> String iotaName bt = "iota_" ++ prettyString bt iotaForType :: IntType -> CallKernelGen Name iotaForType bt = do let fname = nameFromString $ "builtin#" <> iotaName bt exists <- hasFunction fname unless exists $ do mem <- newVName "mem" n <- newVName "n" x <- newVName "x" s <- newVName "s" let params = [ Imp.MemParam mem (Space "device"), Imp.ScalarParam n int64, Imp.ScalarParam x $ IntType bt, Imp.ScalarParam s $ IntType bt ] shape = Shape [Var n] n' = Imp.le64 n x' = Imp.var x $ IntType bt s' = Imp.var s $ IntType bt function fname [] params $ do arr <- sArray "arr" (IntType bt) shape mem $ LMAD.iota 0 (map pe64 (shapeDims shape)) sIotaKernel arr n' x' s' bt pure fname -- | Perform an Iota with a kernel. sIota :: VName -> Imp.TExp Int64 -> Imp.Exp -> Imp.Exp -> IntType -> CallKernelGen () sIota arr n x s et = do ArrayEntry (MemLoc arr_mem _ arr_lmad) _ <- lookupArray arr if LMAD.isDirect arr_lmad then do fname <- iotaForType et emit $ Imp.Call [] fname [Imp.MemArg arr_mem, Imp.ExpArg $ untyped n, Imp.ExpArg x, Imp.ExpArg s] else sIotaKernel arr n x s et compileThreadResult :: SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen () compileThreadResult _ _ RegTileReturns {} = compilerLimitationS "compileThreadResult: RegTileReturns not yet handled." compileThreadResult space pe (Returns _ _ what) = do let is = map (Imp.le64 . fst) $ unSegSpace space copyDWIMFix (patElemName pe) is what [] compileThreadResult _ pe (WriteReturns _ arr dests) = do arr_t <- lookupType arr let rws' = map pe64 $ arrayDims arr_t forM_ dests $ \(slice, e) -> do let slice' = fmap pe64 slice write = inBounds slice' rws' sWhen write $ copyDWIM (patElemName pe) (unSlice slice') e [] compileThreadResult _ _ TileReturns {} = compilerBugS "compileThreadResult: TileReturns unhandled." futhark-0.25.27/src/Futhark/CodeGen/ImpGen/GPU/Block.hs000066400000000000000000000703331475065116200222470ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Generation of kernels with block-level bodies. module Futhark.CodeGen.ImpGen.GPU.Block ( sKernelBlock, compileBlockResult, blockOperations, -- * Precomputation Precomputed, precomputeConstants, precomputedConstants, atomicUpdateLocking, ) where import Control.Monad import Data.Bifunctor import Data.List (partition, zip4) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.CodeGen.ImpCode.GPU qualified as Imp import Futhark.CodeGen.ImpGen import Futhark.CodeGen.ImpGen.GPU.Base import Futhark.Construct (fullSliceNum) import Futhark.Error import Futhark.IR.GPUMem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.Transform.Rename import Futhark.Util (chunks, mapAccumLM, takeLast) import Futhark.Util.IntegralExp (divUp, rem) import Prelude hiding (quot, rem) -- | @flattenArray k flat arr@ flattens the outer @k@ dimensions of -- @arr@ to @flat@. (Make sure @flat@ is the sum of those dimensions -- or you'll have a bad time.) flattenArray :: Int -> TV Int64 -> VName -> ImpM rep r op VName flattenArray k flat arr = do ArrayEntry arr_loc pt <- lookupArray arr let flat_shape = Shape $ Var (tvVar flat) : drop k (memLocShape arr_loc) sArray (baseString arr ++ "_flat") pt flat_shape (memLocName arr_loc) $ fromMaybe (error "flattenArray") $ LMAD.reshape (memLocLMAD arr_loc) (map pe64 $ shapeDims flat_shape) sliceArray :: Imp.TExp Int64 -> TV Int64 -> VName -> ImpM rep r op VName sliceArray start size arr = do MemLoc mem _ lmad <- entryArrayLoc <$> lookupArray arr arr_t <- lookupType arr let slice = fullSliceNum (map Imp.pe64 (arrayDims arr_t)) [DimSlice start (tvExp size) 1] sArray (baseString arr ++ "_chunk") (elemType arr_t) (arrayShape arr_t `setOuterDim` Var (tvVar size)) mem $ LMAD.slice lmad slice -- | @applyLambda lam dests args@ emits code that: -- -- 1. Binds each parameter of @lam@ to the corresponding element of -- @args@, interpreted as a (name,slice) pair (as in 'copyDWIM'). -- Use an empty list for a scalar. -- -- 2. Executes the body of @lam@. -- -- 3. Binds the t'SubExp's that are the 'Result' of @lam@ to the -- provided @dest@s, again interpreted as the destination for a -- 'copyDWIM'. applyLambda :: (Mem rep inner) => Lambda rep -> [(VName, [DimIndex (Imp.TExp Int64)])] -> [(SubExp, [DimIndex (Imp.TExp Int64)])] -> ImpM rep r op () applyLambda lam dests args = do dLParams $ lambdaParams lam forM_ (zip (lambdaParams lam) args) $ \(p, (arg, arg_slice)) -> copyDWIM (paramName p) [] arg arg_slice compileStms mempty (bodyStms $ lambdaBody lam) $ do let res = map resSubExp $ bodyResult $ lambdaBody lam forM_ (zip dests res) $ \((dest, dest_slice), se) -> copyDWIM dest dest_slice se [] -- | As applyLambda, but first rename the names in the lambda. This -- makes it safe to apply it in multiple places. (It might be safe -- anyway, but you have to be more careful - use this if you are in -- doubt.) applyRenamedLambda :: (Mem rep inner) => Lambda rep -> [(VName, [DimIndex (Imp.TExp Int64)])] -> [(SubExp, [DimIndex (Imp.TExp Int64)])] -> ImpM rep r op () applyRenamedLambda lam dests args = do lam_renamed <- renameLambda lam applyLambda lam_renamed dests args blockChunkLoop :: Imp.TExp Int32 -> (Imp.TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen () blockChunkLoop w m = do constants <- kernelConstants <$> askEnv let max_chunk_size = sExt32 $ kernelBlockSize constants num_chunks <- dPrimVE "num_chunks" $ w `divUp` max_chunk_size sFor "chunk_i" num_chunks $ \chunk_i -> do chunk_start <- dPrimVE "chunk_start" $ chunk_i * max_chunk_size chunk_end <- dPrimVE "chunk_end" $ sMin32 w (chunk_start + max_chunk_size) chunk_size <- dPrimV "chunk_size" $ sExt64 $ chunk_end - chunk_start m chunk_start chunk_size virtualisedBlockScan :: Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) -> Imp.TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen () virtualisedBlockScan seg_flag w lam arrs = do blockChunkLoop w $ \chunk_start chunk_size -> do constants <- kernelConstants <$> askEnv let ltid = kernelLocalThreadId constants crosses_segment = case seg_flag of Nothing -> false Just flag_true -> flag_true (sExt32 (chunk_start - 1)) (sExt32 chunk_start) sComment "possibly incorporate carry" $ sWhen (chunk_start .>. 0 .&&. ltid .==. 0 .&&. bNot crosses_segment) $ do carry_idx <- dPrimVE "carry_idx" $ sExt64 chunk_start - 1 applyRenamedLambda lam (map (,[DimFix $ sExt64 chunk_start]) arrs) ( map ((,[DimFix carry_idx]) . Var) arrs ++ map ((,[DimFix $ sExt64 chunk_start]) . Var) arrs ) arrs_chunks <- mapM (sliceArray (sExt64 chunk_start) chunk_size) arrs sOp $ Imp.ErrorSync Imp.FenceLocal blockScan seg_flag (sExt64 w) (tvExp chunk_size) lam arrs_chunks copyInBlock :: CopyCompiler GPUMem KernelEnv Imp.KernelOp copyInBlock pt destloc srcloc = do dest_space <- entryMemSpace <$> lookupMemory (memLocName destloc) src_space <- entryMemSpace <$> lookupMemory (memLocName srcloc) let src_lmad = memLocLMAD srcloc dims = LMAD.shape src_lmad rank = length dims case (dest_space, src_space) of (ScalarSpace destds _, ScalarSpace srcds _) -> do let fullDim d = DimSlice 0 d 1 destslice' = Slice $ replicate (rank - length destds) (DimFix 0) ++ takeLast (length destds) (map fullDim dims) srcslice' = Slice $ replicate (rank - length srcds) (DimFix 0) ++ takeLast (length srcds) (map fullDim dims) lmadCopy pt (sliceMemLoc destloc destslice') (sliceMemLoc srcloc srcslice') _ -> do blockCoverSpace (map sExt32 dims) $ \is -> lmadCopy pt (sliceMemLoc destloc (Slice $ map (DimFix . sExt64) is)) (sliceMemLoc srcloc (Slice $ map (DimFix . sExt64) is)) sOp $ Imp.Barrier Imp.FenceLocal localThreadIDs :: [SubExp] -> InKernelGen [Imp.TExp Int64] localThreadIDs dims = do ltid <- sExt64 . kernelLocalThreadId . kernelConstants <$> askEnv let dims' = map pe64 dims maybe (dIndexSpace' "ltid" dims' ltid) (pure . map sExt64) . M.lookup dims . kernelLocalIdMap . kernelConstants =<< askEnv partitionSeqDims :: SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)]) partitionSeqDims (SegSeqDims seq_is) space = bimap (map fst) (map fst) $ partition ((`elem` seq_is) . snd) (zip (unSegSpace space) [0 ..]) compileFlatId :: SegSpace -> InKernelGen () compileFlatId space = do ltid <- kernelLocalThreadId . kernelConstants <$> askEnv dPrimV_ (segFlat space) $ sExt64 ltid -- Construct the necessary lock arrays for an intra-block histogram. prepareIntraBlockSegHist :: Shape -> Count BlockSize SubExp -> [HistOp GPUMem] -> InKernelGen [[Imp.TExp Int64] -> InKernelGen ()] prepareIntraBlockSegHist segments tblock_size = fmap snd . mapAccumLM onOp Nothing where onOp l op = do constants <- kernelConstants <$> askEnv atomicBinOp <- kernelAtomics <$> askEnv let local_subhistos = histDest op case (l, atomicUpdateLocking atomicBinOp $ histOp op) of (_, AtomicPrim f) -> pure (l, f (Space "shared") local_subhistos) (_, AtomicCAS f) -> pure (l, f (Space "shared") local_subhistos) (Just l', AtomicLocking f) -> pure (l, f l' (Space "shared") local_subhistos) (Nothing, AtomicLocking f) -> do locks <- newVName "locks" let num_locks = pe64 $ unCount tblock_size dims = map pe64 $ shapeDims (segments <> histOpShape op <> histShape op) l' = Locking locks 0 1 0 (pure . (`rem` num_locks) . flattenIndex dims) locks_t = Array int32 (Shape [unCount tblock_size]) NoUniqueness locks_mem <- sAlloc "locks_mem" (typeSize locks_t) $ Space "shared" dArray locks int32 (arrayShape locks_t) locks_mem $ LMAD.iota 0 . map pe64 . arrayDims $ locks_t sComment "All locks start out unlocked" $ blockCoverSpace [kernelBlockSize constants] $ \is -> copyDWIMFix locks is (intConst Int32 0) [] pure (Just l', f l' (Space "shared") local_subhistos) blockCoverSegSpace :: SegVirt -> SegSpace -> InKernelGen () -> InKernelGen () blockCoverSegSpace virt space m = do let (ltids, dims) = unzip $ unSegSpace space dims' = map pe64 dims constants <- kernelConstants <$> askEnv let tblock_size = kernelBlockSize constants -- Maybe we can statically detect that this is actually a -- SegNoVirtFull and generate ever-so-slightly simpler code. let virt' = if dims' == [tblock_size] then SegNoVirtFull (SegSeqDims []) else virt case virt' of SegVirt -> do iters <- M.lookup dims . kernelChunkItersMap . kernelConstants <$> askEnv case iters of Nothing -> do iterations <- dPrimVE "iterations" $ product $ map sExt32 dims' blockLoop iterations $ \i -> do dIndexSpace (zip ltids dims') $ sExt64 i m Just num_chunks -> localOps threadOperations $ do let ltid = kernelLocalThreadId constants sFor "chunk_i" num_chunks $ \chunk_i -> do i <- dPrimVE "i" $ chunk_i * sExt32 tblock_size + ltid dIndexSpace (zip ltids dims') $ sExt64 i sWhen (inBounds (Slice (map (DimFix . le64) ltids)) dims') m SegNoVirt -> localOps threadOperations $ do zipWithM_ dPrimV_ ltids =<< localThreadIDs dims sWhen (isActive $ zip ltids dims) m SegNoVirtFull seq_dims -> do let ((ltids_seq, dims_seq), (ltids_par, dims_par)) = bimap unzip unzip $ partitionSeqDims seq_dims space sLoopNest (Shape dims_seq) $ \is_seq -> do zipWithM_ dPrimV_ ltids_seq is_seq localOps threadOperations $ do zipWithM_ dPrimV_ ltids_par =<< localThreadIDs dims_par m compileBlockExp :: ExpCompiler GPUMem KernelEnv Imp.KernelOp compileBlockExp (Pat [pe]) (BasicOp (Opaque _ se)) = -- Cannot print in GPU code. copyDWIM (patElemName pe) [] se [] -- The static arrays stuff does not work inside kernels. compileBlockExp (Pat [dest]) (BasicOp (ArrayVal vs t)) = compileBlockExp (Pat [dest]) (BasicOp (ArrayLit (map Constant vs) (Prim t))) compileBlockExp (Pat [dest]) (BasicOp (ArrayLit es _)) = forM_ (zip [0 ..] es) $ \(i, e) -> copyDWIMFix (patElemName dest) [fromIntegral (i :: Int64)] e [] compileBlockExp _ (BasicOp (UpdateAcc safety acc is vs)) = do ltid <- kernelLocalThreadId . kernelConstants <$> askEnv sWhen (ltid .==. 0) $ updateAcc safety acc is vs sOp $ Imp.Barrier Imp.FenceLocal compileBlockExp (Pat [dest]) (BasicOp (Replicate ds se)) | ds /= mempty = do flat <- newVName "rep_flat" is <- replicateM (arrayRank dest_t) (newVName "rep_i") let is' = map le64 is blockCoverSegSpace SegVirt (SegSpace flat $ zip is $ arrayDims dest_t) $ copyDWIMFix (patElemName dest) is' se (drop (shapeRank ds) is') sOp $ Imp.Barrier Imp.FenceLocal where dest_t = patElemType dest compileBlockExp (Pat [dest]) (BasicOp (Iota n e s it)) = do n' <- toExp n e' <- toExp e s' <- toExp s blockLoop (TPrimExp n') $ \i' -> do x <- dPrimV "x" $ TPrimExp $ BinOpExp (Add it OverflowUndef) e' $ BinOpExp (Mul it OverflowUndef) (untyped i') s' copyDWIMFix (patElemName dest) [i'] (Var (tvVar x)) [] sOp $ Imp.Barrier Imp.FenceLocal -- When generating code for a scalar in-place update, we must make -- sure that only one thread performs the write. When writing an -- array, the block-level copy code will take care of doing the right -- thing. compileBlockExp (Pat [pe]) (BasicOp (Update safety _ slice se)) | null $ sliceDims slice = do sOp $ Imp.Barrier Imp.FenceLocal ltid <- kernelLocalThreadId . kernelConstants <$> askEnv sWhen (ltid .==. 0) $ case safety of Unsafe -> write Safe -> sWhen (inBounds slice' dims) write sOp $ Imp.Barrier Imp.FenceLocal where slice' = fmap pe64 slice dims = map pe64 $ arrayDims $ patElemType pe write = copyDWIM (patElemName pe) (unSlice slice') se [] compileBlockExp dest e = do -- It is a messy to jump into control flow for error handling. -- Avoid that by always doing an error sync here. Potential -- improvement: only do this if any errors are pending (this could -- also be handled in later codegen). when (doSync e) $ sOp $ Imp.ErrorSync Imp.FenceLocal defCompileExp dest e where doSync Loop {} = True doSync Match {} = True doSync _ = False blockAlloc :: Pat LetDecMem -> SubExp -> Space -> InKernelGen () blockAlloc (Pat [_]) _ ScalarSpace {} = -- Handled by the declaration of the memory block, which is then -- translated to an actual scalar variable during C code generation. pure () blockAlloc (Pat [mem]) size (Space "shared") = allocLocal (patElemName mem) $ Imp.bytes $ pe64 size blockAlloc (Pat [mem]) _ _ = compilerLimitationS $ "Cannot allocate memory block " ++ prettyString mem ++ " in kernel block." blockAlloc dest _ _ = error $ "Invalid target for in-kernel allocation: " ++ show dest compileBlockOp :: OpCompiler GPUMem KernelEnv Imp.KernelOp compileBlockOp pat (Alloc size space) = blockAlloc pat size space compileBlockOp pat (Inner (SegOp (SegMap lvl space _ body))) = do compileFlatId space blockCoverSegSpace (segVirt lvl) space $ compileStms mempty (kernelBodyStms body) $ zipWithM_ (compileThreadResult space) (patElems pat) $ kernelBodyResult body sOp $ Imp.ErrorSync Imp.FenceLocal compileBlockOp pat (Inner (SegOp (SegScan lvl space scans _ body))) = do compileFlatId space let (ltids, dims) = unzip $ unSegSpace space dims' = map pe64 dims blockCoverSegSpace (segVirt lvl) space $ compileStms mempty (kernelBodyStms body) $ forM_ (zip (patNames pat) $ kernelBodyResult body) $ \(dest, res) -> copyDWIMFix dest (map Imp.le64 ltids) (kernelResultSubExp res) [] fence <- fenceForArrays $ patNames pat sOp $ Imp.ErrorSync fence let segment_size = last dims' crossesSegment from to = (sExt64 to - sExt64 from) .>. (sExt64 to `rem` segment_size) -- blockScan needs to treat the scan output as a one-dimensional -- array of scan elements, so we invent some new flattened arrays -- here. dims_flat <- dPrimV "dims_flat" $ product dims' let scan = head scans num_scan_results = length $ segBinOpNeutral scan arrs_flat <- mapM (flattenArray (length dims') dims_flat) $ take num_scan_results $ patNames pat case segVirt lvl of SegVirt -> virtualisedBlockScan (Just crossesSegment) (sExt32 $ tvExp dims_flat) (segBinOpLambda scan) arrs_flat _ -> blockScan (Just crossesSegment) (product dims') (product dims') (segBinOpLambda scan) arrs_flat compileBlockOp pat (Inner (SegOp (SegRed lvl space ops _ body))) = do compileFlatId space let dims' = map pe64 dims mkTempArr t = sAllocArray "red_arr" (elemType t) (Shape dims <> arrayShape t) $ Space "shared" tmp_arrs <- mapM mkTempArr $ concatMap (lambdaReturnType . segBinOpLambda) ops blockCoverSegSpace (segVirt lvl) space $ compileStms mempty (kernelBodyStms body) $ do let (red_res, map_res) = splitAt (segBinOpResults ops) $ kernelBodyResult body forM_ (zip tmp_arrs red_res) $ \(dest, res) -> copyDWIMFix dest (map Imp.le64 ltids) (kernelResultSubExp res) [] zipWithM_ (compileThreadResult space) map_pes map_res sOp $ Imp.ErrorSync Imp.FenceLocal let tmps_for_ops = chunks (map (length . segBinOpNeutral) ops) tmp_arrs case segVirt lvl of SegVirt -> virtCase dims' tmps_for_ops _ -> nonvirtCase dims' tmps_for_ops where (ltids, dims) = unzip $ unSegSpace space (red_pes, map_pes) = splitAt (segBinOpResults ops) $ patElems pat virtCase [dim'] tmps_for_ops = do ltid <- kernelLocalThreadId . kernelConstants <$> askEnv blockChunkLoop (sExt32 dim') $ \chunk_start chunk_size -> do sComment "possibly incorporate carry" $ sWhen (chunk_start .>. 0 .&&. ltid .==. 0) $ forM_ (zip ops tmps_for_ops) $ \(op, tmps) -> applyRenamedLambda (segBinOpLambda op) (map (,[DimFix $ sExt64 chunk_start]) tmps) ( map ((,[]) . Var . patElemName) red_pes ++ map ((,[DimFix $ sExt64 chunk_start]) . Var) tmps ) sOp $ Imp.ErrorSync Imp.FenceLocal forM_ (zip ops tmps_for_ops) $ \(op, tmps) -> do tmps_chunks <- mapM (sliceArray (sExt64 chunk_start) chunk_size) tmps blockReduce (sExt32 (tvExp chunk_size)) (segBinOpLambda op) tmps_chunks sOp $ Imp.ErrorSync Imp.FenceLocal sComment "Save result of reduction." $ forM_ (zip red_pes $ concat tmps_for_ops) $ \(pe, arr) -> copyDWIMFix (patElemName pe) [] (Var arr) [sExt64 chunk_start] -- virtCase dims' tmps_for_ops = do dims_flat <- dPrimV "dims_flat" $ product dims' let segment_size = last dims' crossesSegment from to = (sExt64 to - sExt64 from) .>. (sExt64 to `rem` sExt64 segment_size) forM_ (zip ops tmps_for_ops) $ \(op, tmps) -> do tmps_flat <- mapM (flattenArray (length dims') dims_flat) tmps virtualisedBlockScan (Just crossesSegment) (sExt32 $ tvExp dims_flat) (segBinOpLambda op) tmps_flat sOp $ Imp.ErrorSync Imp.FenceLocal sComment "Save result of reduction." $ forM_ (zip red_pes $ concat tmps_for_ops) $ \(pe, arr) -> copyDWIM (patElemName pe) [] (Var arr) (map (unitSlice 0) (init dims') ++ [DimFix $ last dims' - 1]) sOp $ Imp.Barrier Imp.FenceLocal -- Nonsegmented case (or rather, a single segment) - this we can -- handle directly with a block-level reduction. nonvirtCase [dim'] tmps_for_ops = do forM_ (zip ops tmps_for_ops) $ \(op, tmps) -> blockReduce (sExt32 dim') (segBinOpLambda op) tmps sOp $ Imp.ErrorSync Imp.FenceLocal sComment "Save result of reduction." $ forM_ (zip red_pes $ concat tmps_for_ops) $ \(pe, arr) -> copyDWIMFix (patElemName pe) [] (Var arr) [0] sOp $ Imp.ErrorSync Imp.FenceLocal -- Segmented intra-block reductions are turned into (regular) -- segmented scans. It is possible that this can be done -- better, but at least this approach is simple. nonvirtCase dims' tmps_for_ops = do -- blockScan operates on flattened arrays. This does not -- involve copying anything; merely playing with the index -- function. dims_flat <- dPrimV "dims_flat" $ product dims' let segment_size = last dims' crossesSegment from to = (sExt64 to - sExt64 from) .>. (sExt64 to `rem` sExt64 segment_size) forM_ (zip ops tmps_for_ops) $ \(op, tmps) -> do tmps_flat <- mapM (flattenArray (length dims') dims_flat) tmps blockScan (Just crossesSegment) (product dims') (product dims') (segBinOpLambda op) tmps_flat sOp $ Imp.ErrorSync Imp.FenceLocal sComment "Save result of reduction." $ forM_ (zip red_pes $ concat tmps_for_ops) $ \(pe, arr) -> copyDWIM (patElemName pe) [] (Var arr) (map (unitSlice 0) (init dims') ++ [DimFix $ last dims' - 1]) sOp $ Imp.Barrier Imp.FenceLocal compileBlockOp pat (Inner (SegOp (SegHist lvl space ops _ kbody))) = do compileFlatId space let (ltids, dims) = unzip $ unSegSpace space -- We don't need the red_pes, because it is guaranteed by our type -- rules that they occupy the same memory as the destinations for -- the ops. let num_red_res = length ops + sum (map (length . histNeutral) ops) (_red_pes, map_pes) = splitAt num_red_res $ patElems pat tblock_size <- kernelBlockSizeCount . kernelConstants <$> askEnv ops' <- prepareIntraBlockSegHist (Shape $ init dims) tblock_size ops -- Ensure that all locks have been initialised. sOp $ Imp.Barrier Imp.FenceLocal blockCoverSegSpace (segVirt lvl) space $ compileStms mempty (kernelBodyStms kbody) $ do let (red_res, map_res) = splitAt num_red_res $ kernelBodyResult kbody (red_is, red_vs) = splitAt (length ops) $ map kernelResultSubExp red_res zipWithM_ (compileThreadResult space) map_pes map_res let vs_per_op = chunks (map (length . histDest) ops) red_vs forM_ (zip4 red_is vs_per_op ops' ops) $ \(bin, op_vs, do_op, HistOp dest_shape _ _ _ shape lam) -> do let bin' = pe64 bin dest_shape' = map pe64 $ shapeDims dest_shape bin_in_bounds = inBounds (Slice [DimFix bin']) dest_shape' bin_is = map Imp.le64 (init ltids) ++ [bin'] vs_params = takeLast (length op_vs) $ lambdaParams lam sComment "perform atomic updates" $ sWhen bin_in_bounds $ do dLParams $ lambdaParams lam sLoopNest shape $ \is -> do forM_ (zip vs_params op_vs) $ \(p, v) -> copyDWIMFix (paramName p) [] v is do_op (bin_is ++ is) sOp $ Imp.ErrorSync Imp.FenceLocal compileBlockOp pat _ = compilerBugS $ "compileBlockOp: cannot compile rhs of binding " ++ prettyString pat blockOperations :: Operations GPUMem KernelEnv Imp.KernelOp blockOperations = (defaultOperations compileBlockOp) { opsCopyCompiler = copyInBlock, opsExpCompiler = compileBlockExp, opsStmsCompiler = \_ -> defCompileStms mempty, opsAllocCompilers = M.fromList [(Space "shared", allocLocal)] } arrayInSharedMemory :: SubExp -> InKernelGen Bool arrayInSharedMemory (Var name) = do res <- lookupVar name case res of ArrayVar _ entry -> (Space "shared" ==) . entryMemSpace <$> lookupMemory (memLocName (entryArrayLoc entry)) _ -> pure False arrayInSharedMemory Constant {} = pure False sKernelBlock :: String -> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen () sKernelBlock = sKernel blockOperations $ sExt64 . kernelBlockId compileBlockResult :: SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen () compileBlockResult _ pe (TileReturns _ [(w, per_block_elems)] what) = do n <- pe64 . arraySize 0 <$> lookupType what constants <- kernelConstants <$> askEnv let ltid = sExt64 $ kernelLocalThreadId constants offset = pe64 per_block_elems * sExt64 (kernelBlockId constants) -- Avoid loop for the common case where each thread is statically -- known to write at most one element. localOps threadOperations $ if pe64 per_block_elems == kernelBlockSize constants then sWhen (ltid + offset .<. pe64 w) $ copyDWIMFix (patElemName pe) [ltid + offset] (Var what) [ltid] else sFor "i" (n `divUp` kernelBlockSize constants) $ \i -> do j <- dPrimVE "j" $ kernelBlockSize constants * i + ltid sWhen (j + offset .<. pe64 w) $ copyDWIMFix (patElemName pe) [j + offset] (Var what) [j] compileBlockResult space pe (TileReturns _ dims what) = do let gids = map fst $ unSegSpace space out_tile_sizes = map (pe64 . snd) dims block_is = zipWith (*) (map Imp.le64 gids) out_tile_sizes local_is <- localThreadIDs $ map snd dims is_for_thread <- mapM (dPrimV "thread_out_index") $ zipWith (+) block_is local_is localOps threadOperations $ sWhen (isActive $ zip (map tvVar is_for_thread) $ map fst dims) $ copyDWIMFix (patElemName pe) (map tvExp is_for_thread) (Var what) local_is compileBlockResult space pe (RegTileReturns _ dims_n_tiles what) = do constants <- kernelConstants <$> askEnv let gids = map fst $ unSegSpace space (dims, block_tiles, reg_tiles) = unzip3 dims_n_tiles block_tiles' = map pe64 block_tiles reg_tiles' = map pe64 reg_tiles -- Which block tile is this block responsible for? let block_tile_is = map Imp.le64 gids -- Within the block tile, which register tile is this thread -- responsible for? reg_tile_is <- dIndexSpace' "reg_tile_i" block_tiles' $ sExt64 $ kernelLocalThreadId constants -- Compute output array slice for the register tile belonging to -- this thread. let regTileSliceDim (block_tile, block_tile_i) (reg_tile, reg_tile_i) = do tile_dim_start <- dPrimVE "tile_dim_start" $ reg_tile * (block_tile * block_tile_i + reg_tile_i) pure $ DimSlice tile_dim_start reg_tile 1 reg_tile_slices <- Slice <$> zipWithM regTileSliceDim (zip block_tiles' block_tile_is) (zip reg_tiles' reg_tile_is) localOps threadOperations $ sLoopNest (Shape reg_tiles) $ \is_in_reg_tile -> do let dest_is = fixSlice reg_tile_slices is_in_reg_tile src_is = reg_tile_is ++ is_in_reg_tile sWhen (foldl1 (.&&.) $ zipWith (.<.) dest_is $ map pe64 dims) $ copyDWIMFix (patElemName pe) dest_is (Var what) src_is compileBlockResult space pe (Returns _ _ what) = do constants <- kernelConstants <$> askEnv in_shared_memory <- arrayInSharedMemory what let gids = map (Imp.le64 . fst) $ unSegSpace space if not in_shared_memory then localOps threadOperations $ sWhen (kernelLocalThreadId constants .==. 0) $ copyDWIMFix (patElemName pe) gids what [] else -- If the result of the block is an array in shared memory, we -- store it by collective copying among all the threads of the -- block. TODO: also do this if the array is in global memory -- (but this is a bit more tricky, synchronisation-wise). copyDWIMFix (patElemName pe) gids what [] compileBlockResult _ _ WriteReturns {} = compilerLimitationS "compileBlockResult: WriteReturns not handled yet." -- | The sizes of nested iteration spaces in the kernel. type SegOpSizes = S.Set [SubExp] -- | Various useful precomputed information for block-level SegOps. data Precomputed = Precomputed { pcSegOpSizes :: SegOpSizes, pcChunkItersMap :: M.Map [SubExp] (Imp.TExp Int32) } -- | Find the sizes of nested parallelism in a t'SegOp' body. segOpSizes :: Stms GPUMem -> SegOpSizes segOpSizes = onStms where onStms = foldMap onStm onStm (Let _ _ (Op (Inner (SegOp op)))) = case segVirt $ segLevel op of SegNoVirtFull seq_dims -> S.singleton $ map snd $ snd $ partitionSeqDims seq_dims $ segSpace op _ -> S.singleton $ map snd $ unSegSpace $ segSpace op onStm (Let (Pat [pe]) _ (BasicOp (Replicate {}))) = S.singleton $ arrayDims $ patElemType pe onStm (Let (Pat [pe]) _ (BasicOp (Iota {}))) = S.singleton $ arrayDims $ patElemType pe onStm (Let (Pat [pe]) _ (BasicOp (Manifest {}))) = S.singleton $ arrayDims $ patElemType pe onStm (Let _ _ (Match _ cases defbody _)) = foldMap (onStms . bodyStms . caseBody) cases <> onStms (bodyStms defbody) onStm (Let _ _ (Loop _ _ body)) = onStms (bodyStms body) onStm _ = mempty -- | Precompute various constants and useful information. precomputeConstants :: Count BlockSize (Imp.TExp Int64) -> Stms GPUMem -> CallKernelGen Precomputed precomputeConstants tblock_size stms = do let sizes = segOpSizes stms iters_map <- M.fromList <$> mapM mkMap (S.toList sizes) pure $ Precomputed sizes iters_map where mkMap dims = do let n = product $ map Imp.pe64 dims num_chunks <- dPrimVE "num_chunks" $ sExt32 $ n `divUp` unCount tblock_size pure (dims, num_chunks) -- | Make use of various precomputed constants. precomputedConstants :: Precomputed -> InKernelGen a -> InKernelGen a precomputedConstants pre m = do ltid <- kernelLocalThreadId . kernelConstants <$> askEnv new_ids <- M.fromList <$> mapM (mkMap ltid) (S.toList (pcSegOpSizes pre)) let f env = env { kernelConstants = (kernelConstants env) { kernelLocalIdMap = new_ids, kernelChunkItersMap = pcChunkItersMap pre } } localEnv f m where mkMap ltid dims = do let dims' = map pe64 dims ids' <- dIndexSpace' "ltid_pre" dims' (sExt64 ltid) pure (dims, map sExt32 ids') futhark-0.25.27/src/Futhark/CodeGen/ImpGen/GPU/SegHist.hs000066400000000000000000001265271475065116200225720ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Our compilation strategy for 'SegHist' is based around avoiding -- bin conflicts. We do this by splitting the input into chunks, and -- for each chunk computing a single subhistogram. Then we combine -- the subhistograms using an ordinary segmented reduction ('SegRed'). -- -- There are some branches around to efficiently handle the case where -- we use only a single subhistogram (because it's large), so that we -- respect the asymptotics, and do not copy the destination array. -- -- We also use a heuristic strategy for computing subhistograms in -- shared memory when possible. Given: -- -- H: total size of histograms in bytes, including any lock arrays. -- -- G: block size -- -- T: number of bytes of shared memory each thread can be given without -- impacting occupancy (determined experimentally, e.g. 32). -- -- LMAX: maximum amount of shared memory per threadblock (hard limit). -- -- We wish to compute: -- -- COOP: cooperation level (number of threads per subhistogram) -- -- LH: number of shared memory subhistograms -- -- We do this as: -- -- COOP = ceil(H / T) -- LH = ceil((G*T)/H) -- if COOP <= G && H <= LMAX then -- use shared memory -- else -- use global memory module Futhark.CodeGen.ImpGen.GPU.SegHist (compileSegHist) where import Control.Monad import Data.List qualified as L import Data.Map qualified as M import Data.Maybe import Futhark.CodeGen.ImpCode.GPU qualified as Imp import Futhark.CodeGen.ImpGen import Futhark.CodeGen.ImpGen.GPU.Base import Futhark.CodeGen.ImpGen.GPU.SegRed (compileSegRed') import Futhark.Construct (fullSliceNum) import Futhark.IR.GPUMem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.Pass.ExplicitAllocations () import Futhark.Transform.Substitute import Futhark.Util (chunks, mapAccumLM, maxinum, splitFromEnd, takeLast) import Futhark.Util.IntegralExp (divUp, quot, rem) import Prelude hiding (quot, rem) data SubhistosInfo = SubhistosInfo { subhistosArray :: VName, subhistosAlloc :: CallKernelGen () } data SegHistSlug = SegHistSlug { slugOp :: HistOp GPUMem, slugNumSubhistos :: TV Int64, slugSubhistos :: [SubhistosInfo], slugAtomicUpdate :: AtomicUpdate GPUMem KernelEnv } histSpaceUsage :: HistOp GPUMem -> Imp.Count Imp.Bytes (Imp.TExp Int64) histSpaceUsage op = sum . map (typeSize . (`arrayOfShape` (histShape op <> histOpShape op))) $ lambdaReturnType $ histOp op histSize :: HistOp GPUMem -> Imp.TExp Int64 histSize = product . map pe64 . shapeDims . histShape histRank :: HistOp GPUMem -> Int histRank = shapeRank . histShape -- | Figure out how much memory is needed per histogram, both -- segmented and unsegmented, and compute some other auxiliary -- information. computeHistoUsage :: SegSpace -> HistOp GPUMem -> CallKernelGen ( Imp.Count Imp.Bytes (Imp.TExp Int64), Imp.Count Imp.Bytes (Imp.TExp Int64), SegHistSlug ) computeHistoUsage space op = do let segment_dims = init $ unSegSpace space num_segments = length segment_dims -- Create names for the intermediate array memory blocks, -- memory block sizes, arrays, and number of subhistograms. num_subhistos <- dPrim "num_subhistos" subhisto_infos <- forM (zip (histDest op) (histNeutral op)) $ \(dest, ne) -> do dest_t <- lookupType dest dest_mem <- entryArrayLoc <$> lookupArray dest subhistos_mem <- sDeclareMem (baseString dest ++ "_subhistos_mem") (Space "device") let subhistos_shape = Shape (map snd segment_dims ++ [tvSize num_subhistos]) <> stripDims num_segments (arrayShape dest_t) subhistos <- sArray (baseString dest ++ "_subhistos") (elemType dest_t) subhistos_shape subhistos_mem $ LMAD.iota 0 $ map pe64 $ shapeDims subhistos_shape pure $ SubhistosInfo subhistos $ do let unitHistoCase = emit $ Imp.SetMem subhistos_mem (memLocName dest_mem) $ Space "device" multiHistoCase = do let num_elems = product $ map pe64 $ shapeDims subhistos_shape subhistos_mem_size = Imp.bytes $ Imp.unCount (Imp.elements num_elems `Imp.withElemType` elemType dest_t) sAlloc_ subhistos_mem subhistos_mem_size $ Space "device" sReplicate subhistos ne subhistos_t <- lookupType subhistos let slice = fullSliceNum (map pe64 $ arrayDims subhistos_t) $ map (unitSlice 0 . pe64 . snd) segment_dims ++ [DimFix 0] sUpdate subhistos slice $ Var dest sIf (tvExp num_subhistos .==. 1) unitHistoCase multiHistoCase let h = histSpaceUsage op segmented_h = h * product (map (Imp.bytes . pe64) $ init $ segSpaceDims space) atomics <- hostAtomics <$> askEnv pure ( h, segmented_h, SegHistSlug op num_subhistos subhisto_infos $ atomicUpdateLocking atomics $ histOp op ) prepareAtomicUpdateGlobal :: Maybe Locking -> Shape -> [VName] -> SegHistSlug -> CallKernelGen ( Maybe Locking, [Imp.TExp Int64] -> InKernelGen () ) prepareAtomicUpdateGlobal l segments dests slug = -- We need a separate lock array if the operators are not all of a -- particularly simple form that permits pure atomic operations. case (l, slugAtomicUpdate slug) of (_, AtomicPrim f) -> pure (l, f (Space "global") dests) (_, AtomicCAS f) -> pure (l, f (Space "global") dests) (Just l', AtomicLocking f) -> pure (l, f l' (Space "global") dests) (Nothing, AtomicLocking f) -> do -- The number of locks used here is too low, but since we are -- currently forced to inline a huge list, I'm keeping it down -- for now. Some quick experiments suggested that it has little -- impact anyway (maybe the locking case is just too slow). -- -- A fun solution would also be to use a simple hashing -- algorithm to ensure good distribution of locks. let num_locks = 100151 dims = map pe64 $ shapeDims segments ++ shapeDims (histOpShape (slugOp slug)) ++ [tvSize (slugNumSubhistos slug)] ++ shapeDims (histShape (slugOp slug)) locks <- genZeroes "hist_locks" num_locks let l' = Locking locks 0 1 0 (pure . (`rem` fromIntegral num_locks) . flattenIndex dims) pure (Just l', f l' (Space "global") dests) -- | Some kernel bodies are not safe (or efficient) to execute -- multiple times. data Passage = MustBeSinglePass | MayBeMultiPass deriving (Eq, Ord) bodyPassage :: KernelBody GPUMem -> Passage bodyPassage kbody | mempty == consumedInKernelBody (aliasAnalyseKernelBody mempty kbody) = MayBeMultiPass | otherwise = MustBeSinglePass prepareIntermediateArraysGlobal :: Passage -> Shape -> Imp.TExp Int32 -> Imp.TExp Int64 -> [SegHistSlug] -> CallKernelGen ( Imp.TExp Int32, [[Imp.TExp Int64] -> InKernelGen ()] ) prepareIntermediateArraysGlobal passage segments hist_T hist_N slugs = do -- The paper formulae assume there is only one histogram, but in our -- implementation there can be multiple that have been horisontally -- fused. We do a bit of trickery with summings and averages to -- pretend there is really only one. For the case of a single -- histogram, the actual calculations should be the same as in the -- paper. -- The sum of all Hs. hist_H <- dPrimVE "hist_H" $ sum $ map (histSize . slugOp) slugs hist_RF <- dPrimVE "hist_RF" $ sum (map (r64 . pe64 . histRaceFactor . slugOp) slugs) / L.genericLength slugs hist_el_size <- dPrimVE "hist_el_size" $ sum $ map slugElAvgSize slugs hist_C_max <- dPrimVE "hist_C_max" $ fMin64 (r64 hist_T) $ r64 hist_H / hist_k_ct_min hist_M_min <- dPrimVE "hist_M_min" $ sMax32 1 $ sExt32 $ t64 $ r64 hist_T / hist_C_max -- Equivalent to F_L2*L2 in paper. hist_L2 <- getSize "hist_L2" Imp.SizeCache let hist_L2_ln_sz = 16 * 4 -- L2 cache line size approximation hist_RACE_exp <- dPrimVE "hist_RACE_exp" $ fMax64 1 $ (hist_k_RF * hist_RF) / (hist_L2_ln_sz / r64 hist_el_size) hist_S <- dPrim "hist_S" -- For sparse histograms (H exceeds N) we only want a single chunk. sIf (hist_N .<. hist_H) (hist_S <-- (1 :: Imp.TExp Int32)) $ hist_S <-- case passage of MayBeMultiPass -> sExt32 $ (sExt64 hist_M_min * hist_H * sExt64 hist_el_size) `divUp` t64 (hist_F_L2 * r64 (tvExp hist_L2) * hist_RACE_exp) MustBeSinglePass -> 1 emit $ Imp.DebugPrint "Race expansion factor (RACE^exp)" $ Just $ untyped hist_RACE_exp emit $ Imp.DebugPrint "Number of chunks (S)" $ Just $ untyped $ tvExp hist_S histograms <- snd <$> mapAccumLM (onOp (tvExp hist_L2) hist_M_min (tvExp hist_S) hist_RACE_exp) Nothing slugs pure (tvExp hist_S, histograms) where hist_k_ct_min = 2 -- Chosen experimentally hist_k_RF = 0.75 -- Chosen experimentally hist_F_L2 = 0.4 -- Chosen experimentally r64 = isF64 . ConvOpExp (SIToFP Int32 Float64) . untyped t64 = isInt64 . ConvOpExp (FPToSI Float64 Int64) . untyped -- "Average element size" as computed by a formula that also takes -- locking into account. slugElAvgSize slug@(SegHistSlug op _ _ do_op) = case do_op of AtomicLocking {} -> slugElSize slug `quot` (1 + L.genericLength (lambdaReturnType (histOp op))) _ -> slugElSize slug `quot` L.genericLength (lambdaReturnType (histOp op)) -- "Average element size" as computed by a formula that also takes -- locking into account. slugElSize (SegHistSlug op _ _ do_op) = sExt32 . unCount . sum $ case do_op of AtomicLocking {} -> map (typeSize . (`arrayOfShape` histOpShape op)) $ Prim int32 : lambdaReturnType (histOp op) _ -> map (typeSize . (`arrayOfShape` histOpShape op)) $ lambdaReturnType (histOp op) onOp hist_L2 hist_M_min hist_S hist_RACE_exp l slug = do let SegHistSlug op num_subhistos subhisto_info do_op = slug hist_H = histSize op hist_H_chk <- dPrimVE "hist_H_chk" $ hist_H `divUp` sExt64 hist_S emit $ Imp.DebugPrint "Chunk size (H_chk)" $ Just $ untyped hist_H_chk hist_k_max <- dPrimVE "hist_k_max" $ fMin64 (hist_F_L2 * (r64 hist_L2 / r64 (slugElSize slug)) * hist_RACE_exp) (r64 hist_N) / r64 hist_T hist_u <- dPrimVE "hist_u" $ case do_op of AtomicPrim {} -> 2 _ -> 1 hist_C <- dPrimVE "hist_C" $ fMin64 (r64 hist_T) $ r64 (hist_u * hist_H_chk) / hist_k_max -- Number of subhistograms per result histogram. hist_M <- dPrimVE "hist_M" $ case slugAtomicUpdate slug of AtomicPrim {} -> 1 _ -> sMax32 hist_M_min $ sExt32 $ t64 $ r64 hist_T / hist_C emit $ Imp.DebugPrint "Elements/thread in L2 cache (k_max)" $ Just $ untyped hist_k_max emit $ Imp.DebugPrint "Multiplication degree (M)" $ Just $ untyped hist_M emit $ Imp.DebugPrint "Cooperation level (C)" $ Just $ untyped hist_C -- num_subhistos is the variable we use to communicate back. num_subhistos <-- sExt64 hist_M -- Initialise sub-histograms. -- -- If hist_M is 1, then we just reuse the original -- destination. The idea is to avoid a copy if we are writing a -- small number of values into a very large prior histogram. dests <- forM (zip (histDest op) subhisto_info) $ \(dest, info) -> do dest_mem <- entryArrayLoc <$> lookupArray dest sub_mem <- fmap memLocName $ entryArrayLoc <$> lookupArray (subhistosArray info) let unitHistoCase = emit $ Imp.SetMem sub_mem (memLocName dest_mem) $ Space "device" multiHistoCase = subhistosAlloc info sIf (hist_M .==. 1) unitHistoCase multiHistoCase pure $ subhistosArray info (l', do_op') <- prepareAtomicUpdateGlobal l segments dests slug pure (l', do_op') histKernelGlobalPass :: [PatElem LetDecMem] -> Count NumBlocks SubExp -> Count BlockSize SubExp -> SegSpace -> [SegHistSlug] -> KernelBody GPUMem -> [[Imp.TExp Int64] -> InKernelGen ()] -> Imp.TExp Int32 -> Imp.TExp Int32 -> CallKernelGen () histKernelGlobalPass map_pes num_tblocks tblock_size space slugs kbody histograms hist_S chk_i = do let (space_is, space_sizes) = unzip $ unSegSpace space space_sizes_64 = map (sExt64 . pe64) space_sizes total_w_64 = product space_sizes_64 hist_H_chks <- forM (map (histSize . slugOp) slugs) $ \w -> dPrimVE "hist_H_chk" $ w `divUp` sExt64 hist_S sKernelThread "seghist_global" (segFlat space) (defKernelAttrs num_tblocks tblock_size) $ do constants <- kernelConstants <$> askEnv -- Compute subhistogram index for each thread, per histogram. subhisto_inds <- forM slugs $ \slug -> dPrimVE "subhisto_ind" $ sExt32 (kernelGlobalThreadId constants) `quot` ( kernelNumThreads constants `divUp` sExt32 (tvExp (slugNumSubhistos slug)) ) -- Loop over flat offsets into the input and output. The -- calculation is done with 64-bit integers to avoid overflow, -- but the final unflattened segment indexes are 32 bit. let gtid = sExt64 $ kernelGlobalThreadId constants num_threads = sExt64 $ kernelNumThreads constants kernelLoop gtid num_threads total_w_64 $ \offset -> do -- Construct segment indices. dIndexSpace (zip space_is space_sizes_64) offset -- We execute the bucket function once and update each histogram serially. -- We apply the bucket function if j=offset+ltid is less than -- num_elements. This also involves writing to the mapout -- arrays. let input_in_bounds = offset .<. total_w_64 sWhen input_in_bounds $ compileStms mempty (kernelBodyStms kbody) $ do let (red_res, map_res) = splitFromEnd (length map_pes) $ kernelBodyResult kbody sComment "save map-out results" $ forM_ (zip map_pes map_res) $ \(pe, res) -> copyDWIMFix (patElemName pe) (map (Imp.le64 . fst) $ unSegSpace space) (kernelResultSubExp res) [] let red_res_split = splitHistResults (map slugOp slugs) $ map kernelResultSubExp red_res sComment "perform atomic updates" $ forM_ (L.zip5 (map slugOp slugs) histograms red_res_split subhisto_inds hist_H_chks) $ \( HistOp dest_shape _ _ _ shape lam, do_op, (bucket, vs'), subhisto_ind, hist_H_chk ) -> do let chk_beg = sExt64 chk_i * hist_H_chk bucket' = map pe64 bucket dest_shape' = map pe64 $ shapeDims dest_shape flat_bucket = flattenIndex dest_shape' bucket' bucket_in_bounds = chk_beg .<=. flat_bucket .&&. flat_bucket .<. (chk_beg + hist_H_chk) .&&. inBounds (Slice (map DimFix bucket')) dest_shape' vs_params = takeLast (length vs') $ lambdaParams lam sWhen bucket_in_bounds $ do let bucket_is = map Imp.le64 (init space_is) ++ [sExt64 subhisto_ind] ++ unflattenIndex dest_shape' flat_bucket dLParams $ lambdaParams lam sLoopNest shape $ \is -> do forM_ (zip vs_params vs') $ \(p, res) -> copyDWIMFix (paramName p) [] res is do_op (bucket_is ++ is) histKernelGlobal :: [PatElem LetDecMem] -> Count NumBlocks SubExp -> Count BlockSize SubExp -> SegSpace -> [SegHistSlug] -> KernelBody GPUMem -> CallKernelGen () histKernelGlobal map_pes num_tblocks tblock_size space slugs kbody = do let num_tblocks' = fmap pe64 num_tblocks tblock_size' = fmap pe64 tblock_size let (_space_is, space_sizes) = unzip $ unSegSpace space num_threads = sExt32 $ unCount num_tblocks' * unCount tblock_size' emit $ Imp.DebugPrint "## Using global memory" Nothing (hist_S, histograms) <- prepareIntermediateArraysGlobal (bodyPassage kbody) (Shape (init space_sizes)) num_threads (pe64 $ last space_sizes) slugs sFor "chk_i" hist_S $ \chk_i -> histKernelGlobalPass map_pes num_tblocks tblock_size space slugs kbody histograms hist_S chk_i type InitLocalHistograms = [ ( [VName], SubExp -> InKernelGen ( [VName], [Imp.TExp Int64] -> InKernelGen () ) ) ] prepareIntermediateArraysLocal :: TV Int32 -> Count NumBlocks (Imp.TExp Int64) -> [SegHistSlug] -> CallKernelGen InitLocalHistograms prepareIntermediateArraysLocal num_subhistos_per_block blocks_per_segment = mapM onOp where onOp (SegHistSlug op num_subhistos subhisto_info do_op) = do num_subhistos <-- sExt64 (unCount blocks_per_segment) emit $ Imp.DebugPrint "Number of subhistograms in global memory per segment" $ Just $ untyped $ tvExp num_subhistos mk_op <- case do_op of AtomicPrim f -> pure $ const $ pure f AtomicCAS f -> pure $ const $ pure f AtomicLocking f -> pure $ \hist_H_chk -> do let lock_shape = Shape [tvSize num_subhistos_per_block, hist_H_chk] let dims = map pe64 $ shapeDims lock_shape locks <- sAllocArray "locks" int32 lock_shape $ Space "shared" sComment "All locks start out unlocked" $ blockCoverSpace dims $ \is -> copyDWIMFix locks is (intConst Int32 0) [] pure $ f $ Locking locks 0 1 0 id -- Initialise local-memory sub-histograms. These are -- represented as two-dimensional arrays. let init_local_subhistos hist_H_chk = do local_subhistos <- forM (histType op) $ \t -> do let subhisto_shape = setOuterDims (arrayShape t) (histRank op) (Shape [hist_H_chk]) sAllocArray "subhistogram_local" (elemType t) (Shape [tvSize num_subhistos_per_block] <> subhisto_shape) (Space "shared") do_op' <- mk_op hist_H_chk pure (local_subhistos, do_op' (Space "shared") local_subhistos) -- Initialise global-memory sub-histograms. glob_subhistos <- forM subhisto_info $ \info -> do subhistosAlloc info pure $ subhistosArray info pure (glob_subhistos, init_local_subhistos) histKernelLocalPass :: TV Int32 -> Count NumBlocks (Imp.TExp Int64) -> [PatElem LetDecMem] -> Count NumBlocks SubExp -> Count BlockSize SubExp -> SegSpace -> [SegHistSlug] -> KernelBody GPUMem -> InitLocalHistograms -> Imp.TExp Int32 -> Imp.TExp Int32 -> CallKernelGen () histKernelLocalPass num_subhistos_per_block_var blocks_per_segment map_pes num_tblocks tblock_size space slugs kbody init_histograms hist_S chk_i = do let (space_is, space_sizes) = unzip $ unSegSpace space segment_is = init space_is segment_dims = init space_sizes (i_in_segment, segment_size) = last $ unSegSpace space num_subhistos_per_block = tvExp num_subhistos_per_block_var segment_size' = pe64 segment_size num_segments <- dPrimVE "num_segments" $ product $ map pe64 segment_dims hist_H_chks <- forM (map slugOp slugs) $ \op -> dPrimV "hist_H_chk" $ histSize op `divUp` sExt64 hist_S histo_sizes <- forM (zip slugs hist_H_chks) $ \(slug, hist_H_chk) -> do let histo_dims = tvExp hist_H_chk : map pe64 (shapeDims (histOpShape (slugOp slug))) histo_size <- dPrimVE "histo_size" $ product histo_dims let block_hists_size = sExt64 num_subhistos_per_block * histo_size init_per_thread <- dPrimVE "init_per_thread" $ sExt32 $ block_hists_size `divUp` pe64 (unCount tblock_size) pure (histo_dims, histo_size, init_per_thread) let attrs = (defKernelAttrs num_tblocks tblock_size) {kAttrCheckSharedMemory = False} sKernelThread "seghist_local" (segFlat space) attrs $ virtualiseBlocks SegVirt (sExt32 $ unCount blocks_per_segment * num_segments) $ \tblock_id -> do constants <- kernelConstants <$> askEnv flat_segment_id <- dPrimVE "flat_segment_id" $ tblock_id `quot` sExt32 (unCount blocks_per_segment) gid_in_segment <- dPrimVE "gid_in_segment" $ tblock_id `rem` sExt32 (unCount blocks_per_segment) -- This pgtid is kind of a "virtualised physical" gtid - not the -- same thing as the gtid used for the SegHist itself. pgtid_in_segment <- dPrimVE "pgtid_in_segment" $ gid_in_segment * sExt32 (kernelBlockSize constants) + kernelLocalThreadId constants threads_per_segment <- dPrimVE "threads_per_segment" $ sExt32 $ unCount blocks_per_segment * kernelBlockSize constants -- Set segment indices. zipWithM_ dPrimV_ segment_is $ unflattenIndex (map pe64 segment_dims) $ sExt64 flat_segment_id histograms <- forM (zip init_histograms hist_H_chks) $ \((glob_subhistos, init_local_subhistos), hist_H_chk) -> do (local_subhistos, do_op) <- init_local_subhistos $ Var $ tvVar hist_H_chk pure (zip glob_subhistos local_subhistos, hist_H_chk, do_op) -- Find index of local subhistograms updated by this thread. We -- try to ensure, as much as possible, that threads in the same -- warp use different subhistograms, to avoid conflicts. thread_local_subhisto_i <- dPrimVE "thread_local_subhisto_i" $ kernelLocalThreadId constants `rem` num_subhistos_per_block let onSlugs f = forM_ (zip3 slugs histograms histo_sizes) $ \(slug, (dests, hist_H_chk, _), (histo_dims, histo_size, init_per_thread)) -> f slug dests (tvExp hist_H_chk) histo_dims histo_size init_per_thread let onAllHistograms f = onSlugs $ \slug dests hist_H_chk histo_dims histo_size init_per_thread -> do let block_hists_size = num_subhistos_per_block * sExt32 histo_size forM_ (zip dests (histNeutral $ slugOp slug)) $ \((dest_global, dest_local), ne) -> sFor "local_i" init_per_thread $ \i -> do j <- dPrimVE "j" $ i * sExt32 (kernelBlockSize constants) + kernelLocalThreadId constants j_offset <- dPrimVE "j_offset" $ num_subhistos_per_block * sExt32 histo_size * gid_in_segment + j local_subhisto_i <- dPrimVE "local_subhisto_i" $ j `quot` sExt32 histo_size let local_bucket_is = unflattenIndex histo_dims $ sExt64 $ j `rem` sExt32 histo_size nested_hist_size = map pe64 $ shapeDims $ histShape $ slugOp slug global_bucket_is = unflattenIndex nested_hist_size (head local_bucket_is + sExt64 chk_i * hist_H_chk) ++ tail local_bucket_is global_subhisto_i <- dPrimVE "global_subhisto_i" $ j_offset `quot` sExt32 histo_size sWhen (j .<. block_hists_size) $ f dest_local dest_global (slugOp slug) ne local_subhisto_i global_subhisto_i local_bucket_is global_bucket_is sComment "initialize histograms in shared memory" $ onAllHistograms $ \dest_local dest_global op ne local_subhisto_i global_subhisto_i local_bucket_is global_bucket_is -> sComment "First subhistogram is initialised from global memory; others with neutral element." $ do dest_global_shape <- map pe64 . arrayDims <$> lookupType dest_global let global_is = map Imp.le64 segment_is ++ [0] ++ global_bucket_is local_is = sExt64 local_subhisto_i : local_bucket_is global_in_bounds = inBounds (Slice (map DimFix global_is)) dest_global_shape sIf (global_subhisto_i .==. 0 .&&. global_in_bounds) (copyDWIMFix dest_local local_is (Var dest_global) global_is) ( sLoopNest (histOpShape op) $ \is -> copyDWIMFix dest_local (local_is ++ is) ne [] ) sOp $ Imp.Barrier Imp.FenceLocal kernelLoop (sExt64 pgtid_in_segment) (sExt64 threads_per_segment) segment_size' $ \ie -> do dPrimV_ i_in_segment ie -- We execute the bucket function once and update each histogram -- serially. This also involves writing to the mapout arrays if -- this is the first chunk. compileStms mempty (kernelBodyStms kbody) $ do let (red_res, map_res) = splitFromEnd (length map_pes) $ map kernelResultSubExp $ kernelBodyResult kbody sWhen (chk_i .==. 0) $ sComment "save map-out results" $ forM_ (zip map_pes map_res) $ \(pe, se) -> copyDWIMFix (patElemName pe) (map Imp.le64 space_is) se [] let red_res_split = splitHistResults (map slugOp slugs) red_res forM_ (zip3 (map slugOp slugs) histograms red_res_split) $ \( HistOp dest_shape _ _ _ shape lam, (_, hist_H_chk, do_op), (bucket, vs') ) -> do let chk_beg = sExt64 chk_i * tvExp hist_H_chk bucket' = map pe64 bucket dest_shape' = map pe64 $ shapeDims dest_shape flat_bucket = flattenIndex dest_shape' bucket' bucket_in_bounds = inBounds (Slice (map DimFix bucket')) dest_shape' .&&. chk_beg .<=. flat_bucket .&&. flat_bucket .<. (chk_beg + tvExp hist_H_chk) bucket_is = [sExt64 thread_local_subhisto_i, flat_bucket - chk_beg] vs_params = takeLast (length vs') $ lambdaParams lam sComment "perform atomic updates" $ sWhen bucket_in_bounds $ do dLParams $ lambdaParams lam sLoopNest shape $ \is -> do forM_ (zip vs_params vs') $ \(p, v) -> copyDWIMFix (paramName p) [] v is do_op (bucket_is ++ is) sOp $ Imp.ErrorSync Imp.FenceGlobal sComment "Compact the multiple shared memory subhistograms to result in global memory" $ onSlugs $ \slug dests hist_H_chk histo_dims _histo_size bins_per_thread -> do trunc_H <- dPrimV "trunc_H" . sMin64 hist_H_chk $ histSize (slugOp slug) - sExt64 chk_i * head histo_dims let trunc_histo_dims = tvExp trunc_H : map pe64 (shapeDims (histOpShape (slugOp slug))) trunc_histo_size <- dPrimVE "histo_size" $ sExt32 $ product trunc_histo_dims sFor "local_i" bins_per_thread $ \i -> do j <- dPrimVE "j" $ i * sExt32 (kernelBlockSize constants) + kernelLocalThreadId constants sWhen (j .<. trunc_histo_size) $ do -- We are responsible for compacting the flat bin 'j', which -- we immediately unflatten. let local_bucket_is = unflattenIndex histo_dims $ sExt64 j nested_hist_size = map pe64 $ shapeDims $ histShape $ slugOp slug global_bucket_is = unflattenIndex nested_hist_size (head local_bucket_is + sExt64 chk_i * hist_H_chk) ++ tail local_bucket_is dLParams $ lambdaParams $ histOp $ slugOp slug let (global_dests, local_dests) = unzip dests (xparams, yparams) = splitAt (length local_dests) $ lambdaParams $ histOp $ slugOp slug sComment "Read values from subhistogram 0." $ forM_ (zip xparams local_dests) $ \(xp, subhisto) -> copyDWIMFix (paramName xp) [] (Var subhisto) (0 : local_bucket_is) sComment "Accumulate based on values in other subhistograms." $ sFor "subhisto_id" (num_subhistos_per_block - 1) $ \subhisto_id -> do forM_ (zip yparams local_dests) $ \(yp, subhisto) -> copyDWIMFix (paramName yp) [] (Var subhisto) (sExt64 subhisto_id + 1 : local_bucket_is) compileBody' xparams $ lambdaBody $ histOp $ slugOp slug sComment "Put final bucket value in global memory." $ do let global_is = map Imp.le64 segment_is ++ [sExt64 tblock_id `rem` unCount blocks_per_segment] ++ global_bucket_is forM_ (zip xparams global_dests) $ \(xp, global_dest) -> copyDWIMFix global_dest global_is (Var $ paramName xp) [] histKernelLocal :: TV Int32 -> Count NumBlocks (Imp.TExp Int64) -> [PatElem LetDecMem] -> Count NumBlocks SubExp -> Count BlockSize SubExp -> SegSpace -> Imp.TExp Int32 -> [SegHistSlug] -> KernelBody GPUMem -> CallKernelGen () histKernelLocal num_subhistos_per_block_var blocks_per_segment map_pes num_tblocks tblock_size space hist_S slugs kbody = do let num_subhistos_per_block = tvExp num_subhistos_per_block_var emit $ Imp.DebugPrint "Number of local subhistograms per block" $ Just $ untyped num_subhistos_per_block init_histograms <- prepareIntermediateArraysLocal num_subhistos_per_block_var blocks_per_segment slugs sFor "chk_i" hist_S $ \chk_i -> histKernelLocalPass num_subhistos_per_block_var blocks_per_segment map_pes num_tblocks tblock_size space slugs kbody init_histograms hist_S chk_i -- | The maximum number of passes we are willing to accept for this -- kind of atomic update. slugMaxLocalMemPasses :: SegHistSlug -> Int slugMaxLocalMemPasses slug = case slugAtomicUpdate slug of AtomicPrim _ -> 3 AtomicCAS _ -> 4 AtomicLocking _ -> 6 localMemoryCase :: [PatElem LetDecMem] -> Imp.TExp Int32 -> SegSpace -> Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Int32 -> [SegHistSlug] -> KernelBody GPUMem -> CallKernelGen (Imp.TExp Bool, CallKernelGen ()) localMemoryCase map_pes hist_T space hist_H hist_el_size hist_N _ slugs kbody = do let space_sizes = segSpaceDims space segment_dims = init space_sizes segmented = not $ null segment_dims hist_L <- getSize "hist_L" Imp.SizeSharedMemory max_tblock_size :: TV Int64 <- dPrim "max_tblock_size" sOp $ Imp.GetSizeMax (tvVar max_tblock_size) Imp.SizeThreadBlock -- XXX: we need to record for later use that max_tblock_size is the -- result of GetSizeMax. This is an ugly hack that reflects our -- inability to track which variables are actually constants. let withSizeMax vtable = case M.lookup (tvVar max_tblock_size) vtable of Just (ScalarVar _ se) -> M.insert (tvVar max_tblock_size) (ScalarVar (Just (Op (Inner (SizeOp (GetSizeMax SizeThreadBlock))))) se) vtable _ -> vtable let tblock_size = Imp.Count $ Var $ tvVar max_tblock_size num_tblocks <- fmap (Imp.Count . tvSize) $ dPrimV "num_tblocks" $ sExt64 hist_T `divUp` pe64 (unCount tblock_size) let num_tblocks' = pe64 <$> num_tblocks tblock_size' = pe64 <$> tblock_size let r64 = isF64 . ConvOpExp (SIToFP Int64 Float64) . untyped t64 = isInt64 . ConvOpExp (FPToSI Float64 Int64) . untyped -- M approximation. hist_m' <- dPrimVE "hist_m_prime" $ r64 ( sMin64 (sExt64 (tvExp hist_L `quot` hist_el_size)) (hist_N `divUp` sExt64 (unCount num_tblocks')) ) / r64 hist_H let hist_B = unCount tblock_size' -- M in the paper, but not adjusted for asymptotic efficiency. hist_M0 <- dPrimVE "hist_M0" $ sMax64 1 $ sMin64 (t64 hist_m') hist_B -- Minimal sequential chunking factor. let q_small = 2 -- The number of segments/histograms produced.. hist_Nout <- dPrimVE "hist_Nout" $ product $ map pe64 segment_dims hist_Nin <- dPrimVE "hist_Nin" $ pe64 $ last space_sizes -- Maximum M for work efficiency. work_asymp_M_max <- if segmented then do hist_T_hist_min <- dPrimVE "hist_T_hist_min" $ sExt32 $ sMin64 (sExt64 hist_Nin * sExt64 hist_Nout) (sExt64 hist_T) `divUp` sExt64 hist_Nout -- Number of blocks, rounded up. let r = hist_T_hist_min `divUp` sExt32 hist_B dPrimVE "work_asymp_M_max" $ hist_Nin `quot` (sExt64 r * hist_H) else dPrimVE "work_asymp_M_max" $ (hist_Nout * hist_N) `quot` ( (q_small * unCount num_tblocks' * hist_H) `quot` L.genericLength slugs ) -- Number of subhistograms per result histogram. hist_M <- dPrimV "hist_M" $ sExt32 $ sMin64 hist_M0 work_asymp_M_max -- hist_M may be zero (which we'll check for below), but we need it -- for some divisions first, so crudely make a nonzero form. let hist_M_nonzero = sMax32 1 $ tvExp hist_M -- "Cooperation factor" - the number of threads cooperatively -- working on the same (sub)histogram. hist_C <- dPrimVE "hist_C" $ hist_B `divUp` sExt64 hist_M_nonzero emit $ Imp.DebugPrint "local hist_M0" $ Just $ untyped hist_M0 emit $ Imp.DebugPrint "local work asymp M max" $ Just $ untyped work_asymp_M_max emit $ Imp.DebugPrint "local C" $ Just $ untyped hist_C emit $ Imp.DebugPrint "local B" $ Just $ untyped hist_B emit $ Imp.DebugPrint "local M" $ Just $ untyped $ tvExp hist_M emit $ Imp.DebugPrint "shared memory needed" $ Just $ untyped $ hist_H * hist_el_size * sExt64 (tvExp hist_M) -- local_mem_needed is what we need to keep a single bucket in local -- memory - this is an absolute minimum. We can fit anything else -- by doing multiple passes, although more than a few is -- (heuristically) not efficient. local_mem_needed <- dPrimVE "local_mem_needed" $ hist_el_size * sExt64 (tvExp hist_M) -- We add one to the memory requirement because if the chunk -- otherwise *exactly* fits, it might actually *not* fit in the case -- of a multi-value operator, as we individually round up the sizes -- of the component arrays. (Very rare edge case.) hist_S <- dPrimVE "hist_S" . sExt32 $ (hist_H * local_mem_needed + 1) `divUp` tvExp hist_L let max_S = case bodyPassage kbody of MustBeSinglePass -> 1 MayBeMultiPass -> fromIntegral $ maxinum $ map slugMaxLocalMemPasses slugs blocks_per_segment <- if segmented then fmap Count $ dPrimVE "blocks_per_segment" $ unCount num_tblocks' `divUp` hist_Nout else pure num_tblocks' -- We only use shared memory if the number of updates per histogram -- at least matches the histogram size, as otherwise it is not -- asymptotically efficient. This mostly matters for the segmented -- case. let pick_local = hist_Nin .>=. hist_H .&&. (local_mem_needed .<=. tvExp hist_L) .&&. (hist_S .<=. max_S) .&&. hist_C .<=. hist_B .&&. tvExp hist_M .>. 0 run = do emit $ Imp.DebugPrint "## Using shared memory" Nothing emit $ Imp.DebugPrint "Histogram size (H)" $ Just $ untyped hist_H emit $ Imp.DebugPrint "Multiplication degree (M)" $ Just $ untyped $ tvExp hist_M emit $ Imp.DebugPrint "Cooperation level (C)" $ Just $ untyped hist_C emit $ Imp.DebugPrint "Number of chunks (S)" $ Just $ untyped hist_S when segmented $ emit $ Imp.DebugPrint "Blocks per segment" $ Just $ untyped $ unCount blocks_per_segment localVTable withSizeMax $ histKernelLocal hist_M blocks_per_segment map_pes num_tblocks tblock_size space hist_S slugs kbody pure (pick_local, run) -- | Generate code for a segmented histogram called from the host. compileSegHist :: Pat LetDecMem -> SegLevel -> SegSpace -> [HistOp GPUMem] -> KernelBody GPUMem -> CallKernelGen () compileSegHist (Pat pes) lvl space ops kbody = do KernelAttrs {kAttrNumBlocks = num_tblocks, kAttrBlockSize = tblock_size} <- lvlKernelAttrs lvl -- Most of this function is not the histogram part itself, but -- rather figuring out whether to use a local or global memory -- strategy, as well as collapsing the subhistograms produced (which -- are always in global memory, but their number may vary). let num_tblocks' = fmap pe64 num_tblocks tblock_size' = fmap pe64 tblock_size dims = map pe64 $ segSpaceDims space num_red_res = length ops + sum (map (length . histNeutral) ops) (all_red_pes, map_pes) = splitAt num_red_res pes segment_size = last dims (op_hs, op_seg_hs, slugs) <- unzip3 <$> mapM (computeHistoUsage space) ops h <- dPrimVE "h" $ Imp.unCount $ sum op_hs seg_h <- dPrimVE "seg_h" $ Imp.unCount $ sum op_seg_hs -- Check for emptyness to avoid division-by-zero. sUnless (seg_h .==. 0) $ do -- Maximum block size (or actual, in this case). let hist_B = unCount tblock_size' -- Size of a histogram. hist_H <- dPrimVE "hist_H" $ sum $ map histSize ops -- Size of a single histogram element. Actually the weighted -- average of histogram elements in cases where we have more than -- one histogram operation, plus any locks. let lockSize slug = case slugAtomicUpdate slug of AtomicLocking {} -> Just $ primByteSize int32 _ -> Nothing hist_el_size <- dPrimVE "hist_el_size" $ L.foldl' (+) (h `divUp` hist_H) $ mapMaybe lockSize slugs -- Input elements contributing to each histogram. hist_N <- dPrimVE "hist_N" segment_size -- Compute RF as the average RF over all the histograms. hist_RF <- dPrimVE "hist_RF" $ sExt32 $ sum (map (pe64 . histRaceFactor . slugOp) slugs) `quot` L.genericLength slugs let hist_T = sExt32 $ unCount num_tblocks' * unCount tblock_size' emit $ Imp.DebugPrint "\n# SegHist" Nothing emit $ Imp.DebugPrint "Number of threads (T)" $ Just $ untyped hist_T emit $ Imp.DebugPrint "Desired block size (B)" $ Just $ untyped hist_B emit $ Imp.DebugPrint "Histogram size (H)" $ Just $ untyped hist_H emit $ Imp.DebugPrint "Input elements per histogram (N)" $ Just $ untyped hist_N emit $ Imp.DebugPrint "Number of segments" $ Just $ untyped $ product $ map (pe64 . snd) segment_dims emit $ Imp.DebugPrint "Histogram element size (el_size)" $ Just $ untyped hist_el_size emit $ Imp.DebugPrint "Race factor (RF)" $ Just $ untyped hist_RF emit $ Imp.DebugPrint "Memory per set of subhistograms per segment" $ Just $ untyped h emit $ Imp.DebugPrint "Memory per set of subhistograms times segments" $ Just $ untyped seg_h (use_shared_memory, run_in_shared_memory) <- localMemoryCase map_pes hist_T space hist_H hist_el_size hist_N hist_RF slugs kbody sIf use_shared_memory run_in_shared_memory $ histKernelGlobal map_pes num_tblocks tblock_size space slugs kbody let pes_per_op = chunks (map (length . histDest) ops) all_red_pes forM_ (zip3 slugs pes_per_op ops) $ \(slug, red_pes, op) -> do let num_histos = slugNumSubhistos slug subhistos = map subhistosArray $ slugSubhistos slug let unitHistoCase = -- This is OK because the memory blocks are at least as -- large as the ones we are supposed to use for the result. forM_ (zip red_pes subhistos) $ \(pe, subhisto) -> do pe_mem <- memLocName . entryArrayLoc <$> lookupArray (patElemName pe) subhisto_mem <- memLocName . entryArrayLoc <$> lookupArray subhisto emit $ Imp.SetMem pe_mem subhisto_mem $ Space "device" sIf (tvExp num_histos .==. 1) unitHistoCase $ do -- For the segmented reduction, we keep the segment dimensions -- unchanged. To this, we add two dimensions: one over the number -- of buckets, and one over the number of subhistograms. This -- inner dimension is the one that is collapsed in the reduction. bucket_ids <- replicateM (shapeRank (histShape op)) (newVName "bucket_id") subhistogram_id <- newVName "subhistogram_id" vector_ids <- replicateM (shapeRank (histOpShape op)) (newVName "vector_id") flat_gtid <- newVName "flat_gtid" let grid = KernelGrid num_tblocks tblock_size segred_space = SegSpace flat_gtid $ segment_dims ++ zip bucket_ids (shapeDims (histShape op)) ++ zip vector_ids (shapeDims $ histOpShape op) ++ [(subhistogram_id, Var $ tvVar num_histos)] -- The operator may have references to the old flat thread -- ID, which we must update to point at the new one. subst = M.singleton (segFlat space) flat_gtid let segred_op = SegBinOp Commutative (substituteNames subst $ histOp op) (histNeutral op) mempty compileSegRed' (Pat red_pes) grid segred_space [segred_op] $ \red_cont -> red_cont . flip map subhistos $ \subhisto -> ( Var subhisto, map Imp.le64 $ map fst segment_dims ++ [subhistogram_id] ++ bucket_ids ++ vector_ids ) emit $ Imp.DebugPrint "" Nothing where segment_dims = init $ unSegSpace space futhark-0.25.27/src/Futhark/CodeGen/ImpGen/GPU/SegMap.hs000066400000000000000000000046651475065116200223760ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Code generation for 'SegMap' is quite straightforward. The only -- trick is virtualisation in case the physical number of threads is -- not sufficient to cover the logical thread space. This is handled -- by having actual threadblocks run a loop to imitate multiple threadblocks. module Futhark.CodeGen.ImpGen.GPU.SegMap (compileSegMap) where import Control.Monad import Futhark.CodeGen.ImpCode.GPU qualified as Imp import Futhark.CodeGen.ImpGen import Futhark.CodeGen.ImpGen.GPU.Base import Futhark.CodeGen.ImpGen.GPU.Block import Futhark.IR.GPUMem import Futhark.Util.IntegralExp (divUp) import Prelude hiding (quot, rem) -- | Compile 'SegMap' instance code. compileSegMap :: Pat LetDecMem -> SegLevel -> SegSpace -> KernelBody GPUMem -> CallKernelGen () compileSegMap pat lvl space kbody = do attrs <- lvlKernelAttrs lvl let (is, dims) = unzip $ unSegSpace space dims' = map pe64 dims tblock_size' = pe64 <$> kAttrBlockSize attrs emit $ Imp.DebugPrint "\n# SegMap" Nothing case lvl of SegThread {} -> do virt_num_tblocks <- dPrimVE "virt_num_tblocks" $ sExt32 $ product dims' `divUp` unCount tblock_size' sKernelThread "segmap" (segFlat space) attrs $ virtualiseBlocks (segVirt lvl) virt_num_tblocks $ \tblock_id -> do local_tid <- kernelLocalThreadId . kernelConstants <$> askEnv global_tid <- dPrimVE "global_tid" $ sExt64 tblock_id * sExt64 (unCount tblock_size') + sExt64 local_tid dIndexSpace (zip is dims') global_tid sWhen (isActive $ unSegSpace space) $ compileStms mempty (kernelBodyStms kbody) $ zipWithM_ (compileThreadResult space) (patElems pat) $ kernelBodyResult kbody SegBlock {} -> do pc <- precomputeConstants tblock_size' $ kernelBodyStms kbody virt_num_tblocks <- dPrimVE "virt_num_tblocks" $ sExt32 $ product dims' sKernelBlock "segmap_intrablock" (segFlat space) attrs $ do precomputedConstants pc $ virtualiseBlocks (segVirt lvl) virt_num_tblocks $ \tblock_id -> do dIndexSpace (zip is dims') $ sExt64 tblock_id compileStms mempty (kernelBodyStms kbody) $ zipWithM_ (compileBlockResult space) (patElems pat) $ kernelBodyResult kbody SegThreadInBlock {} -> error "compileSegMap: SegThreadInBlock" emit $ Imp.DebugPrint "" Nothing futhark-0.25.27/src/Futhark/CodeGen/ImpGen/GPU/SegRed.hs000066400000000000000000001225371475065116200223720ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | We generate code for non-segmented/single-segment SegRed using -- the basic approach outlined in the paper "Design and GPGPU -- Performance of Futhark’s Redomap Construct" (ARRAY '16). The main -- deviations are: -- -- * While we still use two-phase reduction, we use only a single -- kernel, with the final threadblock to write a result (tracked via -- an atomic counter) performing the final reduction as well. -- -- * Instead of depending on storage layout transformations to handle -- non-commutative reductions efficiently, we slide a -- @tblocksize@-sized window over the input, and perform a parallel -- reduction for each window. This sacrifices the notion of -- efficient sequentialisation, but is sometimes faster and -- definitely simpler and more predictable (and uses less auxiliary -- storage). -- -- For segmented reductions we use the approach from "Strategies for -- Regular Segmented Reductions on GPU" (FHPC '17). This involves -- having two different strategies, and dynamically deciding which one -- to use based on the number of segments and segment size. We use the -- (static) @tblock_size@ to decide which of the following two -- strategies to choose: -- -- * Large: uses one or more blocks to process a single segment. If -- multiple blocks are used per segment, the intermediate reduction -- results must be recursively reduced, until there is only a single -- value per segment. -- -- Each thread /can/ read multiple elements, which will greatly -- increase performance; however, if the reduction is -- non-commutative we will have to use a less efficient traversal -- (with interim block-wide reductions) to enable coalesced memory -- accesses, just as in the non-segmented case. -- -- * Small: is used to let each block process *multiple* segments -- within a block. We will only use this approach when we can -- process at least two segments within a single block. In those -- cases, we would allocate a /whole/ block per segment with the -- large strategy, but at most 50% of the threads in the block would -- have any element to read, which becomes highly inefficient. -- -- An optimization specfically targeted at non-segmented and large-segments -- segmented reductions with non-commutative is made: The stage one main loop is -- essentially stripmined by a factor *chunk*, inserting collective copies via -- shared memory of each reduction parameter going into the intra-block (partial) -- reductions. This saves a factor *chunk* number of intra-block reductions at -- the cost of some overhead in collective copies. module Futhark.CodeGen.ImpGen.GPU.SegRed ( compileSegRed, compileSegRed', DoSegBody, ) where import Control.Monad import Data.List (genericLength, zip4) import Data.Map qualified as M import Data.Maybe import Futhark.CodeGen.ImpCode.GPU qualified as Imp import Futhark.CodeGen.ImpGen import Futhark.CodeGen.ImpGen.GPU.Base import Futhark.Error import Futhark.IR.GPUMem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.Transform.Rename import Futhark.Util (chunks, mapAccumLM) import Futhark.Util.IntegralExp (divUp, nextMul, quot, rem) import Prelude hiding (quot, rem) forM2_ :: (Monad m) => [a] -> [b] -> (a -> b -> m c) -> m () forM2_ xs ys f = forM_ (zip xs ys) (uncurry f) -- | The maximum number of operators we support in a single SegRed. -- This limit arises out of the static allocation of counters. maxNumOps :: Int maxNumOps = 20 -- | Code generation for the body of the SegRed, taking a continuation -- for saving the results of the body. The results should be -- represented as a pairing of a t'SubExp' along with a list of -- indexes into that t'SubExp' for reading the result. type DoSegBody = ([(SubExp, [Imp.TExp Int64])] -> InKernelGen ()) -> InKernelGen () -- | Datatype used to distinguish between and work with the different sets of -- intermediate memory we need for the different ReduceKinds. data SegRedIntermediateArrays = GeneralSegRedInterms { blockRedArrs :: [VName] } | NoncommPrimSegRedInterms { collCopyArrs :: [VName], blockRedArrs :: [VName], privateChunks :: [VName] } -- | Compile 'SegRed' instance to host-level code with calls to -- various kernels. compileSegRed :: Pat LetDecMem -> SegLevel -> SegSpace -> [SegBinOp GPUMem] -> KernelBody GPUMem -> CallKernelGen () compileSegRed pat lvl space segbinops map_kbody = do emit $ Imp.DebugPrint "\n# SegRed" Nothing KernelAttrs {kAttrNumBlocks = num_tblocks, kAttrBlockSize = tblock_size} <- lvlKernelAttrs lvl let grid = KernelGrid num_tblocks tblock_size compileSegRed' pat grid space segbinops $ \red_cont -> sComment "apply map function" $ compileStms mempty (kernelBodyStms map_kbody) $ do let (red_res, map_res) = splitAt (segBinOpResults segbinops) $ kernelBodyResult map_kbody let mapout_arrs = drop (segBinOpResults segbinops) $ patElems pat unless (null mapout_arrs) $ sComment "write map-out result(s)" $ do zipWithM_ (compileThreadResult space) mapout_arrs map_res red_cont $ map ((,[]) . kernelResultSubExp) red_res emit $ Imp.DebugPrint "" Nothing paramOf :: SegBinOp GPUMem -> [Param LParamMem] paramOf (SegBinOp _ op ne _) = take (length ne) $ lambdaParams op isPrimSegBinOp :: SegBinOp GPUMem -> Bool isPrimSegBinOp segbinop = all primType (lambdaReturnType $ segBinOpLambda segbinop) && shapeRank (segBinOpShape segbinop) == 0 -- | Like 'compileSegRed', but where the body is a monadic action. compileSegRed' :: Pat LetDecMem -> KernelGrid -> SegSpace -> [SegBinOp GPUMem] -> DoSegBody -> CallKernelGen () compileSegRed' pat grid space segbinops map_body_cont | genericLength segbinops > maxNumOps = compilerLimitationS $ ("compileSegRed': at most " <> show maxNumOps <> " reduction operators are supported.\n") <> ("Pattern: " <> prettyString pat) | otherwise = do chunk_v <- dPrimV "chunk_size" . isInt64 =<< kernelConstToExp chunk_const case unSegSpace space of [(_, Constant (IntValue (Int64Value 1))), _] -> compileReduction (chunk_v, chunk_const) nonsegmentedReduction _ -> do let segment_size = pe64 $ last $ segSpaceDims space use_small_segments = segment_size * 2 .<. pe64 (unCount tblock_size) * tvExp chunk_v sIf use_small_segments (compileReduction (chunk_v, chunk_const) smallSegmentsReduction) (compileReduction (chunk_v, chunk_const) largeSegmentsReduction) where compileReduction chunk f = f pat num_tblocks tblock_size chunk space segbinops map_body_cont param_types = map paramType $ concatMap paramOf segbinops num_tblocks = gridNumBlocks grid tblock_size = gridBlockSize grid chunk_const = if Noncommutative `elem` map segBinOpComm segbinops && all isPrimSegBinOp segbinops then getChunkSize param_types else Imp.ValueExp $ IntValue $ intValue Int64 (1 :: Int64) -- | Prepare intermediate arrays for the reduction. Prim-typed -- arguments go in shared memory (so we need to do the allocation of -- those arrays inside the kernel), while array-typed arguments go in -- global memory. Allocations for the latter have already been -- performed. This policy is baked into how the allocations are done -- in ExplicitAllocations. -- -- For more info about the intermediate arrays used for the different reduction -- kernels, see note [IntermArrays]. makeIntermArrays :: Imp.TExp Int64 -> SubExp -> SubExp -> [SegBinOp GPUMem] -> InKernelGen [SegRedIntermediateArrays] makeIntermArrays tblock_id tblock_size chunk segbinops | Noncommutative <- mconcat (map segBinOpComm segbinops), all isPrimSegBinOp segbinops = noncommPrimSegRedInterms | otherwise = generalSegRedInterms tblock_id tblock_size segbinops where params = map paramOf segbinops noncommPrimSegRedInterms = do block_worksize <- tvSize <$> dPrimV "block_worksize" block_worksize_E -- compute total amount of lmem. let sum_ x y = nextMul x y + tblock_size_E * y block_reds_lmem_requirement = foldl sum_ 0 $ concat elem_sizes collcopy_lmem_requirement = block_worksize_E * max_elem_size lmem_total_size = Imp.bytes $ collcopy_lmem_requirement `sMax64` block_reds_lmem_requirement -- offsets into the total pool of lmem for each block reduction array. (_, offsets) <- forAccumLM2D 0 elem_sizes $ \byte_offs elem_size -> (,byte_offs `quot` elem_size) <$> dPrimVE "offset" (sum_ byte_offs elem_size) -- total pool of local mem. lmem <- sAlloc "local_mem" lmem_total_size (Space "shared") let arrInLMem ptype name len_se offset = sArray (name ++ "_" ++ prettyString ptype) ptype (Shape [len_se]) lmem $ LMAD.iota offset [pe64 len_se] forM (zipWith zip params offsets) $ \ps_and_offsets -> do (coll_copy_arrs, block_red_arrs, priv_chunks) <- fmap unzip3 $ forM ps_and_offsets $ \(p, offset) -> do let ptype = elemType $ paramType p (,,) <$> arrInLMem ptype "coll_copy_arr" block_worksize 0 <*> arrInLMem ptype "block_red_arr" tblock_size offset <*> sAllocArray ("chunk_" ++ prettyString ptype) ptype (Shape [chunk]) (ScalarSpace [chunk] ptype) pure $ NoncommPrimSegRedInterms coll_copy_arrs block_red_arrs priv_chunks tblock_size_E = pe64 tblock_size block_worksize_E = tblock_size_E * pe64 chunk paramSize = primByteSize . elemType . paramType elem_sizes = map (map paramSize) params max_elem_size = maximum $ concat elem_sizes forAccumLM2D acc ls f = mapAccumLM (mapAccumLM f) acc ls generalSegRedInterms :: Imp.TExp Int64 -> SubExp -> [SegBinOp GPUMem] -> InKernelGen [SegRedIntermediateArrays] generalSegRedInterms tblock_id tblock_size segbinops = fmap (map GeneralSegRedInterms) $ forM (map paramOf segbinops) $ mapM $ \p -> case paramDec p of MemArray pt shape _ (ArrayIn mem _) -> do let shape' = Shape [tblock_size] <> shape let shape_E = map pe64 $ shapeDims shape' sArray ("red_arr_" ++ prettyString pt) pt shape' mem $ LMAD.iota (tblock_id * product shape_E) shape_E _ -> do let pt = elemType $ paramType p shape = Shape [tblock_size] sAllocArray ("red_arr_" ++ prettyString pt) pt shape $ Space "shared" -- | Arrays for storing block results. -- -- The block-result arrays have an extra dimension because they are -- also used for keeping vectorised accumulators for first-stage -- reduction, if necessary. If necessary, this dimension has size -- tblock_size, and otherwise 1. When actually storing block results, -- the first index is set to 0. groupResultArrays :: SubExp -> SubExp -> [SegBinOp GPUMem] -> CallKernelGen [[VName]] groupResultArrays num_virtblocks tblock_size segbinops = forM segbinops $ \(SegBinOp _ lam _ shape) -> forM (lambdaReturnType lam) $ \t -> do let pt = elemType t extra_dim | primType t, shapeRank shape == 0 = intConst Int64 1 | otherwise = tblock_size full_shape = Shape [extra_dim, num_virtblocks] <> shape <> arrayShape t -- Move the tblocksize dimension last to ensure coalesced -- memory access. perm = [1 .. shapeRank full_shape - 1] ++ [0] sAllocArrayPerm "segred_tmp" pt full_shape (Space "device") perm type DoCompileSegRed = Pat LetDecMem -> Count NumBlocks SubExp -> Count BlockSize SubExp -> (TV Int64, Imp.KernelConstExp) -> SegSpace -> [SegBinOp GPUMem] -> DoSegBody -> CallKernelGen () nonsegmentedReduction :: DoCompileSegRed nonsegmentedReduction (Pat segred_pes) num_tblocks tblock_size (chunk_v, chunk_const) space segbinops map_body_cont = do let (gtids, dims) = unzip $ unSegSpace space chunk = tvExp chunk_v num_tblocks_se = unCount num_tblocks tblock_size_se = unCount tblock_size tblock_size' = pe64 tblock_size_se global_tid = Imp.le64 $ segFlat space n = pe64 $ last dims counters <- genZeroes "counters" maxNumOps reds_block_res_arrs <- groupResultArrays num_tblocks_se tblock_size_se segbinops num_threads <- fmap tvSize $ dPrimV "num_threads" $ pe64 num_tblocks_se * tblock_size' let attrs = (defKernelAttrs num_tblocks tblock_size) { kAttrConstExps = M.singleton (tvVar chunk_v) chunk_const } sKernelThread "segred_nonseg" (segFlat space) attrs $ do constants <- kernelConstants <$> askEnv let ltid = kernelLocalThreadId constants let tblock_id = kernelBlockId constants interms <- makeIntermArrays (sExt64 tblock_id) tblock_size_se (tvSize chunk_v) segbinops sync_arr <- sAllocArray "sync_arr" Bool (Shape [intConst Int32 1]) $ Space "shared" -- Since this is the nonsegmented case, all outer segment IDs must -- necessarily be 0. forM_ gtids $ \v -> dPrimV_ v (0 :: Imp.TExp Int64) q <- dPrimVE "q" $ n `divUp` (sExt64 (kernelNumThreads constants) * chunk) slugs <- mapM (segBinOpSlug ltid tblock_id) $ zip3 segbinops interms reds_block_res_arrs new_lambdas <- reductionStageOne gtids n global_tid q chunk (pe64 num_threads) slugs map_body_cont let segred_pess = chunks (map (length . segBinOpNeutral) segbinops) segred_pes forM_ (zip4 segred_pess slugs new_lambdas [0 ..]) $ \(pes, slug, new_lambda, i) -> reductionStageTwo pes tblock_id [0] 0 (sExt64 $ kernelNumBlocks constants) slug new_lambda counters sync_arr (fromInteger i) smallSegmentsReduction :: DoCompileSegRed smallSegmentsReduction (Pat segred_pes) num_tblocks tblock_size _ space segbinops map_body_cont = do let (gtids, dims) = unzip $ unSegSpace space dims' = map pe64 dims segment_size = last dims' -- Careful to avoid division by zero now. segment_size_nonzero <- dPrimVE "segment_size_nonzero" $ sMax64 1 segment_size let tblock_size_se = unCount tblock_size num_tblocks_se = unCount tblock_size num_tblocks' = pe64 num_tblocks_se tblock_size' = pe64 tblock_size_se num_threads <- fmap tvSize $ dPrimV "num_threads" $ num_tblocks' * tblock_size' let num_segments = product $ init dims' segments_per_block = tblock_size' `quot` segment_size_nonzero required_blocks = sExt32 $ num_segments `divUp` segments_per_block emit $ Imp.DebugPrint "# SegRed-small" Nothing emit $ Imp.DebugPrint "num_segments" $ Just $ untyped num_segments emit $ Imp.DebugPrint "segment_size" $ Just $ untyped segment_size emit $ Imp.DebugPrint "segments_per_block" $ Just $ untyped segments_per_block emit $ Imp.DebugPrint "required_blocks" $ Just $ untyped required_blocks sKernelThread "segred_small" (segFlat space) (defKernelAttrs num_tblocks tblock_size) $ do constants <- kernelConstants <$> askEnv let tblock_id = kernelBlockSize constants ltid = sExt64 $ kernelLocalThreadId constants interms <- generalSegRedInterms tblock_id tblock_size_se segbinops let reds_arrs = map blockRedArrs interms -- We probably do not have enough actual threadblocks to cover the -- entire iteration space. Some blocks thus have to perform double -- duty; we put an outer loop to accomplish this. virtualiseBlocks SegVirt required_blocks $ \virttblock_id -> do -- Compute the 'n' input indices. The outer 'n-1' correspond to -- the segment ID, and are computed from the block id. The inner -- is computed from the local thread id, and may be out-of-bounds. let segment_index = (ltid `quot` segment_size_nonzero) + (sExt64 virttblock_id * sExt64 segments_per_block) index_within_segment = ltid `rem` segment_size dIndexSpace (zip (init gtids) (init dims')) segment_index dPrimV_ (last gtids) index_within_segment let in_bounds = map_body_cont $ \red_res -> sComment "save results to be reduced" $ do let red_dests = map (,[ltid]) (concat reds_arrs) forM2_ red_dests red_res $ \(d, d_is) (res, res_is) -> copyDWIMFix d d_is res res_is out_of_bounds = forM2_ segbinops reds_arrs $ \(SegBinOp _ _ nes _) red_arrs -> forM2_ red_arrs nes $ \arr ne -> copyDWIMFix arr [ltid] ne [] sComment "apply map function if in bounds" $ sIf ( segment_size .>. 0 .&&. isActive (init $ zip gtids dims) .&&. ltid .<. segment_size * segments_per_block ) in_bounds out_of_bounds sOp $ Imp.ErrorSync Imp.FenceLocal -- Also implicitly barrier. let crossesSegment from to = (sExt64 to - sExt64 from) .>. (sExt64 to `rem` segment_size) sWhen (segment_size .>. 0) $ sComment "perform segmented scan to imitate reduction" $ forM2_ segbinops reds_arrs $ \(SegBinOp _ red_op _ _) red_arrs -> blockScan (Just crossesSegment) (sExt64 $ pe64 num_threads) (segment_size * segments_per_block) red_op red_arrs sOp $ Imp.Barrier Imp.FenceLocal sComment "save final values of segments" $ sWhen ( sExt64 virttblock_id * segments_per_block + sExt64 ltid .<. num_segments .&&. ltid .<. segments_per_block ) $ forM2_ segred_pes (concat reds_arrs) $ \pe arr -> do -- Figure out which segment result this thread should write... let flat_segment_index = sExt64 virttblock_id * segments_per_block + sExt64 ltid gtids' = unflattenIndex (init dims') flat_segment_index copyDWIMFix (patElemName pe) gtids' (Var arr) [(ltid + 1) * segment_size_nonzero - 1] -- Finally another barrier, because we will be writing to the -- shared memory array first thing in the next iteration. sOp $ Imp.Barrier Imp.FenceLocal largeSegmentsReduction :: DoCompileSegRed largeSegmentsReduction (Pat segred_pes) num_tblocks tblock_size (chunk_v, chunk_const) space segbinops map_body_cont = do let (gtids, dims) = unzip $ unSegSpace space dims' = map pe64 dims num_segments = product $ init dims' segment_size = last dims' num_tblocks' = pe64 $ unCount num_tblocks tblock_size_se = unCount tblock_size tblock_size' = pe64 tblock_size_se chunk = tvExp chunk_v blocks_per_segment <- dPrimVE "blocks_per_segment" $ num_tblocks' `divUp` sMax64 1 num_segments q <- dPrimVE "q" $ segment_size `divUp` (tblock_size' * blocks_per_segment * chunk) num_virtblocks <- dPrimV "num_virtblocks" $ blocks_per_segment * num_segments threads_per_segment <- dPrimVE "threads_per_segment" $ blocks_per_segment * tblock_size' emit $ Imp.DebugPrint "# SegRed-large" Nothing emit $ Imp.DebugPrint "num_segments" $ Just $ untyped num_segments emit $ Imp.DebugPrint "segment_size" $ Just $ untyped segment_size emit $ Imp.DebugPrint "num_virtblocks" $ Just $ untyped $ tvExp num_virtblocks emit $ Imp.DebugPrint "num_tblocks" $ Just $ untyped num_tblocks' emit $ Imp.DebugPrint "tblock_size" $ Just $ untyped tblock_size' emit $ Imp.DebugPrint "q" $ Just $ untyped q emit $ Imp.DebugPrint "blocks_per_segment" $ Just $ untyped blocks_per_segment reds_block_res_arrs <- groupResultArrays (tvSize num_virtblocks) tblock_size_se segbinops -- In principle we should have a counter for every segment. Since -- the number of segments is a dynamic quantity, we would have to -- allocate and zero out an array here, which is expensive. -- However, we exploit the fact that the number of segments being -- reduced at any point in time is limited by the number of -- threadblocks. If we bound the number of threadblocks, we can get away -- with using that many counters. FIXME: Is this limit checked -- anywhere? There are other places in the compiler that will fail -- if the block count exceeds the maximum block size, which is at -- most 1024 anyway. let num_counters = maxNumOps * 1024 counters <- genZeroes "counters" $ fromIntegral num_counters let attrs = (defKernelAttrs num_tblocks tblock_size) { kAttrConstExps = M.singleton (tvVar chunk_v) chunk_const } sKernelThread "segred_large" (segFlat space) attrs $ do constants <- kernelConstants <$> askEnv let tblock_id = sExt64 $ kernelBlockId constants ltid = kernelLocalThreadId constants interms <- makeIntermArrays tblock_id tblock_size_se (tvSize chunk_v) segbinops sync_arr <- sAllocArray "sync_arr" Bool (Shape [intConst Int32 1]) $ Space "shared" -- We probably do not have enough actual threadblocks to cover the -- entire iteration space. Some blocks thus have to perform double -- duty; we put an outer loop to accomplish this. virtualiseBlocks SegVirt (sExt32 (tvExp num_virtblocks)) $ \virttblock_id -> do let segment_gtids = init gtids flat_segment_id <- dPrimVE "flat_segment_id" $ sExt64 virttblock_id `quot` blocks_per_segment global_tid <- dPrimVE "global_tid" $ (sExt64 virttblock_id * sExt64 tblock_size' + sExt64 ltid) `rem` threads_per_segment let first_block_for_segment = flat_segment_id * blocks_per_segment dIndexSpace (zip segment_gtids (init dims')) flat_segment_id dPrim_ (last gtids) int64 let n = pe64 $ last dims slugs <- mapM (segBinOpSlug ltid virttblock_id) $ zip3 segbinops interms reds_block_res_arrs new_lambdas <- reductionStageOne gtids n global_tid q chunk threads_per_segment slugs map_body_cont let segred_pess = chunks (map (length . segBinOpNeutral) segbinops) segred_pes multiple_blocks_per_segment = forM_ (zip4 segred_pess slugs new_lambdas [0 ..]) $ \(pes, slug, new_lambda, i) -> do let counter_idx = fromIntegral (i * num_counters) + flat_segment_id `rem` fromIntegral num_counters reductionStageTwo pes virttblock_id (map Imp.le64 segment_gtids) first_block_for_segment blocks_per_segment slug new_lambda counters sync_arr counter_idx one_block_per_segment = sComment "first thread in block saves final result to memory" $ forM2_ slugs segred_pess $ \slug pes -> sWhen (ltid .==. 0) $ forM2_ pes (slugAccs slug) $ \v (acc, acc_is) -> copyDWIMFix (patElemName v) (map Imp.le64 segment_gtids) (Var acc) acc_is sIf (blocks_per_segment .==. 1) one_block_per_segment multiple_blocks_per_segment -- | Auxiliary information for a single reduction. A slug holds the `SegBinOp` -- operator for a single reduction, the different arrays required throughout -- stages one and two, and a global mem destination for the final result of the -- particular reduction. data SegBinOpSlug = SegBinOpSlug { slugOp :: SegBinOp GPUMem, -- | Intermediate arrays needed for the given reduction. slugInterms :: SegRedIntermediateArrays, -- | Place(s) to store block accumulator(s) in stage 1 reduction. slugAccs :: [(VName, [Imp.TExp Int64])], -- | Global memory destination(s) for the final result(s) for this -- particular reduction. blockResArrs :: [VName] } segBinOpSlug :: Imp.TExp Int32 -> Imp.TExp Int32 -> (SegBinOp GPUMem, SegRedIntermediateArrays, [VName]) -> InKernelGen SegBinOpSlug segBinOpSlug ltid tblock_id (op, interms, block_res_arrs) = do accs <- zipWithM mkAcc (lambdaParams (segBinOpLambda op)) block_res_arrs pure $ SegBinOpSlug op interms accs block_res_arrs where mkAcc p block_res_arr | Prim t <- paramType p, shapeRank (segBinOpShape op) == 0 = do block_res_acc <- dPrimS (baseString (paramName p) <> "_block_res_acc") t pure (block_res_acc, []) -- if this is a non-primitive reduction, the global mem result array will -- double as accumulator. | otherwise = pure (block_res_arr, [sExt64 ltid, sExt64 tblock_id]) slugLambda :: SegBinOpSlug -> Lambda GPUMem slugLambda = segBinOpLambda . slugOp slugBody :: SegBinOpSlug -> Body GPUMem slugBody = lambdaBody . slugLambda slugParams :: SegBinOpSlug -> [LParam GPUMem] slugParams = lambdaParams . slugLambda slugNeutral :: SegBinOpSlug -> [SubExp] slugNeutral = segBinOpNeutral . slugOp slugShape :: SegBinOpSlug -> Shape slugShape = segBinOpShape . slugOp slugsComm :: [SegBinOpSlug] -> Commutativity slugsComm = mconcat . map (segBinOpComm . slugOp) slugSplitParams :: SegBinOpSlug -> ([LParam GPUMem], [LParam GPUMem]) slugSplitParams slug = splitAt (length (slugNeutral slug)) $ slugParams slug slugBlockRedArrs :: SegBinOpSlug -> [VName] slugBlockRedArrs = blockRedArrs . slugInterms slugPrivChunks :: SegBinOpSlug -> [VName] slugPrivChunks = privateChunks . slugInterms slugCollCopyArrs :: SegBinOpSlug -> [VName] slugCollCopyArrs = collCopyArrs . slugInterms reductionStageOne :: [VName] -> Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Int64 -> [SegBinOpSlug] -> DoSegBody -> InKernelGen [Lambda GPUMem] reductionStageOne gtids n global_tid q chunk threads_per_segment slugs body_cont = do constants <- kernelConstants <$> askEnv let glb_ind_var = mkTV (last gtids) ltid = sExt64 $ kernelLocalThreadId constants dScope Nothing $ scopeOfLParams $ concatMap slugParams slugs sComment "ne-initialise the outer (per-block) accumulator(s)" $ do forM_ slugs $ \slug -> forM2_ (slugAccs slug) (slugNeutral slug) $ \(acc, acc_is) ne -> sLoopNest (slugShape slug) $ \vec_is -> copyDWIMFix acc (acc_is ++ vec_is) ne [] new_lambdas <- mapM (renameLambda . slugLambda) slugs let tblock_size = sExt32 $ kernelBlockSize constants let doBlockReduce = forM2_ slugs new_lambdas $ \slug new_lambda -> do let accs = slugAccs slug let params = slugParams slug sLoopNest (slugShape slug) $ \vec_is -> do let block_red_arrs = slugBlockRedArrs slug sComment "store accs. prims go in lmem; non-prims in params (in global mem)" $ forM_ (zip3 block_red_arrs accs params) $ \(arr, (acc, acc_is), p) -> if isPrimParam p then copyDWIMFix arr [ltid] (Var acc) (acc_is ++ vec_is) else copyDWIMFix (paramName p) [] (Var acc) (acc_is ++ vec_is) sOp $ Imp.ErrorSync Imp.FenceLocal -- Also implicitly barrier. blockReduce tblock_size new_lambda block_red_arrs sOp $ Imp.Barrier Imp.FenceLocal sComment "thread 0 updates per-block acc(s); rest reset to ne" $ do sIf (ltid .==. 0) ( forM2_ accs (lambdaParams new_lambda) $ \(acc, acc_is) p -> copyDWIMFix acc (acc_is ++ vec_is) (Var $ paramName p) [] ) ( forM2_ accs (slugNeutral slug) $ \(acc, acc_is) ne -> copyDWIMFix acc (acc_is ++ vec_is) ne [] ) case (slugsComm slugs, all (isPrimSegBinOp . slugOp) slugs) of (Noncommutative, True) -> noncommPrimParamsStageOneBody slugs body_cont glb_ind_var global_tid q n chunk doBlockReduce _ -> generalStageOneBody slugs body_cont glb_ind_var global_tid q n threads_per_segment doBlockReduce pure new_lambdas generalStageOneBody :: [SegBinOpSlug] -> DoSegBody -> TV Int64 -> Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Int64 -> InKernelGen () -> InKernelGen () generalStageOneBody slugs body_cont glb_ind_var global_tid q n threads_per_segment doBlockReduce = do let is_comm = slugsComm slugs == Commutative constants <- kernelConstants <$> askEnv let tblock_size = kernelBlockSize constants ltid = sExt64 $ kernelLocalThreadId constants -- this block's id within its designated segment, and this block's initial -- global offset. tblock_id_in_segment <- dPrimVE "tblock_id_in_segment" $ global_tid `quot` tblock_size block_base_offset <- dPrimVE "block_base_offset" $ tblock_id_in_segment * q * tblock_size sFor "i" q $ \i -> do block_offset <- dPrimVE "block_offset" $ block_base_offset + i * tblock_size glb_ind_var <-- if is_comm then global_tid + threads_per_segment * i else block_offset + ltid sWhen (tvExp glb_ind_var .<. n) $ sComment "apply map function(s)" $ body_cont $ \all_red_res -> do let maps_res = chunks (map (length . slugNeutral) slugs) all_red_res forM2_ slugs maps_res $ \slug map_res -> sLoopNest (slugShape slug) $ \vec_is -> do let (acc_params, next_params) = slugSplitParams slug sComment "load accumulator(s)" $ forM2_ acc_params (slugAccs slug) $ \p (acc, acc_is) -> copyDWIMFix (paramName p) [] (Var acc) (acc_is ++ vec_is) sComment "load next value(s)" $ forM2_ next_params map_res $ \p (res, res_is) -> copyDWIMFix (paramName p) [] res (res_is ++ vec_is) sComment "apply reduction operator(s)" $ compileStms mempty (bodyStms $ slugBody slug) $ sComment "store in accumulator(s)" $ forM2_ (slugAccs slug) (map resSubExp $ bodyResult $ slugBody slug) $ \(acc, acc_is) se -> copyDWIMFix acc (acc_is ++ vec_is) se [] unless is_comm doBlockReduce sOp $ Imp.ErrorSync Imp.FenceLocal when is_comm doBlockReduce noncommPrimParamsStageOneBody :: [SegBinOpSlug] -> DoSegBody -> TV Int64 -> Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Int64 -> InKernelGen () -> InKernelGen () noncommPrimParamsStageOneBody slugs body_cont glb_ind_var global_tid q n chunk doLMemBlockReduce = do constants <- kernelConstants <$> askEnv let tblock_size = kernelBlockSize constants ltid = sExt64 $ kernelLocalThreadId constants -- this block's id within its designated segment; the stride made per block in -- the outer `i < q` loop; and this block's initial global offset. tblock_id_in_segment <- dPrimVE "block_offset_in_segment" $ global_tid `quot` tblock_size block_stride <- dPrimVE "block_stride" $ tblock_size * chunk block_base_offset <- dPrimVE "block_base_offset" $ tblock_id_in_segment * q * block_stride let chunkLoop = sFor "k" chunk sFor "i" q $ \i -> do -- block offset in this iteration. block_offset <- dPrimVE "block_offset" $ block_base_offset + i * block_stride chunkLoop $ \k -> do loc_ind <- dPrimVE "loc_ind" $ k * tblock_size + ltid glb_ind_var <-- block_offset + loc_ind sIf (tvExp glb_ind_var .<. n) ( body_cont $ \all_red_res -> do let slugs_res = chunks (map (length . slugNeutral) slugs) all_red_res forM2_ slugs slugs_res $ \slug slug_res -> do let priv_chunks = slugPrivChunks slug sComment "write map result(s) to private chunk(s)" $ forM2_ priv_chunks slug_res $ \priv_chunk (res, res_is) -> copyDWIMFix priv_chunk [k] res res_is ) -- if out of bounds, fill chunk(s) with neutral element(s) ( forM_ slugs $ \slug -> forM2_ (slugPrivChunks slug) (slugNeutral slug) $ \priv_chunk ne -> copyDWIMFix priv_chunk [k] ne [] ) sOp $ Imp.ErrorSync Imp.FenceLocal sComment "effectualize collective copies in shared memory" $ do forM_ slugs $ \slug -> do let coll_copy_arrs = slugCollCopyArrs slug let priv_chunks = slugPrivChunks slug forM2_ coll_copy_arrs priv_chunks $ \lmem_arr priv_chunk -> do chunkLoop $ \k -> do lmem_idx <- dPrimVE "lmem_idx" $ ltid + k * tblock_size copyDWIMFix lmem_arr [lmem_idx] (Var priv_chunk) [k] sOp $ Imp.Barrier Imp.FenceLocal chunkLoop $ \k -> do lmem_idx <- dPrimVE "lmem_idx" $ ltid * chunk + k copyDWIMFix priv_chunk [k] (Var lmem_arr) [lmem_idx] sOp $ Imp.Barrier Imp.FenceLocal sComment "per-thread sequential reduction of private chunk(s)" $ do chunkLoop $ \k -> forM_ slugs $ \slug -> do let accs = map fst $ slugAccs slug let (acc_ps, next_ps) = slugSplitParams slug let ps_accs_chunks = zip4 acc_ps next_ps accs (slugPrivChunks slug) sComment "load params for all reductions" $ do forM_ ps_accs_chunks $ \(acc_p, next_p, acc, priv_chunk) -> do copyDWIM (paramName acc_p) [] (Var acc) [] copyDWIMFix (paramName next_p) [] (Var priv_chunk) [k] sComment "apply reduction operator(s)" $ do let binop_ress = map resSubExp $ bodyResult $ slugBody slug compileStms mempty (bodyStms $ slugBody slug) $ forM2_ accs binop_ress $ \acc binop_res -> copyDWIM acc [] binop_res [] doLMemBlockReduce sOp $ Imp.ErrorSync Imp.FenceLocal reductionStageTwo :: [PatElem LetDecMem] -> Imp.TExp Int32 -> [Imp.TExp Int64] -> Imp.TExp Int64 -> Imp.TExp Int64 -> SegBinOpSlug -> Lambda GPUMem -> VName -> VName -> Imp.TExp Int64 -> InKernelGen () reductionStageTwo segred_pes tblock_id segment_gtids first_block_for_segment blocks_per_segment slug new_lambda counters sync_arr counter_idx = do constants <- kernelConstants <$> askEnv let ltid32 = kernelLocalThreadId constants ltid = sExt64 ltid32 tblock_size = kernelBlockSize constants let (acc_params, next_params) = slugSplitParams slug nes = slugNeutral slug red_arrs = slugBlockRedArrs slug block_res_arrs = blockResArrs slug old_counter <- dPrim "old_counter" (counter_mem, _, counter_offset) <- fullyIndexArray counters [counter_idx] sComment "first thread in block saves block result to global memory" $ sWhen (ltid32 .==. 0) $ do forM_ (take (length nes) $ zip block_res_arrs (slugAccs slug)) $ \(v, (acc, acc_is)) -> writeAtomic v [0, sExt64 tblock_id] (Var acc) acc_is -- Increment the counter, thus stating that our result is -- available. sOp $ Imp.Atomic DefaultSpace $ Imp.AtomicAdd Int32 (tvVar old_counter) counter_mem counter_offset $ untyped (1 :: Imp.TExp Int32) -- Now check if we were the last block to write our result. If -- so, it is our responsibility to produce the final result. sWrite sync_arr [0] $ untyped $ tvExp old_counter .==. sExt32 (blocks_per_segment - 1) sOp $ Imp.Barrier Imp.FenceGlobal is_last_block <- dPrim "is_last_block" copyDWIMFix (tvVar is_last_block) [] (Var sync_arr) [0] sWhen (tvExp is_last_block) $ do -- The final block has written its result (and it was -- us!), so read in all the block results and perform the -- final stage of the reduction. But first, we reset the -- counter so it is ready for next time. This is done -- with an atomic to avoid warnings about write/write -- races in oclgrind. sWhen (ltid32 .==. 0) $ sOp $ Imp.Atomic DefaultSpace $ Imp.AtomicAdd Int32 (tvVar old_counter) counter_mem counter_offset $ untyped $ sExt32 (negate blocks_per_segment) sLoopNest (slugShape slug) $ \vec_is -> do unless (null $ slugShape slug) $ sOp (Imp.Barrier Imp.FenceLocal) -- There is no guarantee that the number of threadblocks for the -- segment is less than the threadblock size, so each thread may -- have to read multiple elements. We do this in a sequential -- way that may induce non-coalesced accesses, but the total -- number of accesses should be tiny here. -- -- TODO: here we *could* insert a collective copy of the num_tblocks -- per-block results. However, it may not be beneficial, since num_tblocks -- is not necessarily larger than tblock_size, meaning the number of -- uncoalesced reads here may be insignificant. In fact, if we happen to -- have a num_tblocks < tblock_size, then the collective copy would add -- unnecessary overhead. Also, this code is only executed by a single -- block. sComment "read in the per-block-results" $ do read_per_thread <- dPrimVE "read_per_thread" $ blocks_per_segment `divUp` sExt64 tblock_size forM2_ acc_params nes $ \p ne -> copyDWIM (paramName p) [] ne [] sFor "i" read_per_thread $ \i -> do block_res_id <- dPrimVE "block_res_id" $ ltid * read_per_thread + i index_of_block_res <- dPrimVE "index_of_block_res" $ first_block_for_segment + block_res_id sWhen (block_res_id .<. blocks_per_segment) $ do forM2_ next_params block_res_arrs $ \p block_res_arr -> copyDWIMFix (paramName p) [] (Var block_res_arr) ([0, index_of_block_res] ++ vec_is) compileStms mempty (bodyStms $ slugBody slug) $ forM2_ acc_params (map resSubExp $ bodyResult $ slugBody slug) $ \p se -> copyDWIMFix (paramName p) [] se [] forM2_ acc_params red_arrs $ \p arr -> when (isPrimParam p) $ copyDWIMFix arr [ltid] (Var $ paramName p) [] sOp $ Imp.ErrorSync Imp.FenceLocal sComment "reduce the per-block results" $ do blockReduce (sExt32 tblock_size) new_lambda red_arrs sComment "and back to memory with the final result" $ sWhen (ltid32 .==. 0) $ forM2_ segred_pes (lambdaParams new_lambda) $ \pe p -> copyDWIMFix (patElemName pe) (segment_gtids ++ vec_is) (Var $ paramName p) [] -- Note [IntermArrays] -- -- Intermediate memory for the nonsegmented and large segments non-commutative -- reductions with all primitive parameters: -- -- These kernels need shared memory for 1) the initial collective copy, 2) the -- (virtualized) block reductions, and (TODO: this one not implemented yet!) -- 3) the final single-block collective copy. There are no dependencies -- between these three stages, so we can reuse the same pool of local mem for -- all three. These intermediates all go into local mem because of the -- assumption of primitive parameter types. -- -- Let `elem_sizes` be a list of element type sizes for the reduction -- operators in a given redomap fusion. Then the amount of local mem needed -- across the three steps are: -- -- 1) The initial collective copy from global to thread-private memory -- requires `tblock_size * CHUNK * max elem_sizes`, since the collective copies -- are performed in sequence (ie. inputs to different reduction operators need -- not be held in local mem simultaneously). -- 2) The intra-block reductions of shared memory held per-thread results -- require `tblock_size * sum elem_sizes` bytes, since per-thread results for -- all fused reductions are block-reduced simultaneously. -- 3) If tblock_size < num_tblocks, then after the final single-block collective -- copy, a thread-sequential reduction reduces the number of per-block partial -- results from num_tblocks down to tblock_size for each reduction array, such -- that they will each fit in the final intra-block reduction. This requires -- `num_tblocks * max elem_sizes`. -- -- In summary, the total amount of local mem needed is the maximum between: -- 1) initial collective copy: tblock_size * CHUNK * max elem_sizes -- 2) intra-block reductions: tblock_size * sum elem_sizes -- 3) final collective copy: num_tblocks * max elem_sizes -- -- The amount of local mem will most likely be decided by 1) in most cases, -- unless the number of fused operators is very high *or* if we have a -- `num_tblocks > tblock_size * CHUNK`, but this is unlikely, in which case 2) -- and 3), respectively, will dominate. -- -- Aside from shared memory, these kernels also require a CHUNK-sized array of -- thread-private register memory per reduction operator. -- -- For all other reductions, ie. commutative reductions, reductions with at -- least one non-primitive operator, and small segments reductions: -- -- These kernels use shared memory only for the intra-block reductions, and -- since they do not use chunking or CHUNK, they all require onlly `tblock_size -- * max elem_sizes` bytes of shared memory and no thread-private register mem. futhark-0.25.27/src/Futhark/CodeGen/ImpGen/GPU/SegScan.hs000066400000000000000000000057061475065116200225420ustar00rootroot00000000000000-- | Code generation for 'SegScan'. Dispatches to either a -- single-pass or two-pass implementation, depending on the nature of -- the scan and the chosen abckend. module Futhark.CodeGen.ImpGen.GPU.SegScan (compileSegScan) where import Control.Monad import Data.Maybe import Futhark.CodeGen.ImpCode.GPU qualified as Imp import Futhark.CodeGen.ImpGen hiding (compileProg) import Futhark.CodeGen.ImpGen.GPU.Base import Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass qualified as SinglePass import Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass qualified as TwoPass import Futhark.IR.GPUMem -- The single-pass scan does not support multiple operators, so jam -- them together here. combineScanOps :: [SegBinOp GPUMem] -> SegBinOp GPUMem combineScanOps ops = SegBinOp { segBinOpComm = mconcat (map segBinOpComm ops), segBinOpLambda = lam', segBinOpNeutral = concatMap segBinOpNeutral ops, segBinOpShape = mempty -- Assumed } where lams = map segBinOpLambda ops xParams lam = take (length (lambdaReturnType lam)) (lambdaParams lam) yParams lam = drop (length (lambdaReturnType lam)) (lambdaParams lam) lam' = Lambda { lambdaParams = concatMap xParams lams ++ concatMap yParams lams, lambdaReturnType = concatMap lambdaReturnType lams, lambdaBody = Body () (mconcat (map (bodyStms . lambdaBody) lams)) (concatMap (bodyResult . lambdaBody) lams) } bodyHas :: (Exp GPUMem -> Bool) -> Body GPUMem -> Bool bodyHas f = any (f' . stmExp) . bodyStms where f' e | f e = True | otherwise = isNothing $ walkExpM walker e walker = identityWalker { walkOnBody = const $ guard . not . bodyHas f } canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem) canBeSinglePass scan_ops = if all ok scan_ops then Just $ combineScanOps scan_ops else Nothing where ok op = segBinOpShape op == mempty && all primType (lambdaReturnType (segBinOpLambda op)) && not (bodyHas isAssert (lambdaBody (segBinOpLambda op))) isAssert (BasicOp Assert {}) = True isAssert _ = False -- | Compile 'SegScan' instance to host-level code with calls to -- various kernels. compileSegScan :: Pat LetDecMem -> SegLevel -> SegSpace -> [SegBinOp GPUMem] -> KernelBody GPUMem -> CallKernelGen () compileSegScan pat lvl space scan_ops map_kbody = sWhen (0 .<. n) $ do emit $ Imp.DebugPrint "\n# SegScan" Nothing target <- hostTarget <$> askEnv case (targetSupportsSinglePass target, canBeSinglePass scan_ops) of (True, Just scan_ops') -> SinglePass.compileSegScan pat lvl space scan_ops' map_kbody _ -> TwoPass.compileSegScan pat lvl space scan_ops map_kbody emit $ Imp.DebugPrint "" Nothing where n = product $ map pe64 $ segSpaceDims space targetSupportsSinglePass CUDA = True targetSupportsSinglePass HIP = True targetSupportsSinglePass _ = False futhark-0.25.27/src/Futhark/CodeGen/ImpGen/GPU/SegScan/000077500000000000000000000000001475065116200221765ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/ImpGen/GPU/SegScan/SinglePass.hs000066400000000000000000000627011475065116200246100ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Code generation for segmented and non-segmented scans. Uses a -- fast single-pass algorithm, but which only works on NVIDIA GPUs and -- with some constraints on the operator. We use this when we can. module Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass (compileSegScan) where import Control.Monad import Data.List (zip4, zip7) import Data.Map qualified as M import Data.Maybe import Futhark.CodeGen.ImpCode.GPU qualified as Imp import Futhark.CodeGen.ImpGen import Futhark.CodeGen.ImpGen.GPU.Base import Futhark.IR.GPUMem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.Transform.Rename import Futhark.Util (mapAccumLM, takeLast) import Futhark.Util.IntegralExp (IntegralExp (mod, rem), divUp, nextMul, quot) import Prelude hiding (mod, quot, rem) xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem] xParams scan = take (length (segBinOpNeutral scan)) (lambdaParams (segBinOpLambda scan)) yParams scan = drop (length (segBinOpNeutral scan)) (lambdaParams (segBinOpLambda scan)) createLocalArrays :: Count BlockSize SubExp -> SubExp -> [PrimType] -> InKernelGen (VName, [VName], [VName], VName, [VName]) createLocalArrays (Count block_size) chunk types = do let block_sizeE = pe64 block_size workSize = pe64 chunk * block_sizeE prefixArraysSize = foldl (\acc tySize -> nextMul acc tySize + tySize * block_sizeE) 0 $ map primByteSize types maxTransposedArraySize = foldl1 sMax64 $ map (\ty -> workSize * primByteSize ty) types warp_size :: (Num a) => a warp_size = 32 maxWarpExchangeSize = foldl (\acc tySize -> nextMul acc tySize + tySize * fromInteger warp_size) 0 $ map primByteSize types maxLookbackSize = maxWarpExchangeSize + warp_size size = Imp.bytes $ maxLookbackSize `sMax64` prefixArraysSize `sMax64` maxTransposedArraySize (_, byteOffsets) <- mapAccumLM ( \off tySize -> do off' <- dPrimVE "byte_offsets" $ nextMul off tySize + pe64 block_size * tySize pure (off', off) ) 0 $ map primByteSize types (_, warpByteOffsets) <- mapAccumLM ( \off tySize -> do off' <- dPrimVE "warp_byte_offset" $ nextMul off tySize + warp_size * tySize pure (off', off) ) warp_size $ map primByteSize types sComment "Allocate reusable shared memory" $ pure () localMem <- sAlloc "local_mem" size (Space "shared") transposeArrayLength <- dPrimV "trans_arr_len" workSize sharedId <- sArrayInMem "shared_id" int32 (Shape [constant (1 :: Int32)]) localMem transposedArrays <- forM types $ \ty -> sArrayInMem "local_transpose_arr" ty (Shape [tvSize transposeArrayLength]) localMem prefixArrays <- forM (zip byteOffsets types) $ \(off, ty) -> do let off' = off `quot` primByteSize ty sArray "local_prefix_arr" ty (Shape [block_size]) localMem $ LMAD.iota off' [pe64 block_size] warpscan <- sArrayInMem "warpscan" int8 (Shape [constant (warp_size :: Int64)]) localMem warpExchanges <- forM (zip warpByteOffsets types) $ \(off, ty) -> do let off' = off `quot` primByteSize ty sArray "warp_exchange" ty (Shape [constant (warp_size :: Int64)]) localMem $ LMAD.iota off' [warp_size] pure (sharedId, transposedArrays, prefixArrays, warpscan, warpExchanges) statusX, statusA, statusP :: (Num a) => a statusX = 0 statusA = 1 statusP = 2 inBlockScanLookback :: KernelConstants -> Imp.TExp Int64 -> VName -> [VName] -> Lambda GPUMem -> InKernelGen () inBlockScanLookback constants arrs_full_size flag_arr arrs scan_lam = everythingVolatile $ do flg_x :: TV Int8 <- dPrim "flg_x" flg_y :: TV Int8 <- dPrim "flg_y" let flg_param_x = Param mempty (tvVar flg_x) (MemPrim p_int8) flg_param_y = Param mempty (tvVar flg_y) (MemPrim p_int8) flg_y_exp = tvExp flg_y statusP_e = statusP :: Imp.TExp Int8 statusX_e = statusX :: Imp.TExp Int8 dLParams (lambdaParams scan_lam) skip_threads <- dPrim "skip_threads" let in_block_thread_active = tvExp skip_threads .<=. in_block_id actual_params = lambdaParams scan_lam (x_params, y_params) = splitAt (length actual_params `div` 2) actual_params y_to_x = forM_ (zip x_params y_params) $ \(x, y) -> when (primType (paramType x)) $ copyDWIM (paramName x) [] (Var (paramName y)) [] y_to_x_flg = copyDWIM (tvVar flg_x) [] (Var (tvVar flg_y)) [] -- Set initial y values sComment "read input for in-block scan" $ do zipWithM_ readInitial (flg_param_y : y_params) (flag_arr : arrs) -- Since the final result is expected to be in x_params, we may -- need to copy it there for the first thread in the block. sWhen (in_block_id .==. 0) $ do y_to_x y_to_x_flg when array_scan barrier let op_to_x = do sIf (flg_y_exp .==. statusP_e .||. flg_y_exp .==. statusX_e) ( do y_to_x_flg y_to_x ) (compileBody' x_params $ lambdaBody scan_lam) sComment "in-block scan (hopefully no barriers needed)" $ do skip_threads <-- 1 sWhile (tvExp skip_threads .<. block_size) $ do sWhen in_block_thread_active $ do sComment "read operands" $ zipWithM_ (readParam (sExt64 $ tvExp skip_threads)) (flg_param_x : x_params) (flag_arr : arrs) sComment "perform operation" op_to_x sComment "write result" $ sequence_ $ zipWith3 writeResult (flg_param_x : x_params) (flg_param_y : y_params) (flag_arr : arrs) skip_threads <-- tvExp skip_threads * 2 where p_int8 = IntType Int8 block_size = 32 block_id = ltid32 `quot` block_size in_block_id = ltid32 - block_id * block_size ltid32 = kernelLocalThreadId constants ltid = sExt64 ltid32 gtid = sExt64 $ kernelGlobalThreadId constants array_scan = not $ all primType $ lambdaReturnType scan_lam barrier | array_scan = sOp $ Imp.Barrier Imp.FenceGlobal | otherwise = sOp $ Imp.Barrier Imp.FenceLocal readInitial p arr | primType $ paramType p = copyDWIMFix (paramName p) [] (Var arr) [ltid] | otherwise = copyDWIMFix (paramName p) [] (Var arr) [gtid] readParam behind p arr | primType $ paramType p = copyDWIMFix (paramName p) [] (Var arr) [ltid - behind] | otherwise = copyDWIMFix (paramName p) [] (Var arr) [gtid - behind + arrs_full_size] writeResult x y arr = do when (isPrimParam x) $ copyDWIMFix arr [ltid] (Var $ paramName x) [] copyDWIM (paramName y) [] (Var $ paramName x) [] -- | Compile 'SegScan' instance to host-level code with calls to a -- single-pass kernel. compileSegScan :: Pat LetDecMem -> SegLevel -> SegSpace -> SegBinOp GPUMem -> KernelBody GPUMem -> CallKernelGen () compileSegScan pat lvl space scan_op map_kbody = do attrs <- lvlKernelAttrs lvl let Pat all_pes = pat scanop_nes = segBinOpNeutral scan_op n = product $ map pe64 $ segSpaceDims space tys' = lambdaReturnType $ segBinOpLambda scan_op tys = map elemType tys' tblock_size_e = pe64 $ unCount $ kAttrBlockSize attrs num_phys_blocks_e = pe64 $ unCount $ kAttrNumBlocks attrs let chunk_const = getChunkSize tys' chunk_v <- dPrimV "chunk_size" . isInt64 =<< kernelConstToExp chunk_const let chunk = tvExp chunk_v num_virt_blocks <- tvSize <$> dPrimV "num_virt_blocks" (n `divUp` (tblock_size_e * chunk)) let num_virt_blocks_e = pe64 num_virt_blocks num_virt_threads <- dPrimVE "num_virt_threads" $ num_virt_blocks_e * tblock_size_e let (gtids, dims) = unzip $ unSegSpace space dims' = map pe64 dims segmented = length dims' > 1 not_segmented_e = fromBool $ not segmented segment_size = last dims' emit $ Imp.DebugPrint "Sequential elements per thread (chunk)" $ Just $ untyped chunk statusFlags <- sAllocArray "status_flags" int8 (Shape [num_virt_blocks]) (Space "device") sReplicate statusFlags $ intConst Int8 statusX (aggregateArrays, incprefixArrays) <- fmap unzip $ forM tys $ \ty -> (,) <$> sAllocArray "aggregates" ty (Shape [num_virt_blocks]) (Space "device") <*> sAllocArray "incprefixes" ty (Shape [num_virt_blocks]) (Space "device") global_id <- genZeroes "global_dynid" 1 let attrs' = attrs {kAttrConstExps = M.singleton (tvVar chunk_v) chunk_const} sKernelThread "segscan" (segFlat space) attrs' $ do chunk32 <- dPrimVE "chunk_size_32b" $ sExt32 $ tvExp chunk_v constants <- kernelConstants <$> askEnv let ltid32 = kernelLocalThreadId constants ltid = sExt64 ltid32 (sharedId, transposedArrays, prefixArrays, warpscan, exchanges) <- createLocalArrays (kAttrBlockSize attrs) (tvSize chunk_v) tys -- We wrap the entire kernel body in a virtualisation loop to -- handle the case where we do not have enough thread blocks to -- cover the iteration space. Dynamic block indexing has no -- implication on this, since each block simply fetches a new -- dynamic ID upon entry into the virtualisation loop. -- -- We could use virtualiseBlocks, but this introduces a barrier which is -- redundant in this case, and also we don't need to base virtual block IDs -- on the loop variable, but rather on the dynamic IDs. phys_block_id <- dPrim "phys_block_id" sOp $ Imp.GetBlockId (tvVar phys_block_id) 0 iters <- dPrimVE "virtloop_bound" $ (num_virt_blocks_e - tvExp phys_block_id) `divUp` num_phys_blocks_e sFor "virtloop_i" iters $ const $ do dyn_id <- dPrim "dynamic_id" sComment "First thread in block fetches this block's dynamic_id" $ do sWhen (ltid32 .==. 0) $ do (globalIdMem, _, globalIdOff) <- fullyIndexArray global_id [0] sOp $ Imp.Atomic DefaultSpace $ Imp.AtomicAdd Int32 (tvVar dyn_id) globalIdMem (Count $ unCount globalIdOff) (untyped (1 :: Imp.TExp Int32)) sComment "Set dynamic id for this block" $ do copyDWIMFix sharedId [0] (tvSize dyn_id) [] sComment "First thread in last (virtual) block resets global dynamic_id" $ do sWhen (tvExp dyn_id .==. num_virt_blocks_e - 1) $ copyDWIMFix global_id [0] (intConst Int32 0) [] let local_barrier = Imp.Barrier Imp.FenceLocal local_fence = Imp.MemFence Imp.FenceLocal global_fence = Imp.MemFence Imp.FenceGlobal sOp local_barrier copyDWIMFix (tvVar dyn_id) [] (Var sharedId) [0] sOp local_barrier block_offset <- dPrimVE "block_offset" $ sExt64 (tvExp dyn_id) * chunk * tblock_size_e sgm_idx <- dPrimVE "sgm_idx" $ block_offset `mod` segment_size boundary <- dPrimVE "boundary" $ sExt32 $ sMin64 (chunk * tblock_size_e) (segment_size - sgm_idx) segsize_compact <- dPrimVE "segsize_compact" $ sExt32 $ sMin64 (chunk * tblock_size_e) segment_size private_chunks <- forM tys $ \ty -> sAllocArray "private" ty (Shape [tvSize chunk_v]) (ScalarSpace [tvSize chunk_v] ty) thd_offset <- dPrimVE "thd_offset" $ block_offset + ltid sComment "Load and map" $ sFor "i" chunk $ \i -> do -- The map's input index virt_tid <- dPrimVE "virt_tid" $ thd_offset + i * tblock_size_e dIndexSpace (zip gtids dims') virt_tid -- Perform the map let in_bounds = compileStms mempty (kernelBodyStms map_kbody) $ do let (all_scan_res, map_res) = splitAt (segBinOpResults [scan_op]) $ kernelBodyResult map_kbody -- Write map results to their global memory destinations forM_ (zip (takeLast (length map_res) all_pes) map_res) $ \(dest, src) -> copyDWIMFix (patElemName dest) (map Imp.le64 gtids) (kernelResultSubExp src) [] -- Write to-scan results to private memory. forM_ (zip private_chunks $ map kernelResultSubExp all_scan_res) $ \(dest, src) -> copyDWIMFix dest [i] src [] out_of_bounds = forM_ (zip private_chunks scanop_nes) $ \(dest, ne) -> copyDWIMFix dest [i] ne [] sIf (virt_tid .<. n) in_bounds out_of_bounds sOp $ Imp.ErrorSync Imp.FenceLocal sComment "Transpose scan inputs" $ do forM_ (zip transposedArrays private_chunks) $ \(trans, priv) -> do sFor "i" chunk $ \i -> do sharedIdx <- dPrimVE "sharedIdx" $ ltid + i * tblock_size_e copyDWIMFix trans [sharedIdx] (Var priv) [i] sOp local_barrier sFor "i" chunk $ \i -> do sharedIdx <- dPrimV "sharedIdx" $ ltid * chunk + i copyDWIMFix priv [sExt64 i] (Var trans) [sExt64 $ tvExp sharedIdx] sOp local_barrier sComment "Per thread scan" $ do -- We don't need to touch the first element, so only m-1 -- iterations here. sFor "i" (chunk - 1) $ \i -> do let xs = map paramName $ xParams scan_op ys = map paramName $ yParams scan_op -- determine if start of segment new_sgm <- if segmented then do gidx <- dPrimVE "gidx" $ (ltid32 * chunk32) + 1 dPrimVE "new_sgm" $ (gidx + sExt32 i - boundary) `mod` segsize_compact .==. 0 else pure false -- skip scan of first element in segment sUnless new_sgm $ do forM_ (zip4 private_chunks xs ys tys) $ \(src, x, y, ty) -> do dPrim_ x ty dPrim_ y ty copyDWIMFix x [] (Var src) [i] copyDWIMFix y [] (Var src) [i + 1] compileStms mempty (bodyStms $ lambdaBody $ segBinOpLambda scan_op) $ forM_ (zip private_chunks $ map resSubExp $ bodyResult $ lambdaBody $ segBinOpLambda scan_op) $ \(dest, res) -> copyDWIMFix dest [i + 1] res [] sComment "Publish results in shared memory" $ do forM_ (zip prefixArrays private_chunks) $ \(dest, src) -> copyDWIMFix dest [ltid] (Var src) [chunk - 1] sOp local_barrier let crossesSegment = do guard segmented Just $ \from to -> let from' = (from + 1) * chunk32 - 1 to' = (to + 1) * chunk32 - 1 in (to' - from') .>. (to' + segsize_compact - boundary) `mod` segsize_compact scan_op1 <- renameLambda $ segBinOpLambda scan_op accs <- mapM (dPrimSV "acc") tys sComment "Scan results (with warp scan)" $ do blockScan crossesSegment tblock_size_e num_virt_threads scan_op1 prefixArrays sOp $ Imp.ErrorSync Imp.FenceLocal let firstThread acc prefixes = copyDWIMFix (tvVar acc) [] (Var prefixes) [sExt64 tblock_size_e - 1] notFirstThread acc prefixes = copyDWIMFix (tvVar acc) [] (Var prefixes) [ltid - 1] sIf (ltid32 .==. 0) (zipWithM_ firstThread accs prefixArrays) (zipWithM_ notFirstThread accs prefixArrays) sOp local_barrier prefixes <- forM (zip scanop_nes tys) $ \(ne, ty) -> dPrimV "prefix" $ TPrimExp $ toExp' ty ne blockNewSgm <- dPrimVE "block_new_sgm" $ sgm_idx .==. 0 sComment "Perform lookback" $ do sWhen (blockNewSgm .&&. ltid32 .==. 0) $ do everythingVolatile $ forM_ (zip accs incprefixArrays) $ \(acc, incprefixArray) -> copyDWIMFix incprefixArray [tvExp dyn_id] (tvSize acc) [] sOp global_fence everythingVolatile $ copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusP) [] forM_ (zip scanop_nes accs) $ \(ne, acc) -> copyDWIMFix (tvVar acc) [] ne [] -- end sWhen let warp_size = kernelWaveSize constants sWhen (bNot blockNewSgm .&&. ltid32 .<. warp_size) $ do sWhen (ltid32 .==. 0) $ do sIf (not_segmented_e .||. boundary .==. sExt32 (tblock_size_e * chunk)) ( do everythingVolatile $ forM_ (zip aggregateArrays accs) $ \(aggregateArray, acc) -> copyDWIMFix aggregateArray [tvExp dyn_id] (tvSize acc) [] sOp global_fence everythingVolatile $ copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusA) [] ) ( do everythingVolatile $ forM_ (zip incprefixArrays accs) $ \(incprefixArray, acc) -> copyDWIMFix incprefixArray [tvExp dyn_id] (tvSize acc) [] sOp global_fence everythingVolatile $ copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusP) [] ) everythingVolatile $ copyDWIMFix warpscan [0] (Var statusFlags) [tvExp dyn_id - 1] -- sWhen sOp local_fence status :: TV Int8 <- dPrim "status" copyDWIMFix (tvVar status) [] (Var warpscan) [0] sIf (tvExp status .==. statusP) ( sWhen (ltid32 .==. 0) $ everythingVolatile $ forM_ (zip prefixes incprefixArrays) $ \(prefix, incprefixArray) -> copyDWIMFix (tvVar prefix) [] (Var incprefixArray) [tvExp dyn_id - 1] ) ( do readOffset <- dPrimV "readOffset" $ sExt32 $ tvExp dyn_id - sExt64 (kernelWaveSize constants) let loopStop = warp_size * (-1) sameSegment readIdx | segmented = let startIdx = sExt64 (tvExp readIdx + 1) * tblock_size_e * chunk - 1 in block_offset - startIdx .<=. sgm_idx | otherwise = true sWhile (tvExp readOffset .>. loopStop) $ do readI <- dPrimV "read_i" $ tvExp readOffset + ltid32 aggrs <- forM (zip scanop_nes tys) $ \(ne, ty) -> dPrimV "aggr" $ TPrimExp $ toExp' ty ne flag <- dPrimV "flag" (statusX :: Imp.TExp Int8) everythingVolatile . sWhen (tvExp readI .>=. 0) $ do sIf (sameSegment readI) ( do copyDWIMFix (tvVar flag) [] (Var statusFlags) [sExt64 $ tvExp readI] sIf (tvExp flag .==. statusP) ( forM_ (zip incprefixArrays aggrs) $ \(incprefix, aggr) -> copyDWIMFix (tvVar aggr) [] (Var incprefix) [sExt64 $ tvExp readI] ) ( sWhen (tvExp flag .==. statusA) $ do forM_ (zip aggrs aggregateArrays) $ \(aggr, aggregate) -> copyDWIMFix (tvVar aggr) [] (Var aggregate) [sExt64 $ tvExp readI] ) ) (copyDWIMFix (tvVar flag) [] (intConst Int8 statusP) []) -- end sIf -- end sWhen forM_ (zip exchanges aggrs) $ \(exchange, aggr) -> copyDWIMFix exchange [ltid] (tvSize aggr) [] copyDWIMFix warpscan [ltid] (tvSize flag) [] -- execute warp-parallel reduction but only if the last read flag in not STATUS_P copyDWIMFix (tvVar flag) [] (Var warpscan) [sExt64 warp_size - 1] sWhen (tvExp flag .<. statusP) $ do lam' <- renameLambda scan_op1 inBlockScanLookback constants num_virt_threads warpscan exchanges lam' -- all threads of the warp read the result of reduction copyDWIMFix (tvVar flag) [] (Var warpscan) [sExt64 warp_size - 1] forM_ (zip aggrs exchanges) $ \(aggr, exchange) -> copyDWIMFix (tvVar aggr) [] (Var exchange) [sExt64 warp_size - 1] -- update read offset sIf (tvExp flag .==. statusP) (readOffset <-- loopStop) ( sWhen (tvExp flag .==. statusA) $ do readOffset <-- tvExp readOffset - zExt32 warp_size ) -- update prefix if flag different than STATUS_X: sWhen (tvExp flag .>. statusX) $ do lam <- renameLambda scan_op1 let (xs, ys) = splitAt (length tys) $ map paramName $ lambdaParams lam forM_ (zip xs aggrs) $ \(x, aggr) -> dPrimV_ x (tvExp aggr) forM_ (zip ys prefixes) $ \(y, prefix) -> dPrimV_ y (tvExp prefix) compileStms mempty (bodyStms $ lambdaBody lam) $ forM_ (zip3 prefixes tys $ map resSubExp $ bodyResult $ lambdaBody lam) $ \(prefix, ty, res) -> prefix <-- TPrimExp (toExp' ty res) sOp local_fence ) -- end sWhile -- end sIf sWhen (ltid32 .==. 0) $ do scan_op2 <- renameLambda scan_op1 let xs = map paramName $ take (length tys) $ lambdaParams scan_op2 ys = map paramName $ drop (length tys) $ lambdaParams scan_op2 sWhen (boundary .==. sExt32 (tblock_size_e * chunk)) $ do forM_ (zip xs prefixes) $ \(x, prefix) -> dPrimV_ x $ tvExp prefix forM_ (zip ys accs) $ \(y, acc) -> dPrimV_ y $ tvExp acc compileStms mempty (bodyStms $ lambdaBody scan_op2) $ everythingVolatile $ forM_ (zip incprefixArrays $ map resSubExp $ bodyResult $ lambdaBody scan_op2) $ \(incprefixArray, res) -> copyDWIMFix incprefixArray [tvExp dyn_id] res [] sOp global_fence everythingVolatile $ copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusP) [] forM_ (zip exchanges prefixes) $ \(exchange, prefix) -> copyDWIMFix exchange [0] (tvSize prefix) [] forM_ (zip3 accs tys scanop_nes) $ \(acc, ty, ne) -> tvVar acc <~~ toExp' ty ne -- end sWhen -- end sWhen sWhen (bNot $ tvExp dyn_id .==. 0) $ do sOp local_barrier forM_ (zip exchanges prefixes) $ \(exchange, prefix) -> copyDWIMFix (tvVar prefix) [] (Var exchange) [0] sOp local_barrier -- end sWhen -- end sComment scan_op3 <- renameLambda scan_op1 scan_op4 <- renameLambda scan_op1 sComment "Distribute results" $ do let (xs, ys) = splitAt (length tys) $ map paramName $ lambdaParams scan_op3 (xs', ys') = splitAt (length tys) $ map paramName $ lambdaParams scan_op4 forM_ (zip7 prefixes accs xs xs' ys ys' tys) $ \(prefix, acc, x, x', y, y', ty) -> do dPrim_ x ty dPrim_ y ty dPrimV_ x' $ tvExp prefix dPrimV_ y' $ tvExp acc sIf (ltid32 * chunk32 .<. boundary .&&. bNot blockNewSgm) ( compileStms mempty (bodyStms $ lambdaBody scan_op4) $ forM_ (zip3 xs tys $ map resSubExp $ bodyResult $ lambdaBody scan_op4) $ \(x, ty, res) -> x <~~ toExp' ty res ) (forM_ (zip xs accs) $ \(x, acc) -> copyDWIMFix x [] (Var $ tvVar acc) []) -- calculate where previous thread stopped, to determine number of -- elements left before new segment. stop <- dPrimVE "stopping_point" $ segsize_compact - (ltid32 * chunk32 - 1 + segsize_compact - boundary) `rem` segsize_compact sFor "i" chunk $ \i -> do sWhen (sExt32 i .<. stop - 1) $ do forM_ (zip private_chunks ys) $ \(src, y) -> -- only include prefix for the first segment part per thread copyDWIMFix y [] (Var src) [i] compileStms mempty (bodyStms $ lambdaBody scan_op3) $ forM_ (zip private_chunks $ map resSubExp $ bodyResult $ lambdaBody scan_op3) $ \(dest, res) -> copyDWIMFix dest [i] res [] sComment "Transpose scan output and Write it to global memory in coalesced fashion" $ do forM_ (zip3 transposedArrays private_chunks $ map patElemName all_pes) $ \(locmem, priv, dest) -> do -- sOp local_barrier sFor "i" chunk $ \i -> do sharedIdx <- dPrimV "sharedIdx" $ sExt64 (ltid * chunk) + i copyDWIMFix locmem [tvExp sharedIdx] (Var priv) [i] sOp local_barrier sFor "i" chunk $ \i -> do flat_idx <- dPrimVE "flat_idx" $ thd_offset + i * tblock_size_e dIndexSpace (zip gtids dims') flat_idx sWhen (flat_idx .<. n) $ do copyDWIMFix dest (map Imp.le64 gtids) (Var locmem) [sExt64 $ flat_idx - block_offset] sOp local_barrier {-# NOINLINE compileSegScan #-} futhark-0.25.27/src/Futhark/CodeGen/ImpGen/GPU/SegScan/TwoPass.hs000066400000000000000000000456761475065116200241540ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Code generation for segmented and non-segmented scans. Uses a -- fairly inefficient two-pass algorithm, but can handle anything. module Futhark.CodeGen.ImpGen.GPU.SegScan.TwoPass (compileSegScan) where import Control.Monad import Control.Monad.State import Data.List qualified as L import Data.Maybe import Futhark.CodeGen.ImpCode.GPU qualified as Imp import Futhark.CodeGen.ImpGen import Futhark.CodeGen.ImpGen.GPU.Base import Futhark.IR.GPUMem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.Transform.Rename import Futhark.Util (takeLast) import Futhark.Util.IntegralExp (divUp, quot, rem) import Prelude hiding (quot, rem) -- Aggressively try to reuse memory for different SegBinOps, because -- we will run them sequentially after another. makeLocalArrays :: Count BlockSize SubExp -> SubExp -> [SegBinOp GPUMem] -> InKernelGen [[VName]] makeLocalArrays (Count tblock_size) num_threads scans = do (arrs, mems_and_sizes) <- runStateT (mapM onScan scans) mempty let maxSize sizes = Imp.bytes $ L.foldl' sMax64 1 $ map Imp.unCount sizes forM_ mems_and_sizes $ \(sizes, mem) -> sAlloc_ mem (maxSize sizes) (Space "shared") pure arrs where onScan (SegBinOp _ scan_op nes _) = do let (scan_x_params, _scan_y_params) = splitAt (length nes) $ lambdaParams scan_op (arrs, used_mems) <- fmap unzip $ forM scan_x_params $ \p -> case paramDec p of MemArray pt shape _ (ArrayIn mem _) -> do let shape' = Shape [num_threads] <> shape arr <- lift . sArray "scan_arr" pt shape' mem $ LMAD.iota 0 (map pe64 $ shapeDims shape') pure (arr, []) _ -> do let pt = elemType $ paramType p shape = Shape [tblock_size] (sizes, mem') <- getMem pt shape arr <- lift $ sArrayInMem "scan_arr" pt shape mem' pure (arr, [(sizes, mem')]) modify (<> concat used_mems) pure arrs getMem pt shape = do let size = typeSize $ Array pt shape NoUniqueness mems <- get case (L.find ((size `elem`) . fst) mems, mems) of (Just mem, _) -> do modify $ L.delete mem pure mem (Nothing, (size', mem) : mems') -> do put mems' pure (size : size', mem) (Nothing, []) -> do mem <- lift $ sDeclareMem "scan_arr_mem" $ Space "shared" pure ([size], mem) type CrossesSegment = Maybe (Imp.TExp Int64 -> Imp.TExp Int64 -> Imp.TExp Bool) localArrayIndex :: KernelConstants -> Type -> Imp.TExp Int64 localArrayIndex constants t = if primType t then sExt64 (kernelLocalThreadId constants) else sExt64 (kernelGlobalThreadId constants) barrierFor :: Lambda GPUMem -> (Bool, Imp.Fence, InKernelGen ()) barrierFor scan_op = (array_scan, fence, sOp $ Imp.Barrier fence) where array_scan = not $ all primType $ lambdaReturnType scan_op fence | array_scan = Imp.FenceGlobal | otherwise = Imp.FenceLocal xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem] xParams scan = take (length (segBinOpNeutral scan)) (lambdaParams (segBinOpLambda scan)) yParams scan = drop (length (segBinOpNeutral scan)) (lambdaParams (segBinOpLambda scan)) writeToScanValues :: [VName] -> ([PatElem LetDecMem], SegBinOp GPUMem, [KernelResult]) -> InKernelGen () writeToScanValues gtids (pes, scan, scan_res) | shapeRank (segBinOpShape scan) > 0 = forM_ (zip pes scan_res) $ \(pe, res) -> copyDWIMFix (patElemName pe) (map Imp.le64 gtids) (kernelResultSubExp res) [] | otherwise = forM_ (zip (yParams scan) scan_res) $ \(p, res) -> copyDWIMFix (paramName p) [] (kernelResultSubExp res) [] readToScanValues :: [Imp.TExp Int64] -> [PatElem LetDecMem] -> SegBinOp GPUMem -> InKernelGen () readToScanValues is pes scan | shapeRank (segBinOpShape scan) > 0 = forM_ (zip (yParams scan) pes) $ \(p, pe) -> copyDWIMFix (paramName p) [] (Var (patElemName pe)) is | otherwise = pure () readCarries :: Imp.TExp Int64 -> Imp.TExp Int64 -> [Imp.TExp Int64] -> [Imp.TExp Int64] -> [PatElem LetDecMem] -> SegBinOp GPUMem -> InKernelGen () readCarries chunk_id chunk_offset dims' vec_is pes scan | shapeRank (segBinOpShape scan) > 0 = do ltid <- kernelLocalThreadId . kernelConstants <$> askEnv -- We may have to reload the carries from the output of the -- previous chunk. sIf (chunk_id .>. 0 .&&. ltid .==. 0) ( do let is = unflattenIndex dims' $ chunk_offset - 1 forM_ (zip (xParams scan) pes) $ \(p, pe) -> copyDWIMFix (paramName p) [] (Var (patElemName pe)) (is ++ vec_is) ) ( forM_ (zip (xParams scan) (segBinOpNeutral scan)) $ \(p, ne) -> copyDWIMFix (paramName p) [] ne [] ) | otherwise = pure () -- | Produce partially scanned intervals; one per threadblock. scanStage1 :: Pat LetDecMem -> Count NumBlocks SubExp -> Count BlockSize SubExp -> SegSpace -> [SegBinOp GPUMem] -> KernelBody GPUMem -> CallKernelGen (TV Int32, Imp.TExp Int64, CrossesSegment) scanStage1 (Pat all_pes) num_tblocks tblock_size space scans kbody = do let num_tblocks' = fmap pe64 num_tblocks tblock_size' = fmap pe64 tblock_size num_threads <- dPrimV "num_threads" $ sExt32 $ unCount num_tblocks' * unCount tblock_size' let (gtids, dims) = unzip $ unSegSpace space dims' = map pe64 dims let num_elements = product dims' elems_per_thread = num_elements `divUp` sExt64 (tvExp num_threads) elems_per_group = unCount tblock_size' * elems_per_thread let crossesSegment = case reverse dims' of segment_size : _ : _ -> Just $ \from to -> (to - from) .>. (to `rem` segment_size) _ -> Nothing sKernelThread "scan_stage1" (segFlat space) (defKernelAttrs num_tblocks tblock_size) $ do constants <- kernelConstants <$> askEnv all_local_arrs <- makeLocalArrays tblock_size (tvSize num_threads) scans -- The variables from scan_op will be used for the carry and such -- in the big chunking loop. forM_ scans $ \scan -> do dScope Nothing $ scopeOfLParams $ lambdaParams $ segBinOpLambda scan forM_ (zip (xParams scan) (segBinOpNeutral scan)) $ \(p, ne) -> copyDWIMFix (paramName p) [] ne [] sFor "j" elems_per_thread $ \j -> do chunk_offset <- dPrimV "chunk_offset" $ sExt64 (kernelBlockSize constants) * j + sExt64 (kernelBlockId constants) * elems_per_group flat_idx <- dPrimV "flat_idx" $ tvExp chunk_offset + sExt64 (kernelLocalThreadId constants) -- Construct segment indices. zipWithM_ dPrimV_ gtids $ unflattenIndex dims' $ tvExp flat_idx let per_scan_pes = segBinOpChunks scans all_pes in_bounds = foldl1 (.&&.) $ zipWith (.<.) (map Imp.le64 gtids) dims' when_in_bounds = compileStms mempty (kernelBodyStms kbody) $ do let (all_scan_res, map_res) = splitAt (segBinOpResults scans) $ kernelBodyResult kbody per_scan_res = segBinOpChunks scans all_scan_res sComment "write to-scan values to parameters" $ mapM_ (writeToScanValues gtids) $ zip3 per_scan_pes scans per_scan_res sComment "write mapped values results to global memory" $ forM_ (zip (takeLast (length map_res) all_pes) map_res) $ \(pe, se) -> copyDWIMFix (patElemName pe) (map Imp.le64 gtids) (kernelResultSubExp se) [] sComment "threads in bounds read input" $ sWhen in_bounds when_in_bounds unless (all (null . segBinOpShape) scans) $ sOp $ Imp.Barrier Imp.FenceGlobal forM_ (zip3 per_scan_pes scans all_local_arrs) $ \(pes, scan@(SegBinOp _ scan_op nes vec_shape), local_arrs) -> sComment "do one intra-group scan operation" $ do let rets = lambdaReturnType scan_op scan_x_params = xParams scan (array_scan, fence, barrier) = barrierFor scan_op when array_scan barrier sLoopNest vec_shape $ \vec_is -> do sComment "maybe restore some to-scan values to parameters, or read neutral" $ sIf in_bounds ( do readToScanValues (map Imp.le64 gtids ++ vec_is) pes scan readCarries j (tvExp chunk_offset) dims' vec_is pes scan ) ( forM_ (zip (yParams scan) (segBinOpNeutral scan)) $ \(p, ne) -> copyDWIMFix (paramName p) [] ne [] ) sComment "combine with carry and write to shared memory" $ compileStms mempty (bodyStms $ lambdaBody scan_op) $ forM_ (zip3 rets local_arrs $ map resSubExp $ bodyResult $ lambdaBody scan_op) $ \(t, arr, se) -> copyDWIMFix arr [localArrayIndex constants t] se [] let crossesSegment' = do f <- crossesSegment Just $ \from to -> let from' = sExt64 from + tvExp chunk_offset to' = sExt64 to + tvExp chunk_offset in f from' to' sOp $ Imp.ErrorSync fence -- We need to avoid parameter name clashes. scan_op_renamed <- renameLambda scan_op blockScan crossesSegment' (sExt64 $ tvExp num_threads) (sExt64 $ kernelBlockSize constants) scan_op_renamed local_arrs sComment "threads in bounds write partial scan result" $ sWhen in_bounds $ forM_ (zip3 rets pes local_arrs) $ \(t, pe, arr) -> copyDWIMFix (patElemName pe) (map Imp.le64 gtids ++ vec_is) (Var arr) [localArrayIndex constants t] barrier let load_carry = forM_ (zip local_arrs scan_x_params) $ \(arr, p) -> copyDWIMFix (paramName p) [] (Var arr) [ if primType $ paramType p then sExt64 (kernelBlockSize constants) - 1 else (sExt64 (kernelBlockId constants) + 1) * sExt64 (kernelBlockSize constants) - 1 ] load_neutral = forM_ (zip nes scan_x_params) $ \(ne, p) -> copyDWIMFix (paramName p) [] ne [] sComment "first thread reads last element as carry-in for next iteration" $ do crosses_segment <- dPrimVE "crosses_segment" $ case crossesSegment of Nothing -> false Just f -> f ( tvExp chunk_offset + sExt64 (kernelBlockSize constants) - 1 ) ( tvExp chunk_offset + sExt64 (kernelBlockSize constants) ) should_load_carry <- dPrimVE "should_load_carry" $ kernelLocalThreadId constants .==. 0 .&&. bNot crosses_segment sWhen should_load_carry load_carry when array_scan barrier sUnless should_load_carry load_neutral barrier pure (num_threads, elems_per_group, crossesSegment) scanStage2 :: Pat LetDecMem -> TV Int32 -> Imp.TExp Int64 -> Count NumBlocks SubExp -> CrossesSegment -> SegSpace -> [SegBinOp GPUMem] -> CallKernelGen () scanStage2 (Pat all_pes) stage1_num_threads elems_per_group num_tblocks crossesSegment space scans = do let (gtids, dims) = unzip $ unSegSpace space dims' = map pe64 dims -- Our group size is the number of groups for the stage 1 kernel. let tblock_size = Count $ unCount num_tblocks let crossesSegment' = do f <- crossesSegment Just $ \from to -> f ((sExt64 from + 1) * elems_per_group - 1) ((sExt64 to + 1) * elems_per_group - 1) sKernelThread "scan_stage2" (segFlat space) (defKernelAttrs (Count (intConst Int64 1)) tblock_size) $ do constants <- kernelConstants <$> askEnv per_scan_local_arrs <- makeLocalArrays tblock_size (tvSize stage1_num_threads) scans let per_scan_rets = map (lambdaReturnType . segBinOpLambda) scans per_scan_pes = segBinOpChunks scans all_pes flat_idx <- dPrimV "flat_idx" $ (sExt64 (kernelLocalThreadId constants) + 1) * elems_per_group - 1 -- Construct segment indices. zipWithM_ dPrimV_ gtids $ unflattenIndex dims' $ tvExp flat_idx forM_ (L.zip4 scans per_scan_local_arrs per_scan_rets per_scan_pes) $ \(SegBinOp _ scan_op nes vec_shape, local_arrs, rets, pes) -> sLoopNest vec_shape $ \vec_is -> do let glob_is = map Imp.le64 gtids ++ vec_is in_bounds = foldl1 (.&&.) $ zipWith (.<.) (map Imp.le64 gtids) dims' when_in_bounds = forM_ (zip3 rets local_arrs pes) $ \(t, arr, pe) -> copyDWIMFix arr [localArrayIndex constants t] (Var $ patElemName pe) glob_is when_out_of_bounds = forM_ (zip3 rets local_arrs nes) $ \(t, arr, ne) -> copyDWIMFix arr [localArrayIndex constants t] ne [] (_, _, barrier) = barrierFor scan_op sComment "threads in bound read carries; others get neutral element" $ sIf in_bounds when_in_bounds when_out_of_bounds barrier blockScan crossesSegment' (sExt64 $ tvExp stage1_num_threads) (sExt64 $ kernelBlockSize constants) scan_op local_arrs sComment "threads in bounds write scanned carries" $ sWhen in_bounds $ forM_ (zip3 rets pes local_arrs) $ \(t, pe, arr) -> copyDWIMFix (patElemName pe) glob_is (Var arr) [localArrayIndex constants t] scanStage3 :: Pat LetDecMem -> Count NumBlocks SubExp -> Count BlockSize SubExp -> Imp.TExp Int64 -> CrossesSegment -> SegSpace -> [SegBinOp GPUMem] -> CallKernelGen () scanStage3 (Pat all_pes) num_tblocks tblock_size elems_per_group crossesSegment space scans = do let tblock_size' = fmap pe64 tblock_size (gtids, dims) = unzip $ unSegSpace space dims' = map pe64 dims required_groups <- dPrimVE "required_groups" $ sExt32 $ product dims' `divUp` sExt64 (unCount tblock_size') sKernelThread "scan_stage3" (segFlat space) (defKernelAttrs num_tblocks tblock_size) $ virtualiseBlocks SegVirt required_groups $ \virt_tblock_id -> do constants <- kernelConstants <$> askEnv -- Compute our logical index. flat_idx <- dPrimVE "flat_idx" $ sExt64 virt_tblock_id * sExt64 (unCount tblock_size') + sExt64 (kernelLocalThreadId constants) zipWithM_ dPrimV_ gtids $ unflattenIndex dims' flat_idx -- Figure out which group this element was originally in. orig_group <- dPrimV "orig_group" $ flat_idx `quot` elems_per_group -- Then the index of the carry-in of the preceding group. carry_in_flat_idx <- dPrimV "carry_in_flat_idx" $ tvExp orig_group * elems_per_group - 1 -- Figure out the logical index of the carry-in. let carry_in_idx = unflattenIndex dims' $ tvExp carry_in_flat_idx -- Apply the carry if we are not in the scan results for the first -- group, and are not the last element in such a group (because -- then the carry was updated in stage 2), and we are not crossing -- a segment boundary. let in_bounds = foldl1 (.&&.) $ zipWith (.<.) (map Imp.le64 gtids) dims' crosses_segment = fromMaybe false $ crossesSegment <*> pure (tvExp carry_in_flat_idx) <*> pure flat_idx is_a_carry = flat_idx .==. (tvExp orig_group + 1) * elems_per_group - 1 no_carry_in = tvExp orig_group .==. 0 .||. is_a_carry .||. crosses_segment let per_scan_pes = segBinOpChunks scans all_pes sWhen in_bounds $ sUnless no_carry_in $ forM_ (zip per_scan_pes scans) $ \(pes, SegBinOp _ scan_op nes vec_shape) -> do dScope Nothing $ scopeOfLParams $ lambdaParams scan_op let (scan_x_params, scan_y_params) = splitAt (length nes) $ lambdaParams scan_op sLoopNest vec_shape $ \vec_is -> do forM_ (zip scan_x_params pes) $ \(p, pe) -> copyDWIMFix (paramName p) [] (Var $ patElemName pe) (carry_in_idx ++ vec_is) forM_ (zip scan_y_params pes) $ \(p, pe) -> copyDWIMFix (paramName p) [] (Var $ patElemName pe) (map Imp.le64 gtids ++ vec_is) compileBody' scan_x_params $ lambdaBody scan_op forM_ (zip scan_x_params pes) $ \(p, pe) -> copyDWIMFix (patElemName pe) (map Imp.le64 gtids ++ vec_is) (Var $ paramName p) [] -- | Compile 'SegScan' instance to host-level code with calls to -- various kernels. compileSegScan :: Pat LetDecMem -> SegLevel -> SegSpace -> [SegBinOp GPUMem] -> KernelBody GPUMem -> CallKernelGen () compileSegScan pat lvl space scans kbody = do attrs <- lvlKernelAttrs lvl -- Since stage 2 involves a group size equal to the number of groups -- used for stage 1, we have to cap this number to the maximum group -- size. stage1_max_num_tblocks <- dPrim "stage1_max_num_tblocks" sOp $ Imp.GetSizeMax (tvVar stage1_max_num_tblocks) SizeThreadBlock stage1_num_tblocks <- fmap (Imp.Count . tvSize) $ dPrimV "stage1_num_tblocks" $ sMin64 (tvExp stage1_max_num_tblocks) $ pe64 . Imp.unCount . kAttrNumBlocks $ attrs (stage1_num_threads, elems_per_group, crossesSegment) <- scanStage1 pat stage1_num_tblocks (kAttrBlockSize attrs) space scans kbody emit $ Imp.DebugPrint "elems_per_group" $ Just $ untyped elems_per_group scanStage2 pat stage1_num_threads elems_per_group stage1_num_tblocks crossesSegment space scans scanStage3 pat (kAttrNumBlocks attrs) (kAttrBlockSize attrs) elems_per_group crossesSegment space scans futhark-0.25.27/src/Futhark/CodeGen/ImpGen/GPU/ToOpenCL.hs000066400000000000000000000777251475065116200226540ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} -- | This module defines a translation from imperative code with -- kernels to imperative code with OpenCL or CUDA calls. module Futhark.CodeGen.ImpGen.GPU.ToOpenCL ( kernelsToOpenCL, kernelsToCUDA, kernelsToHIP, ) where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor (second) import Data.Foldable (toList) import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T import Futhark.CodeGen.Backends.GenericC.Fun qualified as GC import Futhark.CodeGen.Backends.GenericC.Pretty import Futhark.CodeGen.Backends.SimpleRep import Futhark.CodeGen.ImpCode.GPU hiding (Program) import Futhark.CodeGen.ImpCode.GPU qualified as ImpGPU import Futhark.CodeGen.ImpCode.OpenCL hiding (Program) import Futhark.CodeGen.ImpCode.OpenCL qualified as ImpOpenCL import Futhark.CodeGen.RTS.C (atomicsH, halfH) import Futhark.CodeGen.RTS.CUDA (preludeCU) import Futhark.CodeGen.RTS.OpenCL (copyCL, preludeCL, transposeCL) import Futhark.Error (compilerLimitationS) import Futhark.MonadFreshNames import Futhark.Util (zEncodeText) import Futhark.Util.IntegralExp (rem) import Language.C.Quote.OpenCL qualified as C import Language.C.Syntax qualified as C import NeatInterpolation (untrimming) import Prelude hiding (rem) -- | Generate HIP host and device code. kernelsToHIP :: ImpGPU.Program -> ImpOpenCL.Program kernelsToHIP = translateGPU TargetHIP -- | Generate CUDA host and device code. kernelsToCUDA :: ImpGPU.Program -> ImpOpenCL.Program kernelsToCUDA = translateGPU TargetCUDA -- | Generate OpenCL host and device code. kernelsToOpenCL :: ImpGPU.Program -> ImpOpenCL.Program kernelsToOpenCL = translateGPU TargetOpenCL -- | Translate a kernels-program to an OpenCL-program. translateGPU :: KernelTarget -> ImpGPU.Program -> ImpOpenCL.Program translateGPU target prog = let env = envFromProg prog ( prog', ToOpenCL kernels device_funs used_types sizes failures constants ) = (`runState` initialOpenCL) . (`runReaderT` env) $ do let ImpGPU.Definitions types (ImpGPU.Constants ps consts) (ImpGPU.Functions funs) = prog consts' <- traverse (onHostOp target) consts funs' <- forM funs $ \(fname, fun) -> (fname,) <$> traverse (onHostOp target) fun pure $ ImpOpenCL.Definitions types (ImpOpenCL.Constants ps consts') (ImpOpenCL.Functions funs') (device_prototypes, device_defs) = unzip $ M.elems device_funs kernels' = M.map fst kernels opencl_code = T.unlines $ map snd $ M.elems kernels opencl_prelude = T.unlines [ genPrelude target used_types, definitionsText device_prototypes, T.unlines device_defs ] in ImpOpenCL.Program { openClProgram = opencl_code, openClPrelude = opencl_prelude, openClMacroDefs = constants, openClKernelNames = kernels', openClUsedTypes = S.toList used_types, openClParams = findParamUsers env prog' (cleanSizes sizes), openClFailures = failures, hostDefinitions = prog' } where genPrelude TargetOpenCL = genOpenClPrelude genPrelude TargetCUDA = const genCUDAPrelude genPrelude TargetHIP = const genHIPPrelude -- | Due to simplifications after kernel extraction, some threshold -- parameters may contain KernelPaths that reference threshold -- parameters that no longer exist. We remove these here. cleanSizes :: M.Map Name SizeClass -> M.Map Name SizeClass cleanSizes m = M.map clean m where known = M.keys m clean (SizeThreshold path def) = SizeThreshold (filter ((`elem` known) . fst) path) def clean s = s findParamUsers :: Env -> Definitions ImpOpenCL.OpenCL -> M.Map Name SizeClass -> ParamMap findParamUsers env defs = M.mapWithKey onParam where cg = envCallGraph env getSize (ImpOpenCL.GetSize _ v) = Just v getSize (ImpOpenCL.CmpSizeLe _ v _) = Just v getSize (ImpOpenCL.GetSizeMax {}) = Nothing getSize (ImpOpenCL.LaunchKernel {}) = Nothing directUseInFun fun = mapMaybe getSize $ toList $ functionBody fun direct_uses = map (second directUseInFun) $ unFunctions $ defFuns defs calledBy fname = M.findWithDefault mempty fname cg indirectUseInFun fname = ( fname, foldMap snd $ filter ((`S.member` calledBy fname) . fst) direct_uses ) indirect_uses = direct_uses <> map (indirectUseInFun . fst) direct_uses onParam k c = (c, S.fromList $ map fst $ filter ((k `elem`) . snd) indirect_uses) pointerQuals :: String -> [C.TypeQual] pointerQuals "global" = [C.ctyquals|__global|] pointerQuals "shared" = [C.ctyquals|__local|] pointerQuals "private" = [C.ctyquals|__private|] pointerQuals "constant" = [C.ctyquals|__constant|] pointerQuals "write_only" = [C.ctyquals|__write_only|] pointerQuals "read_only" = [C.ctyquals|__read_only|] pointerQuals "kernel" = [C.ctyquals|__kernel|] -- OpenCL does not actually have a "device" space, but we use it in -- the compiler pipeline to defer to memory on the device, as opposed -- to the host. From a kernel's perspective, this is "global". pointerQuals "device" = pointerQuals "global" pointerQuals s = error $ "'" ++ s ++ "' is not an OpenCL kernel address space." -- In-kernel name and per-threadblock size in bytes. type SharedMemoryUse = (VName, Count Bytes (TExp Int64)) data KernelState = KernelState { kernelSharedMemory :: [SharedMemoryUse], kernelFailures :: [FailureMsg], kernelNextSync :: Int, -- | Has a potential failure occurred sine the last -- ErrorSync? kernelSyncPending :: Bool, kernelHasBarriers :: Bool } newKernelState :: [FailureMsg] -> KernelState newKernelState failures = KernelState mempty failures 0 False False errorLabel :: KernelState -> String errorLabel = ("error_" ++) . show . kernelNextSync data ToOpenCL = ToOpenCL { clGPU :: M.Map KernelName (KernelSafety, T.Text), clDevFuns :: M.Map Name (C.Definition, T.Text), clUsedTypes :: S.Set PrimType, clSizes :: M.Map Name SizeClass, clFailures :: [FailureMsg], clConstants :: [(Name, KernelConstExp)] } initialOpenCL :: ToOpenCL initialOpenCL = ToOpenCL mempty mempty mempty mempty mempty mempty data Env = Env { envFuns :: ImpGPU.Functions ImpGPU.HostOp, envFunsMayFail :: S.Set Name, envCallGraph :: M.Map Name (S.Set Name) } codeMayFail :: (a -> Bool) -> ImpGPU.Code a -> Bool codeMayFail _ (Assert {}) = True codeMayFail f (Op x) = f x codeMayFail f (x :>>: y) = codeMayFail f x || codeMayFail f y codeMayFail f (For _ _ x) = codeMayFail f x codeMayFail f (While _ x) = codeMayFail f x codeMayFail f (If _ x y) = codeMayFail f x || codeMayFail f y codeMayFail f (Comment _ x) = codeMayFail f x codeMayFail _ _ = False hostOpMayFail :: ImpGPU.HostOp -> Bool hostOpMayFail (CallKernel k) = codeMayFail kernelOpMayFail $ kernelBody k hostOpMayFail _ = False kernelOpMayFail :: ImpGPU.KernelOp -> Bool kernelOpMayFail = const False funsMayFail :: M.Map Name (S.Set Name) -> ImpGPU.Functions ImpGPU.HostOp -> S.Set Name funsMayFail cg (Functions funs) = S.fromList $ map fst $ filter mayFail funs where base_mayfail = map fst $ filter (codeMayFail hostOpMayFail . ImpGPU.functionBody . snd) funs mayFail (fname, _) = any (`elem` base_mayfail) $ fname : S.toList (M.findWithDefault mempty fname cg) envFromProg :: ImpGPU.Program -> Env envFromProg prog = Env funs (funsMayFail cg funs) cg where funs = defFuns prog cg = ImpGPU.callGraph calledInHostOp funs lookupFunction :: Name -> Env -> Maybe (ImpGPU.Function HostOp) lookupFunction fname = lookup fname . unFunctions . envFuns functionMayFail :: Name -> Env -> Bool functionMayFail fname = S.member fname . envFunsMayFail type OnKernelM = ReaderT Env (State ToOpenCL) addSize :: Name -> SizeClass -> OnKernelM () addSize key sclass = modify $ \s -> s {clSizes = M.insert key sclass $ clSizes s} onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL onHostOp target (CallKernel k) = onKernel target k onHostOp _ (ImpGPU.GetSize v key size_class) = do addSize key size_class pure $ ImpOpenCL.GetSize v key onHostOp _ (ImpGPU.CmpSizeLe v key size_class x) = do addSize key size_class pure $ ImpOpenCL.CmpSizeLe v key x onHostOp _ (ImpGPU.GetSizeMax v size_class) = pure $ ImpOpenCL.GetSizeMax v size_class genGPUCode :: Env -> OpsMode -> KernelCode -> [FailureMsg] -> GC.CompilerM KernelOp KernelState a -> (a, GC.CompilerState KernelState) genGPUCode env mode body failures = GC.runCompilerM (inKernelOperations env mode body) blankNameSource (newKernelState failures) -- Compilation of a device function that is not not invoked from the -- host, but is invoked by (perhaps multiple) kernels. generateDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> OnKernelM () generateDeviceFun fname device_func = do when (any memParam $ functionInput device_func) bad env <- ask failures <- gets clFailures let (func, kstate) = if functionMayFail fname env then let params = [ [C.cparam|__global int *global_failure|], [C.cparam|__global typename int64_t *global_failure_args|] ] (f, cstate) = genGPUCode env FunMode (declsFirst $ functionBody device_func) failures $ GC.compileFun mempty params (fname, device_func) in (f, GC.compUserState cstate) else let (f, cstate) = genGPUCode env FunMode (declsFirst $ functionBody device_func) failures $ GC.compileVoidFun mempty (fname, device_func) in (f, GC.compUserState cstate) modify $ \s -> s { clUsedTypes = typesInCode (functionBody device_func) <> clUsedTypes s, clDevFuns = M.insert fname (second funcText func) $ clDevFuns s, clFailures = kernelFailures kstate } -- Important to do this after the 'modify' call, so we propagate the -- right clFailures. void $ ensureDeviceFuns $ functionBody device_func where memParam MemParam {} = True memParam ScalarParam {} = False bad = compilerLimitationS "Cannot generate GPU functions that use arrays." -- Ensure that this device function is available, but don't regenerate -- it if it already exists. ensureDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> OnKernelM () ensureDeviceFun fname host_func = do exists <- gets $ M.member fname . clDevFuns unless exists $ generateDeviceFun fname host_func calledInHostOp :: HostOp -> S.Set Name calledInHostOp (CallKernel k) = calledFuncs calledInKernelOp $ kernelBody k calledInHostOp _ = mempty calledInKernelOp :: KernelOp -> S.Set Name calledInKernelOp = const mempty ensureDeviceFuns :: ImpGPU.KernelCode -> OnKernelM [Name] ensureDeviceFuns code = do let called = calledFuncs calledInKernelOp code fmap catMaybes . forM (S.toList called) $ \fname -> do def <- asks $ lookupFunction fname case def of Just host_func -> do -- Functions are a priori always considered host-level, so we have -- to convert them to device code. This is where most of our -- limitations on device-side functions (no arrays, no parallelism) -- comes from. let device_func = fmap toDevice host_func ensureDeviceFun fname device_func pure $ Just fname Nothing -> pure Nothing where bad = compilerLimitationS "Cannot generate GPU functions that contain parallelism." toDevice :: HostOp -> KernelOp toDevice _ = bad isConst :: BlockDim -> Maybe KernelConstExp isConst (Left (ValueExp (IntValue x))) = Just $ ValueExp (IntValue x) isConst (Right e) = Just e isConst _ = Nothing onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL onKernel target kernel = do called <- ensureDeviceFuns $ kernelBody kernel -- Crucial that this is done after 'ensureDeviceFuns', as the device -- functions may themselves define failure points. failures <- gets clFailures env <- ask let (kernel_body, cstate) = genGPUCode env KernelMode (kernelBody kernel) failures . GC.collect $ do body <- GC.collect $ GC.compileCode $ declsFirst $ kernelBody kernel -- No need to free, as we cannot allocate memory in kernels. mapM_ GC.item =<< GC.declAllocatedMem mapM_ GC.item body kstate = GC.compUserState cstate (kernel_consts, (const_defs, const_undefs)) = second unzip $ unzip $ mapMaybe (constDef (kernelName kernel)) $ kernelUses kernel let (_, shared_memory_init) = L.mapAccumL prepareSharedMemory [C.cexp|0|] (kernelSharedMemory kstate) shared_memory_bytes = sum $ map (padTo8 . snd) $ kernelSharedMemory kstate let (use_params, unpack_params) = unzip $ mapMaybe useAsParam $ kernelUses kernel -- The local_failure variable is an int despite only really storing -- a single bit of information, as some OpenCL implementations -- (e.g. AMD) does not like byte-sized shared memory (and the others -- likely pad to a whole word anyway). let (safety, error_init) -- We conservatively assume that any called function can fail. | not $ null called = ( SafetyFull, [C.citems|volatile __local int local_failure; // Harmless for all threads to write this. local_failure = 0;|] ) | length (kernelFailures kstate) == length failures = if kernelFailureTolerant kernel then (SafetyNone, []) else -- No possible failures in this kernel, so if we make -- it past an initial check, then we are good to go. ( SafetyCheap, [C.citems|if (*global_failure >= 0) { return; }|] ) | otherwise = if not (kernelHasBarriers kstate) then ( SafetyFull, [C.citems|if (*global_failure >= 0) { return; }|] ) else ( SafetyFull, [C.citems| volatile __local int local_failure; if (failure_is_an_option) { int failed = *global_failure >= 0; if (failed) { return; } } // All threads write this value - it looks like CUDA has a compiler bug otherwise. local_failure = 0; barrier(CLK_LOCAL_MEM_FENCE); |] ) failure_params = [ [C.cparam|__global int *global_failure|], [C.cparam|int failure_is_an_option|], [C.cparam|__global typename int64_t *global_failure_args|] ] (shared_memory_param, prepare_shared_memory) = case target of TargetOpenCL -> ( [[C.cparam|__local typename uint64_t* shared_mem_aligned|]], [C.citems|__local unsigned char* shared_mem = (__local unsigned char*)shared_mem_aligned;|] ) TargetCUDA -> (mempty, mempty) TargetHIP -> (mempty, mempty) params = shared_memory_param ++ take (numFailureParams safety) failure_params ++ use_params (attribute_consts, attribute) = case mapM isConst $ kernelBlockSize kernel of Just [x, y, z] -> ( [(xv, x), (yv, y), (zv, z)], "FUTHARK_KERNEL_SIZED" <> prettyText (xv, yv, zv) <> "\n" ) where xv = nameFromText $ zEncodeText $ nameToText name <> "_dim1" yv = nameFromText $ zEncodeText $ nameToText name <> "_dim2" zv = nameFromText $ zEncodeText $ nameToText name <> "_dim3" Just [x, y] -> ( [(xv, x), (yv, y)], "FUTHARK_KERNEL_SIZED" <> prettyText (xv, yv, 1 :: Int) <> "\n" ) where xv = nameFromText $ zEncodeText $ nameToText name <> "_dim1" yv = nameFromText $ zEncodeText $ nameToText name <> "_dim2" Just [x] -> ( [(xv, x)], "FUTHARK_KERNEL_SIZED" <> prettyText (xv, 1 :: Int, 1 :: Int) <> "\n" ) where xv = nameFromText $ zEncodeText $ nameToText name <> "_dim1" _ -> (mempty, "FUTHARK_KERNEL\n") kernel_fun = attribute <> funcText [C.cfun|void $id:name ($params:params) { $items:(mconcat unpack_params) $items:const_defs $items:prepare_shared_memory $items:(mconcat shared_memory_init) $items:error_init $items:kernel_body $id:(errorLabel kstate): return; $items:const_undefs }|] modify $ \s -> s { clGPU = M.insert name (safety, kernel_fun) $ clGPU s, clUsedTypes = typesInKernel kernel <> clUsedTypes s, clFailures = kernelFailures kstate, clConstants = attribute_consts <> kernel_consts <> clConstants s } -- The error handling stuff is automatically added later. let args = kernelArgs kernel pure $ LaunchKernel safety name shared_memory_bytes args num_tblocks tblock_size where name = kernelName kernel num_tblocks = kernelNumBlocks kernel tblock_size = kernelBlockSize kernel padTo8 e = e + ((8 - (e `rem` 8)) `rem` 8) prepareSharedMemory offset (mem, Count size) = let offset_v = nameFromText $ prettyText mem <> "_offset" in ( [C.cexp|$id:offset_v|], [C.citems| volatile __local $ty:defaultMemBlockType $id:mem = &shared_mem[$exp:offset]; const typename int64_t $id:offset_v = $exp:offset + $exp:(padTo8 size); |] ) useAsParam :: KernelUse -> Maybe (C.Param, [C.BlockItem]) useAsParam (ScalarUse name pt) = do let name_bits = zEncodeText (prettyText name) <> "_bits" ctp = case pt of -- OpenCL does not permit bool as a kernel parameter type. Bool -> [C.cty|unsigned char|] Unit -> [C.cty|unsigned char|] _ -> primStorageType pt if ctp == primTypeToCType pt then Just ([C.cparam|$ty:ctp $id:name|], []) else let name_bits_e = [C.cexp|$id:name_bits|] in Just ( [C.cparam|$ty:ctp $id:name_bits|], [[C.citem|$ty:(primTypeToCType pt) $id:name = $exp:(fromStorage pt name_bits_e);|]] ) useAsParam (MemoryUse name) = Just ([C.cparam|__global $ty:defaultMemBlockType $id:name|], []) useAsParam ConstUse {} = Nothing -- Constants are #defined as macros. Since a constant name in one -- kernel might potentially (although unlikely) also be used for -- something else in another kernel, we #undef them after the kernel. constDef :: Name -> KernelUse -> Maybe ((Name, KernelConstExp), (C.BlockItem, C.BlockItem)) constDef kernel_name (ConstUse v e) = Just ( (nameFromText v', e), ( [C.citem|$escstm:(T.unpack def)|], [C.citem|$escstm:(T.unpack undef)|] ) ) where v' = zEncodeText $ nameToText kernel_name <> "." <> prettyText v def = "#define " <> idText (C.toIdent v mempty) <> " (" <> v' <> ")" undef = "#undef " <> idText (C.toIdent v mempty) constDef _ _ = Nothing commonPrelude :: T.Text commonPrelude = halfH <> cScalarDefs <> atomicsH <> transposeCL <> copyCL genOpenClPrelude :: S.Set PrimType -> T.Text genOpenClPrelude ts = "#define FUTHARK_OPENCL\n" <> enable_f64 <> preludeCL <> commonPrelude where enable_f64 | FloatType Float64 `S.member` ts = [untrimming|#define FUTHARK_F64_ENABLED|] | otherwise = mempty genCUDAPrelude :: T.Text genCUDAPrelude = "#define FUTHARK_CUDA\n" <> preludeCU <> commonPrelude genHIPPrelude :: T.Text genHIPPrelude = "#define FUTHARK_HIP\n" <> preludeCU <> commonPrelude kernelArgs :: Kernel -> [KernelArg] kernelArgs = mapMaybe useToArg . kernelUses where useToArg (MemoryUse mem) = Just $ MemKArg mem useToArg (ScalarUse v pt) = Just $ ValueKArg (LeafExp v pt) pt useToArg ConstUse {} = Nothing nextErrorLabel :: GC.CompilerM KernelOp KernelState String nextErrorLabel = errorLabel <$> GC.getUserState incErrorLabel :: GC.CompilerM KernelOp KernelState () incErrorLabel = GC.modifyUserState $ \s -> s {kernelNextSync = kernelNextSync s + 1} pendingError :: Bool -> GC.CompilerM KernelOp KernelState () pendingError b = GC.modifyUserState $ \s -> s {kernelSyncPending = b} hasCommunication :: ImpGPU.KernelCode -> Bool hasCommunication = any communicates where communicates ErrorSync {} = True communicates Barrier {} = True communicates _ = False -- Whether we are generating code for a kernel or a device function. -- This has minor effects, such as exactly how failures are -- propagated. data OpsMode = KernelMode | FunMode deriving (Eq) inKernelOperations :: Env -> OpsMode -> ImpGPU.KernelCode -> GC.Operations KernelOp KernelState inKernelOperations env mode body = GC.Operations { GC.opsCompiler = kernelOps, GC.opsMemoryType = kernelMemoryType, GC.opsWriteScalar = kernelWriteScalar, GC.opsReadScalar = kernelReadScalar, GC.opsAllocate = cannotAllocate, GC.opsDeallocate = cannotDeallocate, GC.opsCopy = copyInKernel, GC.opsCopies = mempty, GC.opsFatMemory = False, GC.opsError = errorInKernel, GC.opsCall = callInKernel, GC.opsCritical = mempty } where has_communication = hasCommunication body fence FenceLocal = [C.cexp|CLK_LOCAL_MEM_FENCE|] fence FenceGlobal = [C.cexp|CLK_GLOBAL_MEM_FENCE | CLK_LOCAL_MEM_FENCE|] kernelOps :: GC.OpCompiler KernelOp KernelState kernelOps (GetBlockId v i) = GC.stm [C.cstm|$id:v = get_tblock_id($int:i);|] kernelOps (GetLocalId v i) = GC.stm [C.cstm|$id:v = get_local_id($int:i);|] kernelOps (GetLocalSize v i) = GC.stm [C.cstm|$id:v = get_local_size($int:i);|] kernelOps (GetLockstepWidth v) = GC.stm [C.cstm|$id:v = LOCKSTEP_WIDTH;|] kernelOps (Barrier f) = do GC.stm [C.cstm|barrier($exp:(fence f));|] GC.modifyUserState $ \s -> s {kernelHasBarriers = True} kernelOps (MemFence FenceLocal) = GC.stm [C.cstm|mem_fence_local();|] kernelOps (MemFence FenceGlobal) = GC.stm [C.cstm|mem_fence_global();|] kernelOps (SharedAlloc name size) = do name' <- newVName $ prettyString name ++ "_backing" GC.modifyUserState $ \s -> s {kernelSharedMemory = (name', size) : kernelSharedMemory s} GC.stm [C.cstm|$id:name = (__local unsigned char*) $id:name';|] kernelOps (ErrorSync f) = do label <- nextErrorLabel pending <- kernelSyncPending <$> GC.getUserState when pending $ do pendingError False GC.stm [C.cstm|$id:label: barrier($exp:(fence f));|] GC.stm [C.cstm|if (local_failure) { return; }|] GC.stm [C.cstm|barrier($exp:(fence f));|] GC.modifyUserState $ \s -> s {kernelHasBarriers = True} incErrorLabel kernelOps (Atomic space aop) = atomicOps space aop atomicCast s t = do let volatile = [C.ctyquals|volatile|] let quals = case s of Space sid -> pointerQuals sid _ -> pointerQuals "global" pure [C.cty|$tyquals:(volatile++quals) $ty:t|] atomicSpace (Space sid) = sid atomicSpace _ = "global" doAtomic s t old arr ind val op ty = do ind' <- GC.compileExp $ untyped $ unCount ind val' <- GC.compileExp val cast <- atomicCast s ty GC.stm [C.cstm|$id:old = $id:op'(&(($ty:cast *)$id:arr)[$exp:ind'], ($ty:ty) $exp:val');|] where op' = op ++ "_" ++ prettyString t ++ "_" ++ atomicSpace s doAtomicCmpXchg s t old arr ind cmp val ty = do ind' <- GC.compileExp $ untyped $ unCount ind cmp' <- GC.compileExp cmp val' <- GC.compileExp val cast <- atomicCast s ty GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:cmp', $exp:val');|] where op = "atomic_cmpxchg_" ++ prettyString t ++ "_" ++ atomicSpace s doAtomicXchg s t old arr ind val ty = do cast <- atomicCast s ty ind' <- GC.compileExp $ untyped $ unCount ind val' <- GC.compileExp val GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:val');|] where op = "atomic_chg_" ++ prettyString t ++ "_" ++ atomicSpace s -- First the 64-bit operations. atomicOps s (AtomicAdd Int64 old arr ind val) = doAtomic s Int64 old arr ind val "atomic_add" [C.cty|typename int64_t|] atomicOps s (AtomicFAdd Float64 old arr ind val) = doAtomic s Float64 old arr ind val "atomic_fadd" [C.cty|double|] atomicOps s (AtomicSMax Int64 old arr ind val) = doAtomic s Int64 old arr ind val "atomic_smax" [C.cty|typename int64_t|] atomicOps s (AtomicSMin Int64 old arr ind val) = doAtomic s Int64 old arr ind val "atomic_smin" [C.cty|typename int64_t|] atomicOps s (AtomicUMax Int64 old arr ind val) = doAtomic s Int64 old arr ind val "atomic_umax" [C.cty|unsigned int64_t|] atomicOps s (AtomicUMin Int64 old arr ind val) = doAtomic s Int64 old arr ind val "atomic_umin" [C.cty|unsigned int64_t|] atomicOps s (AtomicAnd Int64 old arr ind val) = doAtomic s Int64 old arr ind val "atomic_and" [C.cty|typename int64_t|] atomicOps s (AtomicOr Int64 old arr ind val) = doAtomic s Int64 old arr ind val "atomic_or" [C.cty|typename int64_t|] atomicOps s (AtomicXor Int64 old arr ind val) = doAtomic s Int64 old arr ind val "atomic_xor" [C.cty|typename int64_t|] atomicOps s (AtomicCmpXchg (IntType Int64) old arr ind cmp val) = doAtomicCmpXchg s (IntType Int64) old arr ind cmp val [C.cty|typename int64_t|] atomicOps s (AtomicXchg (IntType Int64) old arr ind val) = doAtomicXchg s (IntType Int64) old arr ind val [C.cty|typename int64_t|] -- atomicOps s (AtomicAdd t old arr ind val) = doAtomic s t old arr ind val "atomic_add" [C.cty|int|] atomicOps s (AtomicFAdd t old arr ind val) = doAtomic s t old arr ind val "atomic_fadd" [C.cty|float|] atomicOps s (AtomicSMax t old arr ind val) = doAtomic s t old arr ind val "atomic_smax" [C.cty|int|] atomicOps s (AtomicSMin t old arr ind val) = doAtomic s t old arr ind val "atomic_smin" [C.cty|int|] atomicOps s (AtomicUMax t old arr ind val) = doAtomic s t old arr ind val "atomic_umax" [C.cty|unsigned int|] atomicOps s (AtomicUMin t old arr ind val) = doAtomic s t old arr ind val "atomic_umin" [C.cty|unsigned int|] atomicOps s (AtomicAnd t old arr ind val) = doAtomic s t old arr ind val "atomic_and" [C.cty|int|] atomicOps s (AtomicOr t old arr ind val) = doAtomic s t old arr ind val "atomic_or" [C.cty|int|] atomicOps s (AtomicXor t old arr ind val) = doAtomic s t old arr ind val "atomic_xor" [C.cty|int|] atomicOps s (AtomicCmpXchg t old arr ind cmp val) = doAtomicCmpXchg s t old arr ind cmp val [C.cty|int|] atomicOps s (AtomicXchg t old arr ind val) = doAtomicXchg s t old arr ind val [C.cty|int|] atomicOps s (AtomicWrite t arr ind val) = do ind' <- GC.compileExp $ untyped $ unCount ind val' <- toStorage t <$> GC.compileExp val let quals = case s of Space sid -> pointerQuals sid _ -> pointerQuals "global" GC.stm [C.cstm|(($tyquals:quals $ty:(primStorageType t)*)$id:arr)[$exp:ind'] = $exp:val';|] GC.stm $ case s of Space "shared" -> [C.cstm|mem_fence_local();|] _ -> [C.cstm|mem_fence_global();|] cannotAllocate :: GC.Allocate KernelOp KernelState cannotAllocate _ = error "Cannot allocate memory in kernel" cannotDeallocate :: GC.Deallocate KernelOp KernelState cannotDeallocate _ _ = error "Cannot deallocate memory in kernel" copyInKernel :: GC.Copy KernelOp KernelState copyInKernel _ _ _ _ _ _ _ _ = error "Cannot bulk copy in kernel." kernelMemoryType space = pure [C.cty|$tyquals:(pointerQuals space) $ty:defaultMemBlockType|] kernelWriteScalar = GC.writeScalarPointerWithQuals pointerQuals kernelReadScalar = GC.readScalarPointerWithQuals pointerQuals whatNext = do label <- nextErrorLabel pendingError True pure $ if has_communication then [C.citems|local_failure = 1; goto $id:label;|] else if mode == FunMode then [C.citems|return 1;|] else [C.citems|return;|] callInKernel dests fname args | functionMayFail fname env = do let out_args = [[C.cexp|&$id:d|] | d <- dests] args' = [C.cexp|global_failure|] : [C.cexp|global_failure_args|] : out_args ++ args what_next <- whatNext GC.item [C.citem|if ($id:(funName fname)($args:args') != 0) { $items:what_next; }|] | otherwise = do let out_args = [[C.cexp|&$id:d|] | d <- dests] args' = out_args ++ args GC.item [C.citem|$id:(funName fname)($args:args');|] errorInKernel msg@(ErrorMsg parts) backtrace = do n <- length . kernelFailures <$> GC.getUserState GC.modifyUserState $ \s -> s {kernelFailures = kernelFailures s ++ [FailureMsg msg backtrace]} let setArgs _ [] = pure [] setArgs i (ErrorString {} : parts') = setArgs i parts' -- FIXME: bogus for non-ints. setArgs i (ErrorVal _ x : parts') = do x' <- GC.compileExp x stms <- setArgs (i + 1) parts' pure $ [C.cstm|global_failure_args[$int:i] = (typename int64_t)$exp:x';|] : stms argstms <- setArgs (0 :: Int) parts what_next <- whatNext GC.stm [C.cstm|{ if (atomic_cmpxchg_i32_global(global_failure, -1, $int:n) == -1) { $stms:argstms; } $items:what_next }|] --- Checking requirements typesInKernel :: Kernel -> S.Set PrimType typesInKernel kernel = typesInCode $ kernelBody kernel typesInCode :: ImpGPU.KernelCode -> S.Set PrimType typesInCode Skip = mempty typesInCode (c1 :>>: c2) = typesInCode c1 <> typesInCode c2 typesInCode (For _ e c) = typesInExp e <> typesInCode c typesInCode (While (TPrimExp e) c) = typesInExp e <> typesInCode c typesInCode DeclareMem {} = mempty typesInCode (DeclareScalar _ _ t) = S.singleton t typesInCode (DeclareArray _ t _) = S.singleton t typesInCode (Allocate _ (Count (TPrimExp e)) _) = typesInExp e typesInCode Free {} = mempty typesInCode (Copy _ shape _ (Count (TPrimExp dstoffset), dststrides) _ (Count (TPrimExp srcoffset), srcstrides)) = foldMap (typesInExp . untyped . unCount) shape <> typesInExp dstoffset <> foldMap (typesInExp . untyped . unCount) dststrides <> typesInExp srcoffset <> foldMap (typesInExp . untyped . unCount) srcstrides typesInCode (Write _ (Count (TPrimExp e1)) t _ _ e2) = typesInExp e1 <> S.singleton t <> typesInExp e2 typesInCode (Read _ _ (Count (TPrimExp e1)) t _ _) = typesInExp e1 <> S.singleton t typesInCode (SetScalar _ e) = typesInExp e typesInCode SetMem {} = mempty typesInCode (Call _ _ es) = mconcat $ map typesInArg es where typesInArg MemArg {} = mempty typesInArg (ExpArg e) = typesInExp e typesInCode (If (TPrimExp e) c1 c2) = typesInExp e <> typesInCode c1 <> typesInCode c2 typesInCode (Assert e _ _) = typesInExp e typesInCode (Comment _ c) = typesInCode c typesInCode (DebugPrint _ v) = maybe mempty typesInExp v typesInCode (TracePrint msg) = foldMap typesInExp msg typesInCode Op {} = mempty typesInExp :: Exp -> S.Set PrimType typesInExp (ValueExp v) = S.singleton $ primValueType v typesInExp (BinOpExp _ e1 e2) = typesInExp e1 <> typesInExp e2 typesInExp (CmpOpExp _ e1 e2) = typesInExp e1 <> typesInExp e2 typesInExp (ConvOpExp op e) = S.fromList [from, to] <> typesInExp e where (from, to) = convOpType op typesInExp (UnOpExp _ e) = typesInExp e typesInExp (FunExp _ args t) = S.singleton t <> mconcat (map typesInExp args) typesInExp LeafExp {} = mempty futhark-0.25.27/src/Futhark/CodeGen/ImpGen/HIP.hs000066400000000000000000000010121475065116200211660ustar00rootroot00000000000000-- | Code generation for ImpCode with HIP kernels. module Futhark.CodeGen.ImpGen.HIP ( compileProg, Warnings, ) where import Data.Bifunctor (second) import Futhark.CodeGen.ImpCode.OpenCL import Futhark.CodeGen.ImpGen.GPU import Futhark.CodeGen.ImpGen.GPU.ToOpenCL import Futhark.IR.GPUMem import Futhark.MonadFreshNames -- | Compile the program to ImpCode with HIP kernels. compileProg :: (MonadFreshNames m) => Prog GPUMem -> m (Warnings, Program) compileProg prog = second kernelsToHIP <$> compileProgHIP prog futhark-0.25.27/src/Futhark/CodeGen/ImpGen/Multicore.hs000066400000000000000000000162571475065116200225320ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Code generation for ImpCode with multicore operations. module Futhark.CodeGen.ImpGen.Multicore ( Futhark.CodeGen.ImpGen.Multicore.compileProg, Warnings, ) where import Control.Monad import Data.Map qualified as M import Futhark.CodeGen.ImpCode.Multicore qualified as Imp import Futhark.CodeGen.ImpGen import Futhark.CodeGen.ImpGen.Multicore.Base import Futhark.CodeGen.ImpGen.Multicore.SegHist import Futhark.CodeGen.ImpGen.Multicore.SegMap import Futhark.CodeGen.ImpGen.Multicore.SegRed import Futhark.CodeGen.ImpGen.Multicore.SegScan import Futhark.IR.MCMem import Futhark.MonadFreshNames import Futhark.Util.IntegralExp (rem) import Prelude hiding (quot, rem) opCompiler :: OpCompiler MCMem HostEnv Imp.Multicore opCompiler dest (Alloc e space) = compileAlloc dest e space opCompiler dest (Inner op) = compileMCOp dest op parallelCopy :: CopyCompiler MCMem HostEnv Imp.Multicore parallelCopy pt destloc srcloc = do seq_code <- collect $ localOps inThreadOps $ do body <- genCopy free_params <- freeParams body emit $ Imp.Op $ Imp.ParLoop "copy" body free_params free_params <- freeParams seq_code s <- prettyString <$> newVName "copy" iterations <- dPrimVE "iterations" $ product $ map pe64 srcshape let scheduling = Imp.SchedulerInfo (untyped iterations) Imp.Static emit . Imp.Op $ Imp.SegOp s free_params (Imp.ParallelTask seq_code) Nothing [] scheduling where MemLoc destmem _ _ = destloc MemLoc srcmem srcshape _ = srcloc genCopy = collect . inISPC . generateChunkLoop "copy" Vectorized $ \i -> do is <- dIndexSpace' "i" (map pe64 srcshape) i (_, destspace, destidx) <- fullyIndexArray' destloc is (_, srcspace, srcidx) <- fullyIndexArray' srcloc is tmp <- dPrimS "tmp" pt emit $ Imp.Read tmp srcmem srcidx pt srcspace Imp.Nonvolatile emit $ Imp.Write destmem destidx pt destspace Imp.Nonvolatile $ Imp.var tmp pt topLevelOps, inThreadOps :: Operations MCMem HostEnv Imp.Multicore inThreadOps = (defaultOperations opCompiler) { opsExpCompiler = compileMCExp } topLevelOps = (defaultOperations opCompiler) { opsExpCompiler = compileMCExp, opsCopyCompiler = parallelCopy } updateAcc :: Safety -> VName -> [SubExp] -> [SubExp] -> MulticoreGen () updateAcc safety acc is vs = sComment "UpdateAcc" $ do -- See the ImpGen implementation of UpdateAcc for general notes. let is' = map pe64 is (c, _space, arrs, dims, op) <- lookupAcc acc is' let boundsCheck = case safety of Safe -> sWhen (inBounds (Slice (map DimFix is')) dims) _ -> id boundsCheck $ case op of Nothing -> forM_ (zip arrs vs) $ \(arr, v) -> copyDWIMFix arr is' v [] Just lam -> do dLParams $ lambdaParams lam let (_x_params, y_params) = splitAt (length vs) $ map paramName $ lambdaParams lam forM_ (zip y_params vs) $ \(yp, v) -> copyDWIM yp [] v [] atomics <- hostAtomics <$> askEnv case atomicUpdateLocking atomics lam of AtomicPrim f -> f arrs is' AtomicCAS f -> f arrs is' AtomicLocking f -> do c_locks <- M.lookup c . hostLocks <$> askEnv case c_locks of Just (Locks locks num_locks) -> do let locking = Locking locks 0 1 0 $ pure . (`rem` fromIntegral num_locks) . flattenIndex dims f locking arrs is' Nothing -> error $ "Missing locks for " ++ prettyString acc withAcc :: Pat LetDecMem -> [(Shape, [VName], Maybe (Lambda MCMem, [SubExp]))] -> Lambda MCMem -> MulticoreGen () withAcc pat inputs lam = do atomics <- hostAtomics <$> askEnv locksForInputs atomics $ zip accs inputs where accs = map paramName $ lambdaParams lam locksForInputs _ [] = defCompileExp pat $ WithAcc inputs lam locksForInputs atomics ((c, (_, _, op)) : inputs') | Just (op_lam, _) <- op, AtomicLocking _ <- atomicUpdateLocking atomics op_lam = do let num_locks = 100151 locks_arr <- sStaticArray "withacc_locks" int32 $ Imp.ArrayZeros num_locks let locks = Locks locks_arr num_locks extend env = env {hostLocks = M.insert c locks $ hostLocks env} localEnv extend $ locksForInputs atomics inputs' | otherwise = locksForInputs atomics inputs' compileMCExp :: ExpCompiler MCMem HostEnv Imp.Multicore compileMCExp _ (BasicOp (UpdateAcc safety acc is vs)) = updateAcc safety acc is vs compileMCExp pat (WithAcc inputs lam) = withAcc pat inputs lam compileMCExp dest e = defCompileExp dest e compileMCOp :: Pat LetDecMem -> MCOp NoOp MCMem -> ImpM MCMem HostEnv Imp.Multicore () compileMCOp _ (OtherOp NoOp) = pure () compileMCOp pat (ParOp par_op op) = do let space = getSpace op dPrimV_ (segFlat space) (0 :: Imp.TExp Int64) iterations <- getIterationDomain op space seq_code <- collect $ localOps inThreadOps $ do nsubtasks <- dPrim "nsubtasks" sOp $ Imp.GetNumTasks $ tvVar nsubtasks emit =<< compileSegOp pat op nsubtasks retvals <- getReturnParams pat op let scheduling_info = Imp.SchedulerInfo (untyped iterations) par_task <- case par_op of Just nested_op -> do let space' = getSpace nested_op dPrimV_ (segFlat space') (0 :: Imp.TExp Int64) par_code <- collect $ do nsubtasks <- dPrim "nsubtasks" sOp $ Imp.GetNumTasks $ tvVar nsubtasks emit =<< compileSegOp pat nested_op nsubtasks pure $ Just $ Imp.ParallelTask par_code Nothing -> pure Nothing s <- segOpString op let seq_task = Imp.ParallelTask seq_code free_params <- filter (`notElem` retvals) <$> freeParams (par_task, seq_task) emit . Imp.Op $ Imp.SegOp s free_params seq_task par_task retvals $ scheduling_info (decideScheduling' op seq_code) compileSegOp :: Pat LetDecMem -> SegOp () MCMem -> TV Int32 -> ImpM MCMem HostEnv Imp.Multicore Imp.MCCode compileSegOp pat (SegHist _ space histops _ kbody) ntasks = compileSegHist pat space histops kbody ntasks compileSegOp pat (SegScan _ space scans _ kbody) ntasks = compileSegScan pat space scans kbody ntasks compileSegOp pat (SegRed _ space reds _ kbody) ntasks = compileSegRed pat space reds kbody ntasks compileSegOp pat (SegMap _ space _ kbody) _ = compileSegMap pat space kbody -- GCC supported primitve atomic Operations -- TODO: Add support for 1, 2, and 16 bytes too gccAtomics :: AtomicBinOp gccAtomics = flip lookup cpu where cpu = [ (Add Int32 OverflowUndef, Imp.AtomicAdd Int32), (Sub Int32 OverflowUndef, Imp.AtomicSub Int32), (And Int32, Imp.AtomicAnd Int32), (Xor Int32, Imp.AtomicXor Int32), (Or Int32, Imp.AtomicOr Int32), (Add Int64 OverflowUndef, Imp.AtomicAdd Int64), (Sub Int64 OverflowUndef, Imp.AtomicSub Int64), (And Int64, Imp.AtomicAnd Int64), (Xor Int64, Imp.AtomicXor Int64), (Or Int64, Imp.AtomicOr Int64) ] -- | Compile the program. compileProg :: (MonadFreshNames m) => Prog MCMem -> m (Warnings, Imp.Definitions Imp.Multicore) compileProg = Futhark.CodeGen.ImpGen.compileProg (HostEnv gccAtomics mempty) topLevelOps Imp.DefaultSpace futhark-0.25.27/src/Futhark/CodeGen/ImpGen/Multicore/000077500000000000000000000000001475065116200221635ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/ImpGen/Multicore/Base.hs000066400000000000000000000441361475065116200234010ustar00rootroot00000000000000module Futhark.CodeGen.ImpGen.Multicore.Base ( extractAllocations, compileThreadResult, Locks (..), HostEnv (..), AtomicBinOp, MulticoreGen, decideScheduling, decideScheduling', renameSegBinOp, freeParams, renameHistOpLambda, atomicUpdateLocking, AtomicUpdate (..), DoAtomicUpdate, Locking (..), getSpace, getLoopBounds, getIterationDomain, getReturnParams, segOpString, ChunkLoopVectorization (..), generateChunkLoop, generateUniformizeLoop, extractVectorLane, inISPC, toParam, sLoopNestVectorized, ) where import Control.Monad import Data.Bifunctor import Data.Map qualified as M import Data.Maybe import Futhark.CodeGen.ImpCode.Multicore qualified as Imp import Futhark.CodeGen.ImpGen import Futhark.Error import Futhark.IR.MCMem import Futhark.Transform.Rename import Prelude hiding (quot, rem) -- | Is there an atomic t'BinOp' corresponding to this t'BinOp'? type AtomicBinOp = BinOp -> Maybe (VName -> VName -> Imp.Count Imp.Elements (Imp.TExp Int32) -> Imp.Exp -> Imp.AtomicOp) -- | Information about the locks available for accumulators. data Locks = Locks { locksArray :: VName, locksCount :: Int } data HostEnv = HostEnv { hostAtomics :: AtomicBinOp, hostLocks :: M.Map VName Locks } type MulticoreGen = ImpM MCMem HostEnv Imp.Multicore segOpString :: SegOp () MCMem -> MulticoreGen String segOpString SegMap {} = pure "segmap" segOpString SegRed {} = pure "segred" segOpString SegScan {} = pure "segscan" segOpString SegHist {} = pure "seghist" arrParam :: VName -> MulticoreGen Imp.Param arrParam arr = do name_entry <- lookupVar arr case name_entry of ArrayVar _ (ArrayEntry (MemLoc mem _ _) _) -> pure $ Imp.MemParam mem DefaultSpace _ -> error $ "arrParam: could not handle array " ++ show arr toParam :: VName -> TypeBase shape u -> MulticoreGen [Imp.Param] toParam name (Prim pt) = pure [Imp.ScalarParam name pt] toParam name (Mem space) = pure [Imp.MemParam name space] toParam name Array {} = pure <$> arrParam name toParam _name Acc {} = pure [] -- FIXME? Are we sure this works? getSpace :: SegOp () MCMem -> SegSpace getSpace (SegHist _ space _ _ _) = space getSpace (SegRed _ space _ _ _) = space getSpace (SegScan _ space _ _ _) = space getSpace (SegMap _ space _ _) = space getLoopBounds :: MulticoreGen (Imp.TExp Int64, Imp.TExp Int64) getLoopBounds = do start <- dPrim "start" end <- dPrim "end" emit $ Imp.Op $ Imp.GetLoopBounds (tvVar start) (tvVar end) pure (tvExp start, tvExp end) getIterationDomain :: SegOp () MCMem -> SegSpace -> MulticoreGen (Imp.TExp Int64) getIterationDomain SegMap {} space = do let ns = map snd $ unSegSpace space ns_64 = map pe64 ns pure $ product ns_64 getIterationDomain _ space = do let ns = map snd $ unSegSpace space ns_64 = map pe64 ns case unSegSpace space of [_] -> pure $ product ns_64 -- A segmented SegOp is over the segments -- so we drop the last dimension, which is -- executed sequentially _ -> pure $ product $ init ns_64 -- When the SegRed's return value is a scalar -- we perform a call by value-result in the segop function getReturnParams :: Pat LetDecMem -> SegOp () MCMem -> MulticoreGen [Imp.Param] getReturnParams pat SegRed {} = -- It's a good idea to make sure any prim values are initialised, as -- we will load them (redundantly) in the task code, and -- uninitialised values are UB. fmap concat . forM (patElems pat) $ \pe -> do case patElemType pe of Prim pt -> patElemName pe <~~ ValueExp (blankPrimValue pt) _ -> pure () toParam (patElemName pe) (patElemType pe) getReturnParams _ _ = pure mempty renameSegBinOp :: [SegBinOp MCMem] -> MulticoreGen [SegBinOp MCMem] renameSegBinOp segbinops = forM segbinops $ \(SegBinOp comm lam ne shape) -> do lam' <- renameLambda lam pure $ SegBinOp comm lam' ne shape compileThreadResult :: SegSpace -> PatElem LetDecMem -> KernelResult -> MulticoreGen () compileThreadResult space pe (Returns _ _ what) = do let is = map (Imp.le64 . fst) $ unSegSpace space copyDWIMFix (patElemName pe) is what [] compileThreadResult _ _ WriteReturns {} = compilerBugS "compileThreadResult: WriteReturns unhandled." compileThreadResult _ _ TileReturns {} = compilerBugS "compileThreadResult: TileReturns unhandled." compileThreadResult _ _ RegTileReturns {} = compilerBugS "compileThreadResult: RegTileReturns unhandled." freeParams :: (FreeIn a) => a -> MulticoreGen [Imp.Param] freeParams code = do let free = namesToList $ freeIn code ts <- mapM lookupType free concat <$> zipWithM toParam free ts isLoadBalanced :: Imp.MCCode -> Bool isLoadBalanced (a Imp.:>>: b) = isLoadBalanced a && isLoadBalanced b isLoadBalanced (Imp.For _ _ a) = isLoadBalanced a isLoadBalanced (Imp.If _ a b) = isLoadBalanced a && isLoadBalanced b isLoadBalanced (Imp.Comment _ a) = isLoadBalanced a isLoadBalanced Imp.While {} = False isLoadBalanced (Imp.Op (Imp.ParLoop _ code _)) = isLoadBalanced code isLoadBalanced (Imp.Op (Imp.ForEachActive _ a)) = isLoadBalanced a isLoadBalanced (Imp.Op (Imp.ForEach _ _ _ a)) = isLoadBalanced a isLoadBalanced (Imp.Op (Imp.ISPCKernel a _)) = isLoadBalanced a isLoadBalanced _ = True decideScheduling' :: SegOp () rep -> Imp.MCCode -> Imp.Scheduling decideScheduling' SegHist {} _ = Imp.Static decideScheduling' SegScan {} _ = Imp.Static decideScheduling' SegRed {} _ = Imp.Static decideScheduling' SegMap {} code = decideScheduling code decideScheduling :: Imp.MCCode -> Imp.Scheduling decideScheduling code = if isLoadBalanced code then Imp.Static else Imp.Dynamic -- | Try to extract invariant allocations. If we assume that the -- given 'Imp.MCCode' is the body of a 'SegOp', then it is always safe -- to move the immediate allocations to the prebody. extractAllocations :: Imp.MCCode -> (Imp.MCCode, Imp.MCCode) extractAllocations segop_code = f segop_code where declared = Imp.declaredIn segop_code f (Imp.DeclareMem name space) = -- Hoisting declarations out is always safe. (Imp.DeclareMem name space, mempty) f (Imp.Allocate name size space) | not $ freeIn size `namesIntersect` declared = (Imp.Allocate name size space, mempty) f (x Imp.:>>: y) = f x <> f y f (Imp.While cond body) = (mempty, Imp.While cond body) f (Imp.For i bound body) = (mempty, Imp.For i bound body) f (Imp.Comment s code) = second (Imp.Comment s) (f code) f Imp.Free {} = mempty f (Imp.If cond tcode fcode) = let (ta, tcode') = f tcode (fa, fcode') = f fcode in (ta <> fa, Imp.If cond tcode' fcode') f (Imp.Op (Imp.ParLoop s body free)) = let (body_allocs, body') = extractAllocations body (free_allocs, here_allocs) = f body_allocs free' = filter ( (`notNameIn` Imp.declaredIn body_allocs) . Imp.paramName ) free in ( free_allocs, here_allocs <> Imp.Op (Imp.ParLoop s body' free') ) f code = (mempty, code) -- | Indicates whether to vectorize a chunk loop or keep it sequential. -- We use this to allow falling back to sequential chunk loops in cases -- we don't care about trying to vectorize. data ChunkLoopVectorization = Vectorized | Scalar -- | Emit code for the chunk loop, given an action that generates code -- for a single iteration. -- -- The action is called with the (symbolic) index of the current -- iteration. generateChunkLoop :: String -> ChunkLoopVectorization -> (Imp.TExp Int64 -> MulticoreGen ()) -> MulticoreGen () generateChunkLoop desc Scalar m = do (start, end) <- getLoopBounds n <- dPrimVE "n" $ end - start i <- newVName (desc <> "_i") (body_allocs, body) <- fmap extractAllocations $ collect $ do addLoopVar i Int64 m $ start + Imp.le64 i emit body_allocs -- Emit either foreach or normal for loop let bound = untyped n emit $ Imp.For i bound body generateChunkLoop desc Vectorized m = do (start, end) <- getLoopBounds n <- dPrimVE "n" $ end - start i <- newVName (desc <> "_i") (body_allocs, body) <- fmap extractAllocations $ collect $ do addLoopVar i Int64 m $ Imp.le64 i emit body_allocs -- Emit either foreach or normal for loop let from = untyped start let bound = untyped (start + n) emit $ Imp.Op $ Imp.ForEach i from bound body -- | Emit code for a sequential loop over each vector lane, given -- and action that generates code for a single iteration. The action -- is called with the symbolic index of the current iteration. generateUniformizeLoop :: (Imp.TExp Int64 -> MulticoreGen ()) -> MulticoreGen () generateUniformizeLoop m = do i <- newVName "uni_i" body <- collect $ do addLoopVar i Int64 m $ Imp.le64 i emit $ Imp.Op $ Imp.ForEachActive i body -- | Given a piece of code, if that code performs an assignment, turn -- that assignment into an extraction of element from a vector on the -- right hand side, using a passed index for the extraction. Other code -- is left as is. extractVectorLane :: Imp.TExp Int64 -> MulticoreGen Imp.MCCode -> MulticoreGen () extractVectorLane j code = do let ut_exp = untyped j code' <- code case code' of Imp.SetScalar vname e -> do typ <- lookupType vname case typ of -- ISPC v1.17 does not support extract on f16 yet.. -- Thus we do this stupid conversion to f32 Prim (FloatType Float16) -> do tv :: TV Float <- dPrim "hack_extract_f16" emit $ Imp.SetScalar (tvVar tv) e emit $ Imp.Op $ Imp.ExtractLane vname (untyped $ tvExp tv) ut_exp _ -> emit $ Imp.Op $ Imp.ExtractLane vname e ut_exp _ -> emit code' -- | Given an action that may generate some code, put that code -- into an ISPC kernel. inISPC :: MulticoreGen () -> MulticoreGen () inISPC code = do code' <- collect code free <- freeParams code' emit $ Imp.Op $ Imp.ISPCKernel code' free ------------------------------- ------- SegRed helpers ------- ------------------------------- sForVectorized' :: VName -> Imp.Exp -> MulticoreGen () -> MulticoreGen () sForVectorized' i bound body = do let it = case primExpType bound of IntType bound_t -> bound_t t -> error $ "sFor': bound " ++ prettyString bound ++ " is of type " ++ prettyString t addLoopVar i it body' <- collect body emit $ Imp.Op $ Imp.ForEach i (Imp.ValueExp $ blankPrimValue $ Imp.IntType Imp.Int64) bound body' sForVectorized :: String -> Imp.TExp t -> (Imp.TExp t -> MulticoreGen ()) -> MulticoreGen () sForVectorized i bound body = do i' <- newVName i sForVectorized' i' (untyped bound) $ body $ TPrimExp $ Imp.var i' $ primExpType $ untyped bound -- | Like sLoopNest, but puts a vectorized loop at the innermost layer. sLoopNestVectorized :: Shape -> ([Imp.TExp Int64] -> MulticoreGen ()) -> MulticoreGen () sLoopNestVectorized = sLoopNest' [] . shapeDims where sLoopNest' is [] f = f $ reverse is sLoopNest' is [d] f = sForVectorized "nest_i" (pe64 d) $ \i -> sLoopNest' (i : is) [] f sLoopNest' is (d : ds) f = sFor "nest_i" (pe64 d) $ \i -> sLoopNest' (i : is) ds f ------------------------------- ------- SegHist helpers ------- ------------------------------- renameHistOpLambda :: [HistOp MCMem] -> MulticoreGen [HistOp MCMem] renameHistOpLambda hist_ops = forM hist_ops $ \(HistOp w rf dest neutral shape lam) -> do lam' <- renameLambda lam pure $ HistOp w rf dest neutral shape lam' -- | Locking strategy used for an atomic update. data Locking = Locking { -- | Array containing the lock. lockingArray :: VName, -- | Value for us to consider the lock free. lockingIsUnlocked :: Imp.TExp Int32, -- | What to write when we lock it. lockingToLock :: Imp.TExp Int32, -- | What to write when we unlock it. lockingToUnlock :: Imp.TExp Int32, -- | A transformation from the logical lock index to the -- physical position in the array. This can also be used -- to make the lock array smaller. lockingMapping :: [Imp.TExp Int64] -> [Imp.TExp Int64] } -- | A function for generating code for an atomic update. Assumes -- that the bucket is in-bounds. type DoAtomicUpdate rep r = [VName] -> [Imp.TExp Int64] -> MulticoreGen () -- | The mechanism that will be used for performing the atomic update. -- Approximates how efficient it will be. Ordered from most to least -- efficient. data AtomicUpdate rep r = AtomicPrim (DoAtomicUpdate rep r) | -- | Can be done by efficient swaps. AtomicCAS (DoAtomicUpdate rep r) | -- | Requires explicit locking. AtomicLocking (Locking -> DoAtomicUpdate rep r) atomicUpdateLocking :: AtomicBinOp -> Lambda MCMem -> AtomicUpdate MCMem () atomicUpdateLocking atomicBinOp lam | Just ops_and_ts <- lamIsBinOp lam, all (\(_, t, _, _) -> supportedPrims $ primBitSize t) ops_and_ts = primOrCas ops_and_ts $ \arrs bucket -> -- If the operator is a vectorised binary operator on 32-bit values, -- we can use a particularly efficient implementation. If the -- operator has an atomic implementation we use that, otherwise it -- is still a binary operator which can be implemented by atomic -- compare-and-swap if 32 bits. forM_ (zip arrs ops_and_ts) $ \(a, (op, t, x, y)) -> do -- Common variables. old <- dPrimS "old" t (arr', _a_space, bucket_offset) <- fullyIndexArray a bucket case opHasAtomicSupport old arr' (sExt32 <$> bucket_offset) op of Just f -> sOp $ f $ Imp.var y t Nothing -> atomicUpdateCAS t a old bucket x $ x <~~ Imp.BinOpExp op (Imp.var x t) (Imp.var y t) where opHasAtomicSupport old arr' bucket' bop = do let atomic f = Imp.Atomic . f old arr' bucket' atomic <$> atomicBinOp bop primOrCas ops | all isPrim ops = AtomicPrim | otherwise = AtomicCAS isPrim (op, _, _, _) = isJust $ atomicBinOp op atomicUpdateLocking _ op | [Prim t] <- lambdaReturnType op, [xp, _] <- lambdaParams op, supportedPrims (primBitSize t) = AtomicCAS $ \[arr] bucket -> do old <- dPrimS "old" t atomicUpdateCAS t arr old bucket (paramName xp) $ compileBody' [xp] $ lambdaBody op atomicUpdateLocking _ op = AtomicLocking $ \locking arrs bucket -> do old <- dPrim "old" continue <- dPrimVol "continue" int32 (0 :: Imp.TExp Int32) -- Correctly index into locks. (locks', _locks_space, locks_offset) <- fullyIndexArray (lockingArray locking) $ lockingMapping locking bucket -- Critical section let try_acquire_lock = do old <-- (0 :: Imp.TExp Int32) sOp . Imp.Atomic $ Imp.AtomicCmpXchg int32 (tvVar old) locks' (sExt32 <$> locks_offset) (tvVar continue) (untyped (lockingToLock locking)) lock_acquired = tvExp continue -- Even the releasing is done with an atomic rather than a -- simple write, for memory coherency reasons. release_lock = do old <-- lockingToLock locking sOp . Imp.Atomic $ Imp.AtomicCmpXchg int32 (tvVar old) locks' (sExt32 <$> locks_offset) (tvVar continue) (untyped (lockingToUnlock locking)) -- Preparing parameters. It is assumed that the caller has already -- filled the arr_params. We copy the current value to the -- accumulator parameters. let (acc_params, _arr_params) = splitAt (length arrs) $ lambdaParams op bind_acc_params = everythingVolatile $ sComment "bind lhs" $ forM_ (zip acc_params arrs) $ \(acc_p, arr) -> copyDWIMFix (paramName acc_p) [] (Var arr) bucket let op_body = sComment "execute operation" $ compileBody' acc_params $ lambdaBody op do_hist = everythingVolatile $ sComment "update global result" $ zipWithM_ (writeArray bucket) arrs $ map (Var . paramName) acc_params -- While-loop: Try to insert your value sWhile (tvExp continue .==. 0) $ do try_acquire_lock sUnless (lock_acquired .==. 0) $ do dLParams acc_params bind_acc_params op_body do_hist release_lock where writeArray bucket arr val = copyDWIMFix arr bucket val [] atomicUpdateCAS :: PrimType -> VName -> VName -> [Imp.TExp Int64] -> VName -> MulticoreGen () -> MulticoreGen () atomicUpdateCAS t arr old bucket x do_op = do run_loop <- dPrimV "run_loop" (0 :: Imp.TExp Int32) (arr', _a_space, bucket_offset) <- fullyIndexArray arr bucket bytes <- toIntegral $ primBitSize t let (toBits, fromBits) = case t of FloatType Float16 -> ( \v -> Imp.FunExp "to_bits16" [v] int16, \v -> Imp.FunExp "from_bits16" [v] t ) FloatType Float32 -> ( \v -> Imp.FunExp "to_bits32" [v] int32, \v -> Imp.FunExp "from_bits32" [v] t ) FloatType Float64 -> ( \v -> Imp.FunExp "to_bits64" [v] int64, \v -> Imp.FunExp "from_bits64" [v] t ) _ -> (id, id) int | primBitSize t == 16 = int16 | primBitSize t == 32 = int32 | otherwise = int64 everythingVolatile $ copyDWIMFix old [] (Var arr) bucket old_bits_v <- dPrimS "old_bits" int old_bits_v <~~ toBits (Imp.var old t) let old_bits = Imp.var old_bits_v int -- While-loop: Try to insert your value sWhile (tvExp run_loop .==. 0) $ do x <~~ Imp.var old t do_op -- Writes result into x sOp . Imp.Atomic $ Imp.AtomicCmpXchg bytes old_bits_v arr' (sExt32 <$> bucket_offset) (tvVar run_loop) (toBits (Imp.var x t)) old <~~ fromBits old_bits supportedPrims :: Int -> Bool supportedPrims 8 = True supportedPrims 16 = True supportedPrims 32 = True supportedPrims 64 = True supportedPrims _ = False -- Supported bytes lengths by GCC (and clang) compiler toIntegral :: Int -> MulticoreGen PrimType toIntegral 8 = pure int8 toIntegral 16 = pure int16 toIntegral 32 = pure int32 toIntegral 64 = pure int64 toIntegral b = error $ "number of bytes is not supported for CAS - " ++ prettyString b futhark-0.25.27/src/Futhark/CodeGen/ImpGen/Multicore/SegHist.hs000066400000000000000000000370541475065116200240760ustar00rootroot00000000000000module Futhark.CodeGen.ImpGen.Multicore.SegHist ( compileSegHist, ) where import Control.Monad import Data.List (zip4) import Data.Maybe (listToMaybe) import Futhark.CodeGen.ImpCode.Multicore qualified as Imp import Futhark.CodeGen.ImpGen import Futhark.CodeGen.ImpGen.Multicore.Base import Futhark.CodeGen.ImpGen.Multicore.SegRed (compileSegRed') import Futhark.IR.MCMem import Futhark.Transform.Rename (renameLambda) import Futhark.Util (chunks, splitFromEnd, takeLast) import Futhark.Util.IntegralExp (rem) import Prelude hiding (quot, rem) compileSegHist :: Pat LetDecMem -> SegSpace -> [HistOp MCMem] -> KernelBody MCMem -> TV Int32 -> MulticoreGen Imp.MCCode compileSegHist pat space histops kbody nsubtasks | [_] <- unSegSpace space = nonsegmentedHist pat space histops kbody nsubtasks | otherwise = segmentedHist pat space histops kbody -- | Split some list into chunks equal to the number of values -- returned by each 'SegBinOp' segHistOpChunks :: [HistOp rep] -> [a] -> [[a]] segHistOpChunks = chunks . map (length . histNeutral) histSize :: HistOp MCMem -> Imp.TExp Int64 histSize = product . map pe64 . shapeDims . histShape genHistOpParams :: HistOp MCMem -> MulticoreGen () genHistOpParams histops = dScope Nothing $ scopeOfLParams $ lambdaParams $ histOp histops renameHistop :: HistOp MCMem -> MulticoreGen (HistOp MCMem) renameHistop histop = do let op = histOp histop lambda' <- renameLambda op pure histop {histOp = lambda'} nonsegmentedHist :: Pat LetDecMem -> SegSpace -> [HistOp MCMem] -> KernelBody MCMem -> TV Int32 -> MulticoreGen Imp.MCCode nonsegmentedHist pat space histops kbody num_histos = do let ns = map snd $ unSegSpace space ns_64 = map pe64 ns num_histos' = tvExp num_histos hist_width = maybe 0 histSize $ listToMaybe histops use_subhistogram = sExt64 num_histos' * hist_width .<=. product ns_64 histops' <- renameHistOpLambda histops -- Only do something if there is actually input. collect $ sUnless (product ns_64 .==. 0) $ do sIf use_subhistogram (subHistogram pat space histops num_histos kbody) (atomicHistogram pat space histops' kbody) -- | -- Atomic Histogram approach -- The implementation has three sub-strategies depending on the -- type of the operator -- 1. If values are integral scalars, a direct-supported atomic update is used. -- 2. If values are on one memory location, e.g. a float, then a -- CAS operation is used to perform the update, where the float is -- casted to an integral scalar. -- 1. and 2. currently only works for 32-bit and 64-bit types, -- but GCC has support for 8-, 16- and 128- bit types as well. -- 3. Otherwise a locking based approach is used onOpAtomic :: HistOp MCMem -> MulticoreGen ([VName] -> [Imp.TExp Int64] -> MulticoreGen ()) onOpAtomic op = do atomics <- hostAtomics <$> askEnv let lambda = histOp op do_op = atomicUpdateLocking atomics lambda case do_op of AtomicPrim f -> pure f AtomicCAS f -> pure f AtomicLocking f -> do -- Allocate a static array of locks -- as in the GPU backend let num_locks = 100151 -- This number is taken from the GPU backend dims = map pe64 $ shapeDims (histOpShape op <> histShape op) locks <- sStaticArray "hist_locks" int32 $ Imp.ArrayZeros num_locks let l' = Locking locks 0 1 0 (pure . (`rem` fromIntegral num_locks) . flattenIndex dims) pure $ f l' atomicHistogram :: Pat LetDecMem -> SegSpace -> [HistOp MCMem] -> KernelBody MCMem -> MulticoreGen () atomicHistogram pat space histops kbody = do let (is, ns) = unzip $ unSegSpace space ns_64 = map pe64 ns let num_red_res = length histops + sum (map (length . histNeutral) histops) (all_red_pes, map_pes) = splitAt num_red_res $ patElems pat atomicOps <- mapM onOpAtomic histops body <- collect $ do dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) generateChunkLoop "SegHist" Scalar $ \flat_idx -> do zipWithM_ dPrimV_ is $ unflattenIndex ns_64 flat_idx compileStms mempty (kernelBodyStms kbody) $ do let (red_res, map_res) = splitFromEnd (length map_pes) $ kernelBodyResult kbody red_res_split = splitHistResults histops $ map kernelResultSubExp red_res let pes_per_op = chunks (map (length . histDest) histops) all_red_pes forM_ (zip4 histops red_res_split atomicOps pes_per_op) $ \(HistOp dest_shape _ _ _ shape lam, (bucket, vs'), do_op, dest_res) -> do let (_is_params, vs_params) = splitAt (length vs') $ lambdaParams lam dest_shape' = map pe64 $ shapeDims dest_shape bucket' = map pe64 bucket bucket_in_bounds = inBounds (Slice (map DimFix bucket')) dest_shape' sComment "save map-out results" $ forM_ (zip map_pes map_res) $ \(pe, res) -> copyDWIMFix (patElemName pe) (map Imp.le64 is) (kernelResultSubExp res) [] sComment "perform updates" $ sWhen bucket_in_bounds $ do let bucket_is = map Imp.le64 (init is) ++ bucket' dLParams $ lambdaParams lam sLoopNest shape $ \is' -> do forM_ (zip vs_params vs') $ \(p, res) -> copyDWIMFix (paramName p) [] res is' do_op (map patElemName dest_res) (bucket_is ++ is') free_params <- freeParams body emit $ Imp.Op $ Imp.ParLoop "atomic_seg_hist" body free_params updateHisto :: HistOp MCMem -> [VName] -> [Imp.TExp Int64] -> Imp.TExp Int64 -> [Param LParamMem] -> MulticoreGen () updateHisto op arrs bucket j uni_acc = do let bind_acc_params = forM_ (zip uni_acc arrs) $ \(acc_u, arr) -> do copyDWIMFix (paramName acc_u) [] (Var arr) bucket op_body = compileBody' [] $ lambdaBody $ histOp op writeArray arr val = extractVectorLane j $ collect $ copyDWIMFix arr bucket val [] do_hist = zipWithM_ writeArray arrs $ map resSubExp $ bodyResult $ lambdaBody $ histOp op sComment "Start of body" $ do bind_acc_params op_body do_hist -- Generates num_histos sub-histograms of the size -- of the destination histogram -- Then for each chunk of the input each subhistogram -- is computed and finally combined through a segmented reduction -- across the histogram indicies. -- This is expected to be fast if len(histDest) is small subHistogram :: Pat LetDecMem -> SegSpace -> [HistOp MCMem] -> TV Int32 -> KernelBody MCMem -> MulticoreGen () subHistogram pat space histops num_histos kbody = do emit $ Imp.DebugPrint "subHistogram segHist" Nothing let (is, ns) = unzip $ unSegSpace space ns_64 = map pe64 ns let pes = patElems pat num_red_res = length histops + sum (map (length . histNeutral) histops) map_pes = drop num_red_res pes per_red_pes = segHistOpChunks histops $ patElems pat -- Allocate array of subhistograms in the calling thread. Each -- tasks will work in its own private allocations (to avoid false -- sharing), but this is where they will ultimately copy their -- results. global_subhistograms <- forM histops $ \histop -> forM (histType histop) $ \t -> do let shape = Shape [tvSize num_histos] <> arrayShape t sAllocArray "subhistogram" (elemType t) shape DefaultSpace let tid' = Imp.le64 $ segFlat space -- Generate loop body of parallel function body <- collect $ do dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) local_subhistograms <- forM (zip per_red_pes histops) $ \(pes', histop) -> do op_local_subhistograms <- forM (histType histop) $ \t -> sAllocArray "subhistogram" (elemType t) (arrayShape t) DefaultSpace forM_ (zip3 pes' op_local_subhistograms (histNeutral histop)) $ \(pe, hist, ne) -> -- First thread initializes histogram with dest vals. Others -- initialize with neutral element sIf (tid' .==. 0) (copyDWIMFix hist [] (Var $ patElemName pe) []) ( sLoopNest (histShape histop) $ \shape_is -> sLoopNest (histOpShape histop) $ \vec_is -> copyDWIMFix hist (shape_is <> vec_is) ne [] ) pure op_local_subhistograms inISPC $ generateChunkLoop "SegRed" Vectorized $ \i -> do zipWithM_ dPrimV_ is $ unflattenIndex ns_64 i compileStms mempty (kernelBodyStms kbody) $ do let (red_res, map_res) = splitFromEnd (length map_pes) $ map kernelResultSubExp $ kernelBodyResult kbody sComment "save map-out results" $ forM_ (zip map_pes map_res) $ \(pe, res) -> copyDWIMFix (patElemName pe) (map Imp.le64 is) res [] forM_ (zip3 histops local_subhistograms (splitHistResults histops red_res)) $ \( histop@(HistOp dest_shape _ _ _ shape _), histop_subhistograms, (bucket, vs') ) -> do histop' <- renameHistop histop let bucket' = map pe64 bucket dest_shape' = map pe64 $ shapeDims dest_shape acc_params' = (lambdaParams . histOp) histop' vs_params' = takeLast (length vs') $ lambdaParams $ histOp histop' generateUniformizeLoop $ \j -> sComment "perform updates" $ do -- Create new set of uniform buckets -- That is extract each bucket from a SIMD vector lane extract_buckets <- mapM (dPrimSV "extract_bucket" . (primExpType . untyped)) bucket' forM_ (zip extract_buckets bucket') $ \(x, y) -> emit $ Imp.Op $ Imp.ExtractLane (tvVar x) (untyped y) (untyped j) let bucket'' = map tvExp extract_buckets bucket_in_bounds = inBounds (Slice (map DimFix bucket'')) dest_shape' sWhen bucket_in_bounds $ do genHistOpParams histop' sLoopNest shape $ \is' -> do -- read values vs and perform lambda writing result back to is forM_ (zip vs_params' vs') $ \(p, res) -> ifPrimType (paramType p) $ \pt -> do -- Hack to copy varying load into uniform result variable tmp <- dPrimS "tmp" pt copyDWIMFix tmp [] res is' extractVectorLane j . pure $ Imp.SetScalar (paramName p) (toExp' pt tmp) updateHisto histop' histop_subhistograms (bucket'' ++ is') j acc_params' -- Copy the task-local subhistograms to the global subhistograms, -- where they will be combined. forM_ (zip (concat global_subhistograms) (concat local_subhistograms)) $ \(global, local) -> copyDWIMFix global [tid'] (Var local) [] free_params <- freeParams body emit $ Imp.Op $ Imp.ParLoop "seghist_stage_1" body free_params -- Perform a segmented reduction over the subhistograms forM_ (zip3 per_red_pes global_subhistograms histops) $ \(red_pes, hists, op) -> do bucket_ids <- replicateM (shapeRank (histShape op)) (newVName "bucket_id") subhistogram_id <- newVName "subhistogram_id" let segred_space = SegSpace (segFlat space) $ segment_dims ++ zip bucket_ids (shapeDims (histShape op)) ++ [(subhistogram_id, tvSize num_histos)] segred_op = SegBinOp Noncommutative (histOp op) (histNeutral op) (histOpShape op) red_code <- collect $ do nsubtasks <- dPrim "nsubtasks" sOp $ Imp.GetNumTasks $ tvVar nsubtasks emit <=< compileSegRed' (Pat red_pes) segred_space [segred_op] nsubtasks $ \red_cont -> red_cont $ segBinOpChunks [segred_op] $ flip map hists $ \subhisto -> ( Var subhisto, map Imp.le64 $ map fst segment_dims ++ [subhistogram_id] ++ bucket_ids ) let ns_red = map (pe64 . snd) $ unSegSpace segred_space iterations = product $ init ns_red -- The segmented reduction is sequential over the inner most dimension scheduler_info = Imp.SchedulerInfo (untyped iterations) Imp.Static red_task = Imp.ParallelTask red_code free_params_red <- freeParams red_code emit $ Imp.Op $ Imp.SegOp "seghist_red" free_params_red red_task Nothing mempty scheduler_info where segment_dims = init $ unSegSpace space ifPrimType (Prim pt) f = f pt ifPrimType _ _ = pure () -- Note: This isn't currently used anywhere. -- This implementation for a Segmented Hist only -- parallelize over the segments, -- where each segment is updated sequentially. segmentedHist :: Pat LetDecMem -> SegSpace -> [HistOp MCMem] -> KernelBody MCMem -> MulticoreGen Imp.MCCode segmentedHist pat space histops kbody = do emit $ Imp.DebugPrint "Segmented segHist" Nothing collect $ do body <- compileSegHistBody pat space histops kbody free_params <- freeParams body emit $ Imp.Op $ Imp.ParLoop "segmented_hist" body free_params compileSegHistBody :: Pat LetDecMem -> SegSpace -> [HistOp MCMem] -> KernelBody MCMem -> MulticoreGen Imp.MCCode compileSegHistBody pat space histops kbody = collect $ do let (is, ns) = unzip $ unSegSpace space ns_64 = map pe64 ns let num_red_res = length histops + sum (map (length . histNeutral) histops) map_pes = drop num_red_res $ patElems pat per_red_pes = segHistOpChunks histops $ patElems pat dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) generateChunkLoop "SegHist" Scalar $ \idx -> do let inner_bound = last ns_64 sFor "i" inner_bound $ \i -> do zipWithM_ dPrimV_ (init is) $ unflattenIndex (init ns_64) idx dPrimV_ (last is) i compileStms mempty (kernelBodyStms kbody) $ do let (red_res, map_res) = splitFromEnd (length map_pes) $ map kernelResultSubExp $ kernelBodyResult kbody forM_ (zip3 per_red_pes histops (splitHistResults histops red_res)) $ \(red_pes, HistOp dest_shape _ _ _ shape lam, (bucket, vs')) -> do let (is_params, vs_params) = splitAt (length vs') $ lambdaParams lam bucket' = map pe64 bucket dest_shape' = map pe64 $ shapeDims dest_shape bucket_in_bounds = inBounds (Slice (map DimFix bucket')) dest_shape' sComment "save map-out results" $ forM_ (zip map_pes map_res) $ \(pe, res) -> copyDWIMFix (patElemName pe) (map Imp.le64 is) res [] sComment "perform updates" $ sWhen bucket_in_bounds $ do dLParams $ lambdaParams lam sLoopNest shape $ \vec_is -> do -- Index forM_ (zip red_pes is_params) $ \(pe, p) -> copyDWIMFix (paramName p) [] (Var $ patElemName pe) (map Imp.le64 (init is) ++ bucket' ++ vec_is) -- Value at index forM_ (zip vs_params vs') $ \(p, v) -> copyDWIMFix (paramName p) [] v vec_is compileStms mempty (bodyStms $ lambdaBody lam) $ forM_ (zip red_pes $ map resSubExp $ bodyResult $ lambdaBody lam) $ \(pe, se) -> copyDWIMFix (patElemName pe) (map Imp.le64 (init is) ++ bucket' ++ vec_is) se [] futhark-0.25.27/src/Futhark/CodeGen/ImpGen/Multicore/SegMap.hs000066400000000000000000000032601475065116200236740ustar00rootroot00000000000000-- | Multicore code generation for 'SegMap'. module Futhark.CodeGen.ImpGen.Multicore.SegMap ( compileSegMap, ) where import Control.Monad import Futhark.CodeGen.ImpCode.Multicore qualified as Imp import Futhark.CodeGen.ImpGen import Futhark.CodeGen.ImpGen.Multicore.Base import Futhark.IR.MCMem import Futhark.Transform.Rename writeResult :: [VName] -> PatElem dec -> KernelResult -> MulticoreGen () writeResult is pe (Returns _ _ se) = copyDWIMFix (patElemName pe) (map Imp.le64 is) se [] writeResult _ pe (WriteReturns _ arr idx_vals) = do arr_t <- lookupType arr let (iss, vs) = unzip idx_vals rws' = map pe64 $ arrayDims arr_t forM_ (zip iss vs) $ \(slice, v) -> do let slice' = fmap pe64 slice sWhen (inBounds slice' rws') $ copyDWIM (patElemName pe) (unSlice slice') v [] writeResult _ _ res = error $ "writeResult: cannot handle " ++ prettyString res compileSegMapBody :: Pat LetDecMem -> SegSpace -> KernelBody MCMem -> MulticoreGen Imp.MCCode compileSegMapBody pat space (KernelBody _ kstms kres) = collect $ do let (is, ns) = unzip $ unSegSpace space ns' = map pe64 ns dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) kstms' <- mapM renameStm kstms inISPC $ generateChunkLoop "SegMap" Vectorized $ \i -> do dIndexSpace (zip is ns') i compileStms (freeIn kres) kstms' $ zipWithM_ (writeResult is) (patElems pat) kres compileSegMap :: Pat LetDecMem -> SegSpace -> KernelBody MCMem -> MulticoreGen Imp.MCCode compileSegMap pat space kbody = collect $ do body <- compileSegMapBody pat space kbody free_params <- freeParams body emit $ Imp.Op $ Imp.ParLoop "segmap" body free_params futhark-0.25.27/src/Futhark/CodeGen/ImpGen/Multicore/SegRed.hs000066400000000000000000000406001475065116200236700ustar00rootroot00000000000000module Futhark.CodeGen.ImpGen.Multicore.SegRed ( compileSegRed, compileSegRed', DoSegBody, ) where import Control.Monad import Futhark.CodeGen.ImpCode.Multicore qualified as Imp import Futhark.CodeGen.ImpGen import Futhark.CodeGen.ImpGen.Multicore.Base import Futhark.IR.MCMem import Futhark.Transform.Rename (renameLambda) import Prelude hiding (quot, rem) type DoSegBody = (([[(SubExp, [Imp.TExp Int64])]] -> MulticoreGen ()) -> MulticoreGen ()) -- | Generate code for a SegRed construct compileSegRed :: Pat LetDecMem -> SegSpace -> [SegBinOp MCMem] -> KernelBody MCMem -> TV Int32 -> MulticoreGen Imp.MCCode compileSegRed pat space reds kbody nsubtasks = compileSegRed' pat space reds nsubtasks $ \red_cont -> compileStms mempty (kernelBodyStms kbody) $ do let (red_res, map_res) = splitAt (segBinOpResults reds) $ kernelBodyResult kbody sComment "save map-out results" $ do let map_arrs = drop (segBinOpResults reds) $ patElems pat zipWithM_ (compileThreadResult space) map_arrs map_res red_cont $ segBinOpChunks reds $ map ((,[]) . kernelResultSubExp) red_res -- | Like 'compileSegRed', but where the body is a monadic action. compileSegRed' :: Pat LetDecMem -> SegSpace -> [SegBinOp MCMem] -> TV Int32 -> DoSegBody -> MulticoreGen Imp.MCCode compileSegRed' pat space reds nsubtasks kbody | [_] <- unSegSpace space = nonsegmentedReduction pat space reds nsubtasks kbody | otherwise = segmentedReduction pat space reds kbody -- | A SegBinOp with auxiliary information. data SegBinOpSlug = SegBinOpSlug { slugOp :: SegBinOp MCMem, -- | The array in which we write the intermediate results, indexed -- by the flat/physical thread ID. slugResArrs :: [VName] } slugBody :: SegBinOpSlug -> Body MCMem slugBody = lambdaBody . segBinOpLambda . slugOp slugParams :: SegBinOpSlug -> [LParam MCMem] slugParams = lambdaParams . segBinOpLambda . slugOp slugNeutral :: SegBinOpSlug -> [SubExp] slugNeutral = segBinOpNeutral . slugOp slugShape :: SegBinOpSlug -> Shape slugShape = segBinOpShape . slugOp accParams, nextParams :: SegBinOpSlug -> [LParam MCMem] accParams slug = take (length (slugNeutral slug)) $ slugParams slug nextParams slug = drop (length (slugNeutral slug)) $ slugParams slug renameSlug :: SegBinOpSlug -> MulticoreGen SegBinOpSlug renameSlug slug = do let op = slugOp slug let lambda = segBinOpLambda op lambda' <- renameLambda lambda let op' = op {segBinOpLambda = lambda'} pure slug {slugOp = op'} -- | Arrays for storing group results shared between threads groupResultArrays :: String -> SubExp -> [SegBinOp MCMem] -> MulticoreGen [[VName]] groupResultArrays s num_threads reds = forM reds $ \(SegBinOp _ lam _ shape) -> forM (lambdaReturnType lam) $ \t -> do let full_shape = Shape [num_threads] <> shape <> arrayShape t sAllocArray s (elemType t) full_shape DefaultSpace nonsegmentedReduction :: Pat LetDecMem -> SegSpace -> [SegBinOp MCMem] -> TV Int32 -> DoSegBody -> MulticoreGen Imp.MCCode nonsegmentedReduction pat space reds nsubtasks kbody = collect $ do thread_res_arrs <- groupResultArrays "reduce_stage_1_tid_res_arr" (tvSize nsubtasks) reds let slugs1 = zipWith SegBinOpSlug reds thread_res_arrs nsubtasks' = tvExp nsubtasks -- Are all the operators commutative? let comm = all ((== Commutative) . segBinOpComm) reds let dims = map (shapeDims . slugShape) slugs1 let isScalar x = case x of MemPrim _ -> True; _ -> False -- Are we only working on scalar arrays? let scalars = all (all (isScalar . paramDec) . slugParams) slugs1 && all (== []) dims -- Are we working with vectorized inner maps? let inner_map = [] `notElem` dims let path | comm && scalars = reductionStage1CommScalar | inner_map = reductionStage1Array | scalars = reductionStage1NonCommScalar | otherwise = reductionStage1Fallback path space slugs1 kbody reds2 <- renameSegBinOp reds let slugs2 = zipWith SegBinOpSlug reds2 thread_res_arrs reductionStage2 pat space nsubtasks' slugs2 -- Generate code that declares the params for the binop genBinOpParams :: [SegBinOpSlug] -> MulticoreGen () genBinOpParams slugs = dScope Nothing $ scopeOfLParams $ concatMap slugParams slugs -- Generate code that declares accumulators, return a list of these genAccumulators :: [SegBinOpSlug] -> MulticoreGen [[VName]] genAccumulators slugs = forM slugs $ \slug -> do let shape = segBinOpShape $ slugOp slug forM (zip (accParams slug) (slugNeutral slug)) $ \(p, ne) -> do -- Declare accumulator variable. acc <- case paramType p of Prim pt | shape == mempty -> dPrimS "local_acc" pt | otherwise -> sAllocArray "local_acc" pt shape DefaultSpace _ -> pure $ paramName p -- Now neutral-initialise the accumulator. sLoopNest (slugShape slug) $ \vec_is -> copyDWIMFix acc vec_is ne [] pure acc -- Datatype to represent all the different ways we can generate -- code for a reduction. data RedLoopType = RedSeq -- Fully sequential | RedComm -- Commutative scalar | RedNonComm -- Noncommutative scalar | RedNested -- Nested vectorized operator | RedUniformize -- Uniformize over scalar acc -- Given a type of reduction and the loop index, should we wrap -- the loop body in some extra code? getRedLoop :: RedLoopType -> Imp.TExp Int64 -> (Imp.TExp Int64 -> MulticoreGen ()) -> MulticoreGen () getRedLoop RedNonComm _ = generateUniformizeLoop getRedLoop RedUniformize uni = \body -> body uni getRedLoop _ _ = \body -> body 0 -- Given a type of reduction, should we perform extracts on -- the accumulator? getExtract :: RedLoopType -> Imp.TExp Int64 -> MulticoreGen Imp.MCCode -> MulticoreGen () getExtract RedNonComm = extractVectorLane getExtract RedUniformize = extractVectorLane getExtract _ = \_ body -> body >>= emit -- Given a type of reduction, should we vectorize the inner -- map, if it exists? getNestLoop :: RedLoopType -> Shape -> ([Imp.TExp Int64] -> MulticoreGen ()) -> MulticoreGen () getNestLoop RedNested = sLoopNestVectorized getNestLoop _ = sLoopNest -- Given a list of accumulators, use them as the source -- data for reduction. redSourceAccs :: [[VName]] -> DoSegBody redSourceAccs slug_local_accs m = m $ map (map (\x -> (Var x, []))) slug_local_accs -- Generate a reduction loop for uniformizing vectors genPostbodyReductionLoop :: [[VName]] -> [SegBinOpSlug] -> [[VName]] -> SegSpace -> Imp.TExp Int64 -> MulticoreGen () genPostbodyReductionLoop accs = genReductionLoop RedUniformize (redSourceAccs accs) -- Generate a potentially vectorized body of code that performs reduction -- when put inside a chunked loop. genReductionLoop :: RedLoopType -> DoSegBody -> [SegBinOpSlug] -> [[VName]] -> SegSpace -> Imp.TExp Int64 -> MulticoreGen () genReductionLoop typ kbodymap slugs slug_local_accs space i = do let (is, ns) = unzip $ unSegSpace space ns' = map pe64 ns zipWithM_ dPrimV_ is $ unflattenIndex ns' i kbodymap $ \all_red_res' -> do forM_ (zip3 all_red_res' slugs slug_local_accs) $ \(red_res, slug, local_accs) -> getNestLoop typ (slugShape slug) $ \vec_is -> do let lamtypes = lambdaReturnType $ segBinOpLambda $ slugOp slug getRedLoop typ i $ \uni -> do sComment "Load accum params" $ forM_ (zip3 (accParams slug) local_accs lamtypes) $ \(p, local_acc, t) -> when (primType t) $ do copyDWIMFix (paramName p) [] (Var local_acc) vec_is sComment "Load next params" $ forM_ (zip (nextParams slug) red_res) $ \(p, (res, res_is)) -> do getExtract typ uni $ collect $ copyDWIMFix (paramName p) [] res (res_is ++ vec_is) sComment "SegRed body" $ compileStms mempty (bodyStms $ slugBody slug) $ forM_ (zip local_accs $ map resSubExp $ bodyResult $ slugBody slug) $ \(local_acc, se) -> copyDWIMFix local_acc vec_is se [] -- Generate code to write back results from the accumulators genWriteBack :: [SegBinOpSlug] -> [[VName]] -> SegSpace -> MulticoreGen () genWriteBack slugs slug_local_accs space = forM_ (zip slugs slug_local_accs) $ \(slug, local_accs) -> forM (zip (slugResArrs slug) local_accs) $ \(acc, local_acc) -> copyDWIMFix acc [Imp.le64 $ segFlat space] (Var local_acc) [] type ReductionStage1 = SegSpace -> [SegBinOpSlug] -> DoSegBody -> MulticoreGen () -- Pure sequential codegen with no fancy vectorization reductionStage1Fallback :: ReductionStage1 reductionStage1Fallback space slugs kbody = do fbody <- collect $ do dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) -- Declare params genBinOpParams slugs slug_local_accs <- genAccumulators slugs -- Generate main reduction loop generateChunkLoop "SegRed" Scalar $ genReductionLoop RedSeq kbody slugs slug_local_accs space -- Write back results genWriteBack slugs slug_local_accs space free_params <- freeParams fbody emit $ Imp.Op $ Imp.ParLoop "segred_stage_1" fbody free_params -- Codegen for noncommutative scalar reduction. We vectorize the -- kernel body, and do the reduction sequentially. reductionStage1NonCommScalar :: ReductionStage1 reductionStage1NonCommScalar space slugs kbody = do fbody <- collect $ do dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) inISPC $ do -- Declare params genBinOpParams slugs slug_local_accs <- genAccumulators slugs -- Generate main reduction loop generateChunkLoop "SegRed" Vectorized $ genReductionLoop RedNonComm kbody slugs slug_local_accs space -- Write back results genWriteBack slugs slug_local_accs space free_params <- freeParams fbody emit $ Imp.Op $ Imp.ParLoop "segred_stage_1" fbody free_params -- Codegen for a commutative reduction on scalar arrays -- In this case, we can generate an efficient interleaved reduction reductionStage1CommScalar :: ReductionStage1 reductionStage1CommScalar space slugs kbody = do fbody <- collect $ do dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) -- Rename lambda params in slugs to get a new set of them slugs' <- mapM renameSlug slugs inISPC $ do -- Declare one set of params uniform genBinOpParams slugs' slug_local_accs_uni <- genAccumulators slugs' -- Declare the other varying genBinOpParams slugs slug_local_accs <- genAccumulators slugs -- Generate the main reduction loop over vectors generateChunkLoop "SegRed" Vectorized $ genReductionLoop RedComm kbody slugs slug_local_accs space -- Now reduce over those vector accumulators to get scalar results generateUniformizeLoop $ genPostbodyReductionLoop slug_local_accs slugs' slug_local_accs_uni space -- And write back the results genWriteBack slugs slug_local_accs_uni space free_params <- freeParams fbody emit $ Imp.Op $ Imp.ParLoop "segred_stage_1" fbody free_params -- Codegen for a reduction on arrays, where the body is a perfect nested map. -- We vectorize just the inner map. reductionStage1Array :: ReductionStage1 reductionStage1Array space slugs kbody = do fbody <- collect $ do dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) -- Declare params lparams <- collect $ genBinOpParams slugs (slug_local_accs, uniform_prebody) <- collect' $ genAccumulators slugs -- Put the accumulators outside of the kernel, so they are forced uniform emit uniform_prebody inISPC $ do -- Put the lambda params inside the kernel so they are varying emit lparams -- Generate the main reduction loop generateChunkLoop "SegRed" Scalar $ genReductionLoop RedNested kbody slugs slug_local_accs space -- Write back results genWriteBack slugs slug_local_accs space free_params <- freeParams fbody emit $ Imp.Op $ Imp.ParLoop "segred_stage_1" fbody free_params reductionStage2 :: Pat LetDecMem -> SegSpace -> Imp.TExp Int32 -> [SegBinOpSlug] -> MulticoreGen () reductionStage2 pat space nsubtasks slugs = do let per_red_pes = segBinOpChunks (map slugOp slugs) $ patElems pat phys_id = Imp.le64 (segFlat space) sComment "neutral-initialise the output" $ forM_ (zip (map slugOp slugs) per_red_pes) $ \(red, red_res) -> forM_ (zip red_res $ segBinOpNeutral red) $ \(pe, ne) -> sLoopNest (segBinOpShape red) $ \vec_is -> copyDWIMFix (patElemName pe) vec_is ne [] dScope Nothing $ scopeOfLParams $ concatMap slugParams slugs sFor "i" nsubtasks $ \i' -> do mkTV (segFlat space) <-- i' sComment "Apply main thread reduction" $ forM_ (zip slugs per_red_pes) $ \(slug, red_res) -> sLoopNest (slugShape slug) $ \vec_is -> do sComment "load acc params" $ forM_ (zip (accParams slug) red_res) $ \(p, pe) -> copyDWIMFix (paramName p) [] (Var $ patElemName pe) vec_is sComment "load next params" $ forM_ (zip (nextParams slug) (slugResArrs slug)) $ \(p, acc) -> copyDWIMFix (paramName p) [] (Var acc) (phys_id : vec_is) sComment "red body" $ compileStms mempty (bodyStms $ slugBody slug) $ forM_ (zip red_res $ map resSubExp $ bodyResult $ slugBody slug) $ \(pe, se') -> copyDWIMFix (patElemName pe) vec_is se' [] -- Each thread reduces over the number of segments -- each of which is done sequentially -- Maybe we should select the work of the inner loop -- based on n_segments and dimensions etc. segmentedReduction :: Pat LetDecMem -> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen Imp.MCCode segmentedReduction pat space reds kbody = collect $ do body <- compileSegRedBody pat space reds kbody free_params <- freeParams body emit $ Imp.Op $ Imp.ParLoop "segmented_segred" body free_params -- Currently, this is only used as part of SegHist calculations, never alone. compileSegRedBody :: Pat LetDecMem -> SegSpace -> [SegBinOp MCMem] -> DoSegBody -> MulticoreGen Imp.MCCode compileSegRedBody pat space reds kbody = do let (is, ns) = unzip $ unSegSpace space ns_64 = map pe64 ns inner_bound = last ns_64 dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) let per_red_pes = segBinOpChunks reds $ patElems pat -- Perform sequential reduce on inner most dimension collect . inISPC $ generateChunkLoop "SegRed" Vectorized $ \n_segments -> do flat_idx <- dPrimVE "flat_idx" $ n_segments * inner_bound zipWithM_ dPrimV_ is $ unflattenIndex ns_64 flat_idx sComment "neutral-initialise the accumulators" $ forM_ (zip per_red_pes reds) $ \(pes, red) -> forM_ (zip pes (segBinOpNeutral red)) $ \(pe, ne) -> sLoopNest (segBinOpShape red) $ \vec_is -> copyDWIMFix (patElemName pe) (map Imp.le64 (init is) ++ vec_is) ne [] sComment "main body" $ do dScope Nothing $ scopeOfLParams $ concatMap (lambdaParams . segBinOpLambda) reds sFor "i" inner_bound $ \i -> do zipWithM_ (<--) (map mkTV $ init is) (unflattenIndex (init ns_64) (sExt64 n_segments)) dPrimV_ (last is) i kbody $ \red_res' -> do forM_ (zip3 per_red_pes reds red_res') $ \(pes, red, res') -> sLoopNest (segBinOpShape red) $ \vec_is -> do sComment "load accum" $ do let acc_params = take (length (segBinOpNeutral red)) $ (lambdaParams . segBinOpLambda) red forM_ (zip acc_params pes) $ \(p, pe) -> copyDWIMFix (paramName p) [] (Var $ patElemName pe) (map Imp.le64 (init is) ++ vec_is) sComment "load new val" $ do let next_params = drop (length (segBinOpNeutral red)) $ (lambdaParams . segBinOpLambda) red forM_ (zip next_params res') $ \(p, (res, res_is)) -> copyDWIMFix (paramName p) [] res (res_is ++ vec_is) sComment "apply reduction" $ do let lbody = (lambdaBody . segBinOpLambda) red compileStms mempty (bodyStms lbody) $ sComment "write back to res" $ forM_ (zip pes $ map resSubExp $ bodyResult lbody) $ \(pe, se') -> copyDWIMFix (patElemName pe) (map Imp.le64 (init is) ++ vec_is) se' [] futhark-0.25.27/src/Futhark/CodeGen/ImpGen/Multicore/SegScan.hs000066400000000000000000000432051475065116200240460ustar00rootroot00000000000000module Futhark.CodeGen.ImpGen.Multicore.SegScan ( compileSegScan, ) where import Control.Monad import Data.List (zip4) import Futhark.CodeGen.ImpCode.Multicore qualified as Imp import Futhark.CodeGen.ImpGen import Futhark.CodeGen.ImpGen.Multicore.Base import Futhark.IR.MCMem import Futhark.Util.IntegralExp (quot, rem) import Prelude hiding (quot, rem) -- Compile a SegScan construct compileSegScan :: Pat LetDecMem -> SegSpace -> [SegBinOp MCMem] -> KernelBody MCMem -> TV Int32 -> MulticoreGen Imp.MCCode compileSegScan pat space reds kbody nsubtasks | [_] <- unSegSpace space = nonsegmentedScan pat space reds kbody nsubtasks | otherwise = segmentedScan pat space reds kbody xParams, yParams :: SegBinOp MCMem -> [LParam MCMem] xParams scan = take (length (segBinOpNeutral scan)) (lambdaParams (segBinOpLambda scan)) yParams scan = drop (length (segBinOpNeutral scan)) (lambdaParams (segBinOpLambda scan)) lamBody :: SegBinOp MCMem -> Body MCMem lamBody = lambdaBody . segBinOpLambda -- Arrays for storing worker results. carryArrays :: String -> TV Int32 -> [SegBinOp MCMem] -> MulticoreGen [[VName]] carryArrays s nsubtasks segops = forM segops $ \(SegBinOp _ lam _ shape) -> forM (lambdaReturnType lam) $ \t -> do let pt = elemType t full_shape = Shape [Var (tvVar nsubtasks)] <> shape <> arrayShape t sAllocArray s pt full_shape DefaultSpace nonsegmentedScan :: Pat LetDecMem -> SegSpace -> [SegBinOp MCMem] -> KernelBody MCMem -> TV Int32 -> MulticoreGen Imp.MCCode nonsegmentedScan pat space scan_ops kbody nsubtasks = do emit $ Imp.DebugPrint "nonsegmented segScan" Nothing collect $ do -- Are we working with nested arrays let dims = map (shapeDims . segBinOpShape) scan_ops -- Are we only working on scalars let scalars = all (all (primType . typeOf . paramDec) . (lambdaParams . segBinOpLambda)) scan_ops && all null dims -- Do we have nested vector operations let vectorize = [] `notElem` dims let param_types = concatMap (map paramType . (lambdaParams . segBinOpLambda)) scan_ops let no_array_param = all primType param_types let (scanStage1, scanStage3) | scalars = (scanStage1Scalar, scanStage3Scalar) | vectorize && no_array_param = (scanStage1Nested, scanStage3Nested) | otherwise = (scanStage1Fallback, scanStage3Fallback) emit $ Imp.DebugPrint "Scan stage 1" Nothing scanStage1 pat space kbody scan_ops let nsubtasks' = tvExp nsubtasks sWhen (nsubtasks' .>. 1) $ do scan_ops2 <- renameSegBinOp scan_ops emit $ Imp.DebugPrint "Scan stage 2" Nothing carries <- scanStage2 pat nsubtasks space scan_ops2 scan_ops3 <- renameSegBinOp scan_ops emit $ Imp.DebugPrint "Scan stage 3" Nothing scanStage3 pat space scan_ops3 carries -- Different ways to generate code for a scan loop data ScanLoopType = ScanSeq -- Fully sequential | ScanNested -- Nested vectorized map | ScanScalar -- Vectorized scan over scalars -- Given a scan type, return a function to inject into the loop body getScanLoop :: ScanLoopType -> (Imp.TExp Int64 -> MulticoreGen ()) -> MulticoreGen () getScanLoop ScanScalar = generateUniformizeLoop getScanLoop _ = \body -> body 0 -- Given a scan type, return a function to extract a scalar from a vector getExtract :: ScanLoopType -> Imp.TExp Int64 -> MulticoreGen Imp.MCCode -> MulticoreGen () getExtract ScanSeq = \_ body -> body >>= emit getExtract _ = extractVectorLane genBinOpParams :: [SegBinOp MCMem] -> MulticoreGen () genBinOpParams scan_ops = dScope Nothing $ scopeOfLParams $ concatMap (lambdaParams . segBinOpLambda) scan_ops genLocalAccsStage1 :: [SegBinOp MCMem] -> MulticoreGen [[VName]] genLocalAccsStage1 scan_ops = do forM scan_ops $ \scan_op -> do let shape = segBinOpShape scan_op ts = lambdaReturnType $ segBinOpLambda scan_op forM (zip3 (xParams scan_op) (segBinOpNeutral scan_op) ts) $ \(p, ne, t) -> do acc <- -- update accumulator to have type decoration case shapeDims shape of [] -> pure $ paramName p _ -> do let pt = elemType t sAllocArray "local_acc" pt (shape <> arrayShape t) DefaultSpace -- Now neutral-initialise the accumulator. sLoopNest (segBinOpShape scan_op) $ \vec_is -> copyDWIMFix acc vec_is ne [] pure acc getNestLoop :: ScanLoopType -> Shape -> ([Imp.TExp Int64] -> MulticoreGen ()) -> MulticoreGen () getNestLoop ScanNested = sLoopNestVectorized getNestLoop _ = sLoopNest applyScanOps :: ScanLoopType -> Pat LetDecMem -> SegSpace -> [SubExp] -> [SegBinOp MCMem] -> [[VName]] -> ImpM MCMem HostEnv Imp.Multicore () applyScanOps typ pat space all_scan_res scan_ops local_accs = do let per_scan_res = segBinOpChunks scan_ops all_scan_res per_scan_pes = segBinOpChunks scan_ops $ patElems pat let (is, _) = unzip $ unSegSpace space -- Potential vector load and then do sequential scan getScanLoop typ $ \j -> forM_ (zip4 per_scan_pes scan_ops per_scan_res local_accs) $ \(pes, scan_op, scan_res, acc) -> getNestLoop typ (segBinOpShape scan_op) $ \vec_is -> do sComment "Read accumulator" $ forM_ (zip (xParams scan_op) acc) $ \(p, acc') -> do copyDWIMFix (paramName p) [] (Var acc') vec_is sComment "Read next values" $ forM_ (zip (yParams scan_op) scan_res) $ \(p, se) -> getExtract typ j $ collect $ copyDWIMFix (paramName p) [] se vec_is -- Scan body sComment "Scan op body" $ compileStms mempty (bodyStms $ lamBody scan_op) $ forM_ (zip3 acc pes $ map resSubExp $ bodyResult $ lamBody scan_op) $ \(acc', pe, se) -> do copyDWIMFix (patElemName pe) (map Imp.le64 is ++ vec_is) se [] copyDWIMFix acc' vec_is se [] -- Generate a loop which performs a potentially vectorized scan on the -- result of a kernel body. genScanLoop :: ScanLoopType -> Pat LetDecMem -> SegSpace -> KernelBody MCMem -> [SegBinOp MCMem] -> [[VName]] -> Imp.TExp Int64 -> ImpM MCMem HostEnv Imp.Multicore () genScanLoop typ pat space kbody scan_ops local_accs i = do let (all_scan_res, map_res) = splitAt (segBinOpResults scan_ops) $ kernelBodyResult kbody let (is, ns) = unzip $ unSegSpace space ns' = map pe64 ns zipWithM_ dPrimV_ is $ unflattenIndex ns' i compileStms mempty (kernelBodyStms kbody) $ do let map_arrs = drop (segBinOpResults scan_ops) $ patElems pat sComment "write mapped values results to memory" $ zipWithM_ (compileThreadResult space) map_arrs map_res sComment "Apply scan op" $ applyScanOps typ pat space (map kernelResultSubExp all_scan_res) scan_ops local_accs scanStage1Scalar :: Pat LetDecMem -> SegSpace -> KernelBody MCMem -> [SegBinOp MCMem] -> MulticoreGen () scanStage1Scalar pat space kbody scan_ops = do fbody <- collect $ do dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) genBinOpParams scan_ops local_accs <- genLocalAccsStage1 scan_ops inISPC $ generateChunkLoop "SegScan" Vectorized $ genScanLoop ScanScalar pat space kbody scan_ops local_accs free_params <- freeParams fbody emit $ Imp.Op $ Imp.ParLoop "scan_stage_1" fbody free_params scanStage1Nested :: Pat LetDecMem -> SegSpace -> KernelBody MCMem -> [SegBinOp MCMem] -> MulticoreGen () scanStage1Nested pat space kbody scan_ops = do fbody <- collect $ do dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) local_accs <- genLocalAccsStage1 scan_ops inISPC $ do genBinOpParams scan_ops generateChunkLoop "SegScan" Scalar $ \i -> do genScanLoop ScanNested pat space kbody scan_ops local_accs i free_params <- freeParams fbody emit $ Imp.Op $ Imp.ParLoop "scan_stage_1" fbody free_params scanStage1Fallback :: Pat LetDecMem -> SegSpace -> KernelBody MCMem -> [SegBinOp MCMem] -> MulticoreGen () scanStage1Fallback pat space kbody scan_ops = do -- Stage 1 : each thread partially scans a chunk of the input -- Writes directly to the resulting array fbody <- collect $ do dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) genBinOpParams scan_ops local_accs <- genLocalAccsStage1 scan_ops generateChunkLoop "SegScan" Scalar $ genScanLoop ScanSeq pat space kbody scan_ops local_accs free_params <- freeParams fbody emit $ Imp.Op $ Imp.ParLoop "scan_stage_1" fbody free_params scanStage2 :: Pat LetDecMem -> TV Int32 -> SegSpace -> [SegBinOp MCMem] -> MulticoreGen [[VName]] scanStage2 pat nsubtasks space scan_ops = do let (is, ns) = unzip $ unSegSpace space ns_64 = map pe64 ns per_scan_pes = segBinOpChunks scan_ops $ patElems pat nsubtasks' = sExt64 $ tvExp nsubtasks dScope Nothing $ scopeOfLParams $ concatMap (lambdaParams . segBinOpLambda) scan_ops offset <- dPrimV "offset" (0 :: Imp.TExp Int64) let offset' = tvExp offset offset_index <- dPrimV "offset_index" (0 :: Imp.TExp Int64) let offset_index' = tvExp offset_index -- Parameters used to find the chunk sizes -- Perhaps get this information from ``scheduling information`` -- instead of computing it manually here. let iter_pr_subtask = product ns_64 `quot` nsubtasks' remainder = product ns_64 `rem` nsubtasks' carries <- carryArrays "scan_stage_2_carry" nsubtasks scan_ops sComment "carry-in for first chunk is neutral" $ forM_ (zip scan_ops carries) $ \(scan_op, carry) -> sLoopNest (segBinOpShape scan_op) $ \vec_is -> forM_ (zip carry $ segBinOpNeutral scan_op) $ \(carry', ne) -> copyDWIMFix carry' (0 : vec_is) ne [] -- Perform sequential scan over the last element of each chunk sComment "scan carries" $ sFor "i" (nsubtasks' - 1) $ \i -> do offset <-- iter_pr_subtask sWhen (sExt64 i .<. remainder) (offset <-- offset' + 1) offset_index <-- offset_index' + offset' zipWithM_ dPrimV_ is $ unflattenIndex ns_64 $ sExt64 offset_index' forM_ (zip3 per_scan_pes scan_ops carries) $ \(pes, scan_op, carry) -> sLoopNest (segBinOpShape scan_op) $ \vec_is -> do sComment "Read carry" $ forM_ (zip (xParams scan_op) carry) $ \(p, carry') -> copyDWIMFix (paramName p) [] (Var carry') (i : vec_is) sComment "Read next values" $ forM_ (zip (yParams scan_op) pes) $ \(p, pe) -> copyDWIMFix (paramName p) [] (Var $ patElemName pe) ((offset_index' - 1) : vec_is) compileStms mempty (bodyStms $ lamBody scan_op) $ forM_ (zip carry $ map resSubExp $ bodyResult $ lamBody scan_op) $ \(carry', se) -> do copyDWIMFix carry' ((i + 1) : vec_is) se [] -- Return the array of carries for each chunk. pure carries scanStage3Scalar :: Pat LetDecMem -> SegSpace -> [SegBinOp MCMem] -> [[VName]] -> MulticoreGen () scanStage3Scalar pat space scan_ops per_scan_carries = do let per_scan_pes = segBinOpChunks scan_ops $ patElems pat (is, ns) = unzip $ unSegSpace space ns' = map pe64 ns body <- collect $ do dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId $ segFlat space inISPC $ do genBinOpParams scan_ops sComment "load carry-in" $ forM_ (zip per_scan_carries scan_ops) $ \(op_carries, scan_op) -> forM_ (zip (xParams scan_op) op_carries) $ \(p, carries) -> copyDWIMFix (paramName p) [] (Var carries) [le64 (segFlat space)] generateChunkLoop "SegScan" Vectorized $ \i -> do zipWithM_ dPrimV_ is $ unflattenIndex ns' i sComment "load partial result" $ forM_ (zip per_scan_pes scan_ops) $ \(scan_pes, scan_op) -> forM_ (zip (yParams scan_op) scan_pes) $ \(p, pe) -> copyDWIMFix (paramName p) [] (Var (patElemName pe)) (map le64 is) sComment "combine carry with partial result" $ forM_ (zip per_scan_pes scan_ops) $ \(scan_pes, scan_op) -> compileStms mempty (bodyStms $ lamBody scan_op) $ forM_ (zip scan_pes $ map resSubExp $ bodyResult $ lamBody scan_op) $ \(pe, se) -> copyDWIMFix (patElemName pe) (map Imp.le64 is) se [] free_params <- freeParams body emit $ Imp.Op $ Imp.ParLoop "scan_stage_3" body free_params scanStage3Nested :: Pat LetDecMem -> SegSpace -> [SegBinOp MCMem] -> [[VName]] -> MulticoreGen () scanStage3Nested pat space scan_ops per_scan_carries = do let per_scan_pes = segBinOpChunks scan_ops $ patElems pat (is, ns) = unzip $ unSegSpace space ns' = map pe64 ns body <- collect $ do dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) generateChunkLoop "SegScan" Scalar $ \i -> do genBinOpParams scan_ops zipWithM_ dPrimV_ is $ unflattenIndex ns' i forM_ (zip3 per_scan_pes per_scan_carries scan_ops) $ \(scan_pes, op_carries, scan_op) -> do sLoopNest (segBinOpShape scan_op) $ \vec_is -> do sComment "load carry-in" $ forM_ (zip (xParams scan_op) op_carries) $ \(p, carries) -> copyDWIMFix (paramName p) [] (Var carries) (le64 (segFlat space) : vec_is) sComment "load partial result" $ forM_ (zip (yParams scan_op) scan_pes) $ \(p, pe) -> copyDWIMFix (paramName p) [] (Var (patElemName pe)) (map le64 is ++ vec_is) sComment "combine carry with partial result" $ compileStms mempty (bodyStms $ lamBody scan_op) $ forM_ (zip scan_pes $ map resSubExp $ bodyResult $ lamBody scan_op) $ \(pe, se) -> copyDWIMFix (patElemName pe) (map Imp.le64 is ++ vec_is) se [] free_params <- freeParams body emit $ Imp.Op $ Imp.ParLoop "scan_stage_3" body free_params scanStage3Fallback :: Pat LetDecMem -> SegSpace -> [SegBinOp MCMem] -> [[VName]] -> MulticoreGen () scanStage3Fallback pat space scan_ops per_scan_carries = do let per_scan_pes = segBinOpChunks scan_ops $ patElems pat (is, ns) = unzip $ unSegSpace space ns' = map pe64 ns body <- collect $ do dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) genBinOpParams scan_ops generateChunkLoop "SegScan" Scalar $ \i -> do zipWithM_ dPrimV_ is $ unflattenIndex ns' i forM_ (zip3 per_scan_pes per_scan_carries scan_ops) $ \(scan_pes, op_carries, scan_op) -> do sLoopNest (segBinOpShape scan_op) $ \vec_is -> do sComment "load carry-in" $ forM_ (zip (xParams scan_op) op_carries) $ \(p, carries) -> copyDWIMFix (paramName p) [] (Var carries) (le64 (segFlat space) : vec_is) sComment "load partial result" $ forM_ (zip (yParams scan_op) scan_pes) $ \(p, pe) -> copyDWIMFix (paramName p) [] (Var (patElemName pe)) (map le64 is ++ vec_is) sComment "combine carry with partial result" $ compileStms mempty (bodyStms $ lamBody scan_op) $ forM_ (zip scan_pes $ map resSubExp $ bodyResult $ lamBody scan_op) $ \(pe, se) -> copyDWIMFix (patElemName pe) (map Imp.le64 is ++ vec_is) se [] free_params <- freeParams body emit $ Imp.Op $ Imp.ParLoop "scan_stage_3" body free_params -- Note: This isn't currently used anywhere. -- This implementation for a Segmented scan only -- parallelize over the segments and each segment is -- scanned sequentially. segmentedScan :: Pat LetDecMem -> SegSpace -> [SegBinOp MCMem] -> KernelBody MCMem -> MulticoreGen Imp.MCCode segmentedScan pat space scan_ops kbody = do emit $ Imp.DebugPrint "segmented segScan" Nothing collect $ do body <- compileSegScanBody pat space scan_ops kbody free_params <- freeParams body emit $ Imp.Op $ Imp.ParLoop "seg_scan" body free_params compileSegScanBody :: Pat LetDecMem -> SegSpace -> [SegBinOp MCMem] -> KernelBody MCMem -> MulticoreGen Imp.MCCode compileSegScanBody pat space scan_ops kbody = collect $ do let (is, ns) = unzip $ unSegSpace space ns_64 = map pe64 ns dPrim_ (segFlat space) int64 sOp $ Imp.GetTaskId (segFlat space) let per_scan_pes = segBinOpChunks scan_ops $ patElems pat generateChunkLoop "SegScan" Scalar $ \segment_i -> do forM_ (zip scan_ops per_scan_pes) $ \(scan_op, scan_pes) -> do dScope Nothing $ scopeOfLParams $ lambdaParams $ segBinOpLambda scan_op let (scan_x_params, scan_y_params) = splitAt (length $ segBinOpNeutral scan_op) $ (lambdaParams . segBinOpLambda) scan_op forM_ (zip scan_x_params $ segBinOpNeutral scan_op) $ \(p, ne) -> copyDWIMFix (paramName p) [] ne [] let inner_bound = last ns_64 -- Perform a sequential scan over the segment ``segment_i`` sFor "i" inner_bound $ \i -> do zipWithM_ dPrimV_ (init is) $ unflattenIndex (init ns_64) segment_i dPrimV_ (last is) i compileStms mempty (kernelBodyStms kbody) $ do let (scan_res, map_res) = splitAt (length $ segBinOpNeutral scan_op) $ kernelBodyResult kbody sComment "write to-scan values to parameters" $ forM_ (zip scan_y_params scan_res) $ \(p, se) -> copyDWIMFix (paramName p) [] (kernelResultSubExp se) [] sComment "write mapped values results to memory" $ forM_ (zip (drop (length $ segBinOpNeutral scan_op) $ patElems pat) map_res) $ \(pe, se) -> copyDWIMFix (patElemName pe) (map Imp.le64 is) (kernelResultSubExp se) [] sComment "combine with carry and write to memory" $ compileStms mempty (bodyStms $ lambdaBody $ segBinOpLambda scan_op) $ forM_ (zip3 scan_x_params scan_pes $ map resSubExp $ bodyResult $ lambdaBody $ segBinOpLambda scan_op) $ \(p, pe, se) -> do copyDWIMFix (patElemName pe) (map Imp.le64 is) se [] copyDWIMFix (paramName p) [] se [] futhark-0.25.27/src/Futhark/CodeGen/ImpGen/OpenCL.hs000066400000000000000000000010641475065116200216750ustar00rootroot00000000000000-- | Code generation for ImpCode with OpenCL kernels. module Futhark.CodeGen.ImpGen.OpenCL ( compileProg, Warnings, ) where import Data.Bifunctor (second) import Futhark.CodeGen.ImpCode.OpenCL qualified as OpenCL import Futhark.CodeGen.ImpGen.GPU import Futhark.CodeGen.ImpGen.GPU.ToOpenCL import Futhark.IR.GPUMem import Futhark.MonadFreshNames -- | Compile the program to ImpCode with OpenCL kernels. compileProg :: (MonadFreshNames m) => Prog GPUMem -> m (Warnings, OpenCL.Program) compileProg prog = second kernelsToOpenCL <$> compileProgOpenCL prog futhark-0.25.27/src/Futhark/CodeGen/ImpGen/Sequential.hs000066400000000000000000000013141475065116200226650ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Compile Futhark to sequential imperative code. module Futhark.CodeGen.ImpGen.Sequential ( compileProg, ImpGen.Warnings, ) where import Futhark.CodeGen.ImpCode.Sequential qualified as Imp import Futhark.CodeGen.ImpGen qualified as ImpGen import Futhark.IR.SeqMem import Futhark.MonadFreshNames -- | Compile a 'SeqMem' program to sequential imperative code. compileProg :: (MonadFreshNames m) => Prog SeqMem -> m (ImpGen.Warnings, Imp.Program) compileProg = ImpGen.compileProg () ops Imp.DefaultSpace where ops = ImpGen.defaultOperations opCompiler opCompiler dest (Alloc e space) = ImpGen.compileAlloc dest e space opCompiler _ (Inner NoOp) = pure () futhark-0.25.27/src/Futhark/CodeGen/OpenCL/000077500000000000000000000000001475065116200201615ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/OpenCL/Heuristics.hs000066400000000000000000000054041475065116200226420ustar00rootroot00000000000000-- | Some GPU platforms have a SIMD/warp/wavefront-based execution -- model that execute blocks of threads in lockstep, permitting us to -- perform cross-thread synchronisation within each such block without -- the use of barriers. Unfortunately, there seems to be no reliable -- way to query these sizes at runtime. Instead, we use builtin tables -- to figure out which size we should use for a specific platform and -- device. If nothing matches here, the wave size should be set to -- one. -- -- We also use this to select reasonable default block sizes and block -- counts. module Futhark.CodeGen.OpenCL.Heuristics ( SizeHeuristic (..), DeviceType (..), WhichSize (..), DeviceInfo (..), sizeHeuristicsTable, ) where import Futhark.Analysis.PrimExp import Futhark.Util.Pretty -- | The type of OpenCL device that this heuristic applies to. data DeviceType = DeviceCPU | DeviceGPU -- | The value supplies by a heuristic can depend on some device -- information. This will be translated into a call to -- @clGetDeviceInfo()@. Make sure to only request info that can be -- casted to a scalar type. newtype DeviceInfo = DeviceInfo String instance Pretty DeviceInfo where pretty (DeviceInfo s) = "device_info" <> parens (pretty s) -- | A size that can be assigned a default. data WhichSize = LockstepWidth | NumBlocks | BlockSize | TileSize | RegTileSize | Threshold -- | A heuristic for setting the default value for something. data SizeHeuristic = SizeHeuristic { platformName :: String, deviceType :: DeviceType, heuristicSize :: WhichSize, heuristicValue :: TPrimExp Int32 DeviceInfo } -- | All of our heuristics. sizeHeuristicsTable :: [SizeHeuristic] sizeHeuristicsTable = [ SizeHeuristic "NVIDIA CUDA" DeviceGPU LockstepWidth 32, SizeHeuristic "AMD Accelerated Parallel Processing" DeviceGPU LockstepWidth 32, SizeHeuristic "rusticl" DeviceGPU LockstepWidth 32, SizeHeuristic "" DeviceGPU LockstepWidth 1, -- We calculate the number of blocks to aim for 1024 threads per -- compute unit if we also use the default block size. This seems -- to perform well in practice. SizeHeuristic "" DeviceGPU NumBlocks $ 4 * max_compute_units, SizeHeuristic "" DeviceGPU BlockSize 256, SizeHeuristic "" DeviceGPU TileSize 32, SizeHeuristic "" DeviceGPU RegTileSize 2, SizeHeuristic "" DeviceGPU Threshold $ 32 * 1024, -- SizeHeuristic "" DeviceCPU LockstepWidth 1, SizeHeuristic "" DeviceCPU NumBlocks max_compute_units, SizeHeuristic "" DeviceCPU BlockSize 32, SizeHeuristic "" DeviceCPU TileSize 4, SizeHeuristic "" DeviceCPU RegTileSize 1, SizeHeuristic "" DeviceCPU Threshold max_compute_units ] where max_compute_units = TPrimExp $ LeafExp (DeviceInfo "MAX_COMPUTE_UNITS") $ IntType Int32 futhark-0.25.27/src/Futhark/CodeGen/RTS/000077500000000000000000000000001475065116200175115ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/CodeGen/RTS/C.hs000066400000000000000000000077021475065116200202350ustar00rootroot00000000000000{-# LANGUAGE TemplateHaskell #-} -- | Code snippets used by the C backends. module Futhark.CodeGen.RTS.C ( atomicsH, contextH, contextPrototypesH, copyH, freeListH, eventListH, gpuH, gpuPrototypesH, halfH, lockH, scalarF16H, scalarH, schedulerH, serverH, timingH, tuningH, utilH, valuesH, errorsH, cacheH, uniformH, ispcUtilH, backendsOpenclH, backendsCudaH, backendsHipH, backendsCH, backendsMulticoreH, ) where import Data.FileEmbed import Data.Text qualified as T -- We mark everything here NOINLINE so that the dependent modules -- don't have to be recompiled just because we change the RTS files. -- | @rts/c/atomics.h@ atomicsH :: T.Text atomicsH = $(embedStringFile "rts/c/atomics.h") {-# NOINLINE atomicsH #-} -- | @rts/c/uniform.h@ uniformH :: T.Text uniformH = $(embedStringFile "rts/c/uniform.h") {-# NOINLINE uniformH #-} -- | @rts/c/free_list.h@ freeListH :: T.Text freeListH = $(embedStringFile "rts/c/free_list.h") {-# NOINLINE freeListH #-} -- | @rts/c/event_list.h@ eventListH :: T.Text eventListH = $(embedStringFile "rts/c/event_list.h") {-# NOINLINE eventListH #-} -- | @rts/c/gpu.h@ gpuH :: T.Text gpuH = $(embedStringFile "rts/c/gpu.h") {-# NOINLINE gpuH #-} -- | @rts/c/gpu_prototypes.h@ gpuPrototypesH :: T.Text gpuPrototypesH = $(embedStringFile "rts/c/gpu_prototypes.h") {-# NOINLINE gpuPrototypesH #-} -- | @rts/c/half.h@ halfH :: T.Text halfH = $(embedStringFile "rts/c/half.h") {-# NOINLINE halfH #-} -- | @rts/c/lock.h@ lockH :: T.Text lockH = $(embedStringFile "rts/c/lock.h") {-# NOINLINE lockH #-} -- | @rts/c/scalar_f16.h@ scalarF16H :: T.Text scalarF16H = $(embedStringFile "rts/c/scalar_f16.h") {-# NOINLINE scalarF16H #-} -- | @rts/c/scalar.h@ scalarH :: T.Text scalarH = $(embedStringFile "rts/c/scalar.h") {-# NOINLINE scalarH #-} -- | @rts/c/scheduler.h@ schedulerH :: T.Text schedulerH = $(embedStringFile "rts/c/scheduler.h") {-# NOINLINE schedulerH #-} -- | @rts/c/server.h@ serverH :: T.Text serverH = $(embedStringFile "rts/c/server.h") {-# NOINLINE serverH #-} -- | @rts/c/timing.h@ timingH :: T.Text timingH = $(embedStringFile "rts/c/timing.h") {-# NOINLINE timingH #-} -- | @rts/c/tuning.h@ tuningH :: T.Text tuningH = $(embedStringFile "rts/c/tuning.h") {-# NOINLINE tuningH #-} -- | @rts/c/util.h@ utilH :: T.Text utilH = $(embedStringFile "rts/c/util.h") {-# NOINLINE utilH #-} -- | @rts/c/values.h@ valuesH :: T.Text valuesH = $(embedStringFile "rts/c/values.h") {-# NOINLINE valuesH #-} -- | @rts/c/errors.h@ errorsH :: T.Text errorsH = $(embedStringFile "rts/c/errors.h") {-# NOINLINE errorsH #-} -- | @rts/c/ispc_util.h@ ispcUtilH :: T.Text ispcUtilH = $(embedStringFile "rts/c/ispc_util.h") {-# NOINLINE ispcUtilH #-} -- | @rts/c/cache.h@ cacheH :: T.Text cacheH = $(embedStringFile "rts/c/cache.h") {-# NOINLINE cacheH #-} -- | @rts/c/context.h@ contextH :: T.Text contextH = $(embedStringFile "rts/c/context.h") {-# NOINLINE contextH #-} -- | @rts/c/context_prototypes.h@ contextPrototypesH :: T.Text contextPrototypesH = $(embedStringFile "rts/c/context_prototypes.h") {-# NOINLINE contextPrototypesH #-} -- | @rts/c/backends/opencl.h@ backendsOpenclH :: T.Text backendsOpenclH = $(embedStringFile "rts/c/backends/opencl.h") {-# NOINLINE backendsOpenclH #-} -- | @rts/c/backends/cuda.h@ backendsCudaH :: T.Text backendsCudaH = $(embedStringFile "rts/c/backends/cuda.h") {-# NOINLINE backendsCudaH #-} -- | @rts/c/backends/hip.h@ backendsHipH :: T.Text backendsHipH = $(embedStringFile "rts/c/backends/hip.h") {-# NOINLINE backendsHipH #-} -- | @rts/c/backends/c.h@ backendsCH :: T.Text backendsCH = $(embedStringFile "rts/c/backends/c.h") {-# NOINLINE backendsCH #-} -- | @rts/c/backends/multicore.h@ backendsMulticoreH :: T.Text backendsMulticoreH = $(embedStringFile "rts/c/backends/multicore.h") {-# NOINLINE backendsMulticoreH #-} -- | @rts/c/copy.h@ copyH :: T.Text copyH = $(embedStringFile "rts/c/copy.h") {-# NOINLINE copyH #-} futhark-0.25.27/src/Futhark/CodeGen/RTS/CUDA.hs000066400000000000000000000004701475065116200205620ustar00rootroot00000000000000{-# LANGUAGE TemplateHaskell #-} -- | Code snippets used by the CUDA backend. module Futhark.CodeGen.RTS.CUDA (preludeCU) where import Data.FileEmbed import Data.Text qualified as T -- | @rts/cuda/prelude.cu@ preludeCU :: T.Text preludeCU = $(embedStringFile "rts/cuda/prelude.cu") {-# NOINLINE preludeCU #-} futhark-0.25.27/src/Futhark/CodeGen/RTS/JavaScript.hs000066400000000000000000000011161475065116200221120ustar00rootroot00000000000000{-# LANGUAGE TemplateHaskell #-} -- | Code snippets used by the JS backends. module Futhark.CodeGen.RTS.JavaScript ( serverJs, valuesJs, wrapperclassesJs, ) where import Data.FileEmbed import Data.Text qualified as T -- | @rts/javascript/server.js@ serverJs :: T.Text serverJs = $(embedStringFile "rts/javascript/server.js") -- | @rts/javascript/values.js@ valuesJs :: T.Text valuesJs = $(embedStringFile "rts/javascript/values.js") -- | @rts/javascript/wrapperclasses.js@ wrapperclassesJs :: T.Text wrapperclassesJs = $(embedStringFile "rts/javascript/wrapperclasses.js") futhark-0.25.27/src/Futhark/CodeGen/RTS/OpenCL.hs000066400000000000000000000011611475065116200211640ustar00rootroot00000000000000{-# LANGUAGE TemplateHaskell #-} -- | Code snippets used by the OpenCL and CUDA backends. module Futhark.CodeGen.RTS.OpenCL ( transposeCL, preludeCL, copyCL, ) where import Data.FileEmbed import Data.Text qualified as T -- | @rts/opencl/transpose.cl@ transposeCL :: T.Text transposeCL = $(embedStringFile "rts/opencl/transpose.cl") {-# NOINLINE transposeCL #-} -- | @rts/opencl/prelude.cl@ preludeCL :: T.Text preludeCL = $(embedStringFile "rts/opencl/prelude.cl") {-# NOINLINE preludeCL #-} -- | @rts/opencl/copy.cl@ copyCL :: T.Text copyCL = $(embedStringFile "rts/opencl/copy.cl") {-# NOINLINE copyCL #-} futhark-0.25.27/src/Futhark/CodeGen/RTS/Python.hs000066400000000000000000000017251475065116200213330ustar00rootroot00000000000000{-# LANGUAGE TemplateHaskell #-} -- | Code snippets used by the Python backends. module Futhark.CodeGen.RTS.Python ( memoryPy, openclPy, panicPy, scalarPy, serverPy, tuningPy, valuesPy, ) where import Data.FileEmbed import Data.Text qualified as T -- | @rts/python/memory.py@ memoryPy :: T.Text memoryPy = $(embedStringFile "rts/python/memory.py") -- | @rts/python/opencl.py@ openclPy :: T.Text openclPy = $(embedStringFile "rts/python/opencl.py") -- | @rts/python/panic.py@ panicPy :: T.Text panicPy = $(embedStringFile "rts/python/panic.py") -- | @rts/python/scalar.py@ scalarPy :: T.Text scalarPy = $(embedStringFile "rts/python/scalar.py") -- | @rts/python/server.py@ serverPy :: T.Text serverPy = $(embedStringFile "rts/python/server.py") -- | @rts/python/tuning.py@ tuningPy :: T.Text tuningPy = $(embedStringFile "rts/python/tuning.py") -- | @rts/python/values.py@ valuesPy :: T.Text valuesPy = $(embedStringFile "rts/python/values.py") futhark-0.25.27/src/Futhark/Compiler.hs000066400000000000000000000165241475065116200176530ustar00rootroot00000000000000{-# LANGUAGE Strict #-} -- | High-level API for invoking the Futhark compiler. module Futhark.Compiler ( runPipelineOnProgram, runCompilerOnProgram, dumpError, handleWarnings, prettyProgErrors, module Futhark.Compiler.Program, module Futhark.Compiler.Config, readProgramFile, readProgramFiles, readProgramOrDie, readUntypedProgram, readUntypedProgramOrDie, ) where import Control.Monad import Control.Monad.Except (MonadError) import Control.Monad.IO.Class (MonadIO, liftIO) import Data.Bifunctor (first) import Data.List (sortOn) import Data.List.NonEmpty qualified as NE import Data.Loc (Loc (..), posCoff, posFile) import Data.Text.IO qualified as T import Futhark.Analysis.Alias qualified as Alias import Futhark.Compiler.Config import Futhark.Compiler.Program import Futhark.IR import Futhark.IR.SOACS qualified as I import Futhark.IR.TypeCheck qualified as I import Futhark.Internalise import Futhark.MonadFreshNames import Futhark.Pipeline import Futhark.Util.Log import Futhark.Util.Pretty import Language.Futhark qualified as E import Language.Futhark.Semantic (includeToString) import Language.Futhark.Warnings import System.Exit (ExitCode (..), exitWith) import System.IO -- | Print a compiler error to stdout. The 'FutharkConfig' controls -- to which degree auxiliary information (e.g. the failing program) is -- also printed. dumpError :: FutharkConfig -> CompilerError -> IO () dumpError config err = case err of ExternalError s -> do hPutDoc stderr s T.hPutStrLn stderr "" T.hPutStrLn stderr "If you find this error message confusing, uninformative, or wrong, please open an issue:" T.hPutStrLn stderr " https://github.com/diku-dk/futhark/issues" InternalError s info CompilerBug -> do T.hPutStrLn stderr "Internal compiler error. Please report this:" T.hPutStrLn stderr " https://github.com/diku-dk/futhark/issues" report s info InternalError s info CompilerLimitation -> do T.hPutStrLn stderr "Known compiler limitation encountered. Sorry." T.hPutStrLn stderr "Revise your program or try a different Futhark compiler." report s info where report s info = do T.hPutStrLn stderr s when (fst (futharkVerbose config) > NotVerbose) $ maybe (T.hPutStr stderr) T.writeFile (snd (futharkVerbose config)) $ info <> "\n" -- | Read a program from the given 'FilePath', run the given -- 'Pipeline', and finish up with the given 'Action'. runCompilerOnProgram :: FutharkConfig -> Pipeline I.SOACS rep -> Action rep -> FilePath -> IO () runCompilerOnProgram config pipeline action file = do res <- runFutharkM compile $ fst $ futharkVerbose config case res of Left err -> liftIO $ do dumpError config err exitWith $ ExitFailure 2 Right () -> pure () where compile = do prog <- runPipelineOnProgram config pipeline file when ((> NotVerbose) . fst $ futharkVerbose config) $ logMsg $ "Running action " ++ actionName action actionProcedure action prog when ((> NotVerbose) . fst $ futharkVerbose config) $ logMsg ("Done." :: String) -- | Read a program from the given 'FilePath', run the given -- 'Pipeline', and return it. runPipelineOnProgram :: FutharkConfig -> Pipeline I.SOACS torep -> FilePath -> FutharkM (Prog torep) runPipelineOnProgram config pipeline file = do when (pipelineVerbose pipeline_config) $ logMsg ("Reading and type-checking source program" :: String) (prog_imports, namesrc) <- handleWarnings config $ (\(a, b, c) -> (a, (b, c))) <$> readProgramFile (futharkEntryPoints config) file putNameSource namesrc int_prog <- internaliseProg config prog_imports when (pipelineVerbose pipeline_config) $ logMsg ("Type-checking internalised program" :: String) typeCheckInternalProgram int_prog runPipeline pipeline pipeline_config int_prog where pipeline_config = PipelineConfig { pipelineVerbose = fst (futharkVerbose config) > NotVerbose, pipelineValidate = futharkTypeCheck config } typeCheckInternalProgram :: I.Prog I.SOACS -> FutharkM () typeCheckInternalProgram prog = case I.checkProg prog' of Left err -> internalErrorS ("After internalisation:\n" ++ show err) (pretty prog') Right () -> pure () where prog' = Alias.aliasAnalysis prog -- | Prettyprint program errors as suitable for showing on a text console. prettyProgErrors :: NE.NonEmpty ProgError -> Doc AnsiStyle prettyProgErrors = stack . punctuate line . map onError . sortOn (rep . locOf) . NE.toList where rep NoLoc = ("", 0) rep (Loc p _) = (posFile p, posCoff p) onError (ProgError NoLoc msg) = unAnnotate msg onError (ProgError loc msg) = annotate (color Red) ("Error at " <> pretty (locText (srclocOf loc))) <> ":" unAnnotate msg onError (ProgWarning NoLoc msg) = unAnnotate msg onError (ProgWarning loc msg) = annotate (color Yellow) $ "Warning at " <> pretty (locText (srclocOf loc)) <> ":" unAnnotate msg -- | Throw an exception formatted with 'prettyProgErrors' if there's -- an error. throwOnProgError :: (MonadError CompilerError m) => Either (NE.NonEmpty ProgError) a -> m a throwOnProgError = either (externalError . prettyProgErrors) pure -- | Read and type-check a Futhark program, comprising a single file, -- including all imports. readProgramFile :: (MonadError CompilerError m, MonadIO m) => [I.Name] -> FilePath -> m (Warnings, Imports, VNameSource) readProgramFile extra_eps = readProgramFiles extra_eps . pure -- | Read and type-check a Futhark library, comprising multiple files, -- including all imports. readProgramFiles :: (MonadError CompilerError m, MonadIO m) => [I.Name] -> [FilePath] -> m (Warnings, Imports, VNameSource) readProgramFiles extra_eps = throwOnProgError <=< liftIO . readLibrary extra_eps -- | Read and parse (but do not type-check) a Futhark program, -- including all imports. readUntypedProgram :: (MonadError CompilerError m, MonadIO m) => FilePath -> m [(String, E.UncheckedProg)] readUntypedProgram = fmap (map (first includeToString)) . throwOnProgError <=< liftIO . readUntypedLibrary . pure orDie :: (MonadIO m) => FutharkM a -> m a orDie m = liftIO $ do res <- runFutharkM m NotVerbose case res of Left err -> do dumpError newFutharkConfig err exitWith $ ExitFailure 2 Right res' -> pure res' -- | Not verbose, and terminates process on error. readProgramOrDie :: (MonadIO m) => FilePath -> m (Warnings, Imports, VNameSource) readProgramOrDie file = orDie $ readProgramFile mempty file -- | Not verbose, and terminates process on error. readUntypedProgramOrDie :: (MonadIO m) => FilePath -> m [(String, E.UncheckedProg)] readUntypedProgramOrDie file = orDie $ readUntypedProgram file -- | Run an operation that produces warnings, and handle them -- appropriately, yielding the non-warning return value. "Proper -- handling" means e.g. to print them to the screen, as directed by -- the compiler configuration. handleWarnings :: FutharkConfig -> FutharkM (Warnings, a) -> FutharkM a handleWarnings config m = do (ws, a) <- m when (futharkWarn config && anyWarnings ws) $ do liftIO $ hPutDoc stderr $ prettyWarnings ws when (futharkWerror config) $ externalErrorS "Treating above warnings as errors due to --Werror." pure a futhark-0.25.27/src/Futhark/Compiler/000077500000000000000000000000001475065116200173075ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Compiler/CLI.hs000066400000000000000000000130141475065116200202510ustar00rootroot00000000000000-- | Convenient common interface for command line Futhark compilers. -- Using this module ensures that all compilers take the same options. -- A small amount of flexibility is provided for backend-specific -- options. module Futhark.Compiler.CLI ( compilerMain, CompilerOption, CompilerMode (..), module Futhark.Pipeline, module Futhark.Compiler, ) where import Control.Monad import Data.Maybe import Futhark.Compiler import Futhark.IR (Name, Prog, nameFromString) import Futhark.IR.SOACS (SOACS) import Futhark.Pipeline import Futhark.Util.Options import System.FilePath -- | Run a parameterised Futhark compiler, where @cfg@ is a user-given -- configuration type. Call this from @main@. compilerMain :: -- | Initial configuration. cfg -> -- | Options that affect the configuration. [CompilerOption cfg] -> -- | The short action name (e.g. "compile to C"). String -> -- | The longer action description. String -> -- | The pipeline to use. Pipeline SOACS rep -> -- | The action to take on the result of the pipeline. ( FutharkConfig -> cfg -> CompilerMode -> FilePath -> Prog rep -> FutharkM () ) -> -- | Program name String -> -- | Command line arguments. [String] -> IO () compilerMain cfg cfg_opts name desc pipeline doIt = mainWithOptions (newCompilerConfig cfg) (commandLineOptions ++ map wrapOption cfg_opts) "options... " inspectNonOptions where inspectNonOptions [file] config = Just $ compile config file inspectNonOptions _ _ = Nothing compile config filepath = runCompilerOnProgram (futharkConfig config) pipeline (action config filepath) filepath action config filepath = Action { actionName = name, actionDescription = desc, actionProcedure = doIt (futharkConfig config) (compilerConfig config) (compilerMode config) (outputFilePath filepath config) } -- | An option that modifies the configuration of type @cfg@. type CompilerOption cfg = OptDescr (Either (IO ()) (cfg -> cfg)) type CoreCompilerOption cfg = OptDescr ( Either (IO ()) (CompilerConfig cfg -> CompilerConfig cfg) ) commandLineOptions :: [CoreCompilerOption cfg] commandLineOptions = [ Option "o" [] ( ReqArg (\filename -> Right $ \config -> config {compilerOutput = Just filename}) "FILE" ) "Name of the compiled binary.", Option "v" ["verbose"] (OptArg (Right . incVerbosity) "FILE") "Print verbose output on standard error; wrong program to FILE.", Option [] ["library"] (NoArg $ Right $ \config -> config {compilerMode = ToLibrary}) "Generate a library instead of an executable.", Option [] ["executable"] (NoArg $ Right $ \config -> config {compilerMode = ToExecutable}) "Generate an executable instead of a library (set by default).", Option [] ["server"] (NoArg $ Right $ \config -> config {compilerMode = ToServer}) "Generate a server executable instead of a library.", Option "w" [] (NoArg $ Right $ \config -> config {compilerWarn = False}) "Disable all warnings.", Option [] ["Werror"] (NoArg $ Right $ \config -> config {compilerWerror = True}) "Treat warnings as errors.", Option [] ["safe"] (NoArg $ Right $ \config -> config {compilerSafe = True}) "Ignore 'unsafe' in code.", Option [] ["entry-point"] ( ReqArg ( \arg -> Right $ \config -> config { compilerEntryPoints = nameFromString arg : compilerEntryPoints config } ) "NAME" ) "Treat this function as an additional entry point." ] wrapOption :: CompilerOption cfg -> CoreCompilerOption cfg wrapOption = fmap wrap where wrap f = do g <- f pure $ \cfg -> cfg {compilerConfig = g (compilerConfig cfg)} incVerbosity :: Maybe FilePath -> CompilerConfig cfg -> CompilerConfig cfg incVerbosity file cfg = cfg {compilerVerbose = (v, file `mplus` snd (compilerVerbose cfg))} where v = case fst $ compilerVerbose cfg of NotVerbose -> Verbose Verbose -> VeryVerbose VeryVerbose -> VeryVerbose data CompilerConfig cfg = CompilerConfig { compilerOutput :: Maybe FilePath, compilerVerbose :: (Verbosity, Maybe FilePath), compilerMode :: CompilerMode, compilerWerror :: Bool, compilerSafe :: Bool, compilerWarn :: Bool, compilerConfig :: cfg, compilerEntryPoints :: [Name] } -- | The configuration of the compiler. newCompilerConfig :: cfg -> CompilerConfig cfg newCompilerConfig x = CompilerConfig { compilerOutput = Nothing, compilerVerbose = (NotVerbose, Nothing), compilerMode = ToExecutable, compilerWerror = False, compilerSafe = False, compilerWarn = True, compilerConfig = x, compilerEntryPoints = mempty } outputFilePath :: FilePath -> CompilerConfig cfg -> FilePath outputFilePath srcfile = fromMaybe (srcfile `replaceExtension` "") . compilerOutput futharkConfig :: CompilerConfig cfg -> FutharkConfig futharkConfig config = newFutharkConfig { futharkVerbose = compilerVerbose config, futharkWerror = compilerWerror config, futharkSafe = compilerSafe config, futharkWarn = compilerWarn config, futharkEntryPoints = compilerEntryPoints config } futhark-0.25.27/src/Futhark/Compiler/Config.hs000066400000000000000000000031521475065116200210510ustar00rootroot00000000000000-- | Configuration of compiler behaviour that is universal to all backends. module Futhark.Compiler.Config ( FutharkConfig (..), newFutharkConfig, Verbosity (..), CompilerMode (..), ) where import Futhark.IR.Syntax.Core (Name) -- | Are we compiling a library or an executable? data CompilerMode = ToLibrary | ToExecutable | ToServer deriving (Eq, Ord, Show) -- | How much information to print to stderr while the compiler is running. data Verbosity = -- | Silence is golden. NotVerbose | -- | Print messages about which pass is running. Verbose | -- | Also print logs from individual passes. VeryVerbose deriving (Eq, Ord) -- | The compiler configuration. This only contains options related -- to core compiler functionality, such as reading the initial program -- and running passes. Options related to code generation are handled -- elsewhere. data FutharkConfig = FutharkConfig { futharkVerbose :: (Verbosity, Maybe FilePath), -- | Warn if True. futharkWarn :: Bool, -- | If true, error on any warnings. futharkWerror :: Bool, -- | If True, ignore @unsafe@. futharkSafe :: Bool, -- | Additional functions that should be exposed as entry points. futharkEntryPoints :: [Name], -- | If false, disable type-checking futharkTypeCheck :: Bool } -- | The default compiler configuration. newFutharkConfig :: FutharkConfig newFutharkConfig = FutharkConfig { futharkVerbose = (NotVerbose, Nothing), futharkWarn = True, futharkWerror = False, futharkSafe = False, futharkEntryPoints = [], futharkTypeCheck = True } futhark-0.25.27/src/Futhark/Compiler/Program.hs000066400000000000000000000401761475065116200212620ustar00rootroot00000000000000{-# LANGUAGE LambdaCase #-} -- | Low-level compilation parts. Look at "Futhark.Compiler" for a -- more high-level API. module Futhark.Compiler.Program ( readLibrary, readUntypedLibrary, Imports, FileModule (..), E.Warnings, prettyWarnings, ProgError (..), LoadedProg (lpNameSource), noLoadedProg, lpImports, lpWarnings, lpFilePaths, reloadProg, extendProg, VFS, ) where import Control.Concurrent (forkIO) import Control.Concurrent.MVar ( MVar, modifyMVar, newEmptyMVar, newMVar, putMVar, readMVar, ) import Control.Monad import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.State (execStateT, gets, modify) import Data.Bifunctor (first) import Data.List (intercalate, sort) import Data.List.NonEmpty qualified as NE import Data.Loc (Loc (..), Located, locOf) import Data.Map qualified as M import Data.Maybe (mapMaybe) import Data.Text qualified as T import Data.Text.IO qualified as T import Data.Time.Clock (UTCTime, getCurrentTime) import Futhark.FreshNames import Futhark.Util (interactWithFileSafely, nubOrd, startupTime) import Futhark.Util.Pretty (Doc, align, pretty) import Language.Futhark qualified as E import Language.Futhark.Parser (SyntaxError (..), parseFuthark) import Language.Futhark.Prelude import Language.Futhark.Prop (isBuiltin) import Language.Futhark.Semantic import Language.Futhark.TypeChecker qualified as E import Language.Futhark.Warnings import System.Directory (getModificationTime) import System.FilePath (normalise, takeExtension) import System.FilePath.Posix qualified as Posix data LoadedFile fm = LoadedFile { lfPath :: FilePath, lfImportName :: ImportName, lfMod :: fm, -- | Modification time of the underlying file. lfModTime :: UTCTime } deriving (Eq, Ord, Show) -- | Note that the location may be 'NoLoc'. This essentially only -- happens when the problem is that a root file cannot be found. data ProgError = ProgError Loc (Doc ()) | -- | Not actually an error, but we want them reported -- with errors. ProgWarning Loc (Doc ()) type WithErrors = Either (NE.NonEmpty ProgError) instance Located ProgError where locOf (ProgError l _) = l locOf (ProgWarning l _) = l -- | A mapping from absolute pathnames to pretty representing a virtual -- file system. Before loading a file from the file system, this -- mapping is consulted. If the desired pathname has an entry here, -- the corresponding pretty is used instead of loading the file from -- disk. type VFS = M.Map FilePath T.Text newtype UncheckedImport = UncheckedImport { unChecked :: WithErrors (LoadedFile E.UncheckedProg, [((ImportName, Loc), MVar UncheckedImport)]) } -- | If mapped to Nothing, treat it as present. This is used when -- reloading programs. type ReaderState = MVar (M.Map ImportName (Maybe (MVar UncheckedImport))) newState :: [ImportName] -> IO ReaderState newState known = newMVar $ M.fromList $ map (,Nothing) known orderedImports :: [((ImportName, Loc), MVar UncheckedImport)] -> IO [(ImportName, WithErrors (LoadedFile E.UncheckedProg))] orderedImports = fmap reverse . flip execStateT [] . mapM_ (spelunk []) where spelunk steps ((include, loc), mvar) | include `elem` steps = do let problem = ProgError loc . pretty $ "Import cycle: " <> intercalate " -> " (map includeToString $ reverse $ include : steps) modify ((include, Left (NE.singleton problem)) :) | otherwise = do prev <- gets $ lookup include case prev of Just _ -> pure () Nothing -> do res <- unChecked <$> liftIO (readMVar mvar) case res of Left errors -> modify ((include, Left errors) :) Right (file, more_imports) -> do mapM_ (spelunk (include : steps)) more_imports modify ((include, Right file) :) errorsToTop :: [(ImportName, WithErrors (LoadedFile E.UncheckedProg))] -> WithErrors [(ImportName, LoadedFile E.UncheckedProg)] errorsToTop [] = Right [] errorsToTop ((_, Left x) : rest) = either (Left . (x <>)) (const (Left x)) (errorsToTop rest) errorsToTop ((name, Right x) : rest) = fmap ((name, x) :) (errorsToTop rest) newImportMVar :: IO UncheckedImport -> IO (Maybe (MVar UncheckedImport)) newImportMVar m = do mvar <- newEmptyMVar void $ forkIO $ putMVar mvar =<< m pure $ Just mvar -- | Read the content and modification time of a file. -- Check if the file exits in VFS before interact with file system directly. contentsAndModTime :: FilePath -> VFS -> IO (Maybe (Either String (T.Text, UTCTime))) contentsAndModTime filepath vfs = do case M.lookup filepath vfs of Nothing -> interactWithFileSafely $ (,) <$> T.readFile filepath <*> getModificationTime filepath Just file_contents -> do now <- getCurrentTime pure $ Just $ Right (file_contents, now) readImportFile :: ImportName -> Loc -> VFS -> IO (Either ProgError (LoadedFile T.Text)) readImportFile include loc vfs = do -- First we try to find a file of the given name in the search path, -- then we look at the builtin library if we have to. For the -- builtins, we don't use the search path. r <- contentsAndModTime filepath vfs case (r, lookup prelude_str prelude) of (Just (Right (s, mod_time)), _) -> pure $ Right $ loaded filepath s mod_time (Just (Left e), _) -> pure $ Left $ ProgError loc $ pretty e (Nothing, Just s) -> pure $ Right $ loaded prelude_str s startupTime (Nothing, Nothing) -> pure $ Left $ ProgError loc $ pretty not_found where filepath = includeToFilePath include prelude_str = "/" Posix. includeToString include Posix.<.> "fut" loaded path s mod_time = LoadedFile { lfImportName = include, lfPath = path, lfMod = s, lfModTime = mod_time } not_found = "Could not read file " <> E.quote (T.pack filepath) <> "." handleFile :: ReaderState -> VFS -> LoadedFile T.Text -> IO UncheckedImport handleFile state_mvar vfs (LoadedFile file_name import_name file_contents mod_time) = do case parseFuthark file_name file_contents of Left (SyntaxError loc err) -> pure . UncheckedImport . Left . NE.singleton $ ProgError loc $ pretty err Right prog -> do let imports = map (first $ mkImportFrom import_name) $ E.progImports prog mvars <- mapMaybe sequenceA . zip imports <$> mapM (uncurry $ readImport state_mvar vfs) imports let file = LoadedFile { lfPath = file_name, lfImportName = import_name, lfModTime = mod_time, lfMod = prog } pure $ UncheckedImport $ Right (file, mvars) readImport :: ReaderState -> VFS -> ImportName -> Loc -> IO (Maybe (MVar UncheckedImport)) readImport state_mvar vfs include loc = modifyMVar state_mvar $ \state -> case M.lookup include state of Just x -> pure (state, x) Nothing -> do prog_mvar <- newImportMVar $ do readImportFile include loc vfs >>= \case Left e -> pure $ UncheckedImport $ Left $ NE.singleton e Right file -> handleFile state_mvar vfs file pure (M.insert include prog_mvar state, prog_mvar) readUntypedLibraryExceptKnown :: [ImportName] -> VFS -> [FilePath] -> IO (Either (NE.NonEmpty ProgError) [LoadedFile E.UncheckedProg]) readUntypedLibraryExceptKnown known vfs fps = do state_mvar <- liftIO $ newState known let prelude_import = mkInitialImport "/prelude/prelude" prelude_mvar <- liftIO $ readImport state_mvar vfs prelude_import mempty fps_mvars <- liftIO (mapM (onFile state_mvar) fps) let unknown_mvars = onlyUnknown (((prelude_import, mempty), prelude_mvar) : fps_mvars) fmap (map snd) . errorsToTop <$> orderedImports unknown_mvars where onlyUnknown = mapMaybe sequenceA onFile state_mvar fp = modifyMVar state_mvar $ \state -> do case M.lookup include state of Just prog_mvar -> pure (state, ((include, mempty), prog_mvar)) Nothing -> do prog_mvar <- newImportMVar $ do if takeExtension fp /= ".fut" then pure . UncheckedImport . Left . NE.singleton $ ProgError NoLoc $ pretty fp <> ": source files must have a .fut extension." else do r <- contentsAndModTime fp vfs case r of Just (Right (fs, mod_time)) -> do handleFile state_mvar vfs $ LoadedFile { lfImportName = include, lfMod = fs, lfModTime = mod_time, lfPath = fp } Just (Left e) -> pure . UncheckedImport . Left . NE.singleton $ ProgError NoLoc $ pretty $ show e Nothing -> pure . UncheckedImport . Left . NE.singleton $ ProgError NoLoc $ pretty fp <> ": file not found." pure (M.insert include prog_mvar state, ((include, mempty), prog_mvar)) where include = mkInitialImport fp_name (fp_name, _) = Posix.splitExtension fp -- | A type-checked file. data CheckedFile = CheckedFile { -- | The name generation state after checking this file. cfNameSource :: VNameSource, -- | The warnings that were issued from checking this file. cfWarnings :: Warnings, -- | The type-checked file. cfMod :: FileModule } asImports :: [LoadedFile CheckedFile] -> Imports asImports = map f where f lf = (lfImportName lf, cfMod $ lfMod lf) typeCheckProg :: [LoadedFile CheckedFile] -> VNameSource -> [LoadedFile E.UncheckedProg] -> WithErrors ([LoadedFile CheckedFile], VNameSource) typeCheckProg orig_imports orig_src = foldM f (orig_imports, orig_src) where roots = ["/prelude/prelude"] f (imports, src) (LoadedFile path import_name prog mod_time) = do let prog' | isBuiltin (includeToFilePath import_name) = prog | otherwise = prependRoots roots prog case E.checkProg (asImports imports) src import_name prog' of (prog_ws, Left (E.TypeError loc notes msg)) -> do let err' = msg <> pretty notes warningToError (wloc, wmsg) = ProgWarning (locOf wloc) wmsg Left $ ProgError (locOf loc) err' NE.:| map warningToError (listWarnings prog_ws) (prog_ws, Right (m, src')) -> let warnHole (loc, t) = singleWarning (E.locOf loc) $ "Hole of type: " <> align (pretty t) prog_ws' = prog_ws <> foldMap warnHole (E.progHoles (fileProg m)) in Right ( imports ++ [LoadedFile path import_name (CheckedFile src prog_ws' m) mod_time], src' ) setEntryPoints :: [E.Name] -> [FilePath] -> [LoadedFile E.UncheckedProg] -> [LoadedFile E.UncheckedProg] setEntryPoints extra_eps fps = map onFile where fps' = map normalise fps onFile lf | includeToFilePath (lfImportName lf) `elem` fps' = lf {lfMod = prog {E.progDecs = map onDec (E.progDecs prog)}} | otherwise = lf where prog = lfMod lf onDec (E.ValDec vb) | E.valBindName vb `elem` extra_eps = E.ValDec vb {E.valBindEntryPoint = Just E.NoInfo} onDec dec = dec prependRoots :: [FilePath] -> E.UncheckedProg -> E.UncheckedProg prependRoots roots (E.Prog doc ds) = E.Prog doc $ map mkImport roots ++ ds where mkImport fp = -- We do not use ImportDec here, because we do not want the -- type checker to issue a warning about a redundant import. E.LocalDec (E.OpenDec (E.ModImport fp E.NoInfo mempty) mempty) mempty -- | A loaded, type-checked program. This can be used to extract -- information about the program, but also to speed up subsequent -- reloads. data LoadedProg = LoadedProg { lpRoots :: [FilePath], -- | The 'VNameSource' is the name source just *before* the module -- was type checked. lpFiles :: [LoadedFile CheckedFile], -- | Final name source. lpNameSource :: VNameSource } -- | The 'Imports' of a 'LoadedProg', as expected by e.g. type -- checking functions. lpImports :: LoadedProg -> Imports lpImports = map f . lpFiles where f lf = (lfImportName lf, cfMod $ lfMod lf) -- | All warnings of a 'LoadedProg'. lpWarnings :: LoadedProg -> Warnings lpWarnings = foldMap (cfWarnings . lfMod) . lpFiles -- | The absolute paths of the files that are part of this program. lpFilePaths :: LoadedProg -> [FilePath] lpFilePaths = map lfPath . lpFiles unchangedImports :: (MonadIO m) => VNameSource -> VFS -> [LoadedFile CheckedFile] -> m ([LoadedFile CheckedFile], VNameSource) unchangedImports src _ [] = pure ([], src) unchangedImports src vfs (f : fs) | isBuiltin (includeToFilePath (lfImportName f)) = first (f :) <$> unchangedImports src vfs fs | otherwise = do let file_path = lfPath f if M.member file_path vfs then pure ([], cfNameSource $ lfMod f) else do changed <- maybe True (either (const True) (> lfModTime f)) <$> liftIO (interactWithFileSafely (getModificationTime file_path)) if changed then pure ([], cfNameSource $ lfMod f) else first (f :) <$> unchangedImports src vfs fs -- | A "loaded program" containing no actual files. Use this as a -- starting point for 'reloadProg' noLoadedProg :: LoadedProg noLoadedProg = LoadedProg { lpRoots = [], lpFiles = mempty, lpNameSource = newNameSource $ E.maxIntrinsicTag + 1 } -- | Find out how many of the old imports can be used. Here we are -- forced to be overly conservative, because our type checker -- enforces a linear ordering. usableLoadedProg :: (MonadIO m) => LoadedProg -> VFS -> [FilePath] -> m LoadedProg usableLoadedProg (LoadedProg roots imports src) vfs new_roots | sort roots == sort new_roots = do (imports', src') <- unchangedImports src vfs imports pure $ LoadedProg [] imports' src' | otherwise = pure noLoadedProg -- | Extend a loaded program with (possibly new) files. extendProg :: LoadedProg -> [FilePath] -> VFS -> IO (Either (NE.NonEmpty ProgError) LoadedProg) extendProg lp new_roots vfs = do new_imports_untyped <- readUntypedLibraryExceptKnown (map lfImportName $ lpFiles lp) vfs new_roots pure $ do (imports, src') <- typeCheckProg (lpFiles lp) (lpNameSource lp) =<< new_imports_untyped Right (LoadedProg (nubOrd (lpRoots lp ++ new_roots)) imports src') -- | Load some new files, reusing as much of the previously loaded -- program as possible. This does not *extend* the currently loaded -- program the way 'extendProg' does it, so it is always correct (if -- less efficient) to pass 'noLoadedProg'. reloadProg :: LoadedProg -> [FilePath] -> VFS -> IO (Either (NE.NonEmpty ProgError) LoadedProg) reloadProg lp new_roots vfs = do lp' <- usableLoadedProg lp vfs new_roots extendProg lp' new_roots vfs -- | Read and type-check some Futhark files. readLibrary :: -- | Extra functions that should be marked as entry points; only -- applies to the immediate files, not any imports imported. [E.Name] -> -- | The files to read. [FilePath] -> IO (Either (NE.NonEmpty ProgError) (E.Warnings, Imports, VNameSource)) readLibrary extra_eps fps = ( fmap frob . typeCheckProg mempty (lpNameSource noLoadedProg) <=< fmap (setEntryPoints (E.defaultEntryPoint : extra_eps) fps) ) <$> readUntypedLibraryExceptKnown [] M.empty fps where frob (y, z) = (foldMap (cfWarnings . lfMod) y, asImports y, z) -- | Read (and parse) all source files (including the builtin prelude) -- corresponding to a set of root files. readUntypedLibrary :: [FilePath] -> IO (Either (NE.NonEmpty ProgError) [(ImportName, E.UncheckedProg)]) readUntypedLibrary = fmap (fmap (map f)) . readUntypedLibraryExceptKnown [] M.empty where f lf = (lfImportName lf, lfMod lf) futhark-0.25.27/src/Futhark/Construct.hs000066400000000000000000000525371475065116200200710ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | = Constructing Futhark ASTs -- -- This module re-exports and defines a bunch of building blocks for -- constructing fragments of Futhark ASTs. More importantly, it also -- contains a basic introduction on how to use them. -- -- The "Futhark.IR.Syntax" module contains the core -- AST definition. One important invariant is that all bound names in -- a Futhark program must be /globally/ unique. In principle, you -- could use the facilities from "Futhark.MonadFreshNames" (or your -- own bespoke source of unique names) to manually construct -- expressions, statements, and entire ASTs. In practice, this would -- be very tedious. Instead, we have defined a collection of building -- blocks (centered around the 'MonadBuilder' type class) that permits -- a more abstract way of generating code. -- -- Constructing ASTs with these building blocks requires you to ensure -- that all free variables are in scope. See -- "Futhark.IR.Prop.Scope". -- -- == 'MonadBuilder' -- -- A monad that implements 'MonadBuilder' tracks the statements added -- so far, the current names in scope, and allows you to add -- additional statements with 'addStm'. Any monad that implements -- 'MonadBuilder' also implements the t'Rep' type family, which -- indicates which rep it works with. Inside a 'MonadBuilder' we can -- use 'collectStms' to gather up the 'Stms' added with 'addStm' in -- some nested computation. -- -- The 'BuilderT' monad (and its convenient 'Builder' version) provides -- the simplest implementation of 'MonadBuilder'. -- -- == Higher-level building blocks -- -- On top of the raw facilities provided by 'MonadBuilder', we have -- more convenient facilities. For example, 'letSubExp' lets us -- conveniently create a 'Stm' for an 'Exp' that produces a /single/ -- value, and returns the (fresh) name for the resulting variable: -- -- @ -- z <- letExp "z" $ BasicOp $ BinOp (Add Int32) (Var x) (Var y) -- @ -- -- == Monadic expression builders -- -- This module also contains "monadic expression" functions that let -- us build nested expressions in a "direct" style, rather than using -- 'letExp' and friends to bind every sub-part first. See functions -- such as 'eIf' and 'eBody' for example. See also -- "Futhark.Analysis.PrimExp" and the 'ToExp' type class. -- -- == Examples -- -- The "Futhark.Transform.FirstOrderTransform" module is a -- (relatively) simple example of how to use these components. As are -- some of the high-level building blocks in this very module. module Futhark.Construct ( -- * Basic building blocks module Futhark.Builder, letSubExp, letExp, letTupExp, letTupExp', letInPlace, -- * Monadic expression builders eSubExp, eParam, eMatch', eMatch, eIf, eIf', eBinOp, eUnOp, eCmpOp, eConvOp, eSignum, eCopy, eBody, eLambda, eBlank, eAll, eAny, eDimInBounds, eOutOfBounds, eIndex, eLast, -- * Other building blocks asIntZ, asIntS, resultBody, resultBodyM, insertStmsM, buildBody, buildBody_, mapResult, foldBinOp, binOpLambda, cmpOpLambda, mkLambda, sliceDim, fullSlice, fullSliceNum, isFullSlice, sliceAt, -- * Result types instantiateShapes, instantiateShapes', removeExistentials, -- * Convenience simpleMkLetNames, ToExp (..), toSubExp, ) where import Control.Monad import Control.Monad.Identity import Control.Monad.State import Data.List qualified as L import Data.Map.Strict qualified as M import Futhark.Builder import Futhark.IR import Futhark.Util (maybeNth) -- | @letSubExp desc e@ binds the expression @e@, which must produce a -- single value. Returns a t'SubExp' corresponding to the resulting -- value. For expressions that produce multiple values, see -- 'letTupExp'. letSubExp :: (MonadBuilder m) => String -> Exp (Rep m) -> m SubExp letSubExp _ (BasicOp (SubExp se)) = pure se letSubExp desc e = Var <$> letExp desc e -- | Like 'letSubExp', but returns a name rather than a t'SubExp'. letExp :: (MonadBuilder m) => String -> Exp (Rep m) -> m VName letExp _ (BasicOp (SubExp (Var v))) = pure v letExp desc e = do n <- length <$> expExtType e vs <- replicateM n $ newVName desc letBindNames vs e case vs of [v] -> pure v _ -> error $ "letExp: tuple-typed expression given:\n" ++ prettyString e -- | Like 'letExp', but the 'VName' and 'Slice' denote an array that -- is 'Update'd with the result of the expression. The name of the -- updated array is returned. letInPlace :: (MonadBuilder m) => String -> VName -> Slice SubExp -> Exp (Rep m) -> m VName letInPlace desc src slice e = do tmp <- letSubExp (desc ++ "_tmp") e letExp desc $ BasicOp $ Update Unsafe src slice tmp -- | Like 'letExp', but the expression may return multiple values. letTupExp :: (MonadBuilder m) => String -> Exp (Rep m) -> m [VName] letTupExp _ (BasicOp (SubExp (Var v))) = pure [v] letTupExp name e = do e_t <- expExtType e names <- replicateM (length e_t) $ newVName name letBindNames names e pure names -- | Like 'letTupExp', but returns t'SubExp's instead of 'VName's. letTupExp' :: (MonadBuilder m) => String -> Exp (Rep m) -> m [SubExp] letTupExp' _ (BasicOp (SubExp se)) = pure [se] letTupExp' name ses = map Var <$> letTupExp name ses -- | Turn a subexpression into a monad expression. Does not actually -- lead to any code generation. This is supposed to be used alongside -- the other monadic expression functions, such as 'eIf'. eSubExp :: (MonadBuilder m) => SubExp -> m (Exp (Rep m)) eSubExp = pure . BasicOp . SubExp -- | Treat a parameter as a monadic expression. eParam :: (MonadBuilder m) => Param t -> m (Exp (Rep m)) eParam = eSubExp . Var . paramName removeRedundantScrutinees :: [SubExp] -> [Case b] -> ([SubExp], [Case b]) removeRedundantScrutinees ses cases = let (ses', vs) = unzip $ filter interesting $ zip ses $ L.transpose (map casePat cases) in (ses', zipWith Case (L.transpose vs) $ map caseBody cases) where interesting = any (/= Nothing) . snd -- | As 'eMatch', but an 'MatchSort' can be given. eMatch' :: (MonadBuilder m, BranchType (Rep m) ~ ExtType) => [SubExp] -> [Case (m (Body (Rep m)))] -> m (Body (Rep m)) -> MatchSort -> m (Exp (Rep m)) eMatch' ses cases_m defbody_m sort = do cases <- mapM (traverse insertStmsM) cases_m defbody <- insertStmsM defbody_m ts <- L.foldl' generaliseExtTypes <$> bodyExtType defbody <*> mapM (bodyExtType . caseBody) cases cases' <- mapM (traverse $ addContextForBranch ts) cases defbody' <- addContextForBranch ts defbody let ts' = replicate (length (shapeContext ts)) (Prim int64) ++ ts (ses', cases'') = removeRedundantScrutinees ses cases' pure $ Match ses' cases'' defbody' $ MatchDec ts' sort where addContextForBranch ts (Body _ stms val_res) = do body_ts <- extendedScope (traverse subExpResType val_res) stmsscope let ctx_res = map snd $ L.sortOn fst $ M.toList $ shapeExtMapping ts body_ts mkBodyM stms $ subExpsRes ctx_res ++ val_res where stmsscope = scopeOf stms -- | Construct a 'Match' expression. The main convenience here is -- that the existential context of the return type is automatically -- deduced, and the necessary elements added to the branches. eMatch :: (MonadBuilder m, BranchType (Rep m) ~ ExtType) => [SubExp] -> [Case (m (Body (Rep m)))] -> m (Body (Rep m)) -> m (Exp (Rep m)) eMatch ses cases_m defbody_m = eMatch' ses cases_m defbody_m MatchNormal -- | Construct a 'Match' modelling an if-expression from a monadic -- condition and monadic branches. 'eBody' might be convenient for -- constructing the branches. eIf :: (MonadBuilder m, BranchType (Rep m) ~ ExtType) => m (Exp (Rep m)) -> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m)) eIf ce te fe = eIf' ce te fe MatchNormal -- | As 'eIf', but an 'MatchSort' can be given. eIf' :: (MonadBuilder m, BranchType (Rep m) ~ ExtType) => m (Exp (Rep m)) -> m (Body (Rep m)) -> m (Body (Rep m)) -> MatchSort -> m (Exp (Rep m)) eIf' ce te fe if_sort = do ce' <- letSubExp "cond" =<< ce eMatch' [ce'] [Case [Just $ BoolValue True] te] fe if_sort -- The type of a body. Watch out: this only works for the degenerate -- case where the body does not already return its context. bodyExtType :: (HasScope rep m, Monad m) => Body rep -> m [ExtType] bodyExtType (Body _ stms res) = existentialiseExtTypes (M.keys stmsscope) . staticShapes <$> extendedScope (traverse subExpResType res) stmsscope where stmsscope = scopeOf stms -- | Construct a v'BinOp' expression with the given operator. eBinOp :: (MonadBuilder m) => BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m)) eBinOp op x y = do x' <- letSubExp "x" =<< x y' <- letSubExp "y" =<< y pure $ BasicOp $ BinOp op x' y' -- | Construct a v'UnOp' expression with the given operator. eUnOp :: (MonadBuilder m) => UnOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) eUnOp op x = BasicOp . UnOp op <$> (letSubExp "x" =<< x) -- | Construct a v'CmpOp' expression with the given comparison. eCmpOp :: (MonadBuilder m) => CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m)) eCmpOp op x y = do x' <- letSubExp "x" =<< x y' <- letSubExp "y" =<< y pure $ BasicOp $ CmpOp op x' y' -- | Construct a v'ConvOp' expression with the given conversion. eConvOp :: (MonadBuilder m) => ConvOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) eConvOp op x = do x' <- letSubExp "x" =<< x pure $ BasicOp $ ConvOp op x' -- | Construct a 'SSignum' expression. Fails if the provided -- expression is not of integer type. eSignum :: (MonadBuilder m) => m (Exp (Rep m)) -> m (Exp (Rep m)) eSignum em = do e <- em e' <- letSubExp "signum_arg" e t <- subExpType e' case t of Prim (IntType int_t) -> pure $ BasicOp $ UnOp (SSignum int_t) e' _ -> error $ "eSignum: operand " ++ prettyString e ++ " has invalid type." -- | Copy a value. eCopy :: (MonadBuilder m) => m (Exp (Rep m)) -> m (Exp (Rep m)) eCopy e = BasicOp . Replicate mempty <$> (letSubExp "copy_arg" =<< e) -- | Construct a body from expressions. If multiple expressions are -- provided, their results will be concatenated in order and returned -- as the result. -- -- /Beware/: this will not produce correct code if the type of the -- body would be existential. That is, the type of the results being -- returned should be invariant to the body. eBody :: (MonadBuilder m) => [m (Exp (Rep m))] -> m (Body (Rep m)) eBody es = buildBody_ $ do es' <- sequence es xs <- mapM (letTupExp "x") es' pure $ varsRes $ concat xs -- | Bind each lambda parameter to the result of an expression, then -- bind the body of the lambda. The expressions must produce only a -- single value each. eLambda :: (MonadBuilder m) => Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes] eLambda lam args = do zipWithM_ bindParam (lambdaParams lam) args bodyBind $ lambdaBody lam where bindParam param arg = letBindNames [paramName param] =<< arg -- | @eInBoundsForDim w i@ produces @0 <= i < w@. eDimInBounds :: (MonadBuilder m) => m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m)) eDimInBounds w i = eBinOp LogAnd (eCmpOp (CmpSle Int64) (eSubExp (intConst Int64 0)) i) (eCmpOp (CmpSlt Int64) i w) -- | Are these indexes out-of-bounds for the array? eOutOfBounds :: (MonadBuilder m) => VName -> [m (Exp (Rep m))] -> m (Exp (Rep m)) eOutOfBounds arr is = do arr_t <- lookupType arr let ws = arrayDims arr_t is' <- mapM (letSubExp "write_i") =<< sequence is let checkDim w i = do less_than_zero <- letSubExp "less_than_zero" $ BasicOp $ CmpOp (CmpSlt Int64) i (constant (0 :: Int64)) greater_than_size <- letSubExp "greater_than_size" $ BasicOp $ CmpOp (CmpSle Int64) w i letSubExp "outside_bounds_dim" $ BasicOp $ BinOp LogOr less_than_zero greater_than_size foldBinOp LogOr (constant False) =<< zipWithM checkDim ws is' -- | The array element at this index. Returns array unmodified if -- indexes are null (does not even need to be an array in that case). eIndex :: (MonadBuilder m) => VName -> [m (Exp (Rep m))] -> m (Exp (Rep m)) eIndex arr [] = eSubExp $ Var arr eIndex arr is = do is' <- mapM (letSubExp "i" =<<) is arr_t <- lookupType arr pure $ BasicOp $ Index arr $ fullSlice arr_t $ map DimFix is' -- | The last element of the given array. eLast :: (MonadBuilder m) => VName -> m (Exp (Rep m)) eLast arr = do n <- arraySize 0 <$> lookupType arr nm1 <- letSubExp "nm1" . BasicOp $ BinOp (Sub Int64 OverflowUndef) n (intConst Int64 1) eIndex arr [eSubExp nm1] -- | Construct an unspecified value of the given type. eBlank :: (MonadBuilder m) => Type -> m (Exp (Rep m)) eBlank (Prim t) = pure $ BasicOp $ SubExp $ Constant $ blankPrimValue t eBlank (Array t shape _) = pure $ BasicOp $ Scratch t $ shapeDims shape eBlank Acc {} = error "eBlank: cannot create blank accumulator" eBlank Mem {} = error "eBlank: cannot create blank memory" -- | Sign-extend to the given integer type. asIntS :: (MonadBuilder m) => IntType -> SubExp -> m SubExp asIntS = asInt SExt -- | Zero-extend to the given integer type. asIntZ :: (MonadBuilder m) => IntType -> SubExp -> m SubExp asIntZ = asInt ZExt asInt :: (MonadBuilder m) => (IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp asInt ext to_it e = do e_t <- subExpType e case e_t of Prim (IntType from_it) | to_it == from_it -> pure e | otherwise -> letSubExp s $ BasicOp $ ConvOp (ext from_it to_it) e _ -> error "asInt: wrong type" where s = case e of Var v -> baseString v _ -> "to_" ++ prettyString to_it -- | Apply a binary operator to several subexpressions. A left-fold. foldBinOp :: (MonadBuilder m) => BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m)) foldBinOp _ ne [] = pure $ BasicOp $ SubExp ne foldBinOp bop ne (e : es) = eBinOp bop (pure $ BasicOp $ SubExp e) (foldBinOp bop ne es) -- | True if all operands are true. eAll :: (MonadBuilder m) => [SubExp] -> m (Exp (Rep m)) eAll [] = pure $ BasicOp $ SubExp $ constant True eAll [x] = eSubExp x eAll (x : xs) = foldBinOp LogAnd x xs -- | True if any operand is true. eAny :: (MonadBuilder m) => [SubExp] -> m (Exp (Rep m)) eAny [] = pure $ BasicOp $ SubExp $ constant False eAny [x] = eSubExp x eAny (x : xs) = foldBinOp LogOr x xs -- | Create a two-parameter lambda whose body applies the given binary -- operation to its arguments. It is assumed that both argument and -- result types are the same. (This assumption should be fixed at -- some point.) binOpLambda :: (MonadBuilder m, Buildable (Rep m)) => BinOp -> PrimType -> m (Lambda (Rep m)) binOpLambda bop t = binLambda (BinOp bop) t t -- | As 'binOpLambda', but for t'CmpOp's. cmpOpLambda :: (MonadBuilder m, Buildable (Rep m)) => CmpOp -> m (Lambda (Rep m)) cmpOpLambda cop = binLambda (CmpOp cop) (cmpOpType cop) Bool binLambda :: (MonadBuilder m, Buildable (Rep m)) => (SubExp -> SubExp -> BasicOp) -> PrimType -> PrimType -> m (Lambda (Rep m)) binLambda bop arg_t ret_t = do x <- newVName "x" y <- newVName "y" body <- buildBody_ . fmap (pure . subExpRes) $ letSubExp "binlam_res" $ BasicOp $ bop (Var x) (Var y) pure Lambda { lambdaParams = [ Param mempty x (Prim arg_t), Param mempty y (Prim arg_t) ], lambdaReturnType = [Prim ret_t], lambdaBody = body } -- | Easily construct a t'Lambda' within a 'MonadBuilder'. See also -- 'runLambdaBuilder'. mkLambda :: (MonadBuilder m) => [LParam (Rep m)] -> m Result -> m (Lambda (Rep m)) mkLambda params m = do (body, ret) <- buildBody . localScope (scopeOfLParams params) $ do res <- m ret <- mapM subExpResType res pure (res, ret) pure $ Lambda params ret body -- | Slice a full dimension of the given size. sliceDim :: SubExp -> DimIndex SubExp sliceDim d = DimSlice (constant (0 :: Int64)) d (constant (1 :: Int64)) -- | @fullSlice t slice@ returns @slice@, but with 'DimSlice's of -- entire dimensions appended to the full dimensionality of @t@. This -- function is used to turn incomplete indexing complete, as required -- by 'Index'. fullSlice :: Type -> [DimIndex SubExp] -> Slice SubExp fullSlice t slice = Slice $ slice ++ map sliceDim (drop (length slice) $ arrayDims t) -- | @ sliceAt t n slice@ returns @slice@ but with 'DimSlice's of the -- outer @n@ dimensions prepended, and as many appended as to make it -- a full slice. This is a generalisation of 'fullSlice'. sliceAt :: Type -> Int -> [DimIndex SubExp] -> Slice SubExp sliceAt t n slice = fullSlice t $ map sliceDim (take n $ arrayDims t) ++ slice -- | Like 'fullSlice', but the dimensions are simply numeric. fullSliceNum :: (Num d) => [d] -> [DimIndex d] -> Slice d fullSliceNum dims slice = Slice $ slice ++ map (\d -> DimSlice 0 d 1) (drop (length slice) dims) -- | Does the slice describe the full size of the array? The most -- obvious such slice is one that 'DimSlice's the full span of every -- dimension, but also one that fixes all unit dimensions. isFullSlice :: Shape -> Slice SubExp -> Bool isFullSlice shape slice = and $ zipWith allOfIt (shapeDims shape) (unSlice slice) where allOfIt (Constant v) DimFix {} = oneIsh v allOfIt d (DimSlice _ n _) = d == n allOfIt _ _ = False -- | Conveniently construct a body that contains no bindings. resultBody :: (Buildable rep) => [SubExp] -> Body rep resultBody = mkBody mempty . subExpsRes -- | Conveniently construct a body that contains no bindings - but -- this time, monadically! resultBodyM :: (MonadBuilder m) => [SubExp] -> m (Body (Rep m)) resultBodyM = mkBodyM mempty . subExpsRes -- | Evaluate the action, producing a body, then wrap it in all the -- bindings it created using 'addStm'. insertStmsM :: (MonadBuilder m) => m (Body (Rep m)) -> m (Body (Rep m)) insertStmsM m = do (Body _ stms res, otherstms) <- collectStms m mkBodyM (otherstms <> stms) res -- | Evaluate an action that produces a 'Result' and an auxiliary -- value, then return the body constructed from the 'Result' and any -- statements added during the action, along the auxiliary value. buildBody :: (MonadBuilder m) => m (Result, a) -> m (Body (Rep m), a) buildBody m = do ((res, v), stms) <- collectStms m body <- mkBodyM stms res pure (body, v) -- | As 'buildBody', but there is no auxiliary value. buildBody_ :: (MonadBuilder m) => m Result -> m (Body (Rep m)) buildBody_ m = fst <$> buildBody ((,()) <$> m) -- | Change that result where evaluation of the body would stop. Also -- change type annotations at branches. mapResult :: (Buildable rep) => (Result -> Body rep) -> Body rep -> Body rep mapResult f (Body _ stms res) = let Body _ stms2 newres = f res in mkBody (stms <> stms2) newres -- | Instantiate all existential parts dimensions of the given -- type, using a monadic action to create the necessary t'SubExp's. -- You should call this function within some monad that allows you to -- collect the actions performed (say, 'State'). instantiateShapes :: (Monad m) => (Int -> m SubExp) -> [TypeBase ExtShape u] -> m [TypeBase Shape u] instantiateShapes f ts = evalStateT (mapM instantiate ts) M.empty where instantiate t = do shape <- mapM instantiate' $ shapeDims $ arrayShape t pure $ t `setArrayShape` Shape shape instantiate' (Ext x) = do m <- get case M.lookup x m of Just se -> pure se Nothing -> do se <- lift $ f x put $ M.insert x se m pure se instantiate' (Free se) = pure se -- | Like 'instantiateShapes', but obtains names from the provided -- list. If an 'Ext' is out of bounds of this list, the function -- fails with 'error'. instantiateShapes' :: [VName] -> [TypeBase ExtShape u] -> [TypeBase Shape u] instantiateShapes' names ts = -- Carefully ensure that the order of idents we produce corresponds -- to their existential index. runIdentity $ instantiateShapes instantiate ts where instantiate x = case maybeNth x names of Nothing -> error $ "instantiateShapes': " ++ prettyString names ++ ", " ++ show x Just name -> pure $ Var name -- | Remove existentials by imposing sizes from another type where -- needed. removeExistentials :: ExtType -> Type -> Type removeExistentials t1 t2 = t1 `setArrayDims` zipWith nonExistential (shapeDims $ arrayShape t1) (arrayDims t2) where nonExistential (Ext _) dim = dim nonExistential (Free dim) _ = dim -- | Can be used as the definition of 'mkLetNames' for a 'Buildable' -- instance for simple representations. simpleMkLetNames :: ( ExpDec rep ~ (), LetDec rep ~ Type, MonadFreshNames m, TypedOp (OpC rep), HasScope rep m ) => [VName] -> Exp rep -> m (Stm rep) simpleMkLetNames names e = do et <- expExtType e let ts = instantiateShapes' names et pure $ Let (Pat $ zipWith PatElem names ts) (defAux ()) e -- | Instances of this class can be converted to Futhark expressions -- within a 'MonadBuilder'. class ToExp a where toExp :: (MonadBuilder m) => a -> m (Exp (Rep m)) instance ToExp SubExp where toExp = pure . BasicOp . SubExp instance ToExp VName where toExp = pure . BasicOp . SubExp . Var -- | A convenient composition of 'letSubExp' and 'toExp'. toSubExp :: (MonadBuilder m, ToExp a) => String -> a -> m SubExp toSubExp s e = letSubExp s =<< toExp e futhark-0.25.27/src/Futhark/Doc/000077500000000000000000000000001475065116200162425ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Doc/Generator.hs000066400000000000000000000775041475065116200205410ustar00rootroot00000000000000-- | The core logic of @futhark doc@. module Futhark.Doc.Generator (renderFiles) where import CMarkGFM qualified as GFM import Control.Arrow ((***)) import Control.Monad import Control.Monad.Reader import Control.Monad.Writer (Writer, WriterT, runWriter, runWriterT, tell) import Data.Bifunctor (second) import Data.Char (isAlpha, isSpace, toUpper) import Data.List (find, groupBy, inits, intersperse, isPrefixOf, partition, sort, sortOn, tails) import Data.Map qualified as M import Data.Maybe import Data.Ord import Data.Set qualified as S import Data.String (fromString) import Data.Text qualified as T import Data.Version import Futhark.Util.Pretty (Doc, docText, pretty) import Futhark.Version import Language.Futhark import Language.Futhark.Semantic import Language.Futhark.Warnings import System.FilePath (makeRelative, splitPath, (-<.>), ()) import Text.Blaze.Html5 (AttributeValue, Html, toHtml, (!)) import Text.Blaze.Html5 qualified as H import Text.Blaze.Html5.Attributes qualified as A import Prelude hiding (abs, mod) docToHtml :: Doc a -> Html docToHtml = toHtml . docText primTypeHtml :: PrimType -> Html primTypeHtml = docToHtml . pretty prettyU :: Uniqueness -> Html prettyU = docToHtml . pretty renderName :: Name -> Html renderName name = docToHtml (pretty name) joinBy :: Html -> [Html] -> Html joinBy _ [] = mempty joinBy _ [x] = x joinBy sep (x : xs) = x <> foldMap (sep <>) xs commas :: [Html] -> Html commas = joinBy ", " parens :: Html -> Html parens x = "(" <> x <> ")" braces :: Html -> Html braces x = "{" <> x <> "}" brackets :: Html -> Html brackets x = "[" <> x <> "]" pipes :: [Html] -> Html pipes = joinBy " | " -- | A set of names that we should not generate links to, because they -- are uninteresting. These are for example type parameters. type NoLink = S.Set VName -- | A mapping from unique names to the file in which they exist, and -- the anchor (not necessarily globally unique) for linking to that -- name. type FileMap = M.Map VName (FilePath, String) vnameToFileMap :: Imports -> FileMap vnameToFileMap = mconcat . map forFile where forFile (file, FileModule _ file_env _prog _) = forEnv "" file_env where file' = makeRelative "/" $ includeToFilePath file vname prefix ns v = M.singleton (qualLeaf v) (file', canon ns $ prefix <> baseString (qualLeaf v)) vname' prefix ((ns, _), v) = vname prefix ns v canon Term s = "term:" <> s canon Type s = "type:" <> s canon Signature s = "modtype:" <> s forEnv prefix env = mconcat (map (forType prefix) $ M.toList $ envTypeTable env) <> mconcat (map (forMod prefix) $ M.toList $ envModTable env) <> mconcat (map (forMty prefix) $ M.toList $ envModTypeTable env) <> mconcat (map (vname' prefix) $ M.toList $ envNameMap env) forMod prefix (name, ModEnv env) = forEnv (prefix <> baseString name <> ".") env forMod _ (_, ModFun {}) = mempty forMty prefix (name, MTy abs mod) = forMod prefix (name, mod) <> mconcat (map (vname prefix Type) (M.keys abs)) forType prefix = vname prefix Type . qualName . fst data Context = Context { ctxCurrent :: String, ctxFileMod :: FileModule, ctxImports :: Imports, ctxNoLink :: NoLink, ctxFileMap :: FileMap, -- | Local module types that show up in the interface. These -- should be documented, but clearly marked local. ctxVisibleMTys :: S.Set VName } -- | We keep a mapping of the names we have actually documented, so we -- can generate an index. type Documented = M.Map VName IndexWhat type DocM = ReaderT Context (WriterT Documented (Writer Warnings)) data IndexWhat = IndexValue | IndexFunction | IndexModule | IndexModuleType | IndexType warn :: Loc -> Doc () -> DocM () warn loc s = lift $ lift $ tell $ singleWarning loc s document :: VName -> IndexWhat -> DocM () document v what = tell $ M.singleton v what noLink :: [VName] -> DocM a -> DocM a noLink names = local $ \ctx -> ctx {ctxNoLink = S.fromList names <> ctxNoLink ctx} selfLink :: AttributeValue -> Html -> Html selfLink s = H.a ! A.id s ! A.href ("#" <> s) ! A.class_ "self_link" fullRow :: Html -> Html fullRow = H.tr . (H.td ! A.colspan "3") emptyRow :: Html emptyRow = H.tr $ H.td mempty <> H.td mempty <> H.td mempty specRow :: Html -> Html -> Html -> Html specRow a b c = H.tr $ (H.td ! A.class_ "spec_lhs") a <> (H.td ! A.class_ "spec_eql") b <> (H.td ! A.class_ "spec_rhs") c -- | The header documentation (which need not be present) can contain -- an abstract and further sections. headerDoc :: Prog -> DocM (Html, Html, Html) headerDoc prog = case progDoc prog of Just (DocComment doc loc) -> do let (abstract, more_sections) = splitHeaderDoc $ T.unpack doc first_paragraph <- docHtml $ Just $ DocComment (firstParagraph abstract) loc abstract' <- docHtml $ Just $ DocComment (T.pack abstract) loc more_sections' <- docHtml $ Just $ DocComment (T.pack more_sections) loc pure ( first_paragraph, selfLink "abstract" (H.h2 "Abstract") <> abstract', more_sections' ) _ -> pure mempty where splitHeaderDoc s = fromMaybe (s, mempty) $ find (("\n##" `isPrefixOf`) . snd) $ zip (inits s) (tails s) firstParagraph = T.pack . unlines . takeWhile (not . paragraphSeparator) . lines paragraphSeparator = all isSpace contentsPage :: [ImportName] -> [(ImportName, Html)] -> Html contentsPage important_imports pages = H.docTypeHtml $ addBoilerplate "index.html" "Futhark Library Documentation" $ H.main $ ( if null important_pages then mempty else H.h2 "Main libraries" <> fileList important_pages ) <> ( if null unimportant_pages then mempty else H.h2 "Supporting libraries" <> fileList unimportant_pages ) where (important_pages, unimportant_pages) = partition ((`elem` important_imports) . fst) pages fileList pages' = H.dl ! A.class_ "file_list" $ mconcat $ map linkTo $ sortOn fst pages' linkTo (name, maybe_abstract) = H.div ! A.class_ "file_desc" $ (H.dt ! A.class_ "desc_header") (importLink "index.html" name) <> (H.dd ! A.class_ "desc_doc") maybe_abstract importLink :: FilePath -> ImportName -> Html importLink current name = let file = relativise ("doc" makeRelative "/" (includeToFilePath name) -<.> "html") current in (H.a ! A.href (fromString file) $ fromString (includeToString name)) indexPage :: [ImportName] -> Imports -> Documented -> FileMap -> Html indexPage important_imports imports documented fm = H.docTypeHtml $ addBoilerplateWithNav important_imports imports "doc-index.html" "Index" $ H.main $ ( H.ul ! A.id "doc_index_list" $ mconcat $ map initialListEntry $ letter_group_links ++ [symbol_group_link] ) <> ( H.table ! A.id "doc_index" $ H.thead (H.tr $ H.td "Who" <> H.td "What" <> H.td "Where") <> mconcat (letter_groups ++ [symbol_group]) ) where (letter_names, sym_names) = partition (isLetterName . baseString . fst) $ sortOn (map toUpper . baseString . fst) $ mapMaybe isDocumented $ M.toList fm isDocumented (k, (file, k_id)) = do what <- M.lookup k documented Just (k, (file, (what, k_id))) (letter_groups, letter_group_links) = unzip $ map tbodyForNames $ groupBy sameInitial letter_names (symbol_group, symbol_group_link) = tbodyForInitial "Symbols" sym_names isLetterName [] = False isLetterName (c : _) = isAlpha c sameInitial (x, _) (y, _) = case (baseString x, baseString y) of (x' : _, y' : _) -> toUpper x' == toUpper y' _ -> False tbodyForNames names@((s, _) : _) = tbodyForInitial (map toUpper $ take 1 $ baseString s) names tbodyForNames _ = mempty tbodyForInitial initial names = ( H.tbody $ mconcat $ initial' : map linkTo names, initial ) where initial' = H.tr $ H.td ! A.colspan "2" ! A.class_ "doc_index_initial" $ H.a ! A.id (fromString initial) ! A.href (fromString $ '#' : initial) $ fromString initial initialListEntry initial = H.li $ H.a ! A.href (fromString $ '#' : initial) $ fromString initial linkTo (name, (file, (what, name_id))) = let file' = makeRelative "/" file name_link = vnameLink' name_id "" file' link = (H.a ! A.href (fromString (makeRelative "/" $ "doc" name_link))) $ fromString $ baseString name what' = case what of IndexValue -> "value" IndexFunction -> "function" IndexType -> "type" IndexModuleType -> "module type" IndexModule -> "module" html_file = "doc" file' -<.> "html" in H.tr $ (H.td ! A.class_ "doc_index_name" $ link) <> (H.td ! A.class_ "doc_index_namespace" $ what') <> ( H.td ! A.class_ "doc_index_file" $ (H.a ! A.href (fromString html_file) $ fromString file) ) addBoilerplate :: String -> String -> Html -> Html addBoilerplate current titleText content = let headHtml = H.head $ H.meta ! A.charset "utf-8" <> H.title (fromString titleText) <> H.link ! A.href (fromString $ relativise "style.css" current) ! A.rel "stylesheet" ! A.type_ "text/css" navigation = H.ul ! A.id "navigation" $ H.li (H.a ! A.href (fromString $ relativise "index.html" current) $ "Contents") <> H.li (H.a ! A.href (fromString $ relativise "doc-index.html" current) $ "Index") madeByHtml = "Generated by " <> (H.a ! A.href futhark_doc_url) "futhark-doc" <> " " <> fromString (showVersion version) in headHtml <> H.body ( (H.div ! A.id "header") (H.h1 (toHtml titleText) <> navigation) <> (H.div ! A.id "content") content <> (H.div ! A.id "footer") madeByHtml ) where futhark_doc_url = "https://futhark.readthedocs.io/en/latest/man/futhark-doc.html" addBoilerplateWithNav :: [ImportName] -> Imports -> String -> String -> Html -> Html addBoilerplateWithNav important_imports imports current titleText content = addBoilerplate current titleText $ (H.nav ! A.id "filenav" $ files) <> content where files = H.ul $ mconcat $ map pp $ sort $ filter visible important_imports pp name = H.li $ importLink current name visible = (`elem` map fst imports) synopsisDecs :: [Dec] -> DocM Html synopsisDecs decs = do visible <- asks ctxVisibleMTys fm <- asks ctxFileMod -- We add an empty row to avoid generating invalid HTML in cases -- where all rows are otherwise colspan=2. (H.table ! A.class_ "specs") . (emptyRow <>) . mconcat <$> sequence (mapMaybe (synopsisDec visible fm) decs) synopsisDec :: S.Set VName -> FileModule -> Dec -> Maybe (DocM Html) synopsisDec visible fm dec = case dec of ModTypeDec s -> synopsisModType mempty s ModDec m -> synopsisMod fm m ValDec v -> synopsisValBind v TypeDec t -> synopsisType t OpenDec x _ | Just opened <- synopsisOpened x -> Just $ do opened' <- opened pure $ fullRow $ keyword "open " <> opened' | otherwise -> Just $ pure $ fullRow $ keyword "open" <> fromString (" <" <> prettyString x <> ">") LocalDec (ModTypeDec s) _ | modTypeName s `S.member` visible -> synopsisModType (keyword "local" <> " ") s LocalDec {} -> Nothing ImportDec {} -> Nothing synopsisOpened :: ModExp -> Maybe (DocM Html) synopsisOpened (ModVar qn _) = Just $ qualNameHtml qn synopsisOpened (ModParens me _) = do me' <- synopsisOpened me Just $ parens <$> me' synopsisOpened (ModImport _ (Info file) _) = Just $ do current <- asks ctxCurrent let dest = fromString $ relativise (includeToFilePath file) current -<.> "html" pure $ keyword "import " <> (H.a ! A.href dest) (fromString (show (includeToString file))) synopsisOpened (ModAscript _ se _ _) = Just $ do se' <- synopsisModTypeExp se pure $ "... : " <> se' synopsisOpened _ = Nothing vnameSynopsisDef :: VName -> DocM Html vnameSynopsisDef vname = do (_, vname_id) <- vnameId vname pure $ H.span ! A.id (fromString ("synopsis:" <> vname_id)) $ H.a ! A.href (fromString ("#" ++ vname_id)) $ renderName (baseName vname) synopsisValBind :: ValBind -> Maybe (DocM Html) synopsisValBind vb = Just $ do name' <- vnameSynopsisDef $ valBindName vb (lhs, mhs, rhs) <- valBindHtml name' vb pure $ specRow lhs (mhs <> " : ") rhs valBindHtml :: Html -> ValBind -> DocM (Html, Html, Html) valBindHtml name (ValBind _ _ retdecl (Info rettype) tparams params _ _ _ _) = do tparams' <- mconcat <$> mapM (fmap (" " <>) . typeParamHtml) tparams let noLink' = noLink $ map typeParamName tparams <> foldMap patNames params rettype' <- noLink' $ maybe (retTypeHtml rettype) typeExpHtml retdecl params' <- noLink' $ mapM paramHtml params pure ( keyword "val " <> (H.span ! A.class_ "decl_name") name, tparams', mconcat (intersperse " -> " $ params' ++ [rettype']) ) synopsisModType :: Html -> ModTypeBind -> Maybe (DocM Html) synopsisModType prefix sb = Just $ do name' <- vnameSynopsisDef $ modTypeName sb fullRow <$> do se' <- synopsisModTypeExp $ modTypeExp sb pure $ prefix <> keyword "module type " <> name' <> " = " <> se' synopsisMod :: FileModule -> ModBind -> Maybe (DocM Html) synopsisMod fm (ModBind name ps sig _ _ _) = case sig of Nothing -> (proceed <=< envModType) <$> M.lookup name modtable Just (s, _) -> Just $ proceed =<< synopsisModTypeExp s where proceed sig' = do name' <- vnameSynopsisDef name ps' <- modParamHtml ps pure $ specRow (keyword "module " <> name') ": " (ps' <> sig') FileModule _abs Env {envModTable = modtable} _ _ = fm envModType (ModEnv e) = renderEnv e envModType (ModFun (FunModType _ _ (MTy _ m))) = envModType m synopsisType :: TypeBind -> Maybe (DocM Html) synopsisType tb = Just $ do name' <- vnameSynopsisDef $ typeAlias tb fullRow <$> typeBindHtml name' tb typeBindHtml :: Html -> TypeBind -> DocM Html typeBindHtml name' (TypeBind _ l tparams t _ _ _) = do t' <- noLink (map typeParamName tparams) $ typeExpHtml t abbrev <- typeAbbrevHtml l name' tparams pure $ abbrev <> " = " <> t' renderEnv :: Env -> DocM Html renderEnv (Env vtable ttable sigtable modtable _) = do typeBinds <- mapM renderTypeBind (M.toList ttable) valBinds <- mapM renderValBind (M.toList vtable) sigBinds <- mapM renderModType (M.toList sigtable) modBinds <- mapM renderMod (M.toList modtable) pure $ braces $ mconcat $ typeBinds ++ valBinds ++ sigBinds ++ modBinds renderModType :: (VName, MTy) -> DocM Html renderModType (name, _sig) = (keyword "module type " <>) <$> qualNameHtml (qualName name) renderMod :: (VName, Mod) -> DocM Html renderMod (name, _mod) = (keyword "module " <>) <$> qualNameHtml (qualName name) renderValBind :: (VName, BoundV) -> DocM Html renderValBind = fmap H.div . synopsisValBindBind renderTypeBind :: (VName, TypeBinding) -> DocM Html renderTypeBind (name, TypeAbbr l tps tp) = do tp' <- retTypeHtml $ toResRet Nonunique tp name' <- vnameHtml name abbrev <- typeAbbrevHtml l name' tps pure $ H.div $ abbrev <> " = " <> tp' synopsisValBindBind :: (VName, BoundV) -> DocM Html synopsisValBindBind (name, BoundV tps t) = do tps' <- mapM typeParamHtml tps t' <- typeHtml $ second (const Nonunique) t name' <- vnameHtml name pure $ keyword "val " <> name' <> mconcat (map (" " <>) tps') <> ": " <> t' dietHtml :: Diet -> Html dietHtml Consume = "*" dietHtml Observe = "" typeHtml :: TypeBase Size Uniqueness -> DocM Html typeHtml t = case t of Array u shape et -> do shape' <- prettyShape shape et' <- typeHtml $ Scalar $ second (const Nonunique) et pure $ prettyU u <> shape' <> et' Scalar (Prim et) -> pure $ primTypeHtml et Scalar (Record fs) | Just ts <- areTupleFields fs -> parens . commas <$> mapM typeHtml ts | otherwise -> braces . commas <$> mapM ppField (M.toList fs) where ppField (name, tp) = do tp' <- typeHtml tp pure $ toHtml (nameToString name) <> ": " <> tp' Scalar (TypeVar u et targs) -> do targs' <- mapM typeArgHtml targs et' <- qualNameHtml et pure $ prettyU u <> et' <> mconcat (map (" " <>) targs') Scalar (Arrow _ pname d t1 t2) -> do t1' <- typeHtml $ second (const Nonunique) t1 t2' <- retTypeHtml t2 case pname of Named v -> do v' <- vnameHtml v pure $ parens (v' <> ": " <> dietHtml d <> t1') <> " -> " <> t2' Unnamed -> pure $ dietHtml d <> t1' <> " -> " <> t2' Scalar (Sum cs) -> pipes <$> mapM ppClause (sortConstrs cs) where ppClause (n, ts) = joinBy " " . (ppConstr n :) <$> mapM typeHtml ts ppConstr name = "#" <> toHtml (nameToString name) retTypeHtml :: ResRetType -> DocM Html retTypeHtml (RetType [] t) = typeHtml t retTypeHtml (RetType dims t) = do t' <- typeHtml t dims' <- mapM vnameHtml dims pure $ "?" <> mconcat (map brackets dims') <> "." <> t' prettyShape :: Shape Size -> DocM Html prettyShape (Shape ds) = mconcat <$> mapM dimDeclHtml ds typeArgHtml :: TypeArg Size -> DocM Html typeArgHtml (TypeArgDim d) = dimDeclHtml d typeArgHtml (TypeArgType t) = typeHtml $ second (const Nonunique) t modParamHtml :: [ModParamBase Info VName] -> DocM Html modParamHtml [] = pure mempty modParamHtml (ModParam pname psig _ _ : mps) = do pname' <- vnameHtml pname psig' <- synopsisModTypeExp psig mps' <- modParamHtml mps pure $ "(" <> pname' <> ": " <> psig' <> ") -> " <> mps' synopsisModTypeExp :: ModTypeExpBase Info VName -> DocM Html synopsisModTypeExp e = case e of ModTypeVar v _ _ -> qualNameHtml v ModTypeParens e' _ -> parens <$> synopsisModTypeExp e' ModTypeSpecs ss _ -> braces . (H.table ! A.class_ "specs") . mconcat <$> mapM synopsisSpec ss ModTypeWith s (TypeRef v ps t _) _ -> do s' <- synopsisModTypeExp s t' <- typeExpHtml t v' <- qualNameHtml v ps' <- mconcat <$> mapM (fmap (" " <>) . typeParamHtml) ps pure $ s' <> keyword " with " <> v' <> ps' <> " = " <> t' ModTypeArrow Nothing e1 e2 _ -> liftM2 f (synopsisModTypeExp e1) (synopsisModTypeExp e2) where f e1' e2' = e1' <> " -> " <> e2' ModTypeArrow (Just v) e1 e2 _ -> do name <- vnameHtml v e1' <- synopsisModTypeExp e1 e2' <- noLink [v] $ synopsisModTypeExp e2 pure $ "(" <> name <> ": " <> e1' <> ") -> " <> e2' keyword :: String -> Html keyword = (H.span ! A.class_ "keyword") . fromString vnameHtml :: VName -> DocM Html vnameHtml vname = do (_, vname_id) <- vnameId vname pure $ H.span ! A.id (fromString vname_id) $ renderName $ baseName vname -- | The canonical (in-file) anchor ID for a VName, along with the -- file in which it is defined. vnameId :: VName -> DocM (FilePath, String) vnameId vname = do current <- asks ctxCurrent asks $ fromMaybe (current, show (baseTag vname)) . M.lookup vname . ctxFileMap vnameDescDef :: VName -> IndexWhat -> DocM Html vnameDescDef v what = do document v what (_, v_id) <- vnameId v pure $ H.a ! A.id (fromString v_id) $ renderName (baseName v) vnameSynopsisRef :: VName -> DocM Html vnameSynopsisRef v = do (_, v_id) <- vnameId v pure $ H.a ! A.class_ "synopsis_link" ! A.href (fromString ("#" <> "synopsis:" <> v_id)) $ "↑" synopsisSpec :: SpecBase Info VName -> DocM Html synopsisSpec spec = case spec of TypeAbbrSpec tpsig -> do def <- vnameSynopsisDef $ typeAlias tpsig fullRow <$> typeBindHtml def tpsig TypeSpec l name ps _ _ -> do name' <- vnameSynopsisDef name ps' <- mconcat <$> mapM (fmap (" " <>) . typeParamHtml) ps pure $ fullRow $ keyword l' <> name' <> ps' where l' = case l of Unlifted -> "type " SizeLifted -> "type~ " Lifted -> "type^ " ValSpec name tparams rettype _ _ _ -> do tparams' <- mapM typeParamHtml tparams rettype' <- noLink (map typeParamName tparams) $ typeExpHtml rettype name' <- vnameSynopsisDef name pure $ specRow (keyword "val " <> name') (mconcat (map (" " <>) tparams') <> ": ") rettype' ModSpec name sig _ _ -> do name' <- vnameSynopsisDef name specRow (keyword "module " <> name') ": " <$> synopsisModTypeExp sig IncludeSpec e _ -> fullRow . (keyword "include " <>) <$> synopsisModTypeExp e typeExpHtml :: TypeExp Exp VName -> DocM Html typeExpHtml e = case e of TEUnique t _ -> ("*" <>) <$> typeExpHtml t TEArray d at _ -> do at' <- typeExpHtml at d' <- dimExpHtml d pure $ d' <> at' TETuple ts _ -> parens . commas <$> mapM typeExpHtml ts TERecord fs _ -> braces . commas <$> mapM ppField fs where ppField (L _ name, t) = do t' <- typeExpHtml t pure $ toHtml (nameToString name) <> ": " <> t' TEVar name _ -> qualNameHtml name TEParens te _ -> parens <$> typeExpHtml te TEApply t arg _ -> do t' <- typeExpHtml t arg' <- typeArgExpHtml arg pure $ t' <> " " <> arg' TEArrow pname t1 t2 _ -> do t1' <- case t1 of TEArrow {} -> parens <$> typeExpHtml t1 _ -> typeExpHtml t1 t2' <- typeExpHtml t2 case pname of Just v -> do v' <- vnameHtml v pure $ parens (v' <> ": " <> t1') <> " -> " <> t2' Nothing -> pure $ t1' <> " -> " <> t2' TESum cs _ -> pipes <$> mapM ppClause cs where ppClause (n, ts) = joinBy " " . (ppConstr n :) <$> mapM typeExpHtml ts ppConstr name = "#" <> toHtml (nameToString name) TEDim dims t _ -> do t' <- typeExpHtml t pure $ "?" <> mconcat (map (brackets . renderName . baseName) dims) <> "." <> t' qualNameHtml :: QualName VName -> DocM Html qualNameHtml (QualName names vname@(VName name tag)) = if tag <= maxIntrinsicTag then pure $ renderName name else f <$> ref where prefix :: Html prefix = mapM_ ((<> ".") . renderName . baseName) names f (Just s) = H.a ! A.href (fromString s) $ prefix <> renderName name f Nothing = prefix <> renderName name ref = do boring <- asks $ S.member vname . ctxNoLink if boring then pure Nothing else Just <$> vnameLink vname -- | The link for a VName. vnameLink :: VName -> DocM String vnameLink vname = do current <- asks ctxCurrent (file, tag) <- vnameId vname pure $ vnameLink' tag current file vnameLink' :: String -> FilePath -> FilePath -> String vnameLink' tag current file = if file == current then "#" ++ tag else relativise file current -<.> ".html#" ++ tag paramHtml :: Pat ParamType -> DocM Html paramHtml pat = do let (pat_param, d, t) = patternParam pat t' <- typeHtml $ second (const Nonunique) t case pat_param of Named v -> do v' <- vnameHtml v pure $ parens $ v' <> ": " <> dietHtml d <> t' Unnamed -> pure t' relativise :: FilePath -> FilePath -> FilePath relativise dest src = concat (replicate (length (splitPath src) - 1) "../") ++ makeRelative "/" dest dimDeclHtml :: Size -> DocM Html dimDeclHtml = pure . brackets . toHtml . prettyString dimExpHtml :: SizeExp Exp -> DocM Html dimExpHtml (SizeExpAny _) = pure $ brackets mempty dimExpHtml (SizeExp e _) = pure $ brackets $ toHtml $ prettyString e typeArgExpHtml :: TypeArgExp Exp VName -> DocM Html typeArgExpHtml (TypeArgExpSize d) = dimExpHtml d typeArgExpHtml (TypeArgExpType d) = typeExpHtml d typeParamHtml :: TypeParam -> DocM Html typeParamHtml (TypeParamDim name _) = do name' <- vnameHtml name pure $ brackets name' typeParamHtml (TypeParamType l name _) = do name' <- vnameHtml name pure $ "'" <> fromString (prettyString l) <> name' typeAbbrevHtml :: Liftedness -> Html -> [TypeParam] -> DocM Html typeAbbrevHtml l name params = do params' <- mconcat <$> mapM (fmap (" " <>) . typeParamHtml) params pure $ what <> name <> params' where what = keyword $ "type" <> prettyString l <> " " docHtml :: Maybe DocComment -> DocM Html docHtml (Just (DocComment doc loc)) = H.preEscapedText . GFM.commonmarkToHtml [] [GFM.extAutolink] . T.pack <$> identifierLinks (locOf loc) (T.unpack doc) docHtml Nothing = pure mempty identifierLinks :: Loc -> String -> DocM String identifierLinks _ [] = pure [] identifierLinks loc s | Just ((name, namespace, file), s') <- identifierReference s = do let proceed x = (x <>) <$> identifierLinks loc s' unknown = proceed $ "`" <> name <> "`" case knownNamespace namespace of Just namespace' -> do maybe_v <- lookupName (namespace', name, file) case maybe_v of Nothing -> do warn loc $ "Identifier '" <> fromString name <> "' not found in namespace '" <> fromString namespace <> "'" <> fromString (maybe "" (" in file " <>) file) <> "." unknown Just v' -> do link <- vnameLink v' proceed $ "[`" <> name <> "`](" <> link <> ")" _ -> do warn loc $ "Unknown namespace '" <> fromString namespace <> "'." unknown where knownNamespace "term" = Just Term knownNamespace "mtype" = Just Signature knownNamespace "type" = Just Type knownNamespace _ = Nothing identifierLinks loc (c : s') = (c :) <$> identifierLinks loc s' lookupName :: (Namespace, String, Maybe FilePath) -> DocM (Maybe VName) lookupName (namespace, name, file) = do current <- asks ctxCurrent let file' = mkImportFrom (mkInitialImport current) <$> file env <- lookupEnvForFile file' case M.lookup (namespace, nameFromString name) . envNameMap =<< env of Nothing -> pure Nothing Just qn -> pure $ Just $ qualLeaf qn lookupEnvForFile :: Maybe ImportName -> DocM (Maybe Env) lookupEnvForFile Nothing = asks $ Just . fileEnv . ctxFileMod lookupEnvForFile (Just file) = asks $ fmap fileEnv . lookup file . ctxImports describeGeneric :: VName -> IndexWhat -> Maybe DocComment -> (Html -> DocM Html) -> DocM Html describeGeneric name what doc f = do name' <- H.span ! A.class_ "decl_name" <$> vnameDescDef name what decl_type <- f name' doc' <- docHtml doc ref <- vnameSynopsisRef name let decl_doc = H.dd ! A.class_ "desc_doc" $ doc' decl_header = (H.dt ! A.class_ "desc_header") (ref <> decl_type) pure $ decl_header <> decl_doc describeGenericMod :: VName -> IndexWhat -> ModTypeExp -> Maybe DocComment -> (Html -> DocM Html) -> DocM Html describeGenericMod name what se doc f = do name' <- H.span ! A.class_ "decl_name" <$> vnameDescDef name what decl_type <- f name' doc' <- case se of ModTypeSpecs specs _ -> (<>) <$> docHtml doc <*> describeSpecs specs _ -> docHtml doc ref <- vnameSynopsisRef name let decl_doc = H.dd ! A.class_ "desc_doc" $ doc' decl_header = (H.dt ! A.class_ "desc_header") $ ref <> decl_type pure $ decl_header <> decl_doc describeDecs :: [Dec] -> DocM Html describeDecs decs = do visible <- asks ctxVisibleMTys H.dl . mconcat <$> mapM (fmap $ H.div ! A.class_ "decl_description") (mapMaybe (describeDec visible) decs) describeDec :: S.Set VName -> Dec -> Maybe (DocM Html) describeDec _ (ValDec vb) = Just $ describeGeneric (valBindName vb) (valBindWhat vb) (valBindDoc vb) $ \name -> do (lhs, mhs, rhs) <- valBindHtml name vb pure $ lhs <> mhs <> ": " <> rhs describeDec _ (TypeDec vb) = Just $ describeGeneric (typeAlias vb) IndexType (typeDoc vb) (`typeBindHtml` vb) describeDec _ (ModTypeDec (ModTypeBind name se doc _)) = Just $ describeGenericMod name IndexModuleType se doc $ \name' -> pure $ keyword "module type " <> name' describeDec _ (ModDec mb) = Just $ describeGeneric (modName mb) IndexModule (modDoc mb) $ \name' -> pure $ keyword "module " <> name' describeDec _ OpenDec {} = Nothing describeDec visible (LocalDec (ModTypeDec (ModTypeBind name se doc _)) _) | name `S.member` visible = Just $ describeGenericMod name IndexModuleType se doc $ \name' -> pure $ keyword "local module type " <> name' describeDec _ LocalDec {} = Nothing describeDec _ ImportDec {} = Nothing valBindWhat :: ValBind -> IndexWhat valBindWhat vb | null (valBindParams vb), RetType _ t <- unInfo $ valBindRetType vb, orderZero t = IndexValue | otherwise = IndexFunction describeSpecs :: [Spec] -> DocM Html describeSpecs specs = H.dl . mconcat <$> mapM describeSpec specs describeSpec :: Spec -> DocM Html describeSpec (ValSpec name tparams t _ doc _) = describeGeneric name what doc $ \name' -> do tparams' <- mconcat <$> mapM (fmap (" " <>) . typeParamHtml) tparams t' <- noLink (map typeParamName tparams) $ typeExpHtml t pure $ keyword "val " <> name' <> tparams' <> ": " <> t' where what = case t of TEArrow {} -> IndexFunction _ -> IndexValue describeSpec (TypeAbbrSpec vb) = describeGeneric (typeAlias vb) IndexType (typeDoc vb) (`typeBindHtml` vb) describeSpec (TypeSpec l name tparams doc _) = describeGeneric name IndexType doc $ \name' -> typeAbbrevHtml l name' tparams describeSpec (ModSpec name se doc _) = describeGenericMod name IndexModule se doc $ \name' -> case se of ModTypeSpecs {} -> pure $ keyword "module " <> name' _ -> do se' <- synopsisModTypeExp se pure $ keyword "module " <> name' <> ": " <> se' describeSpec (IncludeSpec sig _) = do sig' <- synopsisModTypeExp sig doc' <- docHtml Nothing let decl_header = (H.dt ! A.class_ "desc_header") $ (H.span ! A.class_ "synopsis_link") mempty <> keyword "include " <> sig' decl_doc = H.dd ! A.class_ "desc_doc" $ doc' pure $ decl_header <> decl_doc -- | @renderFiles important_imports imports@ produces HTML files -- documenting the type-checked program @imports@, with the files in -- @important_imports@ considered most important. The HTML files must -- be written to the specific locations indicated in the return value, -- or the relative links will be wrong. renderFiles :: [ImportName] -> Imports -> ([(FilePath, Html)], Warnings) renderFiles important_imports imports = runWriter $ do (import_pages, documented) <- runWriterT $ forM imports $ \(current, fm) -> do let ctx = Context { ctxCurrent = makeRelative "/" $ includeToFilePath current, ctxFileMod = fm, ctxImports = imports, ctxNoLink = mempty, ctxFileMap = file_map, ctxVisibleMTys = progModuleTypes $ fileProg fm } flip runReaderT ctx $ do (first_paragraph, maybe_abstract, maybe_sections) <- headerDoc $ fileProg fm synopsis <- (H.div ! A.id "module") <$> synopsisDecs (progDecs $ fileProg fm) description <- describeDecs $ progDecs $ fileProg fm pure ( current, ( H.docTypeHtml ! A.lang "en" $ addBoilerplateWithNav important_imports imports ("doc" includeToFilePath current) (includeToString current) $ H.main $ maybe_abstract <> selfLink "synopsis" (H.h2 "Synopsis") <> (H.div ! A.id "overview") synopsis <> selfLink "description" (H.h2 "Description") <> description <> maybe_sections, first_paragraph ) ) pure $ [ ("index.html", contentsPage important_imports $ map (fmap snd) import_pages), ("doc-index.html", indexPage important_imports imports documented file_map) ] ++ map (importHtml *** fst) import_pages where file_map = vnameToFileMap imports importHtml import_name = "doc" makeRelative "/" (fromString (includeToString import_name)) -<.> "html" futhark-0.25.27/src/Futhark/Error.hs000066400000000000000000000053631475065116200171710ustar00rootroot00000000000000-- | Futhark error definitions. module Futhark.Error ( CompilerError (..), prettyCompilerError, ErrorClass (..), externalError, externalErrorS, InternalError (..), compilerBug, compilerBugS, compilerLimitation, compilerLimitationS, internalErrorS, ) where import Control.Exception import Control.Monad.Error.Class import Data.Text qualified as T import Futhark.Util.Pretty import Prettyprinter.Render.Text (renderStrict) -- | There are two classes of internal errors: actual bugs, and -- implementation limitations. The latter are already known and need -- not be reported. data ErrorClass = CompilerBug | CompilerLimitation deriving (Eq, Ord, Show) -- | A compiler error. data CompilerError = -- | An error that happened due to something the user did, such as -- provide incorrect code or options. ExternalError (Doc AnsiStyle) | -- | An internal compiler error. The second pretty is extra data -- for debugging, which can be written to a file. InternalError T.Text T.Text ErrorClass -- | Print an error intended for human consumption. prettyCompilerError :: CompilerError -> Doc AnsiStyle prettyCompilerError (ExternalError e) = e prettyCompilerError (InternalError s _ _) = pretty s -- | Raise an 'ExternalError' based on a prettyprinting result. externalError :: (MonadError CompilerError m) => Doc AnsiStyle -> m a externalError = throwError . ExternalError -- | Raise an 'ExternalError' based on a string. externalErrorS :: (MonadError CompilerError m) => String -> m a externalErrorS = externalError . pretty -- | Raise an v'InternalError' based on a prettyprinting result. internalErrorS :: (MonadError CompilerError m) => String -> Doc AnsiStyle -> m a internalErrorS s d = throwError $ InternalError (T.pack s) (p d) CompilerBug where p = renderStrict . layoutSmart defaultLayoutOptions -- | An error that is not the users fault, but a bug (or limitation) -- in the compiler. Compiler passes should only ever report this -- error - any problems after the type checker are *our* fault, not -- the users. These are generally thrown as IO exceptions, and caught -- at the top level. data InternalError = Error ErrorClass T.Text deriving (Show) instance Exception InternalError -- | Throw an t'InternalError' that is a 'CompilerBug'. compilerBug :: T.Text -> a compilerBug = throw . Error CompilerBug -- | Throw an t'InternalError' that is a 'CompilerLimitation'. compilerLimitation :: T.Text -> a compilerLimitation = throw . Error CompilerLimitation -- | Like 'compilerBug', but with a 'String'. compilerBugS :: String -> a compilerBugS = compilerBug . T.pack -- | Like 'compilerLimitation', but with a 'String'. compilerLimitationS :: String -> a compilerLimitationS = compilerLimitation . T.pack futhark-0.25.27/src/Futhark/Fmt/000077500000000000000000000000001475065116200162635ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Fmt/Monad.hs000066400000000000000000000325611475065116200176640ustar00rootroot00000000000000module Futhark.Fmt.Monad ( Fmt, -- functions for building fmt nil, nest, stdNest, text, space, hardline, line, sep, brackets, braces, parens, (<|>), (<+>), (), (<:/>), hardIndent, indent, hardStdIndent, stdIndent, FmtM, popComments, runFormat, align, fmtCopyLoc, comment, sepArgs, localLayout, localLayoutList, sepDecs, fmtByLayout, addComments, sepComments, sepLineComments, sepLine, -- * Formatting styles commentStyle, constantStyle, keywordStyle, bindingStyle, infixStyle, ) where import Control.Monad (liftM2) import Control.Monad.Reader ( MonadReader (..), ReaderT (..), ) import Control.Monad.State ( MonadState (..), State, evalState, gets, modify, ) import Data.ByteString qualified as BS import Data.List.NonEmpty qualified as NE import Data.Loc (Loc (..), Located (..), locStart, posCoff, posLine) import Data.Maybe (fromMaybe) import Data.String import Data.Text qualified as T import Data.Text.Encoding qualified as T import Language.Futhark.Parser.Monad (Comment (..)) import Prettyprinter qualified as P import Prettyprinter.Render.Terminal ( AnsiStyle, Color (..), bold, color, colorDull, italicized, ) -- These are right associative since we want to evaluate the monadic -- computation from left to right. Since the left most expression is -- printed first and our monad is checking if a comment should be -- printed. infixr 6 <:/> infixr 6 <+> infixr 6 infixr 4 <|> type Fmt = FmtM (P.Doc AnsiStyle) instance Semigroup Fmt where (<>) = liftM2 (<>) instance Monoid Fmt where mempty = nil instance IsString Fmt where fromString s = text style s' where s' = fromString s style = if s' `elem` keywords then keywordStyle else mempty keywords = [ "true", "false", "if", "then", "else", "def", "let", "loop", "in", "val", "for", "do", "with", "local", "open", "include", "import", "type", "entry", "module", "while", "assert", "match", "case" ] commentStyle, keywordStyle, constantStyle, bindingStyle, infixStyle :: AnsiStyle commentStyle = italicized keywordStyle = color Magenta <> bold constantStyle = color Green bindingStyle = colorDull Blue infixStyle = colorDull Cyan -- | This function allows to inspect the layout of an expression @a@ and if it -- is singleline line then use format @s@ and if it is multiline format @m@. fmtByLayout :: (Located a) => a -> Fmt -> Fmt -> Fmt fmtByLayout a s m = s <|> ( case lineLayout a of Just SingleLine -> s _any -> m ) -- | This function determines the Layout of @a@ and updates the monads -- environment to format in the appropriate style. It determines this -- by checking if the location of @a@ spans over two or more lines. localLayout :: (Located a) => a -> FmtM b -> FmtM b localLayout a = local (\lo -> fromMaybe lo $ lineLayout a) -- | This function determines the Layout of @[a]@ and if it is singleline then it -- updates the monads enviroment to format singleline style otherwise format using -- multiline style. It determines this by checking if the locations of @[a]@ -- start and end at any different line number. localLayoutList :: (Located a) => [a] -> FmtM b -> FmtM b localLayoutList a m = do lo <- ask case lo of MultiLine -> local (const $ fromMaybe lo $ lineLayoutList a) m SingleLine -> m -- | This function uses the location of @a@ and prepends comments if -- the comments location is less than the location of @a@. It format -- @b@ in accordance with if @a@ is singleline or multiline using -- 'localLayout'. It currently does not handle trailing comment -- perfectly. See tests/fmt/traillingComments*.fut. addComments :: (Located a) => a -> Fmt -> Fmt addComments a b = localLayout a $ do c <- fmtComments a f <- b pure $ c <> f prependComments :: (a -> Loc) -> (a -> Fmt) -> a -> Fmt prependComments floc fmt a = do fmcs <- fcs f <- fmt a pure $ fromMaybe mempty fmcs <> f where fcs = do s <- get case comments s of c : cs | floc a /= NoLoc && floc a > locOf c -> do put $ s {comments = cs} mcs <- fcs pre' <- pre pure $ Just $ pre' <> fmtNoLine c <> maybe mempty (P.line <>) mcs _any -> pure Nothing fmtNoLine = P.pretty . commentText pre = do lastO <- gets lastOutput case lastO of Nothing -> nil Just Line -> nil Just _ -> modify (\s -> s {lastOutput = Just Line}) >> hardline -- | The internal state of the formatter monad 'FmtM'. data FmtState = FmtState { -- | The comments that will be inserted, ordered by increasing order in regards to location. comments :: [Comment], -- | The original source file that is being formatted. file :: BS.ByteString, -- | Keeps track of what type the last output was. lastOutput :: !(Maybe LastOutput) } deriving (Show, Eq, Ord) -- | A data type to describe the last output used during formatting. data LastOutput = Line | Space | Text | Comm deriving (Show, Eq, Ord) -- | A data type to describe the layout the formatter is using currently. data Layout = MultiLine | SingleLine deriving (Show, Eq) -- | The format monad used to keep track of comments and layout. It is a a -- combincation of a reader and state monad. The comments and reading from the -- input file are the state monads job to deal with. While the reader monad -- deals with the propagating the current layout. type FmtM a = ReaderT Layout (State FmtState) a fmtComment :: Comment -> Fmt fmtComment c = comment $ commentText c fmtCommentList :: [Comment] -> Fmt fmtCommentList [] = nil fmtCommentList (c : cs) = fst $ foldl f (fmtComment c, locOf c) cs where f (acc, loc) c' = if consecutive loc (locOf c') then (acc <> fmtComment c', locOf c') else (acc <> hardline <> fmtComment c', locOf c') hasComment :: (Located a) => a -> FmtM Bool hasComment a = gets $ not . null . takeWhile relevant . comments where relevant c = locOf a /= NoLoc && locOf a > locOf c -- | Prepends comments. fmtComments :: (Located a) => a -> Fmt fmtComments a = do (here, later) <- gets $ span relevant . comments if null here then pure mempty else do modify $ \s -> s {comments = later} fmtCommentList here <> if consecutive (locOf here) (locOf a) then nil else hardline where relevant c = locOf a /= NoLoc && locOf a > locOf c -- | Determines the layout of @a@ by checking if it spans a single line or two -- or more lines. lineLayout :: (Located a) => a -> Maybe Layout lineLayout a = case locOf a of Loc start end -> if posLine start == posLine end then Just SingleLine else Just MultiLine NoLoc -> Nothing -- error "Formatting term without location." -- | Determines the layout of @[a]@ by checking if it spans a single line or two -- or more lines. lineLayoutList :: (Located a) => [a] -> Maybe Layout lineLayoutList as = case concatMap auxiliary as of [] -> Nothing (t : ts) | any (/= t) ts -> Just MultiLine _ -> Just SingleLine where auxiliary a = case locOf a of Loc start end -> [posLine start, posLine end] NoLoc -> [] -- error "Formatting term without location" -- | Retrieves the last comments from the monad and concatenates them together. popComments :: Fmt popComments = do cs <- gets comments modify (\s -> s {comments = []}) lastO <- gets lastOutput case lastO of Nothing -> fmtCommentList cs -- Happens when file has only comments. _ | not $ null cs -> hardline <> fmtCommentList cs | otherwise -> nil -- | Using the location of @a@ get the segment of text in the original file to -- create a @Fmt@. fmtCopyLoc :: (Located a) => AnsiStyle -> a -> Fmt fmtCopyLoc style a = do f <- gets file case locOf a of Loc sPos ePos -> let sOff = posCoff sPos eOff = posCoff ePos in case T.decodeUtf8' $ BS.take (eOff - sOff) $ BS.drop sOff f of Left err -> error $ show err Right lit -> text style lit NoLoc -> error "Formatting term without location" -- | Given a formatter @FmtM a@, a sequence of comments ordered in increasing -- order by location, and the original text files content. Run the formatter and -- create @a@. runFormat :: FmtM a -> [Comment] -> T.Text -> a runFormat format cs file = evalState (runReaderT format e) s where s = FmtState { comments = cs, file = T.encodeUtf8 file, lastOutput = Nothing } e = MultiLine -- | An empty input. nil :: Fmt nil = pure mempty -- | Indents everything after a line occurs if in multiline and if in singleline -- then indent. nest :: Int -> Fmt -> Fmt nest i a = a <|> (P.nest i <$> a) -- | A space. space :: Fmt space = modify (\s -> s {lastOutput = Just Space}) >> pure P.space -- | Forces a line to be used regardless of layout, this should -- ideally not be used. hardline :: Fmt hardline = do modify $ \s -> s {lastOutput = Just Line} pure P.line -- | A line or a space depending on layout. line :: Fmt line = space <|> hardline -- | Seperates element by a @s@ followed by a space in singleline layout and -- seperates by a line followed by a @s@ in multine layout. sepLine :: Fmt -> [Fmt] -> Fmt sepLine s = sep (s <> space <|> hardline <> s) -- | A comment. comment :: T.Text -> Fmt comment c = do modify (\s -> s {lastOutput = Just Line}) pure $ P.annotate commentStyle (P.pretty (T.stripEnd c)) <> P.line sep :: Fmt -> [Fmt] -> Fmt sep _ [] = nil sep s (a : as) = auxiliary a as where auxiliary acc [] = acc auxiliary acc (x : xs) = auxiliary (acc <> s <> x) xs sepComments :: (a -> Loc) -> (a -> Fmt) -> Fmt -> [a] -> Fmt sepComments _ _ _ [] = nil sepComments floc fmt s (a : as) = auxiliary (fmt a) as where auxiliary acc [] = acc auxiliary acc (x : xs) = auxiliary (acc <> prependComments floc (\y -> s <> fmt y) x) xs sepLineComments :: (a -> Loc) -> (a -> Fmt) -> Fmt -> [a] -> Fmt sepLineComments floc fmt s = sepComments floc fmt (s <> space <|> hardline <> s) -- | This is used for function arguments. It seperates multiline -- arguments by lines and singleline arguments by spaces. We specially -- handle the case where all the arguments are on a single line except -- for the last one, which may continue to the next line. sepArgs :: (Located a) => (a -> Fmt) -> NE.NonEmpty a -> Fmt sepArgs fmt ls = localLayout locs $ align' $ sep line $ map fmtArg ls' where locs = map (locStart . locOf) ls' align' = case lineLayout locs of Just SingleLine -> id _ -> align fmtArg x = localLayout x $ fmt x ls' = NE.toList ls -- | Nest but with the standard value of two spaces. stdNest :: Fmt -> Fmt stdNest = nest 2 -- | Aligns line by line. align :: Fmt -> Fmt align a = do modify (\s -> s {lastOutput = Just Line}) -- XXX? P.align <$> a -- | Indents everything by @i@, should never be used. hardIndent :: Int -> Fmt -> Fmt hardIndent i a = P.indent i <$> a -- | Indents if in multiline by @i@ if in singleline it does not indent. indent :: Int -> Fmt -> Fmt indent i a = a <|> hardIndent i a -- | Hard indents with the standard size of two. hardStdIndent :: Fmt -> Fmt hardStdIndent = hardIndent 2 -- | Idents with the standard size of two. stdIndent :: Fmt -> Fmt stdIndent = indent 2 -- | Creates a piece of text, it should not contain any new lines. text :: AnsiStyle -> T.Text -> Fmt text style t = do modify (\s -> s {lastOutput = Just Text}) pure $ P.annotate style $ P.pretty t -- | Adds brackets. brackets :: Fmt -> Fmt brackets a = "[" <> a <> "]" -- | Adds braces. braces :: Fmt -> Fmt braces a = "{" <> a <> "}" -- | Add parenthesis. parens :: Fmt -> Fmt parens a = "(" <> a <> ")" -- | If in a singleline layout then concatenate with 'nil' and in multiline -- concatenate by a line. (<:/>) :: Fmt -> Fmt -> Fmt a <:/> b = a <> (nil <|> hardline) <> b -- | Concatenate with a space between. (<+>) :: Fmt -> Fmt -> Fmt a <+> b = a <> space <> b -- | Concatenate with a space if in singleline layout and concatenate by a -- line in multiline. () :: Fmt -> Fmt -> Fmt a b = a <> line <> b -- | If in a singleline layout then choose @a@, if in a multiline layout choose -- @b@. (<|>) :: Fmt -> Fmt -> Fmt a <|> b = do lo <- ask if lo == SingleLine then a else b -- | Are these locations on consecutive lines? consecutive :: Loc -> Loc -> Bool consecutive (Loc _ end) (Loc beg _) = posLine end + 1 == posLine beg consecutive _ _ = False -- | If in singleline layout seperate by spaces. In a multiline layout seperate -- by a single line if two neighbouring elements are singleline. Otherwise -- sepereate by two lines. sepDecs :: (Located a) => (a -> Fmt) -> [a] -> Fmt sepDecs _ [] = nil sepDecs fmt decs@(x : xs) = sep space (map fmt decs) <|> (fmt x <> auxiliary x xs) where auxiliary _ [] = nil auxiliary prev (y : ys) = p <> fmt y <> auxiliary y ys where p = do commented <- hasComment y case (commented, lineLayout y, lineLayout prev) of (False, Just SingleLine, Just SingleLine) | consecutive (locOf prev) (locOf y) -> hardline _any -> hardline <> hardline futhark-0.25.27/src/Futhark/Fmt/Printer.hs000066400000000000000000000525761475065116200202610ustar00rootroot00000000000000module Futhark.Fmt.Printer ( fmtToText, fmtToDoc, ) where import Data.Bifunctor (second) import Data.Foldable import Data.Loc (locStart) import Data.Text qualified as T import Futhark.Fmt.Monad import Futhark.Util (showText) import Futhark.Util.Pretty ( AnsiStyle, Doc, Pretty, docText, ) import Language.Futhark import Language.Futhark.Parser ( SyntaxError (..), parseFutharkWithComments, ) lineIndent :: (Located a) => a -> Fmt -> Fmt -> Fmt lineIndent l a b = fmtByLayout l (a <+> b) (a hardStdIndent (align b)) fmtName :: AnsiStyle -> Name -> Fmt fmtName style = text style . nameToText fmtBoundName :: Name -> Fmt fmtBoundName name | operatorName name = parens $ fmtName bindingStyle name | otherwise = fmtName bindingStyle name fmtPretty :: (Pretty a) => a -> Fmt fmtPretty = text mempty . prettyText class Format a where fmt :: a -> Fmt instance Format DocComment where fmt (DocComment x loc) = addComments loc $ sep nil $ prefixes (T.lines x) where prefixes [] = [] prefixes (l : ls) = comment (prefix "-- |" l) : map (comment . prefix "--") ls prefix p s = if T.null s then p else p <> " " <> s -- Avoid trailing whitespace. instance Format (Maybe DocComment) where fmt = maybe nil fmt fmtParamType :: Maybe Name -> UncheckedTypeExp -> Fmt fmtParamType (Just n) te = parens $ fmtName mempty n <> ":" <+> fmt te fmtParamType Nothing te = fmt te fmtSumTypeConstr :: (Name, [UncheckedTypeExp]) -> Fmt fmtSumTypeConstr (name, []) = "#" <> fmtName mempty name fmtSumTypeConstr (name, fs) = "#" <> fmtName mempty name <+> sep space (map fmt fs) instance Format Name where fmt = fmtName mempty -- Format a tuple-like thing (expression, pattern, type). fmtTuple :: (Located a) => [Fmt] -> a -> Fmt fmtTuple xs loc = addComments loc $ fmtByLayout loc singleLine multiLine where singleLine = parens $ sep ", " xs multiLine = align $ "(" <+> sep (line <> "," <> space) xs ")" -- Format a record-like thing (expression, pattern, type). fmtRecord :: (Located a) => [Fmt] -> a -> Fmt fmtRecord xs loc = addComments loc $ fmtByLayout loc singleLine multiLine where singleLine = braces $ sep ", " xs multiLine = align $ "{" <+> sep (line <> "," <> space) xs "}" -- Format an array-like thing. fmtArray :: (Located a) => [Fmt] -> a -> Fmt fmtArray xs loc = addComments loc $ fmtByLayout loc singleLine multiLine where singleLine = brackets $ sep ", " xs multiLine = align $ "[" <+> sep (line <> "," <> space) xs "]" instance Format UncheckedTypeExp where fmt (TEVar v loc) = addComments loc $ fmtQualName v fmt (TETuple ts loc) = fmtTuple (map (align . fmt) ts) loc fmt (TEParens te loc) = addComments loc $ parens $ fmt te fmt (TERecord fs loc) = fmtRecord (map fmtFieldType fs) loc where fmtFieldType (L _ name', t) = fmtName mempty name' <> ":" <+> align (fmt t) fmt (TEArray se te loc) = addComments loc $ fmt se <> fmt te fmt (TEUnique te loc) = addComments loc $ "*" <> fmt te fmt (TEApply te tArgE loc) = addComments loc $ fmt te <+> fmt tArgE fmt (TEArrow name te0 te1 loc) = addComments loc $ fmtParamType name te0 "->" <+> case te1 of TEArrow {} -> fmt te1 _ -> align (fmt te1) fmt (TESum tes loc) = -- Comments can not be inserted correctly here because names do not -- have a location. addComments loc $ fmtByLayout loc singleLine multiLine where singleLine = sep " | " $ map fmtSumTypeConstr tes multiLine = sep line $ zipWith prefix [0 :: Int ..] tes prefix 0 te = " " <> fmtSumTypeConstr te prefix _ te = "| " <> fmtSumTypeConstr te fmt (TEDim dims te loc) = addComments loc $ "?" <> dims' <> "." <> fmt te where dims' = sep nil $ map (brackets . fmt) dims instance Format (TypeArgExp UncheckedExp Name) where fmt (TypeArgExpSize se) = fmt se fmt (TypeArgExpType te) = fmt te instance Format UncheckedTypeBind where fmt (TypeBind name l ps e NoInfo dc loc) = addComments loc $ fmt dc <> "type" <> fmt l <+> fmtName bindingStyle name <> (if null ps then nil else space) <> localLayoutList ps (align $ sep line $ map fmt ps) <+> "=" stdIndent (fmt e) instance Format (AttrAtom a) where fmt (AtomName name) = fmt name fmt (AtomInt int) = text constantStyle $ prettyText int instance Format (AttrInfo a) where fmt attr = "#" <> brackets (fmtAttrInfo attr) where fmtAttrInfo (AttrAtom attr' loc) = addComments loc $ fmt attr' fmtAttrInfo (AttrComp name attrs loc) = addComments loc $ fmt name <> parens (sep "," $ map fmtAttrInfo attrs) instance Format Liftedness where fmt Unlifted = nil fmt SizeLifted = "~" fmt Lifted = "^" instance Format UncheckedTypeParam where fmt (TypeParamDim name loc) = addComments loc $ brackets $ fmtName bindingStyle name fmt (TypeParamType l name loc) = addComments loc $ "'" <> fmt l <> fmtName bindingStyle name instance Format (UncheckedPat t) where fmt (TuplePat pats loc) = fmtTuple (map fmt pats) loc fmt (RecordPat pats loc) = fmtRecord (map fmtFieldPat pats) loc where -- We detect the implicit form by whether the name and the 't' -- has the same location. fmtFieldPat (L nameloc name, t) | locOf nameloc == locOf t = fmt name | otherwise = lineIndent [nameloc, locOf t] (fmt name <+> "=") (fmt t) fmt (PatParens pat loc) = addComments loc $ "(" <> align (fmt pat) <:/> ")" fmt (Id name _ loc) = addComments loc $ fmtBoundName name fmt (Wildcard _t loc) = addComments loc "_" fmt (PatAscription pat t loc) = addComments loc $ fmt pat <> ":" <+> fmt t fmt (PatLit _e _ loc) = addComments loc $ fmtCopyLoc constantStyle loc fmt (PatConstr n _ [] loc) = addComments loc $ "#" <> fmt n fmt (PatConstr n _ pats loc) = addComments loc $ "#" <> fmt n align (sep line (map fmt pats)) fmt (PatAttr attr pat loc) = addComments loc $ fmt attr <+> fmt pat instance Format (FieldBase NoInfo Name) where fmt (RecordFieldExplicit (L nameloc name) e loc) = addComments loc $ lineIndent [nameloc, locOf e] (fmt name <+> "=") (stdIndent (fmt e)) fmt (RecordFieldImplicit (L _ name) _ loc) = addComments loc $ fmt name instance Format UncheckedDimIndex where fmt (DimFix e) = fmt e fmt (DimSlice i j (Just s)) = maybe nil fmt i <> ":" <> maybe nil fmt j <> ":" <> fmt s fmt (DimSlice i (Just j) s) = maybe nil fmt i <> ":" <> fmt j <> maybe nil ((":" <>) . fmt) s fmt (DimSlice i Nothing Nothing) = maybe nil fmt i <> ":" operatorName :: Name -> Bool operatorName = (`elem` opchars) . T.head . nameToText where opchars :: String opchars = "+-*/%=!><|&^." instance Format PrimValue where fmt pv = text constantStyle $ case pv of UnsignedValue (Int8Value v) -> showText (fromIntegral v :: Word8) <> "u8" UnsignedValue (Int16Value v) -> showText (fromIntegral v :: Word16) <> "u16" UnsignedValue (Int32Value v) -> showText (fromIntegral v :: Word32) <> "u32" UnsignedValue (Int64Value v) -> showText (fromIntegral v :: Word64) <> "u64" SignedValue v -> prettyText v BoolValue True -> "true" BoolValue False -> "false" FloatValue v -> prettyText v updates :: UncheckedExp -> (UncheckedExp, [(Fmt, Fmt)]) updates (RecordUpdate src fs ve _ _) = second (++ [(fs', ve')]) $ updates src where fs' = sep "." $ fmt <$> fs ve' = fmt ve updates (Update src is ve _) = second (++ [(is', ve')]) $ updates src where is' = brackets $ sep ("," <> space) $ map fmt is ve' = fmt ve updates e = (e, []) fmtUpdate :: UncheckedExp -> Fmt fmtUpdate e = -- Special case multiple chained Updates/RecordUpdates. let (root, us) = updates e loc = srclocOf e in addComments loc . localLayout loc $ fmt root <+> align (sep line (map fmtWith us)) where fmtWith (fs', v) = "with" <+> fs' <+> "=" <+> v instance Format UncheckedExp where fmt (Var name _ loc) = addComments loc $ fmtQualName name fmt (Hole _ loc) = addComments loc "???" fmt (Parens e loc) = addComments loc $ "(" <> align (fmt e) <> ")" fmt (QualParens (v, _qLoc) e loc) = addComments loc $ fmtQualName v <> "." <> "(" <> align (fmt e) <> ")" fmt (Ascript e t loc) = addComments loc $ fmt e ":" <+> align (fmt t) fmt (Coerce e t _ loc) = addComments loc $ fmt e ":>" <+> align (fmt t) fmt (Literal _v loc) = addComments loc $ fmtCopyLoc constantStyle loc fmt (IntLit _v _ loc) = addComments loc $ fmtCopyLoc constantStyle loc fmt (FloatLit _v _ loc) = addComments loc $ fmtCopyLoc constantStyle loc fmt (TupLit es loc) = fmtTuple (map (align . fmt) es) loc fmt (RecordLit fs loc) = fmtRecord (map fmt fs) loc fmt (ArrayLit es _ loc) = fmtArray (map (align . fmt) es) loc fmt (StringLit _s loc) = addComments loc $ fmtCopyLoc constantStyle loc fmt (Project k e _ loc) = addComments loc $ fmt e <> "." <> fmt k fmt (Negate e loc) = addComments loc $ "-" <> fmt e fmt (Not e loc) = addComments loc $ "!" <> fmt e fmt e@Update {} = fmtUpdate e fmt e@RecordUpdate {} = fmtUpdate e fmt (Assert e1 e2 _ loc) = addComments loc $ "assert" <+> fmt e1 <+> fmt e2 fmt (Lambda params body rettype _ loc) = addComments loc $ "\\" <> sep space (map fmt params) <> maybe nil (((space <> ":") <+>) . fmt) rettype <+> stdNest ("->" fmt body) fmt (OpSection binop _ loc) = addComments loc $ if operatorName (qualLeaf binop) then fmtQualName binop else parens $ "`" <> fmtQualName binop <> "`" fmt (OpSectionLeft binop _ x _ _ loc) = addComments loc $ parens $ fmt x <+> fmtBinOp binop fmt (OpSectionRight binop _ x _ _ loc) = addComments loc $ parens $ fmtBinOp binop <+> fmt x fmt (ProjectSection fields _ loc) = addComments loc $ parens $ "." <> sep "." (fmt <$> fields) fmt (IndexSection idxs _ loc) = addComments loc $ parens ("." <> idxs') where idxs' = brackets $ sep ("," <> space) $ map fmt idxs fmt (Constr n [] _ loc) = addComments loc $ "#" <> fmt n fmt (Constr n cs _ loc) = addComments loc $ "#" <> fmt n <+> align (sep line $ map fmt cs) fmt (Attr attr e loc) = addComments loc $ align $ fmt attr fmt e fmt (AppExp e _) = fmt e fmt (ArrayVal vs _ loc) = addComments loc $ fmtArray (map fmt vs) loc fmtQualName :: QualName Name -> Fmt fmtQualName (QualName names name) | operatorName name = parens $ pre <> fmt name | otherwise = pre <> fmt name where pre = if null names then nil else sep "." (map fmt names) <> "." instance Format UncheckedCase where fmt (CasePat p e loc) = addComments loc $ "case" <+> fmt p <+> "->" stdIndent (fmt e) instance Format (AppExpBase NoInfo Name) where fmt (BinOp (bop, _) _ (x, _) (y, _) loc) = addComments loc $ align (fmt x) fmtBinOp bop <+> align (fmt y) fmt (Match e cs loc) = addComments loc $ "match" <+> fmt e sep line (map fmt $ toList cs) -- need some way to omit the inital value expression, when this it's trivial fmt (Loop sizeparams pat (LoopInitImplicit NoInfo) form loopbody loc) = addComments loc $ ("loop" `op` sizeparams') <+> localLayout [locOf pat, formloc] (fmt pat fmt form <+> "do") stdIndent (fmt loopbody) where formloc = case form of For i _ -> locOf i ForIn fpat _ -> locOf fpat While e -> locOf e op = if null sizeparams then (<>) else (<+>) sizeparams' = sep nil $ brackets . fmtName bindingStyle . toName <$> sizeparams fmt (Loop sizeparams pat (LoopInitExplicit initexp) form loopbody loc) = addComments loc $ ("loop" `op` sizeparams') <+> align ( lineIndent [locOf pat, locOf initexp] (fmt pat <+> "=") (align $ fmt initexp) ) fmt form <+> "do" stdIndent (fmt loopbody) where op = if null sizeparams then (<>) else (<+>) sizeparams' = sep nil $ brackets . fmtName bindingStyle . toName <$> sizeparams fmt (Index e idxs loc) = addComments loc $ (fmt e <>) $ brackets $ sepLine "," $ map fmt idxs fmt (LetPat sizes pat e body loc) = addComments loc $ lineIndent [locOf pat, locOf e] ("let" <+> sub <+> "=") (fmt e) letBody body where sizes' = sep nil $ map fmt sizes sub | null sizes = fmt pat | otherwise = sizes' <+> fmt pat fmt (LetFun fname (tparams, params, retdecl, _, e) body loc) = addComments loc $ lineIndent e ( "let" <+> fmtName bindingStyle fname <> sub <> retdecl' <> "=" ) (fmt e) letBody body where tparams' = sep space $ map fmt tparams params' = sep space $ map fmt params retdecl' = case retdecl of Just a -> ":" <+> fmt a <> space Nothing -> space sub | null tparams && null params = nil | null tparams = space <> params' | null params = space <> tparams' | otherwise = space <> tparams' <+> params' fmt (LetWith dest src idxs ve body loc) | dest == src = addComments loc $ lineIndent ve ( "let" <+> fmt dest <> idxs' <+> "=" ) (fmt ve) letBody body | otherwise = addComments loc $ lineIndent ve ( "let" <+> fmt dest <+> "=" <+> fmt src <+> "with" <+> idxs' ) (fmt ve) letBody body where idxs' = brackets $ sep ", " $ map fmt idxs fmt (Range start maybe_step end loc) = addComments loc $ fmt start <> step <> end' where end' = case end of DownToExclusive e -> "..>" <> fmt e ToInclusive e -> "..." <> fmt e UpToExclusive e -> "..<" <> fmt e step = maybe nil ((".." <>) . fmt) maybe_step fmt (If c t f loc) = addComments loc $ "if" <+> fmt c "then" <+> align (fmt t) "else" <> case f of AppExp If {} _ -> space <> fmt f _ -> space <> align (fmt f) fmt (Apply f args loc) = addComments loc $ fmt f <+> fmt_args where fmt_args = sepArgs fmt $ fmap snd args letBody :: UncheckedExp -> Fmt letBody body@(AppExp LetPat {} _) = fmt body letBody body@(AppExp LetFun {} _) = fmt body letBody body@(AppExp LetWith {} _) = fmt body letBody body = addComments body $ "in" <+> align (fmt body) instance Format (SizeBinder Name) where fmt (SizeBinder v loc) = addComments loc $ brackets $ fmtName bindingStyle v instance Format (IdentBase NoInfo Name t) where fmt = fmtName bindingStyle . identName instance Format (LoopFormBase NoInfo Name) where fmt (For i ubound) = "for" <+> fmt i <+> "<" <+> fmt ubound fmt (ForIn x e) = "for" <+> fmt x <+> "in" <+> fmt e fmt (While cond) = "while" <+> fmt cond -- | This should always be simplified by location. fmtBinOp :: QualName Name -> Fmt fmtBinOp bop = case leading of Backtick -> "`" <> fmtQualName bop <> "`" _ -> text infixStyle (prettyText bop) where leading = leadingOperator $ toName $ qualLeaf bop instance Format UncheckedValBind where fmt (ValBind entry name retdecl _rettype tparams args body docs attrs loc) = addComments loc $ fmt docs <> attrs' <> (fun <+> fmtBoundName name) <> sub <> retdecl' <> "=" stdIndent (fmt body) where attrs' = if null attrs then nil else sep space (map fmt attrs) <> hardline tparams' = localLayoutList tparams $ align $ sep line $ map fmt tparams args' = localLayoutList args $ align $ sep line $ map fmt args retdecl' = case retdecl of Just a -> space <> ":" <+> fmt a <> space Nothing -> space sub | null tparams && null args = nil | null tparams = space <> args' | null args = space <> tparams' | otherwise = localLayout [locOf tparams, locOf args] $ space <> align (tparams' args') fun = case entry of Just _ -> "entry" _any -> "def" instance Format (SizeExp UncheckedExp) where fmt (SizeExp d loc) = addComments loc $ brackets $ fmt d fmt (SizeExpAny loc) = addComments loc $ brackets nil instance Format UncheckedSpec where fmt (TypeAbbrSpec tpsig) = fmt tpsig fmt (TypeSpec l name ps doc loc) = addComments loc $ fmt doc <> "type" <> fmt l <+> sub where sub | null ps = fmtName bindingStyle name | otherwise = fmtName bindingStyle name align (sep line $ map fmt ps) fmt (ValSpec name ps te _ doc loc) = addComments loc $ fmt doc <> "val" <+> sub <+> ":" stdIndent (fmt te) where sub | null ps = fmtName bindingStyle name | otherwise = fmtName bindingStyle name <+> align (sep space $ map fmt ps) fmt (ModSpec name mte doc loc) = addComments loc $ fmt doc <> "module" <+> fmtName bindingStyle name <> ":" <+> fmt mte fmt (IncludeSpec mte loc) = addComments loc $ "include" <+> fmt mte typeWiths :: UncheckedModTypeExp -> (UncheckedModTypeExp, [TypeRefBase NoInfo Name]) typeWiths (ModTypeWith mte tr _) = second (tr :) $ typeWiths mte typeWiths mte = (mte, []) instance Format UncheckedModTypeExp where fmt (ModTypeVar v _ loc) = addComments loc $ fmtPretty v fmt (ModTypeParens mte loc) = addComments loc $ "(" <> align (fmt mte) <> ")" fmt (ModTypeSpecs sbs loc) = addComments loc $ "{" <:/> stdIndent (sepDecs fmt sbs) <:/> "}" fmt (ModTypeWith mte tr loc) = -- Special case multiple chained ModTypeWiths. let (root, withs) = typeWiths mte in addComments loc . localLayout loc $ fmt root sep line (map fmtWith (reverse $ tr : withs)) where fmtWith (TypeRef v ps td _) = "with" <+> fmtPretty v `ps_op` sep space (map fmt ps) <+> "=" <+> fmt td where ps_op = if null ps then (<>) else (<+>) fmt (ModTypeArrow (Just v) te0 te1 loc) = addComments loc $ parens (fmtName bindingStyle v <> ":" <+> fmt te0) <+> align ("->" fmt te1) fmt (ModTypeArrow Nothing te0 te1 loc) = addComments loc $ fmt te0 <+> "->" <+> fmt te1 instance Format UncheckedModTypeBind where fmt (ModTypeBind pName pSig doc loc) = addComments loc $ fmt doc <> "module" <+> "type" <+> fmtName bindingStyle pName <+> "=" <> case pSig of ModTypeSpecs {} -> space <> fmt pSig _ -> line <> stdIndent (fmt pSig) instance Format (ModParamBase NoInfo Name) where fmt (ModParam pName pSig _f loc) = addComments loc $ parens $ fmtName bindingStyle pName <> ":" <+> fmt pSig instance Format UncheckedModBind where fmt (ModBind name ps sig me doc loc) = addComments loc $ fmt doc <> "module" <+> localLayout [locStart (locOf loc), locOf ps] (fmtName bindingStyle name <> ps') <> fmtSig sig <> "=" <> me' where me' = fmtByLayout me (line <> stdIndent (fmt me)) (space <> fmt me) fmtSig Nothing = space fmtSig (Just (s', _f)) = localLayout (map locOf ps ++ [locOf s']) $ line <> stdIndent (":" <+> align (fmt s') <> space) ps' = case ps of [] -> nil _any -> line <> stdIndent (localLayoutList ps (align $ sep line $ map fmt ps)) -- All of these should probably be "extra" indented instance Format UncheckedModExp where fmt (ModVar v loc) = addComments loc $ fmtQualName v fmt (ModParens f loc) = addComments loc $ "(" <:/> stdIndent (fmt f) <> ")" fmt (ModImport path _f loc) = addComments loc $ "import" <+> "\"" <> fmtPretty path <> "\"" fmt (ModDecs decs loc) = addComments loc $ "{" <:/> stdIndent (sepDecs fmt decs) <:/> "}" fmt (ModApply f a _f0 _f1 loc) = addComments loc $ fmt f <+> fmt a fmt (ModAscript me se _f loc) = addComments loc $ align (fmt me <> ":" fmt se) fmt (ModLambda param maybe_sig body loc) = addComments loc $ "\\" <> fmt param <> sig <+> "->" stdIndent (fmt body) where sig = case maybe_sig of Nothing -> nil Just (sig', _) -> ":" <+> parens (fmt sig') instance Format UncheckedDec where fmt (ValDec t) = fmt t fmt (TypeDec tb) = fmt tb fmt (ModTypeDec tb) = fmt tb fmt (ModDec tb) = fmt tb fmt (OpenDec tb loc) = addComments loc $ "open" <+> fmt tb fmt (LocalDec tb loc) = addComments loc $ "local" fmt tb fmt (ImportDec path _tb loc) = addComments loc $ "import" <+> "\"" <> fmtPretty path <> "\"" instance Format UncheckedProg where fmt (Prog Nothing []) = popComments fmt (Prog Nothing decs) = sepDecs fmt decs popComments fmt (Prog (Just dc) decs) = fmt dc sepDecs fmt decs popComments -- | Given a filename and a futhark program, formats the program. fmtToDoc :: String -> T.Text -> Either SyntaxError (Doc AnsiStyle) fmtToDoc fname fcontent = do (prog, cs) <- parseFutharkWithComments fname fcontent pure $ runFormat (fmt prog) cs fcontent -- | Given a filename and a futhark program, formats the program as -- text. fmtToText :: String -> T.Text -> Either SyntaxError T.Text fmtToText fname fcontent = docText <$> fmtToDoc fname fcontent futhark-0.25.27/src/Futhark/Format.hs000066400000000000000000000013161475065116200173220ustar00rootroot00000000000000-- | Parsing of format strings. module Futhark.Format (parseFormatString) where import Data.Bifunctor import Data.Text qualified as T import Data.Void import Text.Megaparsec pFormatString :: Parsec Void T.Text [Either T.Text T.Text] pFormatString = many (choice [Left <$> pLiteral, Right <$> pInterpolation]) <* eof where pInterpolation = "{" *> takeWhileP Nothing (`notElem` braces) <* "}" pLiteral = takeWhile1P Nothing (`notElem` braces) braces = "{}" :: String -- | The Lefts are pure text; the Rights are the contents of -- interpolations. parseFormatString :: T.Text -> Either T.Text [Either T.Text T.Text] parseFormatString = first (T.pack . errorBundlePretty) . runParser pFormatString "" futhark-0.25.27/src/Futhark/FreshNames.hs000066400000000000000000000023611475065116200201260ustar00rootroot00000000000000-- | This module provides facilities for generating unique names. module Futhark.FreshNames ( VNameSource, blankNameSource, newNameSource, newName, ) where import Language.Futhark.Core import Language.Haskell.TH.Syntax (Lift) -- | A name source is conceptually an infinite sequence of names with -- no repeating entries. In practice, when asked for a name, the name -- source will return the name along with a new name source, which -- should then be used in place of the original. -- -- The 'Ord' instance is based on how many names have been extracted -- from the name source. newtype VNameSource = VNameSource Int deriving (Lift, Eq, Ord) instance Semigroup VNameSource where VNameSource x <> VNameSource y = VNameSource (x `max` y) instance Monoid VNameSource where mempty = blankNameSource -- | Produce a fresh name, using the given name as a template. newName :: VNameSource -> VName -> (VName, VNameSource) newName (VNameSource i) k = i' `seq` (VName (baseName k) i, VNameSource i') where i' = i + 1 -- | A blank name source. blankNameSource :: VNameSource blankNameSource = newNameSource 0 -- | A new name source that starts counting from the given number. newNameSource :: Int -> VNameSource newNameSource = VNameSource futhark-0.25.27/src/Futhark/IR.hs000066400000000000000000000004571475065116200164110ustar00rootroot00000000000000-- | A convenient re-export of basic AST modules. module Futhark.IR ( module Futhark.IR.Prop, module Futhark.IR.Traversals, module Futhark.IR.Pretty, module Futhark.IR.Syntax, ) where import Futhark.IR.Pretty import Futhark.IR.Prop import Futhark.IR.Syntax import Futhark.IR.Traversals futhark-0.25.27/src/Futhark/IR/000077500000000000000000000000001475065116200160475ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/IR/Aliases.hs000066400000000000000000000312651475065116200177730ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -- | A representation where all patterns are annotated with aliasing -- information. It also records consumption of variables in bodies. -- -- Note that this module is mostly not concerned with actually -- /computing/ the aliasing information; only with shuffling it around -- and providing some basic building blocks. See modules such as -- "Futhark.Analysis.Alias" for computing the aliases in the first -- place. module Futhark.IR.Aliases ( -- * The representation definition Aliases, AliasDec (..), VarAliases, ConsumedInExp, BodyAliasing, module Futhark.IR.Prop.Aliases, -- * Module re-exports module Futhark.IR.Prop, module Futhark.IR.Traversals, module Futhark.IR.Pretty, module Futhark.IR.Syntax, -- * Adding aliases mkAliasedBody, mkAliasedPat, mkBodyAliasing, CanBeAliased (..), AliasableRep, -- * Removing aliases removeProgAliases, removeFunDefAliases, removeExpAliases, removeStmAliases, removeBodyAliases, removeLambdaAliases, removePatAliases, removeScopeAliases, -- * Tracking aliases AliasesAndConsumed, trackAliases, mkStmsAliases, consumedInStms, ) where import Control.Monad.Identity import Control.Monad.Reader import Data.Kind qualified import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Builder import Futhark.IR.Pretty import Futhark.IR.Prop import Futhark.IR.Prop.Aliases import Futhark.IR.Syntax import Futhark.IR.Traversals import Futhark.Transform.Rename import Futhark.Transform.Substitute import Futhark.Util.Pretty qualified as PP -- | The rep for the basic representation. data Aliases (rep :: Data.Kind.Type) -- | A wrapper around 'AliasDec' to get around the fact that we need an -- 'Ord' instance, which 'AliasDec does not have. newtype AliasDec = AliasDec {unAliases :: Names} deriving (Show) instance Semigroup AliasDec where x <> y = AliasDec $ unAliases x <> unAliases y instance Monoid AliasDec where mempty = AliasDec mempty instance Eq AliasDec where _ == _ = True instance Ord AliasDec where _ `compare` _ = EQ instance Rename AliasDec where rename (AliasDec names) = AliasDec <$> rename names instance Substitute AliasDec where substituteNames substs (AliasDec names) = AliasDec $ substituteNames substs names instance FreeIn AliasDec where freeIn' = const mempty instance PP.Pretty AliasDec where pretty = PP.braces . PP.commasep . map PP.pretty . namesToList . unAliases -- | The aliases of the let-bound variable. type VarAliases = AliasDec -- | Everything consumed in the expression. type ConsumedInExp = AliasDec -- | The aliases of what is returned by the t'Body', and what is -- consumed inside of it. type BodyAliasing = ([VarAliases], ConsumedInExp) instance (RepTypes rep, ASTConstraints (OpC rep (Aliases rep))) => RepTypes (Aliases rep) where type LetDec (Aliases rep) = (VarAliases, LetDec rep) type ExpDec (Aliases rep) = (ConsumedInExp, ExpDec rep) type BodyDec (Aliases rep) = (BodyAliasing, BodyDec rep) type FParamInfo (Aliases rep) = FParamInfo rep type LParamInfo (Aliases rep) = LParamInfo rep type RetType (Aliases rep) = RetType rep type BranchType (Aliases rep) = BranchType rep type OpC (Aliases rep) = OpC rep instance AliasesOf (VarAliases, dec) where aliasesOf = unAliases . fst instance FreeDec AliasDec withoutAliases :: (HasScope (Aliases rep) m, Monad m) => ReaderT (Scope rep) m a -> m a withoutAliases m = do scope <- asksScope removeScopeAliases runReaderT m scope instance ( ASTRep rep, AliasedOp (OpC rep), ASTConstraints (OpC rep (Aliases rep)) ) => ASTRep (Aliases rep) where expTypesFromPat = withoutAliases . expTypesFromPat . removePatAliases instance ( ASTRep rep, AliasedOp (OpC rep), ASTConstraints (OpC rep (Aliases rep)) ) => Aliased (Aliases rep) where bodyAliases = map unAliases . fst . fst . bodyDec consumedInBody = unAliases . snd . fst . bodyDec instance ( ASTRep rep, AliasedOp (OpC rep), ASTConstraints (OpC rep (Aliases rep)) ) => PrettyRep (Aliases rep) where ppExpDec (consumed, inner) e = maybeComment . catMaybes $ [exp_dec, merge_dec, ppExpDec inner $ removeExpAliases e] where merge_dec = case e of Loop merge _ body -> let mergeParamAliases fparam als | primType (paramType fparam) = Nothing | otherwise = resultAliasComment (paramName fparam) als in maybeComment . catMaybes $ zipWith mergeParamAliases (map fst merge) $ bodyAliases body _ -> Nothing exp_dec = case namesToList $ unAliases consumed of [] -> Nothing als -> Just $ PP.oneLine $ "-- Consumes " <> PP.commasep (map PP.pretty als) maybeComment :: [PP.Doc a] -> Maybe (PP.Doc a) maybeComment [] = Nothing maybeComment cs = Just $ PP.stack cs resultAliasComment :: (PP.Pretty a) => a -> Names -> Maybe (PP.Doc ann) resultAliasComment name als = case namesToList als of [] -> Nothing als' -> Just $ PP.oneLine $ "-- Result for " <> PP.pretty name <> " aliases " <> PP.commasep (map PP.pretty als') removeAliases :: (RephraseOp (OpC rep)) => Rephraser Identity (Aliases rep) rep removeAliases = Rephraser { rephraseExpDec = pure . snd, rephraseLetBoundDec = pure . snd, rephraseBodyDec = pure . snd, rephraseFParamDec = pure, rephraseLParamDec = pure, rephraseRetType = pure, rephraseBranchType = pure, rephraseOp = rephraseInOp removeAliases } -- | Remove alias information from an aliased scope. removeScopeAliases :: Scope (Aliases rep) -> Scope rep removeScopeAliases = M.map unAlias where unAlias (LetName (_, dec)) = LetName dec unAlias (FParamName dec) = FParamName dec unAlias (LParamName dec) = LParamName dec unAlias (IndexName it) = IndexName it -- | Remove alias information from a program. removeProgAliases :: (RephraseOp (OpC rep)) => Prog (Aliases rep) -> Prog rep removeProgAliases = runIdentity . rephraseProg removeAliases -- | Remove alias information from a function. removeFunDefAliases :: (RephraseOp (OpC rep)) => FunDef (Aliases rep) -> FunDef rep removeFunDefAliases = runIdentity . rephraseFunDef removeAliases -- | Remove alias information from an expression. removeExpAliases :: (RephraseOp (OpC rep)) => Exp (Aliases rep) -> Exp rep removeExpAliases = runIdentity . rephraseExp removeAliases -- | Remove alias information from statements. removeStmAliases :: (RephraseOp (OpC rep)) => Stm (Aliases rep) -> Stm rep removeStmAliases = runIdentity . rephraseStm removeAliases -- | Remove alias information from body. removeBodyAliases :: (RephraseOp (OpC rep)) => Body (Aliases rep) -> Body rep removeBodyAliases = runIdentity . rephraseBody removeAliases -- | Remove alias information from lambda. removeLambdaAliases :: (RephraseOp (OpC rep)) => Lambda (Aliases rep) -> Lambda rep removeLambdaAliases = runIdentity . rephraseLambda removeAliases -- | Remove alias information from pattern. removePatAliases :: Pat (AliasDec, a) -> Pat a removePatAliases = runIdentity . rephrasePat (pure . snd) -- | Augment a body decoration with aliasing information provided by -- the statements and result of that body. mkAliasedBody :: (ASTRep rep, AliasedOp (OpC rep), ASTConstraints (OpC rep (Aliases rep))) => BodyDec rep -> Stms (Aliases rep) -> Result -> Body (Aliases rep) mkAliasedBody dec stms res = Body (mkBodyAliasing stms res, dec) stms res -- | Augment a pattern with aliasing information provided by the -- expression the pattern is bound to. mkAliasedPat :: (Aliased rep, Typed dec) => Pat dec -> Exp rep -> Pat (VarAliases, dec) mkAliasedPat (Pat pes) e = Pat $ zipWith annotate pes $ expAliases pes e where annotate (PatElem v dec) names = PatElem v (AliasDec names', dec) where names' = case typeOf dec of Array {} -> names Mem _ -> names _ -> mempty -- | Given statements (with aliasing information) and a body result, -- produce aliasing information for the corresponding body as a whole. -- The aliasing includes names bound in the body, i.e. which are not -- in scope outside of it. Note that this does *not* include aliases -- of results that are not bound in the statements! mkBodyAliasing :: (Aliased rep) => Stms rep -> Result -> BodyAliasing mkBodyAliasing stms res = -- We need to remove the names that are bound in stms from the alias -- and consumption sets. We do this by computing the transitive -- closure of the alias map (within stms), then removing anything -- bound in stms. let (aliases, consumed) = mkStmsAliases stms res boundNames = foldMap (namesFromList . patNames . stmPat) stms consumed' = consumed `namesSubtract` boundNames in (map AliasDec aliases, AliasDec consumed') -- | The aliases of the result and everything consumed in the given -- statements. mkStmsAliases :: (Aliased rep) => Stms rep -> Result -> ([Names], Names) mkStmsAliases stms res = delve mempty $ stmsToList stms where delve (aliasmap, consumed) [] = ( map (aliasClosure aliasmap . subExpAliases . resSubExp) res, consumed ) delve (aliasmap, consumed) (stm : stms') = delve (trackAliases (aliasmap, consumed) stm) stms' aliasClosure aliasmap names = names <> mconcat (map look $ namesToList names) where look k = M.findWithDefault mempty k aliasmap -- | A tuple of a mapping from variable names to their aliases, and -- the names of consumed variables. type AliasesAndConsumed = ( M.Map VName Names, Names ) -- | The variables consumed in these statements. consumedInStms :: (Aliased rep) => Stms rep -> Names consumedInStms = snd . flip mkStmsAliases [] -- | A helper function for computing the aliases of a sequence of -- statements. You'd use this while recursing down the statements -- from first to last. The 'AliasesAndConsumed' parameter is the -- current "state" of aliasing, and the function then returns a new -- state. The main thing this function provides is proper handling of -- transitivity and "reverse" aliases. trackAliases :: (Aliased rep) => AliasesAndConsumed -> Stm rep -> AliasesAndConsumed trackAliases (aliasmap, consumed) stm = let pat = stmPat stm pe_als = zip (patNames pat) $ map addAliasesOfAliases $ patAliases pat als = M.fromList pe_als rev_als = foldMap revAls pe_als revAls (v, v_als) = M.fromList $ map (,oneName v) $ namesToList v_als comb = M.unionWith (<>) aliasmap' = rev_als `comb` als `comb` aliasmap consumed' = consumed <> addAliasesOfAliases (consumedInStm stm) in (aliasmap', consumed') where addAliasesOfAliases names = names <> aliasesOfAliases names aliasesOfAliases = mconcat . map look . namesToList look k = M.findWithDefault mempty k aliasmap mkAliasedStm :: (ASTRep rep, AliasedOp (OpC rep), ASTConstraints (OpC rep (Aliases rep))) => Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp (Aliases rep) -> Stm (Aliases rep) mkAliasedStm pat (StmAux cs attrs dec) e = Let (mkAliasedPat pat e) (StmAux cs attrs (AliasDec $ consumedInExp e, dec)) e instance ( Buildable rep, AliasedOp (OpC rep), ASTConstraints (OpC rep (Aliases rep)) ) => Buildable (Aliases rep) where mkExpDec pat e = let dec = mkExpDec (removePatAliases pat) $ removeExpAliases e in (AliasDec $ consumedInExp e, dec) mkExpPat ids e = mkAliasedPat (mkExpPat ids $ removeExpAliases e) e mkLetNames names e = do env <- asksScope removeScopeAliases flip runReaderT env $ do Let pat dec _ <- mkLetNames names $ removeExpAliases e pure $ mkAliasedStm pat dec e mkBody stms res = let Body bodyrep _ _ = mkBody (fmap removeStmAliases stms) res in mkAliasedBody bodyrep stms res instance ( ASTRep rep, AliasedOp (OpC rep), Buildable (Aliases rep) ) => BuilderOps (Aliases rep) -- | What we require of an aliasable representation. type AliasableRep rep = ( ASTRep rep, RephraseOp (OpC rep), CanBeAliased (OpC rep), AliasedOp (OpC rep), ASTConstraints (OpC rep (Aliases rep)) ) -- | The class of operations that can be given aliasing information. -- This is a somewhat subtle concept that is only used in the -- simplifier and when using "rep adapters". class CanBeAliased op where -- | Add aliases to this op. addOpAliases :: (AliasableRep rep) => AliasTable -> op rep -> op (Aliases rep) instance CanBeAliased NoOp where addOpAliases _ NoOp = NoOp futhark-0.25.27/src/Futhark/IR/GPU.hs000066400000000000000000000075401475065116200170440ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | A representation with flat parallelism via GPU-oriented kernels. module Futhark.IR.GPU ( GPU, -- * Module re-exports module Futhark.IR.Prop, module Futhark.IR.Traversals, module Futhark.IR.Pretty, module Futhark.IR.Syntax, module Futhark.IR.GPU.Op, module Futhark.IR.GPU.Sizes, module Futhark.IR.SOACS.SOAC, ) where import Futhark.Builder import Futhark.Construct import Futhark.IR.Aliases (Aliases) import Futhark.IR.GPU.Op import Futhark.IR.GPU.Sizes import Futhark.IR.Pretty import Futhark.IR.Prop import Futhark.IR.SOACS.SOAC hiding (HistOp (..)) import Futhark.IR.Syntax import Futhark.IR.Traversals import Futhark.IR.TypeCheck qualified as TC -- | The phantom data type for the kernels representation. data GPU instance RepTypes GPU where type OpC GPU = HostOp SOAC instance ASTRep GPU where expTypesFromPat = pure . expExtTypesFromPat instance TC.Checkable GPU where checkOp = typeCheckGPUOp Nothing where -- GHC 9.2 goes into an infinite loop without the type annotation. typeCheckGPUOp :: Maybe SegLevel -> HostOp SOAC (Aliases GPU) -> TC.TypeM GPU () typeCheckGPUOp lvl = typeCheckHostOp (typeCheckGPUOp . Just) lvl typeCheckSOAC instance Buildable GPU where mkBody = Body () mkExpPat idents _ = basicPat idents mkExpDec _ _ = () mkLetNames = simpleMkLetNames instance BuilderOps GPU instance PrettyRep GPU instance HasSegOp GPU where type SegOpLevel GPU = SegLevel asSegOp (SegOp op) = Just op asSegOp _ = Nothing segOp = SegOp -- Note [GPU Terminology] -- -- For lack of a better spot to put it, this Note summarises the -- terminology used for GPU concepts in the Futhark compiler. The -- terminology is based on CUDA terminology, and tries to match it as -- closely as possible. However, this was not always the case (issue -- #2062), so you may find some code that uses e.g. OpenCL -- terminology. In most cases there is no ambiguity, but there are a -- few instances where the same term is used for different things. -- Please fix any instances you find. -- -- The terminology is as follows: -- -- Host: Essentially the CPU; whatever is controlling the GPU. -- -- Kernel: A GPU program that can be launched from the host. -- -- Grid: The geometry of the thread blocks launched for a kernel. The -- size of a grid is always in terms of the number of thread blocks -- ("grid size"). A grid can have up to 3 dimensions, although we do -- not make much use of it - and not at all prior to code generation. -- -- Thread block: Just as in CUDA. "Workgroup" in OpenCL. Abbretiation: -- tblock. Never just call this "block"; there are too many things -- called "block". Must match the dimensionality of the grid. -- -- Thread: Just as in CUDA. "Workitem" in OpenCL. -- -- Global thread identifier: A globally unique number for a thread -- along one dimension. Abbreviation: gtid. We also use this term for -- the identifiers bound by SegOps. In OpenCL, corresponds to -- get_global_id(). (Except when we virtualise the thread space.) -- -- Local thread identifier: A locally unique number (within the thread -- block) for each thread. Abbreviation: ltid. In OpenCL, corresponds -- to get_local_id(). In CUDA, corresponds to threadIdx. -- -- Thread block identifier: A number unique to each thread block in a -- single dimension. In CUDA, corresponds to blockIdx. -- -- Local memory: Thread-local private memory. In CUDA, this is -- sometimes put in registers (if you are very careful in how you use -- it). In OpenCL, this is called "private memory", and "local memory" -- is something else entirely. -- -- Shared memory: Just as in CUDA. Fast scratchpad memory accessible -- to all threads within the same thread block. In OpenCL, this is -- "local memory". -- -- Device memory: Sometimes also called "global memory"; this is the -- big-but-slow memory on the GPU. futhark-0.25.27/src/Futhark/IR/GPU/000077500000000000000000000000001475065116200165025ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/IR/GPU/Op.hs000066400000000000000000000353161475065116200174240ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module Futhark.IR.GPU.Op ( -- * Size operations SizeOp (..), -- * Host operations HostOp (..), traverseHostOpStms, typeCheckHostOp, -- * SegOp refinements SegLevel (..), segVirt, SegVirt (..), SegSeqDims (..), KernelGrid (..), -- * Reexports module Futhark.IR.GPU.Sizes, module Futhark.IR.SegOp, ) where import Control.Monad import Data.Sequence qualified as SQ import Data.Text qualified as T import Futhark.Analysis.Alias qualified as Alias import Futhark.Analysis.Metrics import Futhark.Analysis.SymbolTable qualified as ST import Futhark.IR import Futhark.IR.Aliases (Aliases, CanBeAliased (..)) import Futhark.IR.GPU.Sizes import Futhark.IR.Mem (OpReturns (..), extReturns) import Futhark.IR.Prop.Aliases import Futhark.IR.SegOp import Futhark.IR.TypeCheck qualified as TC import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Optimise.Simplify.Rep import Futhark.Transform.Rename import Futhark.Transform.Substitute import Futhark.Util.Pretty ( commasep, parens, ppTuple', pretty, (<+>), ) import Futhark.Util.Pretty qualified as PP -- | These dimensions (indexed from 0, outermost) of the corresponding -- 'SegSpace' should not be parallelised, but instead iterated -- sequentially. For example, with a 'SegSeqDims' of @[0]@ and a -- 'SegSpace' with dimensions @[n][m]@, there will be an outer loop -- with @n@ iterations, while the @m@ dimension will be parallelised. -- -- Semantically, this has no effect, but it may allow reductions in -- memory usage or other low-level optimisations. Operationally, the -- guarantee is that for a SegSeqDims of e.g. @[i,j,k]@, threads -- running at any given moment will always have the same indexes along -- the dimensions specified by @[i,j,k]@. -- -- At the moment, this is only supported for 'SegNoVirtFull' -- intra-block parallelism in GPU code, as we have not yet found it -- useful anywhere else. newtype SegSeqDims = SegSeqDims {segSeqDims :: [Int]} deriving (Eq, Ord, Show) -- | Do we need block-virtualisation when generating code for the -- segmented operation? In most cases, we do, but for some simple -- kernels, we compute the full number of blocks in advance, and then -- virtualisation is an unnecessary (but generally very small) -- overhead. This only really matters for fairly trivial but very -- wide @map@ kernels where each thread performs constant-time work on -- scalars. data SegVirt = SegVirt | SegNoVirt | -- | Not only do we not need virtualisation, but we _guarantee_ -- that all physical threads participate in the work. This can -- save some checks in code generation. SegNoVirtFull SegSeqDims deriving (Eq, Ord, Show) -- | The actual, physical grid dimensions used for the GPU kernel -- running this 'SegOp'. data KernelGrid = KernelGrid { gridNumBlocks :: Count NumBlocks SubExp, gridBlockSize :: Count BlockSize SubExp } deriving (Eq, Ord, Show) -- | At which level the *body* of a t'SegOp' executes. data SegLevel = SegThread SegVirt (Maybe KernelGrid) | SegBlock SegVirt (Maybe KernelGrid) | SegThreadInBlock SegVirt deriving (Eq, Ord, Show) -- | The 'SegVirt' of the 'SegLevel'. segVirt :: SegLevel -> SegVirt segVirt (SegThread v _) = v segVirt (SegBlock v _) = v segVirt (SegThreadInBlock v) = v instance PP.Pretty SegVirt where pretty SegNoVirt = mempty pretty (SegNoVirtFull dims) = "full" <+> pretty (segSeqDims dims) pretty SegVirt = "virtualise" instance PP.Pretty KernelGrid where pretty (KernelGrid num_tblocks tblock_size) = "grid=" <> pretty num_tblocks <> PP.semi <+> "blocksize=" <> pretty tblock_size instance PP.Pretty SegLevel where pretty (SegThread virt grid) = PP.parens ("thread" <> PP.semi <+> pretty virt <> PP.semi <+> pretty grid) pretty (SegBlock virt grid) = PP.parens ("block" <> PP.semi <+> pretty virt <> PP.semi <+> pretty grid) pretty (SegThreadInBlock virt) = PP.parens ("inblock" <> PP.semi <+> pretty virt) instance Engine.Simplifiable KernelGrid where simplify (KernelGrid num_tblocks tblock_size) = KernelGrid <$> traverse Engine.simplify num_tblocks <*> traverse Engine.simplify tblock_size instance Engine.Simplifiable SegLevel where simplify (SegThread virt grid) = SegThread virt <$> Engine.simplify grid simplify (SegBlock virt grid) = SegBlock virt <$> Engine.simplify grid simplify (SegThreadInBlock virt) = pure $ SegThreadInBlock virt instance Substitute KernelGrid where substituteNames substs (KernelGrid num_tblocks tblock_size) = KernelGrid (substituteNames substs num_tblocks) (substituteNames substs tblock_size) instance Substitute SegLevel where substituteNames substs (SegThread virt grid) = SegThread virt (substituteNames substs grid) substituteNames substs (SegBlock virt grid) = SegBlock virt (substituteNames substs grid) substituteNames _ (SegThreadInBlock virt) = SegThreadInBlock virt instance Rename SegLevel where rename = substituteRename instance FreeIn KernelGrid where freeIn' (KernelGrid num_tblocks tblock_size) = freeIn' (num_tblocks, tblock_size) instance FreeIn SegLevel where freeIn' (SegThread _virt grid) = freeIn' grid freeIn' (SegBlock _virt grid) = freeIn' grid freeIn' (SegThreadInBlock _virt) = mempty -- | A simple size-level query or computation. data SizeOp = -- | Produce some runtime-configurable size. GetSize Name SizeClass | -- | The maximum size of some class. GetSizeMax SizeClass | -- | Compare size (likely a threshold) with some integer value. CmpSizeLe Name SizeClass SubExp | -- | @CalcNumBlocks w max_num_tblocks tblock_size@ calculates the -- number of GPU threadblocks to use for an input of the given size. -- The @Name@ is a size name. Note that @w@ is an i64 to avoid -- overflow issues. CalcNumBlocks SubExp Name SubExp deriving (Eq, Ord, Show) instance Substitute SizeOp where substituteNames substs (CmpSizeLe name sclass x) = CmpSizeLe name sclass (substituteNames substs x) substituteNames substs (CalcNumBlocks w max_num_tblocks tblock_size) = CalcNumBlocks (substituteNames substs w) max_num_tblocks (substituteNames substs tblock_size) substituteNames _ op = op instance Rename SizeOp where rename (CmpSizeLe name sclass x) = CmpSizeLe name sclass <$> rename x rename (CalcNumBlocks w max_num_tblocks tblock_size) = CalcNumBlocks <$> rename w <*> pure max_num_tblocks <*> rename tblock_size rename x = pure x instance FreeIn SizeOp where freeIn' (CmpSizeLe _ _ x) = freeIn' x freeIn' (CalcNumBlocks w _ tblock_size) = freeIn' w <> freeIn' tblock_size freeIn' _ = mempty instance PP.Pretty SizeOp where pretty (GetSize name size_class) = "get_size" <> parens (commasep [pretty name, pretty size_class]) pretty (GetSizeMax size_class) = "get_size_max" <> parens (commasep [pretty size_class]) pretty (CmpSizeLe name size_class x) = "cmp_size" <> parens (commasep [pretty name, pretty size_class]) <+> "<=" <+> pretty x pretty (CalcNumBlocks w max_num_tblocks tblock_size) = "calc_num_tblocks" <> parens (commasep [pretty w, pretty max_num_tblocks, pretty tblock_size]) instance OpMetrics SizeOp where opMetrics GetSize {} = seen "GetSize" opMetrics GetSizeMax {} = seen "GetSizeMax" opMetrics CmpSizeLe {} = seen "CmpSizeLe" opMetrics CalcNumBlocks {} = seen "CalcNumBlocks" typeCheckSizeOp :: (TC.Checkable rep) => SizeOp -> TC.TypeM rep () typeCheckSizeOp GetSize {} = pure () typeCheckSizeOp GetSizeMax {} = pure () typeCheckSizeOp (CmpSizeLe _ _ x) = TC.require [Prim int64] x typeCheckSizeOp (CalcNumBlocks w _ tblock_size) = do TC.require [Prim int64] w TC.require [Prim int64] tblock_size -- | A host-level operation; parameterised by what else it can do. data HostOp op rep = -- | A segmented operation. SegOp (SegOp SegLevel rep) | SizeOp SizeOp | OtherOp (op rep) | -- | Code to run sequentially on the GPU, -- in a single thread. GPUBody [Type] (Body rep) deriving (Eq, Ord, Show) -- | A helper for defining 'TraverseOpStms'. traverseHostOpStms :: (Monad m) => OpStmsTraverser m (op rep) rep -> OpStmsTraverser m (HostOp op rep) rep traverseHostOpStms _ f (SegOp segop) = SegOp <$> traverseSegOpStms f segop traverseHostOpStms _ _ (SizeOp sizeop) = pure $ SizeOp sizeop traverseHostOpStms onOtherOp f (OtherOp other) = OtherOp <$> onOtherOp f other traverseHostOpStms _ f (GPUBody ts body) = do stms <- f mempty $ bodyStms body pure $ GPUBody ts $ body {bodyStms = stms} instance (ASTRep rep, Substitute (op rep)) => Substitute (HostOp op rep) where substituteNames substs (SegOp op) = SegOp $ substituteNames substs op substituteNames substs (OtherOp op) = OtherOp $ substituteNames substs op substituteNames substs (SizeOp op) = SizeOp $ substituteNames substs op substituteNames substs (GPUBody ts body) = GPUBody (substituteNames substs ts) (substituteNames substs body) instance (ASTRep rep, Rename (op rep)) => Rename (HostOp op rep) where rename (SegOp op) = SegOp <$> rename op rename (OtherOp op) = OtherOp <$> rename op rename (SizeOp op) = SizeOp <$> rename op rename (GPUBody ts body) = GPUBody <$> rename ts <*> rename body instance (IsOp op) => IsOp (HostOp op) where safeOp (SegOp op) = safeOp op safeOp (OtherOp op) = safeOp op safeOp (SizeOp _) = True safeOp (GPUBody _ body) = all (safeExp . stmExp) $ bodyStms body cheapOp (SegOp op) = cheapOp op cheapOp (OtherOp op) = cheapOp op cheapOp (SizeOp _) = True cheapOp (GPUBody types body) = -- Current GPUBody usage only benefits from hoisting kernels that -- transfer scalars to device. SQ.null (bodyStms body) && all ((== 0) . arrayRank) types opDependencies (SegOp op) = opDependencies op opDependencies (OtherOp op) = opDependencies op opDependencies (SizeOp op) = [freeIn op] opDependencies (GPUBody _ body) = replicate (length . bodyResult $ body) (freeIn body) instance (TypedOp op) => TypedOp (HostOp op) where opType (SegOp op) = opType op opType (OtherOp op) = opType op opType (SizeOp (GetSize _ _)) = pure [Prim int64] opType (SizeOp (GetSizeMax _)) = pure [Prim int64] opType (SizeOp CmpSizeLe {}) = pure [Prim Bool] opType (SizeOp (CalcNumBlocks {})) = pure [Prim int64] opType (GPUBody ts _) = pure $ staticShapes $ map (`arrayOfRow` intConst Int64 1) ts instance (AliasedOp op) => AliasedOp (HostOp op) where opAliases (SegOp op) = opAliases op opAliases (OtherOp op) = opAliases op opAliases (SizeOp _) = [mempty] opAliases (GPUBody ts _) = map (const mempty) ts consumedInOp (SegOp op) = consumedInOp op consumedInOp (OtherOp op) = consumedInOp op consumedInOp (SizeOp _) = mempty consumedInOp (GPUBody _ body) = consumedInBody body instance (ASTRep rep, FreeIn (op rep)) => FreeIn (HostOp op rep) where freeIn' (SegOp op) = freeIn' op freeIn' (OtherOp op) = freeIn' op freeIn' (SizeOp op) = freeIn' op freeIn' (GPUBody ts body) = freeIn' ts <> freeIn' body instance (CanBeAliased op) => CanBeAliased (HostOp op) where addOpAliases aliases (SegOp op) = SegOp $ addOpAliases aliases op addOpAliases aliases (GPUBody ts body) = GPUBody ts $ Alias.analyseBody aliases body addOpAliases aliases (OtherOp op) = OtherOp $ addOpAliases aliases op addOpAliases _ (SizeOp op) = SizeOp op instance (CanBeWise op) => CanBeWise (HostOp op) where addOpWisdom (SegOp op) = SegOp $ addOpWisdom op addOpWisdom (OtherOp op) = OtherOp $ addOpWisdom op addOpWisdom (SizeOp op) = SizeOp op addOpWisdom (GPUBody ts body) = GPUBody ts $ informBody body instance OpReturns (HostOp NoOp) where opReturns (SegOp op) = segOpReturns op opReturns k = extReturns <$> opType k instance (ASTRep rep, ST.IndexOp (op rep)) => ST.IndexOp (HostOp op rep) where indexOp vtable k (SegOp op) is = ST.indexOp vtable k op is indexOp vtable k (OtherOp op) is = ST.indexOp vtable k op is indexOp _ _ _ _ = Nothing instance (PrettyRep rep, PP.Pretty (op rep)) => PP.Pretty (HostOp op rep) where pretty (SegOp op) = pretty op pretty (OtherOp op) = pretty op pretty (SizeOp op) = pretty op pretty (GPUBody ts body) = "gpu" <+> PP.colon <+> ppTuple' (map pretty ts) <+> PP.nestedBlock "{" "}" (pretty body) instance (OpMetrics (Op rep), OpMetrics (op rep)) => OpMetrics (HostOp op rep) where opMetrics (SegOp op) = opMetrics op opMetrics (OtherOp op) = opMetrics op opMetrics (SizeOp op) = opMetrics op opMetrics (GPUBody _ body) = inside "GPUBody" $ bodyMetrics body instance (RephraseOp op) => RephraseOp (HostOp op) where rephraseInOp r (SegOp op) = SegOp <$> rephraseInOp r op rephraseInOp r (OtherOp op) = OtherOp <$> rephraseInOp r op rephraseInOp _ (SizeOp op) = pure $ SizeOp op rephraseInOp r (GPUBody ts body) = GPUBody ts <$> rephraseBody r body checkGrid :: (TC.Checkable rep) => KernelGrid -> TC.TypeM rep () checkGrid grid = do TC.require [Prim int64] $ unCount $ gridNumBlocks grid TC.require [Prim int64] $ unCount $ gridBlockSize grid checkSegLevel :: (TC.Checkable rep) => Maybe SegLevel -> SegLevel -> TC.TypeM rep () checkSegLevel (Just SegBlock {}) (SegThreadInBlock _virt) = pure () checkSegLevel _ (SegThreadInBlock _virt) = TC.bad $ TC.TypeError "inblock SegOp not in block SegOp." checkSegLevel (Just SegThread {}) _ = TC.bad $ TC.TypeError "SegOps cannot occur when already at thread level." checkSegLevel (Just SegThreadInBlock {}) _ = TC.bad $ TC.TypeError "SegOps cannot occur when already at inblock level." checkSegLevel _ (SegThread _virt Nothing) = pure () checkSegLevel (Just _) SegThread {} = TC.bad $ TC.TypeError "thread-level SegOp cannot be nested" checkSegLevel Nothing (SegThread _virt grid) = mapM_ checkGrid grid checkSegLevel (Just _) SegBlock {} = TC.bad $ TC.TypeError "block-level SegOp cannot be nested" checkSegLevel Nothing (SegBlock _virt grid) = mapM_ checkGrid grid typeCheckHostOp :: (TC.Checkable rep) => (SegLevel -> Op (Aliases rep) -> TC.TypeM rep ()) -> Maybe SegLevel -> (op (Aliases rep) -> TC.TypeM rep ()) -> HostOp op (Aliases rep) -> TC.TypeM rep () typeCheckHostOp checker lvl _ (SegOp op) = TC.checkOpWith (checker $ segLevel op) $ typeCheckSegOp (checkSegLevel lvl) op typeCheckHostOp _ Just {} _ GPUBody {} = TC.bad $ TC.TypeError "GPUBody may not be nested in SegOps." typeCheckHostOp _ _ f (OtherOp op) = f op typeCheckHostOp _ _ _ (SizeOp op) = typeCheckSizeOp op typeCheckHostOp _ Nothing _ (GPUBody ts body) = do mapM_ TC.checkType ts void $ TC.checkBody body body_ts <- extendedScope (traverse subExpResType (bodyResult body)) (scopeOf (bodyStms body)) unless (body_ts == ts) . TC.bad . TC.TypeError . T.unlines $ [ "Expected type: " <> prettyTuple ts, "Got body type: " <> prettyTuple body_ts ] futhark-0.25.27/src/Futhark/IR/GPU/Simplify.hs000066400000000000000000000125701475065116200206370ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module Futhark.IR.GPU.Simplify ( simplifyGPU, simplifyLambda, GPU, -- * Building blocks simplifyKernelOp, ) where import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Analysis.UsageTable qualified as UT import Futhark.IR.GPU import Futhark.IR.SOACS.Simplify qualified as SOAC import Futhark.MonadFreshNames import Futhark.Optimise.Simplify qualified as Simplify import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Optimise.Simplify.Rep import Futhark.Optimise.Simplify.Rule import Futhark.Optimise.Simplify.Rules import Futhark.Pass import Futhark.Tools import Futhark.Util (focusNth) simpleGPU :: Simplify.SimpleOps GPU simpleGPU = Simplify.bindableSimpleOps $ simplifyKernelOp SOAC.simplifySOAC simplifyGPU :: Prog GPU -> PassM (Prog GPU) simplifyGPU = Simplify.simplifyProg simpleGPU kernelRules Simplify.noExtraHoistBlockers simplifyLambda :: (HasScope GPU m, MonadFreshNames m) => Lambda GPU -> m (Lambda GPU) simplifyLambda = Simplify.simplifyLambda simpleGPU kernelRules Engine.noExtraHoistBlockers simplifyKernelOp :: ( Engine.SimplifiableRep rep, BodyDec rep ~ () ) => Simplify.SimplifyOp rep (op (Wise rep)) -> HostOp op (Wise rep) -> Engine.SimpleM rep (HostOp op (Wise rep), Stms (Wise rep)) simplifyKernelOp f (OtherOp op) = do (op', stms) <- f op pure (OtherOp op', stms) simplifyKernelOp _ (SegOp op) = do (op', hoisted) <- simplifySegOp op pure (SegOp op', hoisted) simplifyKernelOp _ (SizeOp (GetSize key size_class)) = pure (SizeOp $ GetSize key size_class, mempty) simplifyKernelOp _ (SizeOp (GetSizeMax size_class)) = pure (SizeOp $ GetSizeMax size_class, mempty) simplifyKernelOp _ (SizeOp (CmpSizeLe key size_class x)) = do x' <- Engine.simplify x pure (SizeOp $ CmpSizeLe key size_class x', mempty) simplifyKernelOp _ (SizeOp (CalcNumBlocks w max_num_tblocks tblock_size)) = do w' <- Engine.simplify w pure (SizeOp $ CalcNumBlocks w' max_num_tblocks tblock_size, mempty) simplifyKernelOp _ (GPUBody ts body) = do ts' <- Engine.simplify ts (hoisted, body') <- Engine.simplifyBody keepOnGPU mempty (map (const mempty) ts) body pure (GPUBody ts' body', hoisted) where keepOnGPU _ _ = keepExpOnGPU . stmExp keepExpOnGPU (BasicOp Index {}) = True keepExpOnGPU (BasicOp (ArrayLit _ t)) | primType t = True keepExpOnGPU Loop {} = True keepExpOnGPU _ = False instance TraverseOpStms (Wise GPU) where traverseOpStms = traverseHostOpStms traverseSOACStms instance BuilderOps (Wise GPU) instance HasSegOp (Wise GPU) where type SegOpLevel (Wise GPU) = SegLevel asSegOp (SegOp op) = Just op asSegOp _ = Nothing segOp = SegOp instance SOAC.HasSOAC (Wise GPU) where asSOAC (OtherOp soac) = Just soac asSOAC _ = Nothing soacOp = OtherOp kernelRules :: RuleBook (Wise GPU) kernelRules = standardRules <> segOpRules <> ruleBook [ RuleOp SOAC.simplifyKnownIterationSOAC, RuleOp SOAC.removeReplicateMapping, RuleOp SOAC.liftIdentityMapping, RuleOp SOAC.simplifyMapIota, RuleOp SOAC.removeUnusedSOACInput, RuleBasicOp removeScalarCopy ] [ RuleBasicOp removeUnnecessaryCopy, RuleOp removeDeadGPUBodyResult ] -- | Remove the unused return values of a GPUBody. removeDeadGPUBodyResult :: BottomUpRuleOp (Wise GPU) removeDeadGPUBodyResult (_, used) pat aux (GPUBody types body) | -- Figure out which of the names in 'pat' are used... pat_used <- map (`UT.isUsedDirectly` used) $ patNames pat, -- If they are not all used, then this rule applies. not (and pat_used) = -- Remove the parts of the GPUBody results that correspond to dead -- return value bindings. Note that this leaves dead code in the -- kernel, but that will be removed later. let pick :: [a] -> [a] pick = map snd . filter fst . zip pat_used pat' = pick (patElems pat) types' = pick types body' = body {bodyResult = pick (bodyResult body)} in Simplify $ auxing aux $ letBind (Pat pat') $ Op $ GPUBody types' body' | otherwise = Skip removeDeadGPUBodyResult _ _ _ _ = Skip -- If we see an Update with a scalar where the value to be written is -- the result of indexing some other array, then we convert it into an -- Update with a slice of that array. This matters when the arrays -- are far away (on the GPU, say), because it avoids a copy of the -- scalar to and from the host. removeScalarCopy :: (BuilderOps rep) => TopDownRuleBasicOp rep removeScalarCopy vtable pat aux (Update safety arr_x (Slice slice_x) (Var v)) | Just _ <- sliceIndices (Slice slice_x), Just (Index arr_y (Slice slice_y), cs_y) <- ST.lookupBasicOp v vtable, ST.available arr_y vtable, not $ ST.aliases arr_x arr_y vtable, Just (slice_x_bef, DimFix i, []) <- focusNth (length slice_x - 1) slice_x, Just (slice_y_bef, DimFix j, []) <- focusNth (length slice_y - 1) slice_y = Simplify $ do let slice_x' = Slice $ slice_x_bef ++ [DimSlice i (intConst Int64 1) (intConst Int64 1)] slice_y' = Slice $ slice_y_bef ++ [DimSlice j (intConst Int64 1) (intConst Int64 1)] v' <- letExp (baseString v ++ "_slice") $ BasicOp $ Index arr_y slice_y' certifying cs_y . auxing aux $ letBind pat $ BasicOp $ Update safety arr_x slice_x' $ Var v' removeScalarCopy _ _ _ _ = Skip futhark-0.25.27/src/Futhark/IR/GPU/Sizes.hs000066400000000000000000000054401475065116200201360ustar00rootroot00000000000000-- | In the context of this module, a "size" is any kind of tunable -- (run-time) constant. module Futhark.IR.GPU.Sizes ( SizeClass (..), sizeDefault, KernelPath, Count (..), NumBlocks, BlockSize, NumThreads, ) where import Data.Int (Int64) import Data.Traversable import Futhark.IR.Prop.Names (FreeIn) import Futhark.Transform.Substitute import Futhark.Util.IntegralExp (IntegralExp) import Futhark.Util.Pretty import Language.Futhark.Core (Name) import Prelude hiding (id, (.)) -- | An indication of which comparisons have been performed to get to -- this point, as well as the result of each comparison. type KernelPath = [(Name, Bool)] -- | The class of some kind of configurable size. Each class may -- impose constraints on the valid values. data SizeClass = -- | A threshold with an optional default. SizeThreshold KernelPath (Maybe Int64) | SizeThreadBlock | -- | The number of thread blocks. SizeGrid | SizeTile | SizeRegTile | -- | Likely not useful on its own, but querying the -- maximum can be handy. SizeSharedMemory | -- | A bespoke size with a default. SizeBespoke Name Int64 | -- | Amount of registers available per threadblock. Mostly -- meaningful for querying the maximum. SizeRegisters | -- | Amount of L2 cache memory, in bytes. Mostly meaningful for -- querying the maximum. SizeCache deriving (Eq, Ord, Show) instance Pretty SizeClass where pretty (SizeThreshold path def) = "threshold" <> parens (def' <> comma <+> hsep (map pStep path)) where pStep (v, True) = pretty v pStep (v, False) = "!" <> pretty v def' = maybe "def" pretty def pretty SizeThreadBlock = "thread_block_size" pretty SizeGrid = "grid_size" pretty SizeTile = "tile_size" pretty SizeRegTile = "reg_tile_size" pretty SizeSharedMemory = "shared_memory" pretty (SizeBespoke k def) = "bespoke" <> parens (pretty k <> comma <+> pretty def) pretty SizeRegisters = "registers" pretty SizeCache = "cache" -- | The default value for the size. If 'Nothing', that means the backend gets to decide. sizeDefault :: SizeClass -> Maybe Int64 sizeDefault (SizeThreshold _ x) = x sizeDefault (SizeBespoke _ x) = Just x sizeDefault _ = Nothing -- | A wrapper supporting a phantom type for indicating what we are counting. newtype Count u e = Count {unCount :: e} deriving (Eq, Ord, Show, Num, IntegralExp, FreeIn, Pretty, Substitute) instance Functor (Count u) where fmap = fmapDefault instance Foldable (Count u) where foldMap = foldMapDefault instance Traversable (Count u) where traverse f (Count x) = Count <$> f x -- | Phantom type for the number of blocks of some kernel. data NumBlocks -- | Phantom type for the block size of some kernel. data BlockSize -- | Phantom type for number of threads. data NumThreads futhark-0.25.27/src/Futhark/IR/GPUMem.hs000066400000000000000000000067561475065116200175130ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.IR.GPUMem ( GPUMem, -- * Simplification simplifyProg, simplifyStms, simpleGPUMem, -- * Module re-exports module Futhark.IR.Mem, module Futhark.IR.GPU.Op, ) where import Futhark.Analysis.PrimExp.Convert import Futhark.Analysis.UsageTable qualified as UT import Futhark.IR.Aliases (Aliases) import Futhark.IR.GPU.Op import Futhark.IR.GPU.Simplify (simplifyKernelOp) import Futhark.IR.Mem import Futhark.IR.Mem.Simplify import Futhark.IR.TypeCheck qualified as TC import Futhark.MonadFreshNames import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Pass import Futhark.Pass.ExplicitAllocations (BuilderOps (..), mkLetNamesB', mkLetNamesB'') data GPUMem instance RepTypes GPUMem where type LetDec GPUMem = LetDecMem type FParamInfo GPUMem = FParamMem type LParamInfo GPUMem = LParamMem type RetType GPUMem = RetTypeMem type BranchType GPUMem = BranchTypeMem type OpC GPUMem = MemOp (HostOp NoOp) instance ASTRep GPUMem where expTypesFromPat = pure . map snd . bodyReturnsFromPat instance PrettyRep GPUMem instance TC.Checkable GPUMem where checkOp = typeCheckMemoryOp Nothing where -- GHC 9.2 goes into an infinite loop without the type annotation. typeCheckMemoryOp :: Maybe SegLevel -> MemOp (HostOp NoOp) (Aliases GPUMem) -> TC.TypeM GPUMem () typeCheckMemoryOp _ (Alloc size _) = TC.require [Prim int64] size typeCheckMemoryOp lvl (Inner op) = typeCheckHostOp (typeCheckMemoryOp . Just) lvl (const $ pure ()) op checkFParamDec = checkMemInfo checkLParamDec = checkMemInfo checkLetBoundDec = checkMemInfo checkRetType = mapM_ $ TC.checkExtType . declExtTypeOf primFParam name t = pure $ Param mempty name (MemPrim t) matchPat = matchPatToExp matchReturnType = matchFunctionReturnType matchBranchType = matchBranchReturnType matchLoopResult = matchLoopResultMem instance BuilderOps GPUMem where mkExpDecB _ _ = pure () mkBodyB stms res = pure $ Body () stms res mkLetNamesB = mkLetNamesB' (Space "device") () instance BuilderOps (Engine.Wise GPUMem) where mkExpDecB pat e = pure $ Engine.mkWiseExpDec pat () e mkBodyB stms res = pure $ Engine.mkWiseBody () stms res mkLetNamesB = mkLetNamesB'' (Space "device") instance TraverseOpStms (Engine.Wise GPUMem) where traverseOpStms = traverseMemOpStms (traverseHostOpStms (const pure)) simplifyProg :: Prog GPUMem -> PassM (Prog GPUMem) simplifyProg = simplifyProgGeneric memRuleBook simpleGPUMem simplifyStms :: (HasScope GPUMem m, MonadFreshNames m) => Stms GPUMem -> m (Stms GPUMem) simplifyStms = simplifyStmsGeneric memRuleBook simpleGPUMem simpleGPUMem :: Engine.SimpleOps GPUMem simpleGPUMem = simpleGeneric usage $ simplifyKernelOp $ const $ pure (NoOp, mempty) where -- Slightly hackily and very inefficiently, we look at the inside -- of SegOps to figure out the sizes of shared memory allocations, -- and add usages for those sizes. This is necessary so the -- simplifier will hoist those sizes out as far as possible (most -- importantly, past the versioning If, but see also #1569). usage (SegOp (SegMap _ _ _ kbody)) = localAllocs kbody usage _ = mempty localAllocs = foldMap stmSharedAlloc . kernelBodyStms stmSharedAlloc = expSharedAlloc . stmExp expSharedAlloc (Op (Alloc (Var v) _)) = UT.sizeUsage v expSharedAlloc (Op (Inner (SegOp (SegMap _ _ _ kbody)))) = localAllocs kbody expSharedAlloc _ = mempty futhark-0.25.27/src/Futhark/IR/MC.hs000066400000000000000000000037311475065116200167060ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | A representation for multicore CPU parallelism. module Futhark.IR.MC ( MC, -- * Simplification simplifyProg, -- * Module re-exports module Futhark.IR.Prop, module Futhark.IR.Traversals, module Futhark.IR.Pretty, module Futhark.IR.Syntax, module Futhark.IR.SegOp, module Futhark.IR.SOACS.SOAC, module Futhark.IR.MC.Op, ) where import Futhark.Builder import Futhark.Construct import Futhark.IR.MC.Op import Futhark.IR.Pretty import Futhark.IR.Prop import Futhark.IR.SOACS.SOAC hiding (HistOp (..)) import Futhark.IR.SOACS.Simplify qualified as SOAC import Futhark.IR.SegOp import Futhark.IR.Syntax import Futhark.IR.Traversals import Futhark.IR.TypeCheck qualified as TypeCheck import Futhark.Optimise.Simplify qualified as Simplify import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Optimise.Simplify.Rules import Futhark.Pass data MC instance RepTypes MC where type OpC MC = MCOp SOAC instance ASTRep MC where expTypesFromPat = pure . expExtTypesFromPat instance TypeCheck.Checkable MC where checkOp = typeCheckMCOp typeCheckSOAC instance Buildable MC where mkBody = Body () mkExpPat idents _ = basicPat idents mkExpDec _ _ = () mkLetNames = simpleMkLetNames instance BuilderOps MC instance BuilderOps (Engine.Wise MC) instance PrettyRep MC instance TraverseOpStms (Engine.Wise MC) where traverseOpStms = traverseMCOpStms traverseSOACStms simpleMC :: Simplify.SimpleOps MC simpleMC = Simplify.bindableSimpleOps $ simplifyMCOp SOAC.simplifySOAC simplifyProg :: Prog MC -> PassM (Prog MC) simplifyProg = Simplify.simplifyProg simpleMC rules blockers where blockers = Engine.noExtraHoistBlockers rules = standardRules <> segOpRules instance HasSegOp MC where type SegOpLevel MC = () asSegOp = const Nothing segOp = ParOp Nothing instance HasSegOp (Engine.Wise MC) where type SegOpLevel (Engine.Wise MC) = () asSegOp = const Nothing segOp = ParOp Nothing futhark-0.25.27/src/Futhark/IR/MC/000077500000000000000000000000001475065116200163465ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/IR/MC/Op.hs000066400000000000000000000126501475065116200172640ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -- | Definitions for multicore operations. -- -- Most of the interesting stuff is in "Futhark.IR.SegOp", which is -- also re-exported from here. module Futhark.IR.MC.Op ( MCOp (..), traverseMCOpStms, typeCheckMCOp, simplifyMCOp, module Futhark.IR.SegOp, ) where import Data.Bifunctor (first) import Futhark.Analysis.Metrics import Futhark.Analysis.SymbolTable qualified as ST import Futhark.IR import Futhark.IR.Aliases (Aliases, CanBeAliased (..)) import Futhark.IR.Mem (OpReturns (..)) import Futhark.IR.Prop.Aliases import Futhark.IR.SegOp import Futhark.IR.TypeCheck qualified as TC import Futhark.Optimise.Simplify qualified as Simplify import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Optimise.Simplify.Rep import Futhark.Transform.Rename import Futhark.Transform.Substitute import Futhark.Util.Pretty ( nestedBlock, pretty, (<+>), (), ) import Prelude hiding (id, (.)) -- | An operation for the multicore representation. Feel free to -- extend this on an ad hoc basis as needed. Parameterised with some -- other operation. data MCOp op rep = -- | The first 'SegOp' (if it exists) contains nested parallelism, -- while the second one has a fully sequential body. They are -- semantically fully equivalent. ParOp (Maybe (SegOp () rep)) (SegOp () rep) | -- | Something else (in practice often a SOAC). OtherOp (op rep) deriving (Eq, Ord, Show) traverseMCOpStms :: (Monad m) => OpStmsTraverser m (op rep) rep -> OpStmsTraverser m (MCOp op rep) rep traverseMCOpStms _ f (ParOp par_op op) = ParOp <$> traverse (traverseSegOpStms f) par_op <*> traverseSegOpStms f op traverseMCOpStms onInner f (OtherOp op) = OtherOp <$> onInner f op instance (ASTRep rep, Substitute (op rep)) => Substitute (MCOp op rep) where substituteNames substs (ParOp par_op op) = ParOp (substituteNames substs <$> par_op) (substituteNames substs op) substituteNames substs (OtherOp op) = OtherOp $ substituteNames substs op instance (ASTRep rep, Rename (op rep)) => Rename (MCOp op rep) where rename (ParOp par_op op) = ParOp <$> rename par_op <*> rename op rename (OtherOp op) = OtherOp <$> rename op instance (ASTRep rep, FreeIn (op rep)) => FreeIn (MCOp op rep) where freeIn' (ParOp par_op op) = freeIn' par_op <> freeIn' op freeIn' (OtherOp op) = freeIn' op instance (IsOp op) => IsOp (MCOp op) where safeOp (ParOp _ op) = safeOp op safeOp (OtherOp op) = safeOp op cheapOp (ParOp _ op) = cheapOp op cheapOp (OtherOp op) = cheapOp op opDependencies (ParOp _ op) = opDependencies op opDependencies (OtherOp op) = opDependencies op instance (TypedOp op) => TypedOp (MCOp op) where opType (ParOp _ op) = opType op opType (OtherOp op) = opType op instance (AliasedOp op) => AliasedOp (MCOp op) where opAliases (ParOp _ op) = opAliases op opAliases (OtherOp op) = opAliases op consumedInOp (ParOp _ op) = consumedInOp op consumedInOp (OtherOp op) = consumedInOp op instance (CanBeAliased op) => CanBeAliased (MCOp op) where addOpAliases aliases (ParOp par_op op) = ParOp (addOpAliases aliases <$> par_op) (addOpAliases aliases op) addOpAliases aliases (OtherOp op) = OtherOp $ addOpAliases aliases op instance (CanBeWise op) => CanBeWise (MCOp op) where addOpWisdom (ParOp par_op op) = ParOp (addOpWisdom <$> par_op) (addOpWisdom op) addOpWisdom (OtherOp op) = OtherOp $ addOpWisdom op instance (ASTRep rep, ST.IndexOp (op rep)) => ST.IndexOp (MCOp op rep) where indexOp vtable k (ParOp _ op) is = ST.indexOp vtable k op is indexOp vtable k (OtherOp op) is = ST.indexOp vtable k op is instance OpReturns (MCOp NoOp) where opReturns (ParOp _ op) = segOpReturns op opReturns (OtherOp NoOp) = pure [] instance (PrettyRep rep, Pretty (op rep)) => Pretty (MCOp op rep) where pretty (ParOp Nothing op) = pretty op pretty (ParOp (Just par_op) op) = "par" <+> nestedBlock "{" "}" (pretty par_op) "seq" <+> nestedBlock "{" "}" (pretty op) pretty (OtherOp op) = pretty op instance (OpMetrics (Op rep), OpMetrics (op rep)) => OpMetrics (MCOp op rep) where opMetrics (ParOp par_op op) = opMetrics par_op >> opMetrics op opMetrics (OtherOp op) = opMetrics op instance (RephraseOp op) => RephraseOp (MCOp op) where rephraseInOp r (ParOp par_op op) = ParOp <$> traverse (rephraseInOp r) par_op <*> rephraseInOp r op rephraseInOp r (OtherOp op) = OtherOp <$> rephraseInOp r op typeCheckMCOp :: (TC.Checkable rep) => (op (Aliases rep) -> TC.TypeM rep ()) -> MCOp op (Aliases rep) -> TC.TypeM rep () typeCheckMCOp _ (ParOp (Just par_op) op) = do -- It is valid for the same array to be consumed in both par_op and op. _ <- typeCheckSegOp pure par_op `TC.alternative` typeCheckSegOp pure op pure () typeCheckMCOp _ (ParOp Nothing op) = typeCheckSegOp pure op typeCheckMCOp f (OtherOp op) = f op simplifyMCOp :: ( Engine.SimplifiableRep rep, BodyDec rep ~ () ) => Simplify.SimplifyOp rep (op (Wise rep)) -> MCOp op (Wise rep) -> Engine.SimpleM rep (MCOp op (Wise rep), Stms (Wise rep)) simplifyMCOp f (OtherOp op) = do (op', stms) <- f op pure (OtherOp op', stms) simplifyMCOp _ (ParOp par_op op) = do (par_op', par_op_hoisted) <- case par_op of Nothing -> pure (Nothing, mempty) Just x -> first Just <$> simplifySegOp x (op', op_hoisted) <- simplifySegOp op pure (ParOp par_op' op', par_op_hoisted <> op_hoisted) futhark-0.25.27/src/Futhark/IR/MCMem.hs000066400000000000000000000043301475065116200173410ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.IR.MCMem ( MCMem, -- * Simplification simplifyProg, -- * Module re-exports module Futhark.IR.Mem, module Futhark.IR.SegOp, module Futhark.IR.MC.Op, ) where import Futhark.Analysis.PrimExp.Convert import Futhark.IR.MC.Op import Futhark.IR.Mem import Futhark.IR.Mem.Simplify import Futhark.IR.SegOp import Futhark.IR.TypeCheck qualified as TC import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Pass import Futhark.Pass.ExplicitAllocations (BuilderOps (..), mkLetNamesB', mkLetNamesB'') data MCMem instance RepTypes MCMem where type LetDec MCMem = LetDecMem type FParamInfo MCMem = FParamMem type LParamInfo MCMem = LParamMem type RetType MCMem = RetTypeMem type BranchType MCMem = BranchTypeMem type OpC MCMem = MemOp (MCOp NoOp) instance ASTRep MCMem where expTypesFromPat = pure . map snd . bodyReturnsFromPat instance PrettyRep MCMem instance TC.Checkable MCMem where checkOp = typeCheckMemoryOp where typeCheckMemoryOp (Alloc size _) = TC.require [Prim int64] size typeCheckMemoryOp (Inner op) = typeCheckMCOp (const $ pure ()) op checkFParamDec = checkMemInfo checkLParamDec = checkMemInfo checkLetBoundDec = checkMemInfo checkRetType = mapM_ (TC.checkExtType . declExtTypeOf) primFParam name t = pure $ Param mempty name (MemPrim t) matchPat = matchPatToExp matchReturnType = matchFunctionReturnType matchBranchType = matchBranchReturnType matchLoopResult = matchLoopResultMem instance BuilderOps MCMem where mkExpDecB _ _ = pure () mkBodyB stms res = pure $ Body () stms res mkLetNamesB = mkLetNamesB' DefaultSpace () instance BuilderOps (Engine.Wise MCMem) where mkExpDecB pat e = pure $ Engine.mkWiseExpDec pat () e mkBodyB stms res = pure $ Engine.mkWiseBody () stms res mkLetNamesB = mkLetNamesB'' DefaultSpace instance TraverseOpStms (Engine.Wise MCMem) where traverseOpStms = traverseMemOpStms (traverseMCOpStms (const pure)) simplifyProg :: Prog MCMem -> PassM (Prog MCMem) simplifyProg = simplifyProgGeneric memRuleBook simpleMCMem simpleMCMem :: Engine.SimpleOps MCMem simpleMCMem = simpleGeneric (const mempty) $ simplifyMCOp $ const $ pure (NoOp, mempty) futhark-0.25.27/src/Futhark/IR/Mem.hs000066400000000000000000001171431475065116200171300ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -- | Building blocks for defining representations where every array -- is given information about which memory block is it based in, and -- how array elements map to memory block offsets. -- -- There are two primary concepts you will need to understand: -- -- 1. Memory blocks, which are Futhark values of type v'Mem' -- (parametrized with their size). These correspond to arbitrary -- blocks of memory, and are created using the 'Alloc' operation. -- -- 2. Index functions, which describe a mapping from the index space -- of an array (eg. a two-dimensional space for an array of type -- @[[int]]@) to a one-dimensional offset into a memory block. -- Thus, index functions describe how arbitrary-dimensional arrays -- are mapped to the single-dimensional world of memory. -- -- At a conceptual level, imagine that we have a two-dimensional array -- @a@ of 32-bit integers, consisting of @n@ rows of @m@ elements -- each. This array could be represented in classic row-major format -- with an index function like the following: -- -- @ -- f(i,j) = i * m + j -- @ -- -- When we want to know the location of element @a[2,3]@, we simply -- call the index function as @f(2,3)@ and obtain @2*m+3@. We could -- also have chosen another index function, one that represents the -- array in column-major (or "transposed") format: -- -- @ -- f(i,j) = j * n + i -- @ -- -- Index functions are not Futhark-level functions, but a special -- construct that the final code generator will eventually use to -- generate concrete access code. By modifying the index functions we -- can change how an array is represented in memory, which can permit -- memory access pattern optimisations. -- -- Every time we bind an array, whether in a @let@-binding, @loop@ -- merge parameter, or @lambda@ parameter, we have an annotation -- specifying a memory block and an index function. In some cases, -- such as @let@-bindings for many expressions, we are free to specify -- an arbitrary index function and memory block - for example, we get -- to decide where 'Copy' stores its result - but in other cases the -- type rules of the expression chooses for us. For example, 'Index' -- always produces an array in the same memory block as its input, and -- with the same index function, except with some indices fixed. module Futhark.IR.Mem ( LetDecMem, FParamMem, LParamMem, RetTypeMem, BranchTypeMem, MemOp (..), traverseMemOpStms, MemInfo (..), MemBound, MemBind (..), MemReturn (..), LMAD, ExtLMAD, isStaticLMAD, ExpReturns, BodyReturns, FunReturns, noUniquenessReturns, bodyReturnsToExpReturns, Mem, HasLetDecMem (..), OpReturns (..), varReturns, expReturns, extReturns, nameInfoToMemInfo, lookupMemInfo, subExpMemInfo, lookupArraySummary, lookupMemSpace, existentialiseLMAD, -- * Type checking parts matchBranchReturnType, matchPatToExp, matchFunctionReturnType, matchLoopResultMem, bodyReturnsFromPat, checkMemInfo, -- * Module re-exports module Futhark.IR.Prop, module Futhark.IR.Traversals, module Futhark.IR.Pretty, module Futhark.IR.Syntax, module Futhark.Analysis.PrimExp.Convert, ) where import Control.Category import Control.Monad import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State import Data.Foldable (traverse_) import Data.Function ((&)) import Data.Kind qualified import Data.List (elemIndex, find) import Data.Map.Strict qualified as M import Data.Maybe import Data.Text qualified as T import Futhark.Analysis.Metrics import Futhark.Analysis.PrimExp.Convert import Futhark.Analysis.PrimExp.Simplify import Futhark.Analysis.SymbolTable qualified as ST import Futhark.IR.Aliases ( Aliases, CanBeAliased (..), removeExpAliases, removePatAliases, removeScopeAliases, ) import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.IR.Pretty import Futhark.IR.Prop import Futhark.IR.Prop.Aliases import Futhark.IR.Syntax import Futhark.IR.Traversals import Futhark.IR.TypeCheck qualified as TC import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Optimise.Simplify.Rep import Futhark.Transform.Rename import Futhark.Transform.Substitute import Futhark.Util import Futhark.Util.Pretty (docText, indent, ppTupleLines', pretty, (<+>), ()) import Futhark.Util.Pretty qualified as PP import Prelude hiding (id, (.)) type LetDecMem = MemInfo SubExp NoUniqueness MemBind type FParamMem = MemInfo SubExp Uniqueness MemBind type LParamMem = MemInfo SubExp NoUniqueness MemBind type RetTypeMem = FunReturns type BranchTypeMem = BodyReturns -- | The class of pattern element decorators that contain memory -- information. class HasLetDecMem t where letDecMem :: t -> LetDecMem instance HasLetDecMem LetDecMem where letDecMem = id instance (HasLetDecMem b) => HasLetDecMem (a, b) where letDecMem = letDecMem . snd type Mem rep inner = ( FParamInfo rep ~ FParamMem, LParamInfo rep ~ LParamMem, HasLetDecMem (LetDec rep), RetType rep ~ RetTypeMem, BranchType rep ~ BranchTypeMem, ASTRep rep, OpReturns inner, RephraseOp inner, ASTConstraints (inner rep), FreeIn (inner rep), OpC rep ~ MemOp inner ) instance IsRetType FunReturns where primRetType = MemPrim applyRetType = applyFunReturns instance IsBodyType BodyReturns where primBodyType = MemPrim data MemOp (inner :: Data.Kind.Type -> Data.Kind.Type) (rep :: Data.Kind.Type) = -- | Allocate a memory block. Alloc SubExp Space | Inner (inner rep) deriving (Eq, Ord, Show) -- | A helper for defining 'TraverseOpStms'. traverseMemOpStms :: (Monad m) => OpStmsTraverser m (inner rep) rep -> OpStmsTraverser m (MemOp inner rep) rep traverseMemOpStms _ _ op@Alloc {} = pure op traverseMemOpStms onInner f (Inner inner) = Inner <$> onInner f inner instance (RephraseOp inner) => RephraseOp (MemOp inner) where rephraseInOp _ (Alloc e space) = pure (Alloc e space) rephraseInOp r (Inner x) = Inner <$> rephraseInOp r x instance (FreeIn (inner rep)) => FreeIn (MemOp inner rep) where freeIn' (Alloc size _) = freeIn' size freeIn' (Inner k) = freeIn' k instance (TypedOp inner) => TypedOp (MemOp inner) where opType (Alloc _ space) = pure [Mem space] opType (Inner k) = opType k instance (AliasedOp inner) => AliasedOp (MemOp inner) where opAliases Alloc {} = [mempty] opAliases (Inner k) = opAliases k consumedInOp Alloc {} = mempty consumedInOp (Inner k) = consumedInOp k instance (CanBeAliased inner) => CanBeAliased (MemOp inner) where addOpAliases _ (Alloc se space) = Alloc se space addOpAliases aliases (Inner k) = Inner $ addOpAliases aliases k instance (Rename (inner rep)) => Rename (MemOp inner rep) where rename (Alloc size space) = Alloc <$> rename size <*> pure space rename (Inner k) = Inner <$> rename k instance (Substitute (inner rep)) => Substitute (MemOp inner rep) where substituteNames subst (Alloc size space) = Alloc (substituteNames subst size) space substituteNames subst (Inner k) = Inner $ substituteNames subst k instance (PP.Pretty (inner rep)) => PP.Pretty (MemOp inner rep) where pretty (Alloc e DefaultSpace) = "alloc" <> PP.apply [PP.pretty e] pretty (Alloc e s) = "alloc" <> PP.apply [PP.pretty e, PP.pretty s] pretty (Inner k) = PP.pretty k instance (OpMetrics (inner rep)) => OpMetrics (MemOp inner rep) where opMetrics Alloc {} = seen "Alloc" opMetrics (Inner k) = opMetrics k instance (IsOp inner) => IsOp (MemOp inner) where safeOp (Alloc (Constant (IntValue (Int64Value k))) _) = k >= 0 safeOp Alloc {} = False safeOp (Inner k) = safeOp k cheapOp (Inner k) = cheapOp k cheapOp Alloc {} = True opDependencies (Alloc _ e) = [freeIn e] opDependencies (Inner op) = opDependencies op instance (CanBeWise inner) => CanBeWise (MemOp inner) where addOpWisdom (Alloc size space) = Alloc size space addOpWisdom (Inner k) = Inner $ addOpWisdom k instance (ST.IndexOp (inner rep)) => ST.IndexOp (MemOp inner rep) where indexOp vtable k (Inner op) is = ST.indexOp vtable k op is indexOp _ _ _ _ = Nothing -- | The LMAD representation used for memory annotations. type LMAD = LMAD.LMAD (TPrimExp Int64 VName) -- | An index function that may contain existential variables. type ExtLMAD = LMAD.LMAD (TPrimExp Int64 (Ext VName)) -- | A summary of the memory information for every let-bound -- identifier, function parameter, and return value. Parameterisered -- over uniqueness, dimension, and auxiliary array information. data MemInfo d u ret = -- | A primitive value. MemPrim PrimType | -- | A memory block. MemMem Space | -- | The array is stored in the named memory block, and with the -- given index function. The index function maps indices in the -- array to /element/ offset, /not/ byte offsets! To translate to -- byte offsets, multiply the offset with the size of the array -- element type. MemArray PrimType (ShapeBase d) u ret | -- | An accumulator, which is not stored anywhere. MemAcc VName Shape [Type] u deriving (Eq, Show, Ord) --- XXX Ord? type MemBound u = MemInfo SubExp u MemBind instance (FixExt ret) => DeclExtTyped (MemInfo ExtSize Uniqueness ret) where declExtTypeOf (MemPrim pt) = Prim pt declExtTypeOf (MemMem space) = Mem space declExtTypeOf (MemArray pt shape u _) = Array pt shape u declExtTypeOf (MemAcc acc ispace ts u) = Acc acc ispace ts u instance (FixExt ret) => ExtTyped (MemInfo ExtSize Uniqueness ret) where extTypeOf = fromDecl . declExtTypeOf instance (FixExt ret) => ExtTyped (MemInfo ExtSize NoUniqueness ret) where extTypeOf (MemPrim pt) = Prim pt extTypeOf (MemMem space) = Mem space extTypeOf (MemArray pt shape u _) = Array pt shape u extTypeOf (MemAcc acc ispace ts u) = Acc acc ispace ts u instance (FixExt ret) => FixExt (MemInfo ExtSize u ret) where fixExt _ _ (MemPrim pt) = MemPrim pt fixExt _ _ (MemMem space) = MemMem space fixExt _ _ (MemAcc acc ispace ts u) = MemAcc acc ispace ts u fixExt i se (MemArray pt shape u ret) = MemArray pt (fixExt i se shape) u (fixExt i se ret) mapExt _ (MemPrim pt) = MemPrim pt mapExt _ (MemMem space) = MemMem space mapExt _ (MemAcc acc ispace ts u) = MemAcc acc ispace ts u mapExt f (MemArray pt shape u ret) = MemArray pt (mapExt f shape) u (mapExt f ret) instance Typed (MemInfo SubExp Uniqueness ret) where typeOf = fromDecl . declTypeOf instance Typed (MemInfo SubExp NoUniqueness ret) where typeOf (MemPrim pt) = Prim pt typeOf (MemMem space) = Mem space typeOf (MemArray bt shape u _) = Array bt shape u typeOf (MemAcc acc ispace ts u) = Acc acc ispace ts u instance DeclTyped (MemInfo SubExp Uniqueness ret) where declTypeOf (MemPrim bt) = Prim bt declTypeOf (MemMem space) = Mem space declTypeOf (MemArray bt shape u _) = Array bt shape u declTypeOf (MemAcc acc ispace ts u) = Acc acc ispace ts u instance (FreeIn d, FreeIn ret) => FreeIn (MemInfo d u ret) where freeIn' (MemArray _ shape _ ret) = freeIn' shape <> freeIn' ret freeIn' (MemMem s) = freeIn' s freeIn' MemPrim {} = mempty freeIn' (MemAcc acc ispace ts _) = freeIn' (acc, ispace, ts) instance (Substitute d, Substitute ret) => Substitute (MemInfo d u ret) where substituteNames subst (MemArray bt shape u ret) = MemArray bt (substituteNames subst shape) u (substituteNames subst ret) substituteNames substs (MemAcc acc ispace ts u) = MemAcc (substituteNames substs acc) (substituteNames substs ispace) (substituteNames substs ts) u substituteNames _ (MemMem space) = MemMem space substituteNames _ (MemPrim bt) = MemPrim bt instance (Substitute d, Substitute ret) => Rename (MemInfo d u ret) where rename = substituteRename simplifyLMAD :: (Engine.SimplifiableRep rep) => LMAD -> Engine.SimpleM rep LMAD simplifyLMAD = traverse $ fmap isInt64 . simplifyPrimExp . untyped simplifyExtLMAD :: (Engine.SimplifiableRep rep) => ExtLMAD -> Engine.SimpleM rep ExtLMAD simplifyExtLMAD = traverse $ fmap isInt64 . simplifyExtPrimExp . untyped isStaticLMAD :: ExtLMAD -> Maybe LMAD isStaticLMAD = traverse $ traverse inst where inst Ext {} = Nothing inst (Free x) = Just x instance (Engine.Simplifiable d, Engine.Simplifiable ret) => Engine.Simplifiable (MemInfo d u ret) where simplify (MemPrim bt) = pure $ MemPrim bt simplify (MemMem space) = pure $ MemMem space simplify (MemArray bt shape u ret) = MemArray bt <$> Engine.simplify shape <*> pure u <*> Engine.simplify ret simplify (MemAcc acc ispace ts u) = MemAcc <$> Engine.simplify acc <*> Engine.simplify ispace <*> Engine.simplify ts <*> pure u instance ( PP.Pretty (ShapeBase d), PP.Pretty (TypeBase (ShapeBase d) u), PP.Pretty d, PP.Pretty u, PP.Pretty ret ) => PP.Pretty (MemInfo d u ret) where pretty (MemPrim bt) = PP.pretty bt pretty (MemMem DefaultSpace) = "mem" pretty (MemMem s) = "mem" <> PP.pretty s pretty (MemArray bt shape u ret) = PP.pretty (Array bt shape u) <+> "@" <+> PP.pretty ret pretty (MemAcc acc ispace ts u) = PP.pretty u <> PP.pretty (Acc acc ispace ts NoUniqueness :: Type) -- | Memory information for an array bound somewhere in the program. data MemBind = -- | Located in this memory block with this index -- function. ArrayIn VName LMAD deriving (Show) instance Eq MemBind where _ == _ = True instance Ord MemBind where _ `compare` _ = EQ instance Rename MemBind where rename = substituteRename instance Substitute MemBind where substituteNames substs (ArrayIn ident lmad) = ArrayIn (substituteNames substs ident) (substituteNames substs lmad) instance PP.Pretty MemBind where pretty (ArrayIn mem lmad) = PP.pretty mem <+> "->" PP. PP.pretty lmad instance FreeIn MemBind where freeIn' (ArrayIn mem lmad) = freeIn' mem <> freeIn' lmad -- | A description of the memory properties of an array being returned -- by an operation. data MemReturn = -- | The array is located in a memory block that is -- already in scope. ReturnsInBlock VName ExtLMAD | -- | The operation returns a new (existential) memory -- block. ReturnsNewBlock Space Int ExtLMAD deriving (Show) instance Eq MemReturn where _ == _ = True instance Ord MemReturn where _ `compare` _ = EQ instance Rename MemReturn where rename = substituteRename instance Substitute MemReturn where substituteNames substs (ReturnsInBlock ident lmad) = ReturnsInBlock (substituteNames substs ident) (substituteNames substs lmad) substituteNames substs (ReturnsNewBlock space i lmad) = ReturnsNewBlock space i (substituteNames substs lmad) instance FixExt MemReturn where fixExt i (Var v) (ReturnsNewBlock _ j lmad) | j == i = ReturnsInBlock v $ fixExtLMAD i (primExpFromSubExp int64 (Var v)) lmad fixExt i se (ReturnsNewBlock space j lmad) = ReturnsNewBlock space j' (fixExtLMAD i (primExpFromSubExp int64 se) lmad) where j' | i < j = j - 1 | otherwise = j fixExt i se (ReturnsInBlock mem lmad) = ReturnsInBlock mem (fixExtLMAD i (primExpFromSubExp int64 se) lmad) mapExt f (ReturnsNewBlock space i lmad) = ReturnsNewBlock space (f i) lmad mapExt f (ReturnsInBlock mem lmad) = ReturnsInBlock mem (fmap (fmap f') lmad) where f' (Ext i) = Ext $ f i f' v = v fixExtLMAD :: Int -> PrimExp VName -> ExtLMAD -> ExtLMAD fixExtLMAD i e = fmap $ isInt64 . replaceInPrimExp update . untyped where update (Ext j) t | j > i = LeafExp (Ext $ j - 1) t | j == i = fmap Free e | otherwise = LeafExp (Ext j) t update (Free x) t = LeafExp (Free x) t leafExp :: Int -> TPrimExp Int64 (Ext a) leafExp i = isInt64 $ LeafExp (Ext i) int64 existentialiseLMAD :: [VName] -> LMAD -> ExtLMAD existentialiseLMAD ctx = LMAD.substitute ctx' . fmap (fmap Free) where ctx' = M.map leafExp $ M.fromList $ zip (map Free ctx) [0 ..] instance PP.Pretty MemReturn where pretty (ReturnsInBlock v lmad) = pretty v <+> "->" PP. PP.pretty lmad pretty (ReturnsNewBlock space i lmad) = "?" <> pretty i <> PP.pretty space <+> "->" PP. PP.pretty lmad instance FreeIn MemReturn where freeIn' (ReturnsInBlock v lmad) = freeIn' v <> freeIn' lmad freeIn' (ReturnsNewBlock space _ lmad) = freeIn' space <> freeIn' lmad instance Engine.Simplifiable MemReturn where simplify (ReturnsNewBlock space i lmad) = ReturnsNewBlock space i <$> simplifyExtLMAD lmad simplify (ReturnsInBlock v lmad) = ReturnsInBlock <$> Engine.simplify v <*> simplifyExtLMAD lmad instance Engine.Simplifiable MemBind where simplify (ArrayIn mem lmad) = ArrayIn <$> Engine.simplify mem <*> simplifyLMAD lmad instance Engine.Simplifiable [FunReturns] where simplify = mapM Engine.simplify -- | The memory return of an expression. An array is annotated with -- @Maybe MemReturn@, which can be interpreted as the expression -- either dictating exactly where the array is located when it is -- returned (if 'Just'), or able to put it whereever the binding -- prefers (if 'Nothing'). -- -- This is necessary to capture the difference between an expression -- that is just an array-typed variable, in which the array being -- "returned" is located where it already is, and a @copy@ expression, -- whose entire purpose is to store an existing array in some -- arbitrary location. This is a consequence of the design decision -- never to have implicit memory copies. type ExpReturns = MemInfo ExtSize NoUniqueness (Maybe MemReturn) -- | The return of a body, which must always indicate where -- returned arrays are located. type BodyReturns = MemInfo ExtSize NoUniqueness MemReturn -- | The memory return of a function, which must always indicate where -- returned arrays are located. type FunReturns = MemInfo ExtSize Uniqueness MemReturn maybeReturns :: MemInfo d u r -> MemInfo d u (Maybe r) maybeReturns (MemArray bt shape u ret) = MemArray bt shape u $ Just ret maybeReturns (MemPrim bt) = MemPrim bt maybeReturns (MemMem space) = MemMem space maybeReturns (MemAcc acc ispace ts u) = MemAcc acc ispace ts u noUniquenessReturns :: MemInfo d u r -> MemInfo d NoUniqueness r noUniquenessReturns (MemArray bt shape _ r) = MemArray bt shape NoUniqueness r noUniquenessReturns (MemPrim bt) = MemPrim bt noUniquenessReturns (MemMem space) = MemMem space noUniquenessReturns (MemAcc acc ispace ts _) = MemAcc acc ispace ts NoUniqueness funReturnsToExpReturns :: FunReturns -> ExpReturns funReturnsToExpReturns = noUniquenessReturns . maybeReturns bodyReturnsToExpReturns :: BodyReturns -> ExpReturns bodyReturnsToExpReturns = noUniquenessReturns . maybeReturns varInfoToExpReturns :: MemInfo SubExp NoUniqueness MemBind -> ExpReturns varInfoToExpReturns (MemArray et shape u (ArrayIn mem lmad)) = MemArray et (fmap Free shape) u $ Just $ ReturnsInBlock mem $ existentialiseLMAD [] lmad varInfoToExpReturns (MemPrim pt) = MemPrim pt varInfoToExpReturns (MemAcc acc ispace ts u) = MemAcc acc ispace ts u varInfoToExpReturns (MemMem space) = MemMem space matchRetTypeToResult :: (Mem rep inner, TC.Checkable rep) => [FunReturns] -> Result -> TC.TypeM rep () matchRetTypeToResult rettype result = do scope <- askScope result_ts <- runReaderT (mapM (subExpMemInfo . resSubExp) result) $ removeScopeAliases scope matchReturnType rettype (map resSubExp result) result_ts matchFunctionReturnType :: (Mem rep inner, TC.Checkable rep) => [FunReturns] -> Result -> TC.TypeM rep () matchFunctionReturnType rettype result = do matchRetTypeToResult rettype result mapM_ (checkResultSubExp . resSubExp) result where checkResultSubExp Constant {} = pure () checkResultSubExp (Var v) = do dec <- varMemInfo v case dec of MemPrim _ -> pure () MemMem {} -> pure () MemAcc {} -> pure () MemArray _ _ _ (ArrayIn _ lmad) | LMAD.isDirect lmad -> pure () | otherwise -> TC.bad . TC.TypeError $ "Array " <> prettyText v <> " returned by function, but has nontrivial index function:\n" <> prettyText lmad matchLoopResultMem :: (Mem rep inner, TC.Checkable rep) => [FParam (Aliases rep)] -> Result -> TC.TypeM rep () matchLoopResultMem params = matchRetTypeToResult rettype where param_names = map paramName params -- Invent a ReturnType so we can pretend that the loop body is -- actually returning from a function. rettype = map (toRet . paramDec) params toExtV v | Just i <- v `elemIndex` param_names = Ext i | otherwise = Free v toExtSE (Var v) = Var <$> toExtV v toExtSE (Constant v) = Free $ Constant v toRet (MemPrim t) = MemPrim t toRet (MemMem space) = MemMem space toRet (MemAcc acc ispace ts u) = MemAcc acc ispace ts u toRet (MemArray pt shape u (ArrayIn mem lmad)) | Just i <- mem `elemIndex` param_names, Param _ _ (MemMem space) : _ <- drop i params = MemArray pt shape' u $ ReturnsNewBlock space i lmad' | otherwise = MemArray pt shape' u $ ReturnsInBlock mem lmad' where shape' = fmap toExtSE shape lmad' = existentialiseLMAD param_names lmad matchBranchReturnType :: (Mem rep inner, TC.Checkable rep) => [BodyReturns] -> Body (Aliases rep) -> TC.TypeM rep () matchBranchReturnType rettype (Body _ stms res) = do scope <- askScope ts <- runReaderT (mapM (subExpMemInfo . resSubExp) res) $ removeScopeAliases (scope <> scopeOf stms) matchReturnType rettype (map resSubExp res) ts -- | Helper function for index function unification. -- -- The first return value maps a VName (wrapped in 'Free') to its Int -- (wrapped in 'Ext'). In case of duplicates, it is mapped to the -- *first* Int that occurs. -- -- The second return value maps each Int (wrapped in an 'Ext') to a -- 'LeafExp' 'Ext' with the Int at which its associated VName first -- occurs. getExtMaps :: [(VName, Int)] -> ( M.Map (Ext VName) (TPrimExp Int64 (Ext VName)), M.Map (Ext VName) (TPrimExp Int64 (Ext VName)) ) getExtMaps ctx_lst_ids = ( M.map leafExp $ M.mapKeys Free $ M.fromListWith (const id) ctx_lst_ids, M.fromList $ mapMaybe ( traverse ( fmap (\i -> isInt64 $ LeafExp (Ext i) int64) . (`lookup` ctx_lst_ids) ) . uncurry (flip (,)) . fmap Ext ) ctx_lst_ids ) matchReturnType :: (PP.Pretty u) => [MemInfo ExtSize u MemReturn] -> [SubExp] -> [MemInfo SubExp NoUniqueness MemBind] -> TC.TypeM rep () matchReturnType rettype res ts = do let existentialiseLMAD0 :: LMAD -> ExtLMAD existentialiseLMAD0 = fmap $ fmap Free fetchCtx i = case maybeNth i $ zip res ts of Nothing -> throwError $ "Cannot find variable #" <> prettyText i <> " in results: " <> prettyText res Just (se, t) -> pure (se, t) checkReturn (MemPrim x) (MemPrim y) | x == y = pure () checkReturn (MemMem x) (MemMem y) | x == y = pure () checkReturn (MemAcc xacc xispace xts _) (MemAcc yacc yispace yts _) | (xacc, xispace, xts) == (yacc, yispace, yts) = pure () checkReturn (MemArray x_pt x_shape _ x_ret) (MemArray y_pt y_shape _ y_ret) | x_pt == y_pt, shapeRank x_shape == shapeRank y_shape = do zipWithM_ checkDim (shapeDims x_shape) (shapeDims y_shape) checkMemReturn x_ret y_ret checkReturn x y = throwError $ T.unwords ["Expected", prettyText x, "but got", prettyText y] checkDim (Free x) y | x == y = pure () | otherwise = throwError $ T.unwords ["Expected dim", prettyText x, "but got", prettyText y] checkDim (Ext i) y = do (x, _) <- fetchCtx i unless (x == y) . throwError . T.unwords $ ["Expected ext dim", prettyText i, "=>", prettyText x, "but got", prettyText y] checkMemReturn (ReturnsInBlock x_mem x_lmad) (ArrayIn y_mem y_lmad) | x_mem == y_mem = unless (LMAD.closeEnough x_lmad $ existentialiseLMAD0 y_lmad) $ throwError . T.unwords $ [ "Index function unification failed (ReturnsInBlock)", "\nlmad of body result: ", prettyText y_lmad, "\nlmad of return type: ", prettyText x_lmad ] checkMemReturn (ReturnsNewBlock x_space x_ext x_lmad) (ArrayIn y_mem y_lmad) = do (x_mem, x_mem_type) <- fetchCtx x_ext unless (LMAD.closeEnough x_lmad $ existentialiseLMAD0 y_lmad) $ throwError . docText $ "Index function unification failed (ReturnsNewBlock)" "Lmad of body result:" indent 2 (pretty y_lmad) "Lmad of return type:" indent 2 (pretty x_lmad) case x_mem_type of MemMem y_space -> unless (x_space == y_space) . throwError . T.unwords $ [ "Expected memory", prettyText y_mem, "in space", prettyText x_space, "but actually in space", prettyText y_space ] t -> throwError . T.unwords $ ["Expected memory", prettyText x_ext, "=>", prettyText x_mem, "but but has type", prettyText t] checkMemReturn x y = throwError . docText $ "Expected array in" indent 2 (pretty x) "but array returned in" indent 2 (pretty y) bad s = TC.bad . TC.TypeError . docText $ "Return type" indent 2 (ppTupleLines' $ map pretty rettype) "cannot match returns of results" indent 2 (ppTupleLines' $ map pretty ts) pretty s unless (length rettype == length ts) $ TC.bad . TC.TypeError . docText $ "Return type" indent 2 (ppTupleLines' $ map pretty rettype) "does not have same number of elements as results" indent 2 (ppTupleLines' $ map pretty ts) either bad pure =<< runExceptT (zipWithM_ checkReturn rettype ts) matchPatToExp :: (Mem rep inner, LetDec rep ~ LetDecMem, TC.Checkable rep) => Pat (LetDec (Aliases rep)) -> Exp (Aliases rep) -> TC.TypeM rep () matchPatToExp pat e = do scope <- asksScope removeScopeAliases rt <- maybe illformed pure $ runReader (expReturns $ removeExpAliases e) scope let (ctx_ids, val_ts) = unzip $ bodyReturnsFromPat $ removePatAliases pat (ctx_map_ids, ctx_map_exts) = getExtMaps $ zip ctx_ids [0 .. 1] ok = length val_ts == length rt && and (zipWith (matches ctx_map_ids ctx_map_exts) val_ts rt) unless ok . TC.bad . TC.TypeError . docText $ "Expression type:" indent 2 (ppTupleLines' $ map pretty rt) "cannot match pattern type:" indent 2 (ppTupleLines' $ map pretty val_ts) where illformed = TC.bad $ TC.TypeError . docText $ "Expression" indent 2 (pretty e) "cannot be assigned an index function." matches _ _ (MemPrim x) (MemPrim y) = x == y matches _ _ (MemMem x_space) (MemMem y_space) = x_space == y_space matches _ _ (MemAcc x_accs x_ispace x_ts _) (MemAcc y_accs y_ispace y_ts _) = (x_accs, x_ispace, x_ts) == (y_accs, y_ispace, y_ts) matches ctxids ctxexts (MemArray x_pt x_shape _ x_ret) (MemArray y_pt y_shape _ y_ret) = x_pt == y_pt && x_shape == y_shape && case (x_ret, y_ret) of (ReturnsInBlock _ x_lmad, Just (ReturnsInBlock _ y_lmad)) -> let x_lmad' = LMAD.substitute ctxids x_lmad y_lmad' = LMAD.substitute ctxexts y_lmad in LMAD.closeEnough x_lmad' y_lmad' ( ReturnsInBlock _ x_lmad, Just (ReturnsNewBlock _ _ y_lmad) ) -> let x_lmad' = LMAD.substitute ctxids x_lmad y_lmad' = LMAD.substitute ctxexts y_lmad in LMAD.closeEnough x_lmad' y_lmad' ( ReturnsNewBlock _ x_i x_lmad, Just (ReturnsNewBlock _ y_i y_lmad) ) -> let x_lmad' = LMAD.substitute ctxids x_lmad y_lmad' = LMAD.substitute ctxexts y_lmad in x_i == y_i && LMAD.closeEnough x_lmad' y_lmad' (_, Nothing) -> True _ -> False matches _ _ _ _ = False varMemInfo :: (Mem rep inner) => VName -> TC.TypeM rep (MemInfo SubExp NoUniqueness MemBind) varMemInfo name = do dec <- TC.lookupVar name case dec of LetName (_, summary) -> pure $ letDecMem summary FParamName summary -> pure $ noUniquenessReturns summary LParamName summary -> pure summary IndexName it -> pure $ MemPrim $ IntType it -- | Turn info into memory information. nameInfoToMemInfo :: (Mem rep inner) => NameInfo rep -> MemBound NoUniqueness nameInfoToMemInfo info = case info of FParamName summary -> noUniquenessReturns summary LParamName summary -> summary LetName summary -> letDecMem summary IndexName it -> MemPrim $ IntType it -- | Look up information about the memory block with this name. lookupMemInfo :: (HasScope rep m, Mem rep inner) => VName -> m (MemInfo SubExp NoUniqueness MemBind) lookupMemInfo = fmap nameInfoToMemInfo . lookupInfo subExpMemInfo :: (HasScope rep m, Mem rep inner) => SubExp -> m (MemInfo SubExp NoUniqueness MemBind) subExpMemInfo (Var v) = lookupMemInfo v subExpMemInfo (Constant v) = pure $ MemPrim $ primValueType v lookupArraySummary :: (Mem rep inner, HasScope rep m, Monad m) => VName -> m (VName, LMAD.LMAD (TPrimExp Int64 VName)) lookupArraySummary name = do summary <- lookupMemInfo name case summary of MemArray _ _ _ (ArrayIn mem lmad) -> pure (mem, lmad) _ -> error . T.unpack $ "Expected " <> prettyText name <> " to be array but bound to:\n" <> prettyText summary lookupMemSpace :: (Mem rep inner, HasScope rep m, Monad m) => VName -> m Space lookupMemSpace name = do summary <- lookupMemInfo name case summary of MemMem space -> pure space _ -> error . T.unpack $ "Expected " <> prettyText name <> " to be memory but bound to:\n" <> prettyText summary checkMemInfo :: (TC.Checkable rep) => VName -> MemInfo SubExp u MemBind -> TC.TypeM rep () checkMemInfo _ (MemPrim _) = pure () checkMemInfo _ (MemMem (ScalarSpace d _)) = mapM_ (TC.require [Prim int64]) d checkMemInfo _ (MemMem _) = pure () checkMemInfo _ (MemAcc acc ispace ts u) = TC.checkType $ Acc acc ispace ts u checkMemInfo name (MemArray _ shape _ (ArrayIn v lmad)) = do t <- lookupType v case t of Mem {} -> pure () _ -> TC.bad $ TC.TypeError $ "Variable " <> prettyText v <> " used as memory block, but is of type " <> prettyText t <> "." TC.context ("in index function " <> prettyText lmad) $ do traverse_ (TC.requirePrimExp int64 . untyped) lmad unless (LMAD.shape lmad == map pe64 (shapeDims shape)) $ TC.bad $ TC.TypeError $ "Shape of index function (" <> prettyText (LMAD.shape lmad) <> ") does not match shape of array " <> prettyText name <> " (" <> prettyText shape <> ")" bodyReturnsFromPat :: Pat (MemBound NoUniqueness) -> [(VName, BodyReturns)] bodyReturnsFromPat pat = map asReturns $ patElems pat where ctx = patElems pat ext (Var v) | Just (i, _) <- find ((== v) . patElemName . snd) $ zip [0 ..] ctx = Ext i ext se = Free se asReturns pe = ( patElemName pe, case patElemDec pe of MemPrim pt -> MemPrim pt MemMem space -> MemMem space MemArray pt shape u (ArrayIn mem lmad) -> MemArray pt (Shape $ map ext $ shapeDims shape) u $ case find ((== mem) . patElemName . snd) $ zip [0 ..] ctx of Just (i, PatElem _ (MemMem space)) -> ReturnsNewBlock space i $ existentialiseLMAD (map patElemName ctx) lmad _ -> ReturnsInBlock mem $ existentialiseLMAD [] lmad MemAcc acc ispace ts u -> MemAcc acc ispace ts u ) extReturns :: [ExtType] -> [ExpReturns] extReturns ets = evalState (mapM addDec ets) 0 where addDec (Prim bt) = pure $ MemPrim bt addDec (Mem space) = pure $ MemMem space addDec t@(Array bt shape u) | existential t = do i <- get <* modify (+ 1) pure . MemArray bt shape u . Just $ ReturnsNewBlock DefaultSpace i $ LMAD.iota 0 (map convert $ shapeDims shape) | otherwise = pure $ MemArray bt shape u Nothing addDec (Acc acc ispace ts u) = pure $ MemAcc acc ispace ts u convert (Ext i) = le64 (Ext i) convert (Free v) = Free <$> pe64 v arrayVarReturns :: (HasScope rep m, Monad m, Mem rep inner) => VName -> m (PrimType, Shape, VName, LMAD) arrayVarReturns v = do summary <- lookupMemInfo v case summary of MemArray et shape _ (ArrayIn mem lmad) -> pure (et, Shape $ shapeDims shape, mem, lmad) _ -> error . T.unpack $ "arrayVarReturns: " <> prettyText v <> " is not an array." varReturns :: (HasScope rep m, Monad m, Mem rep inner) => VName -> m ExpReturns varReturns v = do summary <- lookupMemInfo v case summary of MemPrim bt -> pure $ MemPrim bt MemArray et shape _ (ArrayIn mem lmad) -> pure $ MemArray et (fmap Free shape) NoUniqueness $ Just $ ReturnsInBlock mem $ existentialiseLMAD [] lmad MemMem space -> pure $ MemMem space MemAcc acc ispace ts u -> pure $ MemAcc acc ispace ts u subExpReturns :: (HasScope rep m, Monad m, Mem rep inner) => SubExp -> m ExpReturns subExpReturns (Var v) = varReturns v subExpReturns (Constant v) = pure $ MemPrim $ primValueType v -- | The return information of an expression. This can be seen as the -- "return type with memory annotations" of the expression. -- -- This can produce Nothing, which signifies that the result is an -- array layout that is not expressible as an index function. expReturns :: (LocalScope rep m, Mem rep inner) => Exp rep -> m (Maybe [ExpReturns]) expReturns (BasicOp (SubExp se)) = Just . pure <$> subExpReturns se expReturns (BasicOp (Opaque _ (Var v))) = Just . pure <$> varReturns v expReturns (BasicOp (Reshape k newshape v)) = do (et, _, mem, lmad) <- arrayVarReturns v case reshaper k lmad $ map pe64 $ shapeDims newshape of Just lmad' -> pure . Just $ [ MemArray et (fmap Free newshape) NoUniqueness . Just $ ReturnsInBlock mem (existentialiseLMAD [] lmad') ] Nothing -> pure Nothing where reshaper ReshapeArbitrary lmad = LMAD.reshape lmad reshaper ReshapeCoerce lmad = Just . LMAD.coerce lmad expReturns (BasicOp (Rearrange perm v)) = do (et, Shape dims, mem, lmad) <- arrayVarReturns v let lmad' = LMAD.permute lmad perm dims' = rearrangeShape perm dims pure $ Just [ MemArray et (Shape $ map Free dims') NoUniqueness $ Just $ ReturnsInBlock mem $ existentialiseLMAD [] lmad' ] expReturns (BasicOp (Index v slice)) = do Just . pure . varInfoToExpReturns <$> sliceInfo v slice expReturns (BasicOp (Update _ v _ _)) = Just . pure <$> varReturns v expReturns (BasicOp (FlatIndex v slice)) = Just . pure . varInfoToExpReturns <$> flatSliceInfo v slice expReturns (BasicOp (FlatUpdate v _ _)) = Just . pure <$> varReturns v expReturns (BasicOp op) = Just . extReturns . staticShapes <$> basicOpType op expReturns e@(Loop merge _ _) = do t <- expExtType e Just <$> zipWithM typeWithDec t (map fst merge) where typeWithDec t p = case (t, paramDec p) of ( Array pt shape u, MemArray _ _ _ (ArrayIn mem lmad) ) | Just (i, mem_p) <- isLoopVar mem, Mem space <- paramType mem_p -> pure $ MemArray pt shape u $ Just $ ReturnsNewBlock space i lmad' | otherwise -> pure $ MemArray pt shape u $ Just $ ReturnsInBlock mem lmad' where lmad' = existentialiseLMAD (map paramName mergevars) lmad (Array {}, _) -> error "expReturns: Array return type but not array merge variable." (Acc acc ispace ts u, _) -> pure $ MemAcc acc ispace ts u (Prim pt, _) -> pure $ MemPrim pt (Mem space, _) -> pure $ MemMem space isLoopVar v = find ((== v) . paramName . snd) $ zip [0 ..] mergevars mergevars = map fst merge expReturns (Apply _ _ ret _) = pure $ Just $ map (funReturnsToExpReturns . fst) ret expReturns (Match _ _ _ (MatchDec ret _)) = pure $ Just $ map bodyReturnsToExpReturns ret expReturns (Op op) = Just <$> opReturns op expReturns (WithAcc inputs lam) = Just <$> ( (<>) <$> (concat <$> mapM inputReturns inputs) <*> -- XXX: this is a bit dubious because it enforces extra copies. I -- think WithAcc should perhaps have a return annotation like If. pure (extReturns $ staticShapes $ drop num_accs $ lambdaReturnType lam) ) where inputReturns (_, arrs, _) = mapM varReturns arrs num_accs = length inputs sliceInfo :: (Monad m, HasScope rep m, Mem rep inner) => VName -> Slice SubExp -> m (MemInfo SubExp NoUniqueness MemBind) sliceInfo v slice = do (et, _, mem, lmad) <- arrayVarReturns v case sliceDims slice of [] -> pure $ MemPrim et dims -> pure $ MemArray et (Shape dims) NoUniqueness . ArrayIn mem $ LMAD.slice lmad (fmap pe64 slice) flatSliceInfo :: (Monad m, HasScope rep m, Mem rep inner) => VName -> FlatSlice SubExp -> m (MemInfo SubExp NoUniqueness MemBind) flatSliceInfo v slice@(FlatSlice offset idxs) = do (et, _, mem, lmad) <- arrayVarReturns v map (fmap pe64) idxs & FlatSlice (pe64 offset) & LMAD.flatSlice lmad & MemArray et (Shape (flatSliceDims slice)) NoUniqueness . ArrayIn mem & pure class (IsOp op) => OpReturns op where opReturns :: (Mem rep inner, Monad m, HasScope rep m) => op rep -> m [ExpReturns] opReturns op = extReturns <$> opType op instance (OpReturns inner) => OpReturns (MemOp inner) where opReturns (Alloc _ space) = pure [MemMem space] opReturns (Inner op) = opReturns op instance OpReturns NoOp where opReturns NoOp = pure [] applyFunReturns :: (Typed dec) => [FunReturns] -> [Param dec] -> [(SubExp, Type)] -> Maybe [FunReturns] applyFunReturns rets params args | Just _ <- applyRetType rettype params args = Just $ map correctDims rets | otherwise = Nothing where rettype = map declExtTypeOf rets parammap :: M.Map VName (SubExp, Type) parammap = M.fromList $ zip (map paramName params) args substSubExp (Var v) | Just (se, _) <- M.lookup v parammap = se substSubExp se = se correctDims (MemPrim t) = MemPrim t correctDims (MemMem space) = MemMem space correctDims (MemArray et shape u memsummary) = MemArray et (correctShape shape) u $ correctSummary memsummary correctDims (MemAcc acc ispace ts u) = MemAcc acc ispace ts u correctShape = Shape . map correctDim . shapeDims correctDim (Ext i) = Ext i correctDim (Free se) = Free $ substSubExp se correctSummary (ReturnsNewBlock space i lmad) = ReturnsNewBlock space i lmad correctSummary (ReturnsInBlock mem lmad) = -- FIXME: we should also do a replacement in lmad here. ReturnsInBlock mem' lmad where mem' = case M.lookup mem parammap of Just (Var v, _) -> v _ -> mem futhark-0.25.27/src/Futhark/IR/Mem/000077500000000000000000000000001475065116200165655ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/IR/Mem/Interval.hs000066400000000000000000000137221475065116200207120ustar00rootroot00000000000000{-# LANGUAGE OverloadedStrings #-} module Futhark.IR.Mem.Interval ( Interval (..), distributeOffset, expandOffset, intervalOverlap, selfOverlap, primBool, intervalPairs, justLeafExp, ) where import Data.Function (on) import Data.List (maximumBy, minimumBy, (\\)) import Futhark.Analysis.AlgSimplify qualified as AlgSimplify import Futhark.Analysis.PrimExp.Convert import Futhark.IR.Prop import Futhark.IR.Syntax hiding (Result) import Futhark.Util data Interval = Interval { lowerBound :: TPrimExp Int64 VName, numElements :: TPrimExp Int64 VName, stride :: TPrimExp Int64 VName } deriving (Show, Eq) instance FreeIn Interval where freeIn' (Interval lb ne st) = freeIn' lb <> freeIn' ne <> freeIn' st distributeOffset :: (MonadFail m) => AlgSimplify.SofP -> [Interval] -> m [Interval] distributeOffset [] interval = pure interval distributeOffset offset [] = fail $ "Cannot distribute offset " <> show offset <> " across empty interval" distributeOffset offset [Interval lb ne 1] = pure [Interval (lb + TPrimExp (AlgSimplify.sumToExp offset)) ne 1] distributeOffset offset (Interval lb ne st0 : is) | st <- AlgSimplify.Prod False [untyped st0], Just (before, quotient, after) <- focusMaybe (`AlgSimplify.maybeDivide` st) offset = distributeOffset (before <> after) $ Interval (lb + TPrimExp (AlgSimplify.sumToExp [quotient])) ne st0 : is | [st] <- AlgSimplify.simplify0 $ untyped st0, Just (before, quotient, after) <- focusMaybe (`AlgSimplify.maybeDivide` st) offset = distributeOffset (before <> after) $ Interval (lb + TPrimExp (AlgSimplify.sumToExp [quotient])) ne st0 : is | otherwise = do rest <- distributeOffset offset is pure $ Interval lb ne st0 : rest findMostComplexTerm :: AlgSimplify.SofP -> (AlgSimplify.Prod, AlgSimplify.SofP) findMostComplexTerm prods = let max_prod = maximumBy (compare `on` (length . AlgSimplify.atoms)) prods in (max_prod, prods \\ [max_prod]) findClosestStride :: [PrimExp VName] -> [Interval] -> (PrimExp VName, [PrimExp VName]) findClosestStride offset_term is = let strides = map (untyped . stride) is p = minimumBy ( compare `on` ( termDifferenceLength . minimumBy (compare `on` \s -> length (offset_term \\ AlgSimplify.atoms s)) . AlgSimplify.simplify0 ) ) strides in ( p, (offset_term \\) $ AlgSimplify.atoms $ minimumBy (compare `on` \s -> length (offset_term \\ AlgSimplify.atoms s)) $ AlgSimplify.simplify0 p ) where termDifferenceLength (AlgSimplify.Prod _ xs) = length (offset_term \\ xs) expandOffset :: AlgSimplify.SofP -> [Interval] -> Maybe AlgSimplify.SofP expandOffset [] _ = Nothing expandOffset offset i1 | (AlgSimplify.Prod b term_to_add, offset_rest) <- findMostComplexTerm offset, -- Find gnb (closest_stride, first_term_divisor) <- findClosestStride term_to_add i1, -- find (nb-b, g) target <- [AlgSimplify.Prod b $ closest_stride : first_term_divisor], -- g(nb-b) diff <- AlgSimplify.sumOfProducts $ AlgSimplify.sumToExp $ AlgSimplify.Prod b term_to_add : map AlgSimplify.negate target, -- gnb - gnb + gb = gnb - g(nb-b) replacement <- target <> diff -- gnb = g(nb-b) + gnb - gnb + gb = Just (replacement <> offset_rest) intervalOverlap :: [(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool intervalOverlap less_thans non_negatives (Interval lb1 ne1 st1) (Interval lb2 ne2 st2) | st1 == st2, AlgSimplify.lessThanish less_thans non_negatives lb1 lb2, AlgSimplify.lessThanish less_thans non_negatives (lb1 + ne1 - 1) lb2 = False | st1 == st2, AlgSimplify.lessThanish less_thans non_negatives lb2 lb1, AlgSimplify.lessThanish less_thans non_negatives (lb2 + ne2 - 1) lb1 = False | otherwise = True primBool :: TPrimExp Bool VName -> Maybe Bool primBool p | Just (BoolValue b) <- evalPrimExp (const Nothing) $ untyped p = Just b | otherwise = Nothing intervalPairs :: [Interval] -> [Interval] -> [(Interval, Interval)] intervalPairs = intervalPairs' [] where intervalPairs' :: [(Interval, Interval)] -> [Interval] -> [Interval] -> [(Interval, Interval)] intervalPairs' acc [] [] = reverse acc intervalPairs' acc (i@(Interval lb _ st) : is) [] = intervalPairs' ((i, Interval lb 1 st) : acc) is [] intervalPairs' acc [] (i@(Interval lb _ st) : is) = intervalPairs' ((Interval lb 1 st, i) : acc) [] is intervalPairs' acc (i1@(Interval lb1 _ st1) : is1) (i2@(Interval lb2 _ st2) : is2) | st1 == st2 = intervalPairs' ((i1, i2) : acc) is1 is2 | otherwise = let res1 = intervalPairs' ((i1, Interval lb1 1 st1) : acc) is1 (i2 : is2) res2 = intervalPairs' ((Interval lb2 1 st2, i2) : acc) (i1 : is1) is2 in if length res1 <= length res2 then res1 else res2 -- | Returns true if the intervals are self-overlapping, meaning that for a -- given dimension d, the stride of d is larger than the aggregate spans of the -- lower dimensions. selfOverlap :: scope -> asserts -> [(VName, PrimExp VName)] -> [PrimExp VName] -> [Interval] -> Maybe Interval selfOverlap _ _ _ _ [_] = Nothing selfOverlap _ _ less_thans non_negatives' is | Just non_negatives <- namesFromList <$> mapM justLeafExp non_negatives' = -- TODO: Do we need to do something clever using some ranges of known values? let selfOverlap' acc (x : xs) = let interval_span = (lowerBound x + numElements x - 1) * stride x res = AlgSimplify.lessThanish less_thans non_negatives (AlgSimplify.simplify' acc) (AlgSimplify.simplify' $ stride x) in if res then selfOverlap' (acc + interval_span) xs else Just x selfOverlap' _ [] = Nothing in selfOverlap' 0 $ reverse is selfOverlap _ _ _ _ (x : _) = Just x selfOverlap _ _ _ _ [] = Nothing justLeafExp :: PrimExp VName -> Maybe VName justLeafExp (LeafExp v _) = Just v justLeafExp _ = Nothing futhark-0.25.27/src/Futhark/IR/Mem/LMAD.hs000066400000000000000000000556031475065116200176470ustar00rootroot00000000000000-- | This module contains a representation of linear-memory accessor -- descriptors (LMAD); see work by Zhu, Hoeflinger and David. -- -- This module is designed to be used as a qualified import, as the -- exported names are quite generic. module Futhark.IR.Mem.LMAD ( -- * Core Shape, Indices, LMAD (..), LMADDim (..), Permutation, index, slice, flatSlice, reshape, coerce, permute, shape, substitute, iota, equivalent, range, -- * Exotic expand, isDirect, disjoint, disjoint2, disjoint3, dynamicEqualsLMAD, mkExistential, closeEnough, existentialize, existentialized, ) where import Control.Category import Control.Monad import Control.Monad.State import Data.Function (on, (&)) import Data.List (elemIndex, partition, sortBy) import Data.Map.Strict qualified as M import Data.Maybe (fromJust, isNothing) import Data.Traversable import Futhark.Analysis.AlgSimplify qualified as AlgSimplify import Futhark.Analysis.PrimExp import Futhark.Analysis.PrimExp.Convert import Futhark.IR.Mem.Interval import Futhark.IR.Prop import Futhark.IR.Syntax ( DimIndex (..), Ext (..), FlatDimIndex (..), FlatSlice (..), Slice (..), Type, unitSlice, ) import Futhark.IR.Syntax.Core (VName (..)) import Futhark.Transform.Rename import Futhark.Transform.Substitute import Futhark.Util import Futhark.Util.IntegralExp import Futhark.Util.Pretty import Prelude hiding (gcd, id, mod, (.)) -- | The shape of an index function. type Shape num = [num] -- | Indices passed to an LMAD. Must always match the rank of the LMAD. type Indices num = [num] -- | A complete permutation. type Permutation = [Int] -- | A single dimension in an 'LMAD'. data LMADDim num = LMADDim { ldStride :: num, ldShape :: num } deriving (Show, Eq, Ord) -- | LMAD's representation consists of a general offset and for each -- dimension a stride, number of elements (or shape), and -- permutation. Note that the permutation is not strictly necessary in -- that the permutation can be performed directly on LMAD dimensions, -- but then it is difficult to extract the permutation back from an -- LMAD. -- -- LMAD algebra is closed under composition w.r.t. operators such as -- permute, index and slice. However, other operations, such as -- reshape, cannot always be represented inside the LMAD algebra. -- -- It follows that the general representation of an index function is a list of -- LMADS, in which each following LMAD in the list implicitly corresponds to an -- irregular reshaping operation. -- -- However, we expect that the common case is when the index function is one -- LMAD -- we call this the "nice" representation. -- -- Finally, the list of LMADs is kept in an @LMAD@ together with the shape of -- the original array, and a bit to indicate whether the index function is -- contiguous, i.e., if we instantiate all the points of the current index -- function, do we get a contiguous memory interval? -- -- By definition, the LMAD \( \sigma + \{ (n_1, s_1), \ldots, (n_k, s_k) \} \), -- where \(n\) and \(s\) denote the shape and stride of each dimension, denotes -- the set of points: -- -- \[ -- \{ ~ \sigma + i_1 * s_1 + \ldots + i_m * s_m ~ | ~ 0 \leq i_1 < n_1, \ldots, 0 \leq i_m < n_m ~ \} -- \] data LMAD num = LMAD { offset :: num, dims :: [LMADDim num] } deriving (Show, Eq, Ord) instance (Pretty num) => Pretty (LMAD num) where pretty (LMAD offset dims) = braces . semistack $ [ "offset:" <+> group (pretty offset), "strides:" <+> p ldStride, "shape:" <+> p ldShape ] where p f = group $ brackets $ align $ commasep $ map (pretty . f) dims instance (Substitute num) => Substitute (LMAD num) where substituteNames substs = fmap $ substituteNames substs instance (Substitute num) => Rename (LMAD num) where rename = substituteRename instance (FreeIn num) => FreeIn (LMAD num) where freeIn' = foldMap freeIn' instance (FreeIn num) => FreeIn (LMADDim num) where freeIn' (LMADDim s n) = freeIn' s <> freeIn' n instance Functor LMAD where fmap = fmapDefault instance Foldable LMAD where foldMap = foldMapDefault instance Traversable LMAD where traverse f (LMAD offset dims) = LMAD <$> f offset <*> traverse f' dims where f' (LMADDim s n) = LMADDim <$> f s <*> f n flatOneDim :: (Eq num, IntegralExp num) => num -> num -> num flatOneDim s i | s == 0 = 0 | otherwise = i * s index :: (IntegralExp num, Eq num) => LMAD num -> Indices num -> num index (LMAD off dims) inds = off + sum prods where prods = zipWith flatOneDim (map ldStride dims) inds -- | Handle the case where a slice can stay within a single LMAD. slice :: (Eq num, IntegralExp num) => LMAD num -> Slice num -> LMAD num slice lmad@(LMAD _ ldims) (Slice is) = foldl sliceOne (LMAD (offset lmad) []) $ zip is ldims where sliceOne :: (Eq num, IntegralExp num) => LMAD num -> (DimIndex num, LMADDim num) -> LMAD num sliceOne (LMAD off dims) (DimFix i, LMADDim s _x) = LMAD (off + flatOneDim s i) dims sliceOne (LMAD off dims) (DimSlice _ ne _, LMADDim 0 _) = LMAD off (dims ++ [LMADDim 0 ne]) sliceOne (LMAD off dims) (dmind, dim@(LMADDim _ n)) | dmind == unitSlice 0 n = LMAD off (dims ++ [dim]) sliceOne (LMAD off dims) (dmind, LMADDim s n) | dmind == DimSlice (n - 1) n (-1) = let off' = off + flatOneDim s (n - 1) in LMAD off' (dims ++ [LMADDim (s * (-1)) n]) sliceOne (LMAD off dims) (DimSlice b ne 0, LMADDim s _) = LMAD (off + flatOneDim s b) (dims ++ [LMADDim 0 ne]) sliceOne (LMAD off dims) (DimSlice bs ns ss, LMADDim s _) = LMAD (off + s * bs) (dims ++ [LMADDim (ss * s) ns]) -- | Flat-slice an LMAD. flatSlice :: (IntegralExp num) => LMAD num -> FlatSlice num -> LMAD num flatSlice (LMAD offset (dim : dims)) (FlatSlice new_offset is) = LMAD (offset + new_offset * ldStride dim) (map (helper $ ldStride dim) is <> dims) where helper s0 (FlatDimIndex n s) = LMADDim (s0 * s) n flatSlice (LMAD offset []) _ = LMAD offset [] -- | Reshape an LMAD. -- -- There are four conditions that all must hold for the result of a reshape -- operation to remain in the one-LMAD domain: -- -- (1) the permutation of the underlying LMAD must leave unchanged -- the LMAD dimensions that were *not* reshape coercions. -- (2) the repetition of dimensions of the underlying LMAD must -- refer only to the coerced-dimensions of the reshape operation. -- -- If any of these conditions do not hold, then the reshape operation -- will conservatively add a new LMAD to the list, leading to a -- representation that provides less opportunities for further -- analysis reshape :: (Eq num, IntegralExp num) => LMAD num -> Shape num -> Maybe (LMAD num) -- -- First a special case for when we are merely injecting unit -- dimensions into an LMAD. reshape (LMAD off dims) newshape | Just dims' <- addingVacuous newshape dims = Just $ LMAD off dims' where addingVacuous (dnew : dnews) (dold : dolds) | dnew == ldShape dold = (dold :) <$> addingVacuous dnews dolds addingVacuous (1 : dnews) dolds = (LMADDim 0 1 :) <$> addingVacuous dnews dolds addingVacuous [] [] = Just [] addingVacuous _ _ = Nothing -- Then the general case. reshape lmad@(LMAD off dims) newshape = do let base_stride = ldStride (last dims) no_zero_stride = all (\ld -> ldStride ld /= 0) dims strides_as_expected = lmad == iotaStrided off base_stride (shape lmad) guard $ no_zero_stride && strides_as_expected Just $ iotaStrided off base_stride newshape {-# NOINLINE reshape #-} -- | Coerce an index function to look like it has a new shape. -- Dynamically the shape must be the same. coerce :: LMAD num -> Shape num -> LMAD num coerce (LMAD offset dims) new_shape = LMAD offset $ zipWith onDim dims new_shape where onDim ld d = ld {ldShape = d} {-# NOINLINE coerce #-} -- | Substitute a name with a PrimExp in an LMAD. substitute :: (Ord a) => M.Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a) substitute tab (LMAD offset dims) = LMAD (sub offset) $ map (\(LMADDim s n) -> LMADDim (sub s) (sub n)) dims where tab' = fmap untyped tab sub = TPrimExp . substituteInPrimExp tab' . untyped -- | Shape of an LMAD. shape :: LMAD num -> Shape num shape = map ldShape . dims iotaStrided :: (IntegralExp num) => -- | Offset num -> -- | Base Stride num -> -- | Shape [num] -> LMAD num iotaStrided off s ns = let ss = tail $ reverse $ scanl (*) s $ reverse ns in LMAD off $ zipWith LMADDim ss ns -- | Generalised iota with user-specified offset. iota :: (IntegralExp num) => -- | Offset num -> -- | Shape [num] -> LMAD num iota off = iotaStrided off 1 {-# NOINLINE iota #-} -- | Create an LMAD that is existential in everything except shape. mkExistential :: Shape (Ext a) -> Int -> LMAD (Ext a) mkExistential shp start = LMAD (Ext start) $ zipWith onDim shp [0 .. r - 1] where r = length shp onDim d i = LMADDim {ldStride = Ext (start + 1 + i), ldShape = d} -- | Permute dimensions. permute :: LMAD num -> Permutation -> LMAD num permute lmad perm = lmad {dims = rearrangeShape perm $ dims lmad} -- | Computes the maximum span of an 'LMAD'. The result is the lowest and -- highest flat values representable by that 'LMAD'. flatSpan :: LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName flatSpan (LMAD _ dims) = foldr ( \dim upper -> let spn = ldStride dim * (ldShape dim - 1) in -- If you've gotten this far, you've already lost spn + upper ) 0 dims -- | Conservatively flatten a list of LMAD dimensions -- -- Since not all LMADs can actually be flattened, we try to overestimate the -- flattened array instead. This means that any "holes" in betwen dimensions -- will get filled out. -- conservativeFlatten :: (IntegralExp e, Ord e, Pretty e) => LMAD e -> LMAD e conservativeFlatten :: LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName)) conservativeFlatten (LMAD offset []) = pure $ LMAD offset [LMADDim 1 1] conservativeFlatten l@(LMAD _ [_]) = pure l conservativeFlatten l@(LMAD offset dims) = do strd <- foldM gcd (ldStride $ head dims) $ map ldStride dims pure $ LMAD offset [LMADDim strd (shp + 1)] where shp = flatSpan l -- | Very conservative GCD calculation. Returns 'Nothing' if the result cannot -- be immediately determined. Does not recurse at all. gcd :: TPrimExp Int64 VName -> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName) gcd x y = gcd' (abs x) (abs y) where gcd' a b | a == b = Just a gcd' 1 _ = Just 1 gcd' _ 1 = Just 1 gcd' a 0 = Just a gcd' _ _ = Nothing -- gcd' b (a `Futhark.Util.IntegralExp.rem` b) -- | Returns @True@ if the two 'LMAD's could be proven disjoint. -- -- Uses some best-approximation heuristics to determine disjointness. For two -- 1-dimensional arrays, we can guarantee whether or not they are disjoint, but -- as soon as more than one dimension is involved, things get more -- tricky. Currently, we try to 'conservativelyFlatten' any LMAD with more than -- one dimension. disjoint :: [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool disjoint less_thans non_negatives (LMAD offset1 [dim1]) (LMAD offset2 [dim2]) = doesNotDivide (gcd (ldStride dim1) (ldStride dim2)) (offset1 - offset2) || AlgSimplify.lessThanish less_thans non_negatives (offset2 + (ldShape dim2 - 1) * ldStride dim2) offset1 || AlgSimplify.lessThanish less_thans non_negatives (offset1 + (ldShape dim1 - 1) * ldStride dim1) offset2 where doesNotDivide :: Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool doesNotDivide (Just x) y = Futhark.Util.IntegralExp.mod y x & untyped & constFoldPrimExp & TPrimExp & (.==.) (0 :: TPrimExp Int64 VName) & primBool & maybe False not doesNotDivide _ _ = False disjoint less_thans non_negatives lmad1 lmad2 = case (conservativeFlatten lmad1, conservativeFlatten lmad2) of (Just lmad1', Just lmad2') -> disjoint less_thans non_negatives lmad1' lmad2' _ -> False disjoint2 :: scope -> asserts -> [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool disjoint2 _ _ less_thans non_negatives lmad1 lmad2 = let (offset1, interval1) = lmadToIntervals lmad1 (offset2, interval2) = lmadToIntervals lmad2 (neg_offset, pos_offset) = partition AlgSimplify.negated $ offset1 `AlgSimplify.sub` offset2 (interval1', interval2') = unzip $ sortBy (flip AlgSimplify.compareComplexity `on` (AlgSimplify.simplify0 . untyped . stride . fst)) $ intervalPairs interval1 interval2 in case ( distributeOffset pos_offset interval1', distributeOffset (map AlgSimplify.negate neg_offset) interval2' ) of (Just interval1'', Just interval2'') -> isNothing ( selfOverlap () () less_thans (map (flip LeafExp $ IntType Int64) $ namesToList non_negatives) interval1'' ) && isNothing ( selfOverlap () () less_thans (map (flip LeafExp $ IntType Int64) $ namesToList non_negatives) interval2'' ) && not ( all (uncurry (intervalOverlap less_thans non_negatives)) (zip interval1'' interval2'') ) _ -> False disjoint3 :: M.Map VName Type -> [PrimExp VName] -> [(VName, PrimExp VName)] -> [PrimExp VName] -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool disjoint3 scope asserts less_thans non_negatives lmad1 lmad2 = let (offset1, interval1) = lmadToIntervals lmad1 (offset2, interval2) = lmadToIntervals lmad2 interval1' = fixPoint (mergeDims . joinDims) $ sortBy (flip AlgSimplify.compareComplexity `on` (AlgSimplify.simplify0 . untyped . stride)) interval1 interval2' = fixPoint (mergeDims . joinDims) $ sortBy (flip AlgSimplify.compareComplexity `on` (AlgSimplify.simplify0 . untyped . stride)) interval2 (interval1'', interval2'') = unzip $ sortBy (flip AlgSimplify.compareComplexity `on` (AlgSimplify.simplify0 . untyped . stride . fst)) $ intervalPairs interval1' interval2' in disjointHelper 4 interval1'' interval2'' $ offset1 `AlgSimplify.sub` offset2 where disjointHelper :: Int -> [Interval] -> [Interval] -> AlgSimplify.SofP -> Bool disjointHelper 0 _ _ _ = False disjointHelper i is10 is20 offset = let (is1, is2) = unzip $ sortBy (flip AlgSimplify.compareComplexity `on` (AlgSimplify.simplify0 . untyped . stride . fst)) $ intervalPairs is10 is20 (neg_offset, pos_offset) = partition AlgSimplify.negated offset in case ( distributeOffset pos_offset is1, distributeOffset (map AlgSimplify.negate neg_offset) is2 ) of (Just is1', Just is2') -> do let overlap1 = selfOverlap scope asserts less_thans non_negatives is1' let overlap2 = selfOverlap scope asserts less_thans non_negatives is2' case (overlap1, overlap2) of (Nothing, Nothing) -> case namesFromList <$> mapM justLeafExp non_negatives of Just non_negatives' -> not $ all (uncurry (intervalOverlap less_thans non_negatives')) (zip is1 is2) _ -> False (Just overlapping_dim, _) -> let expanded_offset = AlgSimplify.simplifySofP' <$> expandOffset offset is1 splits = splitDim overlapping_dim is1' in all (\(new_offset, new_is1) -> disjointHelper (i - 1) (joinDims new_is1) (joinDims is2') new_offset) splits || maybe False (disjointHelper (i - 1) is1 is2) expanded_offset (_, Just overlapping_dim) -> let expanded_offset = AlgSimplify.simplifySofP' <$> expandOffset offset is2 splits = splitDim overlapping_dim is2' in all ( \(new_offset, new_is2) -> disjointHelper (i - 1) (joinDims is1') (joinDims new_is2) $ map AlgSimplify.negate new_offset ) splits || maybe False (disjointHelper (i - 1) is1 is2) expanded_offset _ -> False joinDims :: [Interval] -> [Interval] joinDims = helper [] where helper acc [] = reverse acc helper acc [x] = reverse $ x : acc helper acc (x : y : rest) = if stride x == stride y && lowerBound x == 0 && lowerBound y == 0 then helper acc $ x {numElements = numElements x * numElements y} : rest else helper (x : acc) (y : rest) mergeDims :: [Interval] -> [Interval] mergeDims = helper [] . reverse where helper acc [] = acc helper acc [x] = x : acc helper acc (x : y : rest) = if stride x * numElements x == stride y && lowerBound x == 0 && lowerBound y == 0 then helper acc $ x {numElements = numElements x * numElements y} : rest else helper (x : acc) (y : rest) splitDim :: Interval -> [Interval] -> [(AlgSimplify.SofP, [Interval])] splitDim overlapping_dim0 is | [st] <- AlgSimplify.simplify0 $ untyped $ stride overlapping_dim0, [st1] <- AlgSimplify.simplify0 $ untyped $ stride overlapping_dim, [spn] <- AlgSimplify.simplify0 $ untyped $ stride overlapping_dim * numElements overlapping_dim, lowerBound overlapping_dim == 0, Just big_dim_elems <- AlgSimplify.maybeDivide spn st, Just small_dim_elems <- AlgSimplify.maybeDivide st st1 = [ ( [], init before <> [ Interval 0 (isInt64 $ AlgSimplify.prodToExp big_dim_elems) (stride overlapping_dim0), Interval 0 (isInt64 $ AlgSimplify.prodToExp small_dim_elems) (stride overlapping_dim) ] <> after ) ] | otherwise = let shrunk_dim = overlapping_dim {numElements = numElements overlapping_dim - 1} point_offset = AlgSimplify.simplify0 $ untyped $ (numElements overlapping_dim - 1 + lowerBound overlapping_dim) * stride overlapping_dim in [ (point_offset, before <> after), ([], before <> [shrunk_dim] <> after) ] where (before, overlapping_dim, after) = fromJust $ elemIndex overlapping_dim0 is >>= (flip focusNth is . (+ 1)) lmadToIntervals :: LMAD (TPrimExp Int64 VName) -> (AlgSimplify.SofP, [Interval]) lmadToIntervals (LMAD offset []) = (AlgSimplify.simplify0 $ untyped offset, [Interval 0 1 1]) lmadToIntervals (LMAD offset dims0) = (offset', map helper dims0) where offset' = AlgSimplify.simplify0 $ untyped offset helper :: LMADDim (TPrimExp Int64 VName) -> Interval helper (LMADDim strd shp) = do Interval 0 (AlgSimplify.simplify' shp) (AlgSimplify.simplify' strd) -- | Dynamically determine if two 'LMADDim' are equal. -- -- True if the dynamic values of their constituents are equal. dynamicEqualsLMADDim :: (Eq num) => LMADDim (TPrimExp t num) -> LMADDim (TPrimExp t num) -> TPrimExp Bool num dynamicEqualsLMADDim dim1 dim2 = ldStride dim1 .==. ldStride dim2 .&&. ldShape dim1 .==. ldShape dim2 -- | Dynamically determine if two 'LMAD' are equal. -- -- True if offset and constituent 'LMADDim' are equal. dynamicEqualsLMAD :: (Eq num) => LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num dynamicEqualsLMAD lmad1 lmad2 = offset lmad1 .==. offset lmad2 .&&. foldr ((.&&.) . uncurry dynamicEqualsLMADDim) true (zip (dims lmad1) (dims lmad2)) {-# NOINLINE dynamicEqualsLMAD #-} -- | Returns true if two 'LMAD's are equivalent. -- -- Equivalence in this case is matching in offsets and strides. equivalent :: (Eq num) => LMAD num -> LMAD num -> Bool equivalent lmad1 lmad2 = offset lmad1 == offset lmad2 && map ldStride (dims lmad1) == map ldStride (dims lmad2) {-# NOINLINE equivalent #-} -- | Is this is a row-major array with zero offset? isDirect :: (Eq num, IntegralExp num) => LMAD num -> Bool isDirect lmad = lmad == iota 0 (map ldShape $ dims lmad) {-# NOINLINE isDirect #-} -- | The largest possible linear address reachable by this LMAD, not -- counting the offset. If you add one to this number (and multiply it -- with the element size), you get the amount of bytes you need to -- allocate for an array with this LMAD (assuming zero offset). range :: (Pretty num) => LMAD (TPrimExp Int64 num) -> TPrimExp Int64 num range lmad = -- The idea is that the largest possible offset must be the sum of -- the maximum offsets reachable in each dimension, which must be at -- either the minimum or maximum index. sum (map dimRange $ dims lmad) where dimRange LMADDim {ldStride, ldShape} = 0 `sMax64` ((0 `sMax64` (ldShape - 1)) * ldStride) {-# NOINLINE range #-} -- | When comparing LMADs as part of the type check in GPUMem, we -- may run into problems caused by the simplifier. As index functions -- can be generalized over if-then-else expressions, the simplifier -- might hoist some of the code from inside the if-then-else -- (computing the offset of an array, for instance), but now the type -- checker cannot verify that the generalized index function is valid, -- because some of the existentials are computed somewhere else. To -- Work around this, we've had to relax the KernelsMem type-checker a -- bit, specifically, we've introduced this function to verify whether -- two index functions are "close enough" that we can assume that they -- match. We use this instead of `lmad1 == lmad2` and hope that it's -- good enough. closeEnough :: LMAD num -> LMAD num -> Bool closeEnough lmad1 lmad2 = length (dims lmad1) == length (dims lmad2) {-# NOINLINE closeEnough #-} -- | Turn all the leaves of the LMAD into 'Ext's, except for -- the shape, which where the leaves are simply made 'Free'. existentialize :: Int -> LMAD (TPrimExp Int64 a) -> LMAD (TPrimExp Int64 (Ext a)) existentialize start lmad = evalState lmad' start where mkExt = do i <- get put $ i + 1 pure $ TPrimExp $ LeafExp (Ext i) int64 lmad' = LMAD <$> mkExt <*> mapM onDim (dims lmad) onDim ld = LMADDim <$> mkExt <*> pure (fmap Free (ldShape ld)) -- | Retrieve those elements that 'existentialize' changes. That is, -- everything except the shape (and in the same order as -- 'existentialise' existentialises them). existentialized :: LMAD a -> [a] existentialized (LMAD offset dims) = offset : concatMap onDim dims where onDim (LMADDim ldstride _) = [ldstride] -- | Conceptually expand LMAD to be a particular slice of -- another by adjusting the offset and strides. Used for memory -- expansion. expand :: (IntegralExp num) => num -> num -> LMAD num -> LMAD num expand o p lmad = LMAD (o + p * offset lmad) (map onDim (dims lmad)) where onDim ld = ld {ldStride = p * ldStride ld} futhark-0.25.27/src/Futhark/IR/Mem/Simplify.hs000066400000000000000000000071701475065116200207220ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.IR.Mem.Simplify ( simplifyProgGeneric, simplifyStmsGeneric, simpleGeneric, SimplifyMemory, memRuleBook, ) where import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Analysis.UsageTable qualified as UT import Futhark.Construct import Futhark.IR.Mem import Futhark.IR.Prop.Aliases (AliasedOp) import Futhark.Optimise.Simplify qualified as Simplify import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Optimise.Simplify.Rep import Futhark.Optimise.Simplify.Rule import Futhark.Optimise.Simplify.Rules import Futhark.Pass import Futhark.Pass.ExplicitAllocations (simplifiable) -- | Some constraints that must hold for the simplification rules to work. type SimplifyMemory rep inner = ( Simplify.SimplifiableRep rep, LetDec rep ~ LetDecMem, ExpDec rep ~ (), BodyDec rep ~ (), CanBeWise (OpC rep), BuilderOps (Wise rep), OpReturns inner, ST.IndexOp (inner (Wise rep)), AliasedOp inner, Mem rep inner, CanBeWise inner, RephraseOp inner, ASTConstraints (inner (Engine.Wise rep)) ) simpleGeneric :: (SimplifyMemory rep inner) => (inner (Wise rep) -> UT.UsageTable) -> Simplify.SimplifyOp rep (inner (Wise rep)) -> Simplify.SimpleOps rep simpleGeneric = simplifiable simplifyProgGeneric :: (SimplifyMemory rep inner) => RuleBook (Wise rep) -> Simplify.SimpleOps rep -> Prog rep -> PassM (Prog rep) simplifyProgGeneric rules ops = Simplify.simplifyProg ops rules blockers {Engine.blockHoistBranch = blockAllocs} where blockAllocs vtable _ (Let _ _ (Op Alloc {})) = not $ ST.simplifyMemory vtable -- Do not hoist statements that produce arrays. This is -- because in the KernelsMem representation, multiple -- arrays can be located in the same memory block, and moving -- their creation out of a branch can thus cause memory -- corruption. At this point in the compiler we have probably -- already moved all the array creations that matter. blockAllocs _ _ (Let pat _ _) = not $ all primType $ patTypes pat simplifyStmsGeneric :: ( HasScope rep m, MonadFreshNames m, SimplifyMemory rep inner ) => RuleBook (Wise rep) -> Simplify.SimpleOps rep -> Stms rep -> m (Stms rep) simplifyStmsGeneric rules ops stms = do scope <- askScope Simplify.simplifyStms ops rules blockers scope stms isResultAlloc :: (OpC rep ~ MemOp op) => Engine.BlockPred rep isResultAlloc _ usage (Let (Pat [pe]) _ (Op Alloc {})) = UT.isInResult (patElemName pe) usage isResultAlloc _ _ _ = False isAlloc :: (OpC rep ~ MemOp op) => Engine.BlockPred rep isAlloc _ _ (Let _ _ (Op Alloc {})) = True isAlloc _ _ _ = False blockers :: (OpC rep ~ MemOp inner) => Simplify.HoistBlockers rep blockers = Engine.noExtraHoistBlockers { Engine.blockHoistPar = isAlloc, Engine.blockHoistSeq = isResultAlloc, Engine.isAllocation = isAlloc mempty mempty } -- | Standard collection of simplification rules for representations -- with memory. memRuleBook :: (SimplifyMemory rep inner) => RuleBook (Wise rep) memRuleBook = standardRules <> ruleBook [ RuleOp decertifySafeAlloc ] [] -- If an allocation is statically known to be safe, then we can remove -- the certificates on it. This can help hoist things that would -- otherwise be stuck inside loops or branches. decertifySafeAlloc :: (SimplifyMemory rep inner) => TopDownRuleOp (Wise rep) decertifySafeAlloc _ pat (StmAux cs attrs _) op | cs /= mempty, [Mem _] <- patTypes pat, safeOp op = Simplify $ attributing attrs $ letBind pat $ Op op decertifySafeAlloc _ _ _ _ = Skip futhark-0.25.27/src/Futhark/IR/Parse.hs000066400000000000000000000767621475065116200174770ustar00rootroot00000000000000-- | Parser for the Futhark core language. module Futhark.IR.Parse ( -- * Programs parseSOACS, parseGPU, parseGPUMem, parseMC, parseMCMem, parseSeq, parseSeqMem, -- * Fragments parseType, parseDeclExtType, parseDeclType, parseVName, parseSubExp, parseSubExpRes, parseBodyGPU, parseBodyMC, parseStmGPU, parseStmMC, ) where import Data.Char (isAlpha) import Data.Functor import Data.List (singleton) import Data.List.NonEmpty (NonEmpty (..)) import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T import Data.Void import Futhark.Analysis.PrimExp.Parse import Futhark.IR import Futhark.IR.GPU (GPU) import Futhark.IR.GPU.Op qualified as GPU import Futhark.IR.GPUMem (GPUMem) import Futhark.IR.MC (MC) import Futhark.IR.MC.Op qualified as MC import Futhark.IR.MCMem (MCMem) import Futhark.IR.Mem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.IR.SOACS (SOACS) import Futhark.IR.SOACS.SOAC qualified as SOAC import Futhark.IR.SegOp qualified as SegOp import Futhark.IR.Seq (Seq) import Futhark.IR.SeqMem (SeqMem) import Language.Futhark.Primitive.Parse import Text.Megaparsec import Text.Megaparsec.Char hiding (space) import Text.Megaparsec.Char.Lexer qualified as L type Parser = Parsec Void T.Text pStringLiteral :: Parser T.Text pStringLiteral = lexeme . fmap T.pack $ char '"' >> manyTill L.charLiteral (char '"') pName :: Parser Name pName = lexeme . fmap nameFromString $ (:) <$> satisfy leading <*> many (satisfy constituent) where leading c = isAlpha c || c `elem` ("_+-*/%=!<>|&^.#" :: String) pVName :: Parser VName pVName = lexeme $ do (s, tag) <- choice [exprBox, singleton <$> satisfy constituent] `manyTill_` try pTag "variable name" pure $ VName (nameFromString $ concat s) tag where pTag = "_" *> L.decimal <* notFollowedBy (satisfy constituent) exprBox = ("<{" <>) . (<> "}>") <$> (chunk "<{" *> manyTill anySingle (chunk "}>")) pInt :: Parser Int pInt = lexeme L.decimal pInt64 :: Parser Int64 pInt64 = lexeme L.decimal braces, brackets, parens :: Parser a -> Parser a braces = between (lexeme "{") (lexeme "}") brackets = between (lexeme "[") (lexeme "]") parens = between (lexeme "(") (lexeme ")") pComma, pColon, pSemi, pEqual, pSlash, pAsterisk, pArrow :: Parser () pComma = void $ lexeme "," pColon = void $ lexeme ":" pSemi = void $ lexeme ";" pEqual = void $ lexeme "=" pSlash = void $ lexeme "/" pAsterisk = void $ lexeme "*" pArrow = void $ lexeme "->" pNonArray :: Parser (TypeBase shape NoUniqueness) pNonArray = choice [ Prim <$> pPrimType, "acc" *> parens ( Acc <$> pVName <* pComma <*> pShape <* pComma <*> pTypes <*> pure NoUniqueness ) ] pTypeBase :: (ArrayShape shape) => Parser shape -> Parser u -> Parser (TypeBase shape u) pTypeBase ps pu = do u <- pu shape <- ps arrayOf <$> pNonArray <*> pure shape <*> pure u pShape :: Parser Shape pShape = Shape <$> many (brackets pSubExp) pExt :: Parser a -> Parser (Ext a) pExt p = choice [ lexeme $ "?" $> Ext <*> L.decimal, Free <$> p ] pExtSize :: Parser ExtSize pExtSize = pExt pSubExp pExtShape :: Parser ExtShape pExtShape = Shape <$> many (brackets pExtSize) pType :: Parser Type pType = pTypeBase pShape (pure NoUniqueness) pTypes :: Parser [Type] pTypes = braces $ pType `sepBy` pComma pExtType :: Parser ExtType pExtType = pTypeBase pExtShape (pure NoUniqueness) pRank :: Parser Rank pRank = Rank . length <$> many (lexeme "[" *> lexeme "]") pUniqueness :: Parser Uniqueness pUniqueness = choice [pAsterisk $> Unique, pure Nonunique] pDeclBase :: Parser (TypeBase shape NoUniqueness) -> Parser (TypeBase shape Uniqueness) pDeclBase p = flip toDecl <$> pUniqueness <*> p pDeclType :: Parser DeclType pDeclType = pDeclBase pType pDeclExtType :: Parser DeclExtType pDeclExtType = pDeclBase pExtType pSubExp :: Parser SubExp pSubExp = Var <$> pVName <|> Constant <$> pPrimValue pSubExps :: Parser [SubExp] pSubExps = braces (pSubExp `sepBy` pComma) pVNames :: Parser [VName] pVNames = braces (pVName `sepBy` pComma) pConvOp :: T.Text -> (t1 -> t2 -> ConvOp) -> Parser t1 -> Parser t2 -> Parser BasicOp pConvOp s op t1 t2 = keyword s $> op' <*> t1 <*> pSubExp <*> (keyword "to" *> t2) where op' f se t = ConvOp (op f t) se pBinOp :: Parser BasicOp pBinOp = choice (map p allBinOps) "binary op" where p bop = keyword (prettyText bop) *> parens (BinOp bop <$> pSubExp <* pComma <*> pSubExp) pCmpOp :: Parser BasicOp pCmpOp = choice (map p allCmpOps) "comparison op" where p op = keyword (prettyText op) *> parens (CmpOp op <$> pSubExp <* pComma <*> pSubExp) pUnOp :: Parser BasicOp pUnOp = choice (map p allUnOps) "unary op" where p bop = keyword (prettyText bop) $> UnOp bop <*> pSubExp pDimIndex :: Parser (DimIndex SubExp) pDimIndex = choice [ try $ DimSlice <$> pSubExp <* lexeme ":+" <*> pSubExp <* lexeme "*" <*> pSubExp, DimFix <$> pSubExp ] pSlice :: Parser (Slice SubExp) pSlice = Slice <$> brackets (pDimIndex `sepBy` pComma) pIndex :: Parser BasicOp pIndex = try $ Index <$> pVName <*> pSlice pFlatDimIndex :: Parser (FlatDimIndex SubExp) pFlatDimIndex = FlatDimIndex <$> pSubExp <* lexeme ":" <*> pSubExp pFlatSlice :: Parser (FlatSlice SubExp) pFlatSlice = brackets $ FlatSlice <$> pSubExp <* pSemi <*> (pFlatDimIndex `sepBy` pComma) pFlatIndex :: Parser BasicOp pFlatIndex = try $ FlatIndex <$> pVName <*> pFlatSlice pErrorMsgPart :: Parser (ErrorMsgPart SubExp) pErrorMsgPart = choice [ ErrorString <$> pStringLiteral, flip ErrorVal <$> (pSubExp <* pColon) <*> pPrimType ] pErrorMsg :: Parser (ErrorMsg SubExp) pErrorMsg = ErrorMsg <$> braces (pErrorMsgPart `sepBy` pComma) pSrcLoc :: Parser SrcLoc pSrcLoc = pStringLiteral $> mempty -- FIXME pErrorLoc :: Parser (SrcLoc, [SrcLoc]) pErrorLoc = (,mempty) <$> pSrcLoc pIota :: Parser BasicOp pIota = choice $ map p allIntTypes where p t = keyword ("iota" <> prettyText (primBitSize (IntType t))) *> parens ( Iota <$> pSubExp <* pComma <*> pSubExp <* pComma <*> pSubExp <*> pure t ) pBasicOp :: Parser BasicOp pBasicOp = choice [ keyword "opaque" $> Opaque OpaqueNil <*> parens pSubExp, keyword "trace" $> uncurry (Opaque . OpaqueTrace) <*> parens ((,) <$> pStringLiteral <* pComma <*> pSubExp), keyword "copy" $> Replicate mempty . Var <*> parens pVName, keyword "assert" *> parens ( Assert <$> pSubExp <* pComma <*> pErrorMsg <* pComma <*> pErrorLoc ), keyword "replicate" *> parens (Replicate <$> pShape <* pComma <*> pSubExp), keyword "reshape" *> parens (Reshape ReshapeArbitrary <$> pShape <* pComma <*> pVName), keyword "coerce" *> parens (Reshape ReshapeCoerce <$> pShape <* pComma <*> pVName), keyword "scratch" *> parens (Scratch <$> pPrimType <*> many (pComma *> pSubExp)), keyword "rearrange" *> parens (Rearrange <$> parens (pInt `sepBy` pComma) <* pComma <*> pVName), keyword "manifest" *> parens (Manifest <$> parens (pInt `sepBy` pComma) <* pComma <*> pVName), keyword "concat" *> do d <- "@" *> L.decimal parens $ do w <- pSubExp <* pComma x <- pVName ys <- many (pComma *> pVName) pure $ Concat d (x :| ys) w, pIota, try $ flip Update <$> pVName <* keyword "with" <*> choice [lexeme "?" $> Safe, pure Unsafe] <*> pSlice <* lexeme "=" <*> pSubExp, try $ FlatUpdate <$> pVName <* keyword "with" <*> pFlatSlice <* lexeme "=" <*> pVName, try $ ArrayVal <$> brackets (pPrimValue `sepBy` pComma) <*> (lexeme ":" *> "[]" *> pPrimType), ArrayLit <$> brackets (pSubExp `sepBy` pComma) <*> (lexeme ":" *> "[]" *> pType), do safety <- choice [keyword "update_acc_unsafe" $> Unsafe, keyword "update_acc" $> Safe] parens (UpdateAcc safety <$> pVName <* pComma <*> pSubExps <* pComma <*> pSubExps), -- pConvOp "sext" SExt pIntType pIntType, pConvOp "zext" ZExt pIntType pIntType, pConvOp "fpconv" FPConv pFloatType pFloatType, pConvOp "fptoui" FPToUI pFloatType pIntType, pConvOp "fptosi" FPToSI pFloatType pIntType, pConvOp "uitofp" UIToFP pIntType pFloatType, pConvOp "sitofp" SIToFP pIntType pFloatType, pConvOp "itob" (const . IToB) pIntType (keyword "bool"), pConvOp "btoi" (const BToI) (keyword "bool") pIntType, pConvOp "ftob" (const . FToB) pFloatType (keyword "bool"), pConvOp "btof" (const BToF) (keyword "bool") pFloatType, -- pIndex, pFlatIndex, pBinOp, pCmpOp, pUnOp, SubExp <$> pSubExp ] pAttr :: Parser Attr pAttr = choice [ AttrInt . toInteger <$> pInt, do v <- pName choice [ AttrComp v <$> parens (pAttr `sepBy` pComma), pure $ AttrName v ] ] pAttrs :: Parser Attrs pAttrs = Attrs . S.fromList <$> many pAttr' where pAttr' = lexeme "#[" *> pAttr <* lexeme "]" pComm :: Parser Commutativity pComm = choice [ keyword "commutative" $> Commutative, pure Noncommutative ] -- | This record contains parser for all the representation-specific -- bits. Essentially a manually passed-around type class dictionary, -- because ambiguities make it impossible to write this with actual -- type classes. data PR rep = PR { pRetType :: Parser (RetType rep), pBranchType :: Parser (BranchType rep), pFParamInfo :: Parser (FParamInfo rep), pLParamInfo :: Parser (LParamInfo rep), pLetDec :: Parser (LetDec rep), pOp :: Parser (Op rep), pBodyDec :: BodyDec rep, pExpDec :: ExpDec rep } pRetAls :: Parser RetAls pRetAls = fromMaybe (RetAls mempty mempty) <$> optional p where p = lexeme "#" *> parens (RetAls <$> pInts <* pComma <*> pInts) pInts = brackets $ pInt `sepBy` pComma pRetTypes :: PR rep -> Parser [(RetType rep, RetAls)] pRetTypes pr = braces $ ((,) <$> pRetType pr <*> pRetAls) `sepBy` pComma pBranchTypes :: PR rep -> Parser [BranchType rep] pBranchTypes pr = braces $ pBranchType pr `sepBy` pComma pParam :: Parser t -> Parser (Param t) pParam p = Param <$> pAttrs <*> pVName <*> (pColon *> p) pFParam :: PR rep -> Parser (FParam rep) pFParam = pParam . pFParamInfo pFParams :: PR rep -> Parser [FParam rep] pFParams pr = parens $ pFParam pr `sepBy` pComma pLParam :: PR rep -> Parser (LParam rep) pLParam = pParam . pLParamInfo pLParams :: PR rep -> Parser [LParam rep] pLParams pr = braces $ pLParam pr `sepBy` pComma pPatElem :: PR rep -> Parser (PatElem (LetDec rep)) pPatElem pr = (PatElem <$> pVName <*> (pColon *> pLetDec pr)) "pattern element" pPat :: PR rep -> Parser (Pat (LetDec rep)) pPat pr = Pat <$> braces (pPatElem pr `sepBy` pComma) pResult :: Parser Result pResult = braces $ pSubExpRes `sepBy` pComma pMatchSort :: Parser MatchSort pMatchSort = choice [ lexeme "" $> MatchFallback, lexeme "" $> MatchEquiv, pure MatchNormal ] pBranchBody :: PR rep -> Parser (Body rep) pBranchBody pr = choice [ try $ Body (pBodyDec pr) mempty <$> pResult, braces (pBody pr) ] pIf :: PR rep -> Parser (Exp rep) pIf pr = keyword "if" $> f <*> pMatchSort <*> pSubExp <*> (keyword "then" *> pBranchBody pr) <*> (keyword "else" *> pBranchBody pr) <*> (lexeme ":" *> pBranchTypes pr) where f sort cond tbranch fbranch t = Match [cond] [Case [Just $ BoolValue True] tbranch] fbranch $ MatchDec t sort pMatch :: PR rep -> Parser (Exp rep) pMatch pr = keyword "match" $> f <*> pMatchSort <*> braces (pSubExp `sepBy` pComma) <*> many pCase <*> (keyword "default" *> lexeme "->" *> pBranchBody pr) <*> (lexeme ":" *> pBranchTypes pr) where f sort cond cases defbody t = Match cond cases defbody $ MatchDec t sort pCase = keyword "case" $> Case <*> braces (pMaybeValue `sepBy` pComma) <* lexeme "->" <*> pBranchBody pr pMaybeValue = choice [lexeme "_" $> Nothing, Just <$> pPrimValue] pApply :: PR rep -> Parser (Exp rep) pApply pr = keyword "apply" *> (p =<< choice [lexeme "" $> Unsafe, pure Safe]) where p safety = Apply <$> pName <*> parens (pArg `sepBy` pComma) <* pColon <*> pRetTypes pr <*> pure (safety, mempty, mempty) pArg = choice [ lexeme "*" $> (,Consume) <*> pSubExp, (,Observe) <$> pSubExp ] pLoop :: PR rep -> Parser (Exp rep) pLoop pr = keyword "loop" $> Loop <*> pLoopParams <*> pLoopForm <* keyword "do" <*> braces (pBody pr) where pLoopParams = do params <- braces $ pFParam pr `sepBy` pComma void $ lexeme "=" args <- braces (pSubExp `sepBy` pComma) pure (zip params args) pLoopForm = choice [ keyword "for" $> ForLoop <*> pVName <* lexeme ":" <*> pIntType <* lexeme "<" <*> pSubExp, keyword "while" $> WhileLoop <*> pVName ] pLambda :: PR rep -> Parser (Lambda rep) pLambda pr = choice [ lexeme "\\" $> Lambda <*> pLParams pr <* pColon <*> pTypes <* pArrow <*> pBody pr, keyword "nilFn" $> Lambda mempty [] (Body (pBodyDec pr) mempty []) ] pReduce :: PR rep -> Parser (SOAC.Reduce rep) pReduce pr = SOAC.Reduce <$> pComm <*> pLambda pr <* pComma <*> braces (pSubExp `sepBy` pComma) pScan :: PR rep -> Parser (SOAC.Scan rep) pScan pr = SOAC.Scan <$> pLambda pr <* pComma <*> braces (pSubExp `sepBy` pComma) pWithAcc :: PR rep -> Parser (Exp rep) pWithAcc pr = keyword "with_acc" *> parens (WithAcc <$> braces (pInput `sepBy` pComma) <* pComma <*> pLambda pr) where pInput = parens ( (,,) <$> pShape <* pComma <*> pVNames <*> optional (pComma *> pCombFun) ) pCombFun = parens ((,) <$> pLambda pr <* pComma <*> pSubExps) pExp :: PR rep -> Parser (Exp rep) pExp pr = choice [ pIf pr, pMatch pr, pApply pr, pLoop pr, pWithAcc pr, Op <$> pOp pr, BasicOp <$> pBasicOp ] pCerts :: Parser Certs pCerts = choice [ lexeme "#" *> braces (Certs <$> pVName `sepBy` pComma) "certificates", pure mempty ] pSubExpRes :: Parser SubExpRes pSubExpRes = SubExpRes <$> pCerts <*> pSubExp pStm :: PR rep -> Parser (Stm rep) pStm pr = keyword "let" $> Let <*> pPat pr <* pEqual <*> pStmAux <*> pExp pr where pStmAux = flip StmAux <$> pAttrs <*> pCerts <*> pure (pExpDec pr) pStms :: PR rep -> Parser (Stms rep) pStms pr = stmsFromList <$> many (pStm pr) pBody :: PR rep -> Parser (Body rep) pBody pr = choice [ Body (pBodyDec pr) <$> pStms pr <* keyword "in" <*> pResult, Body (pBodyDec pr) mempty <$> pResult ] pValueType :: Parser ValueType pValueType = comb <$> pRank <*> pSignedType where comb r (s, t) = ValueType s r t pSignedType = choice [ keyword "u8" $> (Unsigned, IntType Int8), keyword "u16" $> (Unsigned, IntType Int16), keyword "u32" $> (Unsigned, IntType Int32), keyword "u64" $> (Unsigned, IntType Int64), (Signed,) <$> pPrimType ] pEntryPointType :: Parser EntryPointType pEntryPointType = choice [ keyword "opaque" $> TypeOpaque . nameFromText <*> pStringLiteral, TypeTransparent <$> pValueType ] pEntry :: Parser EntryPoint pEntry = parens $ (,,) <$> (nameFromText <$> pStringLiteral) <* pComma <*> pEntryPointInputs <* pComma <*> pEntryPointResults where pEntryPointInputs = braces (pEntryPointInput `sepBy` pComma) pEntryPointResults = braces (pEntryPointResult `sepBy` pComma) pEntryPointInput = EntryParam <$> pName <* pColon <*> pUniqueness <*> pEntryPointType pEntryPointResult = EntryResult <$> pUniqueness <*> pEntryPointType pFunDef :: PR rep -> Parser (FunDef rep) pFunDef pr = do attrs <- pAttrs entry <- choice [ keyword "entry" $> Just <*> pEntry, keyword "fun" $> Nothing ] fname <- pName fparams <- pFParams pr <* pColon ret <- pRetTypes pr FunDef entry attrs fname ret fparams <$> (pEqual *> braces (pBody pr)) pOpaqueType :: Parser (Name, OpaqueType) pOpaqueType = (,) <$> (keyword "type" *> (nameFromText <$> pStringLiteral) <* pEqual) <*> choice [pRecord, pSum, pOpaque, pRecordArray, pOpaqueArray] where pFieldName = choice [pName, nameFromString . show <$> pInt] pField = (,) <$> pFieldName <* pColon <*> pEntryPointType pRecord = keyword "record" $> OpaqueRecord <*> braces (many pField) pConstructor = "#" *> pName pPayload = parens $ (,) <$> (pEntryPointType <* pComma) <*> brackets (pInt `sepBy` pComma) pVariant = (,) <$> pConstructor <*> many pPayload pSum = keyword "sum" *> braces ( OpaqueSum <$> brackets (pValueType `sepBy` pComma) <*> many pVariant ) pOpaque = keyword "opaque" $> OpaqueType <*> braces (many pValueType) pRecordArray = keyword "record_array" $> OpaqueRecordArray <*> (pInt <* lexeme "d") <*> (nameFromText <$> pStringLiteral) <*> braces (many pField) pOpaqueArray = keyword "array" $> OpaqueArray <*> (pInt <* lexeme "d") <*> (nameFromText <$> pStringLiteral) <*> braces (many pValueType) pOpaqueTypes :: Parser OpaqueTypes pOpaqueTypes = keyword "types" $> OpaqueTypes <*> braces (many pOpaqueType) pProg :: PR rep -> Parser (Prog rep) pProg pr = Prog <$> (fromMaybe noTypes <$> optional pOpaqueTypes) <*> pStms pr <*> many (pFunDef pr) where noTypes = OpaqueTypes mempty pSOAC :: PR rep -> Parser (SOAC.SOAC rep) pSOAC pr = choice [ keyword "map" *> pScrema pMapForm, keyword "redomap" *> pScrema pRedomapForm, keyword "scanomap" *> pScrema pScanomapForm, keyword "screma" *> pScrema pScremaForm, keyword "vjp" *> pVJP, keyword "jvp" *> pJVP, pScatter, pHist, pStream ] where pScrema p = parens $ SOAC.Screma <$> pSubExp <* pComma <*> braces (pVName `sepBy` pComma) <* pComma <*> p pScremaForm = SOAC.ScremaForm <$> pLambda pr <* pComma <*> braces (pScan pr `sepBy` pComma) <* pComma <*> braces (pReduce pr `sepBy` pComma) pRedomapForm = SOAC.ScremaForm <$> pLambda pr <*> pure [] <* pComma <*> braces (pReduce pr `sepBy` pComma) pScanomapForm = SOAC.ScremaForm <$> pLambda pr <* pComma <*> braces (pScan pr `sepBy` pComma) <*> pure [] pMapForm = SOAC.ScremaForm <$> pLambda pr <*> pure mempty <*> pure mempty pScatter = keyword "scatter" *> parens ( SOAC.Scatter <$> pSubExp <* pComma <*> braces (pVName `sepBy` pComma) <* pComma <*> many (pDest <* pComma) <*> pLambda pr ) where pDest = parens $ (,,) <$> pShape <* pComma <*> pInt <* pComma <*> pVName pHist = keyword "hist" *> parens ( SOAC.Hist <$> pSubExp <* pComma <*> braces (pVName `sepBy` pComma) <* pComma <*> braces (pHistOp `sepBy` pComma) <* pComma <*> pLambda pr ) where pHistOp = SOAC.HistOp <$> pShape <* pComma <*> pSubExp <* pComma <*> braces (pVName `sepBy` pComma) <* pComma <*> braces (pSubExp `sepBy` pComma) <* pComma <*> pLambda pr pStream = keyword "streamSeq" *> pStreamSeq pStreamSeq = parens $ SOAC.Stream <$> pSubExp <* pComma <*> braces (pVName `sepBy` pComma) <* pComma <*> braces (pSubExp `sepBy` pComma) <* pComma <*> pLambda pr pVJP = parens $ SOAC.VJP <$> braces (pSubExp `sepBy` pComma) <* pComma <*> braces (pSubExp `sepBy` pComma) <* pComma <*> pLambda pr pJVP = parens $ SOAC.JVP <$> braces (pSubExp `sepBy` pComma) <* pComma <*> braces (pSubExp `sepBy` pComma) <* pComma <*> pLambda pr pSizeClass :: Parser GPU.SizeClass pSizeClass = choice [ keyword "thread_block_size" $> GPU.SizeThreadBlock, keyword "grid_size" $> GPU.SizeGrid, keyword "tile_size" $> GPU.SizeTile, keyword "reg_tile_size" $> GPU.SizeRegTile, keyword "shared_memory" $> GPU.SizeSharedMemory, keyword "threshold" *> parens ( flip GPU.SizeThreshold <$> choice [Just <$> pInt64, "def" $> Nothing] <* pComma <*> pKernelPath ), keyword "bespoke" *> parens (GPU.SizeBespoke <$> pName <* pComma <*> pInt64) ] where pKernelPath = many pStep pStep = choice [ lexeme "!" $> (,) <*> pName <*> pure False, (,) <$> pName <*> pure True ] pSizeOp :: Parser GPU.SizeOp pSizeOp = choice [ keyword "get_size" *> parens (GPU.GetSize <$> pName <* pComma <*> pSizeClass), keyword "get_size_max" *> parens (GPU.GetSizeMax <$> pSizeClass), keyword "cmp_size" *> ( parens (GPU.CmpSizeLe <$> pName <* pComma <*> pSizeClass) <*> (lexeme "<=" *> pSubExp) ), keyword "calc_num_tblocks" *> parens ( GPU.CalcNumBlocks <$> pSubExp <* pComma <*> pName <* pComma <*> pSubExp ) ] pSegSpace :: Parser SegOp.SegSpace pSegSpace = flip SegOp.SegSpace <$> parens (pDim `sepBy` pComma) <*> parens (lexeme "~" *> pVName) where pDim = (,) <$> pVName <* lexeme "<" <*> pSubExp pKernelResult :: Parser SegOp.KernelResult pKernelResult = do cs <- pCerts choice [ keyword "returns" $> SegOp.Returns <*> choice [ keyword "(manifest)" $> SegOp.ResultNoSimplify, keyword "(private)" $> SegOp.ResultPrivate, pure SegOp.ResultMaySimplify ] <*> pure cs <*> pSubExp, try $ SegOp.WriteReturns cs <$> pVName <* keyword "with" <*> parens (pWrite `sepBy` pComma), try "tile" *> parens (SegOp.TileReturns cs <$> (pTile `sepBy` pComma)) <*> pVName, try "blkreg_tile" *> parens (SegOp.RegTileReturns cs <$> (pRegTile `sepBy` pComma)) <*> pVName ] where pTile = (,) <$> pSubExp <* pSlash <*> pSubExp pRegTile = do dim <- pSubExp <* pSlash parens $ do blk_tile <- pSubExp <* pAsterisk reg_tile <- pSubExp pure (dim, blk_tile, reg_tile) pWrite = (,) <$> pSlice <* pEqual <*> pSubExp pKernelBody :: PR rep -> Parser (SegOp.KernelBody rep) pKernelBody pr = SegOp.KernelBody (pBodyDec pr) <$> pStms pr <* keyword "return" <*> braces (pKernelResult `sepBy` pComma) pSegOp :: PR rep -> Parser lvl -> Parser (SegOp.SegOp lvl rep) pSegOp pr pLvl = choice [ keyword "segmap" *> pSegMap, keyword "segred" *> pSegRed, keyword "segscan" *> pSegScan, keyword "seghist" *> pSegHist ] where pSegMap = SegOp.SegMap <$> pLvl <*> pSegSpace <* pColon <*> pTypes <*> braces (pKernelBody pr) pSegOp' f p = f <$> pLvl <*> pSegSpace <*> parens (p `sepBy` pComma) <* pColon <*> pTypes <*> braces (pKernelBody pr) pSegBinOp = do nes <- braces (pSubExp `sepBy` pComma) <* pComma shape <- pShape <* pComma comm <- pComm lam <- pLambda pr pure $ SegOp.SegBinOp comm lam nes shape pHistOp = SegOp.HistOp <$> pShape <* pComma <*> pSubExp <* pComma <*> braces (pVName `sepBy` pComma) <* pComma <*> braces (pSubExp `sepBy` pComma) <* pComma <*> pShape <* pComma <*> pLambda pr pSegRed = pSegOp' SegOp.SegRed pSegBinOp pSegScan = pSegOp' SegOp.SegScan pSegBinOp pSegHist = pSegOp' SegOp.SegHist pHistOp pSegLevel :: Parser GPU.SegLevel pSegLevel = parens . choice $ [ "thread" $> GPU.SegThread <* pSemi <*> pSegVirt <* pSemi <*> optional pKernelGrid, "block" $> GPU.SegBlock <* pSemi <*> pSegVirt <* pSemi <*> optional pKernelGrid, "inblock" $> GPU.SegThreadInBlock <* pSemi <*> pSegVirt ] where pSegVirt = choice [ choice [ keyword "full" $> GPU.SegNoVirtFull <*> (GPU.SegSeqDims <$> brackets (pInt `sepBy` pComma)), keyword "virtualise" $> GPU.SegVirt ], pure GPU.SegNoVirt ] pKernelGrid = GPU.KernelGrid <$> (lexeme "grid=" $> GPU.Count <*> pSubExp <* pSemi) <*> (lexeme "blocksize=" $> GPU.Count <*> pSubExp) pHostOp :: PR rep -> Parser (op rep) -> Parser (GPU.HostOp op rep) pHostOp pr pOther = choice [ GPU.SegOp <$> pSegOp pr pSegLevel, GPU.SizeOp <$> pSizeOp, GPU.OtherOp <$> pOther, keyword "gpu" $> GPU.GPUBody <*> (pColon *> pTypes) <*> braces (pBody pr) ] pMCOp :: PR rep -> Parser (op rep) -> Parser (MC.MCOp op rep) pMCOp pr pOther = choice [ MC.ParOp . Just <$> (keyword "par" *> braces pMCSegOp) <*> (keyword "seq" *> braces pMCSegOp), MC.ParOp Nothing <$> pMCSegOp, MC.OtherOp <$> pOther ] where pMCSegOp = pSegOp pr (void $ lexeme "()") pLMADBase :: Parser a -> Parser (LMAD.LMAD a) pLMADBase pNum = braces $ do offset <- pLab "offset" pNum <* pSemi strides <- pLab "strides" $ brackets (pNum `sepBy` pComma) <* pSemi shape <- pLab "shape" $ brackets (pNum `sepBy` pComma) pure $ LMAD.LMAD offset $ zipWith LMAD.LMADDim strides shape where pLab s m = keyword s *> pColon *> m pPrimExpLeaf :: Parser VName pPrimExpLeaf = pVName pExtPrimExpLeaf :: Parser (Ext VName) pExtPrimExpLeaf = pExt pVName pLMAD :: Parser LMAD pLMAD = pLMADBase $ isInt64 <$> pPrimExp int64 pPrimExpLeaf pExtLMAD :: Parser ExtLMAD pExtLMAD = pLMADBase $ isInt64 <$> pPrimExp int64 pExtPrimExpLeaf pMemInfo :: Parser d -> Parser u -> Parser ret -> Parser (MemInfo d u ret) pMemInfo pd pu pret = choice [ MemPrim <$> pPrimType, keyword "mem" $> MemMem <*> choice [pSpace, pure DefaultSpace], pArrayOrAcc ] where pArrayOrAcc = do u <- pu shape <- Shape <$> many (brackets pd) choice [pArray u shape, pAcc u] pArray u shape = do pt <- pPrimType MemArray pt shape u <$> (lexeme "@" *> pret) pAcc u = keyword "acc" *> parens ( MemAcc <$> pVName <* pComma <*> pShape <* pComma <*> pTypes <*> pure u ) pSpace :: Parser Space pSpace = lexeme "@" *> choice [ Space . nameToString <$> pName, ScalarSpace <$> (shapeDims <$> pShape) <*> pPrimType ] pMemBind :: Parser MemBind pMemBind = ArrayIn <$> pVName <* lexeme "->" <*> pLMAD pMemReturn :: Parser MemReturn pMemReturn = choice [ ReturnsInBlock <$> pVName <* lexeme "->" <*> pExtLMAD, do i <- "?" *> pInt space <- choice [pSpace, pure DefaultSpace] <* lexeme "->" ReturnsNewBlock space i <$> pExtLMAD ] pRetTypeMem :: Parser RetTypeMem pRetTypeMem = pMemInfo pExtSize pUniqueness pMemReturn pBranchTypeMem :: Parser BranchTypeMem pBranchTypeMem = pMemInfo pExtSize (pure NoUniqueness) pMemReturn pFParamMem :: Parser FParamMem pFParamMem = pMemInfo pSubExp pUniqueness pMemBind pLParamMem :: Parser LParamMem pLParamMem = pMemInfo pSubExp (pure NoUniqueness) pMemBind pLetDecMem :: Parser LetDecMem pLetDecMem = pMemInfo pSubExp (pure NoUniqueness) pMemBind pMemOp :: Parser (inner rep) -> Parser (MemOp inner rep) pMemOp pInner = choice [ keyword "alloc" *> parens (Alloc <$> pSubExp <*> choice [pComma *> pSpace, pure DefaultSpace]), Inner <$> pInner ] prSOACS :: PR SOACS prSOACS = PR pDeclExtType pExtType pDeclType pType pType (pSOAC prSOACS) () () prSeq :: PR Seq prSeq = PR pDeclExtType pExtType pDeclType pType pType empty () () prSeqMem :: PR SeqMem prSeqMem = PR pRetTypeMem pBranchTypeMem pFParamMem pLParamMem pLetDecMem op () () where op = pMemOp empty prGPU :: PR GPU prGPU = PR pDeclExtType pExtType pDeclType pType pType op () () where op = pHostOp prGPU (pSOAC prGPU) prGPUMem :: PR GPUMem prGPUMem = PR pRetTypeMem pBranchTypeMem pFParamMem pLParamMem pLetDecMem op () () where op = pMemOp $ pHostOp prGPUMem empty prMC :: PR MC prMC = PR pDeclExtType pExtType pDeclType pType pType op () () where op = pMCOp prMC (pSOAC prMC) prMCMem :: PR MCMem prMCMem = PR pRetTypeMem pBranchTypeMem pFParamMem pLParamMem pLetDecMem op () () where op = pMemOp $ pMCOp prMCMem empty parseFull :: Parser a -> FilePath -> T.Text -> Either T.Text a parseFull p fname s = either (Left . T.pack . errorBundlePretty) Right $ parse (whitespace *> p <* eof) fname s parseRep :: PR rep -> FilePath -> T.Text -> Either T.Text (Prog rep) parseRep = parseFull . pProg parseSOACS :: FilePath -> T.Text -> Either T.Text (Prog SOACS) parseSOACS = parseRep prSOACS parseSeq :: FilePath -> T.Text -> Either T.Text (Prog Seq) parseSeq = parseRep prSeq parseSeqMem :: FilePath -> T.Text -> Either T.Text (Prog SeqMem) parseSeqMem = parseRep prSeqMem parseGPU :: FilePath -> T.Text -> Either T.Text (Prog GPU) parseGPU = parseRep prGPU parseGPUMem :: FilePath -> T.Text -> Either T.Text (Prog GPUMem) parseGPUMem = parseRep prGPUMem parseMC :: FilePath -> T.Text -> Either T.Text (Prog MC) parseMC = parseRep prMC parseMCMem :: FilePath -> T.Text -> Either T.Text (Prog MCMem) parseMCMem = parseRep prMCMem --- Fragment parsers parseType :: FilePath -> T.Text -> Either T.Text Type parseType = parseFull pType parseDeclExtType :: FilePath -> T.Text -> Either T.Text DeclExtType parseDeclExtType = parseFull pDeclExtType parseDeclType :: FilePath -> T.Text -> Either T.Text DeclType parseDeclType = parseFull pDeclType parseVName :: FilePath -> T.Text -> Either T.Text VName parseVName = parseFull pVName parseSubExp :: FilePath -> T.Text -> Either T.Text SubExp parseSubExp = parseFull pSubExp parseSubExpRes :: FilePath -> T.Text -> Either T.Text SubExpRes parseSubExpRes = parseFull pSubExpRes parseBodyGPU :: FilePath -> T.Text -> Either T.Text (Body GPU) parseBodyGPU = parseFull $ pBody prGPU parseStmGPU :: FilePath -> T.Text -> Either T.Text (Stm GPU) parseStmGPU = parseFull $ pStm prGPU parseBodyMC :: FilePath -> T.Text -> Either T.Text (Body MC) parseBodyMC = parseFull $ pBody prMC parseStmMC :: FilePath -> T.Text -> Either T.Text (Stm MC) parseStmMC = parseFull $ pStm prMC futhark-0.25.27/src/Futhark/IR/Pretty.hs000066400000000000000000000337011475065116200176760ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} -- | Futhark prettyprinter. This module defines 'Pretty' instances -- for the AST defined in "Futhark.IR.Syntax", -- but also a number of convenience functions if you don't want to use -- the interface from 'Pretty'. module Futhark.IR.Pretty ( prettyTuple, prettyTupleLines, prettyString, PrettyRep (..), ) where import Data.Foldable (toList) import Data.List.NonEmpty (NonEmpty (..)) import Data.Maybe import Futhark.IR.Syntax import Futhark.Util.Pretty -- | The class of representations whose annotations can be prettyprinted. class ( RepTypes rep, Pretty (RetType rep), Pretty (BranchType rep), Pretty (FParamInfo rep), Pretty (LParamInfo rep), Pretty (LetDec rep), Pretty (Op rep) ) => PrettyRep rep where ppExpDec :: ExpDec rep -> Exp rep -> Maybe (Doc a) ppExpDec _ _ = Nothing instance Pretty (NoOp rep) where pretty NoOp = "noop" instance Pretty VName where pretty (VName vn i) = pretty vn <> "_" <> pretty (show i) instance Pretty Commutativity where pretty Commutative = "commutative" pretty Noncommutative = "noncommutative" instance Pretty Shape where pretty = mconcat . map (brackets . pretty) . shapeDims instance Pretty Rank where pretty (Rank r) = mconcat $ replicate r "[]" instance (Pretty a) => Pretty (Ext a) where pretty (Free e) = pretty e pretty (Ext x) = "?" <> pretty (show x) instance Pretty ExtShape where pretty = mconcat . map (brackets . pretty) . shapeDims instance Pretty Space where pretty DefaultSpace = mempty pretty (Space s) = "@" <> pretty s pretty (ScalarSpace d t) = "@" <> mconcat (map (brackets . pretty) d) <> pretty t instance (Pretty u) => Pretty (TypeBase Shape u) where pretty (Prim t) = pretty t pretty (Acc acc ispace ts u) = pretty u <> "acc" <> apply [ pretty acc, pretty ispace, ppTuple' $ map pretty ts ] pretty (Array et (Shape ds) u) = pretty u <> mconcat (map (brackets . pretty) ds) <> pretty et pretty (Mem s) = "mem" <> pretty s instance (Pretty u) => Pretty (TypeBase ExtShape u) where pretty (Prim t) = pretty t pretty (Acc acc ispace ts u) = pretty u <> "acc" <> apply [ pretty acc, pretty ispace, ppTuple' $ map pretty ts ] pretty (Array et (Shape ds) u) = pretty u <> mconcat (map (brackets . pretty) ds) <> pretty et pretty (Mem s) = "mem" <> pretty s instance (Pretty u) => Pretty (TypeBase Rank u) where pretty (Prim t) = pretty t pretty (Acc acc ispace ts u) = pretty u <> "acc" <> apply [ pretty acc, pretty ispace, ppTuple' $ map pretty ts ] pretty (Array et (Rank n) u) = pretty u <> mconcat (replicate n $ brackets mempty) <> pretty et pretty (Mem s) = "mem" <> pretty s instance Pretty Ident where pretty ident = pretty (identType ident) <+> pretty (identName ident) instance Pretty SubExp where pretty (Var v) = pretty v pretty (Constant v) = pretty v instance Pretty Certs where pretty (Certs []) = mempty pretty (Certs cs) = "#" <> braces (commasep (map pretty cs)) instance (PrettyRep rep) => Pretty (Stms rep) where pretty = stack . map pretty . stmsToList instance Pretty SubExpRes where pretty (SubExpRes cs se) = hsep $ certAnnots cs ++ [pretty se] instance (PrettyRep rep) => Pretty (Body rep) where pretty (Body _ stms res) | null stms = braces (commasep $ map pretty res) | otherwise = stack (map pretty $ stmsToList stms) "in" <+> braces (commasep $ map pretty res) instance Pretty Attr where pretty (AttrName v) = pretty v pretty (AttrInt x) = pretty x pretty (AttrComp f attrs) = pretty f <> parens (commasep $ map pretty attrs) attrAnnots :: Attrs -> [Doc a] attrAnnots = map f . toList . unAttrs where f v = "#[" <> pretty v <> "]" stmAttrAnnots :: Stm rep -> [Doc a] stmAttrAnnots = attrAnnots . stmAuxAttrs . stmAux certAnnots :: Certs -> [Doc a] certAnnots cs | cs == mempty = [] | otherwise = [pretty cs] stmCertAnnots :: Stm rep -> [Doc a] stmCertAnnots = certAnnots . stmAuxCerts . stmAux instance Pretty Attrs where pretty = hsep . attrAnnots instance (Pretty t) => Pretty (Pat t) where pretty (Pat xs) = braces $ commastack $ map pretty xs instance (Pretty t) => Pretty (PatElem t) where pretty (PatElem name t) = pretty name <+> colon <+> align (pretty t) instance (Pretty t) => Pretty (Param t) where pretty (Param attrs name t) = annot (attrAnnots attrs) $ pretty name <+> colon <+> align (pretty t) instance (PrettyRep rep) => Pretty (Stm rep) where pretty stm@(Let pat aux e) = align . hang 2 $ "let" <+> align (pretty pat) <+> case stmannot of [] -> equals pretty e _ -> equals (stack stmannot pretty e) where stmannot = concat [ maybeToList (ppExpDec (stmAuxDec aux) e), stmAttrAnnots stm, stmCertAnnots stm ] instance (Pretty a) => Pretty (Slice a) where pretty (Slice xs) = brackets (commasep (map pretty xs)) instance (Pretty d) => Pretty (FlatDimIndex d) where pretty (FlatDimIndex n s) = pretty n <+> ":" <+> pretty s instance (Pretty a) => Pretty (FlatSlice a) where pretty (FlatSlice offset xs) = brackets (pretty offset <> ";" <+> commasep (map pretty xs)) instance Pretty BasicOp where pretty (SubExp se) = pretty se pretty (Opaque OpaqueNil e) = "opaque" <> apply [pretty e] pretty (Opaque (OpaqueTrace s) e) = "trace" <> apply [pretty (show s), pretty e] pretty (ArrayLit es rt) = case rt of Array {} -> brackets $ commastack $ map pretty es _ -> brackets $ commasep $ map pretty es <+> colon <+> "[]" <> pretty rt pretty (ArrayVal vs t) = brackets (commasep $ map pretty vs) <+> colon <+> "[]" <> pretty t pretty (BinOp bop x y) = pretty bop <> parens (pretty x <> comma <+> pretty y) pretty (CmpOp op x y) = pretty op <> parens (pretty x <> comma <+> pretty y) pretty (ConvOp conv x) = pretty (convOpFun conv) <+> pretty fromtype <+> pretty x <+> "to" <+> pretty totype where (fromtype, totype) = convOpType conv pretty (UnOp op e) = pretty op <+> pretty e pretty (Index v slice) = pretty v <> pretty slice pretty (Update safety src slice se) = pretty src <+> with <+> pretty slice <+> "=" <+> pretty se where with = case safety of Unsafe -> "with" Safe -> "with?" pretty (FlatIndex v slice) = pretty v <> pretty slice pretty (FlatUpdate src slice se) = pretty src <+> "with" <+> pretty slice <+> "=" <+> pretty se pretty (Iota e x s et) = "iota" <> et' <> apply [pretty e, pretty x, pretty s] where et' = pretty $ show $ primBitSize $ IntType et pretty (Replicate (Shape []) e) = "copy" <> parens (pretty e) pretty (Replicate ne ve) = "replicate" <> apply [pretty ne, align (pretty ve)] pretty (Scratch t shape) = "scratch" <> apply (pretty t : map pretty shape) pretty (Reshape ReshapeArbitrary shape e) = "reshape" <> apply [pretty shape, pretty e] pretty (Reshape ReshapeCoerce shape e) = "coerce" <> apply [pretty shape, pretty e] pretty (Rearrange perm e) = "rearrange" <> apply [apply (map pretty perm), pretty e] pretty (Concat i (x :| xs) w) = "concat" <> "@" <> pretty i <> apply (pretty w : pretty x : map pretty xs) pretty (Manifest perm e) = "manifest" <> apply [apply (map pretty perm), pretty e] pretty (Assert e msg (loc, _)) = "assert" <> apply [pretty e, pretty msg, pretty $ show $ locStr loc] pretty (UpdateAcc safety acc is v) = update_acc_str <> apply [ pretty acc, ppTuple' $ map pretty is, ppTuple' $ map pretty v ] where update_acc_str = case safety of Safe -> "update_acc" Unsafe -> "update_acc_unsafe" instance (Pretty a) => Pretty (ErrorMsg a) where pretty (ErrorMsg parts) = braces $ align $ commasep $ map p parts where p (ErrorString s) = pretty $ show s p (ErrorVal t x) = pretty x <+> colon <+> pretty t maybeNest :: (PrettyRep rep) => Body rep -> Doc a maybeNest b | null $ bodyStms b = pretty b | otherwise = nestedBlock "{" "}" $ pretty b instance (PrettyRep rep) => Pretty (Case (Body rep)) where pretty (Case vs b) = "case" <+> ppTuple' (map (maybe "_" pretty) vs) <+> "->" <+> maybeNest b prettyRet :: (Pretty t) => (t, RetAls) -> Doc a prettyRet (t, RetAls pals rals) | pals == mempty, rals == mempty = pretty t | otherwise = pretty t <> "#" <> parens (pl pals <> comma <+> pl rals) where pl = brackets . commasep . map pretty instance (PrettyRep rep) => Pretty (Exp rep) where pretty (Match [c] [Case [Just (BoolValue True)] t] f (MatchDec ret ifsort)) = "if" <> info' <+> pretty c "then" <+> maybeNest t <+> "else" <+> maybeNest f colon <+> ppTupleLines' (map pretty ret) where info' = case ifsort of MatchNormal -> mempty MatchFallback -> " " MatchEquiv -> " " pretty (Match ses cs defb (MatchDec ret ifsort)) = ("match" <+> info' <+> ppTuple' (map pretty ses)) stack (map pretty cs) "default" <+> "->" <+> maybeNest defb colon <+> ppTupleLines' (map pretty ret) where info' = case ifsort of MatchNormal -> mempty MatchFallback -> " " MatchEquiv -> " " pretty (BasicOp op) = pretty op pretty (Apply fname args ret (safety, _, _)) = applykw <+> pretty (nameToString fname) <> apply (map (align . prettyArg) args) colon <+> braces (commasep $ map prettyRet ret) where prettyArg (arg, Consume) = "*" <> pretty arg prettyArg (arg, _) = pretty arg applykw = case safety of Unsafe -> "apply " Safe -> "apply" pretty (Op op) = pretty op pretty (Loop merge form loopbody) = "loop" <+> braces (commastack $ map pretty params) <+> equals <+> ppTuple' (map pretty args) ( case form of ForLoop i it bound -> "for" <+> align ( pretty i <> ":" <> pretty it <+> "<" <+> align (pretty bound) ) WhileLoop cond -> "while" <+> pretty cond ) <+> "do" <+> nestedBlock "{" "}" (pretty loopbody) where (params, args) = unzip merge pretty (WithAcc inputs lam) = "with_acc" <> parens (braces (commastack $ map ppInput inputs) <> comma pretty lam) where ppInput (shape, arrs, op) = parens ( pretty shape <> comma <+> ppTuple' (map pretty arrs) <> case op of Nothing -> mempty Just (op', nes) -> comma parens (pretty op' <> comma ppTuple' (map pretty nes)) ) instance (PrettyRep rep) => Pretty (Lambda rep) where pretty (Lambda [] [] (Body _ stms [])) | stms == mempty = "nilFn" pretty (Lambda params rettype body) = "\\" <+> braces (commastack $ map pretty params) indent 2 (colon <+> ppTupleLines' (map pretty rettype) <+> "->") indent 2 (pretty body) instance Pretty Signedness where pretty Signed = "signed" pretty Unsigned = "unsigned" instance Pretty ValueType where pretty (ValueType s (Rank r) t) = mconcat (replicate r "[]") <> pretty (prettySigned (s == Unsigned) t) instance Pretty EntryPointType where pretty (TypeTransparent t) = pretty t pretty (TypeOpaque desc) = "opaque" <+> dquotes (pretty desc) instance Pretty EntryParam where pretty (EntryParam name u t) = pretty name <> colon <+> pretty u <> pretty t instance Pretty EntryResult where pretty (EntryResult u t) = pretty u <> pretty t instance (PrettyRep rep) => Pretty (FunDef rep) where pretty (FunDef entry attrs name rettype fparams body) = annot (attrAnnots attrs) $ fun indent 2 (pretty (nameToString name)) <+> parens (commastack $ map pretty fparams) indent 2 (colon <+> align (ppTupleLines' $ map prettyRet rettype)) <+> equals <+> nestedBlock "{" "}" (pretty body) where fun = case entry of Nothing -> "fun" Just (p_name, p_entry, ret_entry) -> "entry" <> (parens . align) ( "\"" <> pretty p_name <> "\"" <> comma ppTupleLines' (map pretty p_entry) <> comma ppTupleLines' (map pretty ret_entry) ) instance Pretty OpaqueType where pretty (OpaqueType ts) = "opaque" <+> nestedBlock "{" "}" (stack $ map pretty ts) pretty (OpaqueRecord fs) = "record" <+> nestedBlock "{" "}" (stack $ map p fs) where p (f, et) = pretty f <> ":" <+> pretty et pretty (OpaqueSum ts cs) = "sum" <+> nestedBlock "{" "}" (stack $ pretty ts : map p cs) where p (c, ets) = hsep $ "#" <> pretty c : map pretty ets pretty (OpaqueArray r v ts) = "array" <+> pretty r <> "d" <+> dquotes (pretty v) <+> nestedBlock "{" "}" (stack $ map pretty ts) pretty (OpaqueRecordArray r v fs) = "record_array" <+> pretty r <> "d" <+> dquotes (pretty v) <+> nestedBlock "{" "}" (stack $ map p fs) where p (f, et) = pretty f <> ":" <+> pretty et instance Pretty OpaqueTypes where pretty (OpaqueTypes ts) = "types" <+> nestedBlock "{" "}" (stack $ map p ts) where p (name, t) = "type" <+> dquotes (pretty name) <+> equals <+> pretty t instance (PrettyRep rep) => Pretty (Prog rep) where pretty (Prog types consts funs) = stack $ punctuate line $ pretty types : pretty consts : map pretty funs instance (Pretty d) => Pretty (DimIndex d) where pretty (DimFix i) = pretty i pretty (DimSlice i n s) = pretty i <+> ":+" <+> pretty n <+> "*" <+> pretty s futhark-0.25.27/src/Futhark/IR/Prop.hs000066400000000000000000000220001475065116200173150ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | This module provides various simple ways to query and manipulate -- fundamental Futhark terms, such as types and values. The intent is -- to keep "Futhark.IR.Syntax" simple, and put whatever embellishments -- we need here. This is an internal, desugared representation. module Futhark.IR.Prop ( module Futhark.IR.Prop.Reshape, module Futhark.IR.Prop.Rearrange, module Futhark.IR.Prop.Types, module Futhark.IR.Prop.Constants, module Futhark.IR.Prop.TypeOf, module Futhark.IR.Prop.Pat, module Futhark.IR.Prop.Names, module Futhark.IR.RetType, module Futhark.IR.Rephrase, -- * Built-in functions isBuiltInFunction, builtInFunctions, -- * Extra tools asBasicOp, safeExp, subExpVars, subExpVar, commutativeLambda, defAux, stmCerts, certify, expExtTypesFromPat, attrsForAssert, lamIsBinOp, ASTConstraints, IsOp (..), ASTRep (..), ) where import Control.Monad import Data.List (elemIndex, find) import Data.Map.Strict qualified as M import Data.Maybe (isJust, mapMaybe) import Data.Set qualified as S import Futhark.IR.Pretty import Futhark.IR.Prop.Constants import Futhark.IR.Prop.Names import Futhark.IR.Prop.Pat import Futhark.IR.Prop.Rearrange import Futhark.IR.Prop.Reshape import Futhark.IR.Prop.TypeOf import Futhark.IR.Prop.Types import Futhark.IR.Rephrase import Futhark.IR.RetType import Futhark.IR.Syntax import Futhark.Transform.Rename (Rename, Renameable) import Futhark.Transform.Substitute (Substitutable, Substitute) import Futhark.Util (maybeNth) -- | @isBuiltInFunction k@ is 'True' if @k@ is an element of 'builtInFunctions'. isBuiltInFunction :: Name -> Bool isBuiltInFunction fnm = fnm `M.member` builtInFunctions -- | A map of all built-in functions and their types. builtInFunctions :: M.Map Name (PrimType, [PrimType]) builtInFunctions = M.fromList $ map namify $ M.toList primFuns where namify (k, (paramts, ret, _)) = (nameFromText k, (ret, paramts)) -- | If the expression is a t'BasicOp', return it, otherwise 'Nothing'. asBasicOp :: Exp rep -> Maybe BasicOp asBasicOp (BasicOp op) = Just op asBasicOp _ = Nothing -- | An expression is safe if it is always well-defined (assuming that -- any required certificates have been checked) in any context. For -- example, array indexing is not safe, as the index may be out of -- bounds. On the other hand, adding two numbers cannot fail. safeExp :: (ASTRep rep) => Exp rep -> Bool safeExp (BasicOp op) = safeBasicOp op where safeBasicOp (BinOp (SDiv _ Safe) _ _) = True safeBasicOp (BinOp (SDivUp _ Safe) _ _) = True safeBasicOp (BinOp (SQuot _ Safe) _ _) = True safeBasicOp (BinOp (UDiv _ Safe) _ _) = True safeBasicOp (BinOp (UDivUp _ Safe) _ _) = True safeBasicOp (BinOp (SMod _ Safe) _ _) = True safeBasicOp (BinOp (SRem _ Safe) _ _) = True safeBasicOp (BinOp (UMod _ Safe) _ _) = True safeBasicOp (BinOp SDiv {} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp SDiv {} _ _) = False safeBasicOp (BinOp SDivUp {} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp SDivUp {} _ _) = False safeBasicOp (BinOp UDiv {} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp UDiv {} _ _) = False safeBasicOp (BinOp UDivUp {} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp UDivUp {} _ _) = False safeBasicOp (BinOp SMod {} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp SMod {} _ _) = False safeBasicOp (BinOp UMod {} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp UMod {} _ _) = False safeBasicOp (BinOp SQuot {} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp SQuot {} _ _) = False safeBasicOp (BinOp SRem {} _ (Constant y)) = not $ zeroIsh y safeBasicOp (BinOp SRem {} _ _) = False safeBasicOp (BinOp Pow {} _ (Constant y)) = not $ negativeIsh y safeBasicOp (BinOp Pow {} _ _) = False safeBasicOp ArrayLit {} = True safeBasicOp BinOp {} = True safeBasicOp SubExp {} = True safeBasicOp UnOp {} = True safeBasicOp CmpOp {} = True safeBasicOp ConvOp {} = True safeBasicOp Scratch {} = True safeBasicOp Concat {} = True safeBasicOp Reshape {} = True safeBasicOp Rearrange {} = True safeBasicOp Manifest {} = True safeBasicOp Iota {} = True safeBasicOp Replicate {} = True safeBasicOp _ = False safeExp (Loop _ _ body) = safeBody body safeExp (Apply fname _ _ _) = isBuiltInFunction fname safeExp (Match _ cases def_case _) = all (all (safeExp . stmExp) . bodyStms . caseBody) cases && all (safeExp . stmExp) (bodyStms def_case) safeExp WithAcc {} = True -- Although unlikely to matter. safeExp (Op op) = safeOp op safeBody :: (ASTRep rep) => Body rep -> Bool safeBody = all (safeExp . stmExp) . bodyStms -- | Return the variable names used in 'Var' subexpressions. May contain -- duplicates. subExpVars :: [SubExp] -> [VName] subExpVars = mapMaybe subExpVar -- | If the t'SubExp' is a 'Var' return the variable name. subExpVar :: SubExp -> Maybe VName subExpVar (Var v) = Just v subExpVar Constant {} = Nothing -- | Does the given lambda represent a known commutative function? -- Based on pattern matching and checking whether the lambda -- represents a known arithmetic operator; don't expect anything -- clever here. commutativeLambda :: Lambda rep -> Bool commutativeLambda lam = let body = lambdaBody lam n2 = length (lambdaParams lam) `div` 2 (xps, yps) = splitAt n2 (lambdaParams lam) okComponent c = isJust $ find (okBinOp c) $ bodyStms body okBinOp (xp, yp, SubExpRes _ (Var r)) (Let (Pat [pe]) _ (BasicOp (BinOp op (Var x) (Var y)))) = patElemName pe == r && commutativeBinOp op && ( (x == paramName xp && y == paramName yp) || (y == paramName xp && x == paramName yp) ) okBinOp _ _ = False in n2 * 2 == length (lambdaParams lam) && n2 == length (bodyResult body) && all okComponent (zip3 xps yps $ bodyResult body) -- | A 'StmAux' with empty 'Certs'. defAux :: dec -> StmAux dec defAux = StmAux mempty mempty -- | The certificates associated with a statement. stmCerts :: Stm rep -> Certs stmCerts = stmAuxCerts . stmAux -- | Add certificates to a statement. certify :: Certs -> Stm rep -> Stm rep certify cs1 (Let pat (StmAux cs2 attrs dec) e) = Let pat (StmAux (cs2 <> cs1) attrs dec) e -- | A handy shorthand for properties that we usually want for things -- we stuff into ASTs. type ASTConstraints a = (Eq a, Ord a, Show a, Rename a, Substitute a, FreeIn a, Pretty a) -- | A type class for operations. class (TypedOp op) => IsOp op where -- | Like 'safeExp', but for arbitrary ops. safeOp :: (ASTRep rep) => op rep -> Bool -- | Should we try to hoist this out of branches? cheapOp :: (ASTRep rep) => op rep -> Bool -- | Compute the data dependencies of an operation. opDependencies :: (ASTRep rep) => op rep -> [Names] instance IsOp NoOp where safeOp NoOp = True cheapOp NoOp = True opDependencies NoOp = [] -- | Representation-specific attributes; also means the rep supports -- some basic facilities. class ( RepTypes rep, PrettyRep rep, Renameable rep, Substitutable rep, FreeDec (ExpDec rep), FreeIn (LetDec rep), FreeDec (BodyDec rep), FreeIn (FParamInfo rep), FreeIn (LParamInfo rep), FreeIn (RetType rep), FreeIn (BranchType rep), ASTConstraints (OpC rep rep), IsOp (OpC rep), RephraseOp (OpC rep) ) => ASTRep rep where -- | Given a pattern, construct the type of a body that would match -- it. An implementation for many representations would be -- 'expExtTypesFromPat'. expTypesFromPat :: (HasScope rep m, Monad m) => Pat (LetDec rep) -> m [BranchType rep] -- | Construct the type of an expression that would match the pattern. expExtTypesFromPat :: (Typed dec) => Pat dec -> [ExtType] expExtTypesFromPat pat = existentialiseExtTypes (patNames pat) $ staticShapes $ map patElemType $ patElems pat -- | Keep only those attributes that are relevant for 'Assert' -- expressions. attrsForAssert :: Attrs -> Attrs attrsForAssert (Attrs attrs) = Attrs $ S.filter attrForAssert attrs where attrForAssert = (== AttrComp "warn" ["safety_checks"]) -- | Horizontally fission a lambda that models a binary operator. lamIsBinOp :: (ASTRep rep) => Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)] lamIsBinOp lam = mapM splitStm $ bodyResult $ lambdaBody lam where n = length $ lambdaReturnType lam splitStm (SubExpRes cs (Var res)) = do guard $ cs == mempty Let (Pat [pe]) _ (BasicOp (BinOp op (Var x) (Var y))) <- find (([res] ==) . patNames . stmPat) $ stmsToList $ bodyStms $ lambdaBody lam i <- Var res `elemIndex` map resSubExp (bodyResult (lambdaBody lam)) xp <- maybeNth i $ lambdaParams lam yp <- maybeNth (n + i) $ lambdaParams lam guard $ paramName xp == x guard $ paramName yp == y Prim t <- Just $ patElemType pe pure (op, t, paramName xp, paramName yp) splitStm _ = Nothing futhark-0.25.27/src/Futhark/IR/Prop/000077500000000000000000000000001475065116200167675ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/IR/Prop/Aliases.hs000066400000000000000000000207051475065116200207100ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | The IR tracks aliases, mostly to ensure the soundness of in-place -- updates, but it can also be used for other things (such as memory -- optimisations). This module contains the raw building blocks for -- determining the aliases of the values produced by expressions. It -- also contains some building blocks for inspecting consumption. -- -- One important caveat is that all aliases computed here are /local/. -- Thus, they do not take aliases-of-aliases into account. See -- "Futhark.Analysis.Alias" if this is not what you want. module Futhark.IR.Prop.Aliases ( subExpAliases, expAliases, patAliases, lookupAliases, Aliased (..), AliasesOf (..), -- * Consumption consumedInStm, consumedInExp, consumedByLambda, -- * Extensibility AliasTable, AliasedOp (..), ) where import Data.Bifunctor (first, second) import Data.List (find, transpose) import Data.Map.Strict qualified as M import Data.Maybe (mapMaybe) import Futhark.IR.Prop (ASTRep, IsOp, NameInfo (..), Scope) import Futhark.IR.Prop.Names import Futhark.IR.Prop.Pat import Futhark.IR.Prop.Types import Futhark.IR.Syntax -- | The class of representations that contain aliasing information. class (ASTRep rep, AliasedOp (OpC rep), AliasesOf (LetDec rep)) => Aliased rep where -- | The aliases of the body results. Note that this includes names -- bound in the body! bodyAliases :: Body rep -> [Names] -- | The variables consumed in the body. consumedInBody :: Body rep -> Names vnameAliases :: VName -> Names vnameAliases = oneName -- | The aliases of a subexpression. subExpAliases :: SubExp -> Names subExpAliases Constant {} = mempty subExpAliases (Var v) = vnameAliases v basicOpAliases :: BasicOp -> [Names] basicOpAliases (SubExp se) = [subExpAliases se] basicOpAliases (Opaque _ se) = [subExpAliases se] basicOpAliases (ArrayVal _ _) = [mempty] basicOpAliases (ArrayLit _ _) = [mempty] basicOpAliases BinOp {} = [mempty] basicOpAliases ConvOp {} = [mempty] basicOpAliases CmpOp {} = [mempty] basicOpAliases UnOp {} = [mempty] basicOpAliases (Index ident _) = [vnameAliases ident] basicOpAliases Update {} = [mempty] basicOpAliases (FlatIndex ident _) = [vnameAliases ident] basicOpAliases FlatUpdate {} = [mempty] basicOpAliases Iota {} = [mempty] basicOpAliases Replicate {} = [mempty] basicOpAliases Scratch {} = [mempty] basicOpAliases (Reshape _ _ e) = [vnameAliases e] basicOpAliases (Rearrange _ e) = [vnameAliases e] basicOpAliases Concat {} = [mempty] basicOpAliases Manifest {} = [mempty] basicOpAliases Assert {} = [mempty] basicOpAliases UpdateAcc {} = [mempty] matchAliases :: [([Names], Names)] -> [Names] matchAliases l = map ((`namesSubtract` mconcat conses) . mconcat) $ transpose alses where (alses, conses) = unzip l funcallAliases :: [PatElem dec] -> [(SubExp, Diet)] -> [(TypeBase shape Uniqueness, RetAls)] -> [Names] funcallAliases pes args = map onType where -- We assumes that the pals/rals lists are sorted, as this allows -- us to compute the intersections much more efficiently. argAls (i, (Var v, Observe)) = Just (i, v) argAls _ = Nothing arg_als = mapMaybe argAls $ zip [0 ..] args res_als = zip [0 ..] $ map patElemName pes pick (i : is) ((j, v) : jvs) | i == j = v : pick is jvs | i > j = pick (i : is) jvs | otherwise = pick is ((j, v) : jvs) pick _ _ = [] getAls als is = namesFromList $ pick is als onType (_t, RetAls pals rals) = getAls arg_als pals <> getAls res_als rals mutualAliases :: Names -> [PatElem dec] -> [Names] -> [Names] mutualAliases bound pes als = zipWith grow (map patElemName pes) als where bound_als = map (`namesIntersection` bound) als grow v names = (names <> pe_names) `namesSubtract` bound where pe_names = namesFromList . filter (/= v) . map (patElemName . fst) . filter (namesIntersect names . snd) $ zip pes bound_als -- | The aliases of an expression, one for each pattern element. -- -- The pattern is important because some aliasing might be through -- variables that are no longer in scope (consider the aliases for a -- body that returns the same value multiple times). expAliases :: (Aliased rep) => [PatElem dec] -> Exp rep -> [Names] expAliases pes (Match _ cases defbody _) = -- Repeat mempty in case the pattern has more elements (this -- implies a type error). mutualAliases bound pes $ als ++ repeat mempty where als = matchAliases $ onBody defbody : map (onBody . caseBody) cases onBody body = (bodyAliases body, consumedInBody body) bound = foldMap boundInBody $ defbody : map caseBody cases expAliases _ (BasicOp op) = basicOpAliases op expAliases pes (Loop merge _ loopbody) = mutualAliases (bound <> param_names) pes $ do (p, als) <- transitive . zip params $ zipWith (<>) arg_aliases (bodyAliases loopbody) if unique $ paramDeclType p then pure mempty else pure als where bound = boundInBody loopbody arg_aliases = map (subExpAliases . snd) merge params = map fst merge param_names = namesFromList $ map paramName params transitive merge_and_als = let merge_and_als' = map (second expand) merge_and_als in if merge_and_als' == merge_and_als then merge_and_als else transitive merge_and_als' where look v = maybe mempty snd $ find ((== v) . paramName . fst) merge_and_als expand als = als <> foldMap look (namesToList als) expAliases pes (Apply _ args t _) = funcallAliases pes args $ map (first declExtTypeOf) t expAliases _ (WithAcc inputs lam) = concatMap inputAliases inputs ++ drop num_accs (map (`namesSubtract` boundInBody body) $ bodyAliases body) where body = lambdaBody lam inputAliases (_, arrs, _) = replicate (length arrs) mempty num_accs = length inputs expAliases _ (Op op) = opAliases op -- | The variables consumed in this statement. consumedInStm :: (Aliased rep) => Stm rep -> Names consumedInStm = consumedInExp . stmExp -- | The variables consumed in this expression. consumedInExp :: (Aliased rep) => Exp rep -> Names consumedInExp (Apply _ args _ _) = mconcat (map (consumeArg . first subExpAliases) args) where consumeArg (als, Consume) = als consumeArg _ = mempty consumedInExp (Match _ cases defbody _) = foldMap (consumedInBody . caseBody) cases <> consumedInBody defbody consumedInExp (Loop merge _ _) = mconcat ( map (subExpAliases . snd) $ filter (unique . paramDeclType . fst) merge ) consumedInExp (WithAcc inputs lam) = mconcat (map inputConsumed inputs) <> ( consumedByLambda lam `namesSubtract` namesFromList (map paramName (lambdaParams lam)) ) where inputConsumed (_, arrs, _) = namesFromList arrs consumedInExp (BasicOp (Update _ src _ _)) = oneName src consumedInExp (BasicOp (FlatUpdate src _ _)) = oneName src consumedInExp (BasicOp (UpdateAcc _ acc _ _)) = oneName acc consumedInExp (BasicOp _) = mempty consumedInExp (Op op) = consumedInOp op -- | The variables consumed by this lambda. consumedByLambda :: (Aliased rep) => Lambda rep -> Names consumedByLambda = consumedInBody . lambdaBody -- | The aliases of each pattern element. patAliases :: (AliasesOf dec) => Pat dec -> [Names] patAliases = map aliasesOf . patElems -- | Something that contains alias information. class AliasesOf a where -- | The alias of the argument element. aliasesOf :: a -> Names instance AliasesOf Names where aliasesOf = id instance (AliasesOf dec) => AliasesOf (PatElem dec) where aliasesOf = aliasesOf . patElemDec -- | Also includes the name itself. lookupAliases :: (AliasesOf (LetDec rep)) => VName -> Scope rep -> Names lookupAliases root scope = -- We must be careful to handle circular aliasing properly (this -- can happen due to Match and Loop). expand mempty root where expand prev v = case M.lookup v scope of Just (LetName dec) -> oneName v <> foldMap (expand (oneName v <> prev)) (filter (`notNameIn` prev) (namesToList (aliasesOf dec))) _ -> oneName v -- | The class of operations that can produce aliasing and consumption -- information. class (IsOp op) => AliasedOp op where opAliases :: (Aliased rep) => op rep -> [Names] consumedInOp :: (Aliased rep) => op rep -> Names instance AliasedOp NoOp where opAliases NoOp = [] consumedInOp NoOp = mempty -- | Pre-existing aliases for variables. Used to add transitive -- aliases. type AliasTable = M.Map VName Names futhark-0.25.27/src/Futhark/IR/Prop/Constants.hs000066400000000000000000000035461475065116200213070ustar00rootroot00000000000000-- | Possibly convenient facilities for constructing constants. module Futhark.IR.Prop.Constants ( IsValue (..), constant, intConst, floatConst, ) where import Futhark.IR.Syntax.Core (SubExp (..)) import Language.Futhark.Primitive -- | If a Haskell type is an instance of 'IsValue', it means that a -- value of that type can be converted to a Futhark 'PrimValue'. -- This is intended to cut down on boilerplate when writing compiler -- code - for example, you'll quickly grow tired of writing @Constant -- (LogVal True) loc@. class IsValue a where value :: a -> PrimValue instance IsValue Int8 where value = IntValue . Int8Value instance IsValue Int16 where value = IntValue . Int16Value instance IsValue Int32 where value = IntValue . Int32Value instance IsValue Int64 where value = IntValue . Int64Value instance IsValue Word8 where value = IntValue . Int8Value . fromIntegral instance IsValue Word16 where value = IntValue . Int16Value . fromIntegral instance IsValue Word32 where value = IntValue . Int32Value . fromIntegral instance IsValue Word64 where value = IntValue . Int64Value . fromIntegral instance IsValue Double where value = FloatValue . Float64Value instance IsValue Float where value = FloatValue . Float32Value instance IsValue Bool where value = BoolValue instance IsValue PrimValue where value = id instance IsValue IntValue where value = IntValue instance IsValue FloatValue where value = FloatValue -- | Create a 'Constant' 'SubExp' containing the given value. constant :: (IsValue v) => v -> SubExp constant = Constant . value -- | Utility definition for reasons of type ambiguity. intConst :: IntType -> Integer -> SubExp intConst t v = constant $ intValue t v -- | Utility definition for reasons of type ambiguity. floatConst :: FloatType -> Double -> SubExp floatConst t v = constant $ floatValue t v futhark-0.25.27/src/Futhark/IR/Prop/Names.hs000066400000000000000000000275321475065116200203770ustar00rootroot00000000000000{-# LANGUAGE UndecidableInstances #-} -- | Facilities for determining which names are used in some syntactic -- construct. The most important interface is the 'FreeIn' class and -- its instances, but for reasons related to the Haskell type system, -- some constructs have specialised functions. module Futhark.IR.Prop.Names ( -- * Free names Names, namesIntMap, namesIntSet, nameIn, notNameIn, oneName, namesFromList, namesToList, namesIntersection, namesIntersect, namesSubtract, mapNames, -- * Class FreeIn (..), freeIn, -- * Specialised Functions freeInStmsAndRes, -- * Bound Names boundInBody, boundByStm, boundByStms, boundByLambda, -- * Efficient computation FreeDec (..), FV, fvBind, fvName, fvNames, ) where import Control.Category import Control.Monad.State.Strict import Data.Foldable import Data.IntMap.Strict qualified as IM import Data.IntSet qualified as IS import Data.Map.Strict qualified as M import Data.Set qualified as S import Futhark.IR.Prop.Pat import Futhark.IR.Syntax import Futhark.IR.Traversals import Futhark.Util.Pretty import Prelude hiding (id, (.)) -- | A set of names. Note that the 'Ord' instance is a dummy that -- treats everything as 'EQ' if '==', and otherwise 'LT'. newtype Names = Names (IM.IntMap VName) deriving (Eq, Show) -- | Retrieve the data structure underlying the names representation. namesIntMap :: Names -> IM.IntMap VName namesIntMap (Names m) = m -- | Retrieve the set of tags in the names set. namesIntSet :: Names -> IS.IntSet namesIntSet (Names m) = IM.keysSet m instance Ord Names where x `compare` y = if x == y then EQ else LT instance Semigroup Names where vs1 <> vs2 = Names $ namesIntMap vs1 <> namesIntMap vs2 instance Monoid Names where mempty = Names mempty instance Pretty Names where pretty = pretty . namesToList -- | Does the set of names contain this name? nameIn :: VName -> Names -> Bool nameIn v (Names vs) = baseTag v `IM.member` vs -- | Does the set of names not contain this name? notNameIn :: VName -> Names -> Bool notNameIn v (Names vs) = baseTag v `IM.notMember` vs -- | Construct a name set from a list. Slow. namesFromList :: [VName] -> Names namesFromList vs = Names $ IM.fromList $ zip (map baseTag vs) vs -- | Turn a name set into a list of names. Slow. namesToList :: Names -> [VName] namesToList = IM.elems . namesIntMap -- | Construct a name set from a single name. oneName :: VName -> Names oneName v = Names $ IM.singleton (baseTag v) v -- | The intersection of two name sets. namesIntersection :: Names -> Names -> Names namesIntersection (Names vs1) (Names vs2) = Names $ IM.intersection vs1 vs2 -- | Do the two name sets intersect? namesIntersect :: Names -> Names -> Bool namesIntersect vs1 vs2 = not $ IM.disjoint (namesIntMap vs1) (namesIntMap vs2) -- | Subtract the latter name set from the former. namesSubtract :: Names -> Names -> Names namesSubtract (Names vs1) (Names vs2) = Names $ IM.difference vs1 vs2 -- | Map over the names in a set. mapNames :: (VName -> VName) -> Names -> Names mapNames f vs = namesFromList $ map f $ namesToList vs -- | A computation to build a free variable set. newtype FV = FV {unFV :: Names} -- Right now the variable set is just stored explicitly, without the -- fancy functional representation that GHC uses. Turns out it's -- faster this way. instance Monoid FV where mempty = FV mempty instance Semigroup FV where FV fv1 <> FV fv2 = FV $ fv1 <> fv2 -- | Consider a variable to be bound in the given 'FV' computation. fvBind :: Names -> FV -> FV fvBind vs (FV fv) = FV $ fv `namesSubtract` vs -- | Take note of a variable reference. fvName :: VName -> FV fvName v = FV $ oneName v -- | Take note of a set of variable references. fvNames :: Names -> FV fvNames = FV freeWalker :: ( FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (FParamInfo rep), FreeIn (LParamInfo rep), FreeIn (LetDec rep), FreeIn (RetType rep), FreeIn (BranchType rep), FreeIn (Op rep) ) => Walker rep (State FV) freeWalker = Walker { walkOnSubExp = modify . (<>) . freeIn', walkOnBody = \scope body -> do modify $ (<>) $ freeIn' body modify $ fvBind (namesFromList (M.keys scope)), walkOnVName = modify . (<>) . fvName, walkOnOp = modify . (<>) . freeIn', walkOnFParam = modify . (<>) . freeIn', walkOnLParam = modify . (<>) . freeIn', walkOnRetType = modify . (<>) . freeIn', walkOnBranchType = modify . (<>) . freeIn' } -- | Return the set of variable names that are free in the given -- statements and result. Filters away the names that are bound by -- the statements. freeInStmsAndRes :: ( FreeIn (Op rep), FreeIn (LetDec rep), FreeIn (LParamInfo rep), FreeIn (FParamInfo rep), FreeDec (BodyDec rep), FreeIn (RetType rep), FreeIn (BranchType rep), FreeDec (ExpDec rep) ) => Stms rep -> Result -> FV freeInStmsAndRes stms res = fvBind (boundByStms stms) $ foldMap freeIn' stms <> freeIn' res -- | A class indicating that we can obtain free variable information -- from values of this type. class FreeIn a where freeIn' :: a -> FV freeIn' = fvNames . freeIn -- | The free variables of some syntactic construct. freeIn :: (FreeIn a) => a -> Names freeIn = unFV . freeIn' instance FreeIn FV where freeIn' = id instance FreeIn () where freeIn' () = mempty instance FreeIn Int where freeIn' = const mempty instance (FreeIn a, FreeIn b) => FreeIn (a, b) where freeIn' (a, b) = freeIn' a <> freeIn' b instance (FreeIn a, FreeIn b, FreeIn c) => FreeIn (a, b, c) where freeIn' (a, b, c) = freeIn' a <> freeIn' b <> freeIn' c instance (FreeIn a, FreeIn b, FreeIn c, FreeIn d) => FreeIn (a, b, c, d) where freeIn' (a, b, c, d) = freeIn' a <> freeIn' b <> freeIn' c <> freeIn' d instance (FreeIn a, FreeIn b) => FreeIn (Either a b) where freeIn' = either freeIn' freeIn' instance (FreeIn a) => FreeIn [a] where freeIn' = foldMap freeIn' instance (FreeIn a) => FreeIn (S.Set a) where freeIn' = foldMap freeIn' instance FreeIn (NoOp rep) where freeIn' NoOp = mempty instance ( FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (FParamInfo rep), FreeIn (LParamInfo rep), FreeIn (LetDec rep), FreeIn (RetType rep), FreeIn (BranchType rep), FreeIn (Op rep) ) => FreeIn (FunDef rep) where freeIn' (FunDef _ _ _ rettype params body) = fvBind (namesFromList $ map paramName params) $ foldMap (freeIn' . fst) rettype <> freeIn' params <> freeIn' body instance ( FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (FParamInfo rep), FreeIn (LParamInfo rep), FreeIn (LetDec rep), FreeIn (RetType rep), FreeIn (BranchType rep), FreeIn (Op rep) ) => FreeIn (Lambda rep) where freeIn' (Lambda params body rettype) = fvBind (namesFromList $ map paramName params) $ freeIn' rettype <> freeIn' params <> freeIn' body instance ( FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (FParamInfo rep), FreeIn (LParamInfo rep), FreeIn (LetDec rep), FreeIn (RetType rep), FreeIn (BranchType rep), FreeIn (Op rep) ) => FreeIn (Body rep) where freeIn' (Body dec stms res) = precomputed dec $ freeIn' dec <> freeInStmsAndRes stms res instance ( FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (FParamInfo rep), FreeIn (LParamInfo rep), FreeIn (LetDec rep), FreeIn (RetType rep), FreeIn (BranchType rep), FreeIn (Op rep) ) => FreeIn (Exp rep) where freeIn' (Loop merge form loopbody) = let (params, args) = unzip merge bound_here = case form of WhileLoop {} -> namesFromList $ map paramName params ForLoop i _ _ -> namesFromList $ i : map paramName params in fvBind bound_here $ freeIn' args <> freeIn' form <> freeIn' params <> freeIn' loopbody freeIn' (WithAcc inputs lam) = freeIn' inputs <> freeIn' lam freeIn' e = execState (walkExpM freeWalker e) mempty instance ( FreeDec (ExpDec rep), FreeDec (BodyDec rep), FreeIn (FParamInfo rep), FreeIn (LParamInfo rep), FreeIn (LetDec rep), FreeIn (RetType rep), FreeIn (BranchType rep), FreeIn (Op rep) ) => FreeIn (Stm rep) where freeIn' (Let pat (StmAux cs attrs dec) e) = freeIn' cs <> freeIn' attrs <> precomputed dec (freeIn' dec <> freeIn' e <> freeIn' pat) instance (FreeIn (Stm rep)) => FreeIn (Stms rep) where freeIn' = foldMap freeIn' instance (FreeIn body) => FreeIn (Case body) where freeIn' = freeIn' . caseBody instance FreeIn Names where freeIn' = fvNames instance FreeIn Bool where freeIn' _ = mempty instance (FreeIn a) => FreeIn (Maybe a) where freeIn' = maybe mempty freeIn' instance FreeIn VName where freeIn' = fvName instance FreeIn Ident where freeIn' = freeIn' . identType instance FreeIn SubExp where freeIn' (Var v) = freeIn' v freeIn' Constant {} = mempty instance FreeIn Space where freeIn' (ScalarSpace d _) = freeIn' d freeIn' DefaultSpace = mempty freeIn' (Space _) = mempty instance (FreeIn d) => FreeIn (ShapeBase d) where freeIn' = freeIn' . shapeDims instance (FreeIn d) => FreeIn (Ext d) where freeIn' (Free x) = freeIn' x freeIn' (Ext _) = mempty instance FreeIn PrimType where freeIn' _ = mempty instance (FreeIn shape) => FreeIn (TypeBase shape u) where freeIn' (Array t shape _) = freeIn' t <> freeIn' shape freeIn' (Mem s) = freeIn' s freeIn' Prim {} = mempty freeIn' (Acc acc ispace ts _) = freeIn' (acc, ispace, ts) instance (FreeIn dec) => FreeIn (Param dec) where freeIn' (Param attrs _ dec) = freeIn' attrs <> freeIn' dec instance (FreeIn dec) => FreeIn (PatElem dec) where freeIn' (PatElem _ dec) = freeIn' dec instance FreeIn LoopForm where freeIn' (ForLoop _ _ bound) = freeIn' bound freeIn' (WhileLoop cond) = freeIn' cond instance (FreeIn d) => FreeIn (DimIndex d) where freeIn' = Data.Foldable.foldMap freeIn' instance (FreeIn d) => FreeIn (Slice d) where freeIn' = Data.Foldable.foldMap freeIn' instance (FreeIn d) => FreeIn (FlatDimIndex d) where freeIn' = Data.Foldable.foldMap freeIn' instance (FreeIn d) => FreeIn (FlatSlice d) where freeIn' = Data.Foldable.foldMap freeIn' instance FreeIn SubExpRes where freeIn' (SubExpRes cs se) = freeIn' cs <> freeIn' se instance (FreeIn dec) => FreeIn (Pat dec) where freeIn' (Pat xs) = fvBind bound_here $ freeIn' xs where bound_here = namesFromList $ map patElemName xs instance FreeIn Certs where freeIn' (Certs cs) = freeIn' cs instance FreeIn Attrs where freeIn' (Attrs _) = mempty instance (FreeIn dec) => FreeIn (StmAux dec) where freeIn' (StmAux cs attrs dec) = freeIn' cs <> freeIn' attrs <> freeIn' dec instance (FreeIn a) => FreeIn (MatchDec a) where freeIn' (MatchDec r _) = freeIn' r -- | Either return precomputed free names stored in the attribute, or -- the freshly computed names. Relies on lazy evaluation to avoid the -- work. class (FreeIn dec) => FreeDec dec where precomputed :: dec -> FV -> FV precomputed _ = id instance FreeDec () instance (FreeDec a, FreeIn b) => FreeDec (a, b) where precomputed (a, _) = precomputed a instance (FreeDec a) => FreeDec [a] where precomputed [] = id precomputed (a : _) = precomputed a instance (FreeDec a) => FreeDec (Maybe a) where precomputed Nothing = id precomputed (Just a) = precomputed a instance FreeDec Names where precomputed _ fv = fv -- | The names bound by the bindings immediately in a t'Body'. boundInBody :: Body rep -> Names boundInBody = boundByStms . bodyStms -- | The names bound by a binding. boundByStm :: Stm rep -> Names boundByStm = namesFromList . patNames . stmPat -- | The names bound by the bindings. boundByStms :: Stms rep -> Names boundByStms = foldMap boundByStm -- | The names of the lambda parameters plus the index parameter. boundByLambda :: Lambda rep -> [VName] boundByLambda lam = map paramName (lambdaParams lam) futhark-0.25.27/src/Futhark/IR/Prop/Pat.hs000066400000000000000000000037001475065116200200470ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Inspecing and modifying t'Pat's, function parameters and -- pattern elements. module Futhark.IR.Prop.Pat ( -- * Function parameters paramIdent, paramType, paramDeclType, -- * Pat elements patElemIdent, patElemType, setPatElemDec, patIdents, patNames, patTypes, patSize, -- * Pat construction basicPat, ) where import Futhark.IR.Prop.Types (DeclTyped (..), Typed (..)) import Futhark.IR.Syntax -- | The 'Type' of a parameter. paramType :: (Typed dec) => Param dec -> Type paramType = typeOf -- | The 'DeclType' of a parameter. paramDeclType :: (DeclTyped dec) => Param dec -> DeclType paramDeclType = declTypeOf -- | An 'Ident' corresponding to a parameter. paramIdent :: (Typed dec) => Param dec -> Ident paramIdent param = Ident (paramName param) (typeOf param) -- | An 'Ident' corresponding to a pattern element. patElemIdent :: (Typed dec) => PatElem dec -> Ident patElemIdent pelem = Ident (patElemName pelem) (typeOf pelem) -- | The type of a name bound by a t'PatElem'. patElemType :: (Typed dec) => PatElem dec -> Type patElemType = typeOf -- | Set the rep of a t'PatElem'. setPatElemDec :: PatElem oldattr -> newattr -> PatElem newattr setPatElemDec pe x = fmap (const x) pe -- | Return a list of the 'Ident's bound by the t'Pat'. patIdents :: (Typed dec) => Pat dec -> [Ident] patIdents = map patElemIdent . patElems -- | Return a list of the 'Name's bound by the t'Pat'. patNames :: Pat dec -> [VName] patNames = map patElemName . patElems -- | Return a list of the typess bound by the pattern. patTypes :: (Typed dec) => Pat dec -> [Type] patTypes = map identType . patIdents -- | Return the number of names bound by the pattern. patSize :: Pat dec -> Int patSize (Pat xs) = length xs -- | Create a pattern using 'Type' as the attribute. basicPat :: [Ident] -> Pat Type basicPat values = Pat $ map patElem values where patElem (Ident name t) = PatElem name t futhark-0.25.27/src/Futhark/IR/Prop/Rearrange.hs000066400000000000000000000072071475065116200212370ustar00rootroot00000000000000-- | A rearrangement is a generalisation of transposition, where the -- dimensions are arbitrarily permuted. module Futhark.IR.Prop.Rearrange ( rearrangeShape, rearrangeInverse, rearrangeReach, rearrangeCompose, isPermutationOf, transposeIndex, isMapTranspose, ) where import Data.List (sortOn, tails) import Futhark.Util -- | Calculate the given permutation of the list. It is an error if -- the permutation goes out of bounds. rearrangeShape :: [Int] -> [a] -> [a] rearrangeShape perm l = map pick perm where pick i | 0 <= i, i < n = l !! i | otherwise = error $ show perm ++ " is not a valid permutation for input." n = length l -- | Produce the inverse permutation. rearrangeInverse :: [Int] -> [Int] rearrangeInverse perm = map snd $ sortOn fst $ zip perm [0 ..] -- | Return the first dimension not affected by the permutation. For -- example, the permutation @[1,0,2]@ would return @2@. rearrangeReach :: [Int] -> Int rearrangeReach perm = case dropWhile (uncurry (/=)) $ zip (tails perm) (tails [0 .. n - 1]) of [] -> n + 1 (perm', _) : _ -> n - length perm' where n = length perm -- | Compose two permutations, with the second given permutation being -- applied first. rearrangeCompose :: [Int] -> [Int] -> [Int] rearrangeCompose = rearrangeShape -- | Check whether the first list is a permutation of the second, and -- if so, return the permutation. This will also find identity -- permutations (i.e. the lists are the same) The implementation is -- naive and slow. isPermutationOf :: (Eq a) => [a] -> [a] -> Maybe [Int] isPermutationOf l1 l2 = case mapAccumLM (pick 0) (map Just l2) l1 of Just (l2', perm) | all (== Nothing) l2' -> Just perm _ -> Nothing where pick :: (Eq a) => Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int) pick _ [] _ = Nothing pick i (x : xs) y | Just y == x = Just (Nothing : xs, i) | otherwise = do (xs', v) <- pick (i + 1) xs y pure (x : xs', v) -- | If @l@ is an index into the array @a@, then @transposeIndex k n -- l@ is an index to the same element in the array @transposeArray k n -- a@. transposeIndex :: Int -> Int -> [a] -> [a] transposeIndex k n l | k + n >= length l = let n' = ((k + n) `mod` length l) - k in transposeIndex k n' l | n < 0, (pre, needle : end) <- splitAt k l, (beg, mid) <- splitAt (length pre + n) pre = beg ++ [needle] ++ mid ++ end | (beg, needle : post) <- splitAt k l, (mid, end) <- splitAt n post = beg ++ mid ++ [needle] ++ end | otherwise = l -- | If @perm@ is conceptually a map of a transposition, -- @isMapTranspose perm@ returns the number of dimensions being mapped -- and the number dimension being transposed. For example, we can -- consider the permutation @[0,1,4,5,2,3]@ as a map of a transpose, -- by considering dimensions @[0,1]@, @[4,5]@, and @[2,3]@ as single -- dimensions each. -- -- If the input is not a valid permutation, then the result is -- undefined. isMapTranspose :: [Int] -> Maybe (Int, Int, Int) isMapTranspose perm | posttrans == [length mapped .. length mapped + length posttrans - 1], not $ null pretrans, not $ null posttrans = Just (length mapped, length pretrans, length posttrans) | otherwise = Nothing where (mapped, notmapped) = findIncreasingFrom 0 perm (pretrans, posttrans) = findTransposed notmapped findIncreasingFrom x (i : is) | i == x = let (js, ps) = findIncreasingFrom (x + 1) is in (i : js, ps) findIncreasingFrom _ is = ([], is) findTransposed [] = ([], []) findTransposed (i : is) = findIncreasingFrom i (i : is) futhark-0.25.27/src/Futhark/IR/Prop/Reshape.hs000066400000000000000000000060311475065116200207120ustar00rootroot00000000000000-- | Facilities for creating, inspecting, and simplifying reshape and -- coercion operations. module Futhark.IR.Prop.Reshape ( -- * Construction shapeCoerce, -- * Execution reshapeOuter, reshapeInner, -- * Simplification -- * Shape calculations reshapeIndex, flattenIndex, unflattenIndex, sliceSizes, ) where import Data.Foldable import Futhark.IR.Syntax import Futhark.Util.IntegralExp import Prelude hiding (product, quot, sum) -- | Construct a 'Reshape' that is a 'ReshapeCoerce'. shapeCoerce :: [SubExp] -> VName -> Exp rep shapeCoerce newdims arr = BasicOp $ Reshape ReshapeCoerce (Shape newdims) arr -- | @reshapeOuter newshape n oldshape@ returns a 'Reshape' expression -- that replaces the outer @n@ dimensions of @oldshape@ with @newshape@. reshapeOuter :: Shape -> Int -> Shape -> Shape reshapeOuter newshape n oldshape = newshape <> Shape (drop n (shapeDims oldshape)) -- | @reshapeInner newshape n oldshape@ returns a 'Reshape' expression -- that replaces the inner @m-n@ dimensions (where @m@ is the rank of -- @oldshape@) of @src@ with @newshape@. reshapeInner :: Shape -> Int -> Shape -> Shape reshapeInner newshape n oldshape = Shape (take n (shapeDims oldshape)) <> newshape -- | @reshapeIndex to_dims from_dims is@ transforms the index list -- @is@ (which is into an array of shape @from_dims@) into an index -- list @is'@, which is into an array of shape @to_dims@. @is@ must -- have the same length as @from_dims@, and @is'@ will have the same -- length as @to_dims@. reshapeIndex :: (IntegralExp num) => [num] -> [num] -> [num] -> [num] reshapeIndex to_dims from_dims is = unflattenIndex to_dims $ flattenIndex from_dims is -- | @unflattenIndex dims i@ computes a list of indices into an array -- with dimension @dims@ given the flat index @i@. The resulting list -- will have the same size as @dims@. unflattenIndex :: (IntegralExp num) => [num] -> num -> [num] unflattenIndex = unflattenIndexFromSlices . drop 1 . sliceSizes unflattenIndexFromSlices :: (IntegralExp num) => [num] -> num -> [num] unflattenIndexFromSlices [] _ = [] unflattenIndexFromSlices (size : slices) i = (i `quot` size) : unflattenIndexFromSlices slices (i - (i `quot` size) * size) -- | @flattenIndex dims is@ computes the flat index of @is@ into an -- array with dimensions @dims@. The length of @dims@ and @is@ must -- be the same. flattenIndex :: (IntegralExp num) => [num] -> [num] -> num flattenIndex dims is | length is /= length slicesizes = error "flattenIndex: length mismatch" | otherwise = sum $ zipWith (*) is slicesizes where slicesizes = drop 1 $ sliceSizes dims -- | Given a length @n@ list of dimensions @dims@, @sizeSizes dims@ -- will compute a length @n+1@ list of the size of each possible array -- slice. The first element of this list will be the product of -- @dims@, and the last element will be 1. sliceSizes :: (IntegralExp num) => [num] -> [num] sliceSizes [] = [1] sliceSizes (n : ns) = product (n : ns) : sliceSizes ns {- HLINT ignore sliceSizes -} futhark-0.25.27/src/Futhark/IR/Prop/Scope.hs000066400000000000000000000173031475065116200204000ustar00rootroot00000000000000{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -- | The core Futhark AST does not contain type information when we -- use a variable. Therefore, most transformations expect to be able -- to access some kind of symbol table that maps names to their types. -- -- This module defines the concept of a type environment as a mapping -- from variable names to 'NameInfo's. Convenience facilities are -- also provided to communicate that some monad or applicative functor -- maintains type information. -- -- A simple example of a monad that maintains such as environment is -- 'Reader'. Indeed, 'HasScope' and 'LocalScope' instances for this -- monad are already defined. module Futhark.IR.Prop.Scope ( HasScope (..), NameInfo (..), LocalScope (..), Scope, Scoped (..), inScopeOf, scopeOfLParams, scopeOfFParams, scopeOfLoopForm, scopeOfPat, scopeOfPatElem, SameScope, castScope, -- * Extended type environment ExtendedScope, extendedScope, ) where import Control.Monad.Except import Control.Monad.RWS.Lazy qualified import Control.Monad.RWS.Strict qualified import Control.Monad.Reader import Data.Map.Strict qualified as M import Futhark.IR.Pretty () import Futhark.IR.Prop.Types import Futhark.IR.Rep import Futhark.IR.Syntax -- | How some name in scope was bound. data NameInfo rep = LetName (LetDec rep) | FParamName (FParamInfo rep) | LParamName (LParamInfo rep) | IndexName IntType deriving instance (RepTypes rep) => Show (NameInfo rep) instance (RepTypes rep) => Typed (NameInfo rep) where typeOf (LetName dec) = typeOf dec typeOf (FParamName dec) = typeOf dec typeOf (LParamName dec) = typeOf dec typeOf (IndexName it) = Prim $ IntType it -- | A scope is a mapping from variable names to information about -- that name. type Scope rep = M.Map VName (NameInfo rep) -- | The class of applicative functors (or more common in practice: -- monads) that permit the lookup of variable types. A default method -- for 'lookupType' exists, which is sufficient (if not always -- maximally efficient, and using 'error' to fail) when 'askScope' -- is defined. class (Applicative m, RepTypes rep) => HasScope rep m | m -> rep where -- | Return the type of the given variable, or fail if it is not in -- the type environment. lookupType :: VName -> m Type lookupType = fmap typeOf . lookupInfo -- | Return the info of the given variable, or fail if it is not in -- the type environment. lookupInfo :: VName -> m (NameInfo rep) lookupInfo name = asksScope (M.findWithDefault notFound name) where notFound = error $ "Scope.lookupInfo: Name " ++ prettyString name ++ " not found in type environment." -- | Return the type environment contained in the applicative -- functor. askScope :: m (Scope rep) -- | Return the result of applying some function to the type -- environment. asksScope :: (Scope rep -> a) -> m a asksScope f = f <$> askScope instance (Monad m, RepTypes rep) => HasScope rep (ReaderT (Scope rep) m) where askScope = ask instance (Monad m, HasScope rep m) => HasScope rep (ExceptT e m) where askScope = lift askScope instance (Monad m, Monoid w, RepTypes rep) => HasScope rep (Control.Monad.RWS.Strict.RWST (Scope rep) w s m) where askScope = ask instance (Monad m, Monoid w, RepTypes rep) => HasScope rep (Control.Monad.RWS.Lazy.RWST (Scope rep) w s m) where askScope = ask -- | The class of monads that not only provide a 'Scope', but also -- the ability to locally extend it. A 'Reader' containing a -- 'Scope' is the prototypical example of such a monad. class (HasScope rep m, Monad m) => LocalScope rep m where -- | Run a computation with an extended type environment. Note that -- this is intended to *add* to the current type environment, it -- does not replace it. localScope :: Scope rep -> m a -> m a instance (LocalScope rep m) => LocalScope rep (ExceptT e m) where localScope = mapExceptT . localScope instance (Monad m, RepTypes rep) => LocalScope rep (ReaderT (Scope rep) m) where localScope = local . M.union instance (Monad m, Monoid w, RepTypes rep) => LocalScope rep (Control.Monad.RWS.Strict.RWST (Scope rep) w s m) where localScope = local . M.union instance (Monad m, Monoid w, RepTypes rep) => LocalScope rep (Control.Monad.RWS.Lazy.RWST (Scope rep) w s m) where localScope = local . M.union -- | The class of things that can provide a scope. There is no -- overarching rule for what this means. For a 'Stm', it is the -- corresponding pattern. For a t'Lambda', is is the parameters. class Scoped rep a | a -> rep where scopeOf :: a -> Scope rep -- | Extend the monadic scope with the 'scopeOf' the given value. inScopeOf :: (Scoped rep a, LocalScope rep m) => a -> m b -> m b inScopeOf = localScope . scopeOf instance (Scoped rep a) => Scoped rep [a] where scopeOf = mconcat . map scopeOf instance Scoped rep (Stms rep) where scopeOf = foldMap scopeOf instance Scoped rep (Stm rep) where scopeOf = scopeOfPat . stmPat instance Scoped rep (FunDef rep) where scopeOf = scopeOfFParams . funDefParams instance Scoped rep (VName, NameInfo rep) where scopeOf = uncurry M.singleton -- | The scope of a loop form. scopeOfLoopForm :: LoopForm -> Scope rep scopeOfLoopForm (WhileLoop _) = mempty scopeOfLoopForm (ForLoop i it _) = M.singleton i $ IndexName it -- | The scope of a pattern. scopeOfPat :: (LetDec rep ~ dec) => Pat dec -> Scope rep scopeOfPat = mconcat . map scopeOfPatElem . patElems -- | The scope of a pattern element. scopeOfPatElem :: (LetDec rep ~ dec) => PatElem dec -> Scope rep scopeOfPatElem (PatElem name dec) = M.singleton name $ LetName dec -- | The scope of some lambda parameters. scopeOfLParams :: (LParamInfo rep ~ dec) => [Param dec] -> Scope rep scopeOfLParams = M.fromList . map f where f param = (paramName param, LParamName $ paramDec param) -- | The scope of some function or loop parameters. scopeOfFParams :: (FParamInfo rep ~ dec) => [Param dec] -> Scope rep scopeOfFParams = M.fromList . map f where f param = (paramName param, FParamName $ paramDec param) instance Scoped rep (Lambda rep) where scopeOf lam = scopeOfLParams $ lambdaParams lam -- | A constraint that indicates two representations have the same 'NameInfo' -- representation. type SameScope rep1 rep2 = ( LetDec rep1 ~ LetDec rep2, FParamInfo rep1 ~ FParamInfo rep2, LParamInfo rep1 ~ LParamInfo rep2 ) -- | If two scopes are really the same, then you can convert one to -- the other. castScope :: (SameScope fromrep torep) => Scope fromrep -> Scope torep castScope = M.map castNameInfo castNameInfo :: (SameScope fromrep torep) => NameInfo fromrep -> NameInfo torep castNameInfo (LetName dec) = LetName dec castNameInfo (FParamName dec) = FParamName dec castNameInfo (LParamName dec) = LParamName dec castNameInfo (IndexName it) = IndexName it -- | A monad transformer that carries around an extended 'Scope'. -- Its 'lookupType' method will first look in the extended 'Scope', -- and then use the 'lookupType' method of the underlying monad. newtype ExtendedScope rep m a = ExtendedScope (ReaderT (Scope rep) m a) deriving ( Functor, Applicative, Monad, MonadReader (Scope rep) ) instance (HasScope rep m, Monad m) => HasScope rep (ExtendedScope rep m) where lookupType name = do res <- asks $ fmap typeOf . M.lookup name maybe (ExtendedScope $ lift $ lookupType name) pure res askScope = asks M.union <*> ExtendedScope (lift askScope) -- | Run a computation in the extended type environment. extendedScope :: ExtendedScope rep m a -> Scope rep -> m a extendedScope (ExtendedScope m) = runReaderT m futhark-0.25.27/src/Futhark/IR/Prop/TypeOf.hs000066400000000000000000000121201475065116200205250ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | This module provides facilities for obtaining the types of -- various Futhark constructs. Typically, you will need to execute -- these in a context where type information is available as a -- 'Scope'; usually by using a monad that is an instance of -- 'HasScope'. The information is returned as a list of 'ExtType' -- values - one for each of the values the Futhark construct returns. -- Some constructs (such as subexpressions) can produce only a single -- value, and their typing functions hence do not return a list. -- -- Some representations may have more specialised facilities enabling -- even more information - for example, -- "Futhark.IR.Mem" exposes functionality for -- also obtaining information about the storage location of results. module Futhark.IR.Prop.TypeOf ( expExtType, subExpType, subExpResType, basicOpType, mapType, -- * Return type module Futhark.IR.RetType, -- * Type environment module Futhark.IR.Prop.Scope, -- * Extensibility TypedOp (..), ) where import Data.List.NonEmpty (NonEmpty (..)) import Futhark.IR.Prop.Constants import Futhark.IR.Prop.Scope import Futhark.IR.Prop.Types import Futhark.IR.RetType import Futhark.IR.Syntax -- | The type of a subexpression. subExpType :: (HasScope t m) => SubExp -> m Type subExpType (Constant val) = pure $ Prim $ primValueType val subExpType (Var name) = lookupType name -- | Type type of a 'SubExpRes' - not that this might refer to names -- bound in the body containing the result. subExpResType :: (HasScope t m) => SubExpRes -> m Type subExpResType = subExpType . resSubExp -- | @mapType f arrts@ wraps each element in the return type of @f@ in -- an array with size equal to the outermost dimension of the first -- element of @arrts@. mapType :: SubExp -> Lambda rep -> [Type] mapType outersize f = [ arrayOf t (Shape [outersize]) NoUniqueness | t <- lambdaReturnType f ] -- | The type of a primitive operation. basicOpType :: (HasScope rep m) => BasicOp -> m [Type] basicOpType (SubExp se) = pure <$> subExpType se basicOpType (Opaque _ se) = pure <$> subExpType se basicOpType (ArrayVal vs t) = pure [arrayOf (Prim t) (Shape [n]) NoUniqueness] where n = intConst Int64 $ toInteger $ length vs basicOpType (ArrayLit es rt) = pure [arrayOf rt (Shape [n]) NoUniqueness] where n = intConst Int64 $ toInteger $ length es basicOpType (BinOp bop _ _) = pure [Prim $ binOpType bop] basicOpType (UnOp _ x) = pure <$> subExpType x basicOpType CmpOp {} = pure [Prim Bool] basicOpType (ConvOp conv _) = pure [Prim $ snd $ convOpType conv] basicOpType (Index ident slice) = result <$> lookupType ident where result t = [Prim (elemType t) `arrayOfShape` shape] shape = Shape $ sliceDims slice basicOpType (Update _ src _ _) = pure <$> lookupType src basicOpType (FlatIndex ident slice) = result <$> lookupType ident where result t = [Prim (elemType t) `arrayOfShape` shape] shape = Shape $ flatSliceDims slice basicOpType (FlatUpdate src _ _) = pure <$> lookupType src basicOpType (Iota n _ _ et) = pure [arrayOf (Prim (IntType et)) (Shape [n]) NoUniqueness] basicOpType (Replicate (Shape []) e) = pure <$> subExpType e basicOpType (Replicate shape e) = pure . flip arrayOfShape shape <$> subExpType e basicOpType (Scratch t shape) = pure [arrayOf (Prim t) (Shape shape) NoUniqueness] basicOpType (Reshape _ (Shape []) e) = result <$> lookupType e where result t = [Prim $ elemType t] basicOpType (Reshape _ shape e) = result <$> lookupType e where result t = [t `setArrayShape` shape] basicOpType (Rearrange perm e) = result <$> lookupType e where result t = [rearrangeType perm t] basicOpType (Concat i (x :| _) ressize) = result <$> lookupType x where result xt = [setDimSize i xt ressize] basicOpType (Manifest _ v) = pure <$> lookupType v basicOpType Assert {} = pure [Prim Unit] basicOpType (UpdateAcc _ v _ _) = pure <$> lookupType v -- | The type of an expression. expExtType :: (HasScope rep m, TypedOp (OpC rep)) => Exp rep -> m [ExtType] expExtType (Apply _ _ rt _) = pure $ map (fromDecl . declExtTypeOf . fst) rt expExtType (Match _ _ _ rt) = pure $ map extTypeOf $ matchReturns rt expExtType (Loop merge _ _) = pure $ loopExtType $ map fst merge expExtType (BasicOp op) = staticShapes <$> basicOpType op expExtType (WithAcc inputs lam) = fmap staticShapes $ (<>) <$> (concat <$> traverse inputType inputs) <*> pure (drop num_accs (lambdaReturnType lam)) where inputType (_, arrs, _) = traverse lookupType arrs num_accs = length inputs expExtType (Op op) = opType op -- | Given the parameters of a loop, produce the return type. loopExtType :: (Typed dec) => [Param dec] -> [ExtType] loopExtType params = existentialiseExtTypes inaccessible $ staticShapes $ map typeOf params where inaccessible = map paramName params -- | Any operation must define an instance of this class, which -- describes the type of the operation (at the value level). class TypedOp op where opType :: (HasScope rep m) => op rep -> m [ExtType] instance TypedOp NoOp where opType NoOp = pure [] futhark-0.25.27/src/Futhark/IR/Prop/Types.hs000066400000000000000000000447341475065116200204430ustar00rootroot00000000000000-- | Functions for inspecting and constructing various types. module Futhark.IR.Prop.Types ( rankShaped, arrayRank, arrayShape, setArrayShape, isEmptyArray, existential, uniqueness, unique, staticShapes, staticShapes1, primType, isAcc, arrayOf, arrayOfRow, arrayOfShape, setOuterSize, setDimSize, setOuterDim, setOuterDims, setDim, setArrayDims, peelArray, stripArray, arrayDims, arrayExtDims, shapeSize, arraySize, arraysSize, elemType, rowType, transposeType, rearrangeType, mapOnExtType, mapOnType, diet, subtypeOf, subtypesOf, toDecl, fromDecl, isExt, isFree, extractShapeContext, shapeContext, hasStaticShape, generaliseExtTypes, existentialiseExtTypes, shapeExtMapping, -- * Abbreviations int8, int16, int32, int64, float32, float64, -- * The Typed typeclass Typed (..), DeclTyped (..), ExtTyped (..), DeclExtTyped (..), FixExt (..), ) where import Control.Monad import Control.Monad.State import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.IR.Prop.Constants import Futhark.IR.Prop.Rearrange import Futhark.IR.Syntax.Core -- | Remove shape information from a type. rankShaped :: (ArrayShape shape) => TypeBase shape u -> TypeBase Rank u rankShaped (Array et sz u) = Array et (Rank $ shapeRank sz) u rankShaped (Prim pt) = Prim pt rankShaped (Acc acc ispace ts u) = Acc acc ispace ts u rankShaped (Mem space) = Mem space -- | Return the dimensionality of a type. For non-arrays, this is -- zero. For a one-dimensional array it is one, for a two-dimensional -- it is two, and so forth. arrayRank :: (ArrayShape shape) => TypeBase shape u -> Int arrayRank = shapeRank . arrayShape -- | Return the shape of a type - for non-arrays, this is the -- 'mempty'. arrayShape :: (ArrayShape shape) => TypeBase shape u -> shape arrayShape (Array _ ds _) = ds arrayShape _ = mempty -- | Modify the shape of an array - for non-arrays, this does nothing. modifyArrayShape :: (ArrayShape newshape) => (oldshape -> newshape) -> TypeBase oldshape u -> TypeBase newshape u modifyArrayShape f (Array t ds u) | shapeRank ds' == 0 = Prim t | otherwise = Array t ds' u where ds' = f ds modifyArrayShape _ (Prim t) = Prim t modifyArrayShape _ (Acc acc ispace ts u) = Acc acc ispace ts u modifyArrayShape _ (Mem space) = Mem space -- | Set the shape of an array. If the given type is not an -- array, return the type unchanged. setArrayShape :: (ArrayShape newshape) => TypeBase oldshape u -> newshape -> TypeBase newshape u setArrayShape t ds = modifyArrayShape (const ds) t -- | If the array is statically an empty array (meaning any dimension -- is a static zero), return the element type and the shape. isEmptyArray :: Type -> Maybe (PrimType, Shape) isEmptyArray (Array pt (Shape ds) _) | intConst Int64 0 `elem` ds = Just (pt, Shape ds) isEmptyArray _ = Nothing -- | True if the given type has a dimension that is existentially sized. existential :: ExtType -> Bool existential = any ext . shapeDims . arrayShape where ext (Ext _) = True ext (Free _) = False -- | Return the uniqueness of a type. uniqueness :: TypeBase shape Uniqueness -> Uniqueness uniqueness (Array _ _ u) = u uniqueness (Acc _ _ _ u) = u uniqueness _ = Nonunique -- | @unique t@ is 'True' if the type of the argument is unique. unique :: TypeBase shape Uniqueness -> Bool unique = (== Unique) . uniqueness -- | Convert types with non-existential shapes to types with -- existential shapes. Only the representation is changed, so all -- the shapes will be 'Free'. staticShapes :: [TypeBase Shape u] -> [TypeBase ExtShape u] staticShapes = map staticShapes1 -- | As 'staticShapes', but on a single type. staticShapes1 :: TypeBase Shape u -> TypeBase ExtShape u staticShapes1 (Prim t) = Prim t staticShapes1 (Acc acc ispace ts u) = Acc acc ispace ts u staticShapes1 (Array bt (Shape shape) u) = Array bt (Shape $ map Free shape) u staticShapes1 (Mem space) = Mem space -- | @arrayOf t s u@ constructs an array type. The convenience -- compared to using the 'Array' constructor directly is that @t@ can -- itself be an array. If @t@ is an @n@-dimensional array, and @s@ is -- a list of length @n@, the resulting type is of an @n+m@ dimensions. -- The uniqueness of the new array will be @u@, no matter the -- uniqueness of @t@. If the shape @s@ has rank 0, then the @t@ will -- be returned, although if it is an array, with the uniqueness -- changed to @u@. arrayOf :: (ArrayShape shape) => TypeBase shape u_unused -> shape -> u -> TypeBase shape u arrayOf (Array et size1 _) size2 u = Array et (size2 <> size1) u arrayOf (Prim t) shape u | 0 <- shapeRank shape = Prim t | otherwise = Array t shape u arrayOf (Acc acc ispace ts _) _shape u = Acc acc ispace ts u arrayOf Mem {} _ _ = error "arrayOf Mem" -- | Construct an array whose rows are the given type, and the outer -- size is the given dimension. This is just a convenient wrapper -- around 'arrayOf'. arrayOfRow :: (ArrayShape (ShapeBase d)) => TypeBase (ShapeBase d) NoUniqueness -> d -> TypeBase (ShapeBase d) NoUniqueness arrayOfRow t size = arrayOf t (Shape [size]) NoUniqueness -- | Construct an array whose rows are the given type, and the outer -- size is the given t'Shape'. This is just a convenient wrapper -- around 'arrayOf'. arrayOfShape :: Type -> Shape -> Type arrayOfShape t shape = arrayOf t shape NoUniqueness -- | Set the dimensions of an array. If the given type is not an -- array, return the type unchanged. setArrayDims :: TypeBase oldshape u -> [SubExp] -> TypeBase Shape u setArrayDims t dims = t `setArrayShape` Shape dims -- | Replace the size of the outermost dimension of an array. If the -- given type is not an array, it is returned unchanged. setOuterSize :: (ArrayShape (ShapeBase d)) => TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u setOuterSize = setDimSize 0 -- | Replace the size of the given dimension of an array. If the -- given type is not an array, it is returned unchanged. setDimSize :: (ArrayShape (ShapeBase d)) => Int -> TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u setDimSize i t e = t `setArrayShape` setDim i (arrayShape t) e -- | Replace the outermost dimension of an array shape. setOuterDim :: ShapeBase d -> d -> ShapeBase d setOuterDim = setDim 0 -- | Replace some outermost dimensions of an array shape. setOuterDims :: ShapeBase d -> Int -> ShapeBase d -> ShapeBase d setOuterDims old k new = new <> stripDims k old -- | Replace the specified dimension of an array shape. setDim :: Int -> ShapeBase d -> d -> ShapeBase d setDim i (Shape ds) e = Shape $ take i ds ++ e : drop (i + 1) ds -- | @peelArray n t@ returns the type resulting from peeling the first -- @n@ array dimensions from @t@. Returns @Nothing@ if @t@ has less -- than @n@ dimensions. peelArray :: Int -> TypeBase Shape u -> Maybe (TypeBase Shape u) peelArray 0 t = Just t peelArray n (Array et shape u) | shapeRank shape == n = Just $ Prim et | shapeRank shape > n = Just $ Array et (stripDims n shape) u peelArray _ _ = Nothing -- | @stripArray n t@ removes the @n@ outermost layers of the array. -- Essentially, it is the type of indexing an array of type @t@ with -- @n@ indexes. stripArray :: Int -> TypeBase Shape u -> TypeBase Shape u stripArray n (Array et shape u) | n < shapeRank shape = Array et (stripDims n shape) u | otherwise = Prim et stripArray _ t = t -- | Return the size of the given dimension. If the dimension does -- not exist, the zero constant is returned. shapeSize :: Int -> Shape -> SubExp shapeSize i shape = case drop i $ shapeDims shape of e : _ -> e [] -> constant (0 :: Int64) -- | Return the dimensions of a type - for non-arrays, this is the -- empty list. arrayDims :: TypeBase Shape u -> [SubExp] arrayDims = shapeDims . arrayShape -- | Return the existential dimensions of a type - for non-arrays, -- this is the empty list. arrayExtDims :: TypeBase ExtShape u -> [ExtSize] arrayExtDims = shapeDims . arrayShape -- | Return the size of the given dimension. If the dimension does -- not exist, the zero constant is returned. arraySize :: Int -> TypeBase Shape u -> SubExp arraySize i = shapeSize i . arrayShape -- | Return the size of the given dimension in the first element of -- the given type list. If the dimension does not exist, or no types -- are given, the zero constant is returned. arraysSize :: Int -> [TypeBase Shape u] -> SubExp arraysSize _ [] = constant (0 :: Int64) arraysSize i (t : _) = arraySize i t -- | Return the immediate row-type of an array. For @[[int]]@, this -- would be @[int]@. rowType :: TypeBase Shape u -> TypeBase Shape u rowType = stripArray 1 -- | A type is a primitive type if it is not an array or memory block. primType :: TypeBase shape u -> Bool primType Prim {} = True primType _ = False -- | Is this an accumulator? isAcc :: TypeBase shape u -> Bool isAcc Acc {} = True isAcc _ = False -- | Returns the bottommost type of an array. For @[][]i32@, this -- would be @i32@. If the given type is not an array, it is returned. elemType :: TypeBase shape u -> PrimType elemType (Array t _ _) = t elemType (Prim t) = t elemType Acc {} = error "elemType Acc" elemType Mem {} = error "elemType Mem" -- | Swap the two outer dimensions of the type. transposeType :: Type -> Type transposeType = rearrangeType [1, 0] -- | Rearrange the dimensions of the type. If the length of the -- permutation does not match the rank of the type, the permutation -- will be extended with identity. rearrangeType :: [Int] -> Type -> Type rearrangeType perm t = t `setArrayShape` Shape (rearrangeShape perm' $ arrayDims t) where perm' = perm ++ [length perm .. arrayRank t - 1] -- | Transform any t'SubExp's in the type. mapOnExtType :: (Monad m) => (SubExp -> m SubExp) -> TypeBase ExtShape u -> m (TypeBase ExtShape u) mapOnExtType _ (Prim bt) = pure $ Prim bt mapOnExtType f (Acc acc ispace ts u) = Acc <$> f' acc <*> traverse f ispace <*> mapM (mapOnType f) ts <*> pure u where f' v = do x <- f $ Var v case x of Var v' -> pure v' Constant {} -> pure v mapOnExtType _ (Mem space) = pure $ Mem space mapOnExtType f (Array t shape u) = Array t <$> (Shape <$> mapM (traverse f) (shapeDims shape)) <*> pure u -- | Transform any t'SubExp's in the type. mapOnType :: (Monad m) => (SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u) mapOnType _ (Prim bt) = pure $ Prim bt mapOnType f (Acc acc ispace ts u) = Acc <$> f' acc <*> traverse f ispace <*> mapM (mapOnType f) ts <*> pure u where f' v = do x <- f $ Var v case x of Var v' -> pure v' Constant {} -> pure v mapOnType _ (Mem space) = pure $ Mem space mapOnType f (Array t shape u) = Array t <$> (Shape <$> mapM f (shapeDims shape)) <*> pure u -- | @diet t@ returns a description of how a function parameter of -- type @t@ might consume its argument. diet :: TypeBase shape Uniqueness -> Diet diet Prim {} = ObservePrim diet (Acc _ _ _ Unique) = Consume diet (Acc _ _ _ Nonunique) = Observe diet (Array _ _ Unique) = Consume diet (Array _ _ Nonunique) = Observe diet Mem {} = Observe -- | @x \`subtypeOf\` y@ is true if @x@ is a subtype of @y@ (or equal to -- @y@), meaning @x@ is valid whenever @y@ is. subtypeOf :: (Ord u, ArrayShape shape) => TypeBase shape u -> TypeBase shape u -> Bool subtypeOf (Array t1 shape1 u1) (Array t2 shape2 u2) = u2 <= u1 && t1 == t2 && shape1 `subShapeOf` shape2 subtypeOf t1 t2 = t1 == t2 -- | @xs \`subtypesOf\` ys@ is true if @xs@ is the same size as @ys@, -- and each element in @xs@ is a subtype of the corresponding element -- in @ys@.. subtypesOf :: (Ord u, ArrayShape shape) => [TypeBase shape u] -> [TypeBase shape u] -> Bool subtypesOf xs ys = length xs == length ys && and (zipWith subtypeOf xs ys) -- | Add the given uniqueness information to the types. toDecl :: TypeBase shape NoUniqueness -> Uniqueness -> TypeBase shape Uniqueness toDecl (Prim t) _ = Prim t toDecl (Acc acc ispace ts _) u = Acc acc ispace ts u toDecl (Array et shape _) u = Array et shape u toDecl (Mem space) _ = Mem space -- | Remove uniqueness information from the type. fromDecl :: TypeBase shape Uniqueness -> TypeBase shape NoUniqueness fromDecl (Prim t) = Prim t fromDecl (Acc acc ispace ts _) = Acc acc ispace ts NoUniqueness fromDecl (Array et shape _) = Array et shape NoUniqueness fromDecl (Mem space) = Mem space -- | If an existential, then return its existential index. isExt :: Ext a -> Maybe Int isExt (Ext i) = Just i isExt _ = Nothing -- | If a known size, then return that size. isFree :: Ext a -> Maybe a isFree (Free d) = Just d isFree _ = Nothing -- | Given the existential return type of a function, and the shapes -- of the values returned by the function, return the existential -- shape context. That is, those sizes that are existential in the -- return type. extractShapeContext :: [TypeBase ExtShape u] -> [[a]] -> [a] extractShapeContext ts shapes = evalState (concat <$> zipWithM extract ts shapes) S.empty where extract t shape = catMaybes <$> zipWithM extract' (shapeDims $ arrayShape t) shape extract' (Ext x) v = do seen <- gets $ S.member x if seen then pure Nothing else do modify $ S.insert x pure $ Just v extract' (Free _) _ = pure Nothing -- | The 'Ext' integers used for existential sizes in the given types. shapeContext :: [TypeBase ExtShape u] -> S.Set Int shapeContext = S.fromList . concatMap (mapMaybe isExt . shapeDims . arrayShape) -- | If all dimensions of the given 'ExtShape' are statically known, -- change to the corresponding t'Shape'. hasStaticShape :: TypeBase ExtShape u -> Maybe (TypeBase Shape u) hasStaticShape (Prim bt) = Just $ Prim bt hasStaticShape (Acc acc ispace ts u) = Just $ Acc acc ispace ts u hasStaticShape (Mem space) = Just $ Mem space hasStaticShape (Array bt (Shape shape) u) = Array bt <$> (Shape <$> mapM isFree shape) <*> pure u -- | Given two lists of 'ExtType's of the same length, return a list -- of 'ExtType's that is a subtype of the two operands. generaliseExtTypes :: [TypeBase ExtShape u] -> [TypeBase ExtShape u] -> [TypeBase ExtShape u] generaliseExtTypes rt1 rt2 = evalState (zipWithM unifyExtShapes rt1 rt2) (0, M.empty) where unifyExtShapes t1 t2 = setArrayShape t1 . Shape <$> zipWithM unifyExtDims (shapeDims $ arrayShape t1) (shapeDims $ arrayShape t2) unifyExtDims (Free se1) (Free se2) | se1 == se2 = pure $ Free se1 -- Arbitrary | otherwise = do (n, m) <- get put (n + 1, m) pure $ Ext n unifyExtDims (Ext x) (Ext y) | x == y = Ext <$> (maybe (new x) pure =<< gets (M.lookup x . snd)) unifyExtDims (Ext x) _ = Ext <$> new x unifyExtDims _ (Ext x) = Ext <$> new x new x = do (n, m) <- get put (n + 1, M.insert x n m) pure n -- | Given a list of 'ExtType's and a list of "forbidden" names, -- modify the dimensions of the 'ExtType's such that they are 'Ext' -- where they were previously 'Free' with a variable in the set of -- forbidden names. existentialiseExtTypes :: [VName] -> [ExtType] -> [ExtType] existentialiseExtTypes inaccessible = map makeBoundShapesFree where makeBoundShapesFree = modifyArrayShape $ fmap checkDim checkDim (Free (Var v)) | Just i <- v `L.elemIndex` inaccessible = Ext i checkDim d = d -- | Produce a mapping for the dimensions context. shapeExtMapping :: [TypeBase ExtShape u] -> [TypeBase Shape u1] -> M.Map Int SubExp shapeExtMapping = dimMapping arrayExtDims arrayDims match mappend where match Free {} _ = mempty match (Ext i) dim = M.singleton i dim dimMapping :: (Monoid res) => (t1 -> [dim1]) -> (t2 -> [dim2]) -> (dim1 -> dim2 -> res) -> (res -> res -> res) -> [t1] -> [t2] -> res dimMapping getDims1 getDims2 f comb ts1 ts2 = L.foldl' comb mempty $ concat $ zipWith (zipWith f) (map getDims1 ts1) (map getDims2 ts2) -- | @IntType Int8@ int8 :: PrimType int8 = IntType Int8 -- | @IntType Int16@ int16 :: PrimType int16 = IntType Int16 -- | @IntType Int32@ int32 :: PrimType int32 = IntType Int32 -- | @IntType Int64@ int64 :: PrimType int64 = IntType Int64 -- | @FloatType Float32@ float32 :: PrimType float32 = FloatType Float32 -- | @FloatType Float64@ float64 :: PrimType float64 = FloatType Float64 -- | Typeclass for things that contain 'Type's. class Typed t where typeOf :: t -> Type instance Typed Type where typeOf = id instance Typed DeclType where typeOf = fromDecl instance Typed Ident where typeOf = identType instance (Typed dec) => Typed (Param dec) where typeOf = typeOf . paramDec instance (Typed dec) => Typed (PatElem dec) where typeOf = typeOf . patElemDec instance (Typed b) => Typed (a, b) where typeOf = typeOf . snd -- | Typeclass for things that contain 'DeclType's. class DeclTyped t where declTypeOf :: t -> DeclType instance DeclTyped DeclType where declTypeOf = id instance (DeclTyped dec) => DeclTyped (Param dec) where declTypeOf = declTypeOf . paramDec -- | Typeclass for things that contain 'ExtType's. class (FixExt t) => ExtTyped t where extTypeOf :: t -> ExtType instance ExtTyped ExtType where extTypeOf = id instance ExtTyped DeclExtType where extTypeOf = fromDecl . declExtTypeOf -- | Typeclass for things that contain 'DeclExtType's. class (FixExt t) => DeclExtTyped t where declExtTypeOf :: t -> DeclExtType instance DeclExtTyped DeclExtType where declExtTypeOf = id -- | Something with an existential context that can be (partially) -- fixed. class FixExt t where -- | Fix the given existentional variable to the indicated free -- value. fixExt :: Int -> SubExp -> t -> t -- | Map a function onto any existential. mapExt :: (Int -> Int) -> t -> t instance (FixExt shape, ArrayShape shape) => FixExt (TypeBase shape u) where fixExt i se = modifyArrayShape $ fixExt i se mapExt f = modifyArrayShape $ mapExt f instance (FixExt d) => FixExt (ShapeBase d) where fixExt i se = fmap $ fixExt i se mapExt f = fmap $ mapExt f instance (FixExt a) => FixExt [a] where fixExt i se = fmap $ fixExt i se mapExt f = fmap $ mapExt f instance FixExt ExtSize where fixExt i se (Ext j) | j > i = Ext $ j - 1 | j == i = Free se | otherwise = Ext j fixExt _ _ (Free x) = Free x mapExt f (Ext i) = Ext $ f i mapExt _ (Free x) = Free x instance FixExt () where fixExt _ _ () = () mapExt _ () = () futhark-0.25.27/src/Futhark/IR/Rep.hs000066400000000000000000000052021475065116200171300ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | The core Futhark AST is parameterised by a @rep@ type parameter, -- which is then used to invoke the type families defined here. module Futhark.IR.Rep ( RepTypes (..), Op, NoOp (..), module Futhark.IR.RetType, ) where import Data.Kind qualified import Futhark.IR.Prop.Types import Futhark.IR.RetType import Futhark.IR.Syntax.Core (DeclExtType, DeclType, ExtType, Type) -- | Returns nothing and does nothing. Placeholder for when we don't -- really want an operation. data NoOp rep = NoOp deriving (Eq, Ord, Show) -- | A collection of type families giving various common types for a -- representation, along with constraints specifying that the types -- they map to should satisfy some minimal requirements. class ( Show (LetDec l), Show (ExpDec l), Show (BodyDec l), Show (FParamInfo l), Show (LParamInfo l), Show (RetType l), Show (BranchType l), Show (Op l), Eq (LetDec l), Eq (ExpDec l), Eq (BodyDec l), Eq (FParamInfo l), Eq (LParamInfo l), Eq (RetType l), Eq (BranchType l), Eq (Op l), Ord (LetDec l), Ord (ExpDec l), Ord (BodyDec l), Ord (FParamInfo l), Ord (LParamInfo l), Ord (RetType l), Ord (BranchType l), Ord (Op l), IsRetType (RetType l), IsBodyType (BranchType l), Typed (FParamInfo l), Typed (LParamInfo l), Typed (LetDec l), DeclTyped (FParamInfo l) ) => RepTypes l where -- | Decoration for every let-pattern element. type LetDec l :: Data.Kind.Type type LetDec l = Type -- | Decoration for every expression. type ExpDec l :: Data.Kind.Type type ExpDec l = () -- | Decoration for every body. type BodyDec l :: Data.Kind.Type type BodyDec l = () -- | Decoration for every (non-lambda) function parameter. type FParamInfo l :: Data.Kind.Type type FParamInfo l = DeclType -- | Decoration for every lambda function parameter. type LParamInfo l :: Data.Kind.Type type LParamInfo l = Type -- | The return type decoration of function calls. type RetType l :: Data.Kind.Type type RetType l = DeclExtType -- | The return type decoration of branches. type BranchType l :: Data.Kind.Type type BranchType l = ExtType -- | Type constructor for the extensible operation. The somewhat -- funky definition is to ensure that we can change the "inner" -- representation in a generic way (e.g. add aliasing information) -- In most code, you will use the 'Op' alias instead. type OpC l :: Data.Kind.Type -> Data.Kind.Type type OpC l = NoOp -- | Apply the 'OpC' constructor of a representation to that -- representation. type Op l = OpC l l futhark-0.25.27/src/Futhark/IR/Rephrase.hs000066400000000000000000000103201475065116200201500ustar00rootroot00000000000000-- | Facilities for changing the representation of some fragment, -- within a monadic context. We call this "rephrasing", for no deep -- reason. module Futhark.IR.Rephrase ( rephraseProg, rephraseFunDef, rephraseExp, rephraseBody, rephraseStm, rephraseLambda, rephrasePat, rephrasePatElem, Rephraser (..), RephraseOp (..), ) where import Data.Bitraversable import Futhark.IR.Syntax import Futhark.IR.Traversals -- | A collection of functions that together allow us to rephrase some -- IR fragment, in some monad @m@. If we let @m@ be the 'Maybe' -- monad, we can conveniently do rephrasing that might fail. This is -- useful if you want to see if some IR in e.g. the @Kernels@ rep -- actually uses any @Kernels@-specific operations. data Rephraser m from to = Rephraser { rephraseExpDec :: ExpDec from -> m (ExpDec to), rephraseLetBoundDec :: LetDec from -> m (LetDec to), rephraseFParamDec :: FParamInfo from -> m (FParamInfo to), rephraseLParamDec :: LParamInfo from -> m (LParamInfo to), rephraseBodyDec :: BodyDec from -> m (BodyDec to), rephraseRetType :: RetType from -> m (RetType to), rephraseBranchType :: BranchType from -> m (BranchType to), rephraseOp :: Op from -> m (Op to) } -- | Rephrase an entire program. rephraseProg :: (Monad m) => Rephraser m from to -> Prog from -> m (Prog to) rephraseProg rephraser prog = do consts <- mapM (rephraseStm rephraser) (progConsts prog) funs <- mapM (rephraseFunDef rephraser) (progFuns prog) pure $ prog {progConsts = consts, progFuns = funs} -- | Rephrase a function definition. rephraseFunDef :: (Monad m) => Rephraser m from to -> FunDef from -> m (FunDef to) rephraseFunDef rephraser fundec = do body' <- rephraseBody rephraser $ funDefBody fundec params' <- mapM (rephraseParam $ rephraseFParamDec rephraser) $ funDefParams fundec rettype' <- mapM (bitraverse (rephraseRetType rephraser) pure) $ funDefRetType fundec pure fundec {funDefBody = body', funDefParams = params', funDefRetType = rettype'} -- | Rephrase an expression. rephraseExp :: (Monad m) => Rephraser m from to -> Exp from -> m (Exp to) rephraseExp = mapExpM . mapper -- | Rephrase a statement. rephraseStm :: (Monad m) => Rephraser m from to -> Stm from -> m (Stm to) rephraseStm rephraser (Let pat (StmAux cs attrs dec) e) = Let <$> rephrasePat (rephraseLetBoundDec rephraser) pat <*> (StmAux cs attrs <$> rephraseExpDec rephraser dec) <*> rephraseExp rephraser e -- | Rephrase a pattern. rephrasePat :: (Monad m) => (from -> m to) -> Pat from -> m (Pat to) rephrasePat = traverse -- | Rephrase a pattern element. rephrasePatElem :: (Monad m) => (from -> m to) -> PatElem from -> m (PatElem to) rephrasePatElem rephraser (PatElem ident from) = PatElem ident <$> rephraser from -- | Rephrase a parameter. rephraseParam :: (Monad m) => (from -> m to) -> Param from -> m (Param to) rephraseParam rephraser (Param attrs name from) = Param attrs name <$> rephraser from -- | Rephrase a body. rephraseBody :: (Monad m) => Rephraser m from to -> Body from -> m (Body to) rephraseBody rephraser (Body rep stms res) = Body <$> rephraseBodyDec rephraser rep <*> (stmsFromList <$> mapM (rephraseStm rephraser) (stmsToList stms)) <*> pure res -- | Rephrase a lambda. rephraseLambda :: (Monad m) => Rephraser m from to -> Lambda from -> m (Lambda to) rephraseLambda rephraser lam = do body' <- rephraseBody rephraser $ lambdaBody lam params' <- mapM (rephraseParam $ rephraseLParamDec rephraser) $ lambdaParams lam pure lam {lambdaBody = body', lambdaParams = params'} mapper :: (Monad m) => Rephraser m from to -> Mapper from to m mapper rephraser = identityMapper { mapOnBody = const $ rephraseBody rephraser, mapOnRetType = rephraseRetType rephraser, mapOnBranchType = rephraseBranchType rephraser, mapOnFParam = rephraseParam (rephraseFParamDec rephraser), mapOnLParam = rephraseParam (rephraseLParamDec rephraser), mapOnOp = rephraseOp rephraser } -- | Rephrasing any fragments inside an Op from one representation to -- another. class RephraseOp op where rephraseInOp :: (Monad m) => Rephraser m from to -> op from -> m (op to) instance RephraseOp NoOp where rephraseInOp _ NoOp = pure NoOp futhark-0.25.27/src/Futhark/IR/RetType.hs000066400000000000000000000050041475065116200177760ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | This module exports a type class covering representations of -- function return types. module Futhark.IR.RetType ( IsBodyType (..), IsRetType (..), expectedTypes, ) where import Control.Monad.Identity import Data.Map.Strict qualified as M import Futhark.IR.Prop.Types import Futhark.IR.Syntax.Core -- | A type representing the return type of a body. It should contain -- at least the information contained in a list of 'ExtType's, but may -- have more, notably an existential context. class (Show rt, Eq rt, Ord rt, ExtTyped rt) => IsBodyType rt where -- | Construct a body type from a primitive type. primBodyType :: PrimType -> rt instance IsBodyType ExtType where primBodyType = Prim -- | A type representing the return type of a function. In practice, -- a list of these will be used. It should contain at least the -- information contained in an 'ExtType', but may have more, notably -- an existential context. class (Show rt, Eq rt, Ord rt, ExtTyped rt, DeclExtTyped rt) => IsRetType rt where -- | Contruct a return type from a primitive type. primRetType :: PrimType -> rt -- | Given a function return type, the parameters of the function, -- and the arguments for a concrete call, return the instantiated -- return type for the concrete call, if valid. applyRetType :: (Typed dec) => [rt] -> [Param dec] -> [(SubExp, Type)] -> Maybe [rt] -- | Given shape parameter names and types, produce the types of -- arguments accepted. expectedTypes :: (Typed t) => [VName] -> [t] -> [SubExp] -> [Type] expectedTypes shapes value_ts args = map (correctDims . typeOf) value_ts where parammap :: M.Map VName SubExp parammap = M.fromList $ zip shapes args correctDims = runIdentity . mapOnType (pure . f) where f (Var v) | Just se <- M.lookup v parammap = se f se = se instance IsRetType DeclExtType where primRetType = Prim applyRetType extret params args = if length args == length params && and ( zipWith subtypeOf argtypes $ expectedTypes (map paramName params) params $ map fst args ) then Just $ map correctExtDims extret else Nothing where argtypes = map snd args parammap :: M.Map VName SubExp parammap = M.fromList $ zip (map paramName params) (map fst args) correctExtDims = runIdentity . mapOnExtType (pure . f) where f (Var v) | Just se <- M.lookup v parammap = se f se = se futhark-0.25.27/src/Futhark/IR/SOACS.hs000066400000000000000000000036751475065116200172660ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | A simple representation with SOACs and nested parallelism. module Futhark.IR.SOACS ( SOACS, usesAD, -- * Module re-exports module Futhark.IR.Prop, module Futhark.IR.Traversals, module Futhark.IR.Pretty, module Futhark.IR.Syntax, module Futhark.IR.SOACS.SOAC, ) where import Futhark.Builder import Futhark.Construct import Futhark.IR.Pretty import Futhark.IR.Prop import Futhark.IR.SOACS.SOAC import Futhark.IR.Syntax import Futhark.IR.Traversals import Futhark.IR.TypeCheck qualified as TC -- | The rep for the basic representation. data SOACS instance RepTypes SOACS where type OpC SOACS = SOAC instance ASTRep SOACS where expTypesFromPat = pure . expExtTypesFromPat instance TC.Checkable SOACS where checkOp = typeCheckSOAC instance Buildable SOACS where mkBody = Body () mkExpPat merge _ = basicPat merge mkExpDec _ _ = () mkLetNames = simpleMkLetNames instance BuilderOps SOACS instance PrettyRep SOACS usesAD :: Prog SOACS -> Bool usesAD prog = any stmUsesAD (progConsts prog) || any funUsesAD (progFuns prog) where funUsesAD = bodyUsesAD . funDefBody bodyUsesAD = any stmUsesAD . bodyStms stmUsesAD = expUsesAD . stmExp lamUsesAD = bodyUsesAD . lambdaBody expUsesAD (Op JVP {}) = True expUsesAD (Op VJP {}) = True expUsesAD (Op (Stream _ _ _ lam)) = lamUsesAD lam expUsesAD (Op (Screma _ _ (ScremaForm lam scans reds))) = lamUsesAD lam || any (lamUsesAD . scanLambda) scans || any (lamUsesAD . redLambda) reds expUsesAD (Op (Hist _ _ ops lam)) = lamUsesAD lam || any (lamUsesAD . histOp) ops expUsesAD (Op (Scatter _ _ _ lam)) = lamUsesAD lam expUsesAD (Match _ cases def_case _) = any (bodyUsesAD . caseBody) cases || bodyUsesAD def_case expUsesAD (Loop _ _ body) = bodyUsesAD body expUsesAD (WithAcc _ lam) = lamUsesAD lam expUsesAD BasicOp {} = False expUsesAD Apply {} = False futhark-0.25.27/src/Futhark/IR/SOACS/000077500000000000000000000000001475065116200167175ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/IR/SOACS/SOAC.hs000066400000000000000000001121061475065116200200010ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -- | Definition of /Second-Order Array Combinators/ (SOACs), which are -- the main form of parallelism in the early stages of the compiler. module Futhark.IR.SOACS.SOAC ( SOAC (..), ScremaForm (..), ScatterSpec, HistOp (..), Scan (..), scanResults, singleScan, Reduce (..), redResults, singleReduce, -- * Utility scremaType, soacType, typeCheckSOAC, mkIdentityLambda, isIdentityLambda, nilFn, scanomapSOAC, redomapSOAC, scanSOAC, reduceSOAC, mapSOAC, isScanomapSOAC, isRedomapSOAC, isScanSOAC, isReduceSOAC, isMapSOAC, ppScrema, ppHist, ppStream, ppScatter, groupScatterResults, groupScatterResults', splitScatterResults, -- * Generic traversal SOACMapper (..), identitySOACMapper, mapSOACM, traverseSOACStms, ) where import Control.Category import Control.Monad import Control.Monad.Identity import Control.Monad.State.Strict import Control.Monad.Writer import Data.Function ((&)) import Data.List (intersperse) import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Analysis.Alias qualified as Alias import Futhark.Analysis.DataDependencies import Futhark.Analysis.Metrics import Futhark.Analysis.PrimExp.Convert import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Construct import Futhark.IR import Futhark.IR.Aliases (Aliases, CanBeAliased (..)) import Futhark.IR.Prop.Aliases import Futhark.IR.TypeCheck qualified as TC import Futhark.Optimise.Simplify.Rep import Futhark.Transform.Rename import Futhark.Transform.Substitute import Futhark.Util (chunks, maybeNth, splitAt3) import Futhark.Util.Pretty (Doc, align, comma, commasep, docText, parens, ppTuple', pretty, (<+>), ()) import Futhark.Util.Pretty qualified as PP import Prelude hiding (id, (.)) -- | How the results of a scatter operation should be written. Each -- element of the list consists of a @v@ (often a `VName`) specifying -- which array to scatter to, a `Shape` describing the shape of that -- array, and an `Int` describing how many elements should be written -- to that array for each invocation of the scatter. type ScatterSpec v = [(Shape, Int, v)] -- | A second-order array combinator (SOAC). data SOAC rep = Stream SubExp [VName] [SubExp] (Lambda rep) | -- | @Scatter @ -- -- Scatter maps values from a set of input arrays to indices and values of a -- set of output arrays. It is able to write multiple values to multiple -- outputs each of which may have multiple dimensions. -- -- is a list of input arrays, all having size , elements of -- which are applied to the function. For instance, if there are -- two arrays, will get two values as input, one from each array. -- -- specifies the result of the and which arrays to -- write to. -- -- is a function that takes inputs from and -- returns values according to . It returns values in the -- following manner: -- -- [index_0, index_1, ..., index_n, value_0, value_1, ..., value_m] -- -- For each output in , returns * index -- values and output values, where is the number of -- dimensions (rank) of the given output, and is the number of -- output values written to the given output. -- -- For example, given the following scatter specification: -- -- [([x1, y1, z1], 2, arr1), ([x2, y2], 1, arr2)] -- -- will produce 6 (3 * 2) index values and 2 output values for -- , and 2 (2 * 1) index values and 1 output value for -- arr2. Additionally, the results are grouped, so the first 6 index values -- will correspond to the first two output values, and so on. For this -- example, should return a total of 11 values, 8 index values and -- 3 output values. See also 'splitScatterResults'. Scatter SubExp [VName] (ScatterSpec VName) (Lambda rep) | -- | @Hist @ -- -- The final lambda produces indexes and values for the 'HistOp's. Hist SubExp [VName] [HistOp rep] (Lambda rep) | -- FIXME: this should not be here JVP [SubExp] [SubExp] (Lambda rep) | -- FIXME: this should not be here VJP [SubExp] [SubExp] (Lambda rep) | -- | A combination of scan, reduction, and map. The first -- t'SubExp' is the size of the input arrays. Screma SubExp [VName] (ScremaForm rep) deriving (Eq, Ord, Show) -- | Information about computing a single histogram. data HistOp rep = HistOp { histShape :: Shape, -- | Race factor @RF@ means that only @1/RF@ -- bins are used. histRaceFactor :: SubExp, histDest :: [VName], histNeutral :: [SubExp], histOp :: Lambda rep } deriving (Eq, Ord, Show) -- | The essential parts of a 'Screma' factored out (everything -- except the input arrays). data ScremaForm rep = ScremaForm { -- | The "main" lambda of the Screma. For a map, this is -- equivalent to 'isMapSOAC'. Note that the meaning of the return -- value of this lambda depends crucially on exactly which Screma -- this is. The parameters will correspond exactly to elements of -- the input arrays, however. scremaLambda :: Lambda rep, scremaScans :: [Scan rep], scremaReduces :: [Reduce rep] } deriving (Eq, Ord, Show) singleBinOp :: (Buildable rep) => [Lambda rep] -> Lambda rep singleBinOp lams = Lambda { lambdaParams = concatMap xParams lams ++ concatMap yParams lams, lambdaReturnType = concatMap lambdaReturnType lams, lambdaBody = mkBody (mconcat (map (bodyStms . lambdaBody) lams)) (concatMap (bodyResult . lambdaBody) lams) } where xParams lam = take (length (lambdaReturnType lam)) (lambdaParams lam) yParams lam = drop (length (lambdaReturnType lam)) (lambdaParams lam) -- | How to compute a single scan result. data Scan rep = Scan { scanLambda :: Lambda rep, scanNeutral :: [SubExp] } deriving (Eq, Ord, Show) -- | What are the sizes of reduction results produced by these 'Scan's? scanSizes :: [Scan rep] -> [Int] scanSizes = map (length . scanNeutral) -- | How many reduction results are produced by these 'Scan's? scanResults :: [Scan rep] -> Int scanResults = sum . scanSizes -- | Combine multiple scan operators to a single operator. singleScan :: (Buildable rep) => [Scan rep] -> Scan rep singleScan scans = let scan_nes = concatMap scanNeutral scans scan_lam = singleBinOp $ map scanLambda scans in Scan scan_lam scan_nes -- | How to compute a single reduction result. data Reduce rep = Reduce { redComm :: Commutativity, redLambda :: Lambda rep, redNeutral :: [SubExp] } deriving (Eq, Ord, Show) -- | What are the sizes of reduction results produced by these 'Reduce's? redSizes :: [Reduce rep] -> [Int] redSizes = map (length . redNeutral) -- | How many reduction results are produced by these 'Reduce's? redResults :: [Reduce rep] -> Int redResults = sum . redSizes -- | Combine multiple reduction operators to a single operator. singleReduce :: (Buildable rep) => [Reduce rep] -> Reduce rep singleReduce reds = let red_nes = concatMap redNeutral reds red_lam = singleBinOp $ map redLambda reds in Reduce (mconcat (map redComm reds)) red_lam red_nes -- | The types produced by a single 'Screma', given the size of the -- input array. scremaType :: SubExp -> ScremaForm rep -> [Type] scremaType w (ScremaForm map_lam scans reds) = scan_tps ++ red_tps ++ map (`arrayOfRow` w) map_tps where scan_tps = map (`arrayOfRow` w) $ concatMap (lambdaReturnType . scanLambda) scans red_tps = concatMap (lambdaReturnType . redLambda) reds map_tps = drop (length scan_tps + length red_tps) $ lambdaReturnType map_lam -- | Construct a lambda that takes parameters of the given types and -- simply returns them unchanged. mkIdentityLambda :: (Buildable rep, MonadFreshNames m) => [Type] -> m (Lambda rep) mkIdentityLambda ts = do params <- mapM (newParam "x") ts pure Lambda { lambdaParams = params, lambdaBody = mkBody mempty $ varsRes $ map paramName params, lambdaReturnType = ts } -- | Is the given lambda an identity lambda? isIdentityLambda :: Lambda rep -> Bool isIdentityLambda lam = map resSubExp (bodyResult (lambdaBody lam)) == map (Var . paramName) (lambdaParams lam) -- | A lambda with no parameters that returns no values. nilFn :: (Buildable rep) => Lambda rep nilFn = Lambda mempty mempty (mkBody mempty mempty) -- | Construct a Screma with possibly multiple scans, and -- the given map function. scanomapSOAC :: [Scan rep] -> Lambda rep -> ScremaForm rep scanomapSOAC scans lam = ScremaForm lam scans [] -- | Construct a Screma with possibly multiple reductions, and -- the given map function. redomapSOAC :: [Reduce rep] -> Lambda rep -> ScremaForm rep redomapSOAC reds lam = ScremaForm lam [] reds -- | Construct a Screma with possibly multiple scans, and identity map -- function. scanSOAC :: (Buildable rep, MonadFreshNames m) => [Scan rep] -> m (ScremaForm rep) scanSOAC scans = scanomapSOAC scans <$> mkIdentityLambda ts where ts = concatMap (lambdaReturnType . scanLambda) scans -- | Construct a Screma with possibly multiple reductions, and -- identity map function. reduceSOAC :: (Buildable rep, MonadFreshNames m) => [Reduce rep] -> m (ScremaForm rep) reduceSOAC reds = redomapSOAC reds <$> mkIdentityLambda ts where ts = concatMap (lambdaReturnType . redLambda) reds -- | Construct a Screma corresponding to a map. mapSOAC :: Lambda rep -> ScremaForm rep mapSOAC lam = ScremaForm lam [] [] -- | Does this Screma correspond to a scan-map composition? isScanomapSOAC :: ScremaForm rep -> Maybe ([Scan rep], Lambda rep) isScanomapSOAC (ScremaForm map_lam scans reds) = do guard $ null reds guard $ not $ null scans pure (scans, map_lam) -- | Does this Screma correspond to pure scan? isScanSOAC :: ScremaForm rep -> Maybe [Scan rep] isScanSOAC form = do (scans, map_lam) <- isScanomapSOAC form guard $ isIdentityLambda map_lam pure scans -- | Does this Screma correspond to a reduce-map composition? isRedomapSOAC :: ScremaForm rep -> Maybe ([Reduce rep], Lambda rep) isRedomapSOAC (ScremaForm map_lam scans reds) = do guard $ null scans guard $ not $ null reds pure (reds, map_lam) -- | Does this Screma correspond to a pure reduce? isReduceSOAC :: ScremaForm rep -> Maybe [Reduce rep] isReduceSOAC form = do (reds, map_lam) <- isRedomapSOAC form guard $ isIdentityLambda map_lam pure reds -- | Does this Screma correspond to a simple map, without any -- reduction or scan results? isMapSOAC :: ScremaForm rep -> Maybe (Lambda rep) isMapSOAC (ScremaForm map_lam scans reds) = do guard $ null scans guard $ null reds pure map_lam -- | @splitScatterResults @ -- -- Splits the results array into indices and values according to the -- specification. -- -- See 'groupScatterResults' for more information. splitScatterResults :: [(Shape, Int, array)] -> [a] -> ([a], [a]) splitScatterResults output_spec results = let (shapes, ns, _) = unzip3 output_spec num_indices = sum $ zipWith (*) ns $ map length shapes in splitAt num_indices results -- | @groupScatterResults' @ -- -- Blocks the index values and result values of according to -- the specification. This is the simpler version of -- @groupScatterResults@, which doesn't return any information about -- shapes or output arrays. -- -- See 'groupScatterResults' for more information, groupScatterResults' :: [(Shape, Int, array)] -> [a] -> [([a], a)] groupScatterResults' output_spec results = let (indices, values) = splitScatterResults output_spec results (shapes, ns, _) = unzip3 output_spec chunk_sizes = concat $ zipWith (\shp n -> replicate n $ length shp) shapes ns in zip (chunks chunk_sizes indices) values -- | @groupScatterResults @ -- -- Blocks the index values and result values of according to the -- . -- -- This function is used for extracting and grouping the results of a -- scatter. In the SOACS representation, the lambda inside a 'Scatter' returns -- all indices and values as one big list. This function groups each value with -- its corresponding indices (as determined by the t'Shape' of the output array). -- -- The elements of the resulting list correspond to the shape and name of the -- output parameters, in addition to a list of values written to that output -- parameter, along with the array indices marking where to write them to. -- -- See 'Scatter' for more information. groupScatterResults :: ScatterSpec array -> [a] -> [(Shape, array, [([a], a)])] groupScatterResults output_spec results = let (shapes, ns, arrays) = unzip3 output_spec in groupScatterResults' output_spec results & chunks ns & zip3 shapes arrays -- | Like 'Mapper', but just for 'SOAC's. data SOACMapper frep trep m = SOACMapper { mapOnSOACSubExp :: SubExp -> m SubExp, mapOnSOACLambda :: Lambda frep -> m (Lambda trep), mapOnSOACVName :: VName -> m VName } -- | A mapper that simply returns the SOAC verbatim. identitySOACMapper :: forall rep m. (Monad m) => SOACMapper rep rep m identitySOACMapper = SOACMapper { mapOnSOACSubExp = pure, mapOnSOACLambda = pure, mapOnSOACVName = pure } -- | Map a monadic action across the immediate children of a -- SOAC. The mapping does not descend recursively into subexpressions -- and is done left-to-right. mapSOACM :: (Monad m) => SOACMapper frep trep m -> SOAC frep -> m (SOAC trep) mapSOACM tv (JVP args vec lam) = JVP <$> mapM (mapOnSOACSubExp tv) args <*> mapM (mapOnSOACSubExp tv) vec <*> mapOnSOACLambda tv lam mapSOACM tv (VJP args vec lam) = VJP <$> mapM (mapOnSOACSubExp tv) args <*> mapM (mapOnSOACSubExp tv) vec <*> mapOnSOACLambda tv lam mapSOACM tv (Stream size arrs accs lam) = Stream <$> mapOnSOACSubExp tv size <*> mapM (mapOnSOACVName tv) arrs <*> mapM (mapOnSOACSubExp tv) accs <*> mapOnSOACLambda tv lam mapSOACM tv (Scatter w ivs as lam) = Scatter <$> mapOnSOACSubExp tv w <*> mapM (mapOnSOACVName tv) ivs <*> mapM ( \(aw, an, a) -> (,,) <$> mapM (mapOnSOACSubExp tv) aw <*> pure an <*> mapOnSOACVName tv a ) as <*> mapOnSOACLambda tv lam mapSOACM tv (Hist w arrs ops bucket_fun) = Hist <$> mapOnSOACSubExp tv w <*> mapM (mapOnSOACVName tv) arrs <*> mapM ( \(HistOp shape rf op_arrs nes op) -> HistOp <$> mapM (mapOnSOACSubExp tv) shape <*> mapOnSOACSubExp tv rf <*> mapM (mapOnSOACVName tv) op_arrs <*> mapM (mapOnSOACSubExp tv) nes <*> mapOnSOACLambda tv op ) ops <*> mapOnSOACLambda tv bucket_fun mapSOACM tv (Screma w arrs (ScremaForm map_lam scans reds)) = Screma <$> mapOnSOACSubExp tv w <*> mapM (mapOnSOACVName tv) arrs <*> ( ScremaForm <$> mapOnSOACLambda tv map_lam <*> forM scans ( \(Scan red_lam red_nes) -> Scan <$> mapOnSOACLambda tv red_lam <*> mapM (mapOnSOACSubExp tv) red_nes ) <*> forM reds ( \(Reduce comm red_lam red_nes) -> Reduce comm <$> mapOnSOACLambda tv red_lam <*> mapM (mapOnSOACSubExp tv) red_nes ) ) -- | A helper for defining 'TraverseOpStms'. traverseSOACStms :: (Monad m) => OpStmsTraverser m (SOAC rep) rep traverseSOACStms f = mapSOACM mapper where mapper = identitySOACMapper {mapOnSOACLambda = traverseLambdaStms f} instance (ASTRep rep) => FreeIn (Scan rep) where freeIn' (Scan lam ne) = freeIn' lam <> freeIn' ne instance (ASTRep rep) => FreeIn (Reduce rep) where freeIn' (Reduce _ lam ne) = freeIn' lam <> freeIn' ne instance (ASTRep rep) => FreeIn (ScremaForm rep) where freeIn' (ScremaForm scans reds lam) = freeIn' scans <> freeIn' reds <> freeIn' lam instance (ASTRep rep) => FreeIn (HistOp rep) where freeIn' (HistOp w rf dests nes lam) = freeIn' w <> freeIn' rf <> freeIn' dests <> freeIn' nes <> freeIn' lam instance (ASTRep rep) => FreeIn (SOAC rep) where freeIn' = flip execState mempty . mapSOACM free where walk f x = modify (<> f x) >> pure x free = SOACMapper { mapOnSOACSubExp = walk freeIn', mapOnSOACLambda = walk freeIn', mapOnSOACVName = walk freeIn' } instance (ASTRep rep) => Substitute (SOAC rep) where substituteNames subst = runIdentity . mapSOACM substitute where substitute = SOACMapper { mapOnSOACSubExp = pure . substituteNames subst, mapOnSOACLambda = pure . substituteNames subst, mapOnSOACVName = pure . substituteNames subst } instance (ASTRep rep) => Rename (SOAC rep) where rename = mapSOACM renamer where renamer = SOACMapper rename rename rename -- | The type of a SOAC. soacType :: (Typed (LParamInfo rep)) => SOAC rep -> [Type] soacType (JVP _ _ lam) = lambdaReturnType lam ++ lambdaReturnType lam soacType (VJP _ _ lam) = lambdaReturnType lam ++ map paramType (lambdaParams lam) soacType (Stream outersize _ accs lam) = map (substNamesInType substs) rtp where nms = map paramName $ take (1 + length accs) params substs = M.fromList $ zip nms (outersize : accs) Lambda params rtp _ = lam soacType (Scatter _w _ivs dests lam) = zipWith arrayOfShape (map (snd . head) rets) shapes where (shapes, _, rets) = unzip3 $ groupScatterResults dests $ lambdaReturnType lam soacType (Hist _ _ ops _bucket_fun) = do op <- ops map (`arrayOfShape` histShape op) (lambdaReturnType $ histOp op) soacType (Screma w _arrs form) = scremaType w form instance TypedOp SOAC where opType = pure . staticShapes . soacType instance AliasedOp SOAC where opAliases = map (const mempty) . soacType consumedInOp JVP {} = mempty consumedInOp VJP {} = mempty -- Only map functions can consume anything. The operands to scan -- and reduce functions are always considered "fresh". consumedInOp (Screma _ arrs (ScremaForm map_lam _ _)) = mapNames consumedArray $ consumedByLambda map_lam where consumedArray v = fromMaybe v $ lookup v params_to_arrs params_to_arrs = zip (map paramName $ lambdaParams map_lam) arrs consumedInOp (Stream _ arrs accs lam) = namesFromList $ subExpVars $ map consumedArray $ namesToList $ consumedByLambda lam where consumedArray v = fromMaybe (Var v) $ lookup v paramsToInput -- Drop the chunk parameter, which cannot alias anything. paramsToInput = zip (map paramName $ drop 1 $ lambdaParams lam) (accs ++ map Var arrs) consumedInOp (Scatter _ _ spec _) = namesFromList $ map (\(_, _, a) -> a) spec consumedInOp (Hist _ _ ops _) = namesFromList $ concatMap histDest ops mapHistOp :: (Lambda frep -> Lambda trep) -> HistOp frep -> HistOp trep mapHistOp f (HistOp w rf dests nes lam) = HistOp w rf dests nes $ f lam instance CanBeAliased SOAC where addOpAliases aliases (JVP args vec lam) = JVP args vec (Alias.analyseLambda aliases lam) addOpAliases aliases (VJP args vec lam) = VJP args vec (Alias.analyseLambda aliases lam) addOpAliases aliases (Stream size arr accs lam) = Stream size arr accs $ Alias.analyseLambda aliases lam addOpAliases aliases (Scatter len arrs dests lam) = Scatter len arrs dests (Alias.analyseLambda aliases lam) addOpAliases aliases (Hist w arrs ops bucket_fun) = Hist w arrs (map (mapHistOp (Alias.analyseLambda aliases)) ops) (Alias.analyseLambda aliases bucket_fun) addOpAliases aliases (Screma w arrs (ScremaForm map_lam scans reds)) = Screma w arrs $ ScremaForm (Alias.analyseLambda aliases map_lam) (map onScan scans) (map onRed reds) where onRed red = red {redLambda = Alias.analyseLambda aliases $ redLambda red} onScan scan = scan {scanLambda = Alias.analyseLambda aliases $ scanLambda scan} instance IsOp SOAC where safeOp _ = False cheapOp _ = False opDependencies (Stream w arrs accs lam) = let accs_deps = map depsOf' accs arrs_deps = depsOfArrays w arrs in lambdaDependencies mempty lam (arrs_deps <> accs_deps) opDependencies (Hist w arrs ops lam) = let bucket_fun_deps' = lambdaDependencies mempty lam (depsOfArrays w arrs) -- Bucket function results are indices followed by values. -- Reshape this to align with list of histogram operations. ranks = map (shapeRank . histShape) ops value_lengths = map (length . histNeutral) ops (indices, values) = splitAt (sum ranks) bucket_fun_deps' bucket_fun_deps = zipWith concatIndicesToEachValue (chunks ranks indices) (chunks value_lengths values) in mconcat $ zipWith (zipWith (<>)) bucket_fun_deps (map depsOfHistOp ops) where depsOfHistOp (HistOp dest_shape rf dests nes op) = let shape_deps = depsOfShape dest_shape in_deps = map (\vn -> oneName vn <> shape_deps <> depsOf' rf) dests in reductionDependencies mempty op nes in_deps -- A histogram operation may use the same index for multiple values. concatIndicesToEachValue is vs = let is_flat = mconcat is in map (is_flat <>) vs opDependencies (Scatter w arrs outputs lam) = let deps = lambdaDependencies mempty lam (depsOfArrays w arrs) in map flattenBlocks (groupScatterResults outputs deps) where flattenBlocks (_, arr, ivs) = oneName arr <> mconcat (map (mconcat . fst) ivs) <> mconcat (map snd ivs) opDependencies (JVP args vec lam) = mconcat $ replicate 2 $ lambdaDependencies mempty lam $ zipWith (<>) (map depsOf' args) (map depsOf' vec) opDependencies (VJP args vec lam) = lambdaDependencies mempty lam (zipWith (<>) (map depsOf' args) (map depsOf' vec)) <> map (const $ freeIn args <> freeIn lam) (lambdaParams lam) opDependencies (Screma w arrs (ScremaForm map_lam scans reds)) = let (scans_in, reds_in, map_deps) = splitAt3 (scanResults scans) (redResults reds) $ lambdaDependencies mempty map_lam (depsOfArrays w arrs) scans_deps = concatMap depsOfScan (zip scans $ chunks (scanSizes scans) scans_in) reds_deps = concatMap depsOfRed (zip reds $ chunks (redSizes reds) reds_in) in scans_deps <> reds_deps <> map_deps where depsOfScan (Scan lam nes, deps_in) = reductionDependencies mempty lam nes deps_in depsOfRed (Reduce _ lam nes, deps_in) = reductionDependencies mempty lam nes deps_in substNamesInType :: M.Map VName SubExp -> Type -> Type substNamesInType _ t@Prim {} = t substNamesInType _ t@Acc {} = t substNamesInType _ (Mem space) = Mem space substNamesInType subs (Array btp shp u) = let shp' = Shape $ map (substNamesInSubExp subs) (shapeDims shp) in Array btp shp' u substNamesInSubExp :: M.Map VName SubExp -> SubExp -> SubExp substNamesInSubExp _ e@(Constant _) = e substNamesInSubExp subs (Var idd) = M.findWithDefault (Var idd) idd subs instance CanBeWise SOAC where addOpWisdom = runIdentity . mapSOACM (SOACMapper pure (pure . informLambda) pure) instance (RepTypes rep) => ST.IndexOp (SOAC rep) where indexOp vtable k soac [i] = do (lam, se, arr_params, arrs) <- lambdaAndSubExp soac let arr_indexes = M.fromList $ catMaybes $ zipWith arrIndex arr_params arrs arr_indexes' = foldl expandPrimExpTable arr_indexes $ bodyStms $ lambdaBody lam case se of SubExpRes _ (Var v) -> uncurry (flip ST.Indexed) <$> M.lookup v arr_indexes' _ -> Nothing where lambdaAndSubExp (Screma _ arrs (ScremaForm map_lam scans reds)) = nthMapOut (scanResults scans + redResults reds) map_lam arrs lambdaAndSubExp _ = Nothing nthMapOut num_accs lam arrs = do se <- maybeNth (num_accs + k) $ bodyResult $ lambdaBody lam pure (lam, se, drop num_accs $ lambdaParams lam, arrs) arrIndex p arr = do ST.Indexed cs pe <- ST.index' arr [i] vtable pure (paramName p, (pe, cs)) expandPrimExpTable table stm | [v] <- patNames $ stmPat stm, Just (pe, cs) <- runWriterT $ primExpFromExp (asPrimExp table) $ stmExp stm, all (`ST.elem` vtable) (unCerts $ stmCerts stm) = M.insert v (pe, stmCerts stm <> cs) table | otherwise = table asPrimExp table v | Just (e, cs) <- M.lookup v table = tell cs >> pure e | Just (Prim pt) <- ST.lookupType v vtable = pure $ LeafExp v pt | otherwise = lift Nothing indexOp _ _ _ _ = Nothing -- | Type-check a SOAC. typeCheckSOAC :: (TC.Checkable rep) => SOAC (Aliases rep) -> TC.TypeM rep () typeCheckSOAC (VJP args vec lam) = do args' <- mapM TC.checkArg args TC.checkLambda lam $ map TC.noArgAliases args' vec_ts <- mapM TC.checkSubExp vec unless (vec_ts == lambdaReturnType lam) $ TC.bad . TC.TypeError . docText $ "Return type" PP.indent 2 (pretty (lambdaReturnType lam)) "does not match type of seed vector" PP.indent 2 (pretty vec_ts) typeCheckSOAC (JVP args vec lam) = do args' <- mapM TC.checkArg args TC.checkLambda lam $ map TC.noArgAliases args' vec_ts <- mapM TC.checkSubExp vec unless (vec_ts == map TC.argType args') $ TC.bad . TC.TypeError . docText $ "Parameter type" PP.indent 2 (pretty $ map TC.argType args') "does not match type of seed vector" PP.indent 2 (pretty vec_ts) typeCheckSOAC (Stream size arrexps accexps lam) = do TC.require [Prim int64] size accargs <- mapM TC.checkArg accexps arrargs <- mapM lookupType arrexps _ <- TC.checkSOACArrayArgs size arrexps chunk <- case lambdaParams lam of chunk : _ -> pure chunk [] -> TC.bad $ TC.TypeError "Stream lambda without parameters." let asArg t = (t, mempty) inttp = Prim int64 lamarrs' = map (`setOuterSize` Var (paramName chunk)) arrargs acc_len = length accexps lamrtp = take acc_len $ lambdaReturnType lam unless (map TC.argType accargs == lamrtp) $ TC.bad . TC.TypeError $ "Stream with inconsistent accumulator type in lambda." -- just get the dflow of lambda on the fakearg, which does not alias -- arr, so we can later check that aliases of arr are not used inside lam. let fake_lamarrs' = map asArg lamarrs' TC.checkLambda lam $ asArg inttp : accargs ++ fake_lamarrs' typeCheckSOAC (Scatter w arrs as lam) = do -- Requirements: -- -- 0. @lambdaReturnType@ of @lam@ must be a list -- [index types..., value types, ...]. -- -- 1. The number of index types and value types must be equal to the number -- of return values from @lam@. -- -- 2. Each index type must have the type i64. -- -- 3. Each array in @as@ and the value types must have the same type -- -- 4. Each array in @as@ is consumed. This is not really a check, but more -- of a requirement, so that e.g. the source is not hoisted out of a -- loop, which will mean it cannot be consumed. -- -- 5. Each of arrs must be an array matching a corresponding lambda -- parameters. -- -- Code: -- First check the input size. TC.require [Prim int64] w -- 0. let (as_ws, as_ns, _as_vs) = unzip3 as indexes = sum $ zipWith (*) as_ns $ map length as_ws rts = lambdaReturnType lam rtsI = take indexes rts rtsV = drop indexes rts -- 1. unless (length rts == sum as_ns + sum (zipWith (*) as_ns $ map length as_ws)) $ TC.bad $ TC.TypeError "Scatter: number of index types, value types and array outputs do not match." -- 2. forM_ rtsI $ \rtI -> unless (Prim int64 == rtI) $ TC.bad $ TC.TypeError "Scatter: Index return type must be i64." forM_ (zip (chunks as_ns rtsV) as) $ \(rtVs, (aw, _, a)) -> do -- All lengths must have type i64. mapM_ (TC.require [Prim int64]) aw -- 3. forM_ rtVs $ \rtV -> TC.requireI [arrayOfShape rtV aw] a -- 4. TC.consume =<< TC.lookupAliases a -- 5. arrargs <- TC.checkSOACArrayArgs w arrs TC.checkLambda lam arrargs typeCheckSOAC (Hist w arrs ops bucket_fun) = do TC.require [Prim int64] w -- Check the operators. forM_ ops $ \(HistOp dest_shape rf dests nes op) -> do nes' <- mapM TC.checkArg nes mapM_ (TC.require [Prim int64]) dest_shape TC.require [Prim int64] rf -- Operator type must match the type of neutral elements. TC.checkLambda op $ map TC.noArgAliases $ nes' ++ nes' let nes_t = map TC.argType nes' unless (nes_t == lambdaReturnType op) $ TC.bad . TC.TypeError $ "Operator has return type " <> prettyTuple (lambdaReturnType op) <> " but neutral element has type " <> prettyTuple nes_t -- Arrays must have proper type. forM_ (zip nes_t dests) $ \(t, dest) -> do TC.requireI [t `arrayOfShape` dest_shape] dest TC.consume =<< TC.lookupAliases dest -- Types of input arrays must equal parameter types for bucket function. img' <- TC.checkSOACArrayArgs w arrs TC.checkLambda bucket_fun img' -- Return type of bucket function must be an index for each -- operation followed by the values to write. nes_ts <- concat <$> mapM (mapM subExpType . histNeutral) ops let bucket_ret_t = concatMap ((`replicate` Prim int64) . shapeRank . histShape) ops ++ nes_ts unless (bucket_ret_t == lambdaReturnType bucket_fun) $ TC.bad . TC.TypeError $ "Bucket function has return type " <> prettyTuple (lambdaReturnType bucket_fun) <> " but should have type " <> prettyTuple bucket_ret_t typeCheckSOAC (Screma w arrs (ScremaForm map_lam scans reds)) = do TC.require [Prim int64] w arrs' <- TC.checkSOACArrayArgs w arrs TC.checkLambda map_lam arrs' scan_nes' <- fmap concat $ forM scans $ \(Scan scan_lam scan_nes) -> do scan_nes' <- mapM TC.checkArg scan_nes let scan_t = map TC.argType scan_nes' TC.checkLambda scan_lam $ map TC.noArgAliases $ scan_nes' ++ scan_nes' unless (scan_t == lambdaReturnType scan_lam) $ TC.bad . TC.TypeError $ "Scan function returns type " <> prettyTuple (lambdaReturnType scan_lam) <> " but neutral element has type " <> prettyTuple scan_t pure scan_nes' red_nes' <- fmap concat $ forM reds $ \(Reduce _ red_lam red_nes) -> do red_nes' <- mapM TC.checkArg red_nes let red_t = map TC.argType red_nes' TC.checkLambda red_lam $ map TC.noArgAliases $ red_nes' ++ red_nes' unless (red_t == lambdaReturnType red_lam) $ TC.bad . TC.TypeError $ "Reduce function returns type " <> prettyTuple (lambdaReturnType red_lam) <> " but neutral element has type " <> prettyTuple red_t pure red_nes' let map_lam_ts = lambdaReturnType map_lam unless ( take (length scan_nes' + length red_nes') map_lam_ts == map TC.argType (scan_nes' ++ red_nes') ) . TC.bad . TC.TypeError $ "Map function return type " <> prettyTuple map_lam_ts <> " wrong for given scan and reduction functions." instance RephraseOp SOAC where rephraseInOp r (VJP args vec lam) = VJP args vec <$> rephraseLambda r lam rephraseInOp r (JVP args vec lam) = JVP args vec <$> rephraseLambda r lam rephraseInOp r (Stream w arrs acc lam) = Stream w arrs acc <$> rephraseLambda r lam rephraseInOp r (Scatter w arrs dests lam) = Scatter w arrs dests <$> rephraseLambda r lam rephraseInOp r (Hist w arrs ops lam) = Hist w arrs <$> mapM onOp ops <*> rephraseLambda r lam where onOp (HistOp dest_shape rf dests nes op) = HistOp dest_shape rf dests nes <$> rephraseLambda r op rephraseInOp r (Screma w arrs (ScremaForm lam scans red)) = Screma w arrs <$> ( ScremaForm <$> rephraseLambda r lam <*> mapM onScan scans <*> mapM onRed red ) where onScan (Scan op nes) = Scan <$> rephraseLambda r op <*> pure nes onRed (Reduce comm op nes) = Reduce comm <$> rephraseLambda r op <*> pure nes instance (OpMetrics (Op rep)) => OpMetrics (SOAC rep) where opMetrics (VJP _ _ lam) = inside "VJP" $ lambdaMetrics lam opMetrics (JVP _ _ lam) = inside "JVP" $ lambdaMetrics lam opMetrics (Stream _ _ _ lam) = inside "Stream" $ lambdaMetrics lam opMetrics (Scatter _len _ _ lam) = inside "Scatter" $ lambdaMetrics lam opMetrics (Hist _ _ ops bucket_fun) = inside "Hist" $ mapM_ (lambdaMetrics . histOp) ops >> lambdaMetrics bucket_fun opMetrics (Screma _ _ (ScremaForm map_lam scans reds)) = inside "Screma" $ do lambdaMetrics map_lam mapM_ (lambdaMetrics . scanLambda) scans mapM_ (lambdaMetrics . redLambda) reds instance (PrettyRep rep) => PP.Pretty (SOAC rep) where pretty (VJP args vec lam) = "vjp" <> parens ( PP.align $ PP.braces (commasep $ map pretty args) <> comma PP.braces (commasep $ map pretty vec) <> comma pretty lam ) pretty (JVP args vec lam) = "jvp" <> parens ( PP.align $ PP.braces (commasep $ map pretty args) <> comma PP.braces (commasep $ map pretty vec) <> comma pretty lam ) pretty (Stream size arrs acc lam) = ppStream size arrs acc lam pretty (Scatter w arrs dests lam) = ppScatter w arrs dests lam pretty (Hist w arrs ops bucket_fun) = ppHist w arrs ops bucket_fun pretty (Screma w arrs (ScremaForm map_lam scans reds)) | null scans, null reds = "map" <> (parens . align) ( pretty w <> comma ppTuple' (map pretty arrs) <> comma pretty map_lam ) | null scans = "redomap" <> (parens . align) ( pretty w <> comma ppTuple' (map pretty arrs) <> comma pretty map_lam <> comma PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty reds) ) | null reds = "scanomap" <> (parens . align) ( pretty w <> comma ppTuple' (map pretty arrs) <> comma pretty map_lam <> comma PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty scans) ) pretty (Screma w arrs form) = ppScrema w arrs form -- | Prettyprint the given Screma. ppScrema :: (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> ScremaForm rep -> Doc ann ppScrema w arrs (ScremaForm map_lam scans reds) = "screma" <> (parens . align) ( pretty w <> comma ppTuple' (map pretty arrs) <> comma pretty map_lam <> comma PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty scans) <> comma PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty reds) ) -- | Prettyprint the given Stream. ppStream :: (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann ppStream size arrs acc lam = "streamSeq" <> (parens . align) ( pretty size <> comma ppTuple' (map pretty arrs) <> comma ppTuple' (map pretty acc) <> comma pretty lam ) -- | Prettyprint the given Scatter. ppScatter :: (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> [(Shape, Int, VName)] -> Lambda rep -> Doc ann ppScatter w arrs dests lam = "scatter" <> (parens . align) ( pretty w <> comma ppTuple' (map pretty arrs) <> comma commasep (map pretty dests) <> comma pretty lam ) instance (PrettyRep rep) => Pretty (Scan rep) where pretty (Scan scan_lam scan_nes) = pretty scan_lam <> comma PP.braces (commasep $ map pretty scan_nes) ppComm :: Commutativity -> Doc ann ppComm Noncommutative = mempty ppComm Commutative = "commutative " instance (PrettyRep rep) => Pretty (Reduce rep) where pretty (Reduce comm red_lam red_nes) = ppComm comm <> pretty red_lam <> comma PP.braces (commasep $ map pretty red_nes) -- | Prettyprint the given histogram operation. ppHist :: (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> [HistOp rep] -> Lambda rep -> Doc ann ppHist w arrs ops bucket_fun = "hist" <> parens ( pretty w <> comma ppTuple' (map pretty arrs) <> comma PP.braces (mconcat $ intersperse (comma <> PP.line) $ map ppOp ops) <> comma pretty bucket_fun ) where ppOp (HistOp dest_w rf dests nes op) = pretty dest_w <> comma <+> pretty rf <> comma <+> PP.braces (commasep $ map pretty dests) <> comma ppTuple' (map pretty nes) <> comma pretty op futhark-0.25.27/src/Futhark/IR/SOACS/Simplify.hs000066400000000000000000001125201475065116200210500ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module Futhark.IR.SOACS.Simplify ( simplifySOACS, simplifyLambda, simplifyFun, simplifyStms, simplifyConsts, simpleSOACS, simplifySOAC, soacRules, HasSOAC (..), simplifyKnownIterationSOAC, removeReplicateMapping, removeUnusedSOACInput, liftIdentityMapping, simplifyMapIota, SOACS, ) where import Control.Monad import Control.Monad.Identity import Control.Monad.State import Control.Monad.Writer import Data.Either import Data.Foldable import Data.List (partition, transpose, unzip6, zip6) import Data.List.NonEmpty (NonEmpty (..)) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.Analysis.DataDependencies import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Analysis.UsageTable qualified as UT import Futhark.IR.Prop.Aliases import Futhark.IR.SOACS import Futhark.MonadFreshNames import Futhark.Optimise.Simplify qualified as Simplify import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Optimise.Simplify.Rep import Futhark.Optimise.Simplify.Rule import Futhark.Optimise.Simplify.Rules import Futhark.Optimise.Simplify.Rules.ClosedForm import Futhark.Pass import Futhark.Tools import Futhark.Transform.Rename import Futhark.Util simpleSOACS :: Simplify.SimpleOps SOACS simpleSOACS = Simplify.bindableSimpleOps simplifySOAC simplifySOACS :: Prog SOACS -> PassM (Prog SOACS) simplifySOACS = Simplify.simplifyProg simpleSOACS soacRules Engine.noExtraHoistBlockers simplifyFun :: (MonadFreshNames m) => ST.SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS) simplifyFun = Simplify.simplifyFun simpleSOACS soacRules Engine.noExtraHoistBlockers simplifyLambda :: (HasScope SOACS m, MonadFreshNames m) => Lambda SOACS -> m (Lambda SOACS) simplifyLambda = Simplify.simplifyLambda simpleSOACS soacRules Engine.noExtraHoistBlockers simplifyStms :: (HasScope SOACS m, MonadFreshNames m) => Stms SOACS -> m (Stms SOACS) simplifyStms stms = do scope <- askScope Simplify.simplifyStms simpleSOACS soacRules Engine.noExtraHoistBlockers scope stms simplifyConsts :: (MonadFreshNames m) => Stms SOACS -> m (Stms SOACS) simplifyConsts = Simplify.simplifyStms simpleSOACS soacRules Engine.noExtraHoistBlockers mempty simplifySOAC :: (Simplify.SimplifiableRep rep) => Simplify.SimplifyOp rep (SOAC (Wise rep)) simplifySOAC (VJP arr vec lam) = do (lam', hoisted) <- Engine.simplifyLambda mempty lam arr' <- mapM Engine.simplify arr vec' <- mapM Engine.simplify vec pure (VJP arr' vec' lam', hoisted) simplifySOAC (JVP arr vec lam) = do (lam', hoisted) <- Engine.simplifyLambda mempty lam arr' <- mapM Engine.simplify arr vec' <- mapM Engine.simplify vec pure (JVP arr' vec' lam', hoisted) simplifySOAC (Stream outerdim arr nes lam) = do outerdim' <- Engine.simplify outerdim nes' <- mapM Engine.simplify nes arr' <- mapM Engine.simplify arr (lam', lam_hoisted) <- Engine.enterLoop $ Engine.simplifyLambda mempty lam pure (Stream outerdim' arr' nes' lam', lam_hoisted) simplifySOAC (Scatter w ivs as lam) = do w' <- Engine.simplify w (lam', hoisted) <- Engine.enterLoop $ Engine.simplifyLambda mempty lam ivs' <- mapM Engine.simplify ivs as' <- mapM Engine.simplify as pure (Scatter w' ivs' as' lam', hoisted) simplifySOAC (Hist w imgs ops bfun) = do w' <- Engine.simplify w (ops', hoisted) <- fmap unzip $ forM ops $ \(HistOp dests_w rf dests nes op) -> do dests_w' <- Engine.simplify dests_w rf' <- Engine.simplify rf dests' <- Engine.simplify dests nes' <- mapM Engine.simplify nes (op', hoisted) <- Engine.enterLoop $ Engine.simplifyLambda mempty op pure (HistOp dests_w' rf' dests' nes' op', hoisted) imgs' <- mapM Engine.simplify imgs (bfun', bfun_hoisted) <- Engine.enterLoop $ Engine.simplifyLambda mempty bfun pure (Hist w' imgs' ops' bfun', mconcat hoisted <> bfun_hoisted) simplifySOAC (Screma w arrs (ScremaForm map_lam scans reds)) = do (scans', scans_hoisted) <- fmap unzip $ forM scans $ \(Scan lam nes) -> do (lam', hoisted) <- Engine.simplifyLambda mempty lam nes' <- Engine.simplify nes pure (Scan lam' nes', hoisted) (reds', reds_hoisted) <- fmap unzip $ forM reds $ \(Reduce comm lam nes) -> do (lam', hoisted) <- Engine.simplifyLambda mempty lam nes' <- Engine.simplify nes pure (Reduce comm lam' nes', hoisted) (map_lam', map_lam_hoisted) <- Engine.enterLoop $ Engine.simplifyLambda mempty map_lam (,) <$> ( Screma <$> Engine.simplify w <*> Engine.simplify arrs <*> pure (ScremaForm map_lam' scans' reds') ) <*> pure (mconcat scans_hoisted <> mconcat reds_hoisted <> map_lam_hoisted) instance BuilderOps (Wise SOACS) instance TraverseOpStms (Wise SOACS) where traverseOpStms = traverseSOACStms fixLambdaParams :: (MonadBuilder m, Buildable (Rep m), BuilderOps (Rep m)) => Lambda (Rep m) -> [Maybe SubExp] -> m (Lambda (Rep m)) fixLambdaParams lam fixes = do body <- runBodyBuilder $ localScope (scopeOfLParams $ lambdaParams lam) $ do zipWithM_ maybeFix (lambdaParams lam) fixes' bodyBind $ lambdaBody lam pure lam { lambdaBody = body, lambdaParams = map fst $ filter (isNothing . snd) $ zip (lambdaParams lam) fixes' } where fixes' = fixes ++ repeat Nothing maybeFix p (Just x) = letBindNames [paramName p] $ BasicOp $ SubExp x maybeFix _ Nothing = pure () removeLambdaResults :: [Bool] -> Lambda rep -> Lambda rep removeLambdaResults keep lam = lam { lambdaBody = lam_body', lambdaReturnType = ret } where keep' :: [a] -> [a] keep' = map snd . filter fst . zip (keep ++ repeat True) lam_body = lambdaBody lam lam_body' = lam_body {bodyResult = keep' $ bodyResult lam_body} ret = keep' $ lambdaReturnType lam soacRules :: RuleBook (Wise SOACS) soacRules = standardRules <> ruleBook topDownRules bottomUpRules -- | Does this rep contain 'SOAC's in its t'Op's? A rep must be an -- instance of this class for the simplification rules to work. class HasSOAC rep where asSOAC :: Op rep -> Maybe (SOAC rep) soacOp :: SOAC rep -> Op rep instance HasSOAC (Wise SOACS) where asSOAC = Just soacOp = id topDownRules :: [TopDownRule (Wise SOACS)] topDownRules = [ RuleOp hoistCerts, RuleOp removeReplicateMapping, RuleOp removeReplicateWrite, RuleOp removeUnusedSOACInput, RuleOp simplifyClosedFormReduce, RuleOp simplifyKnownIterationSOAC, RuleOp liftIdentityMapping, RuleOp removeDuplicateMapOutput, RuleOp fuseConcatScatter, RuleOp simplifyMapIota, RuleOp moveTransformToInput ] bottomUpRules :: [BottomUpRule (Wise SOACS)] bottomUpRules = [ RuleOp removeDeadMapping, RuleOp removeDeadReduction, RuleOp removeDeadWrite, RuleBasicOp removeUnnecessaryCopy, RuleOp liftIdentityStreaming, RuleOp mapOpToOp ] -- Any certificates attached to a trivial Stm in the body might as -- well be applied to the SOAC itself. hoistCerts :: TopDownRuleOp (Wise SOACS) hoistCerts vtable pat aux soac | (soac', hoisted) <- runState (mapSOACM mapper soac) mempty, hoisted /= mempty = Simplify $ auxing aux $ certifying hoisted $ letBind pat $ Op soac' where mapper = identitySOACMapper {mapOnSOACLambda = onLambda} onLambda lam = do stms' <- mapM onStm $ bodyStms $ lambdaBody lam pure lam { lambdaBody = mkBody stms' $ bodyResult $ lambdaBody lam } onStm (Let se_pat se_aux (BasicOp (SubExp se))) = do let (invariant, variant) = partition (`ST.elem` vtable) $ unCerts $ stmAuxCerts se_aux se_aux' = se_aux {stmAuxCerts = Certs variant} modify (Certs invariant <>) pure $ Let se_pat se_aux' $ BasicOp $ SubExp se onStm stm = pure stm hoistCerts _ _ _ _ = Skip liftIdentityMapping :: forall rep. (Buildable rep, BuilderOps rep, HasSOAC rep) => TopDownRuleOp rep liftIdentityMapping _ pat aux op | Just (Screma w arrs form :: SOAC rep) <- asSOAC op, Just fun <- isMapSOAC form = do let inputMap = M.fromList $ zip (map paramName $ lambdaParams fun) arrs free = freeIn $ lambdaBody fun rettype = lambdaReturnType fun ses = bodyResult $ lambdaBody fun freeOrConst (Var v) = v `nameIn` free freeOrConst Constant {} = True checkInvariance (outId, SubExpRes _ (Var v), _) (invariant, mapresult, rettype') | Just inp <- M.lookup v inputMap = ( (Pat [outId], e inp) : invariant, mapresult, rettype' ) where e inp = case patElemType outId of Acc {} -> BasicOp $ SubExp $ Var inp _ -> BasicOp (Replicate mempty (Var inp)) checkInvariance (outId, SubExpRes _ e, t) (invariant, mapresult, rettype') | freeOrConst e = ( (Pat [outId], BasicOp $ Replicate (Shape [w]) e) : invariant, mapresult, rettype' ) | otherwise = ( invariant, (outId, e) : mapresult, t : rettype' ) case foldr checkInvariance ([], [], []) $ zip3 (patElems pat) ses rettype of ([], _, _) -> Skip (invariant, mapresult, rettype') -> Simplify $ do let (pat', ses') = unzip mapresult fun' = fun { lambdaBody = (lambdaBody fun) {bodyResult = subExpsRes ses'}, lambdaReturnType = rettype' } mapM_ (uncurry letBind) invariant auxing aux $ letBindNames (map patElemName pat') $ Op $ soacOp $ Screma w arrs (mapSOAC fun') liftIdentityMapping _ _ _ _ = Skip liftIdentityStreaming :: BottomUpRuleOp (Wise SOACS) liftIdentityStreaming _ (Pat pes) aux (Stream w arrs nes lam) | (variant_map, invariant_map) <- partitionEithers $ map isInvariantRes $ zip3 map_ts map_pes map_res, not $ null invariant_map = Simplify $ do forM_ invariant_map $ \(pe, arr) -> letBind (Pat [pe]) $ BasicOp $ Replicate mempty $ Var arr let (variant_map_ts, variant_map_pes, variant_map_res) = unzip3 variant_map lam' = lam { lambdaBody = (lambdaBody lam) {bodyResult = fold_res ++ variant_map_res}, lambdaReturnType = fold_ts ++ variant_map_ts } auxing aux . letBind (Pat $ fold_pes ++ variant_map_pes) . Op $ Stream w arrs nes lam' where num_folds = length nes (fold_pes, map_pes) = splitAt num_folds pes (fold_ts, map_ts) = splitAt num_folds $ lambdaReturnType lam lam_res = bodyResult $ lambdaBody lam (fold_res, map_res) = splitAt num_folds lam_res params_to_arrs = zip (map paramName $ drop (1 + num_folds) $ lambdaParams lam) arrs isInvariantRes (_, pe, SubExpRes _ (Var v)) | Just arr <- lookup v params_to_arrs = Right (pe, arr) isInvariantRes x = Left x liftIdentityStreaming _ _ _ _ = Skip -- | Remove all arguments to the map that are simply replicates. -- These can be turned into free variables instead. removeReplicateMapping :: (Aliased rep, BuilderOps rep, HasSOAC rep) => TopDownRuleOp rep removeReplicateMapping vtable pat aux op | Just (Screma w arrs form) <- asSOAC op, Just fun <- isMapSOAC form, Just (stms, fun', arrs') <- removeReplicateInput vtable fun arrs = Simplify $ do forM_ stms $ \(vs, cs, e) -> certifying cs $ letBindNames vs e auxing aux $ letBind pat $ Op $ soacOp $ Screma w arrs' $ mapSOAC fun' removeReplicateMapping _ _ _ _ = Skip -- | Like 'removeReplicateMapping', but for 'Scatter'. removeReplicateWrite :: TopDownRuleOp (Wise SOACS) removeReplicateWrite vtable pat aux (Scatter w ivs as lam) | Just (stms, lam', ivs') <- removeReplicateInput vtable lam ivs = Simplify $ do forM_ stms $ \(vs, cs, e) -> certifying cs $ letBindNames vs e auxing aux $ letBind pat $ Op $ Scatter w ivs' as lam' removeReplicateWrite _ _ _ _ = Skip removeReplicateInput :: (Aliased rep) => ST.SymbolTable rep -> Lambda rep -> [VName] -> Maybe ( [([VName], Certs, Exp rep)], Lambda rep, [VName] ) removeReplicateInput vtable fun arrs | not $ null parameterBnds = do let (arr_params', arrs') = unzip params_and_arrs fun' = fun {lambdaParams = acc_params <> arr_params'} pure (parameterBnds, fun', arrs') | otherwise = Nothing where params = lambdaParams fun (acc_params, arr_params) = splitAt (length params - length arrs) params (params_and_arrs, parameterBnds) = partitionEithers $ zipWith isReplicateAndNotConsumed arr_params arrs isReplicateAndNotConsumed p v | Just (BasicOp (Replicate (Shape (_ : ds)) e), v_cs) <- ST.lookupExp v vtable, paramName p `notNameIn` consumedByLambda fun = Right ( [paramName p], v_cs, case ds of [] -> BasicOp $ SubExp e _ -> BasicOp $ Replicate (Shape ds) e ) | otherwise = Left (p, v) -- | Remove inputs that are not used inside the SOAC. removeUnusedSOACInput :: forall rep. (Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) => TopDownRuleOp rep removeUnusedSOACInput _ pat aux op | Just (Screma w arrs form :: SOAC rep) <- asSOAC op, ScremaForm map_lam scan reduce <- form, Just (used_arrs, map_lam') <- remove map_lam arrs = Simplify . auxing aux . letBind pat . Op $ soacOp (Screma w used_arrs (ScremaForm map_lam' scan reduce)) | Just (Scatter w arrs dests map_lam :: SOAC rep) <- asSOAC op, Just (used_arrs, map_lam') <- remove map_lam arrs = Simplify . auxing aux . letBind pat . Op $ soacOp (Scatter w used_arrs dests map_lam') where used_in_body map_lam = freeIn $ lambdaBody map_lam usedInput map_lam (param, _) = paramName param `nameIn` used_in_body map_lam remove map_lam arrs = let (used, unused) = partition (usedInput map_lam) (zip (lambdaParams map_lam) arrs) (used_params, used_arrs) = unzip used map_lam' = map_lam {lambdaParams = used_params} in if null unused then Nothing else Just (used_arrs, map_lam') removeUnusedSOACInput _ _ _ _ = Skip removeDeadMapping :: BottomUpRuleOp (Wise SOACS) removeDeadMapping (_, used) (Pat pes) aux (Screma w arrs (ScremaForm lam scans reds)) | (nonmap_pes, map_pes) <- splitAt num_nonmap_res pes, not $ null map_pes = let (nonmap_res, map_res) = splitAt num_nonmap_res $ bodyResult $ lambdaBody lam (nonmap_ts, map_ts) = splitAt num_nonmap_res $ lambdaReturnType lam isUsed (bindee, _, _) = (`UT.used` used) $ patElemName bindee (map_pes', map_res', map_ts') = unzip3 $ filter isUsed $ zip3 map_pes map_res map_ts lam' = lam { lambdaBody = (lambdaBody lam) {bodyResult = nonmap_res <> map_res'}, lambdaReturnType = nonmap_ts <> map_ts' } in if map_pes /= map_pes' then Simplify . auxing aux $ letBind (Pat $ nonmap_pes <> map_pes') . Op $ Screma w arrs (ScremaForm lam' scans reds) else Skip where num_nonmap_res = scanResults scans + redResults reds removeDeadMapping _ _ _ _ = Skip removeDuplicateMapOutput :: TopDownRuleOp (Wise SOACS) removeDuplicateMapOutput _ (Pat pes) aux (Screma w arrs form) | Just fun <- isMapSOAC form = let ses = bodyResult $ lambdaBody fun ts = lambdaReturnType fun ses_ts_pes = zip3 ses ts pes (ses_ts_pes', copies) = foldl checkForDuplicates (mempty, mempty) ses_ts_pes in if null copies then Skip else Simplify $ do let (ses', ts', pes') = unzip3 ses_ts_pes' fun' = fun { lambdaBody = (lambdaBody fun) {bodyResult = ses'}, lambdaReturnType = ts' } auxing aux $ letBind (Pat pes') $ Op $ Screma w arrs $ mapSOAC fun' forM_ copies $ \(from, to) -> letBind (Pat [to]) $ BasicOp $ Replicate mempty $ Var $ patElemName from where checkForDuplicates (ses_ts_pes', copies) (se, t, pe) | Just (_, _, pe') <- find (\(x, _, _) -> resSubExp x == resSubExp se) ses_ts_pes' = -- This result has been returned before, producing the -- array pe'. (ses_ts_pes', (pe', pe) : copies) | otherwise = (ses_ts_pes' ++ [(se, t, pe)], copies) removeDuplicateMapOutput _ _ _ _ = Skip -- Mapping some operations becomes an extension of that operation. mapOpToOp :: BottomUpRuleOp (Wise SOACS) mapOpToOp (_, used) pat aux1 e | Just (map_pe, cs, w, BasicOp (Reshape k newshape reshape_arr), [p], [arr]) <- isMapWithOp pat e, paramName p == reshape_arr, not $ UT.isConsumed (patElemName map_pe) used = Simplify $ do certifying (stmAuxCerts aux1 <> cs) . letBind pat . BasicOp $ Reshape k (Shape [w] <> newshape) arr | Just (_, cs, _, BasicOp (Concat d (arr :| arrs) dw), ps, outer_arr : outer_arrs) <- isMapWithOp pat e, (arr : arrs) == map paramName ps = Simplify . certifying (stmAuxCerts aux1 <> cs) . letBind pat . BasicOp $ Concat (d + 1) (outer_arr :| outer_arrs) dw | Just (map_pe, cs, _, BasicOp (Rearrange perm rearrange_arr), [p], [arr]) <- isMapWithOp pat e, paramName p == rearrange_arr, not $ UT.isConsumed (patElemName map_pe) used = Simplify . certifying (stmAuxCerts aux1 <> cs) . letBind pat . BasicOp $ Rearrange (0 : map (1 +) perm) arr mapOpToOp _ _ _ _ = Skip isMapWithOp :: Pat dec -> SOAC (Wise SOACS) -> Maybe ( PatElem dec, Certs, SubExp, Exp (Wise SOACS), [Param Type], [VName] ) isMapWithOp pat e | Pat [map_pe] <- pat, Screma w arrs form <- e, Just map_lam <- isMapSOAC form, [Let (Pat [pe]) aux2 e'] <- stmsToList $ bodyStms $ lambdaBody map_lam, [SubExpRes _ (Var r)] <- bodyResult $ lambdaBody map_lam, r == patElemName pe = Just (map_pe, stmAuxCerts aux2, w, e', lambdaParams map_lam, arrs) | otherwise = Nothing -- | Some of the results of a reduction (or really: Redomap) may be -- dead. We remove them here. The trick is that we need to look at -- the data dependencies to see that the "dead" result is not -- actually used for computing one of the live ones. removeDeadReduction :: BottomUpRuleOp (Wise SOACS) removeDeadReduction (_, used) pat aux (Screma w arrs form) = case isRedomapSOAC form of Just ([Reduce comm redlam rednes], maplam) -> let mkOp lam nes' = redomapSOAC [Reduce comm lam nes'] in removeDeadReduction' redlam rednes maplam mkOp _ -> case isScanomapSOAC form of Just ([Scan scanlam nes], maplam) -> let mkOp lam nes' = scanomapSOAC [Scan lam nes'] in removeDeadReduction' scanlam nes maplam mkOp _ -> Skip where removeDeadReduction' redlam nes maplam mkOp | not $ all (`UT.used` used) $ patNames pat, -- Quick/cheap check let (red_pes, map_pes) = splitAt (length nes) $ patElems pat, let redlam_deps = dataDependencies $ lambdaBody redlam, let redlam_res = bodyResult $ lambdaBody redlam, let redlam_params = lambdaParams redlam, let (redlam_xparams, redlam_yparams) = splitAt (length nes) redlam_params, let used_after = map snd . filter ((`UT.used` used) . patElemName . fst) $ zip (red_pes <> red_pes) redlam_params, let necessary = findNecessaryForReturned (`elem` used_after) (zip redlam_params $ map resSubExp $ redlam_res <> redlam_res) redlam_deps, let alive_mask = zipWith (||) (map ((`nameIn` necessary) . paramName) redlam_xparams) (map ((`nameIn` necessary) . paramName) redlam_yparams), not $ and alive_mask = Simplify $ do let fixDeadToNeutral lives ne = if lives then Nothing else Just ne dead_fix = zipWith fixDeadToNeutral alive_mask nes (used_red_pes, used_nes) = unzip . map snd . filter fst $ zip alive_mask $ zip red_pes nes when (used_nes == nes) cannotSimplify let maplam' = removeLambdaResults alive_mask maplam redlam' <- removeLambdaResults alive_mask <$> fixLambdaParams redlam (dead_fix ++ dead_fix) auxing aux . letBind (Pat $ used_red_pes ++ map_pes) . Op $ Screma w arrs (mkOp redlam' used_nes maplam') removeDeadReduction' _ _ _ _ = Skip removeDeadReduction _ _ _ _ = Skip -- | If we are writing to an array that is never used, get rid of it. removeDeadWrite :: BottomUpRuleOp (Wise SOACS) removeDeadWrite (_, used) pat aux (Scatter w arrs dests fun) = let (i_ses, v_ses) = unzip $ groupScatterResults' dests $ bodyResult $ lambdaBody fun (i_ts, v_ts) = unzip $ groupScatterResults' dests $ lambdaReturnType fun isUsed (bindee, _, _, _, _, _) = (`UT.used` used) $ patElemName bindee (pat', i_ses', v_ses', i_ts', v_ts', dests') = unzip6 $ filter isUsed $ zip6 (patElems pat) i_ses v_ses i_ts v_ts dests fun' = fun { lambdaBody = mkBody (bodyStms (lambdaBody fun)) (concat i_ses' ++ v_ses'), lambdaReturnType = concat i_ts' ++ v_ts' } in if pat /= Pat pat' then Simplify . auxing aux . letBind (Pat pat') $ Op (Scatter w arrs dests' fun') else Skip removeDeadWrite _ _ _ _ = Skip -- handles now concatenation of more than two arrays fuseConcatScatter :: TopDownRuleOp (Wise SOACS) fuseConcatScatter vtable pat _ (Scatter _ arrs dests fun) | Just (ws@(w' : _), xss, css) <- unzip3 <$> mapM isConcat arrs, xivs <- transpose xss, all (w' ==) ws = Simplify $ do let r = length xivs fun2s <- replicateM (r - 1) (renameLambda fun) let (fun_is, fun_vs) = unzip . map (splitScatterResults dests . bodyResult . lambdaBody) $ fun : fun2s (its, vts) = unzip . replicate r . splitScatterResults dests $ lambdaReturnType fun new_stmts = mconcat $ map (bodyStms . lambdaBody) (fun : fun2s) let fun' = Lambda { lambdaParams = mconcat $ map lambdaParams (fun : fun2s), lambdaBody = mkBody new_stmts $ mix fun_is <> mix fun_vs, lambdaReturnType = mix its <> mix vts } certifying (mconcat css) . letBind pat . Op $ Scatter w' (concat xivs) (map (incWrites r) dests) fun' where sizeOf :: VName -> Maybe SubExp sizeOf x = arraySize 0 . typeOf <$> ST.lookup x vtable mix = concat . transpose incWrites r (w, n, a) = (w, n * r, a) -- ToDO: is it (n*r) or (n+r-1)?? isConcat v = case ST.lookupExp v vtable of Just (BasicOp (Concat 0 (x :| ys) _), cs) -> do x_w <- sizeOf x y_ws <- mapM sizeOf ys guard $ all (x_w ==) y_ws pure (x_w, x : ys, cs) Just (BasicOp (Reshape ReshapeCoerce _ arr), cs) -> do (a, b, cs') <- isConcat arr pure (a, b, cs <> cs') _ -> Nothing fuseConcatScatter _ _ _ _ = Skip simplifyClosedFormReduce :: TopDownRuleOp (Wise SOACS) simplifyClosedFormReduce _ pat _ (Screma (Constant w) _ form) | Just nes <- concatMap redNeutral . fst <$> isRedomapSOAC form, zeroIsh w = Simplify . forM_ (zip (patNames pat) nes) $ \(v, ne) -> letBindNames [v] $ BasicOp $ SubExp ne simplifyClosedFormReduce vtable pat _ (Screma _ arrs form) | Just [Reduce _ red_fun nes] <- isReduceSOAC form = Simplify $ foldClosedForm (`ST.lookupExp` vtable) pat red_fun nes arrs simplifyClosedFormReduce _ _ _ _ = Skip -- For now we just remove singleton SOACs and those with unroll attributes. simplifyKnownIterationSOAC :: (Buildable rep, BuilderOps rep, HasSOAC rep) => TopDownRuleOp rep simplifyKnownIterationSOAC _ pat _ op | Just (Screma (Constant k) arrs (ScremaForm map_lam scans reds)) <- asSOAC op, oneIsh k = Simplify $ do let (Reduce _ red_lam red_nes) = singleReduce reds (Scan scan_lam scan_nes) = singleScan scans (scan_pes, red_pes, map_pes) = splitAt3 (length scan_nes) (length red_nes) $ patElems pat bindMapParam p a = do a_t <- lookupType a letBindNames [paramName p] $ BasicOp $ Index a $ fullSlice a_t [DimFix $ constant (0 :: Int64)] bindArrayResult pe (SubExpRes cs se) = certifying cs . letBindNames [patElemName pe] $ BasicOp $ ArrayLit [se] $ rowType $ patElemType pe bindResult pe (SubExpRes cs se) = certifying cs $ letBindNames [patElemName pe] $ BasicOp $ SubExp se zipWithM_ bindMapParam (lambdaParams map_lam) arrs (to_scan, to_red, map_res) <- splitAt3 (length scan_nes) (length red_nes) <$> bodyBind (lambdaBody map_lam) scan_res <- eLambda scan_lam $ map eSubExp $ scan_nes ++ map resSubExp to_scan red_res <- eLambda red_lam $ map eSubExp $ red_nes ++ map resSubExp to_red zipWithM_ bindArrayResult scan_pes scan_res zipWithM_ bindResult red_pes red_res zipWithM_ bindArrayResult map_pes map_res simplifyKnownIterationSOAC _ pat _ op | Just (Stream (Constant k) arrs nes fold_lam) <- asSOAC op, oneIsh k = Simplify $ do let (chunk_param, acc_params, slice_params) = partitionChunkedFoldParameters (length nes) (lambdaParams fold_lam) letBindNames [paramName chunk_param] $ BasicOp $ SubExp $ intConst Int64 1 forM_ (zip acc_params nes) $ \(p, ne) -> letBindNames [paramName p] $ BasicOp $ SubExp ne forM_ (zip slice_params arrs) $ \(p, arr) -> letBindNames [paramName p] $ BasicOp $ SubExp $ Var arr res <- bodyBind $ lambdaBody fold_lam forM_ (zip (patNames pat) res) $ \(v, SubExpRes cs se) -> certifying cs $ letBindNames [v] $ BasicOp $ SubExp se -- simplifyKnownIterationSOAC _ pat aux op | Just (Screma (Constant (IntValue (Int64Value k))) arrs (ScremaForm map_lam [] [])) <- asSOAC op, "unroll" `inAttrs` stmAuxAttrs aux = Simplify $ do arrs_elems <- fmap transpose . forM [0 .. k - 1] $ \i -> do map_lam' <- renameLambda map_lam eLambda map_lam' $ map (`eIndex` [eSubExp (constant i)]) arrs forM_ (zip3 (patNames pat) arrs_elems (lambdaReturnType map_lam)) $ \(v, arr_elems, t) -> certifying (mconcat (map resCerts arr_elems)) $ letBindNames [v] . BasicOp $ ArrayLit (map resSubExp arr_elems) t -- simplifyKnownIterationSOAC _ _ _ _ = Skip data ArrayOp = ArrayIndexing Certs VName (Slice SubExp) | ArrayRearrange Certs VName [Int] | ArrayReshape Certs VName ReshapeKind Shape | ArrayCopy Certs VName | -- | Never constructed. ArrayVar Certs VName deriving (Eq, Ord, Show) arrayOpArr :: ArrayOp -> VName arrayOpArr (ArrayIndexing _ arr _) = arr arrayOpArr (ArrayRearrange _ arr _) = arr arrayOpArr (ArrayReshape _ arr _ _) = arr arrayOpArr (ArrayCopy _ arr) = arr arrayOpArr (ArrayVar _ arr) = arr arrayOpCerts :: ArrayOp -> Certs arrayOpCerts (ArrayIndexing cs _ _) = cs arrayOpCerts (ArrayRearrange cs _ _) = cs arrayOpCerts (ArrayReshape cs _ _ _) = cs arrayOpCerts (ArrayCopy cs _) = cs arrayOpCerts (ArrayVar cs _) = cs isArrayOp :: Certs -> Exp rep -> Maybe ArrayOp isArrayOp cs (BasicOp (Index arr slice)) = Just $ ArrayIndexing cs arr slice isArrayOp cs (BasicOp (Rearrange perm arr)) = Just $ ArrayRearrange cs arr perm isArrayOp cs (BasicOp (Reshape k new_shape arr)) = Just $ ArrayReshape cs arr k new_shape isArrayOp cs (BasicOp (Replicate (Shape []) (Var arr))) = Just $ ArrayCopy cs arr isArrayOp _ _ = Nothing fromArrayOp :: ArrayOp -> (Certs, Exp rep) fromArrayOp (ArrayIndexing cs arr slice) = (cs, BasicOp $ Index arr slice) fromArrayOp (ArrayRearrange cs arr perm) = (cs, BasicOp $ Rearrange perm arr) fromArrayOp (ArrayReshape cs arr k new_shape) = (cs, BasicOp $ Reshape k new_shape arr) fromArrayOp (ArrayCopy cs arr) = (cs, BasicOp $ Replicate mempty $ Var arr) fromArrayOp (ArrayVar cs arr) = (cs, BasicOp $ SubExp $ Var arr) arrayOps :: forall rep. (Buildable rep, HasSOAC rep) => Certs -> Body rep -> S.Set (Pat (LetDec rep), ArrayOp) arrayOps cs = mconcat . map onStm . stmsToList . bodyStms where -- It is not safe to move everything out of branches (#1874) or -- loops (#2015); probably we need to put some more intelligence -- in here somehow. onStm (Let _ _ Match {}) = mempty onStm (Let _ _ Loop {}) = mempty onStm (Let pat aux e) = case isArrayOp (cs <> stmAuxCerts aux) e of Just op -> S.singleton (pat, op) Nothing -> execState (walkExpM (walker (stmAuxCerts aux)) e) mempty onOp more_cs op | Just soac <- asSOAC op = -- Copies are not safe to move out of nested ops (#1753). S.filter (notCopy . snd) $ execWriter $ mapSOACM identitySOACMapper {mapOnSOACLambda = onLambda more_cs} (soac :: SOAC rep) | otherwise = mempty onLambda more_cs lam = do tell $ arrayOps (cs <> more_cs) $ lambdaBody lam pure lam walker more_cs = (identityWalker @rep) { walkOnBody = const $ modify . (<>) . arrayOps (cs <> more_cs), walkOnOp = modify . (<>) . onOp more_cs } notCopy (ArrayCopy {}) = False notCopy _ = True replaceArrayOps :: forall rep. (Buildable rep, BuilderOps rep, HasSOAC rep) => M.Map (Pat (LetDec rep)) ArrayOp -> Body rep -> Body rep replaceArrayOps substs (Body _ stms res) = mkBody (fmap onStm stms) res where onStm (Let pat aux e) = let (cs', e') = maybe (mempty, mapExp mapper e) fromArrayOp $ M.lookup pat substs in certify cs' $ mkLet' (patIdents pat) aux e' mapper = (identityMapper @rep) { mapOnBody = const $ pure . replaceArrayOps substs, mapOnOp = pure . onOp } onOp op | Just (soac :: SOAC rep) <- asSOAC op = soacOp . runIdentity $ mapSOACM identitySOACMapper {mapOnSOACLambda = pure . onLambda} soac | otherwise = op onLambda lam = lam {lambdaBody = replaceArrayOps substs $ lambdaBody lam} -- Turn -- -- map (\i -> ... xs[i] ...) (iota n) -- -- into -- -- map (\i x -> ... x ...) (iota n) xs -- -- This is not because we want to encourage the map-iota pattern, but -- it may be present in generated code. This is an unfortunately -- expensive simplification rule, since it requires multiple passes -- over the entire lambda body. It only handles the very simplest -- case - if you find yourself planning to extend it to handle more -- complex situations (rotate or whatnot), consider turning it into a -- separate compiler pass instead. simplifyMapIota :: forall rep. (Buildable rep, BuilderOps rep, HasSOAC rep) => TopDownRuleOp rep simplifyMapIota vtable screma_pat aux op | Just (Screma w arrs (ScremaForm map_lam scan reduce) :: SOAC rep) <- asSOAC op, Just (p, _) <- find isIota (zip (lambdaParams map_lam) arrs), indexings <- mapMaybe (indexesWith (paramName p)) . S.toList $ arrayOps mempty $ lambdaBody map_lam, not $ null indexings = Simplify $ do -- For each indexing with iota, add the corresponding array to -- the Screma, and construct a new lambda parameter. (more_arrs, more_params, replacements) <- unzip3 . catMaybes <$> mapM (mapOverArr w) indexings let substs = M.fromList replacements map_lam' = map_lam { lambdaParams = lambdaParams map_lam <> more_params, lambdaBody = replaceArrayOps substs $ lambdaBody map_lam } auxing aux . letBind screma_pat . Op . soacOp $ Screma w (arrs <> more_arrs) (ScremaForm map_lam' scan reduce) where isIota (_, arr) = case ST.lookupBasicOp arr vtable of Just (Iota _ (Constant o) (Constant s) _, _) -> zeroIsh o && oneIsh s _ -> False -- Find a 'DimFix i', optionally preceded by other DimFixes, and -- if so return those DimFixes. fixWith i (DimFix j : slice) | Var i == j = Just [] | otherwise = (j :) <$> fixWith i slice fixWith _ _ = Nothing indexesWith v (pat, idx@(ArrayIndexing cs arr (Slice js))) | arr `ST.elem` vtable, all (`ST.elem` vtable) $ unCerts cs, Just js' <- fixWith v js, all (`ST.elem` vtable) $ namesToList $ freeIn js' = Just (pat, js', idx) indexesWith _ _ = Nothing properArr [] arr = pure arr properArr js arr = do arr_t <- lookupType arr letExp (baseString arr) $ BasicOp $ Index arr $ fullSlice arr_t $ map DimFix js mapOverArr w (pat, js, ArrayIndexing cs arr slice) = do arr' <- properArr js arr arr_t <- lookupType arr' arr'' <- if arraySize 0 arr_t == w then pure arr' else certifying cs . letExp (baseString arr ++ "_prefix") . BasicOp . Index arr' $ fullSlice arr_t [DimSlice (intConst Int64 0) w (intConst Int64 1)] arr_elem_param <- newParam (baseString arr ++ "_elem") (rowType arr_t) pure $ Just ( arr'', arr_elem_param, ( pat, ArrayIndexing cs (paramName arr_elem_param) (Slice (drop (length js + 1) (unSlice slice))) ) ) mapOverArr _ _ = pure Nothing simplifyMapIota _ _ _ _ = Skip -- If a Screma's map function contains a transformation -- (e.g. transpose) on a parameter, create a new parameter -- corresponding to that transformation performed on the rows of the -- full array. moveTransformToInput :: TopDownRuleOp (Wise SOACS) moveTransformToInput vtable screma_pat aux soac@(Screma w arrs (ScremaForm map_lam scan reduce)) | ops <- filter arrayIsMapParam $ S.toList $ arrayOps mempty $ lambdaBody map_lam, not $ null ops = Simplify $ do (more_arrs, more_params, replacements) <- unzip3 . catMaybes <$> mapM mapOverArr ops when (null more_arrs) cannotSimplify let map_lam' = map_lam { lambdaParams = lambdaParams map_lam <> more_params, lambdaBody = replaceArrayOps (M.fromList replacements) $ lambdaBody map_lam } auxing aux . letBind screma_pat . Op $ Screma w (arrs <> more_arrs) (ScremaForm map_lam' scan reduce) where -- It is not safe to move the transform if the root array is being -- consumed by the Screma. This is a bit too conservative - it's -- actually safe if we completely replace the original input, but -- this rule is not that precise. consumed = consumedInOp soac map_param_names = map paramName (lambdaParams map_lam) topLevelPat = (`elem` fmap stmPat (bodyStms (lambdaBody map_lam))) onlyUsedOnce arr = case filter ((arr `nameIn`) . freeIn) $ stmsToList $ bodyStms $ lambdaBody map_lam of _ : _ : _ -> False _ -> True -- It's not just about whether the array is a parameter; -- everything else must be map-invariant. arrayIsMapParam (pat', ArrayIndexing cs arr slice) = arr `elem` map_param_names && all (`ST.elem` vtable) (namesToList $ freeIn cs <> freeIn slice) && not (null slice) && (not (null $ sliceDims slice) || (topLevelPat pat' && onlyUsedOnce arr)) arrayIsMapParam (_, ArrayRearrange cs arr perm) = arr `elem` map_param_names && all (`ST.elem` vtable) (namesToList $ freeIn cs) && not (null perm) arrayIsMapParam (_, ArrayReshape cs arr _ new_shape) = arr `elem` map_param_names && all (`ST.elem` vtable) (namesToList $ freeIn cs <> freeIn new_shape) arrayIsMapParam (_, ArrayCopy cs arr) = arr `elem` map_param_names && all (`ST.elem` vtable) (namesToList $ freeIn cs) arrayIsMapParam (_, ArrayVar {}) = False mapOverArr (pat, op) | Just (_, arr) <- find ((== arrayOpArr op) . fst) (zip map_param_names arrs), arr `notNameIn` consumed = do arr_t <- lookupType arr let whole_dim = DimSlice (intConst Int64 0) (arraySize 0 arr_t) (intConst Int64 1) arr_transformed <- certifying (arrayOpCerts op) $ letExp (baseString arr ++ "_transformed") $ case op of ArrayIndexing _ _ (Slice slice) -> BasicOp $ Index arr $ Slice $ whole_dim : slice ArrayRearrange _ _ perm -> BasicOp $ Rearrange (0 : map (+ 1) perm) arr ArrayReshape _ _ k new_shape -> BasicOp $ Reshape k (Shape [w] <> new_shape) arr ArrayCopy {} -> BasicOp $ Replicate mempty $ Var arr ArrayVar {} -> BasicOp $ SubExp $ Var arr arr_transformed_t <- lookupType arr_transformed arr_transformed_row <- newVName $ baseString arr ++ "_transformed_row" pure $ Just ( arr_transformed, Param mempty arr_transformed_row (rowType arr_transformed_t), (pat, ArrayVar mempty arr_transformed_row) ) mapOverArr _ = pure Nothing moveTransformToInput _ _ _ _ = Skip futhark-0.25.27/src/Futhark/IR/SegOp.hs000066400000000000000000001376301475065116200174320ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -- | Segmented operations. These correspond to perfect @map@ nests on -- top of /something/, except that the @map@s are conceptually only -- over @iota@s (so there will be explicit indexing inside them). module Futhark.IR.SegOp ( SegOp (..), segLevel, segBody, segSpace, typeCheckSegOp, SegSpace (..), scopeOfSegSpace, segSpaceDims, -- * Details HistOp (..), histType, splitHistResults, SegBinOp (..), segBinOpResults, segBinOpChunks, KernelBody (..), aliasAnalyseKernelBody, consumedInKernelBody, ResultManifest (..), KernelResult (..), kernelResultCerts, kernelResultSubExp, -- ** Generic traversal SegOpMapper (..), identitySegOpMapper, mapSegOpM, traverseSegOpStms, -- * Simplification simplifySegOp, HasSegOp (..), segOpRules, -- * Memory segOpReturns, ) where import Control.Category import Control.Monad import Control.Monad.Identity import Control.Monad.Reader import Control.Monad.State.Strict import Control.Monad.Writer import Data.Bifunctor (first) import Data.Bitraversable import Data.List ( elemIndex, foldl', groupBy, intersperse, isPrefixOf, partition, unzip4, zip4, ) import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Analysis.Alias qualified as Alias import Futhark.Analysis.Metrics import Futhark.Analysis.PrimExp.Convert import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Analysis.UsageTable qualified as UT import Futhark.IR import Futhark.IR.Aliases ( Aliases, CanBeAliased (..), ) import Futhark.IR.Mem import Futhark.IR.Prop.Aliases import Futhark.IR.TypeCheck qualified as TC import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Optimise.Simplify.Rep import Futhark.Optimise.Simplify.Rule import Futhark.Tools import Futhark.Transform.Rename import Futhark.Transform.Substitute import Futhark.Util (chunks, maybeNth) import Futhark.Util.Pretty ( Doc, apply, hsep, parens, ppTuple', pretty, (<+>), (), ) import Futhark.Util.Pretty qualified as PP import Prelude hiding (id, (.)) -- | An operator for 'SegHist'. data HistOp rep = HistOp { histShape :: Shape, histRaceFactor :: SubExp, histDest :: [VName], histNeutral :: [SubExp], -- | In case this operator is semantically a vectorised -- operator (corresponding to a perfect map nest in the -- SOACS representation), these are the logical -- "dimensions". This is used to generate more efficient -- code. histOpShape :: Shape, histOp :: Lambda rep } deriving (Eq, Ord, Show) -- | The type of a histogram produced by a 'HistOp'. This can be -- different from the type of the 'histDest's in case we are -- dealing with a segmented histogram. histType :: HistOp rep -> [Type] histType op = map (`arrayOfShape` (histShape op <> histOpShape op)) $ lambdaReturnType $ histOp op -- | Split reduction results returned by a 'KernelBody' into those -- that correspond to indexes for the 'HistOp's, and those that -- correspond to value. splitHistResults :: [HistOp rep] -> [SubExp] -> [([SubExp], [SubExp])] splitHistResults ops res = let ranks = map (shapeRank . histShape) ops (idxs, vals) = splitAt (sum ranks) res in zip (chunks ranks idxs) (chunks (map (length . histDest) ops) vals) -- | An operator for 'SegScan' and 'SegRed'. data SegBinOp rep = SegBinOp { segBinOpComm :: Commutativity, segBinOpLambda :: Lambda rep, segBinOpNeutral :: [SubExp], -- | In case this operator is semantically a vectorised -- operator (corresponding to a perfect map nest in the -- SOACS representation), these are the logical -- "dimensions". This is used to generate more efficient -- code. segBinOpShape :: Shape } deriving (Eq, Ord, Show) -- | How many reduction results are produced by these 'SegBinOp's? segBinOpResults :: [SegBinOp rep] -> Int segBinOpResults = sum . map (length . segBinOpNeutral) -- | Split some list into chunks equal to the number of values -- returned by each 'SegBinOp' segBinOpChunks :: [SegBinOp rep] -> [a] -> [[a]] segBinOpChunks = chunks . map (length . segBinOpNeutral) -- | The body of a 'SegOp'. data KernelBody rep = KernelBody { kernelBodyDec :: BodyDec rep, kernelBodyStms :: Stms rep, kernelBodyResult :: [KernelResult] } deriving instance (RepTypes rep) => Ord (KernelBody rep) deriving instance (RepTypes rep) => Show (KernelBody rep) deriving instance (RepTypes rep) => Eq (KernelBody rep) -- | Metadata about whether there is a subtle point to this -- 'KernelResult'. This is used to protect things like tiling, which -- might otherwise be removed by the simplifier because they're -- semantically redundant. This has no semantic effect and can be -- ignored at code generation. data ResultManifest = -- | Don't simplify this one! ResultNoSimplify | -- | Go nuts. ResultMaySimplify | -- | The results produced are only used within the -- same physical thread later on, and can thus be -- kept in registers. ResultPrivate deriving (Eq, Show, Ord) -- | A 'KernelBody' does not return an ordinary 'Result'. Instead, it -- returns a list of these. data KernelResult = -- | Each "worker" in the kernel returns this. -- Whether this is a result-per-thread or a -- result-per-block depends on where the 'SegOp' occurs. Returns ResultManifest Certs SubExp | WriteReturns Certs VName -- Destination array [(Slice SubExp, SubExp)] | TileReturns Certs [(SubExp, SubExp)] -- Total/tile for each dimension VName -- Tile written by this worker. -- The TileReturns must not expect more than one -- result to be written per physical thread. | RegTileReturns Certs -- For each dim of result: [ ( SubExp, -- size of this dim. SubExp, -- block tile size for this dim. SubExp -- reg tile size for this dim. ) ] VName -- Tile returned by this thread/block. deriving (Eq, Show, Ord) -- | Get the certs for this 'KernelResult'. kernelResultCerts :: KernelResult -> Certs kernelResultCerts (Returns _ cs _) = cs kernelResultCerts (WriteReturns cs _ _) = cs kernelResultCerts (TileReturns cs _ _) = cs kernelResultCerts (RegTileReturns cs _ _) = cs -- | Get the root t'SubExp' corresponding values for a 'KernelResult'. kernelResultSubExp :: KernelResult -> SubExp kernelResultSubExp (Returns _ _ se) = se kernelResultSubExp (WriteReturns _ arr _) = Var arr kernelResultSubExp (TileReturns _ _ v) = Var v kernelResultSubExp (RegTileReturns _ _ v) = Var v instance FreeIn KernelResult where freeIn' (Returns _ cs what) = freeIn' cs <> freeIn' what freeIn' (WriteReturns cs arr res) = freeIn' cs <> freeIn' arr <> freeIn' res freeIn' (TileReturns cs dims v) = freeIn' cs <> freeIn' dims <> freeIn' v freeIn' (RegTileReturns cs dims_n_tiles v) = freeIn' cs <> freeIn' dims_n_tiles <> freeIn' v instance (ASTRep rep) => FreeIn (KernelBody rep) where freeIn' (KernelBody dec stms res) = fvBind bound_in_stms $ freeIn' dec <> freeIn' stms <> freeIn' res where bound_in_stms = foldMap boundByStm stms instance (ASTRep rep) => Substitute (KernelBody rep) where substituteNames subst (KernelBody dec stms res) = KernelBody (substituteNames subst dec) (substituteNames subst stms) (substituteNames subst res) instance Substitute KernelResult where substituteNames subst (Returns manifest cs se) = Returns manifest (substituteNames subst cs) (substituteNames subst se) substituteNames subst (WriteReturns cs arr res) = WriteReturns (substituteNames subst cs) (substituteNames subst arr) (substituteNames subst res) substituteNames subst (TileReturns cs dims v) = TileReturns (substituteNames subst cs) (substituteNames subst dims) (substituteNames subst v) substituteNames subst (RegTileReturns cs dims_n_tiles v) = RegTileReturns (substituteNames subst cs) (substituteNames subst dims_n_tiles) (substituteNames subst v) instance (ASTRep rep) => Rename (KernelBody rep) where rename (KernelBody dec stms res) = do dec' <- rename dec renamingStms stms $ \stms' -> KernelBody dec' stms' <$> rename res instance Rename KernelResult where rename = substituteRename -- | Perform alias analysis on a 'KernelBody'. aliasAnalyseKernelBody :: (Alias.AliasableRep rep) => AliasTable -> KernelBody rep -> KernelBody (Aliases rep) aliasAnalyseKernelBody aliases (KernelBody dec stms res) = let Body dec' stms' _ = Alias.analyseBody aliases $ Body dec stms [] in KernelBody dec' stms' res -- | The variables consumed in the kernel body. consumedInKernelBody :: (Aliased rep) => KernelBody rep -> Names consumedInKernelBody (KernelBody dec stms res) = consumedInBody (Body dec stms []) <> mconcat (map consumedByReturn res) where consumedByReturn (WriteReturns _ a _) = oneName a consumedByReturn _ = mempty checkKernelBody :: (TC.Checkable rep) => [Type] -> KernelBody (Aliases rep) -> TC.TypeM rep () checkKernelBody ts (KernelBody (_, dec) stms kres) = do TC.checkBodyDec dec -- We consume the kernel results (when applicable) before -- type-checking the stms, so we will get an error if a statement -- uses an array that is written to in a result. mapM_ consumeKernelResult kres TC.checkStms stms $ do unless (length ts == length kres) $ TC.bad . TC.TypeError $ "Kernel return type is " <> prettyTuple ts <> ", but body returns " <> prettyText (length kres) <> " values." zipWithM_ checkKernelResult kres ts where consumeKernelResult (WriteReturns _ arr _) = TC.consume =<< TC.lookupAliases arr consumeKernelResult _ = pure () checkKernelResult (Returns _ cs what) t = do TC.checkCerts cs TC.require [t] what checkKernelResult (WriteReturns cs arr res) t = do TC.checkCerts cs arr_t <- lookupType arr unless (arr_t == t) $ TC.bad . TC.TypeError $ "WriteReturns result type annotation for " <> prettyText arr <> " is " <> prettyText t <> ", but inferred as" <> prettyText arr_t forM_ res $ \(slice, e) -> do TC.checkSlice arr_t slice TC.require [t `setArrayShape` sliceShape slice] e checkKernelResult (TileReturns cs dims v) t = do TC.checkCerts cs forM_ dims $ \(dim, tile) -> do TC.require [Prim int64] dim TC.require [Prim int64] tile vt <- lookupType v unless (vt == t `arrayOfShape` Shape (map snd dims)) $ TC.bad $ TC.TypeError $ "Invalid type for TileReturns " <> prettyText v checkKernelResult (RegTileReturns cs dims_n_tiles arr) t = do TC.checkCerts cs mapM_ (TC.require [Prim int64]) dims mapM_ (TC.require [Prim int64]) blk_tiles mapM_ (TC.require [Prim int64]) reg_tiles -- assert that arr is of element type t and shape (rev outer_tiles ++ reg_tiles) arr_t <- lookupType arr unless (arr_t == expected) $ TC.bad . TC.TypeError $ "Invalid type for TileReturns. Expected:\n " <> prettyText expected <> ",\ngot:\n " <> prettyText arr_t where (dims, blk_tiles, reg_tiles) = unzip3 dims_n_tiles expected = t `arrayOfShape` Shape (blk_tiles <> reg_tiles) kernelBodyMetrics :: (OpMetrics (Op rep)) => KernelBody rep -> MetricsM () kernelBodyMetrics = mapM_ stmMetrics . kernelBodyStms instance (PrettyRep rep) => Pretty (KernelBody rep) where pretty (KernelBody _ stms res) = PP.stack (map pretty (stmsToList stms)) "return" <+> PP.braces (PP.commastack $ map pretty res) certAnnots :: Certs -> [Doc ann] certAnnots cs | cs == mempty = [] | otherwise = [pretty cs] instance Pretty KernelResult where pretty (Returns ResultNoSimplify cs what) = hsep $ certAnnots cs <> ["returns (manifest)" <+> pretty what] pretty (Returns ResultPrivate cs what) = hsep $ certAnnots cs <> ["returns (private)" <+> pretty what] pretty (Returns ResultMaySimplify cs what) = hsep $ certAnnots cs <> ["returns" <+> pretty what] pretty (WriteReturns cs arr res) = hsep $ certAnnots cs <> [pretty arr "with" <+> PP.apply (map ppRes res)] where ppRes (slice, e) = pretty slice <+> "=" <+> pretty e pretty (TileReturns cs dims v) = hsep $ certAnnots cs <> ["tile" <> apply (map onDim dims) <+> pretty v] where onDim (dim, tile) = pretty dim <+> "/" <+> pretty tile pretty (RegTileReturns cs dims_n_tiles v) = hsep $ certAnnots cs <> ["blkreg_tile" <> apply (map onDim dims_n_tiles) <+> pretty v] where onDim (dim, blk_tile, reg_tile) = pretty dim <+> "/" <+> parens (pretty blk_tile <+> "*" <+> pretty reg_tile) -- | Index space of a 'SegOp'. data SegSpace = SegSpace { -- | Flat physical index corresponding to the -- dimensions (at code generation used for a -- thread ID or similar). segFlat :: VName, unSegSpace :: [(VName, SubExp)] } deriving (Eq, Ord, Show) -- | The sizes spanned by the indexes of the 'SegSpace'. segSpaceDims :: SegSpace -> [SubExp] segSpaceDims (SegSpace _ space) = map snd space -- | A 'Scope' containing all the identifiers brought into scope by -- this 'SegSpace'. scopeOfSegSpace :: SegSpace -> Scope rep scopeOfSegSpace (SegSpace phys space) = M.fromList $ map (,IndexName Int64) (phys : map fst space) checkSegSpace :: (TC.Checkable rep) => SegSpace -> TC.TypeM rep () checkSegSpace (SegSpace _ dims) = mapM_ (TC.require [Prim int64] . snd) dims -- | A 'SegOp' is semantically a perfectly nested stack of maps, on -- top of some bottommost computation (scalar computation, reduction, -- scan, or histogram). The 'SegSpace' encodes the original map -- structure. -- -- All 'SegOp's are parameterised by the representation of their body, -- as well as a *level*. The *level* is a representation-specific bit -- of information. For example, in GPU backends, it is used to -- indicate whether the 'SegOp' is expected to run at the thread-level -- or the block-level. -- -- The type list is usually the type of the element returned by a -- single thread. The result of the SegOp is then an array of that -- type, with the shape of the 'SegSpace' prepended. One exception is -- for 'WriteReturns', where the type annotation is the /full/ type of -- the result. data SegOp lvl rep = SegMap lvl SegSpace [Type] (KernelBody rep) | -- | The KernelSpace must always have at least two dimensions, -- implying that the result of a SegRed is always an array. SegRed lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep) | SegScan lvl SegSpace [SegBinOp rep] [Type] (KernelBody rep) | SegHist lvl SegSpace [HistOp rep] [Type] (KernelBody rep) deriving (Eq, Ord, Show) -- | The level of a 'SegOp'. segLevel :: SegOp lvl rep -> lvl segLevel (SegMap lvl _ _ _) = lvl segLevel (SegRed lvl _ _ _ _) = lvl segLevel (SegScan lvl _ _ _ _) = lvl segLevel (SegHist lvl _ _ _ _) = lvl -- | The space of a 'SegOp'. segSpace :: SegOp lvl rep -> SegSpace segSpace (SegMap _ lvl _ _) = lvl segSpace (SegRed _ lvl _ _ _) = lvl segSpace (SegScan _ lvl _ _ _) = lvl segSpace (SegHist _ lvl _ _ _) = lvl -- | The body of a 'SegOp'. segBody :: SegOp lvl rep -> KernelBody rep segBody segop = case segop of SegMap _ _ _ body -> body SegRed _ _ _ _ body -> body SegScan _ _ _ _ body -> body SegHist _ _ _ _ body -> body segResultShape :: SegSpace -> Type -> KernelResult -> Type segResultShape _ t (WriteReturns {}) = t segResultShape space t Returns {} = foldr (flip arrayOfRow) t $ segSpaceDims space segResultShape _ t (TileReturns _ dims _) = t `arrayOfShape` Shape (map fst dims) segResultShape _ t (RegTileReturns _ dims_n_tiles _) = t `arrayOfShape` Shape (map (\(dim, _, _) -> dim) dims_n_tiles) -- | The return type of a 'SegOp'. segOpType :: SegOp lvl rep -> [Type] segOpType (SegMap _ space ts kbody) = zipWith (segResultShape space) ts $ kernelBodyResult kbody segOpType (SegRed _ space reds ts kbody) = red_ts ++ zipWith (segResultShape space) map_ts (drop (length red_ts) $ kernelBodyResult kbody) where map_ts = drop (length red_ts) ts segment_dims = init $ segSpaceDims space red_ts = do op <- reds let shape = Shape segment_dims <> segBinOpShape op map (`arrayOfShape` shape) (lambdaReturnType $ segBinOpLambda op) segOpType (SegScan _ space scans ts kbody) = scan_ts ++ zipWith (segResultShape space) map_ts (drop (length scan_ts) $ kernelBodyResult kbody) where map_ts = drop (length scan_ts) ts scan_ts = do op <- scans let shape = Shape (segSpaceDims space) <> segBinOpShape op map (`arrayOfShape` shape) (lambdaReturnType $ segBinOpLambda op) segOpType (SegHist _ space ops _ _) = do op <- ops let shape = Shape segment_dims <> histShape op <> histOpShape op map (`arrayOfShape` shape) (lambdaReturnType $ histOp op) where dims = segSpaceDims space segment_dims = init dims instance TypedOp (SegOp lvl) where opType = pure . staticShapes . segOpType instance (ASTConstraints lvl) => AliasedOp (SegOp lvl) where opAliases = map (const mempty) . segOpType consumedInOp (SegMap _ _ _ kbody) = consumedInKernelBody kbody consumedInOp (SegRed _ _ _ _ kbody) = consumedInKernelBody kbody consumedInOp (SegScan _ _ _ _ kbody) = consumedInKernelBody kbody consumedInOp (SegHist _ _ ops _ kbody) = namesFromList (concatMap histDest ops) <> consumedInKernelBody kbody -- | Type check a 'SegOp', given a checker for its level. typeCheckSegOp :: (TC.Checkable rep) => (lvl -> TC.TypeM rep ()) -> SegOp lvl (Aliases rep) -> TC.TypeM rep () typeCheckSegOp checkLvl (SegMap lvl space ts kbody) = do checkLvl lvl checkScanRed space [] ts kbody typeCheckSegOp checkLvl (SegRed lvl space reds ts body) = do checkLvl lvl checkScanRed space reds' ts body where reds' = zip3 (map segBinOpLambda reds) (map segBinOpNeutral reds) (map segBinOpShape reds) typeCheckSegOp checkLvl (SegScan lvl space scans ts body) = do checkLvl lvl checkScanRed space scans' ts body where scans' = zip3 (map segBinOpLambda scans) (map segBinOpNeutral scans) (map segBinOpShape scans) typeCheckSegOp checkLvl (SegHist lvl space ops ts kbody) = do checkLvl lvl checkSegSpace space mapM_ TC.checkType ts TC.binding (scopeOfSegSpace space) $ do nes_ts <- forM ops $ \(HistOp dest_shape rf dests nes shape op) -> do mapM_ (TC.require [Prim int64]) dest_shape TC.require [Prim int64] rf nes' <- mapM TC.checkArg nes mapM_ (TC.require [Prim int64]) $ shapeDims shape -- Operator type must match the type of neutral elements. let stripVecDims = stripArray $ shapeRank shape TC.checkLambda op $ map (TC.noArgAliases . first stripVecDims) $ nes' ++ nes' let nes_t = map TC.argType nes' unless (nes_t == lambdaReturnType op) $ TC.bad $ TC.TypeError $ "SegHist operator has return type " <> prettyTuple (lambdaReturnType op) <> " but neutral element has type " <> prettyTuple nes_t -- Arrays must have proper type. let dest_shape' = Shape segment_dims <> dest_shape <> shape forM_ (zip nes_t dests) $ \(t, dest) -> do TC.requireI [t `arrayOfShape` dest_shape'] dest TC.consume =<< TC.lookupAliases dest pure $ map (`arrayOfShape` shape) nes_t checkKernelBody ts kbody -- Return type of bucket function must be an index for each -- operation followed by the values to write. let bucket_ret_t = concatMap ((`replicate` Prim int64) . shapeRank . histShape) ops ++ concat nes_ts unless (bucket_ret_t == ts) $ TC.bad $ TC.TypeError $ "SegHist body has return type " <> prettyTuple ts <> " but should have type " <> prettyTuple bucket_ret_t where segment_dims = init $ segSpaceDims space checkScanRed :: (TC.Checkable rep) => SegSpace -> [(Lambda (Aliases rep), [SubExp], Shape)] -> [Type] -> KernelBody (Aliases rep) -> TC.TypeM rep () checkScanRed space ops ts kbody = do checkSegSpace space mapM_ TC.checkType ts TC.binding (scopeOfSegSpace space) $ do ne_ts <- forM ops $ \(lam, nes, shape) -> do mapM_ (TC.require [Prim int64]) $ shapeDims shape nes' <- mapM TC.checkArg nes -- Operator type must match the type of neutral elements. TC.checkLambda lam $ map TC.noArgAliases $ nes' ++ nes' let nes_t = map TC.argType nes' unless (lambdaReturnType lam == nes_t) $ TC.bad $ TC.TypeError "wrong type for operator or neutral elements." pure $ map (`arrayOfShape` shape) nes_t let expecting = concat ne_ts got = take (length expecting) ts unless (expecting == got) $ TC.bad $ TC.TypeError $ "Wrong return for body (does not match neutral elements; expected " <> prettyText expecting <> "; found " <> prettyText got <> ")" checkKernelBody ts kbody -- | Like 'Mapper', but just for 'SegOp's. data SegOpMapper lvl frep trep m = SegOpMapper { mapOnSegOpSubExp :: SubExp -> m SubExp, mapOnSegOpLambda :: Lambda frep -> m (Lambda trep), mapOnSegOpBody :: KernelBody frep -> m (KernelBody trep), mapOnSegOpVName :: VName -> m VName, mapOnSegOpLevel :: lvl -> m lvl } -- | A mapper that simply returns the 'SegOp' verbatim. identitySegOpMapper :: (Monad m) => SegOpMapper lvl rep rep m identitySegOpMapper = SegOpMapper { mapOnSegOpSubExp = pure, mapOnSegOpLambda = pure, mapOnSegOpBody = pure, mapOnSegOpVName = pure, mapOnSegOpLevel = pure } mapOnSegSpace :: (Monad f) => SegOpMapper lvl frep trep f -> SegSpace -> f SegSpace mapOnSegSpace tv (SegSpace phys dims) = SegSpace <$> mapOnSegOpVName tv phys <*> traverse (bitraverse (mapOnSegOpVName tv) (mapOnSegOpSubExp tv)) dims mapSegBinOp :: (Monad m) => SegOpMapper lvl frep trep m -> SegBinOp frep -> m (SegBinOp trep) mapSegBinOp tv (SegBinOp comm red_op nes shape) = SegBinOp comm <$> mapOnSegOpLambda tv red_op <*> mapM (mapOnSegOpSubExp tv) nes <*> (Shape <$> mapM (mapOnSegOpSubExp tv) (shapeDims shape)) -- | Apply a 'SegOpMapper' to the given 'SegOp'. mapSegOpM :: (Monad m) => SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep) mapSegOpM tv (SegMap lvl space ts body) = SegMap <$> mapOnSegOpLevel tv lvl <*> mapOnSegSpace tv space <*> mapM (mapOnSegOpType tv) ts <*> mapOnSegOpBody tv body mapSegOpM tv (SegRed lvl space reds ts lam) = SegRed <$> mapOnSegOpLevel tv lvl <*> mapOnSegSpace tv space <*> mapM (mapSegBinOp tv) reds <*> mapM (mapOnType $ mapOnSegOpSubExp tv) ts <*> mapOnSegOpBody tv lam mapSegOpM tv (SegScan lvl space scans ts body) = SegScan <$> mapOnSegOpLevel tv lvl <*> mapOnSegSpace tv space <*> mapM (mapSegBinOp tv) scans <*> mapM (mapOnType $ mapOnSegOpSubExp tv) ts <*> mapOnSegOpBody tv body mapSegOpM tv (SegHist lvl space ops ts body) = SegHist <$> mapOnSegOpLevel tv lvl <*> mapOnSegSpace tv space <*> mapM onHistOp ops <*> mapM (mapOnType $ mapOnSegOpSubExp tv) ts <*> mapOnSegOpBody tv body where onHistOp (HistOp w rf arrs nes shape op) = HistOp <$> mapM (mapOnSegOpSubExp tv) w <*> mapOnSegOpSubExp tv rf <*> mapM (mapOnSegOpVName tv) arrs <*> mapM (mapOnSegOpSubExp tv) nes <*> (Shape <$> mapM (mapOnSegOpSubExp tv) (shapeDims shape)) <*> mapOnSegOpLambda tv op mapOnSegOpType :: (Monad m) => SegOpMapper lvl frep trep m -> Type -> m Type mapOnSegOpType _tv t@Prim {} = pure t mapOnSegOpType tv (Acc acc ispace ts u) = Acc <$> mapOnSegOpVName tv acc <*> traverse (mapOnSegOpSubExp tv) ispace <*> traverse (bitraverse (traverse (mapOnSegOpSubExp tv)) pure) ts <*> pure u mapOnSegOpType tv (Array et shape u) = Array et <$> traverse (mapOnSegOpSubExp tv) shape <*> pure u mapOnSegOpType _tv (Mem s) = pure $ Mem s rephraseBinOp :: (Monad f) => Rephraser f from rep -> SegBinOp from -> f (SegBinOp rep) rephraseBinOp r (SegBinOp comm lam nes shape) = SegBinOp comm <$> rephraseLambda r lam <*> pure nes <*> pure shape rephraseKernelBody :: (Monad f) => Rephraser f from rep -> KernelBody from -> f (KernelBody rep) rephraseKernelBody r (KernelBody dec stms res) = KernelBody <$> rephraseBodyDec r dec <*> traverse (rephraseStm r) stms <*> pure res instance RephraseOp (SegOp lvl) where rephraseInOp r (SegMap lvl space ts body) = SegMap lvl space ts <$> rephraseKernelBody r body rephraseInOp r (SegRed lvl space reds ts body) = SegRed lvl space <$> mapM (rephraseBinOp r) reds <*> pure ts <*> rephraseKernelBody r body rephraseInOp r (SegScan lvl space scans ts body) = SegScan lvl space <$> mapM (rephraseBinOp r) scans <*> pure ts <*> rephraseKernelBody r body rephraseInOp r (SegHist lvl space hists ts body) = SegHist lvl space <$> mapM onOp hists <*> pure ts <*> rephraseKernelBody r body where onOp (HistOp w rf arrs nes shape op) = HistOp w rf arrs nes shape <$> rephraseLambda r op -- | A helper for defining 'TraverseOpStms'. traverseSegOpStms :: (Monad m) => OpStmsTraverser m (SegOp lvl rep) rep traverseSegOpStms f segop = mapSegOpM mapper segop where seg_scope = scopeOfSegSpace (segSpace segop) f' scope = f (seg_scope <> scope) mapper = identitySegOpMapper { mapOnSegOpLambda = traverseLambdaStms f', mapOnSegOpBody = onBody } onBody (KernelBody dec stms res) = KernelBody dec <$> f seg_scope stms <*> pure res instance (ASTRep rep, Substitute lvl) => Substitute (SegOp lvl rep) where substituteNames subst = runIdentity . mapSegOpM substitute where substitute = SegOpMapper { mapOnSegOpSubExp = pure . substituteNames subst, mapOnSegOpLambda = pure . substituteNames subst, mapOnSegOpBody = pure . substituteNames subst, mapOnSegOpVName = pure . substituteNames subst, mapOnSegOpLevel = pure . substituteNames subst } instance (ASTRep rep, ASTConstraints lvl) => Rename (SegOp lvl rep) where rename op = renameBound (M.keys (scopeOfSegSpace (segSpace op))) $ mapSegOpM renamer op where renamer = SegOpMapper rename rename rename rename rename instance (ASTRep rep, FreeIn lvl) => FreeIn (SegOp lvl rep) where freeIn' e = fvBind (namesFromList $ M.keys $ scopeOfSegSpace (segSpace e)) $ flip execState mempty $ mapSegOpM free e where walk f x = modify (<> f x) >> pure x free = SegOpMapper { mapOnSegOpSubExp = walk freeIn', mapOnSegOpLambda = walk freeIn', mapOnSegOpBody = walk freeIn', mapOnSegOpVName = walk freeIn', mapOnSegOpLevel = walk freeIn' } instance (OpMetrics (Op rep)) => OpMetrics (SegOp lvl rep) where opMetrics (SegMap _ _ _ body) = inside "SegMap" $ kernelBodyMetrics body opMetrics (SegRed _ _ reds _ body) = inside "SegRed" $ do mapM_ (inside "SegBinOp" . lambdaMetrics . segBinOpLambda) reds kernelBodyMetrics body opMetrics (SegScan _ _ scans _ body) = inside "SegScan" $ do mapM_ (inside "SegBinOp" . lambdaMetrics . segBinOpLambda) scans kernelBodyMetrics body opMetrics (SegHist _ _ ops _ body) = inside "SegHist" $ do mapM_ (lambdaMetrics . histOp) ops kernelBodyMetrics body instance Pretty SegSpace where pretty (SegSpace phys dims) = apply ( do (i, d) <- dims pure $ pretty i <+> "<" <+> pretty d ) <+> parens ("~" <> pretty phys) instance (PrettyRep rep) => Pretty (SegBinOp rep) where pretty (SegBinOp comm lam nes shape) = PP.braces (PP.commasep $ map pretty nes) <> PP.comma pretty shape <> PP.comma comm' <> pretty lam where comm' = case comm of Commutative -> "commutative " Noncommutative -> mempty instance (PrettyRep rep, PP.Pretty lvl) => PP.Pretty (SegOp lvl rep) where pretty (SegMap lvl space ts body) = "segmap" <> pretty lvl PP.align (pretty space) <+> PP.colon <+> ppTuple' (map pretty ts) <+> PP.nestedBlock "{" "}" (pretty body) pretty (SegRed lvl space reds ts body) = "segred" <> pretty lvl PP.align (pretty space) PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map pretty reds) PP.colon <+> ppTuple' (map pretty ts) <+> PP.nestedBlock "{" "}" (pretty body) pretty (SegScan lvl space scans ts body) = "segscan" <> pretty lvl PP.align (pretty space) PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map pretty scans) PP.colon <+> ppTuple' (map pretty ts) <+> PP.nestedBlock "{" "}" (pretty body) pretty (SegHist lvl space ops ts body) = "seghist" <> pretty lvl PP.align (pretty space) PP.parens (mconcat $ intersperse (PP.comma <> PP.line) $ map ppOp ops) PP.colon <+> ppTuple' (map pretty ts) <+> PP.nestedBlock "{" "}" (pretty body) where ppOp (HistOp w rf dests nes shape op) = pretty w <> PP.comma <+> pretty rf <> PP.comma PP.braces (PP.commasep $ map pretty dests) <> PP.comma PP.braces (PP.commasep $ map pretty nes) <> PP.comma pretty shape <> PP.comma pretty op instance CanBeAliased (SegOp lvl) where addOpAliases aliases = runIdentity . mapSegOpM alias where alias = SegOpMapper pure (pure . Alias.analyseLambda aliases) (pure . aliasAnalyseKernelBody aliases) pure pure informKernelBody :: (Informing rep) => KernelBody rep -> KernelBody (Wise rep) informKernelBody (KernelBody dec stms res) = mkWiseKernelBody dec (informStms stms) res instance CanBeWise (SegOp lvl) where addOpWisdom = runIdentity . mapSegOpM add where add = SegOpMapper pure (pure . informLambda) (pure . informKernelBody) pure pure instance (ASTRep rep) => ST.IndexOp (SegOp lvl rep) where indexOp vtable k (SegMap _ space _ kbody) is = do Returns ResultMaySimplify _ se <- maybeNth k $ kernelBodyResult kbody guard $ length gtids <= length is let idx_table = M.fromList $ zip gtids $ map (ST.Indexed mempty . untyped) is idx_table' = foldl' expandIndexedTable idx_table $ kernelBodyStms kbody case se of Var v -> M.lookup v idx_table' _ -> Nothing where (gtids, _) = unzip $ unSegSpace space -- Indexes in excess of what is used to index through the -- segment dimensions. excess_is = drop (length gtids) is expandIndexedTable table stm | [v] <- patNames $ stmPat stm, Just (pe, cs) <- runWriterT $ primExpFromExp (asPrimExp table) $ stmExp stm = M.insert v (ST.Indexed (stmCerts stm <> cs) pe) table | [v] <- patNames $ stmPat stm, BasicOp (Index arr slice) <- stmExp stm, length (sliceDims slice) == length excess_is, arr `ST.available` vtable, Just (slice', cs) <- asPrimExpSlice table slice = let idx = ST.IndexedArray (stmCerts stm <> cs) arr (fixSlice (fmap isInt64 slice') excess_is) in M.insert v idx table | otherwise = table asPrimExpSlice table = runWriterT . traverse (primExpFromSubExpM (asPrimExp table)) asPrimExp table v | Just (ST.Indexed cs e) <- M.lookup v table = tell cs >> pure e | Just (Prim pt) <- ST.lookupType v vtable = pure $ LeafExp v pt | otherwise = lift Nothing indexOp _ _ _ _ = Nothing instance (ASTConstraints lvl) => IsOp (SegOp lvl) where cheapOp _ = False safeOp _ = True opDependencies op = replicate (length (segOpType op)) (freeIn op) --- Simplification instance Engine.Simplifiable SegSpace where simplify (SegSpace phys dims) = SegSpace phys <$> mapM (traverse Engine.simplify) dims instance Engine.Simplifiable KernelResult where simplify (Returns manifest cs what) = Returns manifest <$> Engine.simplify cs <*> Engine.simplify what simplify (WriteReturns cs a res) = WriteReturns <$> Engine.simplify cs <*> Engine.simplify a <*> Engine.simplify res simplify (TileReturns cs dims what) = TileReturns <$> Engine.simplify cs <*> Engine.simplify dims <*> Engine.simplify what simplify (RegTileReturns cs dims_n_tiles what) = RegTileReturns <$> Engine.simplify cs <*> Engine.simplify dims_n_tiles <*> Engine.simplify what mkWiseKernelBody :: (Informing rep) => BodyDec rep -> Stms (Wise rep) -> [KernelResult] -> KernelBody (Wise rep) mkWiseKernelBody dec stms res = let Body dec' _ _ = mkWiseBody dec stms $ subExpsRes res_vs in KernelBody dec' stms res where res_vs = map kernelResultSubExp res mkKernelBodyM :: (MonadBuilder m) => Stms (Rep m) -> [KernelResult] -> m (KernelBody (Rep m)) mkKernelBodyM stms kres = do Body dec' _ _ <- mkBodyM stms $ subExpsRes res_ses pure $ KernelBody dec' stms kres where res_ses = map kernelResultSubExp kres simplifyKernelBody :: (Engine.SimplifiableRep rep, BodyDec rep ~ ()) => SegSpace -> KernelBody (Wise rep) -> Engine.SimpleM rep (KernelBody (Wise rep), Stms (Wise rep)) simplifyKernelBody space (KernelBody _ stms res) = do par_blocker <- Engine.asksEngineEnv $ Engine.blockHoistPar . Engine.envHoistBlockers let blocker = Engine.hasFree bound_here `Engine.orIf` Engine.isOp `Engine.orIf` par_blocker `Engine.orIf` Engine.isConsumed `Engine.orIf` Engine.isConsuming `Engine.orIf` Engine.isDeviceMigrated -- Ensure we do not try to use anything that is consumed in the result. (body_res, body_stms, hoisted) <- Engine.localVtable (flip (foldl' (flip ST.consume)) (foldMap consumedInResult res)) . Engine.localVtable (<> scope_vtable) . Engine.localVtable (\vtable -> vtable {ST.simplifyMemory = True}) . Engine.enterLoop $ Engine.blockIf blocker stms $ do res' <- Engine.localVtable (ST.hideCertified $ namesFromList $ M.keys $ scopeOf stms) $ mapM Engine.simplify res pure (res', UT.usages $ freeIn res') pure (mkWiseKernelBody () body_stms body_res, hoisted) where scope_vtable = segSpaceSymbolTable space bound_here = namesFromList $ M.keys $ scopeOfSegSpace space consumedInResult (WriteReturns _ arr _) = [arr] consumedInResult _ = [] simplifyLambda :: (Engine.SimplifiableRep rep) => Names -> Lambda (Wise rep) -> Engine.SimpleM rep (Lambda (Wise rep), Stms (Wise rep)) simplifyLambda bound = Engine.blockMigrated . Engine.simplifyLambda bound segSpaceSymbolTable :: (ASTRep rep) => SegSpace -> ST.SymbolTable rep segSpaceSymbolTable (SegSpace flat gtids_and_dims) = foldl' f (ST.fromScope $ M.singleton flat $ IndexName Int64) gtids_and_dims where f vtable (gtid, dim) = ST.insertLoopVar gtid Int64 dim vtable simplifySegBinOp :: (Engine.SimplifiableRep rep) => VName -> SegBinOp (Wise rep) -> Engine.SimpleM rep (SegBinOp (Wise rep), Stms (Wise rep)) simplifySegBinOp phys_id (SegBinOp comm lam nes shape) = do (lam', hoisted) <- Engine.localVtable (\vtable -> vtable {ST.simplifyMemory = True}) $ simplifyLambda (oneName phys_id) lam shape' <- Engine.simplify shape nes' <- mapM Engine.simplify nes pure (SegBinOp comm lam' nes' shape', hoisted) -- | Simplify the given 'SegOp'. simplifySegOp :: ( Engine.SimplifiableRep rep, BodyDec rep ~ (), Engine.Simplifiable lvl ) => SegOp lvl (Wise rep) -> Engine.SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep)) simplifySegOp (SegMap lvl space ts kbody) = do (lvl', space', ts') <- Engine.simplify (lvl, space, ts) (kbody', body_hoisted) <- simplifyKernelBody space kbody pure ( SegMap lvl' space' ts' kbody', body_hoisted ) simplifySegOp (SegRed lvl space reds ts kbody) = do (lvl', space', ts') <- Engine.simplify (lvl, space, ts) (reds', reds_hoisted) <- Engine.localVtable (<> scope_vtable) $ mapAndUnzipM (simplifySegBinOp (segFlat space)) reds (kbody', body_hoisted) <- simplifyKernelBody space kbody pure ( SegRed lvl' space' reds' ts' kbody', mconcat reds_hoisted <> body_hoisted ) where scope = scopeOfSegSpace space scope_vtable = ST.fromScope scope simplifySegOp (SegScan lvl space scans ts kbody) = do (lvl', space', ts') <- Engine.simplify (lvl, space, ts) (scans', scans_hoisted) <- Engine.localVtable (<> scope_vtable) $ mapAndUnzipM (simplifySegBinOp (segFlat space)) scans (kbody', body_hoisted) <- simplifyKernelBody space kbody pure ( SegScan lvl' space' scans' ts' kbody', mconcat scans_hoisted <> body_hoisted ) where scope = scopeOfSegSpace space scope_vtable = ST.fromScope scope simplifySegOp (SegHist lvl space ops ts kbody) = do (lvl', space', ts') <- Engine.simplify (lvl, space, ts) Engine.localVtable (flip (foldr ST.consume) $ concatMap histDest ops) $ do (ops', ops_hoisted) <- fmap unzip . forM ops $ \(HistOp w rf arrs nes dims lam) -> do w' <- Engine.simplify w rf' <- Engine.simplify rf arrs' <- Engine.simplify arrs nes' <- Engine.simplify nes dims' <- Engine.simplify dims (lam', op_hoisted) <- Engine.localVtable (<> scope_vtable) $ Engine.localVtable (\vtable -> vtable {ST.simplifyMemory = True}) $ simplifyLambda (oneName (segFlat space)) lam pure ( HistOp w' rf' arrs' nes' dims' lam', op_hoisted ) (kbody', body_hoisted) <- simplifyKernelBody space kbody pure ( SegHist lvl' space' ops' ts' kbody', mconcat ops_hoisted <> body_hoisted ) where scope = scopeOfSegSpace space scope_vtable = ST.fromScope scope -- | Does this rep contain 'SegOp's in its t'Op's? A rep must be an -- instance of this class for the simplification rules to work. class HasSegOp rep where type SegOpLevel rep asSegOp :: Op rep -> Maybe (SegOp (SegOpLevel rep) rep) segOp :: SegOp (SegOpLevel rep) rep -> Op rep -- | Simplification rules for simplifying 'SegOp's. segOpRules :: (HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) => RuleBook rep segOpRules = ruleBook [RuleOp segOpRuleTopDown] [RuleOp segOpRuleBottomUp] segOpRuleTopDown :: (HasSegOp rep, BuilderOps rep, Buildable rep) => TopDownRuleOp rep segOpRuleTopDown vtable pat dec op | Just op' <- asSegOp op = topDownSegOp vtable pat dec op' | otherwise = Skip segOpRuleBottomUp :: (HasSegOp rep, BuilderOps rep, Aliased rep) => BottomUpRuleOp rep segOpRuleBottomUp vtable pat dec op | Just op' <- asSegOp op = bottomUpSegOp vtable pat dec op' | otherwise = Skip topDownSegOp :: (HasSegOp rep, BuilderOps rep, Buildable rep) => ST.SymbolTable rep -> Pat (LetDec rep) -> StmAux (ExpDec rep) -> SegOp (SegOpLevel rep) rep -> Rule rep -- If a SegOp produces something invariant to the SegOp, turn it -- into a replicate. topDownSegOp vtable (Pat kpes) dec (SegMap lvl space ts (KernelBody _ kstms kres)) = Simplify $ do (ts', kpes', kres') <- unzip3 <$> filterM checkForInvarianceResult (zip3 ts kpes kres) -- Check if we did anything at all. when (kres == kres') cannotSimplify kbody <- mkKernelBodyM kstms kres' addStm $ Let (Pat kpes') dec $ Op $ segOp $ SegMap lvl space ts' kbody where isInvariant Constant {} = True isInvariant (Var v) = isJust $ ST.lookup v vtable checkForInvarianceResult (_, pe, Returns rm cs se) | cs == mempty, rm == ResultMaySimplify, isInvariant se = do letBindNames [patElemName pe] $ BasicOp $ Replicate (Shape $ segSpaceDims space) se pure False checkForInvarianceResult _ = pure True -- If a SegRed contains two reduction operations that have the same -- vector shape, merge them together. This saves on communication -- overhead, but can in principle lead to more shared memory usage. topDownSegOp _ (Pat pes) _ (SegRed lvl space ops ts kbody) | length ops > 1, op_groupings <- groupBy sameShape $ zip ops $ chunks (map (length . segBinOpNeutral) ops) $ zip3 red_pes red_ts red_res, any ((> 1) . length) op_groupings = Simplify $ do let (ops', aux) = unzip $ mapMaybe combineOps op_groupings (red_pes', red_ts', red_res') = unzip3 $ concat aux pes' = red_pes' ++ map_pes ts' = red_ts' ++ map_ts kbody' = kbody {kernelBodyResult = red_res' ++ map_res} letBind (Pat pes') $ Op $ segOp $ SegRed lvl space ops' ts' kbody' where (red_pes, map_pes) = splitAt (segBinOpResults ops) pes (red_ts, map_ts) = splitAt (segBinOpResults ops) ts (red_res, map_res) = splitAt (segBinOpResults ops) $ kernelBodyResult kbody sameShape (op1, _) (op2, _) = segBinOpShape op1 == segBinOpShape op2 && shapeRank (segBinOpShape op1) > 0 combineOps [] = Nothing combineOps (x : xs) = Just $ foldl' combine x xs combine (op1, op1_aux) (op2, op2_aux) = let lam1 = segBinOpLambda op1 lam2 = segBinOpLambda op2 (op1_xparams, op1_yparams) = splitAt (length (segBinOpNeutral op1)) $ lambdaParams lam1 (op2_xparams, op2_yparams) = splitAt (length (segBinOpNeutral op2)) $ lambdaParams lam2 lam = Lambda { lambdaParams = op1_xparams ++ op2_xparams ++ op1_yparams ++ op2_yparams, lambdaReturnType = lambdaReturnType lam1 ++ lambdaReturnType lam2, lambdaBody = mkBody (bodyStms (lambdaBody lam1) <> bodyStms (lambdaBody lam2)) $ bodyResult (lambdaBody lam1) <> bodyResult (lambdaBody lam2) } in ( SegBinOp { segBinOpComm = segBinOpComm op1 <> segBinOpComm op2, segBinOpLambda = lam, segBinOpNeutral = segBinOpNeutral op1 ++ segBinOpNeutral op2, segBinOpShape = segBinOpShape op1 -- Same as shape of op2 due to the grouping. }, op1_aux ++ op2_aux ) topDownSegOp _ _ _ _ = Skip -- A convenient way of operating on the type and body of a SegOp, -- without worrying about exactly what kind it is. segOpGuts :: SegOp (SegOpLevel rep) rep -> ( [Type], KernelBody rep, Int, [Type] -> KernelBody rep -> SegOp (SegOpLevel rep) rep ) segOpGuts (SegMap lvl space kts body) = (kts, body, 0, SegMap lvl space) segOpGuts (SegScan lvl space ops kts body) = (kts, body, segBinOpResults ops, SegScan lvl space ops) segOpGuts (SegRed lvl space ops kts body) = (kts, body, segBinOpResults ops, SegRed lvl space ops) segOpGuts (SegHist lvl space ops kts body) = (kts, body, sum $ map (length . histDest) ops, SegHist lvl space ops) bottomUpSegOp :: (Aliased rep, HasSegOp rep, BuilderOps rep) => (ST.SymbolTable rep, UT.UsageTable) -> Pat (LetDec rep) -> StmAux (ExpDec rep) -> SegOp (SegOpLevel rep) rep -> Rule rep -- Some SegOp results can be moved outside the SegOp, which can -- simplify further analysis. bottomUpSegOp (_vtable, used) (Pat kpes) dec segop -- Remove dead results. This is a bit tricky to do with scan/red -- results, so we only deal with map results for now. | (_, kpes', kts', kres') <- unzip4 $ filter keep $ zip4 [0 ..] kpes kts kres, kpes' /= kpes = Simplify $ do kbody' <- localScope (scopeOfSegSpace space) $ mkKernelBodyM kstms kres' addStm $ Let (Pat kpes') dec $ Op $ segOp $ mk_segop kts' kbody' where space = segSpace segop (kts, KernelBody _ kstms kres, num_nonmap_results, mk_segop) = segOpGuts segop keep (i, pe, _, _) = i < num_nonmap_results || patElemName pe `UT.used` used bottomUpSegOp (vtable, _used) (Pat kpes) dec segop = Simplify $ do -- Iterate through the bindings. For each, we check whether it is -- in kres and can be moved outside. If so, we remove it from kres -- and kpes and make it a binding outside. We have to be careful -- not to remove anything that is passed on to a scan/map/histogram -- operation. Fortunately, these are always first in the result -- list. (kpes', kts', kres', kstms') <- localScope (scopeOfSegSpace space) $ foldM distribute (kpes, kts, kres, mempty) kstms when (kpes' == kpes) cannotSimplify kbody' <- localScope (scopeOfSegSpace space) $ mkKernelBodyM kstms' kres' addStm $ Let (Pat kpes') dec $ Op $ segOp $ mk_segop kts' kbody' where (kts, KernelBody _ kstms kres, num_nonmap_results, mk_segop) = segOpGuts segop free_in_kstms = foldMap freeIn kstms space = segSpace segop sliceWithGtidsFixed stm | Let _ aux (BasicOp (Index arr slice)) <- stm, space_slice <- map (DimFix . Var . fst) $ unSegSpace space, space_slice `isPrefixOf` unSlice slice, remaining_slice <- Slice $ drop (length space_slice) (unSlice slice), all (isJust . flip ST.lookup vtable) $ namesToList $ freeIn arr <> freeIn remaining_slice <> freeIn (stmAuxCerts aux) = Just (remaining_slice, arr) | otherwise = Nothing distribute (kpes', kts', kres', kstms') stm | Let (Pat [pe]) _ _ <- stm, Just (Slice remaining_slice, arr) <- sliceWithGtidsFixed stm, Just (kpe, kpes'', kts'', kres'') <- isResult kpes' kts' kres' pe = do let outer_slice = map ( \d -> DimSlice (constant (0 :: Int64)) d (constant (1 :: Int64)) ) $ segSpaceDims space index kpe' = letBindNames [patElemName kpe'] . BasicOp . Index arr $ Slice $ outer_slice <> remaining_slice precopy <- newVName $ baseString (patElemName kpe) <> "_precopy" index kpe {patElemName = precopy} letBindNames [patElemName kpe] $ BasicOp $ Replicate mempty $ Var precopy pure ( kpes'', kts'', kres'', if patElemName pe `nameIn` free_in_kstms then kstms' <> oneStm stm else kstms' ) distribute (kpes', kts', kres', kstms') stm = pure (kpes', kts', kres', kstms' <> oneStm stm) isResult kpes' kts' kres' pe = case partition matches $ zip3 kpes' kts' kres' of ([(kpe, _, _)], kpes_and_kres) | Just i <- elemIndex kpe kpes, i >= num_nonmap_results, (kpes'', kts'', kres'') <- unzip3 kpes_and_kres -> Just (kpe, kpes'', kts'', kres'') _ -> Nothing where matches (_, _, Returns _ _ (Var v)) = v == patElemName pe matches _ = False --- Memory kernelBodyReturns :: (Mem rep inner, HasScope rep m, Monad m) => KernelBody somerep -> [ExpReturns] -> m [ExpReturns] kernelBodyReturns = zipWithM correct . kernelBodyResult where correct (WriteReturns _ arr _) _ = varReturns arr correct _ ret = pure ret -- | Like 'segOpType', but for memory representations. segOpReturns :: (Mem rep inner, Monad m, HasScope rep m) => SegOp lvl rep -> m [ExpReturns] segOpReturns k@(SegMap _ _ _ kbody) = kernelBodyReturns kbody . extReturns =<< opType k segOpReturns k@(SegRed _ _ _ _ kbody) = kernelBodyReturns kbody . extReturns =<< opType k segOpReturns k@(SegScan _ _ _ _ kbody) = kernelBodyReturns kbody . extReturns =<< opType k segOpReturns (SegHist _ _ ops _ _) = concat <$> mapM (mapM varReturns . histDest) ops futhark-0.25.27/src/Futhark/IR/Seq.hs000066400000000000000000000030221475065116200171300ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | A sequential representation. module Futhark.IR.Seq ( Seq, -- * Simplification simplifyProg, -- * Module re-exports module Futhark.IR.Prop, module Futhark.IR.Traversals, module Futhark.IR.Pretty, module Futhark.IR.Syntax, ) where import Futhark.Builder import Futhark.Construct import Futhark.IR.Pretty import Futhark.IR.Prop import Futhark.IR.Syntax import Futhark.IR.Traversals import Futhark.IR.TypeCheck qualified as TC import Futhark.Optimise.Simplify qualified as Simplify import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Optimise.Simplify.Rules import Futhark.Pass -- | The phantom type for the Seq representation. data Seq instance RepTypes Seq instance ASTRep Seq where expTypesFromPat = pure . expExtTypesFromPat instance TC.Checkable Seq where checkOp NoOp = pure () instance Buildable Seq where mkBody = Body () mkExpPat idents _ = basicPat idents mkExpDec _ _ = () mkLetNames = simpleMkLetNames instance BuilderOps Seq instance TraverseOpStms Seq where traverseOpStms _ = pure instance PrettyRep Seq instance BuilderOps (Engine.Wise Seq) instance TraverseOpStms (Engine.Wise Seq) where traverseOpStms _ = pure simpleSeq :: Simplify.SimpleOps Seq simpleSeq = Simplify.bindableSimpleOps (const $ pure (NoOp, mempty)) -- | Simplify a sequential program. simplifyProg :: Prog Seq -> PassM (Prog Seq) simplifyProg = Simplify.simplifyProg simpleSeq standardRules blockers where blockers = Engine.noExtraHoistBlockers futhark-0.25.27/src/Futhark/IR/SeqMem.hs000066400000000000000000000040501475065116200175710ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.IR.SeqMem ( SeqMem, -- * Simplification simplifyProg, simpleSeqMem, -- * Module re-exports module Futhark.IR.Mem, ) where import Futhark.Analysis.PrimExp.Convert import Futhark.IR.Mem import Futhark.IR.Mem.Simplify import Futhark.IR.TypeCheck qualified as TC import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Pass import Futhark.Pass.ExplicitAllocations (BuilderOps (..), mkLetNamesB', mkLetNamesB'') data SeqMem instance RepTypes SeqMem where type LetDec SeqMem = LetDecMem type FParamInfo SeqMem = FParamMem type LParamInfo SeqMem = LParamMem type RetType SeqMem = RetTypeMem type BranchType SeqMem = BranchTypeMem type OpC SeqMem = MemOp NoOp instance ASTRep SeqMem where expTypesFromPat = pure . map snd . bodyReturnsFromPat instance PrettyRep SeqMem instance TC.Checkable SeqMem where checkOp (Alloc size _) = TC.require [Prim int64] size checkOp (Inner NoOp) = pure () checkFParamDec = checkMemInfo checkLParamDec = checkMemInfo checkLetBoundDec = checkMemInfo checkRetType = mapM_ (TC.checkExtType . declExtTypeOf) primFParam name t = pure $ Param mempty name (MemPrim t) matchPat = matchPatToExp matchReturnType = matchFunctionReturnType matchBranchType = matchBranchReturnType matchLoopResult = matchLoopResultMem instance BuilderOps SeqMem where mkExpDecB _ _ = pure () mkBodyB stms res = pure $ Body () stms res mkLetNamesB = mkLetNamesB' DefaultSpace () instance TraverseOpStms SeqMem where traverseOpStms _ = pure instance BuilderOps (Engine.Wise SeqMem) where mkExpDecB pat e = pure $ Engine.mkWiseExpDec pat () e mkBodyB stms res = pure $ Engine.mkWiseBody () stms res mkLetNamesB = mkLetNamesB'' DefaultSpace instance TraverseOpStms (Engine.Wise SeqMem) where traverseOpStms _ = pure simplifyProg :: Prog SeqMem -> PassM (Prog SeqMem) simplifyProg = simplifyProgGeneric memRuleBook simpleSeqMem simpleSeqMem :: Engine.SimpleOps SeqMem simpleSeqMem = simpleGeneric (const mempty) $ const $ pure (NoOp, mempty) futhark-0.25.27/src/Futhark/IR/Syntax.hs000066400000000000000000000463611475065116200177030ustar00rootroot00000000000000{-# LANGUAGE Strict #-} {-# LANGUAGE TypeFamilies #-} -- | = Definition of the Futhark core language IR -- -- For actually /constructing/ ASTs, see "Futhark.Construct". -- -- == Types and values -- -- The core language type system is much more restricted than the core -- language. This is a theme that repeats often. The only types that -- are supported in the core language are various primitive types -- t'PrimType' which can be combined in arrays (ignore v'Mem' and -- v'Acc' for now). Types are represented as t'TypeBase', which is -- parameterised by the shape of the array and whether we keep -- uniqueness information. The t'Type' alias, which is the most -- commonly used, uses t'Shape' and t'NoUniqueness'. -- -- This means that the records, tuples, and sum types of the source -- language are represented merely as collections of primitives and -- arrays. This is implemented in "Futhark.Internalise", but the -- specifics are not important for writing passes on the core -- language. What /is/ important is that many constructs that -- conceptually return tuples instead return /multiple values/. This -- is not merely syntactic sugar for a tuple: each of those values are -- eventually bound to distinct variables. The prettyprinter for the -- IR will typically print such collections of values or types in -- curly braces. -- -- The system of primitive types is interesting in itself. See -- "Language.Futhark.Primitive". -- -- == Overall AST design -- -- Internally, the Futhark compiler core intermediate representation -- resembles a traditional compiler for an imperative language more -- than it resembles, say, a Haskell or ML compiler. All functions -- are monomorphic (except for sizes), first-order, and defined at the -- top level. Notably, the IR does /not/ use continuation-passing -- style (CPS) at any time. Instead it uses Administrative Normal -- Form (ANF), where all subexpressions t'SubExp' are either -- constants 'PrimValue' or variables 'VName'. Variables are -- represented as a human-readable t'Name' (which doesn't matter to -- the compiler) as well as a numeric /tag/, which is what the -- compiler actually looks at. All variable names when prettyprinted -- are of the form @foo_123@. Function names are just t'Name's, -- though. -- -- The body of a function ('FunDef') is a t'Body', which consists of -- a sequence of statements ('Stms') and a t'Result'. Execution of a -- t'Body' consists of executing all of the statements, then returning -- the values of the variables indicated by the result. -- -- A statement ('Stm') consists of a t'Pat' alongside an -- expression 'Exp'. A pattern is a sequence of name/type pairs. -- -- For example, the source language expression @let z = x + y - 1 in -- z@ would in the core language be represented (in prettyprinted -- form) as something like: -- -- @ -- let {a_12} = x_10 + y_11 -- let {b_13} = a_12 - 1 -- in {b_13} -- @ -- -- == Representations -- -- Most AST types ('Stm', 'Exp', t'Prog', etc) are parameterised by a -- type parameter @rep@. The representation specifies how to fill out -- various polymorphic parts of the AST. For example, 'Exp' has a -- constructor v'Op' whose payload depends on @rep@, via the use of a -- type family called t'Op' (a kind of type-level function) which is -- applied to the @rep@. The SOACS representation -- ("Futhark.IR.SOACS") thus uses a rep called @SOACS@, and defines -- that @Op SOACS@ is a SOAC, while the Kernels representation -- ("Futhark.IR.Kernels") defines @Op Kernels@ as some kind of kernel -- construct. Similarly, various other decorations (e.g. what -- information we store in a t'PatElem') are also type families. -- -- The full list of possible decorations is defined as part of the -- type class 'RepTypes' (although other type families are also -- used elsewhere in the compiler on an ad hoc basis). -- -- Essentially, the @rep@ type parameter functions as a kind of -- proxy, saving us from having to parameterise the AST type with all -- the different forms of decorations that we desire (it would easily -- become a type with a dozen type parameters). -- -- Some AST elements (such as 'Pat') do not take a @rep@ type -- parameter, but instead immediately the single type of decoration -- that they contain. We only use the more complicated machinery when -- needed. -- -- Defining a new representation (or /rep/) thus requires you to -- define an empty datatype and implement a handful of type class -- instances for it. See the source of "Futhark.IR.Seq" -- for what is likely the simplest example. module Futhark.IR.Syntax ( module Language.Futhark.Core, prettyString, prettyText, Pretty, module Futhark.IR.Rep, module Futhark.IR.Syntax.Core, -- * Types Uniqueness (..), NoUniqueness (..), Rank (..), ArrayShape (..), Space (..), TypeBase (..), Diet (..), -- * Abstract syntax tree Ident (..), SubExp (..), PatElem (..), Pat (..), StmAux (..), Stm (..), Stms, SubExpRes (..), Result, Body (..), BasicOp (..), UnOp (..), BinOp (..), CmpOp (..), ConvOp (..), OpaqueOp (..), ReshapeKind (..), WithAccInput, Exp (..), Case (..), LoopForm (..), MatchDec (..), MatchSort (..), Safety (..), Lambda (..), RetAls (..), -- * Definitions Param (..), FParam, LParam, FunDef (..), EntryParam (..), EntryResult (..), EntryPoint, Prog (..), -- * Utils oneStm, stmsFromList, stmsToList, stmsHead, stmsLast, subExpRes, subExpsRes, varRes, varsRes, subExpResVName, ) where import Control.Category import Data.Foldable import Data.List.NonEmpty (NonEmpty (..)) import Data.Sequence qualified as Seq import Data.Text qualified as T import Data.Traversable (fmapDefault, foldMapDefault) import Futhark.IR.Rep import Futhark.IR.Syntax.Core import Futhark.Util.Pretty (Pretty, prettyString, prettyText) import Language.Futhark.Core import Prelude hiding (id, (.)) -- | A pattern is conceptually just a list of names and their types. newtype Pat dec = Pat {patElems :: [PatElem dec]} deriving (Ord, Show, Eq) instance Semigroup (Pat dec) where Pat xs <> Pat ys = Pat (xs <> ys) instance Monoid (Pat dec) where mempty = Pat mempty instance Functor Pat where fmap = fmapDefault instance Foldable Pat where foldMap = foldMapDefault instance Traversable Pat where traverse f (Pat xs) = Pat <$> traverse (traverse f) xs -- | Auxilliary Information associated with a statement. data StmAux dec = StmAux { stmAuxCerts :: !Certs, stmAuxAttrs :: Attrs, stmAuxDec :: dec } deriving (Ord, Show, Eq) instance (Semigroup dec) => Semigroup (StmAux dec) where StmAux cs1 attrs1 dec1 <> StmAux cs2 attrs2 dec2 = StmAux (cs1 <> cs2) (attrs1 <> attrs2) (dec1 <> dec2) -- | A local variable binding. data Stm rep = Let { -- | Pat. stmPat :: Pat (LetDec rep), -- | Auxiliary information statement. stmAux :: StmAux (ExpDec rep), -- | Expression. stmExp :: Exp rep } deriving instance (RepTypes rep) => Ord (Stm rep) deriving instance (RepTypes rep) => Show (Stm rep) deriving instance (RepTypes rep) => Eq (Stm rep) -- | A sequence of statements. type Stms rep = Seq.Seq (Stm rep) -- | A single statement. oneStm :: Stm rep -> Stms rep oneStm = Seq.singleton -- | Convert a statement list to a statement sequence. stmsFromList :: [Stm rep] -> Stms rep stmsFromList = Seq.fromList -- | Convert a statement sequence to a statement list. stmsToList :: Stms rep -> [Stm rep] stmsToList = toList -- | The first statement in the sequence, if any. stmsHead :: Stms rep -> Maybe (Stm rep, Stms rep) stmsHead stms = case Seq.viewl stms of stm Seq.:< stms' -> Just (stm, stms') Seq.EmptyL -> Nothing -- | The last statement in the sequence, if any. stmsLast :: Stms lore -> Maybe (Stms lore, Stm lore) stmsLast stms = case Seq.viewr stms of stms' Seq.:> stm -> Just (stms', stm) Seq.EmptyR -> Nothing -- | A pairing of a subexpression and some certificates. data SubExpRes = SubExpRes { resCerts :: Certs, resSubExp :: SubExp } deriving (Eq, Ord, Show) -- | Construct a 'SubExpRes' with no certificates. subExpRes :: SubExp -> SubExpRes subExpRes = SubExpRes mempty -- | Construct a 'SubExpRes' from a variable name. varRes :: VName -> SubExpRes varRes = subExpRes . Var -- | Construct a 'Result' from subexpressions. subExpsRes :: [SubExp] -> Result subExpsRes = map subExpRes -- | Construct a 'Result' from variable names. varsRes :: [VName] -> Result varsRes = map varRes -- | The 'VName' of a 'SubExpRes', if it exists. subExpResVName :: SubExpRes -> Maybe VName subExpResVName (SubExpRes _ (Var v)) = Just v subExpResVName _ = Nothing -- | The result of a body is a sequence of subexpressions. type Result = [SubExpRes] -- | A body consists of a sequence of statements, terminating in a -- list of result values. data Body rep = Body { bodyDec :: BodyDec rep, bodyStms :: Stms rep, bodyResult :: Result } deriving instance (RepTypes rep) => Ord (Body rep) deriving instance (RepTypes rep) => Show (Body rep) deriving instance (RepTypes rep) => Eq (Body rep) -- | Apart from being Opaque, what else is going on here? data OpaqueOp = -- | No special operation. OpaqueNil | -- | Print the argument, prefixed by this string. OpaqueTrace T.Text deriving (Eq, Ord, Show) -- | Which kind of reshape is this? data ReshapeKind = -- | New shape is dynamically same as original. ReshapeCoerce | -- | Any kind of reshaping. ReshapeArbitrary deriving (Eq, Ord, Show) -- | A primitive operation that returns something of known size and -- does not itself contain any bindings. data BasicOp = -- | A variable or constant. SubExp SubExp | -- | Semantically and operationally just identity, but is -- invisible/impenetrable to optimisations (hopefully). This -- partially a hack to avoid optimisation (so, to work around -- compiler limitations), but is also used to implement tracing -- and other operations that are semantically invisible, but have -- some sort of effect (brrr). Opaque OpaqueOp SubExp | -- | Array literals, e.g., @[ [1+x, 3], [2, 1+4] ]@. -- Second arg is the element type of the rows of the array. ArrayLit [SubExp] Type | -- | A one-dimensional array literal that contains only constants. -- This is a fast-path for representing very large array literals -- that show up in some programs. The key rule for processing this -- in compiler passes is that you should never need to look at the -- individual elements. Has exactly the same semantics as an -- 'ArrayLit'. ArrayVal [PrimValue] PrimType | -- | Unary operation. UnOp UnOp SubExp | -- | Binary operation. BinOp BinOp SubExp SubExp | -- | Comparison - result type is always boolean. CmpOp CmpOp SubExp SubExp | -- | Conversion "casting". ConvOp ConvOp SubExp | -- | Turn a boolean into a certificate, halting the program with the -- given error message if the boolean is false. Assert SubExp (ErrorMsg SubExp) (SrcLoc, [SrcLoc]) | -- | The certificates for bounds-checking are part of the 'Stm'. Index VName (Slice SubExp) | -- | An in-place update of the given array at the given position. -- Consumes the array. If 'Safe', perform a run-time bounds check -- and ignore the write if out of bounds (like @Scatter@). Update Safety VName (Slice SubExp) SubExp | FlatIndex VName (FlatSlice SubExp) | FlatUpdate VName (FlatSlice SubExp) VName | -- | @concat(0, [1] :| [[2, 3, 4], [5, 6]], 6) = [1, 2, 3, 4, 5, 6]@ -- -- Concatenates the non-empty list of 'VName' resulting in an -- array of length t'SubExp'. The 'Int' argument is used to -- specify the dimension along which the arrays are -- concatenated. For instance: -- -- @concat(1, [[1,2], [3, 4]] :| [[[5,6]], [[7, 8]]], 4) = [[1, 2, 5, 6], [3, 4, 7, 8]]@ Concat Int (NonEmpty VName) SubExp | -- | Manifest an array with dimensions represented in the given -- order. The result will not alias anything. Manifest [Int] VName | -- Array construction. -- | @iota(n, x, s) = [x,x+s,..,x+(n-1)*s]@. -- -- The t'IntType' indicates the type of the array returned and the -- offset/stride arguments, but not the length argument. Iota SubExp SubExp SubExp IntType | -- | @replicate([3][2],1) = [[1,1], [1,1], [1,1]]@. The result -- has no aliases. Copy a value by passing an empty shape. Replicate Shape SubExp | -- | Create array of given type and shape, with undefined elements. Scratch PrimType [SubExp] | -- | 1st arg is the new shape, 2nd arg is the input array. Reshape ReshapeKind Shape VName | -- | Permute the dimensions of the input array. The list -- of integers is a list of dimensions (0-indexed), which -- must be a permutation of @[0,n-1]@, where @n@ is the -- number of dimensions in the input array. Rearrange [Int] VName | -- | Update an accumulator at the given index with the given -- value. Consumes the accumulator and produces a new one. If -- 'Safe', perform a run-time bounds check and ignore the write if -- out of bounds (like @Scatter@). UpdateAcc Safety VName [SubExp] [SubExp] deriving (Eq, Ord, Show) -- | The input to a 'WithAcc' construct. Comprises the index space of -- the accumulator, the underlying arrays, and possibly a combining -- function. type WithAccInput rep = (Shape, [VName], Maybe (Lambda rep, [SubExp])) -- | A non-default case in a 'Match' statement. The number of -- elements in the pattern must match the number of scrutinees. A -- 'Nothing' value indicates that we don't care about it (i.e. a -- wildcard). data Case body = Case {casePat :: [Maybe PrimValue], caseBody :: body} deriving (Eq, Ord, Show) instance Functor Case where fmap = fmapDefault instance Foldable Case where foldMap = foldMapDefault instance Traversable Case where traverse f (Case vs b) = Case vs <$> f b -- | Information about the possible aliases of a function result. data RetAls = RetAls { -- | Which of the parameters may be aliased, numbered from zero. -- Must be sorted in increasing order. paramAls :: [Int], -- | Which of the other results may be aliased, numbered from -- zero. This must be a reflexive relation. Must be sorted in -- increasing order. otherAls :: [Int] } deriving (Eq, Ord, Show) instance Monoid RetAls where mempty = RetAls mempty mempty instance Semigroup RetAls where RetAls pals1 rals1 <> RetAls pals2 rals2 = RetAls (pals1 <> pals2) (rals1 <> rals2) -- | The root Futhark expression type. The v'Op' constructor contains -- a rep-specific operation. Do-loops, branches and function calls -- are special. Everything else is a simple t'BasicOp'. data Exp rep = -- | A simple (non-recursive) operation. BasicOp BasicOp | Apply Name [(SubExp, Diet)] [(RetType rep, RetAls)] (Safety, SrcLoc, [SrcLoc]) | -- | A match statement picks a branch by comparing the given -- subexpressions (called the /scrutinee/) with the pattern in -- each of the cases. If none of the cases match, the /default -- body/ is picked. Match [SubExp] [Case (Body rep)] (Body rep) (MatchDec (BranchType rep)) | -- | @loop {a} = {v} (for i < n|while b) do b@. Loop [(FParam rep, SubExp)] LoopForm (Body rep) | -- | Create accumulators backed by the given arrays (which are -- consumed) and pass them to the lambda, which must return the -- updated accumulators and possibly some extra values. The -- accumulators are turned back into arrays. In the lambda, the result -- accumulators come first, and are ordered in a manner consistent with -- that of the input (accumulator) arguments. The t'Shape' is the -- write index space. The corresponding arrays must all have this -- shape outermost. This construct is not part of t'BasicOp' -- because we need the @rep@ parameter. WithAcc [WithAccInput rep] (Lambda rep) | Op (Op rep) deriving instance (RepTypes rep) => Eq (Exp rep) deriving instance (RepTypes rep) => Show (Exp rep) deriving instance (RepTypes rep) => Ord (Exp rep) -- | For-loop or while-loop? data LoopForm = ForLoop -- | The loop iterator var VName -- | The type of the loop iterator var IntType -- | The number of iterations. SubExp | WhileLoop VName deriving (Eq, Ord, Show) -- | Data associated with a branch. data MatchDec rt = MatchDec { matchReturns :: [rt], matchSort :: MatchSort } deriving (Eq, Show, Ord) -- | What kind of branch is this? This has no semantic meaning, but -- provides hints to simplifications. data MatchSort = -- | An ordinary branch. MatchNormal | -- | A branch where the "true" case is what we are -- actually interested in, and the "false" case is only -- present as a fallback for when the true case cannot -- be safely evaluated. The compiler is permitted to -- optimise away the branch if the true case contains -- only safe statements. MatchFallback | -- | Both of these branches are semantically equivalent, -- and it is fine to eliminate one if it turns out to -- have problems (e.g. contain things we cannot generate -- code for). MatchEquiv deriving (Eq, Show, Ord) -- | Anonymous function for use in a SOAC. data Lambda rep = Lambda { lambdaParams :: [LParam rep], lambdaReturnType :: [Type], lambdaBody :: Body rep } deriving instance (RepTypes rep) => Eq (Lambda rep) deriving instance (RepTypes rep) => Show (Lambda rep) deriving instance (RepTypes rep) => Ord (Lambda rep) -- | A function and loop parameter. type FParam rep = Param (FParamInfo rep) -- | A lambda parameter. type LParam rep = Param (LParamInfo rep) -- | Function definitions. data FunDef rep = FunDef { -- | Contains a value if this function is -- an entry point. funDefEntryPoint :: Maybe EntryPoint, funDefAttrs :: Attrs, funDefName :: Name, funDefRetType :: [(RetType rep, RetAls)], funDefParams :: [FParam rep], funDefBody :: Body rep } deriving instance (RepTypes rep) => Eq (FunDef rep) deriving instance (RepTypes rep) => Show (FunDef rep) deriving instance (RepTypes rep) => Ord (FunDef rep) -- | An entry point parameter, comprising its name and original type. data EntryParam = EntryParam { entryParamName :: Name, entryParamUniqueness :: Uniqueness, entryParamType :: EntryPointType } deriving (Eq, Show, Ord) -- | An entry point result type. data EntryResult = EntryResult { entryResultUniqueness :: Uniqueness, entryResultType :: EntryPointType } deriving (Eq, Show, Ord) -- | Information about the inputs and outputs (return value) of an entry -- point. type EntryPoint = (Name, [EntryParam], [EntryResult]) -- | An entire Futhark program. data Prog rep = Prog { -- | The opaque types used in entry points. This information is -- used to generate extra API functions for -- construction and deconstruction of values of these types. progTypes :: OpaqueTypes, -- | Top-level constants that are computed at program startup, and -- which are in scope inside all functions. progConsts :: Stms rep, -- | The functions comprising the program. All functions are also -- available in scope in the definitions of the constants, so be -- careful not to introduce circular dependencies (not currently -- checked). progFuns :: [FunDef rep] } deriving (Eq, Ord, Show) futhark-0.25.27/src/Futhark/IR/Syntax/000077500000000000000000000000001475065116200173355ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/IR/Syntax/Core.hs000066400000000000000000000446061475065116200205730ustar00rootroot00000000000000{-# LANGUAGE Strict #-} -- | The most primitive ("core") aspects of the AST. Split out of -- "Futhark.IR.Syntax" in order for -- "Futhark.IR.Rep" to use these definitions. This -- module is re-exported from "Futhark.IR.Syntax" and -- there should be no reason to include it explicitly. module Futhark.IR.Syntax.Core ( module Language.Futhark.Core, module Language.Futhark.Primitive, -- * Types Commutativity (..), Uniqueness (..), ShapeBase (..), Shape, stripDims, Ext (..), ExtSize, ExtShape, Rank (..), ArrayShape (..), Space (..), SpaceId, TypeBase (..), Type, ExtType, DeclType, DeclExtType, Diet (..), ErrorMsg (..), ErrorMsgPart (..), errorMsgArgTypes, -- * Entry point information ValueType (..), OpaqueType (..), OpaqueTypes (..), Signedness (..), EntryPointType (..), -- * Attributes Attr (..), Attrs (..), oneAttr, inAttrs, withoutAttrs, mapAttrs, -- * Values PrimValue (..), -- * Abstract syntax tree Ident (..), Certs (..), SubExp (..), Param (..), DimIndex (..), Slice (..), dimFix, sliceIndices, sliceDims, sliceShape, unitSlice, fixSlice, sliceSlice, PatElem (..), -- * Flat (LMAD) slices FlatSlice (..), FlatDimIndex (..), flatSliceDims, flatSliceStrides, ) where import Control.Category import Control.Monad import Control.Monad.State import Data.Bifoldable import Data.Bifunctor import Data.Bitraversable import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Data.String import Data.Text qualified as T import Data.Traversable (fmapDefault, foldMapDefault) import Language.Futhark.Core import Language.Futhark.Primitive import Prelude hiding (id, (.)) -- | Whether some operator is commutative or not. The 'Monoid' -- instance returns the least commutative of its arguments. data Commutativity = Noncommutative | Commutative deriving (Eq, Ord, Show) instance Semigroup Commutativity where (<>) = min instance Monoid Commutativity where mempty = Commutative -- | The size of an array type as a list of its dimension sizes, with -- the type of sizes being parametric. newtype ShapeBase d = Shape {shapeDims :: [d]} deriving (Eq, Ord, Show) instance Functor ShapeBase where fmap = fmapDefault instance Foldable ShapeBase where foldMap = foldMapDefault instance Traversable ShapeBase where traverse f = fmap Shape . traverse f . shapeDims instance Semigroup (ShapeBase d) where Shape l1 <> Shape l2 = Shape $ l1 `mappend` l2 instance Monoid (ShapeBase d) where mempty = Shape mempty -- | @stripDims n shape@ strips the outer @n@ dimensions from -- @shape@. stripDims :: Int -> ShapeBase d -> ShapeBase d stripDims n (Shape dims) = Shape $ drop n dims -- | The size of an array as a list of subexpressions. If a variable, -- that variable must be in scope where this array is used. type Shape = ShapeBase SubExp -- | Something that may be existential. data Ext a = Ext Int | Free a deriving (Eq, Ord, Show) instance Functor Ext where fmap = fmapDefault instance Foldable Ext where foldMap = foldMapDefault instance Traversable Ext where traverse _ (Ext i) = pure $ Ext i traverse f (Free v) = Free <$> f v -- | The size of this dimension. type ExtSize = Ext SubExp -- | Like t'Shape' but some of its elements may be bound in a local -- environment instead. These are denoted with integral indices. type ExtShape = ShapeBase ExtSize -- | The size of an array type as merely the number of dimensions, -- with no further information. newtype Rank = Rank Int deriving (Show, Eq, Ord) -- | A class encompassing types containing array shape information. class (Monoid a, Eq a, Ord a) => ArrayShape a where -- | Return the rank of an array with the given size. shapeRank :: a -> Int -- | Check whether one shape if a subset of another shape. subShapeOf :: a -> a -> Bool instance ArrayShape (ShapeBase SubExp) where shapeRank (Shape l) = length l subShapeOf = (==) instance ArrayShape (ShapeBase ExtSize) where shapeRank (Shape l) = length l subShapeOf (Shape ds1) (Shape ds2) = -- Must agree on Free dimensions, and ds1 may not be existential -- where ds2 is Free. Existentials must also be congruent. length ds1 == length ds2 && evalState (and <$> zipWithM subDimOf ds1 ds2) M.empty where subDimOf (Free se1) (Free se2) = pure $ se1 == se2 subDimOf (Ext _) (Free _) = pure False subDimOf (Free _) (Ext _) = pure True subDimOf (Ext x) (Ext y) = do extmap <- get case M.lookup y extmap of Just ywas | ywas == x -> pure True | otherwise -> pure False Nothing -> do put $ M.insert y x extmap pure True instance Semigroup Rank where Rank x <> Rank y = Rank $ x + y instance Monoid Rank where mempty = Rank 0 instance ArrayShape Rank where shapeRank (Rank x) = x subShapeOf = (==) -- | The memory space of a block. If 'DefaultSpace', this is the "default" -- space, whatever that is. The exact meaning of the 'SpaceId' -- depends on the backend used. In GPU kernels, for example, this is -- used to distinguish between constant, global and shared memory -- spaces. In GPU-enabled host code, it is used to distinguish -- between host memory ('DefaultSpace') and GPU space. data Space = DefaultSpace | Space SpaceId | -- | A special kind of memory that is a statically sized -- array of some primitive type. Used for private memory -- on GPUs. ScalarSpace [SubExp] PrimType deriving (Show, Eq, Ord) -- | A string representing a specific non-default memory space. type SpaceId = String -- | The type of a value. When comparing types for equality with -- '==', shapes must match. data TypeBase shape u = Prim PrimType | -- | Token, index space, element type, and uniqueness. Acc VName Shape [Type] u | Array PrimType shape u | Mem Space deriving (Show, Eq, Ord) instance Bitraversable TypeBase where bitraverse f g (Array t shape u) = Array t <$> f shape <*> g u bitraverse _ _ (Prim pt) = pure $ Prim pt bitraverse _ g (Acc arrs ispace ts u) = Acc arrs ispace ts <$> g u bitraverse _ _ (Mem s) = pure $ Mem s instance Functor (TypeBase shape) where fmap = fmapDefault instance Foldable (TypeBase shape) where foldMap = foldMapDefault instance Traversable (TypeBase shape) where traverse = bitraverse pure instance Bifunctor TypeBase where bimap = bimapDefault instance Bifoldable TypeBase where bifoldMap = bifoldMapDefault -- | A type with shape information, used for describing the type of -- variables. type Type = TypeBase Shape NoUniqueness -- | A type with existentially quantified shapes - used as part of -- function (and function-like) return types. Generally only makes -- sense when used in a list. type ExtType = TypeBase ExtShape NoUniqueness -- | A type with shape and uniqueness information, used declaring -- return- and parameters types. type DeclType = TypeBase Shape Uniqueness -- | An 'ExtType' with uniqueness information, used for function -- return types. type DeclExtType = TypeBase ExtShape Uniqueness -- | Information about which parts of a value/type are consumed. For -- example, we might say that a function taking three arguments of -- types @([int], *[int], [int])@ has diet @[Observe, Consume, -- Observe]@. data Diet = -- | Consumes this value. Consume | -- | Only observes value in this position, does -- not consume. A result may alias this. Observe | -- | As 'Observe', but the result will not -- alias, because the parameter does not carry -- aliases. ObservePrim deriving (Eq, Ord, Show) -- | An identifier consists of its name and the type of the value -- bound to the identifier. data Ident = Ident { identName :: VName, identType :: Type } deriving (Show) instance Eq Ident where x == y = identName x == identName y instance Ord Ident where x `compare` y = identName x `compare` identName y -- | A list of names used for certificates in some expressions. newtype Certs = Certs {unCerts :: [VName]} deriving (Eq, Ord, Show) instance Semigroup Certs where Certs x <> Certs y = Certs (x <> filter (`notElem` x) y) instance Monoid Certs where mempty = Certs mempty -- | A subexpression is either a scalar constant or a variable. One -- important property is that evaluation of a subexpression is -- guaranteed to complete in constant time. data SubExp = Constant PrimValue | Var VName deriving (Show, Eq, Ord) -- | A function or lambda parameter. data Param dec = Param { -- | Attributes of the parameter. When constructing a parameter, -- feel free to just pass 'mempty'. paramAttrs :: Attrs, -- | Name of the parameter. paramName :: VName, -- | Function parameter decoration. paramDec :: dec } deriving (Ord, Show, Eq) instance Foldable Param where foldMap = foldMapDefault instance Functor Param where fmap = fmapDefault instance Traversable Param where traverse f (Param attr name dec) = Param attr name <$> f dec -- | How to index a single dimension of an array. data DimIndex d = -- | Fix index in this dimension. DimFix d | -- | @DimSlice start_offset num_elems stride@. DimSlice d d d deriving (Eq, Ord, Show) instance Functor DimIndex where fmap f (DimFix i) = DimFix $ f i fmap f (DimSlice i j s) = DimSlice (f i) (f j) (f s) instance Foldable DimIndex where foldMap f (DimFix d) = f d foldMap f (DimSlice i j s) = f i <> f j <> f s instance Traversable DimIndex where traverse f (DimFix d) = DimFix <$> f d traverse f (DimSlice i j s) = DimSlice <$> f i <*> f j <*> f s -- | A list of 'DimIndex's, indicating how an array should be sliced. -- Whenever a function accepts a 'Slice', that slice should be total, -- i.e, cover all dimensions of the array. Deviators should be -- indicated by taking a list of 'DimIndex'es instead. newtype Slice d = Slice {unSlice :: [DimIndex d]} deriving (Eq, Ord, Show) instance Traversable Slice where traverse f = fmap Slice . traverse (traverse f) . unSlice instance Functor Slice where fmap = fmapDefault instance Foldable Slice where foldMap = foldMapDefault -- | If the argument is a 'DimFix', return its component. dimFix :: DimIndex d -> Maybe d dimFix (DimFix d) = Just d dimFix _ = Nothing -- | If the slice is all 'DimFix's, return the components. sliceIndices :: Slice d -> Maybe [d] sliceIndices = mapM dimFix . unSlice -- | The dimensions of the array produced by this slice. sliceDims :: Slice d -> [d] sliceDims = mapMaybe dimSlice . unSlice where dimSlice (DimSlice _ d _) = Just d dimSlice DimFix {} = Nothing -- | The shape of the array produced by this slice. sliceShape :: Slice d -> ShapeBase d sliceShape = Shape . sliceDims -- | A slice with a stride of one. unitSlice :: (Num d) => d -> d -> DimIndex d unitSlice offset n = DimSlice offset n 1 -- | Fix the 'DimSlice's of a slice. The number of indexes must equal -- the length of 'sliceDims' for the slice. fixSlice :: (Num d) => Slice d -> [d] -> [d] fixSlice = fixSlice' . unSlice where fixSlice' (DimFix j : mis') is' = j : fixSlice' mis' is' fixSlice' (DimSlice orig_k _ orig_s : mis') (i : is') = (orig_k + i * orig_s) : fixSlice' mis' is' fixSlice' _ _ = [] -- | Further slice the 'DimSlice's of a slice. The number of slices -- must equal the length of 'sliceDims' for the slice. sliceSlice :: (Num d) => Slice d -> Slice d -> Slice d sliceSlice (Slice jslice) (Slice islice) = Slice $ sliceSlice' jslice islice where sliceSlice' (DimFix j : js') is' = DimFix j : sliceSlice' js' is' sliceSlice' (DimSlice j _ s : js') (DimFix i : is') = DimFix (j + (i * s)) : sliceSlice' js' is' sliceSlice' (DimSlice j _ s0 : js') (DimSlice i n s1 : is') = DimSlice (j + (s0 * i)) n (s0 * s1) : sliceSlice' js' is' sliceSlice' _ _ = [] -- | A dimension in a 'FlatSlice'. data FlatDimIndex d = FlatDimIndex -- | Number of elements in dimension d -- | Stride of dimension d deriving (Eq, Ord, Show) instance Traversable FlatDimIndex where traverse f (FlatDimIndex n s) = FlatDimIndex <$> f n <*> f s instance Functor FlatDimIndex where fmap = fmapDefault instance Foldable FlatDimIndex where foldMap = foldMapDefault -- | A flat slice is a way of viewing a one-dimensional array as a -- multi-dimensional array, using a more compressed mechanism than -- reshaping and using 'Slice'. The initial @d@ is an offset, and the -- list then specifies the shape of the resulting array. data FlatSlice d = FlatSlice d [FlatDimIndex d] deriving (Eq, Ord, Show) instance Traversable FlatSlice where traverse f (FlatSlice offset is) = FlatSlice <$> f offset <*> traverse (traverse f) is instance Functor FlatSlice where fmap = fmapDefault instance Foldable FlatSlice where foldMap = foldMapDefault -- | The dimensions (shape) of the view produced by a flat slice. flatSliceDims :: FlatSlice d -> [d] flatSliceDims (FlatSlice _ ds) = map dimSlice ds where dimSlice (FlatDimIndex n _) = n -- | The strides of each dimension produced by a flat slice. flatSliceStrides :: FlatSlice d -> [d] flatSliceStrides (FlatSlice _ ds) = map dimStride ds where dimStride (FlatDimIndex _ s) = s -- | An element of a pattern - consisting of a name and an addditional -- parametric decoration. This decoration is what is expected to -- contain the type of the resulting variable. data PatElem dec = PatElem { -- | The name being bound. patElemName :: VName, -- | Pat element decoration. patElemDec :: dec } deriving (Ord, Show, Eq) instance Functor PatElem where fmap = fmapDefault instance Foldable PatElem where foldMap = foldMapDefault instance Traversable PatElem where traverse f (PatElem name dec) = PatElem name <$> f dec -- | An error message is a list of error parts, which are concatenated -- to form the final message. newtype ErrorMsg a = ErrorMsg [ErrorMsgPart a] deriving (Eq, Ord, Show) instance IsString (ErrorMsg a) where fromString = ErrorMsg . pure . fromString instance Monoid (ErrorMsg a) where mempty = ErrorMsg mempty instance Semigroup (ErrorMsg a) where ErrorMsg x <> ErrorMsg y = ErrorMsg $ x <> y -- | A part of an error message. data ErrorMsgPart a = -- | A literal string. ErrorString T.Text | -- | A run-time value. ErrorVal PrimType a deriving (Eq, Ord, Show) instance IsString (ErrorMsgPart a) where fromString = ErrorString . T.pack instance Functor ErrorMsg where fmap f (ErrorMsg parts) = ErrorMsg $ map (fmap f) parts instance Foldable ErrorMsg where foldMap f (ErrorMsg parts) = foldMap (foldMap f) parts instance Traversable ErrorMsg where traverse f (ErrorMsg parts) = ErrorMsg <$> traverse (traverse f) parts instance Functor ErrorMsgPart where fmap = fmapDefault instance Foldable ErrorMsgPart where foldMap = foldMapDefault instance Traversable ErrorMsgPart where traverse _ (ErrorString s) = pure $ ErrorString s traverse f (ErrorVal t a) = ErrorVal t <$> f a -- | How many non-constant parts does the error message have, and what -- is their type? errorMsgArgTypes :: ErrorMsg a -> [PrimType] errorMsgArgTypes (ErrorMsg parts) = mapMaybe onPart parts where onPart ErrorString {} = Nothing onPart (ErrorVal t _) = Just t -- | A single attribute. data Attr = AttrName Name | AttrInt Integer | AttrComp Name [Attr] deriving (Ord, Show, Eq) instance IsString Attr where fromString = AttrName . fromString -- | Every statement is associated with a set of attributes, which can -- have various effects throughout the compiler. newtype Attrs = Attrs {unAttrs :: S.Set Attr} deriving (Ord, Show, Eq, Monoid, Semigroup) -- | Construct 'Attrs' from a single 'Attr'. oneAttr :: Attr -> Attrs oneAttr = Attrs . S.singleton -- | Is the given attribute to be found in the attribute set? inAttrs :: Attr -> Attrs -> Bool inAttrs attr (Attrs attrs) = attr `S.member` attrs -- | @x `withoutAttrs` y@ gives @x@ except for any attributes also in @y@. withoutAttrs :: Attrs -> Attrs -> Attrs withoutAttrs (Attrs x) (Attrs y) = Attrs $ x `S.difference` y -- | Map a function over an attribute set. mapAttrs :: (Attr -> a) -> Attrs -> [a] mapAttrs f (Attrs attrs) = map f $ S.toList attrs -- | Since the core language does not care for signedness, but the -- source language does, entry point input/output information has -- metadata for integer types (and arrays containing these) that -- indicate whether they are really unsigned integers. This doesn't -- matter for non-integer types. data Signedness = Unsigned | Signed deriving (Eq, Ord, Show) -- | An actual non-opaque type that can be passed to and from Futhark -- programs, or serve as the contents of opaque types. Scalars are -- represented with zero rank. data ValueType = ValueType Signedness Rank PrimType deriving (Eq, Ord, Show) -- | Every entry point argument and return value has an annotation -- indicating how it maps to the original source program type. data EntryPointType = -- | An opaque type of this name. TypeOpaque Name | -- | A transparent type, which is scalar if the rank is zero. TypeTransparent ValueType deriving (Eq, Show, Ord) -- | The representation of an opaque type. data OpaqueType = OpaqueType [ValueType] | -- | Note that the field ordering here denote the actual -- representation - make sure it is preserved. OpaqueRecord [(Name, EntryPointType)] | -- | Constructor ordering also denotes representation, in that the -- index of the constructor is the identifying number. -- -- The total values used to represent a sum values is the -- 'ValueType' list. The 'Int's associated with each -- 'EntryPointType' are the indexes of the values used to -- represent that constructor payload. This is necessary because -- we deduplicate payloads across constructors. OpaqueSum [ValueType] [(Name, [(EntryPointType, [Int])])] | -- | An array with this rank and named opaque element type. OpaqueArray Int Name [ValueType] | -- | An array with known rank and where the elements are this -- record type. OpaqueRecordArray Int Name [(Name, EntryPointType)] deriving (Eq, Ord, Show) -- | Names of opaque types and their representation. newtype OpaqueTypes = OpaqueTypes [(Name, OpaqueType)] deriving (Eq, Ord, Show) instance Monoid OpaqueTypes where mempty = OpaqueTypes mempty instance Semigroup OpaqueTypes where OpaqueTypes x <> OpaqueTypes y = OpaqueTypes $ x <> filter ((`notElem` map fst x) . fst) y futhark-0.25.27/src/Futhark/IR/Traversals.hs000066400000000000000000000306751475065116200205440ustar00rootroot00000000000000-- | -- -- Functions for generic traversals across Futhark syntax trees. The -- motivation for this module came from dissatisfaction with rewriting -- the same trivial tree recursions for every module. A possible -- alternative would be to use normal \"Scrap your -- boilerplate\"-techniques, but these are rejected for two reasons: -- -- * They are too slow. -- -- * More importantly, they do not tell you whether you have missed -- some cases. -- -- Instead, this module defines various traversals of the Futhark syntax -- tree. The implementation is rather tedious, but the interface is -- easy to use. -- -- A traversal of the Futhark syntax tree is expressed as a record of -- functions expressing the operations to be performed on the various -- types of nodes. -- -- The "Futhark.Transform.Rename" module is a simple example of how to -- use this facility. module Futhark.IR.Traversals ( -- * Mapping Mapper (..), identityMapper, mapExpM, mapExp, -- * Walking Walker (..), identityWalker, walkExpM, -- * Ops TraverseOpStms (..), OpStmsTraverser, traverseLambdaStms, ) where import Control.Monad import Control.Monad.Identity import Data.Bitraversable import Data.Foldable (traverse_) import Data.List.NonEmpty (NonEmpty (..)) import Futhark.IR.Prop.Scope import Futhark.IR.Prop.Types (mapOnType) import Futhark.IR.Syntax -- | Express a monad mapping operation on a syntax node. Each element -- of this structure expresses the operation to be performed on a -- given child. data Mapper frep trep m = Mapper { mapOnSubExp :: SubExp -> m SubExp, -- | Most bodies are enclosed in a scope, which is passed along -- for convenience. mapOnBody :: Scope trep -> Body frep -> m (Body trep), mapOnVName :: VName -> m VName, mapOnRetType :: RetType frep -> m (RetType trep), mapOnBranchType :: BranchType frep -> m (BranchType trep), mapOnFParam :: FParam frep -> m (FParam trep), mapOnLParam :: LParam frep -> m (LParam trep), mapOnOp :: Op frep -> m (Op trep) } -- | A mapper that simply returns the tree verbatim. identityMapper :: forall rep m. (Monad m) => Mapper rep rep m identityMapper = Mapper { mapOnSubExp = pure, mapOnBody = const pure, mapOnVName = pure, mapOnRetType = pure, mapOnBranchType = pure, mapOnFParam = pure, mapOnLParam = pure, mapOnOp = pure } -- | Map a monadic action across the immediate children of an -- expression. Importantly, the mapping does not descend recursively -- into subexpressions. The mapping is done left-to-right. mapExpM :: (Monad m) => Mapper frep trep m -> Exp frep -> m (Exp trep) mapExpM tv (BasicOp (SubExp se)) = BasicOp <$> (SubExp <$> mapOnSubExp tv se) mapExpM _ (BasicOp (ArrayVal vs t)) = pure $ BasicOp $ ArrayVal vs t mapExpM tv (BasicOp (ArrayLit els rowt)) = BasicOp <$> ( ArrayLit <$> mapM (mapOnSubExp tv) els <*> mapOnType (mapOnSubExp tv) rowt ) mapExpM tv (BasicOp (BinOp bop x y)) = BasicOp <$> (BinOp bop <$> mapOnSubExp tv x <*> mapOnSubExp tv y) mapExpM tv (BasicOp (CmpOp op x y)) = BasicOp <$> (CmpOp op <$> mapOnSubExp tv x <*> mapOnSubExp tv y) mapExpM tv (BasicOp (ConvOp conv x)) = BasicOp <$> (ConvOp conv <$> mapOnSubExp tv x) mapExpM tv (BasicOp (UnOp op x)) = BasicOp <$> (UnOp op <$> mapOnSubExp tv x) mapExpM tv (Match ses cases defbody (MatchDec ts s)) = Match <$> mapM (mapOnSubExp tv) ses <*> mapM mapOnCase cases <*> mapOnBody tv mempty defbody <*> (MatchDec <$> mapM (mapOnBranchType tv) ts <*> pure s) where mapOnCase (Case vs body) = Case vs <$> mapOnBody tv mempty body mapExpM tv (Apply fname args ret loc) = do args' <- forM args $ \(arg, d) -> (,) <$> mapOnSubExp tv arg <*> pure d Apply fname args' <$> mapM (bitraverse (mapOnRetType tv) pure) ret <*> pure loc mapExpM tv (BasicOp (Index arr slice)) = BasicOp <$> ( Index <$> mapOnVName tv arr <*> traverse (mapOnSubExp tv) slice ) mapExpM tv (BasicOp (Update safety arr slice se)) = BasicOp <$> ( Update safety <$> mapOnVName tv arr <*> traverse (mapOnSubExp tv) slice <*> mapOnSubExp tv se ) mapExpM tv (BasicOp (FlatIndex arr slice)) = BasicOp <$> ( FlatIndex <$> mapOnVName tv arr <*> traverse (mapOnSubExp tv) slice ) mapExpM tv (BasicOp (FlatUpdate arr slice se)) = BasicOp <$> ( FlatUpdate <$> mapOnVName tv arr <*> traverse (mapOnSubExp tv) slice <*> mapOnVName tv se ) mapExpM tv (BasicOp (Iota n x s et)) = BasicOp <$> (Iota <$> mapOnSubExp tv n <*> mapOnSubExp tv x <*> mapOnSubExp tv s <*> pure et) mapExpM tv (BasicOp (Replicate shape vexp)) = BasicOp <$> (Replicate <$> mapOnShape tv shape <*> mapOnSubExp tv vexp) mapExpM tv (BasicOp (Scratch t shape)) = BasicOp <$> (Scratch t <$> mapM (mapOnSubExp tv) shape) mapExpM tv (BasicOp (Reshape kind shape arrexp)) = BasicOp <$> ( Reshape kind <$> mapM (mapOnSubExp tv) shape <*> mapOnVName tv arrexp ) mapExpM tv (BasicOp (Rearrange perm e)) = BasicOp <$> (Rearrange perm <$> mapOnVName tv e) mapExpM tv (BasicOp (Concat i (x :| ys) size)) = do x' <- mapOnVName tv x ys' <- mapM (mapOnVName tv) ys size' <- mapOnSubExp tv size pure $ BasicOp $ Concat i (x' :| ys') size' mapExpM tv (BasicOp (Manifest perm e)) = BasicOp <$> (Manifest perm <$> mapOnVName tv e) mapExpM tv (BasicOp (Assert e msg loc)) = BasicOp <$> (Assert <$> mapOnSubExp tv e <*> traverse (mapOnSubExp tv) msg <*> pure loc) mapExpM tv (BasicOp (Opaque op e)) = BasicOp <$> (Opaque op <$> mapOnSubExp tv e) mapExpM tv (BasicOp (UpdateAcc safety v is ses)) = BasicOp <$> ( UpdateAcc safety <$> mapOnVName tv v <*> mapM (mapOnSubExp tv) is <*> mapM (mapOnSubExp tv) ses ) mapExpM tv (WithAcc inputs lam) = WithAcc <$> mapM onInput inputs <*> mapOnLambda tv lam where onInput (shape, vs, op) = (,,) <$> mapOnShape tv shape <*> mapM (mapOnVName tv) vs <*> traverse (bitraverse (mapOnLambda tv) (mapM (mapOnSubExp tv))) op mapExpM tv (Loop merge form loopbody) = do params' <- mapM (mapOnFParam tv) params form' <- mapOnLoopForm tv form let scope = scopeOfLoopForm form' <> scopeOfFParams params' Loop <$> (zip params' <$> mapM (mapOnSubExp tv) args) <*> pure form' <*> mapOnBody tv scope loopbody where (params, args) = unzip merge mapExpM tv (Op op) = Op <$> mapOnOp tv op mapOnShape :: (Monad m) => Mapper frep trep m -> Shape -> m Shape mapOnShape tv (Shape ds) = Shape <$> mapM (mapOnSubExp tv) ds mapOnLoopForm :: (Monad m) => Mapper frep trep m -> LoopForm -> m LoopForm mapOnLoopForm tv (ForLoop i it bound) = ForLoop <$> mapOnVName tv i <*> pure it <*> mapOnSubExp tv bound mapOnLoopForm tv (WhileLoop cond) = WhileLoop <$> mapOnVName tv cond mapOnLambda :: (Monad m) => Mapper frep trep m -> Lambda frep -> m (Lambda trep) mapOnLambda tv (Lambda params ret body) = do params' <- mapM (mapOnLParam tv) params Lambda params' <$> mapM (mapOnType (mapOnSubExp tv)) ret <*> mapOnBody tv (scopeOfLParams params') body -- | Like 'mapExpM', but in the 'Identity' monad. mapExp :: Mapper frep trep Identity -> Exp frep -> Exp trep mapExp m = runIdentity . mapExpM m -- | Express a monad expression on a syntax node. Each element of -- this structure expresses the action to be performed on a given -- child. data Walker rep m = Walker { walkOnSubExp :: SubExp -> m (), walkOnBody :: Scope rep -> Body rep -> m (), walkOnVName :: VName -> m (), walkOnRetType :: RetType rep -> m (), walkOnBranchType :: BranchType rep -> m (), walkOnFParam :: FParam rep -> m (), walkOnLParam :: LParam rep -> m (), walkOnOp :: Op rep -> m () } -- | A no-op traversal. identityWalker :: forall rep m. (Monad m) => Walker rep m identityWalker = Walker { walkOnSubExp = const $ pure (), walkOnBody = const $ const $ pure (), walkOnVName = const $ pure (), walkOnRetType = const $ pure (), walkOnBranchType = const $ pure (), walkOnFParam = const $ pure (), walkOnLParam = const $ pure (), walkOnOp = const $ pure () } walkOnShape :: (Monad m) => Walker rep m -> Shape -> m () walkOnShape tv (Shape ds) = mapM_ (walkOnSubExp tv) ds walkOnType :: (Monad m) => Walker rep m -> Type -> m () walkOnType _ Prim {} = pure () walkOnType tv (Acc acc ispace ts _) = do walkOnVName tv acc traverse_ (walkOnSubExp tv) ispace mapM_ (walkOnType tv) ts walkOnType _ Mem {} = pure () walkOnType tv (Array _ shape _) = walkOnShape tv shape walkOnLoopForm :: (Monad m) => Walker rep m -> LoopForm -> m () walkOnLoopForm tv (ForLoop i _ bound) = do walkOnVName tv i walkOnSubExp tv bound walkOnLoopForm tv (WhileLoop cond) = walkOnVName tv cond walkOnLambda :: (Monad m) => Walker rep m -> Lambda rep -> m () walkOnLambda tv (Lambda params ret body) = do mapM_ (walkOnLParam tv) params walkOnBody tv (scopeOfLParams params) body mapM_ (walkOnType tv) ret -- | As 'mapExpM', but do not construct a result AST. walkExpM :: (Monad m) => Walker rep m -> Exp rep -> m () walkExpM tv (BasicOp (SubExp se)) = walkOnSubExp tv se walkExpM _ (BasicOp ArrayVal {}) = pure () walkExpM tv (BasicOp (ArrayLit els rowt)) = mapM_ (walkOnSubExp tv) els >> walkOnType tv rowt walkExpM tv (BasicOp (BinOp _ x y)) = walkOnSubExp tv x >> walkOnSubExp tv y walkExpM tv (BasicOp (CmpOp _ x y)) = walkOnSubExp tv x >> walkOnSubExp tv y walkExpM tv (BasicOp (ConvOp _ x)) = walkOnSubExp tv x walkExpM tv (BasicOp (UnOp _ x)) = walkOnSubExp tv x walkExpM tv (Match ses cases defbody (MatchDec ts _)) = do mapM_ (walkOnSubExp tv) ses mapM_ (walkOnBody tv mempty . caseBody) cases walkOnBody tv mempty defbody mapM_ (walkOnBranchType tv) ts walkExpM tv (Apply _ args ret _) = do mapM_ (walkOnSubExp tv . fst) args mapM_ (walkOnRetType tv . fst) ret walkExpM tv (BasicOp (Index arr slice)) = walkOnVName tv arr >> traverse_ (walkOnSubExp tv) slice walkExpM tv (BasicOp (Update _ arr slice se)) = walkOnVName tv arr >> traverse_ (walkOnSubExp tv) slice >> walkOnSubExp tv se walkExpM tv (BasicOp (FlatIndex arr slice)) = walkOnVName tv arr >> traverse_ (walkOnSubExp tv) slice walkExpM tv (BasicOp (FlatUpdate arr slice se)) = walkOnVName tv arr >> traverse_ (walkOnSubExp tv) slice >> walkOnVName tv se walkExpM tv (BasicOp (Iota n x s _)) = walkOnSubExp tv n >> walkOnSubExp tv x >> walkOnSubExp tv s walkExpM tv (BasicOp (Replicate shape vexp)) = walkOnShape tv shape >> walkOnSubExp tv vexp walkExpM tv (BasicOp (Scratch _ shape)) = mapM_ (walkOnSubExp tv) shape walkExpM tv (BasicOp (Reshape _ shape arrexp)) = mapM_ (walkOnSubExp tv) shape >> walkOnVName tv arrexp walkExpM tv (BasicOp (Rearrange _ e)) = walkOnVName tv e walkExpM tv (BasicOp (Concat _ (x :| ys) size)) = walkOnVName tv x >> mapM_ (walkOnVName tv) ys >> walkOnSubExp tv size walkExpM tv (BasicOp (Manifest _ e)) = walkOnVName tv e walkExpM tv (BasicOp (Assert e msg _)) = walkOnSubExp tv e >> traverse_ (walkOnSubExp tv) msg walkExpM tv (BasicOp (Opaque _ e)) = walkOnSubExp tv e walkExpM tv (BasicOp (UpdateAcc _ v is ses)) = do walkOnVName tv v mapM_ (walkOnSubExp tv) is mapM_ (walkOnSubExp tv) ses walkExpM tv (WithAcc inputs lam) = do forM_ inputs $ \(shape, vs, op) -> do walkOnShape tv shape mapM_ (walkOnVName tv) vs traverse_ (bitraverse (walkOnLambda tv) (mapM (walkOnSubExp tv))) op walkOnLambda tv lam walkExpM tv (Loop merge form loopbody) = do mapM_ (walkOnFParam tv) params walkOnLoopForm tv form mapM_ (walkOnSubExp tv) args let scope = scopeOfFParams params <> scopeOfLoopForm form walkOnBody tv scope loopbody where (params, args) = unzip merge walkExpM tv (Op op) = walkOnOp tv op -- | A function for monadically traversing any sub-statements of the -- given op for some representation. type OpStmsTraverser m op rep = (Scope rep -> Stms rep -> m (Stms rep)) -> op -> m op -- | This representation supports an 'OpStmsTraverser' for its t'Op'. -- This is used for some simplification rules. class TraverseOpStms rep where -- | Transform every sub-'Stms' of this op. traverseOpStms :: (Monad m) => OpStmsTraverser m (Op rep) rep -- | A helper for defining 'traverseOpStms'. traverseLambdaStms :: (Monad m) => OpStmsTraverser m (Lambda rep) rep traverseLambdaStms f (Lambda ps ret (Body dec stms res)) = Lambda ps ret <$> (Body dec <$> f (scopeOfLParams ps) stms <*> pure res) futhark-0.25.27/src/Futhark/IR/TypeCheck.hs000066400000000000000000001343141475065116200202700ustar00rootroot00000000000000{-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE TypeFamilies #-} -- | The type checker checks whether the program is type-consistent. module Futhark.IR.TypeCheck ( -- * Interface checkProg, TypeError (..), ErrorCase (..), -- * Extensionality TypeM, bad, context, Checkable (..), lookupVar, lookupAliases, checkOpWith, -- * Checkers require, requireI, requirePrimExp, checkSubExp, checkCerts, checkExp, checkStms, checkStm, checkSlice, checkType, checkExtType, matchExtPat, matchExtBranchType, argType, noArgAliases, checkArg, checkSOACArrayArgs, checkLambda, checkBody, consume, binding, alternative, ) where import Control.Monad import Control.Monad.Reader import Control.Monad.State.Strict import Control.Parallel.Strategies import Data.Bifunctor (first) import Data.List (find, intercalate, isPrefixOf, sort) import Data.List.NonEmpty (NonEmpty (..)) import Data.Map.Strict qualified as M import Data.Maybe import Data.Text qualified as T import Futhark.Analysis.Alias import Futhark.Analysis.PrimExp import Futhark.Construct (instantiateShapes) import Futhark.IR.Aliases hiding (lookupAliases) import Futhark.Util import Futhark.Util.Pretty (align, docText, indent, ppTuple', pretty, (<+>), ()) -- | Information about an error during type checking. The 'Show' -- instance for this type produces a human-readable description. data ErrorCase rep = TypeError T.Text | UnexpectedType (Exp rep) Type [Type] | ReturnTypeError Name [ExtType] [ExtType] | DupDefinitionError Name | DupParamError Name VName | DupPatError VName | InvalidPatError (Pat (LetDec (Aliases rep))) [ExtType] (Maybe String) | UnknownVariableError VName | UnknownFunctionError Name | ParameterMismatch (Maybe Name) [Type] [Type] | SlicingError Int Int | BadAnnotation String Type Type | ReturnAliased Name VName | UniqueReturnAliased Name | NotAnArray VName Type | PermutationError [Int] Int (Maybe VName) instance (Checkable rep) => Show (ErrorCase rep) where show (TypeError msg) = "Type error:\n" ++ T.unpack msg show (UnexpectedType e _ []) = "Type of expression\n" ++ T.unpack (docText $ indent 2 $ pretty e) ++ "\ncannot have any type - possibly a bug in the type checker." show (UnexpectedType e t ts) = "Type of expression\n" ++ T.unpack (docText $ indent 2 $ pretty e) ++ "\nmust be one of " ++ intercalate ", " (map prettyString ts) ++ ", but is " ++ prettyString t ++ "." show (ReturnTypeError fname rettype bodytype) = "Declaration of function " ++ nameToString fname ++ " declares return type\n " ++ T.unpack (prettyTuple rettype) ++ "\nBut body has type\n " ++ T.unpack (prettyTuple bodytype) show (DupDefinitionError name) = "Duplicate definition of function " ++ nameToString name show (DupParamError funname paramname) = "Parameter " ++ prettyString paramname ++ " mentioned multiple times in argument list of function " ++ nameToString funname ++ "." show (DupPatError name) = "Variable " ++ prettyString name ++ " bound twice in pattern." show (InvalidPatError pat t desc) = "Pat\n" ++ prettyString pat ++ "\ncannot match value of type\n" ++ T.unpack (prettyTupleLines t) ++ end where end = case desc of Nothing -> "." Just desc' -> ":\n" ++ desc' show (UnknownVariableError name) = "Use of unknown variable " ++ prettyString name ++ "." show (UnknownFunctionError fname) = "Call of unknown function " ++ nameToString fname ++ "." show (ParameterMismatch fname expected got) = "In call of " ++ fname' ++ ":\n" ++ "expecting " ++ show nexpected ++ " arguments of type(s)\n" ++ intercalate ", " (map prettyString expected) ++ "\nGot " ++ show ngot ++ " arguments of types\n" ++ intercalate ", " (map prettyString got) where nexpected = length expected ngot = length got fname' = maybe "anonymous function" (("function " ++) . nameToString) fname show (SlicingError dims got) = show got ++ " indices given, but type of indexee has " ++ show dims ++ " dimension(s)." show (BadAnnotation desc expected got) = "Annotation of \"" ++ desc ++ "\" type of expression is " ++ prettyString expected ++ ", but derived to be " ++ prettyString got ++ "." show (ReturnAliased fname name) = "Unique return value of function " ++ nameToString fname ++ " is aliased to " ++ prettyString name ++ ", which is not consumed." show (UniqueReturnAliased fname) = "A unique tuple element of return value of function " ++ nameToString fname ++ " is aliased to some other tuple component." show (NotAnArray e t) = "The expression " ++ prettyString e ++ " is expected to be an array, but is " ++ prettyString t ++ "." show (PermutationError perm rank name) = "The permutation (" ++ intercalate ", " (map show perm) ++ ") is not valid for array " ++ name' ++ "of rank " ++ show rank ++ "." where name' = maybe "" ((++ " ") . prettyString) name -- | A type error. data TypeError rep = Error [T.Text] (ErrorCase rep) instance (Checkable rep) => Show (TypeError rep) where show (Error [] err) = show err show (Error msgs err) = intercalate "\n" (map T.unpack msgs) ++ "\n" ++ show err -- | A tuple of a return type and a list of parameters, possibly -- named. type FunBinding rep = ([(RetType (Aliases rep), RetAls)], [FParam (Aliases rep)]) type VarBinding rep = NameInfo (Aliases rep) data Usage = Consumed | Observed deriving (Eq, Ord, Show) data Occurence = Occurence { observed :: Names, consumed :: Names } deriving (Eq, Show) observation :: Names -> Occurence observation = flip Occurence mempty consumption :: Names -> Occurence consumption = Occurence mempty nullOccurence :: Occurence -> Bool nullOccurence occ = observed occ == mempty && consumed occ == mempty type Occurences = [Occurence] allConsumed :: Occurences -> Names allConsumed = mconcat . map consumed seqOccurences :: Occurences -> Occurences -> Occurences seqOccurences occurs1 occurs2 = filter (not . nullOccurence) (map filt occurs1) ++ occurs2 where filt occ = occ {observed = observed occ `namesSubtract` postcons} postcons = allConsumed occurs2 altOccurences :: Occurences -> Occurences -> Occurences altOccurences occurs1 occurs2 = filter (not . nullOccurence) (map filt occurs1) ++ occurs2 where filt occ = occ { consumed = consumed occ `namesSubtract` postcons, observed = observed occ `namesSubtract` postcons } postcons = allConsumed occurs2 unOccur :: Names -> Occurences -> Occurences unOccur to_be_removed = filter (not . nullOccurence) . map unOccur' where unOccur' occ = occ { observed = observed occ `namesSubtract` to_be_removed, consumed = consumed occ `namesSubtract` to_be_removed } -- | The 'Consumption' data structure is used to keep track of which -- variables have been consumed, as well as whether a violation has been detected. data Consumption = ConsumptionError T.Text | Consumption Occurences deriving (Show) instance Semigroup Consumption where ConsumptionError e <> _ = ConsumptionError e _ <> ConsumptionError e = ConsumptionError e Consumption o1 <> Consumption o2 | v : _ <- namesToList $ consumed_in_o1 `namesIntersection` used_in_o2 = ConsumptionError $ "Variable " <> prettyText v <> " referenced after being consumed." | otherwise = Consumption $ o1 `seqOccurences` o2 where consumed_in_o1 = mconcat $ map consumed o1 used_in_o2 = mconcat $ map consumed o2 <> map observed o2 instance Monoid Consumption where mempty = Consumption mempty -- | The environment contains a variable table and a function table. -- Type checking happens with access to this environment. The -- function table is only initialised at the very beginning, but the -- variable table will be extended during type-checking when -- let-expressions are encountered. data Env rep = Env { envVtable :: M.Map VName (VarBinding rep), envFtable :: M.Map Name (FunBinding rep), envCheckOp :: Op (Aliases rep) -> TypeM rep (), envContext :: [T.Text] } data TState = TState { stateNames :: Names, stateCons :: Consumption } -- | The type checker runs in this monad. newtype TypeM rep a = TypeM (ReaderT (Env rep) (StateT TState (Either (TypeError rep))) a) deriving ( Monad, Functor, Applicative, MonadReader (Env rep), MonadState TState ) instance (Checkable rep) => HasScope (Aliases rep) (TypeM rep) where lookupType = fmap typeOf . lookupVar askScope = asks $ M.fromList . mapMaybe varType . M.toList . envVtable where varType (name, dec) = Just (name, dec) runTypeM :: Env rep -> TypeM rep a -> Either (TypeError rep) a runTypeM env (TypeM m) = evalStateT (runReaderT m env) (TState mempty mempty) -- | Signal a type error. bad :: ErrorCase rep -> TypeM rep a bad e = do messages <- asks envContext TypeM $ lift $ lift $ Left $ Error (reverse messages) e tell :: Consumption -> TypeM rep () tell cons = modify $ \s -> s {stateCons = stateCons s <> cons} -- | Add information about what is being type-checked to the current -- context. Liberal use of this combinator makes it easier to track -- type errors, as the strings are added to type errors signalled via -- 'bad'. context :: T.Text -> TypeM rep a -> TypeM rep a context s = local $ \env -> env {envContext = s : envContext env} message :: (Pretty a) => T.Text -> a -> T.Text message s x = docText $ pretty s <+> align (pretty x) -- | Mark a name as bound. If the name has been bound previously in -- the program, report a type error. bound :: VName -> TypeM rep () bound name = do already_seen <- gets $ nameIn name . stateNames when already_seen . bad . TypeError $ "Name " <> prettyText name <> " bound twice" modify $ \s -> s {stateNames = oneName name <> stateNames s} occur :: Occurences -> TypeM rep () occur = tell . Consumption . filter (not . nullOccurence) -- | Proclaim that we have made read-only use of the given variable. -- No-op unless the variable is array-typed. observe :: (Checkable rep) => VName -> TypeM rep () observe name = do dec <- lookupVar name unless (primType $ typeOf dec) $ occur [observation $ oneName name <> aliases dec] -- | Proclaim that we have written to the given variables. consume :: (Checkable rep) => Names -> TypeM rep () consume als = do scope <- askScope let isArray = maybe False (not . primType . typeOf) . (`M.lookup` scope) occur [consumption $ namesFromList $ filter isArray $ namesToList als] collectOccurences :: TypeM rep a -> TypeM rep (a, Occurences) collectOccurences m = do old <- gets stateCons modify $ \s -> s {stateCons = mempty} x <- m new <- gets stateCons modify $ \s -> s {stateCons = old} o <- checkConsumption new pure (x, o) checkOpWith :: (Op (Aliases rep) -> TypeM rep ()) -> TypeM rep a -> TypeM rep a checkOpWith checker = local $ \env -> env {envCheckOp = checker} checkConsumption :: Consumption -> TypeM rep Occurences checkConsumption (ConsumptionError e) = bad $ TypeError e checkConsumption (Consumption os) = pure os -- | Type check two mutually exclusive control flow branches. Think -- @if@. This interacts with consumption checking, as it is OK for an -- array to be consumed in both branches. alternative :: TypeM rep a -> TypeM rep b -> TypeM rep (a, b) alternative m1 m2 = do (x, os1) <- collectOccurences m1 (y, os2) <- collectOccurences m2 tell $ Consumption $ os1 `altOccurences` os2 pure (x, y) alternatives :: [TypeM rep ()] -> TypeM rep () alternatives [] = pure () alternatives (x : xs) = void $ x `alternative` alternatives xs -- | Permit consumption of only the specified names. If one of these -- names is consumed, the consumption will be rewritten to be a -- consumption of the corresponding alias set. Consumption of -- anything else will result in a type error. consumeOnlyParams :: [(VName, Names)] -> TypeM rep a -> TypeM rep a consumeOnlyParams consumable m = do (x, os) <- collectOccurences m tell . Consumption =<< mapM inspect os pure x where inspect o = do new_consumed <- mconcat <$> mapM wasConsumed (namesToList $ consumed o) pure o {consumed = new_consumed} wasConsumed v | Just als <- lookup v consumable = pure als | otherwise = bad . TypeError . T.unlines $ [ prettyText v <> " was invalidly consumed.", what <> " can be consumed here." ] what | null consumable = "Nothing" | otherwise = "Only " <> T.intercalate ", " (map (prettyText . fst) consumable) -- | Given the immediate aliases, compute the full transitive alias -- set (including the immediate aliases). expandAliases :: Names -> Env rep -> Names expandAliases names env = names <> aliasesOfAliases where aliasesOfAliases = mconcat . map look . namesToList $ names look k = case M.lookup k $ envVtable env of Just (LetName (als, _)) -> unAliases als _ -> mempty binding :: (Checkable rep) => Scope (Aliases rep) -> TypeM rep a -> TypeM rep a binding stms = check . local (`bindVars` stms) where bindVars orig_env = M.foldlWithKey' (bindVar orig_env) orig_env boundnames = M.keys stms bindVar orig_env env name (LetName (AliasDec als, dec)) = let als' | primType (typeOf dec) = mempty | otherwise = expandAliases als orig_env in env { envVtable = M.insert name (LetName (AliasDec als', dec)) $ envVtable env } bindVar _ env name dec = env {envVtable = M.insert name dec $ envVtable env} -- Check whether the bound variables have been used correctly -- within their scope. check m = do mapM_ bound $ M.keys stms (a, os) <- collectOccurences m tell $ Consumption $ unOccur (namesFromList boundnames) os pure a lookupVar :: VName -> TypeM rep (NameInfo (Aliases rep)) lookupVar name = do stm <- asks $ M.lookup name . envVtable case stm of Nothing -> bad $ UnknownVariableError name Just dec -> pure dec lookupAliases :: (Checkable rep) => VName -> TypeM rep Names lookupAliases name = do info <- lookupVar name pure $ if primType $ typeOf info then mempty else oneName name <> aliases info aliases :: NameInfo (Aliases rep) -> Names aliases (LetName (als, _)) = unAliases als aliases _ = mempty subExpAliasesM :: (Checkable rep) => SubExp -> TypeM rep Names subExpAliasesM Constant {} = pure mempty subExpAliasesM (Var v) = lookupAliases v lookupFun :: (Checkable rep) => Name -> [SubExp] -> TypeM rep ([(RetType rep, RetAls)], [DeclType]) lookupFun fname args = do stm <- asks $ M.lookup fname . envFtable case stm of Nothing -> bad $ UnknownFunctionError fname Just (ftype, params) -> do argts <- mapM subExpType args case applyRetType (map fst ftype) params $ zip args argts of Nothing -> bad $ ParameterMismatch (Just fname) (map paramType params) argts Just rt -> pure (zip rt $ map snd ftype, map paramDeclType params) -- | @checkAnnotation loc s t1 t2@ checks if @t2@ is equal to -- @t1@. If not, a 'BadAnnotation' is raised. checkAnnotation :: String -> Type -> Type -> TypeM rep () checkAnnotation desc t1 t2 | t2 == t1 = pure () | otherwise = bad $ BadAnnotation desc t1 t2 -- | @require ts se@ causes a '(TypeError vn)' if the type of @se@ is -- not a subtype of one of the types in @ts@. require :: (Checkable rep) => [Type] -> SubExp -> TypeM rep () require ts se = do t <- checkSubExp se unless (t `elem` ts) $ bad $ UnexpectedType (BasicOp $ SubExp se) t ts -- | Variant of 'require' working on variable names. requireI :: (Checkable rep) => [Type] -> VName -> TypeM rep () requireI ts ident = require ts $ Var ident checkArrIdent :: (Checkable rep) => VName -> TypeM rep (Shape, PrimType) checkArrIdent v = do t <- lookupType v case t of Array pt shape _ -> pure (shape, pt) _ -> bad $ NotAnArray v t checkAccIdent :: (Checkable rep) => VName -> TypeM rep (Shape, [Type]) checkAccIdent v = do t <- lookupType v case t of Acc _ ispace ts _ -> pure (ispace, ts) _ -> bad . TypeError $ prettyText v <> " should be an accumulator but is of type " <> prettyText t checkOpaques :: OpaqueTypes -> Either (TypeError rep) () checkOpaques (OpaqueTypes types) = descend [] types where descend _ [] = pure () descend known ((name, t) : ts) = do check known t descend (name : known) ts check known (OpaqueRecord fs) = mapM_ (checkEntryPointType known . snd) fs check known (OpaqueSum _ cs) = mapM_ (mapM_ (checkEntryPointType known . fst) . snd) cs check known (OpaqueArray _ v _) = checkEntryPointType known (TypeOpaque v) check known (OpaqueRecordArray _ v fs) = do checkEntryPointType known (TypeOpaque v) mapM_ (checkEntryPointType known . snd) fs check _ (OpaqueType _) = pure () checkEntryPointType known (TypeOpaque s) = unless (s `elem` known) $ Left . Error [] . TypeError $ "Opaque not defined before first use: " <> nameToText s checkEntryPointType _ (TypeTransparent _) = pure () -- | Type check a program containing arbitrary type information, -- yielding either a type error or a program with complete type -- information. checkProg :: (Checkable rep) => Prog (Aliases rep) -> Either (TypeError rep) () checkProg (Prog opaques consts funs) = do checkOpaques opaques let typeenv = Env { envVtable = M.empty, envFtable = mempty, envContext = [], envCheckOp = checkOp } let const_names = foldMap (patNames . stmPat) consts onFunction ftable vtable fun = runTypeM typeenv $ do modify $ \s -> s {stateNames = namesFromList const_names} local (\env -> env {envFtable = ftable, envVtable = vtable}) $ checkFun fun ftable <- runTypeM typeenv buildFtable vtable <- runTypeM typeenv {envFtable = ftable} $ checkStms consts $ asks envVtable sequence_ $ parMap rpar (onFunction ftable vtable) funs where buildFtable = do table <- initialFtable foldM expand table funs expand ftable (FunDef _ _ name ret params _) | M.member name ftable = bad $ DupDefinitionError name | otherwise = pure $ M.insert name (ret, params) ftable initialFtable :: (Checkable rep) => TypeM rep (M.Map Name (FunBinding rep)) initialFtable = fmap M.fromList $ mapM addBuiltin $ M.toList builtInFunctions where addBuiltin (fname, (t, ts)) = do ps <- mapM (primFParam name) ts pure (fname, ([(primRetType t, RetAls mempty mempty)], ps)) name = VName (nameFromString "x") 0 checkFun :: (Checkable rep) => FunDef (Aliases rep) -> TypeM rep () checkFun (FunDef _ _ fname rettype params body) = context ("In function " <> nameToText fname) $ checkFun' ( fname, map (first declExtTypeOf) rettype, funParamsToNameInfos params ) (Just consumable) $ do checkFunParams params checkRetType $ map fst rettype context "When checking function body" $ checkFunBody rettype body where consumable = [ (paramName param, mempty) | param <- params, unique $ paramDeclType param ] funParamsToNameInfos :: [FParam rep] -> [(VName, NameInfo (Aliases rep))] funParamsToNameInfos = map nameTypeAndDec where nameTypeAndDec fparam = ( paramName fparam, FParamName $ paramDec fparam ) checkFunParams :: (Checkable rep) => [FParam rep] -> TypeM rep () checkFunParams = mapM_ $ \param -> context ("In parameter " <> prettyText param) $ checkFParamDec (paramName param) (paramDec param) checkLambdaParams :: (Checkable rep) => [LParam rep] -> TypeM rep () checkLambdaParams = mapM_ $ \param -> context ("In parameter " <> prettyText param) $ checkLParamDec (paramName param) (paramDec param) checkNoDuplicateParams :: Name -> [VName] -> TypeM rep () checkNoDuplicateParams fname = foldM_ expand [] where expand seen pname | Just _ <- find (== pname) seen = bad $ DupParamError fname pname | otherwise = pure $ pname : seen checkFun' :: (Checkable rep) => ( Name, [(DeclExtType, RetAls)], [(VName, NameInfo (Aliases rep))] ) -> Maybe [(VName, Names)] -> TypeM rep [Names] -> TypeM rep () checkFun' (fname, rettype, params) consumable check = do checkNoDuplicateParams fname param_names binding (M.fromList params) $ maybe id consumeOnlyParams consumable $ do body_aliases <- check context ( "When checking the body aliases: " <> prettyText (map namesToList body_aliases) ) $ checkReturnAlias body_aliases where param_names = map fst params isParam = (`elem` param_names) unique_names = namesFromList $ do (v, FParamName t) <- params guard $ unique $ declTypeOf t pure v allowedArgAliases pals = namesFromList (map (param_names !!) pals) <> unique_names checkReturnAlias retals = zipWithM_ checkRet (zip [(0 :: Int) ..] rettype) retals where comrades = zip3 [0 ..] retals $ map (otherAls . snd) rettype checkRet (i, (Array {}, RetAls pals rals)) als | als'' <- filter isParam $ namesToList als', not $ null als'' = bad . TypeError . T.unlines $ [ T.unwords ["Result", prettyText i, "aliases", prettyText als''], T.unwords ["but is only allowed to alias arguments", prettyText allowed_args] ] | ((j, _, _) : _) <- filter (isProblem i als' rals) comrades = bad . TypeError . T.unlines $ [ T.unwords ["Results", prettyText i, "and", prettyText j, "alias each other"], T.unwords ["but result", prettyText i, "only allowed to alias results", prettyText rals], prettyText retals ] where allowed_args = allowedArgAliases pals als' = als `namesSubtract` allowed_args checkRet _ _ = pure () isProblem i als rals (j, jals, j_rals) = i /= j && j `notElem` rals && i `notElem` j_rals && namesIntersect als jals checkSubExp :: (Checkable rep) => SubExp -> TypeM rep Type checkSubExp (Constant val) = pure $ Prim $ primValueType val checkSubExp (Var ident) = context ("In subexp " <> prettyText ident) $ do observe ident lookupType ident checkCerts :: (Checkable rep) => Certs -> TypeM rep () checkCerts (Certs cs) = mapM_ (requireI [Prim Unit]) cs checkSubExpRes :: (Checkable rep) => SubExpRes -> TypeM rep Type checkSubExpRes (SubExpRes cs se) = do checkCerts cs checkSubExp se checkStms :: (Checkable rep) => Stms (Aliases rep) -> TypeM rep a -> TypeM rep a checkStms origstms m = delve $ stmsToList origstms where delve (stm@(Let pat _ e) : stms) = do context (docText $ "In expression of statement" indent 2 (pretty pat)) $ checkExp e checkStm stm $ delve stms delve [] = m checkResult :: (Checkable rep) => Result -> TypeM rep () checkResult = mapM_ checkSubExpRes checkFunBody :: (Checkable rep) => [(RetType rep, RetAls)] -> Body (Aliases rep) -> TypeM rep [Names] checkFunBody rt (Body (_, rep) stms res) = do checkBodyDec rep checkStms stms $ do context "When checking body result" $ checkResult res context "When matching declared return type to result of body" $ matchReturnType (map fst rt) res mapM (subExpAliasesM . resSubExp) res checkLambdaBody :: (Checkable rep) => [Type] -> Body (Aliases rep) -> TypeM rep () checkLambdaBody ret (Body (_, rep) stms res) = do checkBodyDec rep checkStms stms $ checkLambdaResult ret res checkLambdaResult :: (Checkable rep) => [Type] -> Result -> TypeM rep () checkLambdaResult ts es | length ts /= length es = bad . TypeError $ "Lambda has return type " <> prettyTuple ts <> " describing " <> prettyText (length ts) <> " values, but body returns " <> prettyText (length es) <> " values: " <> prettyTuple es | otherwise = forM_ (zip ts es) $ \(t, e) -> do et <- checkSubExpRes e unless (et == t) . bad . TypeError $ "Subexpression " <> prettyText e <> " has type " <> prettyText et <> " but expected " <> prettyText t checkBody :: (Checkable rep) => Body (Aliases rep) -> TypeM rep [Names] checkBody (Body (_, rep) stms res) = do checkBodyDec rep checkStms stms $ do checkResult res map (`namesSubtract` bound_here) <$> mapM (subExpAliasesM . resSubExp) res where bound_here = namesFromList $ M.keys $ scopeOf stms -- | Check a slicing operation of an array of the provided type. checkSlice :: (Checkable rep) => Type -> Slice SubExp -> TypeM rep () checkSlice vt (Slice idxes) = do when (arrayRank vt /= length idxes) . bad $ SlicingError (arrayRank vt) (length idxes) mapM_ (traverse $ require [Prim int64]) idxes checkBasicOp :: (Checkable rep) => BasicOp -> TypeM rep () checkBasicOp (SubExp es) = void $ checkSubExp es checkBasicOp (Opaque _ es) = void $ checkSubExp es checkBasicOp ArrayVal {} = -- We assume this is never changed, so no need to check it. pure () checkBasicOp (ArrayLit [] _) = pure () checkBasicOp (ArrayLit (e : es') t) = do let check elemt eleme = do elemet <- checkSubExp eleme unless (elemet == elemt) . bad . TypeError $ prettyText elemet <> " is not of expected type " <> prettyText elemt <> "." et <- checkSubExp e -- Compare that type with the one given for the array literal. checkAnnotation "array-element" t et mapM_ (check et) es' checkBasicOp (UnOp op e) = require [Prim $ unOpType op] e checkBasicOp (BinOp op e1 e2) = checkBinOpArgs (binOpType op) e1 e2 checkBasicOp (CmpOp op e1 e2) = checkCmpOp op e1 e2 checkBasicOp (ConvOp op e) = require [Prim $ fst $ convOpType op] e checkBasicOp (Index ident slice) = do vt <- lookupType ident observe ident checkSlice vt slice checkBasicOp (Update _ src slice se) = do (src_shape, src_pt) <- checkArrIdent src se_aliases <- subExpAliasesM se when (src `nameIn` se_aliases) $ bad $ TypeError "The target of an Update must not alias the value to be written." checkSlice (arrayOf (Prim src_pt) src_shape NoUniqueness) slice require [arrayOf (Prim src_pt) (sliceShape slice) NoUniqueness] se consume =<< lookupAliases src checkBasicOp (FlatIndex ident slice) = do vt <- lookupType ident observe ident when (arrayRank vt /= 1) $ bad $ SlicingError (arrayRank vt) 1 checkFlatSlice slice checkBasicOp (FlatUpdate src slice v) = do (src_shape, src_pt) <- checkArrIdent src when (shapeRank src_shape /= 1) $ bad $ SlicingError (shapeRank src_shape) 1 v_aliases <- lookupAliases v when (src `nameIn` v_aliases) $ bad $ TypeError "The target of an Update must not alias the value to be written." checkFlatSlice slice requireI [arrayOf (Prim src_pt) (Shape (flatSliceDims slice)) NoUniqueness] v consume =<< lookupAliases src checkBasicOp (Iota e x s et) = do require [Prim int64] e require [Prim $ IntType et] x require [Prim $ IntType et] s checkBasicOp (Replicate (Shape dims) valexp) = do mapM_ (require [Prim int64]) dims void $ checkSubExp valexp checkBasicOp (Scratch _ shape) = mapM_ checkSubExp shape checkBasicOp (Reshape k newshape arrexp) = do rank <- shapeRank . fst <$> checkArrIdent arrexp mapM_ (require [Prim int64]) $ shapeDims newshape case k of ReshapeCoerce -> when (shapeRank newshape /= rank) . bad $ TypeError "Coercion changes rank of array." ReshapeArbitrary -> pure () checkBasicOp (Rearrange perm arr) = do arrt <- lookupType arr let rank = arrayRank arrt when (length perm /= rank || sort perm /= [0 .. rank - 1]) $ bad $ PermutationError perm rank $ Just arr checkBasicOp (Concat i (arr1exp :| arr2exps) ressize) = do arr1_dims <- shapeDims . fst <$> checkArrIdent arr1exp arr2s_dims <- map (shapeDims . fst) <$> mapM checkArrIdent arr2exps unless (all ((== dropAt i 1 arr1_dims) . dropAt i 1) arr2s_dims) $ bad $ TypeError "Types of arguments to concat do not match." require [Prim int64] ressize checkBasicOp (Manifest perm arr) = checkBasicOp $ Rearrange perm arr -- Basically same thing! checkBasicOp (Assert e (ErrorMsg parts) _) = do require [Prim Bool] e mapM_ checkPart parts where checkPart ErrorString {} = pure () checkPart (ErrorVal t x) = require [Prim t] x checkBasicOp (UpdateAcc _ acc is ses) = do (shape, ts) <- checkAccIdent acc unless (length ses == length ts) . bad . TypeError $ "Accumulator requires " <> prettyText (length ts) <> " values, but " <> prettyText (length ses) <> " provided." unless (length is == shapeRank shape) $ bad . TypeError $ "Accumulator requires " <> prettyText (shapeRank shape) <> " indices, but " <> prettyText (length is) <> " provided." zipWithM_ require (map pure ts) ses consume =<< lookupAliases acc matchLoopResultExt :: (Checkable rep) => [Param DeclType] -> Result -> TypeM rep () matchLoopResultExt merge loopres = do let rettype_ext = existentialiseExtTypes (map paramName merge) $ staticShapes $ map typeOf merge bodyt <- mapM subExpResType loopres case instantiateShapes (fmap resSubExp . (`maybeNth` loopres)) rettype_ext of Nothing -> bad $ ReturnTypeError (nameFromString "") rettype_ext (staticShapes bodyt) Just rettype' -> unless (bodyt `subtypesOf` rettype') . bad $ ReturnTypeError (nameFromString "") (staticShapes rettype') (staticShapes bodyt) allowAllAliases :: Int -> Int -> RetAls allowAllAliases n m = RetAls [0 .. n - 1] [0 .. m - 1] checkExp :: (Checkable rep) => Exp (Aliases rep) -> TypeM rep () checkExp (BasicOp op) = checkBasicOp op checkExp (Match ses cases def_case info) = do ses_ts <- mapM checkSubExp ses alternatives $ context "in body of last case" (checkCaseBody def_case) : map (checkCase ses_ts) cases where checkVal t (Just v) = Prim (primValueType v) == t checkVal _ Nothing = True checkCase ses_ts (Case vs body) = do let ok = length vs == length ses_ts && and (zipWith checkVal ses_ts vs) unless ok . bad . TypeError . docText $ "Scrutinee" indent 2 (ppTuple' $ map pretty ses) "cannot match pattern" indent 2 (ppTuple' $ map pretty vs) context ("in body of case " <> prettyTuple vs) $ checkCaseBody body checkCaseBody body = do void $ checkBody body matchBranchType (matchReturns info) body checkExp (Apply fname args rettype_annot _) = do (rettype_derived, paramtypes) <- lookupFun fname $ map fst args argflows <- mapM (checkArg . fst) args when (rettype_derived /= rettype_annot) $ bad . TypeError . docText $ "Expected apply result type:" indent 2 (pretty $ map fst rettype_derived) "But annotation is:" indent 2 (pretty $ map fst rettype_annot) consumeArgs paramtypes argflows checkExp (Loop merge form loopbody) = do let (mergepat, mergeexps) = unzip merge mergeargs <- mapM checkArg mergeexps checkLoopArgs binding (scopeOfLoopForm form) $ do form_consumable <- checkForm mergeargs form let rettype = map paramDeclType mergepat consumable = [ (paramName param, mempty) | param <- mergepat, unique $ paramDeclType param ] ++ form_consumable context "Inside the loop body" $ checkFun' ( nameFromString "", map (,allowAllAliases (length merge) (length merge)) (staticShapes rettype), funParamsToNameInfos mergepat ) (Just consumable) $ do checkFunParams mergepat checkBodyDec $ snd $ bodyDec loopbody checkStms (bodyStms loopbody) $ do context "In loop body result" $ checkResult $ bodyResult loopbody context "When matching result of body with loop parameters" $ matchLoopResult (map fst merge) $ bodyResult loopbody let bound_here = namesFromList $ M.keys $ scopeOf $ bodyStms loopbody map (`namesSubtract` bound_here) <$> mapM (subExpAliasesM . resSubExp) (bodyResult loopbody) where checkForm mergeargs (ForLoop loopvar it boundexp) = do iparam <- primFParam loopvar $ IntType it let mergepat = map fst merge funparams = iparam : mergepat paramts = map paramDeclType funparams boundarg <- checkArg boundexp checkFuncall Nothing paramts $ boundarg : mergeargs pure mempty checkForm mergeargs (WhileLoop cond) = do case find ((== cond) . paramName . fst) merge of Just (condparam, _) -> unless (paramType condparam == Prim Bool) $ bad . TypeError $ "Conditional '" <> prettyText cond <> "' of while-loop is not boolean, but " <> prettyText (paramType condparam) <> "." Nothing -> -- Implies infinite loop, but that's OK. pure () let mergepat = map fst merge funparams = mergepat paramts = map paramDeclType funparams checkFuncall Nothing paramts mergeargs pure mempty checkLoopArgs = do let (params, args) = unzip merge argtypes <- mapM subExpType args let expected = expectedTypes (map paramName params) params args unless (expected == argtypes) . bad . TypeError . docText $ "Loop parameters" indent 2 (ppTuple' $ map pretty params) "cannot accept initial values" indent 2 (ppTuple' $ map pretty args) "of types" indent 2 (ppTuple' $ map pretty argtypes) checkExp (WithAcc inputs lam) = do unless (length (lambdaParams lam) == 2 * num_accs) . bad . TypeError $ prettyText (length (lambdaParams lam)) <> " parameters, but " <> prettyText num_accs <> " accumulators." let cert_params = take num_accs $ lambdaParams lam acc_args <- forM (zip inputs cert_params) $ \((shape, arrs, op), p) -> do mapM_ (require [Prim int64]) (shapeDims shape) elem_ts <- forM arrs $ \arr -> do arr_t <- lookupType arr unless (shapeDims shape `isPrefixOf` arrayDims arr_t) $ bad . TypeError $ prettyText arr <> " is not an array of outer shape " <> prettyText shape consume =<< lookupAliases arr pure $ stripArray (shapeRank shape) arr_t case op of Just (op_lam, nes) -> do let mkArrArg t = (t, mempty) nes_ts <- mapM checkSubExp nes unless (nes_ts == lambdaReturnType op_lam) $ bad . TypeError . T.unlines $ [ "Accumulator operator return type: " <> prettyText (lambdaReturnType op_lam), "Type of neutral elements: " <> prettyText nes_ts ] checkLambda op_lam $ replicate (shapeRank shape) (Prim int64, mempty) ++ map mkArrArg (elem_ts ++ elem_ts) Nothing -> pure () pure (Acc (paramName p) shape elem_ts NoUniqueness, mempty) checkAnyLambda False lam $ replicate num_accs (Prim Unit, mempty) ++ acc_args where num_accs = length inputs checkExp (Op op) = do checker <- asks envCheckOp checker op checkSOACArrayArgs :: (Checkable rep) => SubExp -> [VName] -> TypeM rep [Arg] checkSOACArrayArgs width = mapM checkSOACArrayArg where checkSOACArrayArg v = do (t, als) <- checkArg $ Var v case t of Acc {} -> pure (t, als) Array {} -> do let argSize = arraySize 0 t unless (argSize == width) . bad . TypeError $ "SOAC argument " <> prettyText v <> " has outer size " <> prettyText argSize <> ", but width of SOAC is " <> prettyText width pure (rowType t, als) _ -> bad . TypeError $ "SOAC argument " <> prettyText v <> " is not an array" checkType :: (Checkable rep) => TypeBase Shape u -> TypeM rep () checkType (Mem (ScalarSpace d _)) = mapM_ (require [Prim int64]) d checkType (Acc cert shape ts _) = do requireI [Prim Unit] cert mapM_ (require [Prim int64]) $ shapeDims shape mapM_ checkType ts checkType t = mapM_ checkSubExp $ arrayDims t checkExtType :: (Checkable rep) => TypeBase ExtShape u -> TypeM rep () checkExtType = mapM_ checkExtDim . shapeDims . arrayShape where checkExtDim (Free se) = void $ checkSubExp se checkExtDim (Ext _) = pure () checkCmpOp :: (Checkable rep) => CmpOp -> SubExp -> SubExp -> TypeM rep () checkCmpOp (CmpEq t) x y = do require [Prim t] x require [Prim t] y checkCmpOp (CmpUlt t) x y = checkBinOpArgs (IntType t) x y checkCmpOp (CmpUle t) x y = checkBinOpArgs (IntType t) x y checkCmpOp (CmpSlt t) x y = checkBinOpArgs (IntType t) x y checkCmpOp (CmpSle t) x y = checkBinOpArgs (IntType t) x y checkCmpOp (FCmpLt t) x y = checkBinOpArgs (FloatType t) x y checkCmpOp (FCmpLe t) x y = checkBinOpArgs (FloatType t) x y checkCmpOp CmpLlt x y = checkBinOpArgs Bool x y checkCmpOp CmpLle x y = checkBinOpArgs Bool x y checkBinOpArgs :: (Checkable rep) => PrimType -> SubExp -> SubExp -> TypeM rep () checkBinOpArgs t e1 e2 = do require [Prim t] e1 require [Prim t] e2 checkPatElem :: (Checkable rep) => PatElem (LetDec rep) -> TypeM rep () checkPatElem (PatElem name dec) = context ("When checking pattern element " <> prettyText name) $ checkLetBoundDec name dec checkFlatDimIndex :: (Checkable rep) => FlatDimIndex SubExp -> TypeM rep () checkFlatDimIndex (FlatDimIndex n s) = mapM_ (require [Prim int64]) [n, s] checkFlatSlice :: (Checkable rep) => FlatSlice SubExp -> TypeM rep () checkFlatSlice (FlatSlice offset idxs) = do require [Prim int64] offset mapM_ checkFlatDimIndex idxs checkStm :: (Checkable rep) => Stm (Aliases rep) -> TypeM rep a -> TypeM rep a checkStm stm@(Let pat (StmAux (Certs cs) _ (_, dec)) e) m = do context "When checking certificates" $ mapM_ (requireI [Prim Unit]) cs context "When checking expression annotation" $ checkExpDec dec context ("When matching\n" <> message " " pat <> "\nwith\n" <> message " " e) $ matchPat pat e binding (scopeOf stm) $ do mapM_ checkPatElem (patElems $ removePatAliases pat) m matchExtPat :: (Checkable rep) => Pat (LetDec (Aliases rep)) -> [ExtType] -> TypeM rep () matchExtPat pat ts = unless (expExtTypesFromPat pat == ts) $ bad $ InvalidPatError pat ts Nothing matchExtReturnType :: (Checkable rep) => [ExtType] -> Result -> TypeM rep () matchExtReturnType rettype res = do ts <- mapM subExpResType res matchExtReturns rettype res ts matchExtBranchType :: (Checkable rep) => [ExtType] -> Body (Aliases rep) -> TypeM rep () matchExtBranchType rettype (Body _ stms res) = do ts <- extendedScope (traverse subExpResType res) stmscope matchExtReturns rettype res ts where stmscope = scopeOf stms matchExtReturns :: [ExtType] -> Result -> [Type] -> TypeM rep () matchExtReturns rettype res ts = do let problem :: TypeM rep a problem = bad . TypeError . T.unlines $ [ "Type annotation is", " " <> prettyTuple rettype, "But result returns type", " " <> prettyTuple ts ] unless (length res == length rettype) problem let ctx_vals = zip res ts instantiateExt i = case maybeNth i ctx_vals of Just (SubExpRes _ se, Prim (IntType Int64)) -> pure se _ -> problem rettype' <- instantiateShapes instantiateExt rettype unless (rettype' == ts) problem validApply :: (ArrayShape shape) => [TypeBase shape Uniqueness] -> [TypeBase shape NoUniqueness] -> Bool validApply expected got = length got == length expected && and ( zipWith subtypeOf (map rankShaped got) (map (fromDecl . rankShaped) expected) ) type Arg = (Type, Names) argType :: Arg -> Type argType (t, _) = t -- | Remove all aliases from the 'Arg'. argAliases :: Arg -> Names argAliases (_, als) = als noArgAliases :: Arg -> Arg noArgAliases (t, _) = (t, mempty) checkArg :: (Checkable rep) => SubExp -> TypeM rep Arg checkArg arg = do argt <- checkSubExp arg als <- subExpAliasesM arg pure (argt, als) checkFuncall :: Maybe Name -> [DeclType] -> [Arg] -> TypeM rep () checkFuncall fname paramts args = do let argts = map argType args unless (validApply paramts argts) $ bad $ ParameterMismatch fname (map fromDecl paramts) $ map argType args consumeArgs paramts args consumeArgs :: [DeclType] -> [Arg] -> TypeM rep () consumeArgs paramts args = forM_ (zip (map diet paramts) args) $ \(d, (_, als)) -> occur [consumption (consumeArg als d)] where consumeArg als Consume = als consumeArg _ _ = mempty -- The boolean indicates whether we only allow consumption of -- parameters. checkAnyLambda :: (Checkable rep) => Bool -> Lambda (Aliases rep) -> [Arg] -> TypeM rep () checkAnyLambda soac (Lambda params rettype body) args = do let fname = nameFromString "" if length params == length args then do -- Consumption for this is done explicitly elsewhere. checkFuncall Nothing (map ((`toDecl` Nonunique) . paramType) params) $ map noArgAliases args let consumable = if soac then Just $ zip (map paramName params) (map argAliases args) else Nothing params' = [(paramName param, LParamName $ paramDec param) | param <- params] checkNoDuplicateParams fname $ map paramName params binding (M.fromList params') $ maybe id consumeOnlyParams consumable $ do checkLambdaParams params mapM_ checkType rettype checkLambdaBody rettype body else bad . TypeError $ "Anonymous function defined with " <> prettyText (length params) <> " parameters:\n" <> prettyText params <> "\nbut expected to take " <> prettyText (length args) <> " arguments." checkLambda :: (Checkable rep) => Lambda (Aliases rep) -> [Arg] -> TypeM rep () checkLambda = checkAnyLambda True checkPrimExp :: (Checkable rep) => PrimExp VName -> TypeM rep () checkPrimExp ValueExp {} = pure () checkPrimExp (LeafExp v pt) = requireI [Prim pt] v checkPrimExp (BinOpExp op x y) = do requirePrimExp (binOpType op) x requirePrimExp (binOpType op) y checkPrimExp (CmpOpExp op x y) = do requirePrimExp (cmpOpType op) x requirePrimExp (cmpOpType op) y checkPrimExp (UnOpExp op x) = requirePrimExp (unOpType op) x checkPrimExp (ConvOpExp op x) = requirePrimExp (fst $ convOpType op) x checkPrimExp (FunExp h args t) = do (h_ts, h_ret, _) <- maybe (bad $ TypeError $ "Unknown function: " <> h) pure $ M.lookup h primFuns when (length h_ts /= length args) . bad . TypeError $ "Function expects " <> prettyText (length h_ts) <> " parameters, but given " <> prettyText (length args) <> " arguments." when (h_ret /= t) . bad . TypeError $ "Function return annotation is " <> prettyText t <> ", but expected " <> prettyText h_ret zipWithM_ requirePrimExp h_ts args requirePrimExp :: (Checkable rep) => PrimType -> PrimExp VName -> TypeM rep () requirePrimExp t e = context ("in PrimExp " <> prettyText e) $ do checkPrimExp e unless (primExpType e == t) . bad . TypeError $ prettyText e <> " must have type " <> prettyText t -- | The class of representations that can be type-checked. class (AliasableRep rep, TypedOp (OpC rep)) => Checkable rep where checkExpDec :: ExpDec rep -> TypeM rep () checkBodyDec :: BodyDec rep -> TypeM rep () checkFParamDec :: VName -> FParamInfo rep -> TypeM rep () checkLParamDec :: VName -> LParamInfo rep -> TypeM rep () checkLetBoundDec :: VName -> LetDec rep -> TypeM rep () checkRetType :: [RetType rep] -> TypeM rep () matchPat :: Pat (LetDec (Aliases rep)) -> Exp (Aliases rep) -> TypeM rep () primFParam :: VName -> PrimType -> TypeM rep (FParam (Aliases rep)) matchReturnType :: [RetType rep] -> Result -> TypeM rep () matchBranchType :: [BranchType rep] -> Body (Aliases rep) -> TypeM rep () matchLoopResult :: [FParam (Aliases rep)] -> Result -> TypeM rep () -- | Used at top level; can be locally changed with 'checkOpWith'. checkOp :: Op (Aliases rep) -> TypeM rep () default checkExpDec :: (ExpDec rep ~ ()) => ExpDec rep -> TypeM rep () checkExpDec = pure default checkBodyDec :: (BodyDec rep ~ ()) => BodyDec rep -> TypeM rep () checkBodyDec = pure default checkFParamDec :: (FParamInfo rep ~ DeclType) => VName -> FParamInfo rep -> TypeM rep () checkFParamDec _ = checkType default checkLParamDec :: (LParamInfo rep ~ Type) => VName -> LParamInfo rep -> TypeM rep () checkLParamDec _ = checkType default checkLetBoundDec :: (LetDec rep ~ Type) => VName -> LetDec rep -> TypeM rep () checkLetBoundDec _ = checkType default checkRetType :: (RetType rep ~ DeclExtType) => [RetType rep] -> TypeM rep () checkRetType = mapM_ $ checkExtType . declExtTypeOf default matchPat :: Pat (LetDec (Aliases rep)) -> Exp (Aliases rep) -> TypeM rep () matchPat pat = matchExtPat pat <=< expExtType default primFParam :: (FParamInfo rep ~ DeclType) => VName -> PrimType -> TypeM rep (FParam (Aliases rep)) primFParam name t = pure $ Param mempty name (Prim t) default matchReturnType :: (RetType rep ~ DeclExtType) => [RetType rep] -> Result -> TypeM rep () matchReturnType = matchExtReturnType . map fromDecl default matchBranchType :: (BranchType rep ~ ExtType) => [BranchType rep] -> Body (Aliases rep) -> TypeM rep () matchBranchType = matchExtBranchType default matchLoopResult :: (FParamInfo rep ~ DeclType) => [FParam (Aliases rep)] -> Result -> TypeM rep () matchLoopResult = matchLoopResultExt futhark-0.25.27/src/Futhark/Internalise.hs000066400000000000000000000073361475065116200203570ustar00rootroot00000000000000-- | -- -- This module implements a transformation from source to core -- Futhark. -- -- The source and core language is similar in spirit, but the core -- language is much more regular (and mostly much simpler) in order to -- make it easier to write program transformations. -- -- * "Language.Futhark.Syntax" contains the source language definition. -- -- * "Futhark.IR.Syntax" contains the core IR definition. -- -- Specifically, internalisation generates the SOACS dialect of the -- core IR ("Futhark.IR.SOACS"). This is then initially used by the -- compiler middle-end. The main differences between the source and -- core IR are as follows: -- -- * The core IR has no modules. These are removed in -- "Futhark.Internalise.Defunctorise". -- -- * The core IR has no type abbreviations. These are removed in -- "Futhark.Internalise.ApplyTypeAbbrs". -- -- * The core IR has little syntactic niceties. A lot of syntactic -- sugar is removed in "Futhark.Internalise.FullNormalise". -- -- * Lambda lifting is performed by "Futhark.Internalise.LiftLambdas", -- * mostly to make the job of later passes simpler. -- -- * The core IR is monomorphic. Polymorphic functions are monomorphised in -- "Futhark.Internalise.Monomorphise" -- -- * The core IR is first-order. "Futhark.Internalise.Defunctionalise" -- removes higher-order functions. -- -- * The core IR is in [ANF](https://en.wikipedia.org/wiki/A-normal_form). -- -- * The core IR does not have arrays of tuples (or tuples or records -- at all, really). Arrays of tuples are turned into multiple -- arrays. For example, a source language transposition of an array -- of pairs becomes a core IR that contains two transpositions of -- two distinct arrays. The guts of this transformation is in -- "Futhark.Internalise.Exps". -- -- * For the above reason, SOACs also accept multiple input arrays. -- The available primitive operations are also somewhat different -- than in the source language. See 'Futhark.IR.SOACS.SOAC.SOAC'. module Futhark.Internalise (internaliseProg) where import Data.Text qualified as T import Futhark.Compiler.Config import Futhark.IR.SOACS as I hiding (stmPat) import Futhark.Internalise.ApplyTypeAbbrs as ApplyTypeAbbrs import Futhark.Internalise.Defunctionalise as Defunctionalise import Futhark.Internalise.Defunctorise as Defunctorise import Futhark.Internalise.Entry (visibleTypes) import Futhark.Internalise.Exps qualified as Exps import Futhark.Internalise.FullNormalise qualified as FullNormalise import Futhark.Internalise.LiftLambdas as LiftLambdas import Futhark.Internalise.Monad as I import Futhark.Internalise.Monomorphise as Monomorphise import Futhark.Internalise.ReplaceRecords as ReplaceRecords import Futhark.Util.Log import Language.Futhark.Semantic (Imports) -- | Convert a program in source Futhark to a program in the Futhark -- core language. internaliseProg :: (MonadFreshNames m, MonadLogger m) => FutharkConfig -> Imports -> m (I.Prog SOACS) internaliseProg config prog = do maybeLog "Defunctorising" prog_decs0 <- ApplyTypeAbbrs.transformProg =<< Defunctorise.transformProg prog maybeLog "Full Normalising" prog_decs1 <- FullNormalise.transformProg prog_decs0 maybeLog "Replacing records" prog_decs2 <- ReplaceRecords.transformProg prog_decs1 maybeLog "Lifting lambdas" prog_decs3 <- LiftLambdas.transformProg prog_decs2 maybeLog "Monomorphising" prog_decs4 <- Monomorphise.transformProg prog_decs3 maybeLog "Defunctionalising" prog_decs5 <- Defunctionalise.transformProg prog_decs4 maybeLog "Converting to core IR" Exps.transformProg (futharkSafe config) (visibleTypes prog) prog_decs5 where verbose = fst (futharkVerbose config) > NotVerbose maybeLog s | verbose = logMsg (s :: T.Text) | otherwise = pure () futhark-0.25.27/src/Futhark/Internalise/000077500000000000000000000000001475065116200200125ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Internalise/AccurateSizes.hs000066400000000000000000000103071475065116200231140ustar00rootroot00000000000000module Futhark.Internalise.AccurateSizes ( argShapes, ensureResultShape, ensureResultExtShape, ensureExtShape, ensureShape, ensureArgShapes, ) where import Control.Monad import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Construct import Futhark.IR.SOACS import Futhark.Internalise.Monad import Futhark.Util (takeLast) shapeMapping :: (HasScope SOACS m, Monad m) => [FParam SOACS] -> [Type] -> m (M.Map VName SubExp) shapeMapping all_params value_arg_types = mconcat <$> zipWithM f (map paramType value_params) value_arg_types where value_params = takeLast (length value_arg_types) all_params f t1@Array {} t2@Array {} = pure $ M.fromList $ mapMaybe match $ zip (arrayDims t1) (arrayDims t2) f (Acc acc1 ispace1 ts1 _) (Acc acc2 ispace2 ts2 _) = do let ispace_m = M.fromList . mapMaybe match $ zip (shapeDims ispace1) (shapeDims ispace2) arr_sizes_m <- mconcat <$> zipWithM f ts1 ts2 pure $ M.singleton acc1 (Var acc2) <> ispace_m <> arr_sizes_m f _ _ = pure mempty match (Var v, se) = Just (v, se) match _ = Nothing argShapes :: [VName] -> [FParam SOACS] -> [Type] -> InternaliseM [SubExp] argShapes shapes all_params valargts = do mapping <- shapeMapping all_params valargts let addShape name = case M.lookup name mapping of Just se -> se _ -> error $ "argShapes: " ++ prettyString name pure $ map addShape shapes ensureResultShape :: ErrorMsg SubExp -> SrcLoc -> [Type] -> Result -> InternaliseM Result ensureResultShape msg loc = ensureResultExtShape msg loc . staticShapes ensureResultExtShape :: ErrorMsg SubExp -> SrcLoc -> [ExtType] -> Result -> InternaliseM Result ensureResultExtShape msg loc rettype res = do res' <- ensureResultExtShapeNoCtx msg loc rettype res ts <- mapM subExpResType res' let ctx = extractShapeContext rettype $ map arrayDims ts pure $ subExpsRes ctx ++ res' ensureResultExtShapeNoCtx :: ErrorMsg SubExp -> SrcLoc -> [ExtType] -> Result -> InternaliseM Result ensureResultExtShapeNoCtx msg loc rettype es = do es_ts <- mapM subExpResType es let ext_mapping = shapeExtMapping rettype es_ts rettype' = foldr (uncurry fixExt) rettype $ M.toList ext_mapping assertProperShape t (SubExpRes cs se) = let name = "result_proper_shape" in SubExpRes cs <$> ensureExtShape msg loc t name se zipWithM assertProperShape rettype' es ensureExtShape :: ErrorMsg SubExp -> SrcLoc -> ExtType -> String -> SubExp -> InternaliseM SubExp ensureExtShape msg loc t name orig | Array {} <- t, Var v <- orig = Var <$> ensureShapeVar msg loc t name v | otherwise = pure orig ensureShape :: ErrorMsg SubExp -> SrcLoc -> Type -> String -> SubExp -> InternaliseM SubExp ensureShape msg loc = ensureExtShape msg loc . staticShapes1 -- | Reshape the arguments to a function so that they fit the expected -- shape declarations. Not used to change rank of arguments. Assumes -- everything is otherwise type-correct. ensureArgShapes :: (Typed (TypeBase Shape u)) => ErrorMsg SubExp -> SrcLoc -> [VName] -> [TypeBase Shape u] -> [SubExp] -> InternaliseM [SubExp] ensureArgShapes msg loc shapes paramts args = zipWithM ensureArgShape (expectedTypes shapes paramts args) args where ensureArgShape _ (Constant v) = pure $ Constant v ensureArgShape t (Var v) | arrayRank t < 1 = pure $ Var v | otherwise = ensureShape msg loc t (baseString v) $ Var v ensureShapeVar :: ErrorMsg SubExp -> SrcLoc -> ExtType -> String -> VName -> InternaliseM VName ensureShapeVar msg loc t name v | Array {} <- t = do newdims <- arrayDims . removeExistentials t <$> lookupType v olddims <- arrayDims <$> lookupType v if newdims == olddims then pure v else do matches <- zipWithM checkDim newdims olddims all_match <- letSubExp "match" =<< eAll matches cs <- assert "empty_or_match_cert" all_match msg loc certifying cs $ letExp name $ shapeCoerce newdims v | otherwise = pure v where checkDim desired has = letSubExp "dim_match" $ BasicOp $ CmpOp (CmpEq int64) desired has futhark-0.25.27/src/Futhark/Internalise/ApplyTypeAbbrs.hs000066400000000000000000000064041475065116200232530ustar00rootroot00000000000000-- | A minor cleanup pass that runs after defunctorisation and applies -- any type abbreviations. After this, the program consists entirely -- value bindings. module Futhark.Internalise.ApplyTypeAbbrs (transformProg) where import Control.Monad.Identity import Data.Map.Strict qualified as M import Data.Maybe (mapMaybe) import Language.Futhark import Language.Futhark.Semantic (TypeBinding (..)) import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Types type Types = M.Map VName (Subst StructRetType) getTypes :: Types -> [Dec] -> Types getTypes types [] = types getTypes types (TypeDec typebind : ds) = do let (TypeBind name l tparams _ (Info (RetType dims t)) _ _) = typebind tbinding = TypeAbbr l tparams $ RetType dims $ applySubst (`M.lookup` types) t types' = M.insert name (substFromAbbr tbinding) types getTypes types' ds getTypes types (_ : ds) = getTypes types ds -- Perform a given substitution on the types in a pattern. substPat :: (t -> t) -> Pat t -> Pat t substPat f pat = case pat of TuplePat pats loc -> TuplePat (map (substPat f) pats) loc RecordPat fs loc -> RecordPat (map substField fs) loc where substField (n, p) = (n, substPat f p) PatParens p loc -> PatParens (substPat f p) loc PatAttr attr p loc -> PatAttr attr (substPat f p) loc Id vn (Info tp) loc -> Id vn (Info $ f tp) loc Wildcard (Info tp) loc -> Wildcard (Info $ f tp) loc PatAscription p _ _ -> substPat f p PatLit e (Info tp) loc -> PatLit e (Info $ f tp) loc PatConstr n (Info tp) ps loc -> PatConstr n (Info $ f tp) ps loc removeTypeVariablesInType :: Types -> StructType -> StructType removeTypeVariablesInType types = applySubst (`M.lookup` types) substEntry :: Types -> EntryPoint -> EntryPoint substEntry types (EntryPoint params ret) = EntryPoint (map onEntryParam params) (onEntryType ret) where onEntryParam (EntryParam v t) = EntryParam v $ onEntryType t onEntryType (EntryType t te) = EntryType (removeTypeVariablesInType types t) te -- Remove all type variables and type abbreviations from a value binding. removeTypeVariables :: Types -> ValBind -> ValBind removeTypeVariables types valbind = do let (ValBind entry _ _ (Info (RetType dims rettype)) _ pats body _ _ _) = valbind mapper = ASTMapper { mapOnExp = onExp, mapOnName = pure, mapOnStructType = pure . applySubst (`M.lookup` types), mapOnParamType = pure . applySubst (`M.lookup` types), mapOnResRetType = pure . applySubst (`M.lookup` types) } onExp = astMap mapper let body' = runIdentity $ onExp body valbind { valBindRetType = Info (applySubst (`M.lookup` types) $ RetType dims rettype), valBindParams = map (substPat $ applySubst (`M.lookup` types)) pats, valBindEntryPoint = fmap (substEntry types) <$> entry, valBindBody = body' } -- | Apply type abbreviations from a list of top-level declarations. A -- module-free input program is expected, so only value declarations -- and type declaration are accepted. transformProg :: (Monad m) => [Dec] -> m [ValBind] transformProg decs = let types = getTypes mempty decs onDec (ValDec valbind) = Just $ removeTypeVariables types valbind onDec _ = Nothing in pure $ mapMaybe onDec decs futhark-0.25.27/src/Futhark/Internalise/Bindings.hs000066400000000000000000000153641475065116200221140ustar00rootroot00000000000000{-# LANGUAGE Strict #-} -- | Internalising bindings. module Futhark.Internalise.Bindings ( internaliseAttrs, internaliseAttr, bindingFParams, bindingLoopParams, bindingLambdaParams, stmPat, ) where import Control.Monad import Control.Monad.Free (Free (..)) import Control.Monad.Reader import Data.Bifunctor import Data.Foldable (toList) import Data.Map.Strict qualified as M import Data.Maybe import Futhark.IR.SOACS qualified as I import Futhark.Internalise.Monad import Futhark.Internalise.TypesValues import Futhark.Util import Language.Futhark as E hiding (matchDims) internaliseAttr :: E.AttrInfo VName -> InternaliseM I.Attr internaliseAttr (E.AttrAtom (E.AtomName v) _) = pure $ I.AttrName v internaliseAttr (E.AttrAtom (E.AtomInt x) _) = pure $ I.AttrInt x internaliseAttr (E.AttrComp f attrs _) = I.AttrComp f <$> mapM internaliseAttr attrs internaliseAttrs :: [E.AttrInfo VName] -> InternaliseM I.Attrs internaliseAttrs = fmap (mconcat . map I.oneAttr) . mapM internaliseAttr treeLike :: Tree a -> [b] -> Tree b treeLike (Pure _) [b] = Pure b treeLike (Pure _) _ = error "treeLike: invalid input" treeLike (Free ls) bs = Free $ zipWith treeLike ls (chunks (map length ls) bs) bindingFParams :: [E.TypeParam] -> [E.Pat E.ParamType] -> ([I.FParam I.SOACS] -> [[Tree (I.FParam I.SOACS)]] -> InternaliseM a) -> InternaliseM a bindingFParams tparams params m = do flattened_params <- mapM flattenPat params let params_idents = concat flattened_params params_ts <- internaliseParamTypes $ map (E.unInfo . E.identType . fst) params_idents let num_param_idents = map length flattened_params let shape_params = [I.Param mempty v $ I.Prim I.int64 | E.TypeParamDim v _ <- tparams] shape_subst = M.fromList [(I.paramName p, [I.Var $ I.paramName p]) | p <- shape_params] bindingFlatPat params_idents (concatMap (concatMap toList) params_ts) $ \valueparams -> do let (certparams, valueparams') = first concat $ unzip $ map fixAccParams valueparams all_params = certparams ++ shape_params ++ concat valueparams' I.localScope (I.scopeOfFParams all_params) $ substitutingVars shape_subst $ do let values_grouped_by_params = chunks num_param_idents valueparams' types_grouped_by_params = chunks num_param_idents params_ts m (certparams ++ shape_params) $ zipWith chunkValues types_grouped_by_params values_grouped_by_params where fixAccParams ps = first catMaybes $ unzip $ map fixAccParam ps fixAccParam (I.Param attrs pv (I.Acc acc ispace ts u)) = ( Just (I.Param attrs acc $ I.Prim I.Unit), I.Param attrs pv (I.Acc acc ispace ts u) ) fixAccParam p = (Nothing, p) chunkValues :: [[Tree (I.TypeBase I.Shape Uniqueness)]] -> [[I.FParam I.SOACS]] -> [Tree (I.FParam I.SOACS)] chunkValues tss vss = concat $ zipWith f tss vss where f ts vs = zipWith treeLike ts (chunks (map length ts) vs) bindingLoopParams :: [E.TypeParam] -> E.Pat E.ParamType -> [I.Type] -> ([I.FParam I.SOACS] -> [I.FParam I.SOACS] -> InternaliseM a) -> InternaliseM a bindingLoopParams tparams pat ts m = do pat_idents <- flattenPat pat pat_ts <- internaliseLoopParamType (E.patternType pat) ts let shape_params = [I.Param mempty v $ I.Prim I.int64 | E.TypeParamDim v _ <- tparams] shape_subst = M.fromList [(I.paramName p, [I.Var $ I.paramName p]) | p <- shape_params] bindingFlatPat pat_idents pat_ts $ \valueparams -> I.localScope (I.scopeOfFParams $ shape_params ++ concat valueparams) $ substitutingVars shape_subst $ m shape_params (concat valueparams) bindingLambdaParams :: [E.Pat E.ParamType] -> [I.Type] -> ([I.LParam I.SOACS] -> InternaliseM a) -> InternaliseM a bindingLambdaParams params ts m = do params_idents <- concat <$> mapM flattenPat params bindingFlatPat params_idents ts $ \params' -> I.localScope (I.scopeOfLParams $ concat params') $ m (concat params') type Params t = [I.Param t] processFlatPat :: (Show t) => [(E.Ident ParamType, [E.AttrInfo VName])] -> [t] -> InternaliseM ([Params t], VarSubsts) processFlatPat x y = processFlatPat' [] x y where processFlatPat' pat [] _ = do let (vs, substs) = unzip pat pure (reverse vs, M.fromList substs) processFlatPat' pat ((p, attrs) : rest) ts = do attrs' <- internaliseAttrs attrs (ps, rest_ts) <- handleMapping attrs' ts <$> internaliseBindee p processFlatPat' ((ps, (E.identName p, map (I.Var . I.paramName) ps)) : pat) rest rest_ts handleMapping _ ts [] = ([], ts) handleMapping attrs (t : ts) (r : rs) = let (ps, ts') = handleMapping attrs ts rs in (I.Param attrs r t : ps, ts') handleMapping _ [] _ = error $ "handleMapping: insufficient identifiers in pattern.\n" ++ show (x, y) internaliseBindee :: E.Ident E.ParamType -> InternaliseM [VName] internaliseBindee bindee = do let name = E.identName bindee case internalisedTypeSize $ E.unInfo $ E.identType bindee of 1 -> pure [name] n -> replicateM n $ newVName $ baseString name bindingFlatPat :: (Show t) => [(E.Ident E.ParamType, [E.AttrInfo VName])] -> [t] -> ([Params t] -> InternaliseM a) -> InternaliseM a bindingFlatPat idents ts m = do (ps, substs) <- processFlatPat idents ts local (\env -> env {envSubsts = substs `M.union` envSubsts env}) $ m ps -- | Flatten a pattern. Returns a list of identifiers. flattenPat :: (MonadFreshNames m) => E.Pat (TypeBase Size u) -> m [(E.Ident (TypeBase Size u), [E.AttrInfo VName])] flattenPat = flattenPat' where flattenPat' (E.PatParens p _) = flattenPat' p flattenPat' (E.PatAttr attr p _) = map (second (attr :)) <$> flattenPat' p flattenPat' (E.Wildcard t loc) = do name <- newVName "nameless" flattenPat' $ E.Id name t loc flattenPat' (E.Id v (Info t) loc) = pure [(E.Ident v (Info t) loc, mempty)] flattenPat' (E.TuplePat [] loc) = flattenPat' (E.Wildcard (Info $ E.Scalar $ E.Record mempty) loc) flattenPat' (E.RecordPat [] loc) = flattenPat' (E.Wildcard (Info $ E.Scalar $ E.Record mempty) loc) flattenPat' (E.TuplePat pats _) = concat <$> mapM flattenPat' pats flattenPat' (E.RecordPat fs loc) = flattenPat' $ E.TuplePat (map snd $ sortFields $ M.fromList $ map (first unLoc) fs) loc flattenPat' (E.PatAscription p _ _) = flattenPat' p flattenPat' (E.PatLit _ t loc) = flattenPat' $ E.Wildcard t loc flattenPat' (E.PatConstr _ _ ps _) = concat <$> mapM flattenPat' ps stmPat :: E.Pat E.ParamType -> [I.Type] -> ([VName] -> InternaliseM a) -> InternaliseM a stmPat pat ts m = do pat' <- flattenPat pat bindingFlatPat pat' ts $ m . map I.paramName . concat futhark-0.25.27/src/Futhark/Internalise/Defunctionalise.hs000066400000000000000000001364721475065116200234770ustar00rootroot00000000000000-- | Defunctionalization of typed, monomorphic Futhark programs without modules. module Futhark.Internalise.Defunctionalise (transformProg) where import Control.Monad import Control.Monad.Identity import Control.Monad.Reader import Control.Monad.State import Data.Bifoldable (bifoldMap) import Data.Bifunctor import Data.Bitraversable import Data.Foldable import Data.List (partition, sortOn) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.IR.Pretty () import Futhark.MonadFreshNames import Futhark.Util (mapAccumLM, nubOrd) import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Types (Subst (..), applySubst) -- | A static value stores additional information about the result of -- defunctionalization of an expression, aside from the residual expression. data StaticVal = Dynamic ParamType | -- | The Env is the lexical closure of the lambda. LambdaSV (Pat ParamType) ResRetType Exp Env | RecordSV [(Name, StaticVal)] | -- | The constructor that is actually present, plus -- the others that are not. SumSV Name [StaticVal] [(Name, [ParamType])] | -- | The pair is the StaticVal and residual expression of this -- function as a whole, while the second StaticVal is its -- body. (Don't trust this too much, my understanding may have -- holes.) DynamicFun (Exp, StaticVal) StaticVal | IntrinsicSV | HoleSV StructType SrcLoc deriving (Show) data Binding = Binding { -- | Just if this is a polymorphic binding that must be -- instantiated. bindingType :: Maybe ([VName], StructType), bindingSV :: StaticVal } deriving (Show) -- | Environment mapping variable names to their associated static -- value. type Env = M.Map VName Binding localEnv :: Env -> DefM a -> DefM a localEnv env = local $ second (env <>) -- Even when using a "new" environment (for evaluating closures) we -- still ram the global environment of DynamicFuns in there. localNewEnv :: Env -> DefM a -> DefM a localNewEnv env = local $ \(globals, old_env) -> (globals, M.filterWithKey (\k _ -> k `S.member` globals) old_env <> env) askEnv :: DefM Env askEnv = asks snd areGlobal :: [VName] -> DefM a -> DefM a areGlobal vs = local $ first (S.fromList vs <>) replaceTypeSizes :: M.Map VName SizeSubst -> TypeBase Size als -> TypeBase Size als replaceTypeSizes substs = first onDim where onDim (Var v typ loc) = case M.lookup (qualLeaf v) substs of Just (SubstNamed v') -> Var v' typ loc Just (SubstConst d) -> sizeFromInteger (toInteger d) loc Nothing -> Var v typ loc onDim d = d replaceStaticValSizes :: S.Set VName -> M.Map VName SizeSubst -> StaticVal -> StaticVal replaceStaticValSizes globals orig_substs sv = case sv of _ | M.null orig_substs -> sv LambdaSV param (RetType t_dims t) e closure_env -> let substs = foldl' (flip M.delete) orig_substs $ S.fromList (M.keys closure_env) in LambdaSV (fmap (replaceTypeSizes substs) param) (RetType t_dims (replaceTypeSizes substs t)) (onExp substs e) (onEnv orig_substs closure_env) -- intentional Dynamic t -> Dynamic $ replaceTypeSizes orig_substs t RecordSV fs -> RecordSV $ map (fmap (replaceStaticValSizes globals orig_substs)) fs SumSV c svs ts -> SumSV c (map (replaceStaticValSizes globals orig_substs) svs) $ map (fmap $ map $ replaceTypeSizes orig_substs) ts DynamicFun (e, sv1) sv2 -> DynamicFun (onExp orig_substs e, replaceStaticValSizes globals orig_substs sv1) $ replaceStaticValSizes globals orig_substs sv2 IntrinsicSV -> IntrinsicSV HoleSV t loc -> HoleSV t loc where tv substs = ASTMapper { mapOnStructType = pure . replaceTypeSizes substs, mapOnParamType = pure . replaceTypeSizes substs, mapOnResRetType = pure, mapOnExp = pure . onExp substs, mapOnName = pure . fmap (onName substs) } onName substs v = case M.lookup v substs of Just (SubstNamed v') -> qualLeaf v' _ -> v onExp substs (Var v t loc) = case M.lookup (qualLeaf v) substs of Just (SubstNamed v') -> Var v' t loc Just (SubstConst d) -> Literal (SignedValue (Int64Value (fromIntegral d))) loc Nothing -> Var v (replaceTypeSizes substs <$> t) loc onExp substs (Coerce e te t loc) = Coerce (onExp substs e) te (replaceTypeSizes substs <$> t) loc onExp substs (Lambda params e ret (Info (RetType t_dims t)) loc) = Lambda (map (fmap $ replaceTypeSizes substs) params) (onExp substs e) ret (Info (RetType t_dims (replaceTypeSizes substs t))) loc onExp substs e = runIdentity $ astMap (tv substs) e onEnv substs = M.fromList . map (second (onBinding substs)) . M.toList onBinding substs (Binding t bsv) = Binding (second (replaceTypeSizes substs) <$> t) (replaceStaticValSizes globals substs bsv) -- | Returns the defunctionalization environment restricted -- to the given set of variable names. restrictEnvTo :: FV -> DefM Env restrictEnvTo fv = asks restrict where restrict (globals, env) = M.mapMaybeWithKey keep env where keep k (Binding t sv) = do guard $ not (k `S.member` globals) && S.member k (fvVars fv) Just $ Binding t $ restrict' sv restrict' (Dynamic t) = Dynamic t restrict' (LambdaSV pat t e env) = LambdaSV pat t e $ M.map restrict'' env restrict' (RecordSV fields) = RecordSV $ map (fmap restrict') fields restrict' (SumSV c svs fields) = SumSV c (map restrict' svs) fields restrict' (DynamicFun (e, sv1) sv2) = DynamicFun (e, restrict' sv1) $ restrict' sv2 restrict' IntrinsicSV = IntrinsicSV restrict' (HoleSV t loc) = HoleSV t loc restrict'' (Binding t sv) = Binding t $ restrict' sv -- | Defunctionalization monad. The Reader environment tracks both -- the current Env as well as the set of globally defined dynamic -- functions. This is used to avoid unnecessarily large closure -- environments. newtype DefM a = DefM (ReaderT (S.Set VName, Env) (State ([ValBind], VNameSource)) a) deriving ( Functor, Applicative, Monad, MonadReader (S.Set VName, Env), MonadState ([ValBind], VNameSource) ) instance MonadFreshNames DefM where putNameSource src = modify $ \(x, _) -> (x, src) getNameSource = gets snd -- | Run a computation in the defunctionalization monad. Returns the result of -- the computation, a new name source, and a list of lifted function declations. runDefM :: VNameSource -> DefM a -> (a, VNameSource, [ValBind]) runDefM src (DefM m) = let (x, (vbs, src')) = runState (runReaderT m mempty) (mempty, src) in (x, src', reverse vbs) addValBind :: ValBind -> DefM () addValBind vb = modify $ first (vb :) -- | Create a new top-level value declaration with the given function name, -- return type, list of parameters, and body expression. liftValDec :: VName -> ResRetType -> [VName] -> [Pat ParamType] -> Exp -> DefM () liftValDec fname (RetType ret_dims ret) dims pats body = addValBind dec where dims' = map (`TypeParamDim` mempty) dims -- FIXME: this pass is still not correctly size-preserving, so -- forget those return sizes that we forgot to propagate along -- the way. Hopefully the internaliser is conservative and -- will insert reshapes... bound_here = S.fromList $ dims <> foldMap patNames pats mkExt v | not $ v `S.member` bound_here = Just v mkExt _ = Nothing rettype_st = RetType (mapMaybe mkExt (S.toList $ fvVars $ freeInType ret) ++ ret_dims) ret dec = ValBind { valBindEntryPoint = Nothing, valBindName = fname, valBindRetDecl = Nothing, valBindRetType = Info rettype_st, valBindTypeParams = dims', valBindParams = pats, valBindBody = body, valBindDoc = Nothing, valBindAttrs = mempty, valBindLocation = mempty } -- | Looks up the associated static value for a given name in the environment. lookupVar :: StructType -> VName -> DefM StaticVal lookupVar t x = do env <- askEnv case M.lookup x env of Just (Binding (Just (dims, sv_t)) sv) -> do globals <- asks fst instStaticVal globals dims t sv_t sv Just (Binding Nothing sv) -> pure sv Nothing -- If the variable is unknown, it may refer to the 'intrinsics' -- module, which we will have to treat specially. | baseTag x <= maxIntrinsicTag -> pure IntrinsicSV | otherwise -> -- Anything not in scope is going to be an existential size. pure $ Dynamic $ Scalar $ Prim $ Signed Int64 -- Like freeInPat, but ignores sizes that are only found in -- funtion types. arraySizes :: StructType -> S.Set VName arraySizes (Scalar Arrow {}) = mempty arraySizes (Scalar (Record fields)) = foldMap arraySizes fields arraySizes (Scalar (Sum cs)) = foldMap (foldMap arraySizes) cs arraySizes (Scalar (TypeVar _ _ targs)) = mconcat $ map f targs where f (TypeArgDim (Var d _ _)) = S.singleton $ qualLeaf d f TypeArgDim {} = mempty f (TypeArgType t) = arraySizes t arraySizes (Scalar Prim {}) = mempty arraySizes (Array _ shape t) = arraySizes (Scalar t) <> foldMap dimName (shapeDims shape) where dimName :: Size -> S.Set VName dimName (Var qn _ _) = S.singleton $ qualLeaf qn dimName _ = mempty patternArraySizes :: Pat ParamType -> S.Set VName patternArraySizes = arraySizes . patternStructType data SizeSubst = SubstNamed (QualName VName) | SubstConst Int64 deriving (Eq, Ord, Show) dimMapping :: (Monoid a) => TypeBase Size a -> TypeBase Size a -> M.Map VName SizeSubst dimMapping t1 t2 = execState (matchDims f t1 t2) mempty where f bound d1 (Var d2 _ _) | qualLeaf d2 `elem` bound = pure d1 f _ (Var d1 typ loc) (Var d2 _ _) = do modify $ M.insert (qualLeaf d1) $ SubstNamed d2 pure $ Var d1 typ loc f _ (Var d1 typ loc) (IntLit d2 _ _) = do modify $ M.insert (qualLeaf d1) $ SubstConst $ fromInteger d2 pure $ Var d1 typ loc f _ d _ = pure d dimMapping' :: (Monoid a) => TypeBase Size a -> TypeBase Size a -> M.Map VName VName dimMapping' t1 t2 = M.mapMaybe f $ dimMapping t1 t2 where f (SubstNamed d) = Just $ qualLeaf d f _ = Nothing sizesToRename :: StaticVal -> S.Set VName sizesToRename (DynamicFun (_, sv1) sv2) = sizesToRename sv1 <> sizesToRename sv2 sizesToRename IntrinsicSV = mempty sizesToRename HoleSV {} = mempty sizesToRename Dynamic {} = mempty sizesToRename (RecordSV fs) = foldMap (sizesToRename . snd) fs sizesToRename (SumSV _ svs _) = foldMap sizesToRename svs sizesToRename (LambdaSV param _ _ _) = -- We used to rename parameters here, but I don't understand why -- that was necessary and it caused some problems. fvVars (freeInPat param) -- | Combine the shape information of types as much as possible. The first -- argument is the orignal type and the second is the type of the transformed -- expression. This is necessary since the original type may contain additional -- information (e.g., shape restrictions) from the user given annotation. combineTypeShapes :: (Monoid as) => TypeBase Size as -> TypeBase Size as -> TypeBase Size as combineTypeShapes (Scalar (Record ts1)) (Scalar (Record ts2)) | M.keys ts1 == M.keys ts2 = Scalar $ Record $ M.map (uncurry combineTypeShapes) (M.intersectionWith (,) ts1 ts2) combineTypeShapes (Scalar (Sum cs1)) (Scalar (Sum cs2)) | M.keys cs1 == M.keys cs2 = Scalar $ Sum $ M.map (uncurry $ zipWith combineTypeShapes) (M.intersectionWith (,) cs1 cs2) combineTypeShapes (Scalar (Arrow als1 p1 d1 a1 (RetType dims1 b1))) (Scalar (Arrow als2 _p2 _d2 a2 (RetType _ b2))) = Scalar $ Arrow (als1 <> als2) p1 d1 (combineTypeShapes a1 a2) (RetType dims1 (combineTypeShapes b1 b2)) combineTypeShapes (Scalar (TypeVar u v targs1)) (Scalar (TypeVar _ _ targs2)) = Scalar $ TypeVar u v $ zipWith f targs1 targs2 where f (TypeArgType t1) (TypeArgType t2) = TypeArgType (combineTypeShapes t1 t2) f targ _ = targ combineTypeShapes (Array u shape1 et1) (Array _ _shape2 et2) = arrayOfWithAliases u shape1 (combineTypeShapes (setUniqueness (Scalar et1) u) (setUniqueness (Scalar et2) u)) combineTypeShapes _ t = t -- When we instantiate a polymorphic StaticVal, we rename all the -- sizes to avoid name conflicts later on. This is a bit of a hack... instStaticVal :: (MonadFreshNames m) => S.Set VName -> [VName] -> StructType -> StructType -> StaticVal -> m StaticVal instStaticVal globals dims t sv_t sv = do fresh_substs <- mkSubsts . filter (`S.notMember` globals) . S.toList $ S.fromList dims <> sizesToRename sv let dims' = map (onName fresh_substs) dims isDim k _ = k `elem` dims' dim_substs = M.filterWithKey isDim $ dimMapping (replaceTypeSizes fresh_substs sv_t) t replace (SubstNamed k) = fromMaybe (SubstNamed k) $ M.lookup (qualLeaf k) dim_substs replace k = k substs = M.map replace fresh_substs <> dim_substs pure $ replaceStaticValSizes globals substs sv where mkSubsts names = M.fromList . zip names . map (SubstNamed . qualName) <$> mapM newName names onName substs v = case M.lookup v substs of Just (SubstNamed v') -> qualLeaf v' _ -> v defuncFun :: [VName] -> [Pat ParamType] -> Exp -> ResRetType -> SrcLoc -> DefM (Exp, StaticVal) defuncFun tparams pats e0 ret loc = do -- Extract the first parameter of the lambda and "push" the -- remaining ones (if there are any) into the body of the lambda. let (pat, ret', e0') = case pats of [] -> error "Received a lambda with no parameters." [pat'] -> (pat', ret, e0) (pat' : pats') -> ( pat', RetType [] $ second (const Nonunique) $ funType pats' ret, Lambda pats' e0 Nothing (Info ret) loc ) -- Construct a record literal that closes over the environment of -- the lambda. Closed-over 'DynamicFun's are converted to their -- closure representation. let used = freeInExp (Lambda pats e0 Nothing (Info ret) loc) `freeWithout` S.fromList tparams used_env <- restrictEnvTo used -- The closure parts that are sizes are proactively turned into size -- parameters. let sizes_of_arrays = foldMap (arraySizes . structTypeFromSV . bindingSV) used_env <> patternArraySizes pat notSize = not . (`S.member` sizes_of_arrays) (fields, env) = second M.fromList . unzip . map closureFromDynamicFun . filter (notSize . fst) $ M.toList used_env pure ( RecordLit fields loc, LambdaSV pat ret' e0' env ) where closureFromDynamicFun (vn, Binding _ (DynamicFun (clsr_env, sv) _)) = let name = nameFromString $ prettyString vn in ( RecordFieldExplicit (L noLoc name) clsr_env mempty, (vn, Binding Nothing sv) ) closureFromDynamicFun (vn, Binding _ sv) = let name = nameFromString $ prettyString vn tp' = structTypeFromSV sv in ( RecordFieldExplicit (L noLoc name) (Var (qualName vn) (Info tp') mempty) mempty, (vn, Binding Nothing sv) ) -- | Defunctionalization of an expression. Returns the residual expression and -- the associated static value in the defunctionalization monad. defuncExp :: Exp -> DefM (Exp, StaticVal) defuncExp e@Literal {} = pure (e, Dynamic $ toParam Observe $ typeOf e) defuncExp e@IntLit {} = pure (e, Dynamic $ toParam Observe $ typeOf e) defuncExp e@FloatLit {} = pure (e, Dynamic $ toParam Observe $ typeOf e) defuncExp e@StringLit {} = pure (e, Dynamic $ toParam Observe $ typeOf e) defuncExp (Parens e loc) = do (e', sv) <- defuncExp e pure (Parens e' loc, sv) defuncExp (QualParens qn e loc) = do (e', sv) <- defuncExp e pure (QualParens qn e' loc, sv) defuncExp (TupLit es loc) = do (es', svs) <- mapAndUnzipM defuncExp es pure (TupLit es' loc, RecordSV $ zip tupleFieldNames svs) defuncExp (RecordLit fs loc) = do (fs', names_svs) <- mapAndUnzipM defuncField fs pure (RecordLit fs' loc, RecordSV names_svs) where defuncField (RecordFieldExplicit vn e loc') = do (e', sv) <- defuncExp e pure (RecordFieldExplicit vn e' loc', (unLoc vn, sv)) defuncField (RecordFieldImplicit (L _ vn) (Info t) loc') = do sv <- lookupVar (toStruct t) vn case sv of -- If the implicit field refers to a dynamic function, we -- convert it to an explicit field with a record closing over -- the environment and bind the corresponding static value. DynamicFun (e, sv') _ -> let vn' = baseName vn in pure ( RecordFieldExplicit (L noLoc vn') e loc', (vn', sv') ) -- The field may refer to a functional expression, so we get the -- type from the static value and not the one from the AST. _ -> let tp = Info $ structTypeFromSV sv in pure ( RecordFieldImplicit (L noLoc vn) tp loc', (baseName vn, sv) ) defuncExp e@(ArrayVal vs t loc) = pure (ArrayVal vs t loc, Dynamic $ toParam Observe $ typeOf e) defuncExp (ArrayLit es t@(Info t') loc) = do es' <- mapM defuncExp' es pure (ArrayLit es' t loc, Dynamic $ toParam Observe t') defuncExp (AppExp (Range e1 me incl loc) res) = do e1' <- defuncExp' e1 me' <- mapM defuncExp' me incl' <- mapM defuncExp' incl pure ( AppExp (Range e1' me' incl' loc) res, Dynamic $ toParam Observe $ appResType $ unInfo res ) defuncExp e@(Var qn (Info t) loc) = do sv <- lookupVar (toStruct t) (qualLeaf qn) case sv of -- If the variable refers to a dynamic function, we eta-expand it -- so that we do not have to duplicate its definition. DynamicFun {} -> do (params, body, ret) <- etaExpand (RetType [] $ toRes Nonunique t) e defuncFun [] params body ret mempty -- Intrinsic functions used as variables are eta-expanded, so we -- can get rid of them. IntrinsicSV -> do (pats, body, tp) <- etaExpand (RetType [] $ toRes Nonunique t) e defuncExp $ Lambda pats body Nothing (Info tp) mempty HoleSV _ hole_loc -> pure (Hole (Info t) hole_loc, sv) _ -> pure (Var qn (Info (structTypeFromSV sv)) loc, sv) defuncExp (Hole (Info t) loc) = pure (Hole (Info t) loc, HoleSV t loc) defuncExp (Ascript e0 tydecl loc) | orderZero (typeOf e0) = do (e0', sv) <- defuncExp e0 pure (Ascript e0' tydecl loc, sv) | otherwise = defuncExp e0 defuncExp (Coerce e0 tydecl t loc) | orderZero (typeOf e0) = do (e0', sv) <- defuncExp e0 pure (Coerce e0' tydecl t loc, sv) | otherwise = defuncExp e0 defuncExp (AppExp (LetPat sizes pat e1 e2 loc) (Info (AppRes t retext))) = do (e1', sv1) <- defuncExp e1 let env = alwaysMatchPatSV (fmap (toParam Observe) pat) sv1 pat' = updatePat (fmap (toParam Observe) pat) sv1 (e2', sv2) <- localEnv env $ defuncExp e2 -- To maintain any sizes going out of scope, we need to compute the -- old size substitution induced by retext and also apply it to the -- newly computed body type. let mapping = dimMapping' (typeOf e2) t subst v = ExpSubst . flip sizeFromName mempty . qualName <$> M.lookup v mapping t' = applySubst subst $ typeOf e2' pure (AppExp (LetPat sizes (fmap toStruct pat') e1' e2' loc) (Info (AppRes t' retext)), sv2) defuncExp (AppExp (LetFun vn _ _ _) _) = error $ "defuncExp: Unexpected LetFun: " ++ show vn defuncExp (AppExp (If e1 e2 e3 loc) res) = do (e1', _) <- defuncExp e1 (e2', sv) <- defuncExp e2 (e3', _) <- defuncExp e3 pure (AppExp (If e1' e2' e3' loc) res, sv) defuncExp (AppExp (Apply f args loc) (Info appres)) = defuncApply f (fmap (first unInfo) args) appres loc defuncExp (Negate e0 loc) = do (e0', sv) <- defuncExp e0 pure (Negate e0' loc, sv) defuncExp (Not e0 loc) = do (e0', sv) <- defuncExp e0 pure (Not e0' loc, sv) defuncExp (Lambda pats e0 _ (Info ret) loc) = defuncFun [] pats e0 ret loc -- Operator sections are expected to be converted to lambda-expressions -- by the monomorphizer, so they should no longer occur at this point. defuncExp OpSection {} = error "defuncExp: unexpected operator section." defuncExp OpSectionLeft {} = error "defuncExp: unexpected operator section." defuncExp OpSectionRight {} = error "defuncExp: unexpected operator section." defuncExp ProjectSection {} = error "defuncExp: unexpected projection section." defuncExp IndexSection {} = error "defuncExp: unexpected projection section." defuncExp (AppExp (Loop sparams pat loopinit form e3 loc) res) = do (e1', sv1) <- defuncExp $ loopInitExp loopinit let env1 = alwaysMatchPatSV pat sv1 (form', env2) <- case form of For v e2 -> do e2' <- defuncExp' e2 pure (For v e2', envFromIdent v) ForIn pat2 e2 -> do e2' <- defuncExp' e2 pure (ForIn pat2 e2', envFromPat $ fmap (toParam Observe) pat2) While e2 -> do e2' <- localEnv env1 $ defuncExp' e2 pure (While e2', mempty) (e3', sv) <- localEnv (env1 <> env2) $ defuncExp e3 pure (AppExp (Loop sparams pat (LoopInitExplicit e1') form' e3' loc) res, sv) where envFromIdent (Ident vn (Info tp) _) = M.singleton vn $ Binding Nothing $ Dynamic $ toParam Observe tp defuncExp e@(AppExp BinOp {} _) = error $ "defuncExp: unexpected binary operator: " ++ prettyString e defuncExp (Project vn e0 tp@(Info tp') loc) = do (e0', sv0) <- defuncExp e0 case sv0 of RecordSV svs -> case lookup vn svs of Just sv -> pure (Project vn e0' (Info $ structTypeFromSV sv) loc, sv) Nothing -> error "Invalid record projection." Dynamic _ -> pure (Project vn e0' tp loc, Dynamic $ toParam Observe tp') HoleSV _ hloc -> pure (Project vn e0' tp loc, HoleSV tp' hloc) _ -> error $ "Projection of an expression with static value " ++ show sv0 defuncExp (AppExp LetWith {} _) = error "defuncExp: unexpected LetWith" defuncExp expr@(AppExp (Index e0 idxs loc) res) = do e0' <- defuncExp' e0 idxs' <- mapM defuncDimIndex idxs pure ( AppExp (Index e0' idxs' loc) res, Dynamic $ toParam Observe $ typeOf expr ) defuncExp (Update e1 idxs e2 loc) = do (e1', sv) <- defuncExp e1 idxs' <- mapM defuncDimIndex idxs e2' <- defuncExp' e2 pure (Update e1' idxs' e2' loc, sv) -- Note that we might change the type of the record field here. This -- is not permitted in the type checker due to problems with type -- inference, but it actually works fine. defuncExp (RecordUpdate e1 fs e2 _ loc) = do (e1', sv1) <- defuncExp e1 (e2', sv2) <- defuncExp e2 let sv = staticField sv1 sv2 fs pure ( RecordUpdate e1' fs e2' (Info $ structTypeFromSV sv1) loc, sv ) where staticField (RecordSV svs) sv2 (f : fs') = case lookup f svs of Just sv -> RecordSV $ (f, staticField sv sv2 fs') : filter ((/= f) . fst) svs Nothing -> error "Invalid record projection." staticField (Dynamic t@(Scalar Record {})) sv2 fs'@(_ : _) = staticField (svFromType t) sv2 fs' staticField _ sv2 _ = sv2 defuncExp (Assert e1 e2 desc loc) = do (e1', _) <- defuncExp e1 (e2', sv) <- defuncExp e2 pure (Assert e1' e2' desc loc, sv) defuncExp (Constr name es (Info sum_t@(Scalar (Sum all_fs))) loc) = do (es', svs) <- mapAndUnzipM defuncExp es let sv = SumSV name svs $ M.toList $ name `M.delete` M.map (map (toParam Observe . defuncType)) all_fs sum_t' = combineTypeShapes sum_t (structTypeFromSV sv) pure (Constr name es' (Info sum_t') loc, sv) where defuncType :: (Monoid als) => TypeBase Size als -> TypeBase Size als defuncType (Array u shape t) = Array u shape (defuncScalar t) defuncType (Scalar t) = Scalar $ defuncScalar t defuncScalar :: (Monoid als) => ScalarTypeBase Size als -> ScalarTypeBase Size als defuncScalar (Record fs) = Record $ M.map defuncType fs defuncScalar Arrow {} = Record mempty defuncScalar (Sum fs) = Sum $ M.map (map defuncType) fs defuncScalar (Prim t) = Prim t defuncScalar (TypeVar u tn targs) = TypeVar u tn targs defuncExp (Constr name _ (Info t) loc) = error $ "Constructor " ++ prettyString name ++ " given type " ++ prettyString t ++ " at " ++ locStr loc defuncExp (AppExp (Match e cs loc) res) = do (e', sv) <- defuncExp e let bad = error $ "No case matches StaticVal\n" <> show sv csPairs <- fromMaybe bad . NE.nonEmpty . catMaybes <$> mapM (defuncCase sv) (NE.toList cs) let cs' = fmap fst csPairs sv' = snd $ NE.head csPairs pure (AppExp (Match e' cs' loc) res, sv') defuncExp (Attr info e loc) = do (e', sv) <- defuncExp e pure (Attr info e' loc, sv) -- | Same as 'defuncExp', except it ignores the static value. defuncExp' :: Exp -> DefM Exp defuncExp' = fmap fst . defuncExp defuncCase :: StaticVal -> Case -> DefM (Maybe (Case, StaticVal)) defuncCase sv (CasePat p e loc) = do let p' = updatePat (fmap (toParam Observe) p) sv case matchPatSV (fmap (toParam Observe) p) sv of Just env -> do (e', sv') <- localEnv env $ defuncExp e pure $ Just (CasePat (fmap toStruct p') e' loc, sv') Nothing -> pure Nothing -- | Defunctionalize the function argument to a SOAC by eta-expanding if -- necessary and then defunctionalizing the body of the introduced lambda. defuncSoacExp :: Exp -> DefM Exp defuncSoacExp e@OpSection {} = pure e defuncSoacExp e@OpSectionLeft {} = pure e defuncSoacExp e@OpSectionRight {} = pure e defuncSoacExp e@ProjectSection {} = pure e defuncSoacExp (Parens e loc) = Parens <$> defuncSoacExp e <*> pure loc defuncSoacExp (Lambda params e0 decl tp loc) = do let env = foldMap envFromPat params e0' <- localEnv env $ defuncSoacExp e0 pure $ Lambda params e0' decl tp loc defuncSoacExp e | Scalar Arrow {} <- typeOf e = do (pats, body, tp) <- etaExpand (RetType [] $ toRes Nonunique $ typeOf e) e let env = foldMap envFromPat pats body' <- localEnv env $ defuncExp' body pure $ Lambda pats body' Nothing (Info tp) mempty | otherwise = defuncExp' e etaExpand :: ResRetType -> Exp -> DefM ([Pat ParamType], Exp, ResRetType) etaExpand e_t e = do let (ps, ret) = getType e_t -- Some careful hackery to avoid duplicate names. (_, (params, vars)) <- second unzip <$> mapAccumLM f [] ps -- Important that we synthesize new existential names and substitute -- them into the (body) return type. ext' <- mapM newName $ retDims ret let extsubst = M.fromList . zip (retDims ret) $ map (ExpSubst . flip sizeFromName mempty . qualName) ext' ret' = applySubst (`M.lookup` extsubst) ret e' = mkApply e (map (Nothing,) vars) $ AppRes (toStruct $ retType ret') ext' pure (params, e', ret) where getType (RetType _ (Scalar (Arrow _ p d t1 t2))) = let (ps, r) = getType t2 in ((p, (d, t1)) : ps, r) getType t = ([], t) f prev (p, (d, t)) = do let t' = second (const d) t x <- case p of Named x | x `notElem` prev -> pure x _ -> newNameFromString "eta_p" pure ( x : prev, ( Id x (Info t') mempty, Var (qualName x) (Info $ toStruct t') mempty ) ) -- | Defunctionalize an indexing of a single array dimension. defuncDimIndex :: DimIndexBase Info VName -> DefM (DimIndexBase Info VName) defuncDimIndex (DimFix e1) = DimFix . fst <$> defuncExp e1 defuncDimIndex (DimSlice me1 me2 me3) = DimSlice <$> defunc' me1 <*> defunc' me2 <*> defunc' me3 where defunc' = mapM defuncExp' envFromDimNames :: [VName] -> Env envFromDimNames = M.fromList . flip zip (repeat d) where d = Binding Nothing $ Dynamic $ Scalar $ Prim $ Signed Int64 -- | Defunctionalize a let-bound function, while preserving parameters -- that have order 0 types (i.e., non-functional). defuncLet :: [VName] -> [Pat ParamType] -> Exp -> ResRetType -> DefM ([VName], [Pat ParamType], Exp, StaticVal, ResType) defuncLet dims ps@(pat : pats) body (RetType ret_dims rettype) | patternOrderZero pat = do let bound_by_pat = (`S.member` fvVars (freeInPat pat)) -- Take care to not include more size parameters than necessary. (pat_dims, rest_dims) = partition bound_by_pat dims env = envFromPat pat <> envFromDimNames pat_dims (rest_dims', pats', body', sv, sv_t) <- localEnv env $ defuncLet rest_dims pats body $ RetType ret_dims rettype closure <- defuncFun dims ps body (RetType ret_dims rettype) mempty pure ( pat_dims ++ rest_dims', pat : pats', body', DynamicFun closure sv, sv_t ) | otherwise = do (e, sv) <- defuncFun dims ps body (RetType ret_dims rettype) mempty pure ([], [], e, sv, resTypeFromSV sv) defuncLet _ [] body (RetType _ rettype) = do (body', sv) <- defuncExp body pure ( [], [], body', imposeType sv $ resToParam rettype, resTypeFromSV sv ) where imposeType Dynamic {} t = Dynamic t imposeType (RecordSV fs1) (Scalar (Record fs2)) = RecordSV $ M.toList $ M.intersectionWith imposeType (M.fromList fs1) fs2 imposeType sv _ = sv instAnySizes :: (MonadFreshNames m) => [Pat ParamType] -> m [Pat ParamType] instAnySizes = traverse $ traverse $ bitraverse onDim pure where onDim d | d == anySize = do v <- newVName "size" pure $ sizeFromName (qualName v) mempty onDim d = pure d unboundSizes :: S.Set VName -> [Pat ParamType] -> [VName] unboundSizes bound_sizes params = nubOrd $ execState (f params) [] where f = traverse $ traverse $ bitraverse onDim pure bound = bound_sizes <> S.fromList (foldMap patNames params) onDim (Var d typ loc) = do unless (qualLeaf d `S.member` bound) $ modify (qualLeaf d :) pure $ Var d typ loc onDim d = pure d unRetType :: ResRetType -> DefM AppRes unRetType (RetType [] t) = pure $ AppRes (toStruct t) [] unRetType (RetType ext t) = do ext' <- mapM newName ext let extsubst = M.fromList . zip ext $ map (ExpSubst . flip sizeFromName mempty . qualName) ext' pure $ AppRes (applySubst (`M.lookup` extsubst) $ toStruct t) ext' defuncApplyFunction :: Exp -> Int -> DefM (Exp, StaticVal) defuncApplyFunction e@(Var qn (Info t) loc) num_args = do let (argtypes, rettype) = unfoldFunType t sv <- lookupVar (toStruct t) (qualLeaf qn) case sv of DynamicFun _ _ | fullyApplied sv num_args -> do -- We still need to update the types in case the dynamic -- function returns a higher-order term. let (argtypes', rettype') = dynamicFunType sv argtypes pure (Var qn (Info (foldFunType argtypes' $ RetType [] rettype')) loc, sv) | all orderZero argtypes, orderZero rettype -> do (params, body, ret) <- etaExpand (RetType [] $ toRes Nonunique t) e defuncFun [] params body ret mempty | otherwise -> do fname <- newVName $ "dyn_" <> baseString (qualLeaf qn) let (pats, e0, sv') = liftDynFun (prettyString qn) sv num_args (argtypes', rettype') = dynamicFunType sv' argtypes dims' = mempty -- Ensure that no parameter sizes are AnySize. The internaliser -- expects this. This is easy, because they are all -- first-order. globals <- asks fst let bound_sizes = S.fromList dims' <> globals pats' <- instAnySizes pats liftValDec fname (RetType [] rettype') (dims' ++ unboundSizes bound_sizes pats') pats' e0 pure ( Var (qualName fname) (Info (foldFunType argtypes' $ RetType [] rettype')) loc, sv' ) IntrinsicSV -> pure (e, IntrinsicSV) _ -> pure (Var qn (Info (structTypeFromSV sv)) loc, sv) defuncApplyFunction e _ = defuncExp e -- Embed some information about the original function -- into the name of the lifted function, to make the -- result slightly more human-readable. liftedName :: Int -> Exp -> String liftedName i (Var f _ _) = "defunc_" ++ show i ++ "_" ++ baseString (qualLeaf f) liftedName i (AppExp (Apply f _ _) _) = liftedName (i + 1) f liftedName _ _ = "defunc" defuncApplyArg :: String -> (Exp, StaticVal) -> ((Maybe VName, Exp), [ParamType]) -> DefM (Exp, StaticVal) defuncApplyArg fname_s (f', LambdaSV pat lam_e_t lam_e closure_env) ((argext, arg), _) = do (arg', arg_sv) <- defuncExp arg let env' = alwaysMatchPatSV pat arg_sv dims = mempty (lam_e', sv) <- localNewEnv (env' <> closure_env) $ defuncExp lam_e let closure_pat = buildEnvPat dims closure_env pat' = updatePat pat arg_sv globals <- asks fst -- Lift lambda to top-level function definition. We put in -- a lot of effort to try to infer the uniqueness attributes -- of the lifted function, but this is ultimately all a sham -- and a hack. There is some piece we're missing. let params = [closure_pat, pat'] lifted_rettype = RetType (retDims lam_e_t) $ combineTypeShapes (retType lam_e_t) (resTypeFromSV sv) already_bound = globals <> S.fromList (dims <> foldMap patNames params) more_dims = S.toList $ S.filter (`S.notMember` already_bound) $ foldMap patternArraySizes params -- Ensure that no parameter sizes are AnySize. The internaliser -- expects this. This is easy, because they are all -- first-order. let bound_sizes = S.fromList (dims <> more_dims) <> globals params' <- instAnySizes params fname <- newNameFromString fname_s liftValDec fname lifted_rettype (dims ++ more_dims ++ unboundSizes bound_sizes params') params' lam_e' let f_t = toStruct $ typeOf f' arg_t = toStruct $ typeOf arg' fname_t = foldFunType [toParam Observe f_t, toParam (diet (patternType pat)) arg_t] lifted_rettype fname' = Var (qualName fname) (Info fname_t) (srclocOf arg) callret <- unRetType lifted_rettype pure ( mkApply fname' [(Nothing, f'), (argext, arg')] callret, sv ) -- If 'f' is a dynamic function, we just leave the application in -- place, but we update the types since it may be partially -- applied or return a higher-order value. defuncApplyArg _ (f', DynamicFun _ sv) ((argext, arg), argtypes) = do (arg', _) <- defuncExp arg let (argtypes', rettype) = dynamicFunType sv argtypes restype = foldFunType argtypes' (RetType [] rettype) callret = AppRes restype [] apply_e = mkApply f' [(argext, arg')] callret pure (apply_e, sv) -- defuncApplyArg fname_s (_, sv) ((_, arg), _) = error $ "defuncApplyArg: cannot apply StaticVal\n" <> show sv <> "\nFunction name: " <> prettyString fname_s <> "\nArgument: " <> prettyString arg updateReturn :: AppRes -> Exp -> Exp updateReturn (AppRes ret1 ext1) (AppExp apply (Info (AppRes ret2 ext2))) = AppExp apply $ Info $ AppRes (combineTypeShapes ret1 ret2) (ext1 <> ext2) updateReturn _ e = e defuncApply :: Exp -> NE.NonEmpty (Maybe VName, Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal) defuncApply f args appres loc = do (f', f_sv) <- defuncApplyFunction f (length args) case f_sv of IntrinsicSV -> do args' <- fmap (first Info) <$> traverse (traverse defuncSoacExp) args let e' = AppExp (Apply f' args' loc) (Info appres) intrinsicOrHole e' HoleSV {} -> do args' <- fmap (first Info) <$> traverse (traverse $ fmap fst . defuncExp) args let e' = AppExp (Apply f' args' loc) (Info appres) intrinsicOrHole e' _ -> do let fname = liftedName 0 f (argtypes, _) = unfoldFunType $ typeOf f fmap (first $ updateReturn appres) $ foldM (defuncApplyArg fname) (f', f_sv) $ NE.zip args $ NE.tails argtypes where intrinsicOrHole e' = do -- If the intrinsic is fully applied, then we are done. -- Otherwise we need to eta-expand it and recursively -- defunctionalise. XXX: might it be better to simply eta-expand -- immediately any time we encounter a non-fully-applied -- intrinsic? if null $ fst $ unfoldFunType $ appResType appres then pure (e', Dynamic $ toParam Observe $ appResType appres) else do (pats, body, tp) <- etaExpand (RetType [] $ toRes Nonunique $ typeOf e') e' defuncExp $ Lambda pats body Nothing (Info tp) mempty -- | Check if a 'StaticVal' and a given application depth corresponds -- to a fully applied dynamic function. fullyApplied :: StaticVal -> Int -> Bool fullyApplied (DynamicFun _ sv) depth | depth == 0 = False | depth > 0 = fullyApplied sv (depth - 1) fullyApplied _ _ = True -- | Converts a dynamic function 'StaticVal' into a list of -- dimensions, a list of parameters, a function body, and the -- appropriate static value for applying the function at the given -- depth of partial application. liftDynFun :: String -> StaticVal -> Int -> ([Pat ParamType], Exp, StaticVal) liftDynFun _ (DynamicFun (e, sv) _) 0 = ([], e, sv) liftDynFun s (DynamicFun clsr@(_, LambdaSV pat _ _ _) sv) d | d > 0 = let (pats, e', sv') = liftDynFun s sv (d - 1) in (pat : pats, e', DynamicFun clsr sv') liftDynFun s sv d = error $ s ++ " Tried to lift a StaticVal " ++ take 100 (show sv) ++ ", but expected a dynamic function.\n" ++ prettyString d -- | Converts a pattern to an environment that binds the individual names of the -- pattern to their corresponding types wrapped in a 'Dynamic' static value. envFromPat :: Pat ParamType -> Env envFromPat pat = case pat of TuplePat ps _ -> foldMap envFromPat ps RecordPat fs _ -> foldMap (envFromPat . snd) fs PatParens p _ -> envFromPat p PatAttr _ p _ -> envFromPat p Id vn (Info t) _ -> M.singleton vn $ Binding Nothing $ Dynamic t Wildcard _ _ -> mempty PatAscription p _ _ -> envFromPat p PatLit {} -> mempty PatConstr _ _ ps _ -> foldMap envFromPat ps -- | Given a closure environment, construct a record pattern that -- binds the closed over variables. Insert wildcard for any patterns -- that would otherwise clash with size parameters. buildEnvPat :: [VName] -> Env -> Pat ParamType buildEnvPat sizes env = RecordPat (map buildField $ M.toList env) mempty where buildField (vn, Binding _ sv) = ( L noLoc $ nameFromText (prettyText vn), if vn `elem` sizes then Wildcard (Info $ paramTypeFromSV sv) mempty else Id vn (Info $ paramTypeFromSV sv) mempty ) -- | Compute the corresponding type for the *representation* of a -- given static value (not the original possibly higher-order value). typeFromSV :: StaticVal -> ParamType typeFromSV (Dynamic tp) = tp typeFromSV (LambdaSV _ _ _ env) = Scalar . Record . M.fromList $ map (bimap (nameFromString . prettyString) (typeFromSV . bindingSV)) $ M.toList env typeFromSV (RecordSV ls) = let ts = map (fmap typeFromSV) ls in Scalar $ Record $ M.fromList ts typeFromSV (DynamicFun (_, sv) _) = typeFromSV sv typeFromSV (SumSV name svs fields) = let svs' = map typeFromSV svs in Scalar $ Sum $ M.insert name svs' $ M.fromList fields typeFromSV (HoleSV t _) = toParam Observe t typeFromSV IntrinsicSV = error "Tried to get the type from the static value of an intrinsic." resTypeFromSV :: StaticVal -> ResType resTypeFromSV = paramToRes . typeFromSV structTypeFromSV :: StaticVal -> StructType structTypeFromSV = toStruct . typeFromSV paramTypeFromSV :: StaticVal -> ParamType paramTypeFromSV = typeFromSV -- | Construct the type for a fully-applied dynamic function from its -- static value and the original types of its arguments. dynamicFunType :: StaticVal -> [ParamType] -> ([ParamType], ResType) dynamicFunType (DynamicFun _ sv) (p : ps) = let (ps', ret) = dynamicFunType sv ps in (p : ps', ret) dynamicFunType sv _ = ([], resTypeFromSV sv) -- | Match a pattern with its static value. Returns an environment -- with the identifier components of the pattern mapped to the -- corresponding subcomponents of the static value. If this function -- returns 'Nothing', then it corresponds to an unmatchable case. -- These should only occur for 'Match' expressions. matchPatSV :: Pat ParamType -> StaticVal -> Maybe Env matchPatSV (TuplePat ps _) (RecordSV ls) = mconcat <$> zipWithM (\p (_, sv) -> matchPatSV p sv) ps ls matchPatSV (RecordPat ps _) (RecordSV ls) | ps' <- sortOn fst $ map (first unLoc) ps, ls' <- sortOn fst ls, map fst ps' == map fst ls' = mconcat <$> zipWithM (\(_, p) (_, sv) -> matchPatSV p sv) ps' ls' matchPatSV (PatParens pat _) sv = matchPatSV pat sv matchPatSV (PatAttr _ pat _) sv = matchPatSV pat sv matchPatSV (Id vn (Info t) _) sv = -- When matching a zero-order pattern with a StaticVal, the type of -- the pattern wins out. This is important for propagating sizes -- (but probably reveals a flaw in our bookkeeping). pure $ if orderZero t then dim_env <> M.singleton vn (Binding Nothing $ Dynamic t) else dim_env <> M.singleton vn (Binding Nothing sv) where -- Extract all sizes that are potentially bound here. This is -- different from all free variables (see #2040). dim_env = bifoldMap onDim (const mempty) t onDim (Var v _ _) = M.singleton (qualLeaf v) i64 onDim _ = mempty i64 = Binding Nothing $ Dynamic $ Scalar $ Prim $ Signed Int64 matchPatSV (Wildcard _ _) _ = pure mempty matchPatSV (PatAscription pat _ _) sv = matchPatSV pat sv matchPatSV PatLit {} _ = pure mempty matchPatSV (PatConstr c1 _ ps _) (SumSV c2 ls fs) | c1 == c2 = mconcat <$> zipWithM matchPatSV ps ls | Just _ <- lookup c1 fs = Nothing | otherwise = error $ "matchPatSV: missing constructor in type: " ++ prettyString c1 matchPatSV (PatConstr c1 _ ps _) (Dynamic (Scalar (Sum fs))) | Just ts <- M.lookup c1 fs = -- A higher-order pattern can only match an appropriate SumSV. if all orderZero ts then mconcat <$> zipWithM matchPatSV ps (map svFromType ts) else Nothing | otherwise = error $ "matchPatSV: missing constructor in type: " ++ prettyString c1 matchPatSV pat (Dynamic t) = matchPatSV pat $ svFromType t matchPatSV pat (HoleSV t _) = matchPatSV pat $ svFromType $ toParam Observe t matchPatSV pat sv = error $ "Tried to match pattern\n" ++ prettyString pat ++ "\n with static value\n" ++ show sv alwaysMatchPatSV :: Pat ParamType -> StaticVal -> Env alwaysMatchPatSV pat sv = fromMaybe bad $ matchPatSV pat sv where bad = error $ unlines [prettyString pat, "cannot match StaticVal", show sv] -- | Given a pattern and the static value for the defunctionalized argument, -- update the pattern to reflect the changes in the types. updatePat :: Pat ParamType -> StaticVal -> Pat ParamType updatePat (TuplePat ps loc) (RecordSV svs) = TuplePat (zipWith updatePat ps $ map snd svs) loc updatePat (RecordPat ps loc) (RecordSV svs) | ps' <- sortOn fst ps, svs' <- sortOn fst svs = RecordPat (zipWith (\(n, p) (_, sv) -> (n, updatePat p sv)) ps' svs') loc updatePat (PatParens pat loc) sv = PatParens (updatePat pat sv) loc updatePat (PatAttr attr pat loc) sv = PatAttr attr (updatePat pat sv) loc updatePat (Id vn (Info tp) loc) sv = Id vn (Info $ comb tp $ paramTypeFromSV sv) loc where -- Preserve any original zeroth-order types. comb (Scalar Arrow {}) t2 = t2 comb (Scalar (Record m1)) (Scalar (Record m2)) = Scalar $ Record $ M.intersectionWith comb m1 m2 comb (Scalar (Sum m1)) (Scalar (Sum m2)) = Scalar $ Sum $ M.intersectionWith (zipWith comb) m1 m2 comb t1 _ = t1 -- t1 must be array or prim. updatePat pat@(Wildcard (Info tp) loc) sv | orderZero tp = pat | otherwise = Wildcard (Info $ paramTypeFromSV sv) loc updatePat (PatAscription pat _ _) sv = updatePat pat sv updatePat p@PatLit {} _ = p updatePat pat@(PatConstr c1 (Info t) ps loc) sv@(SumSV _ svs _) | orderZero t = pat | otherwise = PatConstr c1 (Info $ toParam Observe t') ps' loc where t' = resTypeFromSV sv ps' = zipWith updatePat ps svs updatePat (PatConstr c1 _ ps loc) (Dynamic t) = PatConstr c1 (Info $ toParam Observe t) ps loc updatePat pat (Dynamic t) = updatePat pat (svFromType t) updatePat pat (HoleSV t _) = updatePat pat (svFromType $ toParam Observe t) updatePat pat sv = error $ "Tried to update pattern\n" ++ prettyString pat ++ "\nto reflect the static value\n" ++ show sv -- | Convert a record (or tuple) type to a record static value. This -- is used for "unwrapping" tuples and records that are nested in -- 'Dynamic' static values. svFromType :: ParamType -> StaticVal svFromType (Scalar (Record fs)) = RecordSV . M.toList $ M.map svFromType fs svFromType t = Dynamic t -- | Defunctionalize a top-level value binding. Returns the -- transformed result as well as an environment that binds the name of -- the value binding to the static value of the transformed body. The -- boolean is true if the function is a 'DynamicFun'. defuncValBind :: ValBind -> DefM (ValBind, Env) -- Eta-expand entry points with a functional return type. defuncValBind (ValBind entry name _ (Info rettype) tparams params body _ attrs loc) | Scalar Arrow {} <- retType rettype = do (body_pats, body', rettype') <- etaExpand (second (const mempty) rettype) body defuncValBind $ ValBind entry name Nothing (Info rettype') tparams (params <> body_pats) body' Nothing attrs loc defuncValBind valbind@(ValBind _ name retdecl (Info (RetType ret_dims rettype)) tparams params body _ _ _) = do when (any isTypeParam tparams) $ error $ show name ++ " has type parameters, " ++ "but the defunctionaliser expects a monomorphic input program." (tparams', params', body', sv, sv_t) <- defuncLet (map typeParamName tparams) params body $ RetType ret_dims rettype globals <- asks fst let bound_sizes = S.fromList (foldMap patNames params') <> S.fromList tparams' <> globals params'' <- instAnySizes params' let rettype' = combineTypeShapes rettype sv_t tparams'' = tparams' ++ unboundSizes bound_sizes params'' ret_dims' = filter (`notElem` bound_sizes) $ S.toList $ fvVars $ freeInType rettype' pure ( valbind { valBindRetDecl = retdecl, valBindRetType = Info $ if null params' then RetType ret_dims' $ rettype' `setUniqueness` Nonunique else RetType ret_dims' rettype', valBindTypeParams = map (`TypeParamDim` mempty) tparams'', valBindParams = params'', valBindBody = body' }, M.singleton name $ Binding (Just (first (map typeParamName) (valBindTypeScheme valbind))) sv ) -- | Defunctionalize a list of top-level declarations. defuncVals :: [ValBind] -> DefM () defuncVals [] = pure () defuncVals (valbind : ds) = do (valbind', env) <- defuncValBind valbind addValBind valbind' let globals = valBindBound valbind' localEnv env $ areGlobal globals $ defuncVals ds {-# NOINLINE transformProg #-} -- | Transform a list of top-level value bindings. May produce new -- lifted function definitions, which are placed in front of the -- resulting list of declarations. transformProg :: (MonadFreshNames m) => [ValBind] -> m [ValBind] transformProg decs = modifyNameSource $ \namesrc -> let ((), namesrc', decs') = runDefM namesrc $ defuncVals decs in (decs', namesrc') futhark-0.25.27/src/Futhark/Internalise/Defunctorise.hs000066400000000000000000000320241475065116200230010ustar00rootroot00000000000000-- | Partially evaluate all modules away from a source Futhark -- program. This is implemented as a source-to-source transformation. module Futhark.Internalise.Defunctorise (transformProg) where import Control.Monad.Identity import Control.Monad.RWS.Strict import Data.DList qualified as DL import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.MonadFreshNames import Language.Futhark import Language.Futhark.Semantic (FileModule (..), Imports, includeToString) import Language.Futhark.Traversals import Prelude hiding (abs, mod) -- | A substitution from names in the original program to names in the -- generated/residual program. type Substitutions = M.Map VName VName lookupSubst :: VName -> Substitutions -> VName lookupSubst v substs = case M.lookup v substs of Just v' | v' /= v -> lookupSubst v' substs _ -> v data Mod = -- | A pairing of a lexical closure and a module function. ModFun TySet Scope ModParam ModExp | -- | A non-parametric module. ModMod Scope deriving (Show) modScope :: Mod -> Scope modScope (ModMod scope) = scope modScope ModFun {} = mempty data Scope = Scope { scopeSubsts :: Substitutions, scopeMods :: M.Map VName Mod } deriving (Show) lookupSubstInScope :: QualName VName -> Scope -> (QualName VName, Scope) lookupSubstInScope qn@(QualName quals name) scope@(Scope substs mods) = case quals of [] -> (qualName $ lookupSubst name substs, scope) q : qs -> let q' = lookupSubst q substs in case M.lookup q' mods of Just (ModMod mod_scope) -> lookupSubstInScope (QualName qs name) mod_scope _ -> (qn, scope) instance Semigroup Scope where Scope ss1 mt1 <> Scope ss2 mt2 = Scope (ss1 <> ss2) (mt1 <> mt2) instance Monoid Scope where mempty = Scope mempty mempty type TySet = S.Set VName data Env = Env { envScope :: Scope, envGenerating :: Bool, envImports :: M.Map ImportName Scope, envAbs :: TySet } newtype TransformM a = TransformM (RWS Env (DL.DList Dec) VNameSource a) deriving ( Applicative, Functor, Monad, MonadFreshNames, MonadReader Env, MonadWriter (DL.DList Dec) ) emit :: Dec -> TransformM () emit = tell . DL.singleton askScope :: TransformM Scope askScope = asks envScope localScope :: (Scope -> Scope) -> TransformM a -> TransformM a localScope f = local $ \env -> env {envScope = f $ envScope env} extendScope :: Scope -> TransformM a -> TransformM a extendScope (Scope substs mods) = localScope $ \scope -> scope { scopeSubsts = M.map (forward (scopeSubsts scope)) substs <> scopeSubsts scope, scopeMods = mods <> scopeMods scope } where forward old_substs v = fromMaybe v $ M.lookup v old_substs substituting :: Substitutions -> TransformM a -> TransformM a substituting substs = extendScope mempty {scopeSubsts = substs} boundName :: VName -> TransformM VName boundName v = do g <- asks envGenerating if g then newName v else pure v bindingNames :: [VName] -> TransformM Scope -> TransformM Scope bindingNames names m = do names' <- mapM boundName names let substs = M.fromList (zip names names') substituting substs $ mappend <$> m <*> pure (Scope substs mempty) generating :: TransformM a -> TransformM a generating = local $ \env -> env {envGenerating = True} bindingImport :: ImportName -> Scope -> TransformM a -> TransformM a bindingImport name scope = local $ \env -> env {envImports = M.insert name scope $ envImports env} bindingAbs :: TySet -> TransformM a -> TransformM a bindingAbs abs = local $ \env -> env {envAbs = abs <> envAbs env} lookupImport :: ImportName -> TransformM Scope lookupImport name = maybe bad pure =<< asks (M.lookup name . envImports) where bad = error $ "Defunctorise: unknown import: " ++ includeToString name lookupMod' :: QualName VName -> Scope -> Either String Mod lookupMod' mname scope = let (mname', scope') = lookupSubstInScope mname scope in maybe (Left $ bad mname') (Right . extend) $ M.lookup (qualLeaf mname') $ scopeMods scope' where bad mname' = "Unknown module: " ++ prettyString mname ++ " (" ++ prettyString mname' ++ ")" extend (ModMod (Scope inner_scope inner_mods)) = -- XXX: perhaps hacky fix for #1653. We need to impose the -- substitutions of abstract types from outside, because the -- inner module may have some incorrect substitutions in some -- cases. Our treatment of abstract types is completely whack -- and should be fixed. ModMod $ Scope (scopeSubsts scope <> inner_scope) inner_mods extend m = m lookupMod :: QualName VName -> TransformM Mod lookupMod mname = either error pure . lookupMod' mname =<< askScope runTransformM :: VNameSource -> TransformM a -> (a, VNameSource, DL.DList Dec) runTransformM src (TransformM m) = runRWS m env src where env = Env mempty False mempty mempty maybeAscript :: SrcLoc -> Maybe (ModTypeExp, Info (M.Map VName VName)) -> ModExp -> ModExp maybeAscript loc (Just (mtye, substs)) me = ModAscript me mtye substs loc maybeAscript _ Nothing me = me substituteInMod :: Substitutions -> Mod -> Mod substituteInMod substs (ModMod (Scope mod_substs mod_mods)) = -- Forward all substitutions. ModMod $ Scope substs' $ M.map (substituteInMod substs) mod_mods where forward v = lookupSubst v $ mod_substs <> substs substs' = M.map forward substs substituteInMod substs (ModFun abs (Scope mod_substs mod_mods) mparam mbody) = ModFun abs (Scope (substs' <> mod_substs) mod_mods) mparam mbody where forward v = lookupSubst v mod_substs substs' = M.map forward substs extendAbsTypes :: Substitutions -> TransformM a -> TransformM a extendAbsTypes ascript_substs m = do abs <- asks envAbs -- Some abstract types may have a different name on the inside, and -- we need to make them visible, because substitutions involving -- abstract types must be lifted out in transformModBind. let subst_abs = S.fromList . map snd . filter ((`S.member` abs) . fst) $ M.toList ascript_substs bindingAbs subst_abs m evalModExp :: ModExp -> TransformM Mod evalModExp (ModVar qn _) = lookupMod qn evalModExp (ModParens e _) = evalModExp e evalModExp (ModDecs decs _) = ModMod <$> transformDecs decs evalModExp (ModImport _ (Info fpath) _) = ModMod <$> lookupImport fpath evalModExp (ModAscript me _ (Info ascript_substs) _) = extendAbsTypes ascript_substs $ substituteInMod ascript_substs <$> evalModExp me evalModExp (ModApply f arg (Info p_substs) (Info b_substs) loc) = do f_mod <- evalModExp f arg_mod <- evalModExp arg case f_mod of ModMod _ -> error $ "Cannot apply non-parametric module at " ++ locStr loc ModFun f_abs f_closure f_p f_body -> bindingAbs (f_abs <> S.fromList (unInfo (modParamAbs f_p))) . extendAbsTypes b_substs . localScope (const f_closure) -- Start afresh. . generating $ do abs <- asks envAbs let keep k _ = k `M.member` p_substs || k `S.member` abs abs_substs = M.filterWithKey keep $ M.map (`lookupSubst` scopeSubsts (modScope arg_mod)) p_substs <> scopeSubsts f_closure <> scopeSubsts (modScope arg_mod) extendScope ( Scope abs_substs ( M.singleton (modParamName f_p) $ substituteInMod p_substs arg_mod ) ) $ do substs <- scopeSubsts <$> askScope x <- evalModExp f_body pure $ addSubsts abs abs_substs $ -- The next one is dubious, but is necessary to -- propagate substitutions from the argument (see -- modules/functor24.fut). addSubstsModMod (scopeSubsts $ modScope arg_mod) $ substituteInMod (b_substs <> substs) x where addSubsts abs substs (ModFun mabs (Scope msubsts mods) mp me) = ModFun (abs <> mabs) (Scope (substs <> msubsts) mods) mp me addSubsts _ substs (ModMod (Scope msubsts mods)) = ModMod $ Scope (substs <> msubsts) mods addSubstsModMod substs (ModMod (Scope msubsts mods)) = ModMod $ Scope (substs <> msubsts) mods addSubstsModMod _ m = m evalModExp (ModLambda p ascript e loc) = do scope <- askScope abs <- asks envAbs pure $ ModFun abs scope p $ maybeAscript loc ascript e transformName :: VName -> TransformM VName transformName v = lookupSubst v . scopeSubsts <$> askScope -- | A general-purpose substitution of names. transformNames :: (ASTMappable x) => x -> TransformM x transformNames x = do scope <- askScope pure $ runIdentity $ astMap (substituter scope) x where substituter scope = ASTMapper { mapOnExp = onExp scope, mapOnName = \v -> pure $ fst $ lookupSubstInScope v {qualQuals = []} scope, mapOnStructType = astMap (substituter scope), mapOnParamType = astMap (substituter scope), mapOnResRetType = astMap (substituter scope) } onExp scope e = -- One expression is tricky, because it interacts with scoping rules. case e of QualParens (mn, _) e' _ -> case lookupMod' mn scope of Left err -> error err Right mod -> astMap (substituter $ modScope mod <> scope) e' _ -> astMap (substituter scope) e transformTypeExp :: TypeExp Exp VName -> TransformM (TypeExp Exp VName) transformTypeExp = transformNames transformStructType :: StructType -> TransformM StructType transformStructType = transformNames transformResType :: ResType -> TransformM ResType transformResType = transformNames transformExp :: Exp -> TransformM Exp transformExp = transformNames transformEntry :: EntryPoint -> TransformM EntryPoint transformEntry (EntryPoint params ret) = EntryPoint <$> mapM onEntryParam params <*> onEntryType ret where onEntryParam (EntryParam v t) = EntryParam v <$> onEntryType t onEntryType (EntryType t te) = EntryType <$> transformStructType t <*> pure te transformValBind :: ValBind -> TransformM () transformValBind (ValBind entry name tdecl (Info (RetType dims t)) tparams params e doc attrs loc) = do entry' <- traverse (traverse transformEntry) entry name' <- transformName name tdecl' <- traverse transformTypeExp tdecl t' <- transformResType t e' <- transformExp e params' <- traverse transformNames params emit $ ValDec $ ValBind entry' name' tdecl' (Info (RetType dims t')) tparams params' e' doc attrs loc transformTypeBind :: TypeBind -> TransformM () transformTypeBind (TypeBind name l tparams te (Info (RetType dims t)) doc loc) = do name' <- transformName name emit . TypeDec =<< ( TypeBind name' l tparams <$> transformTypeExp te <*> (Info . RetType dims <$> transformStructType t) <*> pure doc <*> pure loc ) transformModBind :: ModBind -> TransformM Scope transformModBind mb = do let addParam p me = ModLambda p Nothing me $ srclocOf me mod <- evalModExp $ foldr addParam (maybeAscript (srclocOf mb) (modType mb) $ modExp mb) $ modParams mb mname <- transformName $ modName mb pure $ Scope (scopeSubsts $ modScope mod) $ M.singleton mname mod transformDecs :: [Dec] -> TransformM Scope transformDecs ds = case ds of [] -> pure mempty LocalDec d _ : ds' -> transformDecs $ d : ds' ValDec fdec : ds' -> bindingNames [valBindName fdec] $ do transformValBind fdec transformDecs ds' TypeDec tb : ds' -> bindingNames [typeAlias tb] $ do transformTypeBind tb transformDecs ds' ModTypeDec {} : ds' -> transformDecs ds' ModDec mb : ds' -> bindingNames [modName mb] $ do mod_scope <- transformModBind mb extendScope mod_scope $ mappend <$> transformDecs ds' <*> pure mod_scope OpenDec e _ : ds' -> do scope <- modScope <$> evalModExp e extendScope scope $ mappend <$> transformDecs ds' <*> pure scope ImportDec name name' loc : ds' -> let d = LocalDec (OpenDec (ModImport name name' loc) loc) loc in transformDecs $ d : ds' transformImports :: Imports -> TransformM () transformImports [] = pure () transformImports ((name, imp) : imps) = do let abs = S.fromList $ map qualLeaf $ M.keys $ fileAbs imp scope <- censor (fmap maybeHideEntryPoint) $ bindingAbs abs $ transformDecs $ progDecs $ fileProg imp bindingAbs abs $ bindingImport name scope $ transformImports imps where -- Only the "main" file (last import) is allowed to have entry points. permit_entry_points = null imps maybeHideEntryPoint (ValDec vdec) = ValDec vdec { valBindEntryPoint = if permit_entry_points then valBindEntryPoint vdec else Nothing } maybeHideEntryPoint d = d -- | Perform defunctorisation. transformProg :: (MonadFreshNames m) => Imports -> m [Dec] transformProg prog = modifyNameSource $ \namesrc -> let ((), namesrc', prog') = runTransformM namesrc $ transformImports prog in (DL.toList prog', namesrc') futhark-0.25.27/src/Futhark/Internalise/Entry.hs000066400000000000000000000237501475065116200214560ustar00rootroot00000000000000-- | Generating metadata so that programs can run at all. module Futhark.Internalise.Entry ( entryPoint, VisibleTypes, visibleTypes, ) where import Control.Monad import Control.Monad.State.Strict import Data.Bifunctor (first) import Data.List (find, intersperse) import Data.Map qualified as M import Futhark.IR qualified as I import Futhark.Internalise.TypesValues (internaliseSumTypeRep, internalisedTypeSize) import Futhark.Util (chunks) import Futhark.Util.Pretty (prettyTextOneLine) import Language.Futhark qualified as E hiding (TypeArg) import Language.Futhark.Core (L (..), Name, Uniqueness (..), VName, nameFromText, unLoc) import Language.Futhark.Semantic qualified as E -- | The types that are visible to the outside world. newtype VisibleTypes = VisibleTypes [E.TypeBind] -- | Retrieve those type bindings that should be visible to the -- outside world. Currently that is everything at top level that does -- not have type parameters. visibleTypes :: E.Imports -> VisibleTypes visibleTypes = VisibleTypes . foldMap (modTypes . snd) where modTypes = progTypes . E.fileProg progTypes = foldMap decTypes . E.progDecs decTypes (E.TypeDec tb) = [tb] decTypes _ = [] findType :: VName -> VisibleTypes -> Maybe (E.TypeExp E.Exp VName) findType v (VisibleTypes ts) = E.typeExp <$> find ((== v) . E.typeAlias) ts valueType :: I.TypeBase I.Rank Uniqueness -> I.ValueType valueType (I.Prim pt) = I.ValueType I.Signed (I.Rank 0) pt valueType (I.Array pt rank _) = I.ValueType I.Signed rank pt valueType I.Acc {} = error "valueType Acc" valueType I.Mem {} = error "valueType Mem" withoutDims :: E.TypeExp E.Exp VName -> (Int, E.TypeExp E.Exp VName) withoutDims (E.TEArray _ te _) = let (d, te') = withoutDims te in (d + 1, te') withoutDims te = (0 :: Int, te) rootType :: E.TypeExp E.Exp VName -> E.TypeExp E.Exp VName rootType (E.TEApply te E.TypeArgExpSize {} _) = rootType te rootType (E.TEUnique te _) = rootType te rootType (E.TEDim _ te _) = rootType te rootType (E.TEParens te _) = rootType te rootType te = te typeExpOpaqueName :: E.TypeExp E.Exp VName -> Name typeExpOpaqueName = nameFromText . f where f = g . rootType g (E.TEArray _ te _) = let (d, te') = withoutDims te in mconcat (replicate (1 + d) "[]") <> f te' g (E.TETuple tes _) = "(" <> mconcat (intersperse ", " (map f tes)) <> ")" g (E.TERecord tes _) = "{" <> mconcat (intersperse ", " (map onField tes)) <> "}" where onField (L _ k, te) = E.nameToText k <> ":" <> f te g (E.TESum cs _) = mconcat (intersperse " | " (map onConstr cs)) where onConstr (k, tes) = E.nameToText k <> ":" <> mconcat (intersperse " " (map f tes)) g (E.TEParens te _) = "(" <> f te <> ")" g te = prettyTextOneLine te type GenOpaque = State I.OpaqueTypes runGenOpaque :: GenOpaque a -> (a, I.OpaqueTypes) runGenOpaque = flip runState mempty addType :: Name -> I.OpaqueType -> GenOpaque () addType name t = modify $ \(I.OpaqueTypes ts) -> case find ((== name) . fst) ts of Just (_, t') | t /= t' -> error . unlines $ [ "Duplicate definition of entry point type " <> E.prettyString name, show t, show t' ] _ -> I.OpaqueTypes ts <> I.OpaqueTypes [(name, t)] isRecord :: VisibleTypes -> E.TypeExp E.Exp VName -> Maybe (M.Map Name (E.TypeExp E.Exp VName)) isRecord _ (E.TERecord fs _) = Just $ M.fromList $ map (first unLoc) fs isRecord _ (E.TETuple fs _) = Just $ E.tupleFields fs isRecord types (E.TEVar v _) = isRecord types =<< findType (E.qualLeaf v) types isRecord _ _ = Nothing recordFields :: VisibleTypes -> M.Map Name E.StructType -> Maybe (E.TypeExp E.Exp VName) -> [(Name, E.EntryType)] recordFields types fs t = case isRecord types . rootType =<< t of Just e_fs -> zipWith f (E.sortFields fs) (E.sortFields e_fs) where f (k, f_t) (_, e_f_t) = (k, E.EntryType f_t $ Just e_f_t) Nothing -> map (fmap (`E.EntryType` Nothing)) $ E.sortFields fs opaqueRecord :: VisibleTypes -> [(Name, E.EntryType)] -> [I.TypeBase I.Rank Uniqueness] -> GenOpaque [(Name, I.EntryPointType)] opaqueRecord _ [] _ = pure [] opaqueRecord types ((f, t) : fs) ts = do let (f_ts, ts') = splitAt (internalisedTypeSize $ E.entryType t) ts f' <- opaqueField t f_ts ((f, f') :) <$> opaqueRecord types fs ts' where opaqueField e_t i_ts = snd <$> entryPointType types e_t i_ts opaqueRecordArray :: VisibleTypes -> Int -> [(Name, E.EntryType)] -> [I.TypeBase I.Rank Uniqueness] -> GenOpaque [(Name, I.EntryPointType)] opaqueRecordArray _ _ [] _ = pure [] opaqueRecordArray types rank ((f, t) : fs) ts = do let (f_ts, ts') = splitAt (internalisedTypeSize $ E.entryType t) ts f' <- opaqueField t f_ts ((f, f') :) <$> opaqueRecordArray types rank fs ts' where opaqueField (E.EntryType e_t _) i_ts = snd <$> entryPointType types (E.EntryType e_t' Nothing) i_ts where e_t' = E.arrayOf (E.Shape (replicate rank E.anySize)) e_t isSum :: VisibleTypes -> E.TypeExp E.Exp VName -> Maybe (M.Map Name [E.TypeExp E.Exp VName]) isSum _ (E.TESum cs _) = Just $ M.fromList cs isSum types (E.TEVar v _) = isSum types =<< findType (E.qualLeaf v) types isSum _ _ = Nothing sumConstrs :: VisibleTypes -> M.Map Name [E.StructType] -> Maybe (E.TypeExp E.Exp VName) -> [(Name, [E.EntryType])] sumConstrs types cs t = case isSum types . rootType =<< t of Just e_cs -> zipWith f (E.sortConstrs cs) (E.sortConstrs e_cs) where f (k, c_ts) (_, e_c_ts) = (k, zipWith E.EntryType c_ts $ map Just e_c_ts) Nothing -> map (fmap (map (`E.EntryType` Nothing))) $ E.sortConstrs cs opaqueSum :: VisibleTypes -> [(Name, ([E.EntryType], [Int]))] -> [I.TypeBase I.Rank Uniqueness] -> GenOpaque [(Name, [(I.EntryPointType, [Int])])] opaqueSum types cs ts = mapM (traverse f) cs where f (ets, is) = do let ns = map (internalisedTypeSize . E.entryType) ets is' = chunks ns is ets' <- map snd <$> zipWithM (entryPointType types) ets (map (map (ts !!)) is') pure $ zip ets' $ map (map (+ 1)) is' -- Adjust for tag. entryPointTypeName :: I.EntryPointType -> Name entryPointTypeName (I.TypeOpaque v) = v entryPointTypeName (I.TypeTransparent {}) = error "entryPointTypeName: TypeTransparent" entryPointType :: VisibleTypes -> E.EntryType -> [I.TypeBase I.Rank Uniqueness] -> GenOpaque (Uniqueness, I.EntryPointType) entryPointType types t ts | E.Scalar (E.Prim E.Unsigned {}) <- E.entryType t, [I.Prim ts0] <- ts = pure (u, I.TypeTransparent $ I.ValueType I.Unsigned (I.Rank 0) ts0) | E.Array _ _ (E.Prim E.Unsigned {}) <- E.entryType t, [I.Array ts0 r _] <- ts = pure (u, I.TypeTransparent $ I.ValueType I.Unsigned r ts0) | E.Scalar E.Prim {} <- E.entryType t, [I.Prim ts0] <- ts = pure (u, I.TypeTransparent $ I.ValueType I.Signed (I.Rank 0) ts0) | E.Array _ _ E.Prim {} <- E.entryType t, [I.Array ts0 r _] <- ts = pure (u, I.TypeTransparent $ I.ValueType I.Signed r ts0) | otherwise = do case E.entryType t of E.Scalar (E.Record fs) | not $ null fs -> do let fs' = recordFields types fs $ E.entryAscribed t addType desc . I.OpaqueRecord =<< opaqueRecord types fs' ts E.Scalar (E.Sum cs) -> do let (_, places) = internaliseSumTypeRep cs cs' = sumConstrs types cs $ E.entryAscribed t cs'' = zip (map fst cs') (zip (map snd cs') (map snd places)) addType desc . I.OpaqueSum (map valueType ts) =<< opaqueSum types cs'' (drop 1 ts) E.Array _ shape (E.Record fs) | not $ null fs -> do let fs' = recordFields types fs $ E.entryAscribed t rank = E.shapeRank shape ts' = map (strip rank) ts record_t = E.Scalar (E.Record fs) record_te = case E.entryAscribed t of Just (E.TEArray _ te _) -> Just te _ -> Nothing ept <- snd <$> entryPointType types (E.EntryType record_t record_te) ts' addType desc . I.OpaqueRecordArray rank (entryPointTypeName ept) =<< opaqueRecordArray types rank fs' ts E.Array _ shape et -> do let ts' = map (strip (E.shapeRank shape)) ts elem_te = case E.entryAscribed t of Just (E.TEArray _ te _) -> Just te _ -> Nothing ept <- snd <$> entryPointType types (E.EntryType (E.Scalar et) elem_te) ts' addType desc . I.OpaqueArray (E.shapeRank shape) (entryPointTypeName ept) $ map valueType ts _ -> addType desc $ I.OpaqueType $ map valueType ts pure (u, I.TypeOpaque desc) where u = foldl max Nonunique $ map I.uniqueness ts desc = maybe (nameFromText $ prettyTextOneLine t') typeExpOpaqueName $ E.entryAscribed t t' = E.noSizes (E.entryType t) `E.setUniqueness` Nonunique strip k (I.Array pt (I.Rank r) t_u) = I.arrayOf (I.Prim pt) (I.Rank (r - k)) t_u strip _ ts_t = ts_t entryPoint :: VisibleTypes -> Name -> [(E.EntryParam, [I.Param I.DeclType])] -> ( E.EntryType, [[I.TypeBase I.Rank I.Uniqueness]] ) -> (I.EntryPoint, I.OpaqueTypes) entryPoint types name params (eret, crets) = runGenOpaque $ (name,,) <$> mapM onParam params <*> ( map (uncurry I.EntryResult) <$> case ( E.isTupleRecord $ E.entryType eret, E.entryAscribed eret ) of (Just ts, Just (E.TETuple e_ts _)) -> zipWithM (entryPointType types) (zipWith E.EntryType ts (map Just e_ts)) crets (Just ts, Nothing) -> zipWithM (entryPointType types) (map (`E.EntryType` Nothing) ts) crets _ -> pure <$> entryPointType types eret (concat crets) ) where onParam (E.EntryParam e_p e_t, ps) = uncurry (I.EntryParam e_p) <$> entryPointType types e_t (map (I.rankShaped . I.paramDeclType) ps) futhark-0.25.27/src/Futhark/Internalise/Exps.hs000066400000000000000000002540141475065116200212730ustar00rootroot00000000000000{-# LANGUAGE Strict #-} {-# LANGUAGE TypeFamilies #-} -- | Conversion of a monomorphic, first-order, defunctorised source -- program to a core Futhark program. module Futhark.Internalise.Exps (transformProg) where import Control.Monad import Control.Monad.Reader import Data.Bifunctor import Data.Foldable (toList) import Data.List (elemIndex, find, intercalate, intersperse, transpose) import Data.List.NonEmpty (NonEmpty (..)) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Set qualified as S import Data.Text qualified as T import Futhark.IR.SOACS as I hiding (stmPat) import Futhark.Internalise.AccurateSizes import Futhark.Internalise.Bindings import Futhark.Internalise.Entry import Futhark.Internalise.Lambdas import Futhark.Internalise.Monad as I import Futhark.Internalise.TypesValues import Futhark.Transform.Rename as I import Futhark.Util (lookupWithIndex, splitAt3) import Futhark.Util.Pretty (align, docText, pretty) import Language.Futhark as E hiding (TypeArg) import Language.Futhark.TypeChecker.Types qualified as E -- | Convert a program in source Futhark to a program in the Futhark -- core language. transformProg :: (MonadFreshNames m) => Bool -> VisibleTypes -> [E.ValBind] -> m (I.Prog SOACS) transformProg always_safe types vbinds = do (opaques, consts, funs) <- runInternaliseM always_safe (internaliseValBinds types vbinds) I.renameProg $ I.Prog opaques consts funs internaliseValBinds :: VisibleTypes -> [E.ValBind] -> InternaliseM () internaliseValBinds types = mapM_ $ internaliseValBind types internaliseFunName :: VName -> Name internaliseFunName = nameFromString . prettyString shiftRetAls :: Int -> RetAls -> RetAls shiftRetAls d (RetAls pals rals) = RetAls pals $ map (+ d) rals internaliseValBind :: VisibleTypes -> E.ValBind -> InternaliseM () internaliseValBind types fb@(E.ValBind entry fname _ (Info rettype) tparams params body _ attrs loc) = do bindingFParams tparams params $ \shapeparams params' -> do let shapenames = map I.paramName shapeparams all_params = map pure shapeparams ++ concat params' msg = errorMsg ["Function return value does not match shape of declared return type."] (body', rettype') <- buildBody $ do body_res <- internaliseExp (baseString fname <> "_res") body (rettype', retals) <- first zeroExts . unzip . internaliseReturnType (map (fmap paramDeclType) all_params) rettype <$> mapM subExpType body_res when (null params') $ bindExtSizes (E.AppRes (E.toStruct $ E.retType rettype) (E.retDims rettype)) body_res body_res' <- ensureResultExtShape msg loc (map I.fromDecl rettype') $ subExpsRes body_res let num_ctx = length (shapeContext rettype') pure ( body_res', replicate num_ctx (I.Prim int64, mempty) ++ zip rettype' (map (shiftRetAls num_ctx) retals) ) attrs' <- internaliseAttrs attrs let fd = I.FunDef Nothing attrs' (internaliseFunName fname) rettype' (foldMap toList all_params) body' if null params' then bindConstant fname fd else bindFunction fname fd ( shapenames, map declTypeOf $ foldMap (foldMap toList) params', foldMap toList all_params, fmap (`zip` map snd rettype') . applyRetType (map fst rettype') (foldMap toList all_params) ) case entry of Just (Info entry') -> generateEntryPoint types entry' fb Nothing -> pure () where zeroExts ts = generaliseExtTypes ts ts generateEntryPoint :: VisibleTypes -> E.EntryPoint -> E.ValBind -> InternaliseM () generateEntryPoint types (E.EntryPoint e_params e_rettype) vb = do let (E.ValBind _ ofname _ (Info rettype) tparams params _ _ attrs loc) = vb bindingFParams tparams params $ \shapeparams params' -> do let all_params = map pure shapeparams ++ concat params' (entry_rettype, retals) = unzip $ map unzip $ internaliseEntryReturnType (map (fmap paramDeclType) all_params) rettype (entry', opaques) = entryPoint types (baseName ofname) (zip e_params $ map (foldMap toList) params') (e_rettype, map (map I.rankShaped) entry_rettype) args = map (I.Var . I.paramName) $ foldMap (foldMap toList) params' addOpaques opaques (entry_body, ctx_ts) <- buildBody $ do -- Special case the (rare) situation where the entry point is -- not a function. maybe_const <- lookupConst ofname vals <- case maybe_const of Just ses -> pure ses Nothing -> funcall "entry_result" (E.qualName ofname) args loc ctx <- extractShapeContext (zeroExts $ concat entry_rettype) <$> mapM (fmap I.arrayDims . subExpType) vals pure (subExpsRes $ ctx ++ vals, map (const (I.Prim int64, mempty)) ctx) attrs' <- internaliseAttrs attrs let num_ctx = length ctx_ts addFunDef $ I.FunDef (Just entry') attrs' ("entry_" <> baseName ofname) ( ctx_ts ++ zip (zeroExts (concat entry_rettype)) (map (shiftRetAls num_ctx) $ concat retals) ) (shapeparams ++ foldMap (foldMap toList) params') entry_body where zeroExts ts = generaliseExtTypes ts ts internaliseBody :: String -> E.Exp -> InternaliseM (Body SOACS) internaliseBody desc e = buildBody_ $ subExpsRes <$> internaliseExp (desc <> "_res") e bodyFromStms :: InternaliseM (Result, a) -> InternaliseM (Body SOACS, a) bodyFromStms m = do ((res, a), stms) <- collectStms m (,a) <$> mkBodyM stms res -- | Only returns those pattern names that are not used in the pattern -- itself (the "non-existential" part, you could say). letValExp :: String -> I.Exp SOACS -> InternaliseM [VName] letValExp name e = do e_t <- expExtType e names <- replicateM (length e_t) $ newVName name letBindNames names e let ctx = shapeContext e_t pure $ map fst $ filter ((`S.notMember` ctx) . snd) $ zip names [0 ..] letValExp' :: String -> I.Exp SOACS -> InternaliseM [SubExp] letValExp' _ (BasicOp (SubExp se)) = pure [se] letValExp' name ses = map I.Var <$> letValExp name ses internaliseAppExp :: String -> E.AppRes -> E.AppExp -> InternaliseM [I.SubExp] internaliseAppExp desc _ (E.Index e idxs loc) = do vs <- internaliseExpToVars "indexed" e dims <- case vs of [] -> pure [] -- Will this happen? v : _ -> I.arrayDims <$> lookupType v (idxs', cs) <- internaliseSlice loc dims idxs let index v = do v_t <- lookupType v pure $ I.BasicOp $ I.Index v $ fullSlice v_t idxs' certifying cs $ mapM (letSubExp desc <=< index) vs internaliseAppExp desc _ (E.Range start maybe_second end loc) = do start' <- internaliseExp1 "range_start" start end' <- internaliseExp1 "range_end" $ case end of DownToExclusive e -> e ToInclusive e -> e UpToExclusive e -> e maybe_second' <- traverse (internaliseExp1 "range_second") maybe_second -- Construct an error message in case the range is invalid. let conv = case E.typeOf start of E.Scalar (E.Prim (E.Unsigned _)) -> asIntZ Int64 _ -> asIntS Int64 start'_i64 <- conv start' end'_i64 <- conv end' maybe_second'_i64 <- traverse conv maybe_second' let errmsg = errorMsg $ ["Range "] ++ [ErrorVal int64 start'_i64] ++ ( case maybe_second'_i64 of Nothing -> [] Just second_i64 -> ["..", ErrorVal int64 second_i64] ) ++ ( case end of DownToExclusive {} -> ["..>"] ToInclusive {} -> ["..."] UpToExclusive {} -> ["..<"] ) ++ [ErrorVal int64 end'_i64, " is invalid."] (it, lt_op) <- case E.typeOf start of E.Scalar (E.Prim (E.Signed it)) -> pure (it, CmpSlt it) E.Scalar (E.Prim (E.Unsigned it)) -> pure (it, CmpUlt it) start_t -> error $ "Start value in range has type " ++ prettyString start_t let one = intConst it 1 negone = intConst it (-1) default_step = case end of DownToExclusive {} -> negone ToInclusive {} -> one UpToExclusive {} -> one (step, step_zero) <- case maybe_second' of Just second' -> do subtracted_step <- letSubExp "subtracted_step" $ I.BasicOp $ I.BinOp (I.Sub it I.OverflowWrap) second' start' step_zero <- letSubExp "step_zero" $ I.BasicOp $ I.CmpOp (I.CmpEq $ IntType it) start' second' pure (subtracted_step, step_zero) Nothing -> pure (default_step, constant False) step_sign <- letSubExp "s_sign" $ BasicOp $ I.UnOp (I.SSignum it) step step_sign_i64 <- asIntS Int64 step_sign bounds_invalid_downwards <- letSubExp "bounds_invalid_downwards" $ I.BasicOp $ I.CmpOp lt_op start' end' bounds_invalid_upwards <- letSubExp "bounds_invalid_upwards" $ I.BasicOp $ I.CmpOp lt_op end' start' (distance, step_wrong_dir, bounds_invalid) <- case end of DownToExclusive {} -> do step_wrong_dir <- letSubExp "step_wrong_dir" $ I.BasicOp $ I.CmpOp (I.CmpEq $ IntType it) step_sign one distance <- letSubExp "distance" $ I.BasicOp $ I.BinOp (Sub it I.OverflowWrap) start' end' distance_i64 <- asIntS Int64 distance pure (distance_i64, step_wrong_dir, bounds_invalid_downwards) UpToExclusive {} -> do step_wrong_dir <- letSubExp "step_wrong_dir" $ I.BasicOp $ I.CmpOp (I.CmpEq $ IntType it) step_sign negone distance <- letSubExp "distance" $ I.BasicOp $ I.BinOp (Sub it I.OverflowWrap) end' start' distance_i64 <- asIntS Int64 distance pure (distance_i64, step_wrong_dir, bounds_invalid_upwards) ToInclusive {} -> do downwards <- letSubExp "downwards" $ I.BasicOp $ I.CmpOp (I.CmpEq $ IntType it) step_sign negone distance_downwards_exclusive <- letSubExp "distance_downwards_exclusive" $ I.BasicOp $ I.BinOp (Sub it I.OverflowWrap) start' end' distance_upwards_exclusive <- letSubExp "distance_upwards_exclusive" $ I.BasicOp $ I.BinOp (Sub it I.OverflowWrap) end' start' bounds_invalid <- letSubExp "bounds_invalid" =<< eIf (eSubExp downwards) (resultBodyM [bounds_invalid_downwards]) (resultBodyM [bounds_invalid_upwards]) distance_exclusive <- letSubExp "distance_exclusive" =<< eIf (eSubExp downwards) (resultBodyM [distance_downwards_exclusive]) (resultBodyM [distance_upwards_exclusive]) distance_exclusive_i64 <- asIntS Int64 distance_exclusive distance <- letSubExp "distance" $ I.BasicOp $ I.BinOp (Add Int64 I.OverflowWrap) distance_exclusive_i64 (intConst Int64 1) pure (distance, constant False, bounds_invalid) step_invalid <- letSubExp "step_invalid" $ I.BasicOp $ I.BinOp I.LogOr step_wrong_dir step_zero invalid <- letSubExp "range_invalid" $ I.BasicOp $ I.BinOp I.LogOr step_invalid bounds_invalid valid <- letSubExp "valid" $ I.BasicOp $ I.UnOp (I.Neg I.Bool) invalid cs <- assert "range_valid_c" valid errmsg loc step_i64 <- asIntS Int64 step pos_step <- letSubExp "pos_step" $ I.BasicOp $ I.BinOp (Mul Int64 I.OverflowWrap) step_i64 step_sign_i64 num_elems <- certifying cs $ letSubExp "num_elems" $ I.BasicOp $ I.BinOp (SDivUp Int64 I.Unsafe) distance pos_step se <- letSubExp desc (I.BasicOp $ I.Iota num_elems start' step it) pure [se] internaliseAppExp desc (E.AppRes et ext) e@E.Apply {} = case findFuncall e of (FunctionHole loc, _args) -> do -- The function we are supposed to call doesn't exist, but we -- have to synthesize some fake values of the right type. The -- easy way to do this is to just ignore the arguments and -- create a hole whose type is the type of the entire -- application. One caveat is that we need to replace any -- existential sizes, too (with zeroes, because they don't -- matter). let subst = map (,E.ExpSubst (E.sizeFromInteger 0 mempty)) ext et' = E.applySubst (`lookup` subst) et internaliseExp desc (E.Hole (Info et') loc) (FunctionName qfname, args) -> do -- Argument evaluation is outermost-in so that any existential sizes -- created by function applications can be brought into scope. let fname = nameFromString $ prettyString $ baseName $ qualLeaf qfname loc = srclocOf e arg_desc = nameToString fname ++ "_arg" -- Some functions are magical (overloaded) and we handle that here. case () of () -- Short-circuiting operators are magical. | baseTag (qualLeaf qfname) <= maxIntrinsicTag, baseString (qualLeaf qfname) == "&&", [(x, _), (y, _)] <- args -> internaliseExp desc $ E.AppExp (E.If x y (E.Literal (E.BoolValue False) mempty) mempty) (Info $ AppRes (E.Scalar $ E.Prim E.Bool) []) | baseTag (qualLeaf qfname) <= maxIntrinsicTag, baseString (qualLeaf qfname) == "||", [(x, _), (y, _)] <- args -> internaliseExp desc $ E.AppExp (E.If x (E.Literal (E.BoolValue True) mempty) y mempty) (Info $ AppRes (E.Scalar $ E.Prim E.Bool) []) -- Overloaded and intrinsic functions never take array -- arguments (except equality, but those cannot be -- existential), so we can safely ignore the existential -- dimensions. | Just internalise <- isOverloadedFunction qfname desc loc -> do let prepareArg (arg, _) = (E.toStruct (E.typeOf arg),) <$> internaliseExp "arg" arg internalise =<< mapM prepareArg args | Just internalise <- isIntrinsicFunction qfname (map fst args) loc -> internalise desc | baseTag (qualLeaf qfname) <= maxIntrinsicTag, Just (rettype, _) <- M.lookup fname I.builtInFunctions -> do let tag ses = [(se, I.Observe) | se <- ses] args' <- reverse <$> mapM (internaliseArg arg_desc) (reverse args) let args'' = concatMap tag args' letValExp' desc $ I.Apply fname args'' [(I.Prim rettype, mempty)] (Safe, loc, []) | otherwise -> do args' <- concat . reverse <$> mapM (internaliseArg arg_desc) (reverse args) funcall desc qfname args' loc internaliseAppExp desc _ (E.LetPat sizes pat e body _) = internalisePat desc sizes pat e $ internaliseExp desc body internaliseAppExp _ _ (E.LetFun ofname _ _ _) = error $ "Unexpected LetFun " ++ prettyString ofname internaliseAppExp desc _ (E.Loop sparams mergepat loopinit form loopbody loc) = do ses <- internaliseExp "loop_init" $ loopInitExp loopinit ((loopbody', (form', shapepat, mergepat', mergeinit')), initstms) <- collectStms $ handleForm ses form addStms initstms mergeinit_ts' <- mapM subExpType mergeinit' ctxinit <- argShapes (map I.paramName shapepat) mergepat' mergeinit_ts' -- Ensure that the initial loop values match the shapes of the loop -- parameters. XXX: Ideally they should already match (by the -- source language type rules), but some of our transformations -- (esp. defunctionalisation) strips out some size information. For -- a type-correct source program, these reshapes should simplify -- away. let args = ctxinit ++ mergeinit' args' <- ensureArgShapes "initial loop values have right shape" loc (map I.paramName shapepat) (map paramType $ shapepat ++ mergepat') args let dropCond = case form of E.While {} -> drop 1 _ -> id -- As above, ensure that the result has the right shape. let merge = zip (shapepat ++ mergepat') args' merge_ts = map (I.paramType . fst) merge loopbody'' <- localScope (scopeOfFParams (map fst merge) <> scopeOfLoopForm form') . buildBody_ $ fmap subExpsRes . ensureArgShapes "shape of loop result does not match shapes in loop parameter" loc (map (I.paramName . fst) merge) merge_ts . map resSubExp =<< bodyBind loopbody' attrs <- asks envAttrs map I.Var . dropCond <$> attributing attrs (letValExp desc (I.Loop merge form' loopbody'')) where sparams' = map (`TypeParamDim` mempty) sparams -- Attributes that apply to loops. loopAttrs = oneAttr "unroll" -- Remove those attributes from the attribute set that apply to -- the loop itself. noLoopAttrs env = env {envAttrs = envAttrs env `withoutAttrs` loopAttrs} loopBody = local noLoopAttrs $ internaliseExp "loopres" loopbody forLoop mergepat' shapepat mergeinit i loopvars form' = bodyFromStms . localScope (scopeOfLoopForm form') $ do forM_ loopvars $ \(p, arr) -> letBindNames [I.paramName p] =<< eIndex arr [eSubExp (I.Var i)] ses <- loopBody sets <- mapM subExpType ses shapeargs <- argShapes (map I.paramName shapepat) mergepat' sets pure ( subExpsRes $ shapeargs ++ ses, ( form', shapepat, mergepat', mergeinit ) ) handleForm mergeinit (E.ForIn x arr) = do arr' <- internaliseExpToVars "for_in_arr" arr arr_ts <- mapM lookupType arr' let w = arraysSize 0 arr_ts i <- newVName "i" ts <- mapM subExpType mergeinit bindingLoopParams sparams' mergepat ts $ \shapepat mergepat' -> bindingLambdaParams [toParam E.Observe <$> x] (map rowType arr_ts) $ \x_params -> do let loopvars = zip x_params arr' forLoop mergepat' shapepat mergeinit i loopvars $ I.ForLoop i Int64 w handleForm mergeinit (E.For i num_iterations) = do num_iterations' <- internaliseExp1 "upper_bound" num_iterations num_iterations_t <- I.subExpType num_iterations' it <- case num_iterations_t of I.Prim (IntType it) -> pure it _ -> error "internaliseExp Loop: invalid type" ts <- mapM subExpType mergeinit bindingLoopParams sparams' mergepat ts $ \shapepat mergepat' -> forLoop mergepat' shapepat mergeinit (E.identName i) [] $ I.ForLoop (E.identName i) it num_iterations' handleForm mergeinit (E.While cond) = do ts <- mapM subExpType mergeinit bindingLoopParams sparams' mergepat ts $ \shapepat mergepat' -> do mergeinit_ts <- mapM subExpType mergeinit -- We need to insert 'cond' twice - once for the initial -- condition (do we enter the loop at all?), and once with the -- result values of the loop (do we continue into the next -- iteration?). This is safe, as the type rules for the -- external language guarantees that 'cond' does not consume -- anything. shapeinit <- argShapes (map I.paramName shapepat) mergepat' mergeinit_ts (loop_initial_cond, init_loop_cond_stms) <- collectStms $ do forM_ (zip shapepat shapeinit) $ \(p, se) -> letBindNames [I.paramName p] $ BasicOp $ SubExp se forM_ (zip mergepat' mergeinit) $ \(p, se) -> unless (se == I.Var (I.paramName p)) $ letBindNames [I.paramName p] $ BasicOp $ case se of I.Var v | not $ primType $ paramType p -> Reshape I.ReshapeCoerce (I.arrayShape $ paramType p) v _ -> SubExp se -- As the condition expression is inserted twice, we have to -- avoid shadowing (#1935). (cond_stms, cond') <- uncurry (flip renameStmsWith) =<< collectStms (internaliseExp1 "loop_cond" cond) addStms cond_stms pure cond' addStms init_loop_cond_stms bodyFromStms $ do ses <- loopBody sets <- mapM subExpType ses loop_while <- newParam "loop_while" $ I.Prim I.Bool shapeargs <- argShapes (map I.paramName shapepat) mergepat' sets -- Careful not to clobber anything. loop_end_cond_body <- renameBody <=< buildBody_ $ do forM_ (zip shapepat shapeargs) $ \(p, se) -> unless (se == I.Var (I.paramName p)) $ letBindNames [I.paramName p] $ BasicOp $ SubExp se forM_ (zip mergepat' ses) $ \(p, se) -> unless (se == I.Var (I.paramName p)) $ letBindNames [I.paramName p] $ BasicOp $ case se of I.Var v | not $ primType $ paramType p -> Reshape I.ReshapeCoerce (I.arrayShape $ paramType p) v _ -> SubExp se subExpsRes <$> internaliseExp "loop_cond" cond loop_end_cond <- bodyBind loop_end_cond_body pure ( subExpsRes shapeargs ++ loop_end_cond ++ subExpsRes ses, ( I.WhileLoop $ I.paramName loop_while, shapepat, loop_while : mergepat', loop_initial_cond : mergeinit ) ) internaliseAppExp desc _ (E.LetWith name src idxs ve body loc) = do let pat = E.Id (E.identName name) (E.identType name) loc src_t = E.identType src e = E.Update (E.Var (E.qualName $ E.identName src) src_t loc) idxs ve loc internaliseExp desc $ E.AppExp (E.LetPat [] pat e body loc) (Info (AppRes (E.typeOf body) mempty)) internaliseAppExp desc _ (E.Match e orig_cs _) = do ses <- internaliseExp (desc ++ "_scrutinee") e cs <- mapM (onCase ses) orig_cs case NE.uncons cs of (I.Case _ body, Nothing) -> fmap (map resSubExp) $ bodyBind =<< body _ -> do letValExp' desc =<< eMatch ses (NE.init cs) (I.caseBody $ NE.last cs) where onCase ses (E.CasePat p case_e _) = do (cmps, pertinent) <- generateCond p ses pure . I.Case cmps $ internalisePat' [] p pertinent $ internaliseBody "case" case_e internaliseAppExp desc _ (E.If ce te fe _) = letValExp' desc =<< eIf (BasicOp . SubExp <$> internaliseExp1 "cond" ce) (internaliseBody (desc <> "_t") te) (internaliseBody (desc <> "_f") fe) internaliseAppExp _ _ e@E.BinOp {} = error $ "internaliseAppExp: Unexpected BinOp " ++ prettyString e internaliseExp :: String -> E.Exp -> InternaliseM [I.SubExp] internaliseExp desc (E.Parens e _) = internaliseExp desc e internaliseExp desc (E.Hole (Info t) loc) = do let msg = docText $ "Reached hole of type: " <> align (pretty t) ts = foldMap toList $ internaliseType (E.toStruct t) c <- assert "hole_c" (constant False) (errorMsg [ErrorString msg]) loc case mapM hasStaticShape ts of Nothing -> error $ "Hole at " <> locStr loc <> " has existential type:\n" <> show ts Just ts' -> -- Make sure we always generate a binding, even for primitives. certifying c $ mapM (fmap I.Var . letExp desc <=< eBlank . I.fromDecl) ts' internaliseExp desc (E.QualParens _ e _) = internaliseExp desc e internaliseExp desc (E.StringLit vs _) = fmap pure . letSubExp desc $ I.BasicOp $ I.ArrayLit (map constant vs) $ I.Prim int8 internaliseExp _ (E.Var (E.QualName _ name) _ _) = do subst <- lookupSubst name case subst of Just substs -> pure substs Nothing -> pure [I.Var name] internaliseExp desc (E.AppExp e (Info appres)) = do ses <- internaliseAppExp desc appres e bindExtSizes appres ses pure ses internaliseExp _ (E.TupLit [] _) = pure [constant UnitValue] internaliseExp _ (E.RecordLit [] _) = pure [constant UnitValue] internaliseExp desc (E.TupLit es _) = concat <$> mapM (internaliseExp desc) es internaliseExp desc (E.RecordLit orig_fields _) = concatMap snd . sortFields . M.unions <$> mapM internaliseField orig_fields where internaliseField (E.RecordFieldExplicit (L _ name) e _) = M.singleton name <$> internaliseExp desc e internaliseField (E.RecordFieldImplicit (L _ name) t loc) = internaliseField $ E.RecordFieldExplicit (L noLoc (baseName name)) (E.Var (E.qualName name) t loc) loc internaliseExp desc (E.ArrayVal vs t _) = fmap pure . letSubExp desc . I.BasicOp $ I.ArrayVal (map internalisePrimValue vs) (internalisePrimType t) internaliseExp desc (E.ArrayLit es (Info arr_t) loc) -- If this is a multidimensional array literal of primitives, we -- treat it specially by flattening it out followed by a reshape. -- This cuts down on the amount of statements that are produced, and -- thus allows us to efficiently handle huge array literals - a -- corner case, but an important one. | Just ((eshape, e') : es') <- mapM isArrayLiteral es, not $ null eshape, all ((eshape ==) . fst) es', Just basetype <- E.peelArray (length eshape) arr_t = do let flat_lit = E.ArrayLit (e' ++ concatMap snd es') (Info basetype) loc new_shape = length es : eshape flat_arrs <- internaliseExpToVars "flat_literal" flat_lit forM flat_arrs $ \flat_arr -> do flat_arr_t <- lookupType flat_arr let new_shape' = reshapeOuter (I.Shape $ map (intConst Int64 . toInteger) new_shape) 1 $ I.arrayShape flat_arr_t letSubExp desc $ I.BasicOp $ I.Reshape I.ReshapeArbitrary new_shape' flat_arr | otherwise = do es' <- mapM (internaliseExp "arr_elem") es let arr_t_ext = foldMap toList $ internaliseType $ E.toStruct arr_t rowtypes <- case mapM (fmap rowType . hasStaticShape . I.fromDecl) arr_t_ext of Just ts -> pure ts Nothing -> -- XXX: the monomorphiser may create single-element array -- literals with an unknown row type. In those cases we -- need to look at the types of the actual elements. -- Fixing this in the monomorphiser is a lot more tricky -- than just working around it here. case es' of [] -> error $ "internaliseExp ArrayLit: existential type: " ++ prettyString arr_t e' : _ -> mapM subExpType e' let arraylit ks rt = do ks' <- mapM ( ensureShape "shape of element differs from shape of first element" loc rt "elem_reshaped" ) ks pure $ I.BasicOp $ I.ArrayLit ks' rt mapM (letSubExp desc) =<< if null es' then mapM (arraylit []) rowtypes else zipWithM arraylit (transpose es') rowtypes where isArrayLiteral :: E.Exp -> Maybe ([Int], [E.Exp]) isArrayLiteral (E.ArrayLit inner_es _ _) = do (eshape, e) : inner_es' <- mapM isArrayLiteral inner_es guard $ all ((eshape ==) . fst) inner_es' pure (length inner_es : eshape, e ++ concatMap snd inner_es') isArrayLiteral e = Just ([], [e]) internaliseExp desc (E.Ascript e _ _) = internaliseExp desc e internaliseExp desc (E.Coerce e _ (Info et) loc) = do ses <- internaliseExp desc e ts <- internaliseCoerceType (E.toStruct et) <$> mapM subExpType ses dt' <- typeExpForError $ toStruct et forM (zip ses ts) $ \(e', t') -> do dims <- arrayDims <$> subExpType e' let parts = ["Value of (desugared) shape ["] ++ intersperse "][" (map (ErrorVal int64) dims) ++ ["] cannot match shape of type `"] ++ dt' ++ ["`."] ensureExtShape (errorMsg parts) loc (I.fromDecl t') desc e' internaliseExp desc (E.Negate e _) = do e' <- internaliseExp1 "negate_arg" e et <- subExpType e' case et of I.Prim pt -> letTupExp' desc $ I.BasicOp $ I.UnOp (I.Neg pt) e' _ -> error "Futhark.Internalise.internaliseExp: non-primitive type in Negate" internaliseExp desc (E.Not e _) = do e' <- internaliseExp1 "not_arg" e et <- subExpType e' case et of I.Prim (I.IntType t) -> letTupExp' desc $ I.BasicOp $ I.UnOp (I.Complement t) e' I.Prim I.Bool -> letTupExp' desc $ I.BasicOp $ I.UnOp (I.Neg I.Bool) e' _ -> error "Futhark.Internalise.internaliseExp: non-int/bool type in Not" internaliseExp desc (E.Update src slice ve loc) = do ves <- internaliseExp "lw_val" ve srcs <- internaliseExpToVars "src" src (src_dims, ve_dims) <- case (srcs, ves) of (src_v : _, ve_v : _) -> (,) <$> (I.arrayDims <$> lookupType src_v) <*> (I.arrayDims <$> subExpType ve_v) _ -> pure ([], []) -- Will this happen? (idxs', cs) <- internaliseSlice loc src_dims slice let src_dims' = sliceDims (Slice idxs') rank = length src_dims' errormsg = "Shape " <> errorShape src_dims' <> " of slice does not match shape " <> errorShape (take rank ve_dims) <> " of value." let comb sname ve' = do sname_t <- lookupType sname let full_slice = fullSlice sname_t idxs' rowtype = sname_t `setArrayDims` sliceDims full_slice ve'' <- ensureShape errormsg loc rowtype "lw_val_correct_shape" ve' letInPlace desc sname full_slice $ BasicOp $ SubExp ve'' certifying cs $ map I.Var <$> zipWithM comb srcs ves internaliseExp desc (E.RecordUpdate src fields ve _ _) = do src' <- internaliseExp desc src ve' <- internaliseExp desc ve replace (E.typeOf src) fields ve' src' where replace (E.Scalar (E.Record m)) (f : fs) ve' src' | Just t <- M.lookup f m = do let i = sum . map (internalisedTypeSize . snd) $ takeWhile ((/= f) . fst) . sortFields $ m k = internalisedTypeSize t (bef, to_update, aft) = splitAt3 i k src' src'' <- replace t fs ve' to_update pure $ bef ++ src'' ++ aft replace _ _ ve' _ = pure ve' internaliseExp desc (E.Attr attr e loc) = do attr' <- internaliseAttr attr e' <- local (f attr') $ internaliseExp desc e case attr' of "trace" -> traceRes (T.pack $ locStr loc) e' I.AttrComp "trace" [I.AttrName tag] -> traceRes (nameToText tag) e' "opaque" -> mapM (letSubExp desc . BasicOp . Opaque OpaqueNil) e' _ -> pure e' where traceRes tag' = mapM (letSubExp desc . BasicOp . Opaque (OpaqueTrace tag')) f attr' env | attr' == "unsafe", not $ envSafe env = env {envDoBoundsChecks = False} | otherwise = env {envAttrs = envAttrs env <> oneAttr attr'} internaliseExp desc (E.Assert e1 e2 (Info check) loc) = do e1' <- internaliseExp1 "assert_cond" e1 c <- assert "assert_c" e1' (errorMsg [ErrorString $ "Assertion is false: " <> check]) loc -- Make sure there are some bindings to certify. certifying c $ mapM rebind =<< internaliseExp desc e2 where rebind v = do v' <- newVName "assert_res" letBindNames [v'] $ I.BasicOp $ I.SubExp v pure $ I.Var v' internaliseExp _ (E.Constr c es (Info (E.Scalar (E.Sum fs))) _) = do (ts, constr_map) <- internaliseSumType $ M.map (map E.toStruct) fs es' <- concat <$> mapM (internaliseExp "payload") es let noExt _ = pure $ intConst Int64 0 ts' <- instantiateShapes noExt $ map fromDecl ts case lookupWithIndex c constr_map of Just (i, js) -> (intConst Int8 (toInteger i) :) <$> clauses 0 ts' (zip js es') Nothing -> error "internaliseExp Constr: missing constructor" where clauses j (t : ts) js_to_es | Just e <- j `lookup` js_to_es = (e :) <$> clauses (j + 1) ts js_to_es | otherwise = do blank <- -- Cannot use eBlank here for arrays, because when doing -- equality comparisons on sum types, we end up looking at -- the array elements. (#2081) This is a bit of an edge -- case, but arrays in sum types are known to be -- inefficient. letSubExp "zero" =<< case t of I.Array {} -> pure $ BasicOp $ Replicate (I.arrayShape t) $ I.Constant $ blankPrimValue $ elemType t _ -> eBlank t (blank :) <$> clauses (j + 1) ts js_to_es clauses _ [] _ = pure [] internaliseExp _ (E.Constr _ _ (Info t) loc) = error $ "internaliseExp: constructor with type " ++ prettyString t ++ " at " ++ locStr loc -- The "interesting" cases are over, now it's mostly boilerplate. internaliseExp _ (E.Literal v _) = pure [I.Constant $ internalisePrimValue v] internaliseExp _ (E.IntLit v (Info t) _) = case t of E.Scalar (E.Prim (E.Signed it)) -> pure [I.Constant $ I.IntValue $ intValue it v] E.Scalar (E.Prim (E.Unsigned it)) -> pure [I.Constant $ I.IntValue $ intValue it v] E.Scalar (E.Prim (E.FloatType ft)) -> pure [I.Constant $ I.FloatValue $ floatValue ft v] _ -> error $ "internaliseExp: nonsensical type for integer literal: " ++ prettyString t internaliseExp _ (E.FloatLit v (Info t) _) = case t of E.Scalar (E.Prim (E.FloatType ft)) -> pure [I.Constant $ I.FloatValue $ floatValue ft v] _ -> error $ "internaliseExp: nonsensical type for float literal: " ++ prettyString t -- Builtin operators are handled specially because they are -- overloaded. internaliseExp desc (E.Project k e (Info rt) _) = do let i' = sum . map internalisedTypeSize $ case E.typeOf e of E.Scalar (Record fs) -> map snd $ takeWhile ((/= k) . fst) $ sortFields fs t -> [t] take (internalisedTypeSize rt) . drop i' <$> internaliseExp desc e internaliseExp _ e@E.Lambda {} = error $ "internaliseExp: Unexpected lambda at " ++ locStr (srclocOf e) internaliseExp _ e@E.OpSection {} = error $ "internaliseExp: Unexpected operator section at " ++ locStr (srclocOf e) internaliseExp _ e@E.OpSectionLeft {} = error $ "internaliseExp: Unexpected left operator section at " ++ locStr (srclocOf e) internaliseExp _ e@E.OpSectionRight {} = error $ "internaliseExp: Unexpected right operator section at " ++ locStr (srclocOf e) internaliseExp _ e@E.ProjectSection {} = error $ "internaliseExp: Unexpected projection section at " ++ locStr (srclocOf e) internaliseExp _ e@E.IndexSection {} = error $ "internaliseExp: Unexpected index section at " ++ locStr (srclocOf e) internaliseArg :: String -> (E.Exp, Maybe VName) -> InternaliseM [SubExp] internaliseArg desc (arg, argdim) = do exists <- askScope case argdim of Just d | d `M.member` exists -> pure [I.Var d] _ -> do arg' <- internaliseExp desc arg case (arg', argdim) of ([se], Just d) -> do letBindNames [d] $ BasicOp $ SubExp se _ -> pure () pure arg' internalisePatLit :: E.PatLit -> E.StructType -> I.PrimValue internalisePatLit (E.PatLitPrim v) _ = internalisePrimValue v internalisePatLit (E.PatLitInt x) (E.Scalar (E.Prim (E.Signed it))) = I.IntValue $ intValue it x internalisePatLit (E.PatLitInt x) (E.Scalar (E.Prim (E.Unsigned it))) = I.IntValue $ intValue it x internalisePatLit (E.PatLitFloat x) (E.Scalar (E.Prim (E.FloatType ft))) = I.FloatValue $ floatValue ft x internalisePatLit l t = error $ "Nonsensical pattern and type: " ++ show (l, t) generateCond :: E.Pat StructType -> [I.SubExp] -> InternaliseM ([Maybe I.PrimValue], [I.SubExp]) generateCond orig_p orig_ses = do (cmps, pertinent, _) <- compares orig_p orig_ses pure (cmps, pertinent) where compares (E.PatLit l (Info t) _) (se : ses) = pure ([Just $ internalisePatLit l t], [se], ses) compares (E.PatConstr c (Info (E.Scalar (E.Sum fs))) pats _) (_ : ses) = do (payload_ts, m) <- internaliseSumType $ M.map (map toStruct) fs case lookupWithIndex c m of Just (tag, payload_is) -> do let (payload_ses, ses') = splitAt (length payload_ts) ses (cmps, pertinent, _) <- comparesMany pats $ map (payload_ses !!) payload_is let missingCmps i _ = case i `elemIndex` payload_is of Just j -> cmps !! j Nothing -> Nothing pure ( Just (I.IntValue $ intValue Int8 $ toInteger tag) : zipWith missingCmps [0 ..] payload_ses, pertinent, ses' ) Nothing -> error "generateCond: missing constructor" compares (E.PatConstr _ (Info t) _ _) _ = error $ "generateCond: PatConstr has nonsensical type: " ++ prettyString t compares (E.Id _ t loc) ses = compares (E.Wildcard t loc) ses compares (E.Wildcard (Info t) _) ses = do let (id_ses, rest_ses) = splitAt (internalisedTypeSize $ E.toStruct t) ses pure (map (const Nothing) id_ses, id_ses, rest_ses) compares (E.PatParens pat _) ses = compares pat ses compares (E.PatAttr _ pat _) ses = compares pat ses compares (E.TuplePat [] loc) ses = compares (E.Wildcard (Info $ E.Scalar $ E.Record mempty) loc) ses compares (E.RecordPat [] loc) ses = compares (E.Wildcard (Info $ E.Scalar $ E.Record mempty) loc) ses compares (E.TuplePat pats _) ses = comparesMany pats ses compares (E.RecordPat fs _) ses = comparesMany (map snd $ E.sortFields $ M.fromList $ map (first unLoc) fs) ses compares (E.PatAscription pat _ _) ses = compares pat ses compares pat [] = error $ "generateCond: No values left for pattern " ++ prettyString pat comparesMany [] ses = pure ([], [], ses) comparesMany (pat : pats) ses = do (cmps1, pertinent1, ses') <- compares pat ses (cmps2, pertinent2, ses'') <- comparesMany pats ses' pure ( cmps1 <> cmps2, pertinent1 <> pertinent2, ses'' ) internalisePat :: String -> [E.SizeBinder VName] -> E.Pat StructType -> E.Exp -> InternaliseM a -> InternaliseM a internalisePat desc sizes p e m = do ses <- internaliseExp desc' e internalisePat' sizes p ses m where desc' = case E.patIdents p of [v] -> baseString $ E.identName v _ -> desc internalisePat' :: [E.SizeBinder VName] -> E.Pat StructType -> [I.SubExp] -> InternaliseM a -> InternaliseM a internalisePat' sizes p ses m = do ses_ts <- mapM subExpType ses stmPat (toParam E.Observe <$> p) ses_ts $ \pat_names -> do bindExtSizes (AppRes (E.patternType p) (map E.sizeName sizes)) ses forM_ (zip pat_names ses) $ \(v, se) -> letBindNames [v] $ I.BasicOp $ I.SubExp se m internaliseSlice :: SrcLoc -> [SubExp] -> [E.DimIndex] -> InternaliseM ([I.DimIndex SubExp], Certs) internaliseSlice loc dims idxs = do (idxs', oks, parts) <- unzip3 <$> zipWithM internaliseDimIndex dims idxs ok <- letSubExp "index_ok" =<< eAll oks let msg = errorMsg $ ["Index ["] ++ intercalate [", "] parts ++ ["] out of bounds for array of shape ["] ++ intersperse "][" (map (ErrorVal int64) $ take (length idxs) dims) ++ ["]."] c <- assert "index_certs" ok msg loc pure (idxs', c) internaliseDimIndex :: SubExp -> E.DimIndex -> InternaliseM (I.DimIndex SubExp, SubExp, [ErrorMsgPart SubExp]) internaliseDimIndex w (E.DimFix i) = do (i', _) <- internaliseSizeExp "i" i let lowerBound = I.BasicOp $ I.CmpOp (I.CmpSle I.Int64) (I.constant (0 :: I.Int64)) i' upperBound = I.BasicOp $ I.CmpOp (I.CmpSlt I.Int64) i' w ok <- letSubExp "bounds_check" =<< eBinOp I.LogAnd (pure lowerBound) (pure upperBound) pure (I.DimFix i', ok, [ErrorVal int64 i']) -- Special-case an important common case that otherwise leads to horrible code. internaliseDimIndex w ( E.DimSlice Nothing Nothing (Just (E.Negate (E.IntLit 1 _ _) _)) ) = do w_minus_1 <- letSubExp "w_minus_1" $ BasicOp $ I.BinOp (Sub Int64 I.OverflowWrap) w one pure ( I.DimSlice w_minus_1 w $ intConst Int64 (-1), constant True, mempty ) where one = constant (1 :: Int64) internaliseDimIndex w (E.DimSlice i j s) = do s' <- maybe (pure one) (fmap fst . internaliseSizeExp "s") s s_sign <- letSubExp "s_sign" $ BasicOp $ I.UnOp (I.SSignum Int64) s' backwards <- letSubExp "backwards" $ I.BasicOp $ I.CmpOp (I.CmpEq int64) s_sign negone w_minus_1 <- letSubExp "w_minus_1" $ BasicOp $ I.BinOp (Sub Int64 I.OverflowWrap) w one let i_def = letSubExp "i_def" =<< eIf (eSubExp backwards) (resultBodyM [w_minus_1]) (resultBodyM [zero]) j_def = letSubExp "j_def" =<< eIf (eSubExp backwards) (resultBodyM [negone]) (resultBodyM [w]) i' <- maybe i_def (fmap fst . internaliseSizeExp "i") i j' <- maybe j_def (fmap fst . internaliseSizeExp "j") j j_m_i <- letSubExp "j_m_i" $ BasicOp $ I.BinOp (Sub Int64 I.OverflowWrap) j' i' -- Something like a division-rounding-up, but accomodating negative -- operands. let divRounding x y = eBinOp (SQuot Int64 Safe) ( eBinOp (Add Int64 I.OverflowWrap) x (eBinOp (Sub Int64 I.OverflowWrap) y (eSignum y)) ) y n <- letSubExp "n" =<< divRounding (toExp j_m_i) (toExp s') zero_stride <- letSubExp "zero_stride" $ I.BasicOp $ I.CmpOp (CmpEq int64) s_sign zero nonzero_stride <- letSubExp "nonzero_stride" $ I.BasicOp $ I.UnOp (I.Neg I.Bool) zero_stride -- Bounds checks depend on whether we are slicing forwards or -- backwards. If forwards, we must check '0 <= i && i <= j'. If -- backwards, '-1 <= j && j <= i'. In both cases, we check '0 <= -- i+n*s && i+(n-1)*s < w'. We only check if the slice is nonempty. empty_slice <- letSubExp "empty_slice" $ I.BasicOp $ I.CmpOp (CmpEq int64) n zero m <- letSubExp "m" $ I.BasicOp $ I.BinOp (Sub Int64 I.OverflowWrap) n one m_t_s <- letSubExp "m_t_s" $ I.BasicOp $ I.BinOp (Mul Int64 I.OverflowWrap) m s' i_p_m_t_s <- letSubExp "i_p_m_t_s" $ I.BasicOp $ I.BinOp (Add Int64 I.OverflowWrap) i' m_t_s zero_leq_i_p_m_t_s <- letSubExp "zero_leq_i_p_m_t_s" $ I.BasicOp $ I.CmpOp (I.CmpSle Int64) zero i_p_m_t_s i_p_m_t_s_leq_w <- letSubExp "i_p_m_t_s_leq_w" $ I.BasicOp $ I.CmpOp (I.CmpSle Int64) i_p_m_t_s w i_p_m_t_s_lth_w <- letSubExp "i_p_m_t_s_leq_w" $ I.BasicOp $ I.CmpOp (I.CmpSlt Int64) i_p_m_t_s w zero_lte_i <- letSubExp "zero_lte_i" $ I.BasicOp $ I.CmpOp (I.CmpSle Int64) zero i' i_lte_j <- letSubExp "i_lte_j" $ I.BasicOp $ I.CmpOp (I.CmpSle Int64) i' j' forwards_ok <- letSubExp "forwards_ok" =<< eAll [zero_lte_i, i_lte_j, zero_leq_i_p_m_t_s, i_p_m_t_s_lth_w] negone_lte_j <- letSubExp "negone_lte_j" $ I.BasicOp $ I.CmpOp (I.CmpSle Int64) negone j' j_lte_i <- letSubExp "j_lte_i" $ I.BasicOp $ I.CmpOp (I.CmpSle Int64) j' i' backwards_ok <- letSubExp "backwards_ok" =<< eAll [negone_lte_j, j_lte_i, zero_leq_i_p_m_t_s, i_p_m_t_s_leq_w] slice_ok <- letSubExp "slice_ok" =<< eIf (eSubExp backwards) (resultBodyM [backwards_ok]) (resultBodyM [forwards_ok]) ok_or_empty <- letSubExp "ok_or_empty" $ I.BasicOp $ I.BinOp I.LogOr empty_slice slice_ok acceptable <- letSubExp "slice_acceptable" $ I.BasicOp $ I.BinOp I.LogAnd nonzero_stride ok_or_empty let parts = case (i, j, s) of (_, _, Just {}) -> [ maybe "" (const $ ErrorVal int64 i') i, ":", maybe "" (const $ ErrorVal int64 j') j, ":", ErrorVal int64 s' ] (_, Just {}, _) -> [ maybe "" (const $ ErrorVal int64 i') i, ":", ErrorVal int64 j' ] ++ maybe mempty (const [":", ErrorVal int64 s']) s (_, Nothing, Nothing) -> [ErrorVal int64 i', ":"] pure (I.DimSlice i' n s', acceptable, parts) where zero = constant (0 :: Int64) negone = constant (-1 :: Int64) one = constant (1 :: Int64) internaliseScanOrReduce :: String -> String -> (SubExp -> I.Lambda SOACS -> [SubExp] -> [VName] -> InternaliseM (SOAC SOACS)) -> (E.Exp, E.Exp, E.Exp, SrcLoc) -> InternaliseM [SubExp] internaliseScanOrReduce desc what f (lam, ne, arr, loc) = do arrs <- internaliseExpToVars (what ++ "_arr") arr nes <- internaliseExp (what ++ "_ne") ne nes' <- forM (zip nes arrs) $ \(ne', arr') -> do rowtype <- I.stripArray 1 <$> lookupType arr' ensureShape "Row shape of input array does not match shape of neutral element" loc rowtype (what ++ "_ne_right_shape") ne' nests <- mapM I.subExpType nes' arrts <- mapM lookupType arrs lam' <- internaliseFoldLambda internaliseLambda lam nests arrts w <- arraysSize 0 <$> mapM lookupType arrs letValExp' desc . I.Op =<< f w lam' nes' arrs internaliseHist :: Int -> String -> E.Exp -> E.Exp -> E.Exp -> E.Exp -> E.Exp -> E.Exp -> SrcLoc -> InternaliseM [SubExp] internaliseHist dim desc rf hist op ne buckets img loc = do rf' <- internaliseExp1 "hist_rf" rf ne' <- internaliseExp "hist_ne" ne hist' <- internaliseExpToVars "hist_hist" hist buckets' <- internaliseExpToVars "hist_buckets" buckets img' <- internaliseExpToVars "hist_img" img -- reshape neutral element to have same size as the destination array ne_shp <- forM (zip ne' hist') $ \(n, h) -> do rowtype <- I.stripArray 1 <$> lookupType h ensureShape "Row shape of destination array does not match shape of neutral element" loc rowtype "hist_ne_right_shape" n ne_ts <- mapM I.subExpType ne_shp his_ts <- mapM (fmap (I.stripArray (dim - 1)) . lookupType) hist' op' <- internaliseFoldLambda internaliseLambda op ne_ts his_ts -- reshape return type of bucket function to have same size as neutral element -- (modulo the index) bucket_params <- replicateM dim (newParam "bucket_p" $ I.Prim int64) img_params <- mapM (newParam "img_p" . rowType) =<< mapM lookupType img' let params = bucket_params ++ img_params rettype = replicate dim (I.Prim int64) ++ ne_ts body = mkBody mempty $ varsRes $ map I.paramName params lam' <- mkLambda params $ ensureResultShape "Row shape of value array does not match row shape of hist target" (srclocOf img) rettype =<< bodyBind body -- get sizes of histogram and image arrays shape_hist <- I.Shape . take dim . I.arrayDims <$> lookupType (head hist') w_img <- I.arraySize 0 <$> lookupType (head img') letValExp' desc . I.Op $ I.Hist w_img (buckets' ++ img') [HistOp shape_hist rf' hist' ne_shp op'] lam' internaliseStreamAcc :: String -> E.Exp -> Maybe (E.Exp, E.Exp) -> E.Exp -> E.Exp -> InternaliseM [SubExp] internaliseStreamAcc desc dest op lam bs = do dest' <- internaliseExpToVars "scatter_dest" dest bs' <- internaliseExpToVars "scatter_input" bs acc_cert_v <- newVName "acc_cert" dest_ts <- mapM lookupType dest' let dest_w = arraysSize 0 dest_ts acc_t = Acc acc_cert_v (I.Shape [dest_w]) (map rowType dest_ts) NoUniqueness acc_p <- newParam "acc_p" acc_t withacc_lam <- mkLambda [Param mempty acc_cert_v (I.Prim I.Unit), acc_p] $ do bs_ts <- mapM lookupType bs' lam' <- internaliseLambdaCoerce lam $ map rowType $ paramType acc_p : bs_ts let w = arraysSize 0 bs_ts fmap subExpsRes . letValExp' "acc_res" $ I.Op $ I.Screma w (I.paramName acc_p : bs') (I.mapSOAC lam') op' <- case op of Just (op_lam, ne) -> do ne' <- internaliseExp "hist_ne" ne ne_ts <- mapM I.subExpType ne' (lam_params, lam_body, lam_rettype) <- internaliseLambda op_lam $ ne_ts ++ ne_ts idxp <- newParam "idx" $ I.Prim int64 let op_lam' = I.Lambda (idxp : lam_params) lam_rettype lam_body pure $ Just (op_lam', ne') Nothing -> pure Nothing destw <- arraysSize 0 <$> mapM lookupType dest' fmap (map I.Var) $ letTupExp desc $ WithAcc [(I.Shape [destw], dest', op')] withacc_lam internaliseExp1 :: String -> E.Exp -> InternaliseM I.SubExp internaliseExp1 desc e = do vs <- internaliseExp desc e case vs of [se] -> pure se _ -> error "Internalise.internaliseExp1: was passed not just a single subexpression" -- | Promote to dimension type as appropriate for the original type. -- Also return original type. internaliseSizeExp :: String -> E.Exp -> InternaliseM (I.SubExp, IntType) internaliseSizeExp s e = do e' <- internaliseExp1 s e case E.typeOf e of E.Scalar (E.Prim (E.Signed it)) -> (,it) <$> asIntS Int64 e' _ -> error "internaliseSizeExp: bad type" internaliseExpToVars :: String -> E.Exp -> InternaliseM [I.VName] internaliseExpToVars desc e = mapM asIdent =<< internaliseExp desc e where asIdent (I.Var v) = pure v asIdent se = letExp desc $ I.BasicOp $ I.SubExp se internaliseOperation :: String -> E.Exp -> (I.VName -> InternaliseM I.BasicOp) -> InternaliseM [I.SubExp] internaliseOperation s e op = do vs <- internaliseExpToVars s e mapM (letSubExp s . I.BasicOp <=< op) vs certifyingNonzero :: SrcLoc -> IntType -> SubExp -> InternaliseM a -> InternaliseM a certifyingNonzero loc t x m = do zero <- letSubExp "zero" $ I.BasicOp $ CmpOp (CmpEq (IntType t)) x (intConst t 0) nonzero <- letSubExp "nonzero" $ I.BasicOp $ UnOp (I.Neg I.Bool) zero c <- assert "nonzero_cert" nonzero "division by zero" loc certifying c m certifyingNonnegative :: SrcLoc -> IntType -> SubExp -> InternaliseM a -> InternaliseM a certifyingNonnegative loc t x m = do nonnegative <- letSubExp "nonnegative" . I.BasicOp $ CmpOp (CmpSle t) (intConst t 0) x c <- assert "nonzero_cert" nonnegative "negative exponent" loc certifying c m internaliseBinOp :: SrcLoc -> String -> E.BinOp -> I.SubExp -> I.SubExp -> E.PrimType -> E.PrimType -> InternaliseM [I.SubExp] internaliseBinOp _ desc E.LogAnd x y E.Bool _ = simpleBinOp desc I.LogAnd x y internaliseBinOp _ desc E.LogOr x y E.Bool _ = simpleBinOp desc I.LogOr x y internaliseBinOp _ desc E.Plus x y (E.Signed t) _ = simpleBinOp desc (I.Add t I.OverflowWrap) x y internaliseBinOp _ desc E.Plus x y (E.Unsigned t) _ = simpleBinOp desc (I.Add t I.OverflowWrap) x y internaliseBinOp _ desc E.Plus x y (E.FloatType t) _ = simpleBinOp desc (I.FAdd t) x y internaliseBinOp _ desc E.Minus x y (E.Signed t) _ = simpleBinOp desc (I.Sub t I.OverflowWrap) x y internaliseBinOp _ desc E.Minus x y (E.Unsigned t) _ = simpleBinOp desc (I.Sub t I.OverflowWrap) x y internaliseBinOp _ desc E.Minus x y (E.FloatType t) _ = simpleBinOp desc (I.FSub t) x y internaliseBinOp _ desc E.Times x y (E.Signed t) _ = simpleBinOp desc (I.Mul t I.OverflowWrap) x y internaliseBinOp _ desc E.Times x y (E.Unsigned t) _ = simpleBinOp desc (I.Mul t I.OverflowWrap) x y internaliseBinOp _ desc E.Times x y (E.FloatType t) _ = simpleBinOp desc (I.FMul t) x y internaliseBinOp loc desc E.Divide x y (E.Signed t) _ = certifyingNonzero loc t y $ simpleBinOp desc (I.SDiv t I.Unsafe) x y internaliseBinOp loc desc E.Divide x y (E.Unsigned t) _ = certifyingNonzero loc t y $ simpleBinOp desc (I.UDiv t I.Unsafe) x y internaliseBinOp _ desc E.Divide x y (E.FloatType t) _ = simpleBinOp desc (I.FDiv t) x y internaliseBinOp _ desc E.Pow x y (E.FloatType t) _ = simpleBinOp desc (I.FPow t) x y internaliseBinOp loc desc E.Pow x y (E.Signed t) _ = certifyingNonnegative loc t y $ simpleBinOp desc (I.Pow t) x y internaliseBinOp _ desc E.Pow x y (E.Unsigned t) _ = simpleBinOp desc (I.Pow t) x y internaliseBinOp loc desc E.Mod x y (E.Signed t) _ = certifyingNonzero loc t y $ simpleBinOp desc (I.SMod t I.Unsafe) x y internaliseBinOp loc desc E.Mod x y (E.Unsigned t) _ = certifyingNonzero loc t y $ simpleBinOp desc (I.UMod t I.Unsafe) x y internaliseBinOp _ desc E.Mod x y (E.FloatType t) _ = simpleBinOp desc (I.FMod t) x y internaliseBinOp loc desc E.Quot x y (E.Signed t) _ = certifyingNonzero loc t y $ simpleBinOp desc (I.SQuot t I.Unsafe) x y internaliseBinOp loc desc E.Quot x y (E.Unsigned t) _ = certifyingNonzero loc t y $ simpleBinOp desc (I.UDiv t I.Unsafe) x y internaliseBinOp loc desc E.Rem x y (E.Signed t) _ = certifyingNonzero loc t y $ simpleBinOp desc (I.SRem t I.Unsafe) x y internaliseBinOp loc desc E.Rem x y (E.Unsigned t) _ = certifyingNonzero loc t y $ simpleBinOp desc (I.UMod t I.Unsafe) x y internaliseBinOp _ desc E.ShiftR x y (E.Signed t) _ = simpleBinOp desc (I.AShr t) x y internaliseBinOp _ desc E.ShiftR x y (E.Unsigned t) _ = simpleBinOp desc (I.LShr t) x y internaliseBinOp _ desc E.ShiftL x y (E.Signed t) _ = simpleBinOp desc (I.Shl t) x y internaliseBinOp _ desc E.ShiftL x y (E.Unsigned t) _ = simpleBinOp desc (I.Shl t) x y internaliseBinOp _ desc E.Band x y (E.Signed t) _ = simpleBinOp desc (I.And t) x y internaliseBinOp _ desc E.Band x y (E.Unsigned t) _ = simpleBinOp desc (I.And t) x y internaliseBinOp _ desc E.Xor x y (E.Signed t) _ = simpleBinOp desc (I.Xor t) x y internaliseBinOp _ desc E.Xor x y (E.Unsigned t) _ = simpleBinOp desc (I.Xor t) x y internaliseBinOp _ desc E.Bor x y (E.Signed t) _ = simpleBinOp desc (I.Or t) x y internaliseBinOp _ desc E.Bor x y (E.Unsigned t) _ = simpleBinOp desc (I.Or t) x y internaliseBinOp _ desc E.Equal x y t _ = simpleCmpOp desc (I.CmpEq $ internalisePrimType t) x y internaliseBinOp _ desc E.NotEqual x y t _ = do eq <- letSubExp (desc ++ "true") $ I.BasicOp $ I.CmpOp (I.CmpEq $ internalisePrimType t) x y fmap pure $ letSubExp desc $ I.BasicOp $ I.UnOp (I.Neg I.Bool) eq internaliseBinOp _ desc E.Less x y (E.Signed t) _ = simpleCmpOp desc (I.CmpSlt t) x y internaliseBinOp _ desc E.Less x y (E.Unsigned t) _ = simpleCmpOp desc (I.CmpUlt t) x y internaliseBinOp _ desc E.Leq x y (E.Signed t) _ = simpleCmpOp desc (I.CmpSle t) x y internaliseBinOp _ desc E.Leq x y (E.Unsigned t) _ = simpleCmpOp desc (I.CmpUle t) x y internaliseBinOp _ desc E.Greater x y (E.Signed t) _ = simpleCmpOp desc (I.CmpSlt t) y x -- Note the swapped x and y internaliseBinOp _ desc E.Greater x y (E.Unsigned t) _ = simpleCmpOp desc (I.CmpUlt t) y x -- Note the swapped x and y internaliseBinOp _ desc E.Geq x y (E.Signed t) _ = simpleCmpOp desc (I.CmpSle t) y x -- Note the swapped x and y internaliseBinOp _ desc E.Geq x y (E.Unsigned t) _ = simpleCmpOp desc (I.CmpUle t) y x -- Note the swapped x and y internaliseBinOp _ desc E.Less x y (E.FloatType t) _ = simpleCmpOp desc (I.FCmpLt t) x y internaliseBinOp _ desc E.Leq x y (E.FloatType t) _ = simpleCmpOp desc (I.FCmpLe t) x y internaliseBinOp _ desc E.Greater x y (E.FloatType t) _ = simpleCmpOp desc (I.FCmpLt t) y x -- Note the swapped x and y internaliseBinOp _ desc E.Geq x y (E.FloatType t) _ = simpleCmpOp desc (I.FCmpLe t) y x -- Note the swapped x and y -- Relational operators for booleans. internaliseBinOp _ desc E.Less x y E.Bool _ = simpleCmpOp desc I.CmpLlt x y internaliseBinOp _ desc E.Leq x y E.Bool _ = simpleCmpOp desc I.CmpLle x y internaliseBinOp _ desc E.Greater x y E.Bool _ = simpleCmpOp desc I.CmpLlt y x -- Note the swapped x and y internaliseBinOp _ desc E.Geq x y E.Bool _ = simpleCmpOp desc I.CmpLle y x -- Note the swapped x and y internaliseBinOp _ _ op _ _ t1 t2 = error $ "Invalid binary operator " ++ prettyString op ++ " with operand types " ++ prettyString t1 ++ ", " ++ prettyString t2 simpleBinOp :: String -> I.BinOp -> I.SubExp -> I.SubExp -> InternaliseM [I.SubExp] simpleBinOp desc bop x y = letTupExp' desc $ I.BasicOp $ I.BinOp bop x y simpleCmpOp :: String -> I.CmpOp -> I.SubExp -> I.SubExp -> InternaliseM [I.SubExp] simpleCmpOp desc op x y = letTupExp' desc $ I.BasicOp $ I.CmpOp op x y data Function = FunctionName (E.QualName VName) | FunctionHole SrcLoc deriving (Show) findFuncall :: E.AppExp -> (Function, [(E.Exp, Maybe VName)]) findFuncall (E.Apply f args _) | E.Var fname _ _ <- f = (FunctionName fname, map onArg $ NE.toList args) | E.Hole (Info _) loc <- f = (FunctionHole loc, map onArg $ NE.toList args) where onArg (Info argext, e) = (e, argext) findFuncall e = error $ "Invalid function expression in application:\n" ++ prettyString e -- The type of a body. Watch out: this only works for the degenerate -- case where the body does not already return its context. bodyExtType :: Body SOACS -> InternaliseM [ExtType] bodyExtType (Body _ stms res) = existentialiseExtTypes (M.keys stmsscope) . staticShapes <$> extendedScope (traverse subExpResType res) stmsscope where stmsscope = scopeOf stms internaliseLambda :: InternaliseLambda internaliseLambda (E.Parens e _) rowtypes = internaliseLambda e rowtypes internaliseLambda (E.Lambda params body _ (Info (RetType _ rettype)) _) rowtypes = bindingLambdaParams params rowtypes $ \params' -> do body' <- internaliseBody "lam" body rettype' <- internaliseLambdaReturnType rettype =<< bodyExtType body' pure (params', body', rettype') internaliseLambda e _ = error $ "internaliseLambda: unexpected expression:\n" ++ prettyString e internaliseLambdaCoerce :: E.Exp -> [Type] -> InternaliseM (I.Lambda SOACS) internaliseLambdaCoerce lam argtypes = do (params, body, rettype) <- internaliseLambda lam argtypes mkLambda params $ ensureResultShape (ErrorMsg [ErrorString "unexpected lambda result size"]) (srclocOf lam) rettype =<< bodyBind body -- | Overloaded operators are treated here. isOverloadedFunction :: E.QualName VName -> String -> SrcLoc -> Maybe ([(E.StructType, [SubExp])] -> InternaliseM [SubExp]) isOverloadedFunction qname desc loc = do guard $ baseTag (qualLeaf qname) <= maxIntrinsicTag handle $ baseString $ qualLeaf qname where -- Handle equality and inequality specially, to treat the case of -- arrays. handle op | Just cmp_f <- isEqlOp op = Just $ \[(_, xe'), (_, ye')] -> do rs <- zipWithM doComparison xe' ye' cmp_f =<< letSubExp "eq" =<< eAll rs where isEqlOp "!=" = Just $ \eq -> letTupExp' desc $ I.BasicOp $ I.UnOp (I.Neg I.Bool) eq isEqlOp "==" = Just $ \eq -> pure [eq] isEqlOp _ = Nothing doComparison x y = do x_t <- I.subExpType x y_t <- I.subExpType y case x_t of I.Prim t -> letSubExp desc $ I.BasicOp $ I.CmpOp (I.CmpEq t) x y _ -> do let x_dims = I.arrayDims x_t y_dims = I.arrayDims y_t dims_match <- forM (zip x_dims y_dims) $ \(x_dim, y_dim) -> letSubExp "dim_eq" $ I.BasicOp $ I.CmpOp (I.CmpEq int64) x_dim y_dim shapes_match <- letSubExp "shapes_match" =<< eAll dims_match let compare_elems_body = runBodyBuilder $ do -- Flatten both x and y. x_num_elems <- letSubExp "x_num_elems" =<< foldBinOp (I.Mul Int64 I.OverflowUndef) (constant (1 :: Int64)) x_dims x' <- letExp "x" $ I.BasicOp $ I.SubExp x y' <- letExp "x" $ I.BasicOp $ I.SubExp y x_flat <- letExp "x_flat" $ I.BasicOp $ I.Reshape I.ReshapeArbitrary (I.Shape [x_num_elems]) x' y_flat <- letExp "y_flat" $ I.BasicOp $ I.Reshape I.ReshapeArbitrary (I.Shape [x_num_elems]) y' -- Compare the elements. cmp_lam <- cmpOpLambda $ I.CmpEq (elemType x_t) cmps <- letExp "cmps" $ I.Op $ I.Screma x_num_elems [x_flat, y_flat] (I.mapSOAC cmp_lam) -- Check that all were equal. and_lam <- binOpLambda I.LogAnd I.Bool reduce <- I.reduceSOAC [Reduce Commutative and_lam [constant True]] all_equal <- letSubExp "all_equal" $ I.Op $ I.Screma x_num_elems [cmps] reduce pure $ subExpsRes [all_equal] letSubExp "arrays_equal" =<< eIf (eSubExp shapes_match) compare_elems_body (resultBodyM [constant False]) handle name | Just bop <- find ((name ==) . prettyString) [minBound .. maxBound :: E.BinOp] = Just $ \[(x_t, [x']), (y_t, [y'])] -> case (x_t, y_t) of (E.Scalar (E.Prim t1), E.Scalar (E.Prim t2)) -> internaliseBinOp loc desc bop x' y' t1 t2 _ -> error "Futhark.Internalise.internaliseExp: non-primitive type in BinOp." handle _ = Nothing -- | Handle intrinsic functions. These are only allowed to be called -- in the prelude, and their internalisation may involve inspecting -- the AST. isIntrinsicFunction :: E.QualName VName -> [E.Exp] -> SrcLoc -> Maybe (String -> InternaliseM [SubExp]) isIntrinsicFunction qname args loc = do guard $ baseTag (qualLeaf qname) <= maxIntrinsicTag let handlers = [ handleSign, handleOps, handleSOACs, handleAccs, handleAD, handleRest ] msum [h args $ baseString $ qualLeaf qname | h <- handlers] where handleSign [x] "sign_i8" = Just $ toSigned I.Int8 x handleSign [x] "sign_i16" = Just $ toSigned I.Int16 x handleSign [x] "sign_i32" = Just $ toSigned I.Int32 x handleSign [x] "sign_i64" = Just $ toSigned I.Int64 x handleSign [x] "unsign_i8" = Just $ toUnsigned I.Int8 x handleSign [x] "unsign_i16" = Just $ toUnsigned I.Int16 x handleSign [x] "unsign_i32" = Just $ toUnsigned I.Int32 x handleSign [x] "unsign_i64" = Just $ toUnsigned I.Int64 x handleSign _ _ = Nothing handleOps [x] s | Just unop <- find ((== s) . prettyString) allUnOps = Just $ \desc -> do x' <- internaliseExp1 "x" x fmap pure $ letSubExp desc $ I.BasicOp $ I.UnOp unop x' handleOps [TupLit [x, y] _] s | Just bop <- find ((== s) . prettyString) allBinOps = Just $ \desc -> do x' <- internaliseExp1 "x" x y' <- internaliseExp1 "y" y fmap pure $ letSubExp desc $ I.BasicOp $ I.BinOp bop x' y' | Just cmp <- find ((== s) . prettyString) allCmpOps = Just $ \desc -> do x' <- internaliseExp1 "x" x y' <- internaliseExp1 "y" y fmap pure $ letSubExp desc $ I.BasicOp $ I.CmpOp cmp x' y' handleOps [x] s | Just conv <- find ((== s) . prettyString) allConvOps = Just $ \desc -> do x' <- internaliseExp1 "x" x fmap pure $ letSubExp desc $ I.BasicOp $ I.ConvOp conv x' handleOps _ _ = Nothing handleSOACs [lam, arr] "map" = Just $ \desc -> do arr' <- internaliseExpToVars "map_arr" arr arr_ts <- mapM lookupType arr' lam' <- internaliseLambdaCoerce lam $ map rowType arr_ts let w = arraysSize 0 arr_ts letTupExp' desc $ I.Op $ I.Screma w arr' (I.mapSOAC lam') handleSOACs [k, lam, arr] "partition" = do k' <- fromIntegral <$> fromInt32 k Just $ \_desc -> do arrs <- internaliseExpToVars "partition_input" arr lam' <- internalisePartitionLambda internaliseLambda k' lam $ map I.Var arrs uncurry (++) <$> partitionWithSOACS (fromIntegral k') lam' arrs where fromInt32 (Literal (SignedValue (Int32Value k')) _) = Just k' fromInt32 (IntLit k' (Info (E.Scalar (E.Prim (E.Signed Int32)))) _) = Just $ fromInteger k' fromInt32 _ = Nothing handleSOACs [lam, ne, arr] "reduce" = Just $ \desc -> internaliseScanOrReduce desc "reduce" reduce (lam, ne, arr, loc) where reduce w red_lam nes arrs = I.Screma w arrs <$> I.reduceSOAC [Reduce Noncommutative red_lam nes] handleSOACs [lam, ne, arr] "reduce_comm" = Just $ \desc -> internaliseScanOrReduce desc "reduce" reduce (lam, ne, arr, loc) where reduce w red_lam nes arrs = I.Screma w arrs <$> I.reduceSOAC [Reduce Commutative red_lam nes] handleSOACs [lam, ne, arr] "scan" = Just $ \desc -> internaliseScanOrReduce desc "scan" reduce (lam, ne, arr, loc) where reduce w scan_lam nes arrs = I.Screma w arrs <$> I.scanSOAC [Scan scan_lam nes] handleSOACs [rf, dest, op, ne, buckets, img] "hist_1d" = Just $ \desc -> internaliseHist 1 desc rf dest op ne buckets img loc handleSOACs [rf, dest, op, ne, buckets, img] "hist_2d" = Just $ \desc -> internaliseHist 2 desc rf dest op ne buckets img loc handleSOACs [rf, dest, op, ne, buckets, img] "hist_3d" = Just $ \desc -> internaliseHist 3 desc rf dest op ne buckets img loc handleSOACs _ _ = Nothing handleAccs [dest, f, bs] "scatter_stream" = Just $ \desc -> internaliseStreamAcc desc dest Nothing f bs handleAccs [dest, op, ne, f, bs] "hist_stream" = Just $ \desc -> internaliseStreamAcc desc dest (Just (op, ne)) f bs handleAccs [acc, i, v] "acc_write" = Just $ \desc -> do acc' <- head <$> internaliseExpToVars "acc" acc i' <- internaliseExp1 "acc_i" i vs <- internaliseExp "acc_v" v fmap pure $ letSubExp desc $ BasicOp $ UpdateAcc Safe acc' [i'] vs handleAccs _ _ = Nothing handleAD [f, x, v] fname | fname `elem` ["jvp2", "vjp2"] = Just $ \desc -> do x' <- internaliseExp "ad_x" x v' <- internaliseExp "ad_v" v lam <- internaliseLambdaCoerce f =<< mapM subExpType x' fmap (map I.Var) . letTupExp desc . Op $ case fname of "jvp2" -> JVP x' v' lam _ -> VJP x' v' lam handleAD _ _ = Nothing handleRest [a, si, v] "scatter" = Just $ scatterF 1 a si v handleRest [a, si, v] "scatter_2d" = Just $ scatterF 2 a si v handleRest [a, si, v] "scatter_3d" = Just $ scatterF 3 a si v handleRest [n, m, arr] "unflatten" = Just $ \desc -> do arrs <- internaliseExpToVars "unflatten_arr" arr n' <- internaliseExp1 "n" n m' <- internaliseExp1 "m" m -- Each dimension must be nonnegative, and the unflattened -- dimension needs to have the same number of elements as the -- original dimension. old_dim <- I.arraysSize 0 <$> mapM lookupType arrs dim_ok <- letSubExp "dim_ok" <=< toExp $ pe64 old_dim .==. pe64 n' * pe64 m' .&&. pe64 n' .>=. 0 .&&. pe64 m' .>=. 0 dim_ok_cert <- assert "dim_ok_cert" dim_ok ( ErrorMsg [ "Cannot unflatten array of shape [", ErrorVal int64 old_dim, "] to array of shape [", ErrorVal int64 n', "][", ErrorVal int64 m', "]" ] ) loc certifying dim_ok_cert $ forM arrs $ \arr' -> do arr_t <- lookupType arr' letSubExp desc . I.BasicOp $ I.Reshape I.ReshapeArbitrary (reshapeOuter (I.Shape [n', m']) 1 $ I.arrayShape arr_t) arr' handleRest [arr] "manifest" = Just $ \desc -> do arrs <- internaliseExpToVars "flatten_arr" arr forM arrs $ \arr' -> do r <- I.arrayRank <$> lookupType arr' if r == 0 then pure $ I.Var arr' else letSubExp desc $ I.BasicOp $ I.Manifest [0 .. r - 1] arr' handleRest [arr] "flatten" = Just $ \desc -> do arrs <- internaliseExpToVars "flatten_arr" arr forM arrs $ \arr' -> do arr_t <- lookupType arr' let n = arraySize 0 arr_t m = arraySize 1 arr_t k <- letSubExp "flat_dim" $ I.BasicOp $ I.BinOp (Mul Int64 I.OverflowUndef) n m letSubExp desc . I.BasicOp $ I.Reshape I.ReshapeArbitrary (reshapeOuter (I.Shape [k]) 2 $ I.arrayShape arr_t) arr' handleRest [x, y] "concat" = Just $ \desc -> do xs <- internaliseExpToVars "concat_x" x ys <- internaliseExpToVars "concat_y" y outer_size <- arraysSize 0 <$> mapM lookupType xs let sumdims xsize ysize = letSubExp "conc_tmp" $ I.BasicOp $ I.BinOp (I.Add I.Int64 I.OverflowUndef) xsize ysize ressize <- foldM sumdims outer_size =<< mapM (fmap (arraysSize 0) . mapM lookupType) [ys] let conc xarr yarr = I.BasicOp $ I.Concat 0 (xarr :| [yarr]) ressize mapM (letSubExp desc) $ zipWith conc xs ys handleRest [e] "transpose" = Just $ \desc -> internaliseOperation desc e $ \v -> do r <- I.arrayRank <$> lookupType v pure $ I.Rearrange ([1, 0] ++ [2 .. r - 1]) v handleRest [x, y] "zip" = Just $ \desc -> mapM (letSubExp "zip_copy" . BasicOp . Replicate mempty . I.Var) =<< ( (++) <$> internaliseExpToVars (desc ++ "_zip_x") x <*> internaliseExpToVars (desc ++ "_zip_y") y ) handleRest [x] "unzip" = Just $ \desc -> mapM (letSubExp desc . BasicOp . Replicate mempty . I.Var) =<< internaliseExpToVars desc x handleRest [arr, offset, n1, s1, n2, s2] "flat_index_2d" = Just $ \desc -> do flatIndexHelper desc loc arr offset [(n1, s1), (n2, s2)] handleRest [arr1, offset, s1, s2, arr2] "flat_update_2d" = Just $ \desc -> do flatUpdateHelper desc loc arr1 offset [s1, s2] arr2 handleRest [arr, offset, n1, s1, n2, s2, n3, s3] "flat_index_3d" = Just $ \desc -> do flatIndexHelper desc loc arr offset [(n1, s1), (n2, s2), (n3, s3)] handleRest [arr1, offset, s1, s2, s3, arr2] "flat_update_3d" = Just $ \desc -> do flatUpdateHelper desc loc arr1 offset [s1, s2, s3] arr2 handleRest [arr, offset, n1, s1, n2, s2, n3, s3, n4, s4] "flat_index_4d" = Just $ \desc -> do flatIndexHelper desc loc arr offset [(n1, s1), (n2, s2), (n3, s3), (n4, s4)] handleRest [arr1, offset, s1, s2, s3, s4, arr2] "flat_update_4d" = Just $ \desc -> do flatUpdateHelper desc loc arr1 offset [s1, s2, s3, s4] arr2 handleRest _ _ = Nothing toSigned int_to e desc = do e' <- internaliseExp1 "trunc_arg" e case E.typeOf e of E.Scalar (E.Prim E.Bool) -> letTupExp' desc =<< eIf (eSubExp e') (resultBodyM [intConst int_to 1]) (resultBodyM [intConst int_to 0]) E.Scalar (E.Prim (E.Signed int_from)) -> letTupExp' desc $ I.BasicOp $ I.ConvOp (I.SExt int_from int_to) e' E.Scalar (E.Prim (E.Unsigned int_from)) -> letTupExp' desc $ I.BasicOp $ I.ConvOp (I.ZExt int_from int_to) e' E.Scalar (E.Prim (E.FloatType float_from)) -> letTupExp' desc $ I.BasicOp $ I.ConvOp (I.FPToSI float_from int_to) e' _ -> error "Futhark.Internalise: non-numeric type in ToSigned" toUnsigned int_to e desc = do e' <- internaliseExp1 "trunc_arg" e case E.typeOf e of E.Scalar (E.Prim E.Bool) -> letTupExp' desc =<< eIf (eSubExp e') (resultBodyM [intConst int_to 1]) (resultBodyM [intConst int_to 0]) E.Scalar (E.Prim (E.Signed int_from)) -> letTupExp' desc $ I.BasicOp $ I.ConvOp (I.ZExt int_from int_to) e' E.Scalar (E.Prim (E.Unsigned int_from)) -> letTupExp' desc $ I.BasicOp $ I.ConvOp (I.ZExt int_from int_to) e' E.Scalar (E.Prim (E.FloatType float_from)) -> letTupExp' desc $ I.BasicOp $ I.ConvOp (I.FPToUI float_from int_to) e' _ -> error "Futhark.Internalise.internaliseExp: non-numeric type in ToUnsigned" scatterF dim a si v desc = do si' <- internaliseExpToVars "write_arg_i" si svs <- internaliseExpToVars "write_arg_v" v sas <- internaliseExpToVars "write_arg_a" a si_w <- I.arraysSize 0 <$> mapM lookupType si' sv_ts <- mapM lookupType svs svs' <- forM (zip svs sv_ts) $ \(sv, sv_t) -> do let sv_shape = I.arrayShape sv_t sv_w = arraySize 0 sv_t -- Generate an assertion and reshapes to ensure that sv and si' are the same -- size. cmp <- letSubExp "write_cmp" $ I.BasicOp $ I.CmpOp (I.CmpEq I.int64) si_w sv_w c <- assert "write_cert" cmp "length of index and value array does not match" loc certifying c $ letExp (baseString sv ++ "_write_sv") . I.BasicOp $ I.Reshape I.ReshapeCoerce (reshapeOuter (I.Shape [si_w]) 1 sv_shape) sv indexType <- fmap rowType <$> mapM lookupType si' indexName <- mapM (\_ -> newVName "write_index") indexType valueNames <- replicateM (length sv_ts) $ newVName "write_value" sa_ts <- mapM lookupType sas let bodyTypes = concat (replicate (length sv_ts) indexType) ++ map (I.stripArray dim) sa_ts paramTypes = indexType <> map rowType sv_ts bodyNames = indexName <> valueNames bodyParams = zipWith (I.Param mempty) bodyNames paramTypes -- This body is boring right now, as every input is exactly the output. -- But it can get funky later on if fused with something else. body <- localScope (scopeOfLParams bodyParams) . buildBody_ $ do let outs = concat (replicate (length valueNames) indexName) ++ valueNames results <- forM outs $ \name -> letSubExp "write_res" $ I.BasicOp $ I.SubExp $ I.Var name ensureResultShape "scatter value has wrong size" loc bodyTypes (subExpsRes results) let lam = I.Lambda { I.lambdaParams = bodyParams, I.lambdaReturnType = bodyTypes, I.lambdaBody = body } sivs = si' <> svs' let sa_ws = map (I.Shape . take dim . arrayDims) sa_ts spec = zip3 sa_ws (repeat 1) sas letTupExp' desc $ I.Op $ I.Scatter si_w sivs spec lam flatIndexHelper :: String -> SrcLoc -> E.Exp -> E.Exp -> [(E.Exp, E.Exp)] -> InternaliseM [SubExp] flatIndexHelper desc loc arr offset slices = do arrs <- internaliseExpToVars "arr" arr offset' <- internaliseExp1 "offset" offset old_dim <- I.arraysSize 0 <$> mapM lookupType arrs offset_inbounds_down <- letSubExp "offset_inbounds_down" $ I.BasicOp $ I.CmpOp (I.CmpUle Int64) (intConst Int64 0) offset' offset_inbounds_up <- letSubExp "offset_inbounds_up" $ I.BasicOp $ I.CmpOp (I.CmpUlt Int64) offset' old_dim slices' <- mapM ( \(n, s) -> do n' <- internaliseExp1 "n" n s' <- internaliseExp1 "s" s pure (n', s') ) slices (min_bound, max_bound) <- foldM ( \(lower, upper) (n, s) -> do n_m1 <- letSubExp "span" $ I.BasicOp $ I.BinOp (I.Sub Int64 I.OverflowUndef) n (intConst Int64 1) spn <- letSubExp "span" $ I.BasicOp $ I.BinOp (I.Mul Int64 I.OverflowUndef) n_m1 s span_and_lower <- letSubExp "span_and_lower" $ I.BasicOp $ I.BinOp (I.Add Int64 I.OverflowUndef) spn lower span_and_upper <- letSubExp "span_and_upper" $ I.BasicOp $ I.BinOp (I.Add Int64 I.OverflowUndef) spn upper lower' <- letSubExp "minimum" $ I.BasicOp $ I.BinOp (I.UMin Int64) span_and_lower lower upper' <- letSubExp "maximum" $ I.BasicOp $ I.BinOp (I.UMax Int64) span_and_upper upper pure (lower', upper') ) (offset', offset') slices' min_in_bounds <- letSubExp "min_in_bounds" $ I.BasicOp $ I.CmpOp (I.CmpUle Int64) (intConst Int64 0) min_bound max_in_bounds <- letSubExp "max_in_bounds" $ I.BasicOp $ I.CmpOp (I.CmpUlt Int64) max_bound old_dim all_bounds <- foldM (\x y -> letSubExp "inBounds" $ I.BasicOp $ I.BinOp I.LogAnd x y) offset_inbounds_down [offset_inbounds_up, min_in_bounds, max_in_bounds] c <- assert "bounds_cert" all_bounds (ErrorMsg [ErrorString $ "Flat slice out of bounds: " <> prettyText old_dim <> " and " <> prettyText slices']) loc let slice = I.FlatSlice offset' $ map (uncurry FlatDimIndex) slices' certifying c $ forM arrs $ \arr' -> letSubExp desc $ I.BasicOp $ I.FlatIndex arr' slice flatUpdateHelper :: String -> SrcLoc -> E.Exp -> E.Exp -> [E.Exp] -> E.Exp -> InternaliseM [SubExp] flatUpdateHelper desc loc arr1 offset slices arr2 = do arrs1 <- internaliseExpToVars "arr" arr1 offset' <- internaliseExp1 "offset" offset old_dim <- I.arraysSize 0 <$> mapM lookupType arrs1 offset_inbounds_down <- letSubExp "offset_inbounds_down" $ I.BasicOp $ I.CmpOp (I.CmpUle Int64) (intConst Int64 0) offset' offset_inbounds_up <- letSubExp "offset_inbounds_up" $ I.BasicOp $ I.CmpOp (I.CmpUlt Int64) offset' old_dim arrs2 <- internaliseExpToVars "arr" arr2 ts <- mapM lookupType arrs2 slices' <- mapM ( \(s, i) -> do s' <- internaliseExp1 "s" s let n = arraysSize i ts pure (n, s') ) $ zip slices [0 ..] (min_bound, max_bound) <- foldM ( \(lower, upper) (n, s) -> do n_m1 <- letSubExp "span" $ I.BasicOp $ I.BinOp (I.Sub Int64 I.OverflowUndef) n (intConst Int64 1) spn <- letSubExp "span" $ I.BasicOp $ I.BinOp (I.Mul Int64 I.OverflowUndef) n_m1 s span_and_lower <- letSubExp "span_and_lower" $ I.BasicOp $ I.BinOp (I.Add Int64 I.OverflowUndef) spn lower span_and_upper <- letSubExp "span_and_upper" $ I.BasicOp $ I.BinOp (I.Add Int64 I.OverflowUndef) spn upper lower' <- letSubExp "minimum" $ I.BasicOp $ I.BinOp (I.UMin Int64) span_and_lower lower upper' <- letSubExp "maximum" $ I.BasicOp $ I.BinOp (I.UMax Int64) span_and_upper upper pure (lower', upper') ) (offset', offset') slices' min_in_bounds <- letSubExp "min_in_bounds" $ I.BasicOp $ I.CmpOp (I.CmpUle Int64) (intConst Int64 0) min_bound max_in_bounds <- letSubExp "max_in_bounds" $ I.BasicOp $ I.CmpOp (I.CmpUlt Int64) max_bound old_dim all_bounds <- foldM (\x y -> letSubExp "inBounds" $ I.BasicOp $ I.BinOp I.LogAnd x y) offset_inbounds_down [offset_inbounds_up, min_in_bounds, max_in_bounds] c <- assert "bounds_cert" all_bounds (ErrorMsg [ErrorString $ "Flat slice out of bounds: " <> prettyText old_dim <> " and " <> prettyText slices']) loc let slice = I.FlatSlice offset' $ map (uncurry FlatDimIndex) slices' certifying c $ forM (zip arrs1 arrs2) $ \(arr1', arr2') -> letSubExp desc $ I.BasicOp $ I.FlatUpdate arr1' slice arr2' funcall :: String -> QualName VName -> [SubExp] -> SrcLoc -> InternaliseM [SubExp] funcall desc (QualName _ fname) args loc = do (shapes, value_paramts, fun_params, rettype_fun) <- lookupFunction fname argts <- mapM subExpType args shapeargs <- argShapes shapes fun_params argts let diets = replicate (length shapeargs) I.ObservePrim ++ map I.diet value_paramts args' <- ensureArgShapes "function arguments of wrong shape" loc (map I.paramName fun_params) (map I.paramType fun_params) (shapeargs ++ args) argts' <- mapM subExpType args' case rettype_fun $ zip args' argts' of Nothing -> error . unlines $ [ "Cannot apply " <> prettyString fname <> " to " <> show (length args') <> " arguments", " " <> prettyString args', "of types", " " <> prettyString argts', "Function has " <> show (length fun_params) <> " parameters", " " <> prettyString fun_params ] Just ts -> do safety <- askSafety attrs <- asks envAttrs attributing attrs . letValExp' desc $ I.Apply (internaliseFunName fname) (zip args' diets) ts (safety, loc, mempty) -- Bind existential names defined by an expression, based on the -- concrete values that expression evaluated to. This most -- importantly should be done after function calls, but also -- everything else that can produce existentials in the source -- language. bindExtSizes :: AppRes -> [SubExp] -> InternaliseM () bindExtSizes (AppRes ret retext) ses = do let ts = foldMap toList $ internaliseType $ E.toStruct ret ses_ts <- mapM subExpType ses let combine t1 t2 = mconcat $ zipWith combine' (arrayExtDims t1) (arrayDims t2) combine' (I.Free (I.Var v)) se | v `elem` retext = M.singleton v se combine' _ _ = mempty forM_ (M.toList $ mconcat $ zipWith combine ts ses_ts) $ \(v, se) -> letBindNames [v] $ BasicOp $ SubExp se askSafety :: InternaliseM Safety askSafety = do check <- asks envDoBoundsChecks pure $ if check then I.Safe else I.Unsafe -- Implement partitioning using maps, scans and writes. partitionWithSOACS :: Int -> I.Lambda SOACS -> [I.VName] -> InternaliseM ([I.SubExp], [I.SubExp]) partitionWithSOACS k lam arrs = do arr_ts <- mapM lookupType arrs let w = arraysSize 0 arr_ts classes_and_increments <- letTupExp "increments" $ I.Op $ I.Screma w arrs (mapSOAC lam) (classes, increments) <- case classes_and_increments of classes : increments -> pure (classes, take k increments) _ -> error "partitionWithSOACS" add_lam_x_params <- replicateM k $ newParam "x" (I.Prim int64) add_lam_y_params <- replicateM k $ newParam "y" (I.Prim int64) add_lam_body <- runBodyBuilder $ localScope (scopeOfLParams $ add_lam_x_params ++ add_lam_y_params) $ fmap subExpsRes $ forM (zip add_lam_x_params add_lam_y_params) $ \(x, y) -> letSubExp "z" $ I.BasicOp $ I.BinOp (I.Add Int64 I.OverflowUndef) (I.Var $ I.paramName x) (I.Var $ I.paramName y) let add_lam = I.Lambda { I.lambdaBody = add_lam_body, I.lambdaParams = add_lam_x_params ++ add_lam_y_params, I.lambdaReturnType = replicate k $ I.Prim int64 } nes = replicate (length increments) $ intConst Int64 0 scan <- I.scanSOAC [I.Scan add_lam nes] all_offsets <- letTupExp "offsets" $ I.Op $ I.Screma w increments scan -- We have the offsets for each of the partitions, but we also need -- the total sizes, which are the last elements in the offests. We -- just have to be careful in case the array is empty. last_index <- letSubExp "last_index" $ I.BasicOp $ I.BinOp (I.Sub Int64 OverflowUndef) w $ constant (1 :: Int64) let nonempty_body = runBodyBuilder $ fmap subExpsRes $ forM all_offsets $ \offset_array -> letSubExp "last_offset" $ I.BasicOp $ I.Index offset_array $ Slice [I.DimFix last_index] empty_body = resultBodyM $ replicate k $ constant (0 :: Int64) is_empty <- letSubExp "is_empty" $ I.BasicOp $ I.CmpOp (CmpEq int64) w $ constant (0 :: Int64) sizes <- letTupExp "partition_size" =<< eIf (eSubExp is_empty) empty_body nonempty_body -- The total size of all partitions must necessarily be equal to the -- size of the input array. -- Create scratch arrays for the result. blanks <- forM arr_ts $ \arr_t -> letExp "partition_dest" $ I.BasicOp $ Scratch (I.elemType arr_t) (w : drop 1 (I.arrayDims arr_t)) -- Now write into the result. write_lam <- do c_param <- newParam "c" (I.Prim int64) offset_params <- replicateM k $ newParam "offset" (I.Prim int64) value_params <- mapM (newParam "v" . I.rowType) arr_ts (offset, offset_stms) <- collectStms $ mkOffsetLambdaBody (map I.Var sizes) (I.Var $ I.paramName c_param) 0 offset_params pure I.Lambda { I.lambdaParams = c_param : offset_params ++ value_params, I.lambdaReturnType = replicate (length arr_ts) (I.Prim int64) ++ map I.rowType arr_ts, I.lambdaBody = mkBody offset_stms $ replicate (length arr_ts) (subExpRes offset) ++ I.varsRes (map I.paramName value_params) } let spec = zip3 (repeat $ I.Shape [w]) (repeat 1) blanks results <- letTupExp "partition_res" . I.Op $ I.Scatter w (classes : all_offsets ++ arrs) spec write_lam sizes' <- letSubExp "partition_sizes" $ I.BasicOp $ I.ArrayLit (map I.Var sizes) $ I.Prim int64 pure (map I.Var results, [sizes']) where mkOffsetLambdaBody :: [SubExp] -> SubExp -> Int -> [I.LParam SOACS] -> InternaliseM SubExp mkOffsetLambdaBody _ _ _ [] = pure $ constant (-1 :: Int64) mkOffsetLambdaBody sizes c i (p : ps) = do is_this_one <- letSubExp "is_this_one" $ I.BasicOp $ I.CmpOp (CmpEq int64) c $ intConst Int64 $ toInteger i next_one <- mkOffsetLambdaBody sizes c (i + 1) ps this_one <- letSubExp "this_offset" =<< foldBinOp (Add Int64 OverflowUndef) (constant (-1 :: Int64)) (I.Var (I.paramName p) : take i sizes) letSubExp "total_res" =<< eIf (eSubExp is_this_one) (resultBodyM [this_one]) (resultBodyM [next_one]) sizeExpForError :: E.Size -> InternaliseM [ErrorMsgPart SubExp] sizeExpForError e = do e' <- internaliseExp1 "size" e pure ["[", ErrorVal int64 e', "]"] typeExpForError :: E.TypeBase Size u -> InternaliseM [ErrorMsgPart SubExp] typeExpForError (E.Scalar (E.Prim t)) = pure [ErrorString $ prettyText t] typeExpForError (E.Scalar (E.TypeVar _ v args)) = do args' <- concat <$> mapM onArg args pure $ intersperse " " $ ErrorString (prettyText v) : args' where onArg (TypeArgDim d) = sizeExpForError d onArg (TypeArgType t) = typeExpForError t typeExpForError (E.Scalar (E.Record fs)) | Just ts <- E.areTupleFields fs = do ts' <- mapM typeExpForError ts pure $ ["("] ++ intercalate [", "] ts' ++ [")"] | otherwise = do fs' <- mapM onField $ M.toList fs pure $ ["{"] ++ intercalate [", "] fs' ++ ["}"] where onField (k, te) = (ErrorString (prettyText k <> ": ") :) <$> typeExpForError te typeExpForError (E.Array _ shape et) = do shape' <- mconcat <$> mapM sizeExpForError (E.shapeDims shape) et' <- typeExpForError $ Scalar et pure $ shape' ++ et' typeExpForError (E.Scalar (E.Sum cs)) = do cs' <- mapM onConstructor $ M.toList cs pure $ intercalate [" | "] cs' where onConstructor (c, ts) = do ts' <- mapM typeExpForError ts pure $ ErrorString ("#" <> prettyText c <> " ") : intercalate [" "] ts' typeExpForError (E.Scalar Arrow {}) = pure ["#"] -- A smart constructor that compacts neighbouring literals for easier -- reading in the IR. errorMsg :: [ErrorMsgPart a] -> ErrorMsg a errorMsg = ErrorMsg . compact where compact [] = [] compact (ErrorString x : ErrorString y : parts) = compact (ErrorString (x <> y) : parts) compact (x : y) = x : compact y errorShape :: [a] -> ErrorMsg a errorShape dims = "[" <> mconcat (intersperse "][" $ map (ErrorMsg . pure . ErrorVal int64) dims) <> "]" futhark-0.25.27/src/Futhark/Internalise/FullNormalise.hs000066400000000000000000000351251475065116200231300ustar00rootroot00000000000000-- | This full normalisation module converts a well-typed, polymorphic, -- module-free Futhark program into an equivalent with only simple expresssions. -- Notably, all non-trivial expression are converted into a list of -- let-bindings to make them simpler, with no nested apply, nested lets... -- This module only performs syntactic operations. -- -- Also, it performs various kinds of desugaring: -- -- * Turns operator sections into explicit lambdas. -- -- * Rewrites BinOp nodes to Apply nodes (&& and || are converted to conditionals). -- -- * Turns `let x [i] = e1` into `let x = x with [i] = e1`. -- -- * Binds all implicit sizes. -- -- * Turns implicit record fields into explicit record fields. -- -- This is currently not done for expressions inside sizes, this processing -- still needed in monomorphisation for now. module Futhark.Internalise.FullNormalise (transformProg) where import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor import Data.List.NonEmpty qualified as NE import Data.Map qualified as M import Data.Text qualified as T import Futhark.MonadFreshNames import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Types -- Modifier to apply on binding, this is used to propagate attributes and move assertions data BindModifier = Ass Exp (Info T.Text) SrcLoc | Att (AttrInfo VName) -- Apply a list of modifiers, removing the assertions as it is not needed to check them multiple times applyModifiers :: Exp -> [BindModifier] -> (Exp, [BindModifier]) applyModifiers = foldr f . (,[]) where f (Ass ass txt loc) (body, modifs) = (Assert ass body txt loc, modifs) f (Att attr) (body, modifs) = (Attr attr body mempty, Att attr : modifs) -- A binding that occurs in the calculation flow data Binding = PatBind [SizeBinder VName] (Pat StructType) Exp | FunBind VName ([TypeParam], [Pat ParamType], Maybe (TypeExp Exp VName), Info ResRetType, Exp) type NormState = (([Binding], [BindModifier]), VNameSource) -- | Main monad of this module, the state as 3 parts: -- * the VNameSource to produce new names -- * the [Binding] is the accumulator for the result -- It behave a bit like a writer -- * the [BindModifier] is the current list of modifiers to apply to the introduced bindings -- It behave like a reader for attributes modifier, and as a state for assertion, -- they have to be in the same list to conserve their order -- Direct interaction with the inside state should be done with caution, that's why their -- no instance of `MonadState`. newtype OrderingM a = OrderingM (StateT NormState (Reader String) a) deriving (Functor, Applicative, Monad, MonadReader String, MonadState NormState) instance MonadFreshNames OrderingM where getNameSource = OrderingM $ gets snd putNameSource = OrderingM . modify . second . const addModifier :: BindModifier -> OrderingM () addModifier = OrderingM . modify . first . second . (:) rmModifier :: OrderingM () rmModifier = OrderingM $ modify $ first $ second tail addBind :: Binding -> OrderingM () addBind (PatBind s p e) = do modifs <- gets $ snd . fst let (e', modifs') = applyModifiers e modifs modify $ first $ bimap (PatBind (s <> implicit) p e' :) (const modifs') where implicit = case e of (AppExp _ (Info (AppRes _ ext))) -> map (`SizeBinder` mempty) ext _ -> [] addBind b@FunBind {} = OrderingM $ modify $ first $ first (b :) runOrdering :: (MonadFreshNames m) => OrderingM a -> m (a, [Binding]) runOrdering (OrderingM m) = modifyNameSource $ mod_tup . flip runReader "tmp" . runStateT m . (([], []),) where mod_tup (a, ((binds, modifs), src)) = if null modifs then ((a, binds), src) else error "not all bind modifiers were freed" naming :: String -> OrderingM a -> OrderingM a naming s = local (const s) -- | From now, we say an expression is "final" if it's going to be stored in a let-bind -- or is at the end of the body e.g. after all lets -- Replace a non-final expression by a let-binded variable nameExp :: Bool -> Exp -> OrderingM Exp nameExp True e = pure e nameExp False e = do name <- newNameFromString =<< ask -- "e<{" ++ prettyString e ++ "}>" let ty = typeOf e loc = srclocOf e pat = Id name (Info ty) loc addBind $ PatBind [] pat e pure $ Var (qualName name) (Info ty) loc -- An evocative name to use when naming subexpressions of the -- expression bound to this pattern. patRepName :: Pat t -> String patRepName (PatAscription p _ _) = patRepName p patRepName (Id v _ _) = baseString v patRepName _ = "tmp" expRepName :: Exp -> String expRepName (Var v _ _) = prettyString v expRepName e = "d<{" ++ prettyString (bareExp e) ++ "}>" -- An evocative name to use when naming arguments to an application. argRepName :: Exp -> Int -> String argRepName e i = expRepName e <> "_arg" <> show i -- Modify an expression as describe in module introduction, -- introducing the let-bindings in the state. getOrdering :: Bool -> Exp -> OrderingM Exp getOrdering final (Assert ass e txt loc) = do ass' <- getOrdering False ass l_prev <- OrderingM $ gets $ length . snd . fst addModifier $ Ass ass' txt loc e' <- getOrdering final e l_after <- OrderingM $ gets $ length . snd . fst -- if the list of modifier has reduced in size, that means that -- all assertions as been inserted, -- else, we have to introduce the assertion ourself if l_after <= l_prev then pure e' else do rmModifier pure $ Assert ass' e' txt loc getOrdering final (Attr attr e loc) = do -- propagate attribute addModifier $ Att attr e' <- getOrdering final e rmModifier pure $ Attr attr e' loc getOrdering _ e@Literal {} = pure e getOrdering _ e@IntLit {} = pure e getOrdering _ e@FloatLit {} = pure e getOrdering _ e@StringLit {} = pure e getOrdering _ e@Hole {} = pure e -- can we still have some ? getOrdering _ e@Var {} = pure e getOrdering final (Parens e _) = getOrdering final e getOrdering final (QualParens _ e _) = getOrdering final e getOrdering _ (TupLit es loc) = do es' <- mapM (getOrdering False) es pure $ TupLit es' loc getOrdering _ (RecordLit fs loc) = do fs' <- mapM f fs pure $ RecordLit fs' loc where f (RecordFieldExplicit n e floc) = do e' <- getOrdering False e pure $ RecordFieldExplicit n e' floc f (RecordFieldImplicit (L vloc v) t _) = f $ RecordFieldExplicit (L vloc (baseName v)) (Var (qualName v) t loc) loc getOrdering _ (ArrayVal vs t loc) = pure $ ArrayVal vs t loc getOrdering _ (ArrayLit es ty loc) | Just vs <- mapM isLiteral es, Info (Array _ (Shape [_]) (Prim t)) <- ty = pure $ ArrayVal vs t loc | otherwise = do es' <- mapM (getOrdering False) es pure $ ArrayLit es' ty loc where isLiteral (Literal v _) = Just v isLiteral _ = Nothing getOrdering _ (Project n e ty loc) = do e' <- getOrdering False e pure $ Project n e' ty loc getOrdering _ (Negate e loc) = do e' <- getOrdering False e pure $ Negate e' loc getOrdering _ (Not e loc) = do e' <- getOrdering False e pure $ Not e' loc getOrdering final (Constr n es ty loc) = do es' <- mapM (getOrdering False) es nameExp final $ Constr n es' ty loc getOrdering final (Update eb slice eu loc) = do eu' <- getOrdering False eu slice' <- astMap mapper slice eb' <- getOrdering False eb nameExp final $ Update eb' slice' eu' loc where mapper = identityMapper {mapOnExp = getOrdering False} getOrdering final (RecordUpdate eb ns eu ty loc) = do eb' <- getOrdering False eb eu' <- getOrdering False eu nameExp final $ RecordUpdate eb' ns eu' ty loc getOrdering final (Lambda params body mte ret loc) = do body' <- transformBody body nameExp final $ Lambda params body' mte ret loc getOrdering _ (OpSection qn ty loc) = pure $ Var qn ty loc getOrdering final (OpSectionLeft op ty e (Info (xp, _, xext), Info (yp, yty)) (Info (RetType dims ret), Info exts) loc) = do x <- getOrdering False e yn <- newNameFromString "y" let y = Var (qualName yn) (Info $ toStruct yty) mempty ret' = applySubst (pSubst x y) ret body = mkApply (Var op ty mempty) [(xext, x), (Nothing, y)] $ AppRes (toStruct ret') exts nameExp final $ Lambda [Id yn (Info yty) mempty] body Nothing (Info (RetType dims ret')) loc where pSubst x y vn | Named p <- xp, p == vn = Just $ ExpSubst x | Named p <- yp, p == vn = Just $ ExpSubst y | otherwise = Nothing getOrdering final (OpSectionRight op ty e (Info (xp, xty), Info (yp, _, yext)) (Info (RetType dims ret)) loc) = do xn <- newNameFromString "x" y <- getOrdering False e let x = Var (qualName xn) (Info $ toStruct xty) mempty ret' = applySubst (pSubst x y) ret body = mkApply (Var op ty mempty) [(Nothing, x), (yext, y)] $ AppRes (toStruct ret') [] nameExp final $ Lambda [Id xn (Info xty) mempty] body Nothing (Info (RetType dims ret')) loc where pSubst x y vn | Named p <- xp, p == vn = Just $ ExpSubst x | Named p <- yp, p == vn = Just $ ExpSubst y | otherwise = Nothing getOrdering final (ProjectSection names (Info ty) loc) = do xn <- newNameFromString "x" let (xty, RetType dims ret) = case ty of Scalar (Arrow _ _ d xty' ret') -> (toParam d xty', ret') _ -> error $ "not a function type for project section: " ++ prettyString ty x = Var (qualName xn) (Info $ toStruct xty) mempty body = foldl project x names nameExp final $ Lambda [Id xn (Info xty) mempty] body Nothing (Info (RetType dims ret)) loc where project e field = case typeOf e of Scalar (Record fs) | Just t <- M.lookup field fs -> Project field e (Info t) mempty t -> error $ "desugar ProjectSection: type " ++ prettyString t ++ " does not have field " ++ prettyString field getOrdering final (IndexSection slice (Info ty) loc) = do slice' <- astMap mapper slice xn <- newNameFromString "x" let (xty, RetType dims ret) = case ty of Scalar (Arrow _ _ d xty' ret') -> (toParam d xty', ret') _ -> error $ "not a function type for index section: " ++ prettyString ty x = Var (qualName xn) (Info $ toStruct xty) mempty body = AppExp (Index x slice' loc) (Info (AppRes (toStruct ret) [])) nameExp final $ Lambda [Id xn (Info xty) mempty] body Nothing (Info (RetType dims ret)) loc where mapper = identityMapper {mapOnExp = getOrdering False} getOrdering _ (Ascript e _ _) = getOrdering False e getOrdering final (AppExp (Apply f args loc) resT) = do args' <- NE.reverse <$> mapM onArg (NE.reverse (NE.zip args (NE.fromList [0 ..]))) f' <- getOrdering False f nameExp final $ AppExp (Apply f' args' loc) resT where onArg ((d, e), i) = naming (argRepName f i) $ (d,) <$> getOrdering False e getOrdering final (Coerce e te t loc) = do e' <- getOrdering False e nameExp final $ Coerce e' te t loc getOrdering final (AppExp (Range start stride end loc) resT) = do start' <- getOrdering False start stride' <- mapM (getOrdering False) stride end' <- mapM (getOrdering False) end nameExp final $ AppExp (Range start' stride' end' loc) resT getOrdering final (AppExp (LetPat sizes pat expr body _) _) = do expr' <- naming (patRepName pat) $ getOrdering True expr addBind $ PatBind sizes pat expr' getOrdering final body getOrdering final (AppExp (LetFun vn (tparams, params, mrettype, rettype, body) e _) _) = do body' <- transformBody body addBind $ FunBind vn (tparams, params, mrettype, rettype, body') getOrdering final e getOrdering final (AppExp (If cond et ef loc) resT) = do cond' <- getOrdering True cond et' <- transformBody et ef' <- transformBody ef nameExp final $ AppExp (If cond' et' ef' loc) resT getOrdering final (AppExp (Loop sizes pat einit form body loc) resT) = do einit' <- getOrdering False $ loopInitExp einit form' <- case form of For ident e -> For ident <$> getOrdering True e ForIn fpat e -> ForIn fpat <$> getOrdering True e While e -> While <$> transformBody e body' <- transformBody body nameExp final $ AppExp (Loop sizes pat (LoopInitExplicit einit') form' body' loc) resT getOrdering final (AppExp (BinOp (op, oloc) opT (el, Info elp) (er, Info erp) loc) (Info resT)) = do expr' <- case (isOr, isAnd) of (True, _) -> do el' <- naming "or_lhs" $ getOrdering True el er' <- naming "or_rhs" $ transformBody er pure $ AppExp (If el' (Literal (BoolValue True) mempty) er' loc) (Info resT) (_, True) -> do el' <- naming "and_lhs" $ getOrdering True el er' <- naming "and_rhs" $ transformBody er pure $ AppExp (If el' er' (Literal (BoolValue False) mempty) loc) (Info resT) (False, False) -> do el' <- naming (prettyString op <> "_lhs") $ getOrdering False el er' <- naming (prettyString op <> "_rhs") $ getOrdering False er pure $ mkApply (Var op opT oloc) [(elp, el'), (erp, er')] resT nameExp final expr' where isOr = baseName (qualLeaf op) == "||" isAnd = baseName (qualLeaf op) == "&&" getOrdering final (AppExp (LetWith (Ident dest dty dloc) (Ident src sty sloc) slice e body loc) _) = do e' <- getOrdering False e slice' <- astMap mapper slice addBind $ PatBind [] (Id dest dty dloc) (Update (Var (qualName src) sty sloc) slice' e' loc) getOrdering final body where mapper = identityMapper {mapOnExp = getOrdering False} getOrdering final (AppExp (Index e slice loc) resT) = do e' <- getOrdering False e slice' <- astMap mapper slice nameExp final $ AppExp (Index e' slice' loc) resT where mapper = identityMapper {mapOnExp = getOrdering False} getOrdering final (AppExp (Match expr cs loc) resT) = do expr' <- getOrdering False expr cs' <- mapM f cs nameExp final $ AppExp (Match expr' cs' loc) resT where f (CasePat pat body cloc) = do body' <- transformBody body pure (CasePat pat body' cloc) -- Transform a body, e.g. the expression of a valbind, -- branches of an if/match... -- Note that this is not producing an OrderingM, produce -- a complete separtion of states. transformBody :: (MonadFreshNames m) => Exp -> m Exp transformBody e = do (e', pre_eval) <- runOrdering (getOrdering True e) pure $ foldl f e' pre_eval where appRes = case e of (AppExp _ r) -> r _ -> Info $ AppRes (typeOf e) [] f body (PatBind sizes p expr) = AppExp (LetPat sizes p expr body mempty) appRes f body (FunBind vn infos) = AppExp (LetFun vn infos body mempty) appRes transformValBind :: (MonadFreshNames m) => ValBind -> m ValBind transformValBind valbind = do body' <- transformBody $ valBindBody valbind pure $ valbind {valBindBody = body'} -- | Fully normalise top level bindings. transformProg :: (MonadFreshNames m) => [ValBind] -> m [ValBind] transformProg = mapM transformValBind futhark-0.25.27/src/Futhark/Internalise/Lambdas.hs000066400000000000000000000051451475065116200217160ustar00rootroot00000000000000module Futhark.Internalise.Lambdas ( InternaliseLambda, internaliseFoldLambda, internalisePartitionLambda, ) where import Data.Maybe (listToMaybe) import Futhark.IR.SOACS as I import Futhark.Internalise.AccurateSizes import Futhark.Internalise.Monad import Language.Futhark as E -- | A function for internalising lambdas. type InternaliseLambda = E.Exp -> [I.Type] -> InternaliseM ([I.LParam SOACS], I.Body SOACS, [I.Type]) internaliseFoldLambda :: InternaliseLambda -> E.Exp -> [I.Type] -> [I.Type] -> InternaliseM (I.Lambda SOACS) internaliseFoldLambda internaliseLambda lam acctypes arrtypes = do let rowtypes = map I.rowType arrtypes (params, body, rettype) <- internaliseLambda lam $ acctypes ++ rowtypes let rettype' = [ t `I.setArrayShape` I.arrayShape shape | (t, shape) <- zip rettype acctypes ] -- The result of the body must have the exact same shape as the -- initial accumulator. mkLambda params $ ensureResultShape (ErrorMsg [ErrorString "shape of result does not match shape of initial value"]) (srclocOf lam) rettype' =<< bodyBind body -- Given @k@ lambdas, this will return a lambda that returns an -- (k+2)-element tuple of integers. The first element is the -- equivalence class ID in the range [0,k]. The remaining are all zero -- except for possibly one element. internalisePartitionLambda :: InternaliseLambda -> Int -> E.Exp -> [I.SubExp] -> InternaliseM (I.Lambda SOACS) internalisePartitionLambda internaliseLambda k lam args = do argtypes <- mapM I.subExpType args let rowtypes = map I.rowType argtypes (params, body, _) <- internaliseLambda lam rowtypes body' <- localScope (scopeOfLParams params) $ lambdaWithIncrement body pure $ I.Lambda params rettype body' where rettype = replicate (k + 2) $ I.Prim int64 result i = map constant $ fromIntegral i : (replicate i 0 ++ [1 :: Int64] ++ replicate (k - i) 0) mkResult _ i | i >= k = pure $ result i mkResult eq_class i = do is_i <- letSubExp "is_i" $ BasicOp $ CmpOp (CmpEq int64) eq_class $ intConst Int64 $ toInteger i letTupExp' "part_res" =<< eIf (eSubExp is_i) (pure $ resultBody $ result i) (resultBody <$> mkResult eq_class (i + 1)) lambdaWithIncrement :: I.Body SOACS -> InternaliseM (I.Body SOACS) lambdaWithIncrement lam_body = runBodyBuilder $ do eq_class <- maybe (intConst Int64 0) resSubExp . listToMaybe <$> bodyBind lam_body subExpsRes <$> mkResult eq_class 0 futhark-0.25.27/src/Futhark/Internalise/LiftLambdas.hs000066400000000000000000000164401475065116200225350ustar00rootroot00000000000000-- | Lambda-lifting of typed, monomorphic Futhark programs without -- modules. After this pass, the program will no longer contain any -- 'LetFun's or 'Lambda's. module Futhark.Internalise.LiftLambdas (transformProg) where import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor import Data.Bitraversable import Data.Foldable import Data.List (partition) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.IR.Pretty () import Futhark.MonadFreshNames import Futhark.Util (nubOrd) import Language.Futhark import Language.Futhark.Traversals data Env = Env { envReplace :: M.Map VName (StructType -> Exp), envVtable :: M.Map VName StructType } initialEnv :: Env initialEnv = Env mempty mempty data LiftState = State { stateNameSource :: VNameSource, stateValBinds :: [ValBind], stateGlobal :: S.Set VName } initialState :: VNameSource -> LiftState initialState src = State src mempty $ S.fromList $ M.keys intrinsics newtype LiftM a = LiftM (ReaderT Env (State LiftState) a) deriving (Functor, Applicative, Monad, MonadReader Env, MonadState LiftState) instance MonadFreshNames LiftM where putNameSource src = modify $ \s -> s {stateNameSource = src} getNameSource = gets stateNameSource runLiftM :: VNameSource -> LiftM () -> ([ValBind], VNameSource) runLiftM src (LiftM m) = let s = execState (runReaderT m initialEnv) (initialState src) in (reverse (stateValBinds s), stateNameSource s) addValBind :: ValBind -> LiftM () addValBind vb = modify $ \s -> s { stateValBinds = vb : stateValBinds s, stateGlobal = foldl' (flip S.insert) (stateGlobal s) (valBindBound vb) } replacing :: VName -> (StructType -> Exp) -> LiftM a -> LiftM a replacing v e = local $ \env -> env {envReplace = M.insert v e $ envReplace env} bindingParams :: [VName] -> [Pat ParamType] -> LiftM a -> LiftM a bindingParams sizes params = local $ \env -> env { envVtable = M.fromList (map (second toStruct) (foldMap patternMap params) <> map (,i64) sizes) <> envVtable env } where i64 = Scalar $ Prim $ Signed Int64 bindingLetPat :: [VName] -> Pat StructType -> LiftM a -> LiftM a bindingLetPat sizes pat = local $ \env -> env { envVtable = M.fromList (map (second toStruct) (patternMap pat) <> map (,i64) sizes) <> envVtable env } where i64 = Scalar $ Prim $ Signed Int64 bindingForm :: LoopFormBase Info VName -> LiftM a -> LiftM a bindingForm (For i _) = bindingLetPat [] (Id (identName i) (identType i) mempty) bindingForm (ForIn p _) = bindingLetPat [] p bindingForm While {} = id toRet :: TypeBase Size u -> TypeBase Size Uniqueness toRet = second (const Nonunique) liftFunction :: VName -> [TypeParam] -> [Pat ParamType] -> ResRetType -> Exp -> LiftM (StructType -> Exp) liftFunction fname tparams params (RetType dims ret) funbody = do -- Find free variables vtable <- asks envVtable let isFree v = (v,) <$> M.lookup v vtable withTypes = mapMaybe isFree . S.toList . fvVars let free = let immediate_free = withTypes $ freeInExp funbody sizes_in_free = foldMap (freeInType . snd) immediate_free sizes = withTypes $ sizes_in_free <> foldMap freeInPat params <> freeInType ret in nubOrd $ immediate_free <> sizes -- Those parameters that correspond to sizes must come first. sizes_in_types = foldMap freeInType (toStruct ret : map snd free ++ map patternStructType params) isSize (v, _) = v `S.member` fvVars sizes_in_types (free_dims, free_nondims) = partition isSize free free_ts = map (second (`setUniqueness` Nonunique)) $ free_dims ++ free_nondims addValBind $ ValBind { valBindName = fname, valBindTypeParams = tparams, valBindParams = map mkParam free_ts ++ params, valBindRetDecl = Nothing, valBindRetType = Info (RetType dims ret), valBindBody = funbody, valBindDoc = Nothing, valBindAttrs = mempty, valBindLocation = mempty, valBindEntryPoint = Nothing } pure $ \orig_type -> apply orig_type (Var (qualName fname) (Info (augType free_ts orig_type)) mempty) $ free_dims ++ free_nondims where mkParam (v, t) = Id v (Info (toParam Observe t)) mempty freeVar (v, t) = Var (qualName v) (Info t) mempty augType rem_free orig_type = funType (map mkParam rem_free) $ RetType [] $ toRet orig_type apply :: StructType -> Exp -> [(VName, StructType)] -> Exp apply _ f [] = f apply orig_type f (p : rem_ps) = let inner_ret = AppRes (augType rem_ps orig_type) mempty inner = mkApply f [(Nothing, freeVar p)] inner_ret in apply orig_type inner rem_ps transformSubExps :: ASTMapper LiftM transformSubExps = identityMapper {mapOnExp = transformExp} transformType :: TypeBase Exp u -> LiftM (TypeBase Exp u) transformType = bitraverse transformExp pure transformPat :: PatBase Info VName (TypeBase Exp u) -> LiftM (PatBase Info VName (TypeBase Exp u)) transformPat = traverse transformType transformExp :: Exp -> LiftM Exp transformExp (AppExp (LetFun fname (tparams, params, _, Info ret, funbody) body _) _) = do funbody' <- bindingParams (map typeParamName tparams) params $ transformExp funbody fname' <- newVName $ "lifted_" ++ baseString fname lifted_call <- liftFunction fname' tparams params ret funbody' replacing fname lifted_call $ transformExp body transformExp e@(Lambda params body _ (Info ret) _) = do body' <- bindingParams [] params $ transformExp body fname <- newVName "lifted_lambda" liftFunction fname [] params ret body' <*> pure (typeOf e) transformExp (AppExp (LetPat sizes pat e body loc) appres) = do e' <- transformExp e pat' <- transformPat pat body' <- bindingLetPat (map sizeName sizes) pat' $ transformExp body pure $ AppExp (LetPat sizes pat' e' body' loc) appres transformExp (AppExp (Match e cases loc) appres) = do e' <- transformExp e cases' <- mapM transformCase cases pure $ AppExp (Match e' cases' loc) appres where transformCase (CasePat case_pat case_e case_loc) = CasePat case_pat <$> bindingLetPat [] case_pat (transformExp case_e) <*> pure case_loc transformExp (AppExp (Loop sizes pat args form body loc) appres) = do args' <- transformExp $ loopInitExp args bindingParams sizes [pat] $ do form' <- astMap transformSubExps form body' <- bindingForm form' $ transformExp body pure $ AppExp (Loop sizes pat (LoopInitExplicit args') form' body' loc) appres transformExp (Var v (Info t) loc) = do t' <- transformType t -- Note that function-typed variables can only occur in expressions, -- not in other places where VNames/QualNames can occur. asks $ maybe (Var v (Info t') loc) ($ t') . M.lookup (qualLeaf v) . envReplace transformExp e = astMap transformSubExps e transformValBind :: ValBind -> LiftM () transformValBind vb = do e <- bindingParams (map typeParamName $ valBindTypeParams vb) (valBindParams vb) $ transformExp (valBindBody vb) addValBind $ vb {valBindBody = e} {-# NOINLINE transformProg #-} -- | Perform the transformation. transformProg :: (MonadFreshNames m) => [ValBind] -> m [ValBind] transformProg vbinds = modifyNameSource $ \namesrc -> runLiftM namesrc $ mapM_ transformValBind vbinds futhark-0.25.27/src/Futhark/Internalise/Monad.hs000066400000000000000000000146571475065116200214210ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.Internalise.Monad ( InternaliseM, runInternaliseM, throwError, VarSubsts, InternaliseEnv (..), FunInfo, substitutingVars, lookupSubst, addOpaques, addFunDef, lookupFunction, lookupConst, bindFunction, bindConstant, assert, -- * Convenient reexports module Futhark.Tools, ) where import Control.Monad import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State import Data.List (find) import Data.Map.Strict qualified as M import Futhark.IR.SOACS import Futhark.MonadFreshNames import Futhark.Tools type FunInfo = ( [VName], [DeclType], [FParam SOACS], [(SubExp, Type)] -> Maybe [(DeclExtType, RetAls)] ) type FunTable = M.Map VName FunInfo -- | A mapping from external variable names to the corresponding -- internalised subexpressions. type VarSubsts = M.Map VName [SubExp] data InternaliseEnv = InternaliseEnv { envSubsts :: VarSubsts, envDoBoundsChecks :: Bool, envSafe :: Bool, envAttrs :: Attrs } data InternaliseState = InternaliseState { stateNameSource :: VNameSource, stateFunTable :: FunTable, stateConstSubsts :: VarSubsts, stateFuns :: [FunDef SOACS], stateTypes :: OpaqueTypes } newtype InternaliseM a = InternaliseM (BuilderT SOACS (ReaderT InternaliseEnv (State InternaliseState)) a) deriving ( Functor, Applicative, Monad, MonadReader InternaliseEnv, MonadState InternaliseState, MonadFreshNames, HasScope SOACS ) -- Internalisation has to deal with the risk of multiple binding of -- the same variable (although always of the same type) in the -- program; in particular this might imply shadowing a constant. The -- LocalScope instance for BuilderT does not handle this properly (and -- doing so would make it slower). So we remove already-known -- variables before passing the scope on. instance LocalScope SOACS InternaliseM where localScope scope (InternaliseM m) = do old_scope <- askScope InternaliseM $ localScope (scope `M.difference` old_scope) m instance MonadFreshNames (State InternaliseState) where getNameSource = gets stateNameSource putNameSource src = modify $ \s -> s {stateNameSource = src} instance MonadBuilder InternaliseM where type Rep InternaliseM = SOACS mkExpDecM pat e = InternaliseM $ mkExpDecM pat e mkBodyM stms res = InternaliseM $ mkBodyM stms res mkLetNamesM pat e = InternaliseM $ mkLetNamesM pat e addStms = InternaliseM . addStms collectStms (InternaliseM m) = InternaliseM $ collectStms m runInternaliseM :: (MonadFreshNames m) => Bool -> InternaliseM () -> m (OpaqueTypes, Stms SOACS, [FunDef SOACS]) runInternaliseM safe (InternaliseM m) = modifyNameSource $ \src -> let ((_, consts), s) = runState (runReaderT (runBuilderT m mempty) newEnv) (newState src) in ( (stateTypes s, consts, reverse $ stateFuns s), stateNameSource s ) where newEnv = InternaliseEnv { envSubsts = mempty, envDoBoundsChecks = True, envSafe = safe, envAttrs = mempty } newState src = InternaliseState { stateNameSource = src, stateFunTable = mempty, stateConstSubsts = mempty, stateFuns = mempty, stateTypes = mempty } substitutingVars :: VarSubsts -> InternaliseM a -> InternaliseM a substitutingVars substs = local $ \env -> env {envSubsts = substs <> envSubsts env} lookupSubst :: VName -> InternaliseM (Maybe [SubExp]) lookupSubst v = do env_substs <- asks $ M.lookup v . envSubsts const_substs <- gets $ M.lookup v . stateConstSubsts pure $ env_substs `mplus` const_substs -- | Add opaque types. If the types are already known, they will not -- be added. addOpaques :: OpaqueTypes -> InternaliseM () addOpaques ts@(OpaqueTypes ts') = modify $ \s -> -- TODO: handle this better (#1960) case find (knownButDifferent (stateTypes s)) ts' of Just (x, _) -> error $ "addOpaques: multiple incompatible definitions of type " <> nameToString x Nothing -> s {stateTypes = stateTypes s <> ts} where knownButDifferent (OpaqueTypes old_ts) (v, def) = any (\(v_old, v_def) -> v == v_old && def /= v_def) old_ts -- | Add a function definition to the program being constructed. addFunDef :: FunDef SOACS -> InternaliseM () addFunDef fd = modify $ \s -> s {stateFuns = fd : stateFuns s} lookupFunction :: VName -> InternaliseM FunInfo lookupFunction fname = maybe bad pure =<< gets (M.lookup fname . stateFunTable) where bad = error $ "Internalise.lookupFunction: Function '" ++ prettyString fname ++ "' not found." lookupConst :: VName -> InternaliseM (Maybe [SubExp]) lookupConst fname = do is_var <- asksScope (fname `M.member`) fname_subst <- lookupSubst fname case (is_var, fname_subst) of (_, Just ses) -> pure $ Just ses (True, _) -> pure $ Just [Var fname] _ -> pure Nothing bindFunction :: VName -> FunDef SOACS -> FunInfo -> InternaliseM () bindFunction fname fd info = do addFunDef fd modify $ \s -> s {stateFunTable = M.insert fname info $ stateFunTable s} bindConstant :: VName -> FunDef SOACS -> InternaliseM () bindConstant cname fd = do addStms $ bodyStms $ funDefBody fd case map resSubExp . bodyResult . funDefBody $ fd of [se] -> do letBindNames [cname] $ BasicOp $ SubExp se ses -> do let substs = drop (length (shapeContext (map fst (funDefRetType fd)))) ses modify $ \s -> s { stateConstSubsts = M.insert cname substs $ stateConstSubsts s } -- | Construct an 'Assert' statement, but taking attributes into -- account. Always use this function, and never construct 'Assert' -- directly in the internaliser! assert :: String -> SubExp -> ErrorMsg SubExp -> SrcLoc -> InternaliseM Certs assert desc se msg loc = assertingOne $ do attrs <- asks $ attrsForAssert . envAttrs attributing attrs $ letExp desc $ BasicOp $ Assert se msg (loc, mempty) -- | Execute the given action if 'envDoBoundsChecks' is true, otherwise -- just return an empty list. asserting :: InternaliseM Certs -> InternaliseM Certs asserting m = do doBoundsChecks <- asks envDoBoundsChecks if doBoundsChecks then m else pure mempty -- | Execute the given action if 'envDoBoundsChecks' is true, otherwise -- just return an empty list. assertingOne :: InternaliseM VName -> InternaliseM Certs assertingOne m = asserting $ Certs . pure <$> m futhark-0.25.27/src/Futhark/Internalise/Monomorphise.hs000066400000000000000000001251671475065116200230410ustar00rootroot00000000000000-- | This monomorphization module converts a well-typed, polymorphic, -- module-free Futhark program into an equivalent monomorphic program. -- -- This pass also does a few other simplifications to make the job of -- subsequent passes easier. Specifically, it does the following: -- -- * Turn operator sections into explicit lambdas. -- -- * Converts applications of intrinsic SOACs into SOAC AST nodes -- (Map, Reduce, etc). -- -- * Elide functions that are not reachable from an entry point (this -- is a side effect of the monomorphisation algorithm, which uses -- the entry points as roots). -- -- * Rewrite BinOp nodes to Apply nodes. -- -- * Replace all size expressions by constants or variables, -- complex expressions replaced by variables are calculated in -- let binding or replaced by size parameters if in argument. -- -- Note that these changes are unfortunately not visible in the AST -- representation. module Futhark.Internalise.Monomorphise (transformProg) where import Control.Monad import Control.Monad.Identity import Control.Monad.RWS (MonadReader (..), MonadWriter (..), RWST, asks, runRWST) import Control.Monad.State import Control.Monad.Writer (Writer, runWriter, runWriterT) import Data.Bifunctor import Data.Bitraversable import Data.Foldable import Data.List (partition) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe (isJust, isNothing) import Data.Sequence qualified as Seq import Data.Set qualified as S import Futhark.MonadFreshNames import Futhark.Util (nubOrd, topologicalSort) import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Types i64 :: TypeBase dim als i64 = Scalar $ Prim $ Signed Int64 -- The monomorphization monad reads 'PolyBinding's and writes -- 'ValBind's. The 'TypeParam's in the 'ValBind's can only be size -- parameters. newtype PolyBinding = PolyBinding ( Maybe EntryPoint, VName, [TypeParam], [Pat ParamType], ResRetType, Exp, [AttrInfo VName], SrcLoc ) -- | To deduplicate size expressions, we want a looser notion of -- equality than the strict syntactical equality provided by the Eq -- instance on Exp. This newtype wrapper provides such a looser notion -- of equality. newtype ReplacedExp = ReplacedExp {unReplaced :: Exp} deriving (Show) instance Pretty ReplacedExp where pretty (ReplacedExp e) = pretty e instance Eq ReplacedExp where ReplacedExp e1 == ReplacedExp e2 | Just es <- similarExps e1 e2 = all (uncurry (==) . bimap ReplacedExp ReplacedExp) es _ == _ = False type ExpReplacements = [(ReplacedExp, VName)] canCalculate :: S.Set VName -> ExpReplacements -> ExpReplacements canCalculate scope mapping = do filter ( (`S.isSubsetOf` scope) . S.filter notIntrisic . fvVars . freeInExp . unReplaced . fst ) mapping where notIntrisic vn = baseTag vn > maxIntrinsicTag -- Replace some expressions by a parameter. expReplace :: ExpReplacements -> Exp -> Exp expReplace mapping e | Just vn <- lookup (ReplacedExp e) mapping = Var (qualName vn) (Info $ typeOf e) (srclocOf e) expReplace mapping e = runIdentity $ astMap mapper e where mapper = identityMapper {mapOnExp = pure . expReplace mapping} -- Construct an Assert expression that checks that the names (values) -- in the mapping have the same value as the expression they -- represent. This is injected into entry points, where we cannot -- otherwise trust the input. XXX: the error message generated from -- this is not great; we should rework it eventually. entryAssert :: ExpReplacements -> Exp -> Exp entryAssert [] body = body entryAssert (x : xs) body = Assert (foldl logAnd (cmpExp x) $ map cmpExp xs) body errmsg (srclocOf body) where errmsg = Info "entry point arguments have invalid sizes." bool = Scalar $ Prim Bool opt = foldFunType [bool, bool] $ RetType [] bool andop = Var (qualName (intrinsicVar "&&")) (Info opt) mempty eqop = Var (qualName (intrinsicVar "==")) (Info opt) mempty logAnd x' y = mkApply andop [(Nothing, x'), (Nothing, y)] $ AppRes bool [] cmpExp (ReplacedExp x', y) = mkApply eqop [(Nothing, x'), (Nothing, y')] $ AppRes bool [] where y' = Var (qualName y) (Info i64) mempty -- Monomorphization environment mapping names of polymorphic functions -- to a representation of their corresponding function bindings. data Env = Env { envPolyBindings :: M.Map VName PolyBinding, envScope :: S.Set VName, envGlobalScope :: S.Set VName, envParametrized :: ExpReplacements } instance Semigroup Env where Env pb1 sc1 gs1 pr1 <> Env pb2 sc2 gs2 pr2 = Env (pb1 <> pb2) (sc1 <> sc2) (gs1 <> gs2) (pr1 <> pr2) instance Monoid Env where mempty = Env mempty mempty mempty mempty localEnv :: Env -> MonoM a -> MonoM a localEnv env = local (env <>) isolateNormalisation :: MonoM a -> MonoM a isolateNormalisation m = do prevRepl <- get put mempty ret <- local (\env -> env {envScope = mempty, envParametrized = mempty}) m put prevRepl pure ret -- | These now have monomorphic types in the given action. This is -- used to handle shadowing. withMono :: [VName] -> MonoM a -> MonoM a withMono [] = id withMono vs = local $ \env -> env {envPolyBindings = M.filterWithKey keep (envPolyBindings env)} where keep v _ = v `notElem` vs withArgs :: S.Set VName -> MonoM a -> MonoM a withArgs args = localEnv $ mempty {envScope = args} withParams :: ExpReplacements -> MonoM a -> MonoM a withParams params = localEnv $ mempty {envParametrized = params} -- The monomorphization monad. newtype MonoM a = MonoM ( RWST Env (Seq.Seq (VName, ValBind)) (ExpReplacements, VNameSource) (State Lifts) a ) deriving ( Functor, Applicative, Monad, MonadReader Env, MonadWriter (Seq.Seq (VName, ValBind)) ) instance MonadFreshNames MonoM where getNameSource = MonoM $ gets snd putNameSource = MonoM . modify . second . const instance MonadState ExpReplacements MonoM where get = MonoM $ gets fst put = MonoM . modify . first . const runMonoM :: VNameSource -> MonoM a -> ((a, Seq.Seq (VName, ValBind)), VNameSource) runMonoM src (MonoM m) = ((a, defs), src') where (a, (_, src'), defs) = evalState (runRWST m mempty (mempty, src)) mempty lookupFun :: VName -> MonoM (Maybe PolyBinding) lookupFun vn = do env <- asks envPolyBindings case M.lookup vn env of Just valbind -> pure $ Just valbind Nothing -> pure Nothing askScope :: MonoM (S.Set VName) askScope = do scope <- asks envScope scope' <- asks $ S.union scope . envGlobalScope scope'' <- asks $ S.union scope' . M.keysSet . envPolyBindings S.union scope'' . S.fromList . map (fst . snd) <$> getLifts -- | Asks the introduced variables in a set of argument, -- that is arguments not currently in scope. askIntros :: S.Set VName -> MonoM (S.Set VName) askIntros argset = (S.filter notIntrisic argset `S.difference`) <$> askScope where notIntrisic vn = baseTag vn > maxIntrinsicTag -- | Gets and removes expressions that could not be calculated when -- the arguments set will be unscoped. -- This should be called without argset in scope, for good detection of intros. parametrizing :: S.Set VName -> MonoM ExpReplacements parametrizing argset = do intros <- askIntros argset let usesIntros = not . S.disjoint intros . fvVars . freeInExp (params, nxtBind) <- gets $ partition (usesIntros . unReplaced . fst) put nxtBind pure params calculateDims :: Exp -> ExpReplacements -> MonoM Exp calculateDims body repl = foldCalc top_repl $ expReplace top_repl body where depends (a, _) (b, _) = unReplaced b `elem` subExps (unReplaced a) top_repl = topologicalSort depends repl ---- Calculus insertion foldCalc [] body' = pure body' foldCalc ((dim, vn) : repls) body' = do reName <- newName vn let expr = expReplace repls $ unReplaced dim subst vn' = if vn' == vn then Just $ ExpSubst $ sizeFromName (qualName reName) mempty else Nothing appRes = case body' of (AppExp _ (Info (AppRes ty ext))) -> Info $ AppRes (applySubst subst ty) (reName : ext) e -> Info $ AppRes (applySubst subst $ typeOf e) [reName] foldCalc repls $ AppExp ( LetPat [] (Id vn (Info i64) (srclocOf expr)) expr body' mempty ) appRes unscoping :: S.Set VName -> Exp -> MonoM Exp unscoping argset body = do localDims <- parametrizing argset scope <- S.union argset <$> askScope calculateDims body $ canCalculate scope localDims scoping :: S.Set VName -> MonoM Exp -> MonoM Exp scoping argset m = withArgs argset m >>= unscoping argset -- Given instantiated type of function, produce size arguments. type InferSizeArgs = StructType -> MonoM [Exp] -- | The integer encodes an equivalence class, so we can keep -- track of sizes that are statically identical. data MonoSize = MonoKnown Int | MonoAnon Int deriving (Eq, Show) instance Pretty MonoSize where pretty (MonoKnown i) = "?" <> pretty i pretty (MonoAnon i) = "??" <> pretty i instance Pretty (Shape MonoSize) where pretty (Shape ds) = mconcat (map (brackets . pretty) ds) -- The kind of type relative to which we monomorphise. What is most -- important to us is not the specific dimensions, but merely whether -- they are known or anonymous/local. type MonoType = TypeBase MonoSize NoUniqueness monoType :: TypeBase Size als -> MonoType monoType = noExts . (`evalState` (0, mempty)) . traverseDims onDim . toStruct where -- Remove exts from return types because we don't use them anymore. noExts :: TypeBase MonoSize u -> TypeBase MonoSize u noExts (Array u shape t) = Array u shape $ noExtsScalar t noExts (Scalar t) = Scalar $ noExtsScalar t noExtsScalar (Record fs) = Record $ M.map noExts fs noExtsScalar (Sum fs) = Sum $ M.map (map noExts) fs noExtsScalar (Arrow as p d t1 (RetType _ t2)) = Arrow as p d (noExts t1) (RetType [] (noExts t2)) noExtsScalar t = t onDim bound _ d -- A locally bound size. | any (`S.member` bound) $ fvVars $ freeInExp d = do (i, m) <- get case M.lookup d m of Just prev -> pure $ MonoAnon prev Nothing -> do put (i + 1, M.insert d i m) pure $ MonoAnon i onDim _ _ d = do (i, m) <- get case M.lookup d m of Just prev -> pure $ MonoKnown prev Nothing -> do put (i + 1, M.insert d i m) pure $ MonoKnown i -- Mapping from function name and instance list to a new function name in case -- the function has already been instantiated with those concrete types. type Lifts = [((VName, MonoType), (VName, InferSizeArgs))] getLifts :: MonoM Lifts getLifts = MonoM $ lift get modifyLifts :: (Lifts -> Lifts) -> MonoM () modifyLifts = MonoM . lift . modify addLifted :: VName -> MonoType -> (VName, InferSizeArgs) -> MonoM () addLifted fname il liftf = modifyLifts (((fname, il), liftf) :) lookupLifted :: VName -> MonoType -> MonoM (Maybe (VName, InferSizeArgs)) lookupLifted fname t = lookup (fname, t) <$> getLifts sizeVarName :: Exp -> String sizeVarName e = "d<{" <> prettyString (bareExp e) <> "}>" -- | Creates a new expression replacement if needed, this always produces normalised sizes. -- (e.g. single variable or constant) replaceExp :: Exp -> MonoM Exp replaceExp e = case maybeNormalisedSize e of Just e' -> pure e' Nothing -> do let e' = ReplacedExp e prev <- gets $ lookup e' prev_param <- asks $ lookup e' . envParametrized case (prev_param, prev) of (Just vn, _) -> pure $ sizeFromName (qualName vn) (srclocOf e) (Nothing, Just vn) -> pure $ sizeFromName (qualName vn) (srclocOf e) (Nothing, Nothing) -> do vn <- newNameFromString $ sizeVarName e modify ((e', vn) :) pure $ sizeFromName (qualName vn) (srclocOf e) where -- Avoid replacing of some 'already normalised' sizes that are just surounded by some parentheses. maybeNormalisedSize e' | Just e'' <- stripExp e' = maybeNormalisedSize e'' maybeNormalisedSize (Var qn _ loc) = Just $ sizeFromName qn loc maybeNormalisedSize (IntLit v _ loc) = Just $ IntLit v (Info i64) loc maybeNormalisedSize _ = Nothing transformFName :: SrcLoc -> QualName VName -> StructType -> MonoM Exp transformFName loc fname ft = do t' <- transformType ft let mono_t = monoType ft if baseTag (qualLeaf fname) <= maxIntrinsicTag then pure $ var fname t' else do maybe_fname <- lookupLifted (qualLeaf fname) mono_t maybe_funbind <- lookupFun $ qualLeaf fname case (maybe_fname, maybe_funbind) of -- The function has already been monomorphised. (Just (fname', infer), _) -> applySizeArgs fname' (toRes Nonunique t') <$> infer t' -- An intrinsic function. (Nothing, Nothing) -> pure $ var fname t' -- A polymorphic function. (Nothing, Just funbind) -> do (fname', infer, funbind') <- monomorphiseBinding funbind mono_t tell $ Seq.singleton (qualLeaf fname, funbind') addLifted (qualLeaf fname) mono_t (fname', infer) applySizeArgs fname' (toRes Nonunique t') <$> infer t' where var fname' t' = Var fname' (Info t') loc applySizeArg t (i, f) size_arg = ( i - 1, mkApply f [(Nothing, size_arg)] (AppRes (foldFunType (replicate i i64) (RetType [] t)) []) ) applySizeArgs fname' t size_args = snd $ foldl' (applySizeArg t) ( length size_args - 1, Var (qualName fname') (Info (foldFunType (map (const i64) size_args) (RetType [] t))) loc ) size_args transformType :: TypeBase Size u -> MonoM (TypeBase Size u) transformType typ = case typ of Scalar scalar -> Scalar <$> transformScalarSizes scalar Array u shape scalar -> Array u <$> mapM onDim shape <*> transformScalarSizes scalar where transformScalarSizes :: ScalarTypeBase Size u -> MonoM (ScalarTypeBase Size u) transformScalarSizes (Record fs) = Record <$> traverse transformType fs transformScalarSizes (Sum cs) = Sum <$> (traverse . traverse) transformType cs transformScalarSizes (Arrow as argName d argT retT) = Arrow as argName d <$> transformType argT <*> transformRetTypeSizes argset retT where argset = case argName of Unnamed -> mempty Named vn -> S.singleton vn transformScalarSizes (TypeVar u qn args) = TypeVar u qn <$> mapM onArg args where onArg (TypeArgDim dim) = TypeArgDim <$> onDim dim onArg (TypeArgType ty) = TypeArgType <$> transformType ty transformScalarSizes ty@Prim {} = pure ty onDim e | e == anySize = pure e | otherwise = replaceExp =<< transformExp e transformRetTypeSizes :: S.Set VName -> RetTypeBase Size as -> MonoM (RetTypeBase Size as) transformRetTypeSizes argset (RetType dims ty) = do ty' <- withArgs argset $ withMono dims $ transformType ty rl <- parametrizing argset let dims' = dims <> map snd rl pure $ RetType dims' ty' sizesForPat :: (MonadFreshNames m) => Pat ParamType -> m ([VName], Pat ParamType) sizesForPat pat = do (params', sizes) <- runStateT (traverse (bitraverse onDim pure) pat) [] pure (sizes, params') where onDim d | d == anySize = do v <- lift $ newVName "size" modify (v :) pure $ sizeFromName (qualName v) mempty | otherwise = pure d transformAppRes :: AppRes -> MonoM AppRes transformAppRes (AppRes t ext) = AppRes <$> transformType t <*> pure ext transformAppExp :: AppExp -> AppRes -> MonoM Exp transformAppExp (Range e1 me incl loc) res = do e1' <- transformExp e1 me' <- mapM transformExp me incl' <- mapM transformExp incl res' <- transformAppRes res pure $ AppExp (Range e1' me' incl' loc) (Info res') transformAppExp (LetPat sizes pat e body loc) res = do e' <- transformExp e let dimArgs = S.fromList (map sizeName sizes) implicitDims <- withArgs dimArgs $ askIntros $ fvVars $ freeInPat pat let dimArgs' = dimArgs <> implicitDims letArgs = S.fromList $ patNames pat argset = dimArgs' `S.union` letArgs pat' <- withArgs dimArgs' $ transformPat pat params <- parametrizing dimArgs' let sizes' = sizes <> map (`SizeBinder` mempty) (map snd params <> S.toList implicitDims) body' <- withParams params $ scoping argset $ transformExp body res' <- transformAppRes res pure $ AppExp (LetPat sizes' pat' e' body' loc) (Info res') transformAppExp LetFun {} _ = error "transformAppExp: LetFun is not supposed to occur" transformAppExp (If e1 e2 e3 loc) res = AppExp <$> (If <$> transformExp e1 <*> transformExp e2 <*> transformExp e3 <*> pure loc) <*> (Info <$> transformAppRes res) transformAppExp (Apply fe args _) res = mkApply <$> transformExp fe <*> mapM onArg (NE.toList args) <*> transformAppRes res where onArg (Info ext, e) = (ext,) <$> transformExp e transformAppExp (Loop sparams pat loopinit form body loc) res = do e1' <- transformExp $ loopInitExp loopinit let dimArgs = S.fromList sparams pat' <- withArgs dimArgs $ transformPat pat params <- parametrizing dimArgs let sparams' = sparams <> map snd params mergeArgs = dimArgs `S.union` S.fromList (patNames pat) (form', formArgs) <- case form of For ident e2 -> (,S.singleton $ identName ident) . For ident <$> transformExp e2 ForIn pat2 e2 -> do pat2' <- transformPat pat2 (,S.fromList (patNames pat2)) . ForIn pat2' <$> transformExp e2 While e2 -> fmap ((,mempty) . While) $ withParams params $ scoping mergeArgs $ transformExp e2 let argset = mergeArgs `S.union` formArgs body' <- withParams params $ scoping argset $ transformExp body -- Maybe monomorphisation introduced new arrays to the loop, and -- maybe they have AnySize sizes. This is not allowed. Invent some -- sizes for them. (pat_sizes, pat'') <- sizesForPat pat' res' <- transformAppRes res pure $ AppExp (Loop (sparams' ++ pat_sizes) pat'' (LoopInitExplicit e1') form' body' loc) (Info res') transformAppExp (BinOp (fname, _) (Info t) (e1, d1) (e2, d2) loc) res = do (AppRes ret ext) <- transformAppRes res fname' <- transformFName loc fname (toStruct t) e1' <- transformExp e1 e2' <- transformExp e2 if orderZero (typeOf e1') && orderZero (typeOf e2') then pure $ applyOp ret ext fname' e1' e2' else do -- We have to flip the arguments to the function, because -- operator application is left-to-right, while function -- application is outside-in. This matters when the arguments -- produce existential sizes. There are later places in the -- compiler where we transform BinOp to Apply, but anything that -- involves existential sizes will necessarily go through here. (x_param_e, x_param) <- makeVarParam e1' (y_param_e, y_param) <- makeVarParam e2' -- XXX: the type annotations here are wrong, but hopefully it -- doesn't matter as there will be an outer AppExp to handle -- them. pure $ AppExp ( LetPat [] x_param e1' ( AppExp (LetPat [] y_param e2' (applyOp ret ext fname' x_param_e y_param_e) loc) (Info $ AppRes ret mempty) ) mempty ) (Info (AppRes ret mempty)) where applyOp ret ext fname' x y = mkApply (mkApply fname' [(unInfo d1, x)] (AppRes ret mempty)) [(unInfo d2, y)] (AppRes ret ext) makeVarParam arg = do let argtype = typeOf arg x <- newNameFromString "binop_p" pure ( Var (qualName x) (Info argtype) mempty, Id x (Info argtype) mempty ) transformAppExp LetWith {} _ = error "transformAppExp: LetWith is not supposed to occur" transformAppExp (Index e0 idxs loc) res = AppExp <$> (Index <$> transformExp e0 <*> mapM transformDimIndex idxs <*> pure loc) <*> (Info <$> transformAppRes res) transformAppExp (Match e cs loc) res = do implicitDims <- askIntros $ fvVars $ freeInType $ typeOf e e' <- transformExp e cs' <- mapM (transformCase implicitDims) cs res' <- transformAppRes res if S.null implicitDims then pure $ AppExp (Match e' cs' loc) (Info res') else do tmpVar <- newNameFromString "matched_variable" pure $ AppExp ( LetPat (map (`SizeBinder` mempty) $ S.toList implicitDims) (Id tmpVar (Info $ typeOf e') mempty) e' ( AppExp (Match (Var (qualName tmpVar) (Info $ typeOf e') mempty) cs' loc) (Info res) ) mempty ) (Info res') -- Monomorphization of expressions. transformExp :: Exp -> MonoM Exp transformExp e@Literal {} = pure e transformExp e@IntLit {} = pure e transformExp e@FloatLit {} = pure e transformExp e@StringLit {} = pure e transformExp (Parens e loc) = Parens <$> transformExp e <*> pure loc transformExp (QualParens qn e loc) = QualParens qn <$> transformExp e <*> pure loc transformExp (TupLit es loc) = TupLit <$> mapM transformExp es <*> pure loc transformExp (RecordLit fs loc) = RecordLit <$> mapM transformField fs <*> pure loc where transformField (RecordFieldExplicit name e loc') = RecordFieldExplicit name <$> transformExp e <*> pure loc' transformField (RecordFieldImplicit (L vloc v) t _) = do t' <- traverse transformType t transformField $ RecordFieldExplicit (L vloc (baseName v)) (Var (qualName v) t' loc) loc transformExp (ArrayVal vs t loc) = pure $ ArrayVal vs t loc transformExp (ArrayLit es t loc) = ArrayLit <$> mapM transformExp es <*> traverse transformType t <*> pure loc transformExp (AppExp e res) = transformAppExp e (unInfo res) transformExp (Var fname (Info t) loc) = transformFName loc fname (toStruct t) transformExp (Hole t loc) = Hole <$> traverse transformType t <*> pure loc transformExp (Ascript e tp loc) = Ascript <$> transformExp e <*> pure tp <*> pure loc transformExp (Coerce e te t loc) = Coerce <$> transformExp e <*> pure te <*> traverse transformType t <*> pure loc transformExp (Negate e loc) = Negate <$> transformExp e <*> pure loc transformExp (Not e loc) = Not <$> transformExp e <*> pure loc transformExp (Lambda {}) = error "transformExp: Lambda is not supposed to occur" transformExp (OpSection qn t loc) = transformExp $ Var qn t loc transformExp (OpSectionLeft fname (Info t) e arg (Info rettype, Info retext) loc) = do let (Info (xp, xtype, xargext), Info (yp, ytype)) = arg e' <- transformExp e desugarBinOpSection fname (Just e') Nothing t (xp, xtype, xargext) (yp, ytype, Nothing) (rettype, retext) loc transformExp (OpSectionRight fname (Info t) e arg (Info rettype) loc) = do let (Info (xp, xtype), Info (yp, ytype, yargext)) = arg e' <- transformExp e desugarBinOpSection fname Nothing (Just e') t (xp, xtype, Nothing) (yp, ytype, yargext) (rettype, []) loc transformExp (ProjectSection fields (Info t) loc) = do t' <- transformType t desugarProjectSection fields t' loc transformExp (IndexSection idxs (Info t) loc) = do idxs' <- mapM transformDimIndex idxs desugarIndexSection idxs' t loc transformExp (Project n e tp loc) = do tp' <- traverse transformType tp e' <- transformExp e pure $ Project n e' tp' loc transformExp (Update e1 idxs e2 loc) = Update <$> transformExp e1 <*> mapM transformDimIndex idxs <*> transformExp e2 <*> pure loc transformExp (RecordUpdate e1 fs e2 t loc) = RecordUpdate <$> transformExp e1 <*> pure fs <*> transformExp e2 <*> traverse transformType t <*> pure loc transformExp (Assert e1 e2 desc loc) = Assert <$> transformExp e1 <*> transformExp e2 <*> pure desc <*> pure loc transformExp (Constr name all_es t loc) = Constr name <$> mapM transformExp all_es <*> traverse transformType t <*> pure loc transformExp (Attr info e loc) = Attr info <$> transformExp e <*> pure loc transformCase :: S.Set VName -> Case -> MonoM Case transformCase implicitDims (CasePat p e loc) = do p' <- transformPat p CasePat p' <$> scoping (S.fromList (patNames p) `S.union` implicitDims) (transformExp e) <*> pure loc transformDimIndex :: DimIndexBase Info VName -> MonoM (DimIndexBase Info VName) transformDimIndex (DimFix e) = DimFix <$> transformExp e transformDimIndex (DimSlice me1 me2 me3) = DimSlice <$> trans me1 <*> trans me2 <*> trans me3 where trans = mapM transformExp -- Transform an operator section into a lambda. desugarBinOpSection :: QualName VName -> Maybe Exp -> Maybe Exp -> StructType -> (PName, ParamType, Maybe VName) -> (PName, ParamType, Maybe VName) -> (ResRetType, [VName]) -> SrcLoc -> MonoM Exp desugarBinOpSection fname e_left e_right t (xp, xtype, xext) (yp, ytype, yext) (RetType dims rettype, retext) loc = do t' <- transformType t op <- transformFName loc fname $ toStruct t (v1, wrap_left, e1, p1) <- makeVarParam e_left =<< transformType xtype (v2, wrap_right, e2, p2) <- makeVarParam e_right =<< transformType ytype let apply_left = mkApply op [(xext, e1)] (AppRes (Scalar $ Arrow mempty yp (diet ytype) (toStruct ytype) (RetType [] $ toRes Nonunique t')) []) onDim (Var d typ _) | Named p <- xp, qualLeaf d == p = Var (qualName v1) typ loc | Named p <- yp, qualLeaf d == p = Var (qualName v2) typ loc onDim d = d rettype' = first onDim rettype body <- scoping (S.fromList [v1, v2]) $ mkApply apply_left [(yext, e2)] <$> transformAppRes (AppRes (toStruct rettype') retext) rettype'' <- transformRetTypeSizes (S.fromList [v1, v2]) $ RetType dims rettype' pure . wrap_left . wrap_right $ Lambda (p1 ++ p2) body Nothing (Info rettype'') loc where patAndVar argtype = do x <- newNameFromString "x" pure ( x, Id x (Info argtype) mempty, Var (qualName x) (Info (toStruct argtype)) mempty ) makeVarParam (Just e) argtype = do (v, pat, var_e) <- patAndVar argtype let wrap body = AppExp (LetPat [] (fmap toStruct pat) e body mempty) (Info $ AppRes (typeOf body) mempty) pure (v, wrap, var_e, []) makeVarParam Nothing argtype = do (v, pat, var_e) <- patAndVar argtype pure (v, id, var_e, [pat]) desugarProjectSection :: [Name] -> StructType -> SrcLoc -> MonoM Exp desugarProjectSection fields (Scalar (Arrow _ _ _ t1 (RetType dims t2))) loc = do p <- newVName "project_p" let body = foldl project (Var (qualName p) (Info t1) mempty) fields pure $ Lambda [Id p (Info $ toParam Observe t1) mempty] body Nothing (Info (RetType dims t2)) loc where project e field = case typeOf e of Scalar (Record fs) | Just t <- M.lookup field fs -> Project field e (Info t) mempty t -> error $ "desugarOpSection: type " ++ prettyString t ++ " does not have field " ++ prettyString field desugarProjectSection _ t _ = error $ "desugarOpSection: not a function type: " ++ prettyString t desugarIndexSection :: [DimIndex] -> StructType -> SrcLoc -> MonoM Exp desugarIndexSection idxs (Scalar (Arrow _ _ _ t1 (RetType dims t2))) loc = do p <- newVName "index_i" t1' <- transformType t1 t2' <- transformType t2 let body = AppExp (Index (Var (qualName p) (Info t1') loc) idxs loc) (Info (AppRes (toStruct t2') [])) pure $ Lambda [Id p (Info $ toParam Observe t1') mempty] body Nothing (Info (RetType dims t2')) loc desugarIndexSection _ t _ = error $ "desugarIndexSection: not a function type: " ++ prettyString t transformPat :: Pat (TypeBase Size u) -> MonoM (Pat (TypeBase Size u)) transformPat = traverse transformType type DimInst = M.Map VName Size dimMapping :: (Monoid a) => TypeBase Size a -> TypeBase Size a -> ExpReplacements -> ExpReplacements -> DimInst dimMapping t1 t2 r1 r2 = execState (matchDims onDims t1 t2) mempty where revMap = map (\(k, v) -> (v, k)) named1 = revMap r1 named2 = revMap r2 onDims bound e1 e2 = do onExps bound e1 e2 pure e1 onExps bound (Var v _ _) e = do unless (any (`elem` bound) $ freeVarsInExp e) $ modify (M.insert (qualLeaf v) e) case lookup (qualLeaf v) named1 of Just rexp -> onExps bound (unReplaced rexp) e Nothing -> pure () onExps bound e (Var v _ _) | Just rexp <- lookup (qualLeaf v) named2 = onExps bound e (unReplaced rexp) onExps bound e1 e2 | Just es <- similarExps e1 e2 = mapM_ (uncurry $ onExps bound) es onExps _ _ _ = pure mempty freeVarsInExp = fvVars . freeInExp inferSizeArgs :: [TypeParam] -> StructType -> ExpReplacements -> StructType -> MonoM [Exp] inferSizeArgs tparams bind_t bind_r t = do r <- gets (<>) <*> asks envParametrized let dinst = dimMapping bind_t t bind_r r mapM (tparamArg dinst) tparams where tparamArg dinst tp = case M.lookup (typeParamName tp) dinst of Just e -> replaceExp e Nothing -> pure $ sizeFromInteger 0 mempty -- Monomorphising higher-order functions can result in function types -- where the same named parameter occurs in multiple spots. When -- monomorphising we don't really need those parameter names anymore, -- and the defunctionaliser can be confused if there are duplicates -- (it doesn't handle shadowing), so let's just remove all parameter -- names here. This is safe because a MonoType does not contain sizes -- anyway. noNamedParams :: MonoType -> MonoType noNamedParams = f where f :: TypeBase MonoSize u -> TypeBase MonoSize u f (Array u shape t) = Array u shape (f' t) f (Scalar t) = Scalar $ f' t f' :: ScalarTypeBase MonoSize u -> ScalarTypeBase MonoSize u f' (Record fs) = Record $ fmap f fs f' (Sum cs) = Sum $ fmap (map f) cs f' (Arrow u _ d1 t1 (RetType dims t2)) = Arrow u Unnamed d1 (f t1) (RetType dims (f t2)) f' t = t -- | arrowArg takes a return type and returns it -- with the existentials bound moved at the right of arrows. -- It also gives the new set of parameters to consider. arrowArg :: S.Set VName -> -- scope S.Set VName -> -- set of argument [VName] -> -- size parameters RetTypeBase Size as -> (RetTypeBase Size as, S.Set VName) arrowArg scope argset args_params rety = let (rety', (funArgs, _)) = runWriter (arrowArgRetType (scope, mempty) argset rety) new_params = funArgs `S.union` S.fromList args_params in (arrowCleanRetType new_params rety', new_params) where -- \| takes a type (or return type) and returns it -- with the existentials bound moved at the right of arrows. -- It also gives (through writer monad) size variables used in arrow arguments -- and variables that are constructively used. -- The returned type should be cleanned, as too many existentials are introduced. arrowArgRetType :: (S.Set VName, [VName]) -> S.Set VName -> RetTypeBase Size as' -> Writer (S.Set VName, S.Set VName) (RetTypeBase Size as') arrowArgRetType (scope', dimsToPush) argset' (RetType dims ty) = pass $ do let dims' = dims <> dimsToPush (ty', (_, canExt)) <- listen $ arrowArgType (argset' `S.union` scope', dims') ty pure (RetType (filter (`S.member` canExt) dims') ty', first (`S.difference` canExt)) arrowArgScalar env (Record fs) = Record <$> traverse (arrowArgType env) fs arrowArgScalar env (Sum cs) = Sum <$> (traverse . traverse) (arrowArgType env) cs arrowArgScalar (scope', dimsToPush) (Arrow as argName d argT retT) = pass $ do let intros = S.filter notIntrisic argset' `S.difference` scope' retT' <- arrowArgRetType (scope', filter (`S.notMember` intros) dimsToPush) fullArgset retT pure (Arrow as argName d argT retT', bimap (intros `S.union`) (const mempty)) where notIntrisic vn = baseTag vn > maxIntrinsicTag argset' = fvVars $ freeInType argT fullArgset = argset' <> case argName of Unnamed -> mempty Named vn -> S.singleton vn arrowArgScalar env (TypeVar u qn args) = TypeVar u qn <$> mapM arrowArgArg args where arrowArgArg (TypeArgDim dim) = TypeArgDim <$> arrowArgSize dim arrowArgArg (TypeArgType ty) = TypeArgType <$> arrowArgType env ty arrowArgScalar _ ty = pure ty arrowArgType :: (S.Set VName, [VName]) -> TypeBase Size as' -> Writer (S.Set VName, S.Set VName) (TypeBase Size as') arrowArgType env (Array u shape scalar) = Array u <$> traverse arrowArgSize shape <*> arrowArgScalar env scalar arrowArgType env (Scalar ty) = Scalar <$> arrowArgScalar env ty arrowArgSize s@(Var qn _ _) = writer (s, (mempty, S.singleton $ qualLeaf qn)) arrowArgSize s = pure s -- \| arrowClean cleans the mess in the type arrowCleanRetType :: S.Set VName -> RetTypeBase Size as -> RetTypeBase Size as arrowCleanRetType paramed (RetType dims ty) = RetType (nubOrd $ filter (`S.notMember` paramed) dims) (arrowCleanType (paramed `S.union` S.fromList dims) ty) arrowCleanScalar :: S.Set VName -> ScalarTypeBase Size as -> ScalarTypeBase Size as arrowCleanScalar paramed (Record fs) = Record $ M.map (arrowCleanType paramed) fs arrowCleanScalar paramed (Sum cs) = Sum $ (M.map . map) (arrowCleanType paramed) cs arrowCleanScalar paramed (Arrow as argName d argT retT) = Arrow as argName d argT (arrowCleanRetType paramed retT) arrowCleanScalar paramed (TypeVar u qn args) = TypeVar u qn $ map arrowCleanArg args where arrowCleanArg (TypeArgDim dim) = TypeArgDim dim arrowCleanArg (TypeArgType ty) = TypeArgType $ arrowCleanType paramed ty arrowCleanScalar _ ty = ty arrowCleanType :: S.Set VName -> TypeBase Size as -> TypeBase Size as arrowCleanType paramed (Array u shape scalar) = Array u shape $ arrowCleanScalar paramed scalar arrowCleanType paramed (Scalar ty) = Scalar $ arrowCleanScalar paramed ty -- Monomorphise a polymorphic function at the types given in the instance -- list. Monomorphises the body of the function as well. Returns the fresh name -- of the generated monomorphic function and its 'ValBind' representation. monomorphiseBinding :: PolyBinding -> MonoType -> MonoM (VName, InferSizeArgs, ValBind) monomorphiseBinding (PolyBinding (entry, name, tparams, params, rettype, body, attrs, loc)) inst_t = isolateNormalisation $ do let bind_t = funType params rettype (substs, t_shape_params) <- typeSubstsM loc bind_t $ noNamedParams inst_t let shape_names = S.fromList $ map typeParamName $ shape_params ++ t_shape_params substs' = M.map (Subst []) substs substStructType = substTypesAny (fmap (fmap (second (const mempty))) . (`M.lookup` substs')) params' = map (substPat substStructType) params params'' <- withArgs shape_names $ mapM transformPat params' exp_naming <- paramGetClean let args = S.fromList $ foldMap patNames params arg_params = map snd exp_naming rettype' <- withParams exp_naming $ withArgs (args <> shape_names) $ hardTransformRetType (applySubst (`M.lookup` substs') rettype) extNaming <- paramGetClean scope <- S.union shape_names <$> askScope' let (rettype'', new_params) = arrowArg scope args arg_params rettype' bind_t' = substTypesAny (`M.lookup` substs') bind_t mkExplicit = flip S.member (mustBeExplicitInBinding bind_t'' <> mustBeExplicitInBinding bind_t') (shape_params_explicit, shape_params_implicit) = partition (mkExplicit . typeParamName) $ shape_params ++ t_shape_params ++ map (`TypeParamDim` mempty) (S.toList new_params) exp_naming' = filter ((`S.member` new_params) . snd) (extNaming <> exp_naming) bind_t'' = funType params'' rettype'' bind_r = exp_naming <> extNaming body' <- updateExpTypes (`M.lookup` substs') body body'' <- withParams exp_naming' $ withArgs (shape_names <> args) $ transformExp body' scope' <- S.union (shape_names <> args) <$> askScope' body''' <- expReplace exp_naming' <$> (calculateDims body'' . canCalculate scope' =<< get) seen_before <- elem name . map (fst . fst) <$> getLifts name' <- if null tparams && isNothing entry && not seen_before then pure name else newName name pure ( name', -- If the function is an entry point, then it cannot possibly -- need any explicit size arguments (checked by type checker). if isJust entry then const $ pure [] else inferSizeArgs shape_params_explicit bind_t'' bind_r, if isJust entry then toValBinding name' (shape_params_explicit ++ shape_params_implicit) params'' rettype'' (entryAssert exp_naming body''') else toValBinding name' shape_params_implicit (map shapeParam shape_params_explicit ++ params'') rettype'' body''' ) where askScope' = S.filter (`notElem` retDims rettype) <$> askScope shape_params = filter (not . isTypeParam) tparams updateExpTypes substs = astMap (mapper substs) paramGetClean = do ret <- get put mempty pure ret hardTransformRetType (RetType dims ty) = do ty' <- transformType ty unbounded <- askIntros $ fvVars $ freeInType ty' let dims' = S.toList unbounded pure $ RetType (dims' <> dims) ty' mapper substs = ASTMapper { mapOnExp = updateExpTypes substs, mapOnName = pure, mapOnStructType = pure . applySubst substs, mapOnParamType = pure . applySubst substs, mapOnResRetType = pure . applySubst substs } shapeParam tp = Id (typeParamName tp) (Info i64) $ srclocOf tp toValBinding name' tparams' params'' rettype' body'' = ValBind { valBindEntryPoint = Info <$> entry, valBindName = name', valBindRetType = Info rettype', valBindRetDecl = Nothing, valBindTypeParams = tparams', valBindParams = params'', valBindBody = body'', valBindDoc = Nothing, valBindAttrs = attrs, valBindLocation = loc } typeSubstsM :: (MonadFreshNames m) => SrcLoc -> StructType -> MonoType -> m (M.Map VName StructRetType, [TypeParam]) typeSubstsM loc orig_t1 orig_t2 = runWriterT $ fst <$> execStateT (sub orig_t1 orig_t2) (mempty, mempty) where subRet (Scalar (TypeVar _ v _)) rt = unless (baseTag (qualLeaf v) <= maxIntrinsicTag) $ addSubst v rt subRet t1 (RetType _ t2) = sub t1 t2 sub t1@(Array _ (Shape (d1 : _)) _) t2@(Array _ (Shape (d2 : _)) _) = do case d2 of MonoAnon i -> do (ts, sizes) <- get put (ts, M.insert i d1 sizes) _ -> pure () sub (stripArray 1 t1) (stripArray 1 t2) sub (Scalar (TypeVar _ v _)) t = unless (baseTag (qualLeaf v) <= maxIntrinsicTag) $ addSubst v $ RetType [] t sub (Scalar (Record fields1)) (Scalar (Record fields2)) = zipWithM_ sub (map snd $ sortFields fields1) (map snd $ sortFields fields2) sub (Scalar Prim {}) (Scalar Prim {}) = pure () sub (Scalar (Arrow _ _ _ t1a (RetType _ t1b))) (Scalar (Arrow _ _ _ t2a t2b)) = do sub t1a t2a subRet (toStruct t1b) (second (const NoUniqueness) t2b) sub (Scalar (Sum cs1)) (Scalar (Sum cs2)) = zipWithM_ typeSubstClause (sortConstrs cs1) (sortConstrs cs2) where typeSubstClause (_, ts1) (_, ts2) = zipWithM sub ts1 ts2 sub t1@(Scalar Sum {}) t2 = sub t1 t2 sub t1 t2@(Scalar Sum {}) = sub t1 t2 sub t1 t2 = error $ unlines ["typeSubstsM: mismatched types:", prettyString t1, prettyString t2] addSubst (QualName _ v) (RetType ext t) = do (ts, sizes) <- get unless (v `M.member` ts) $ do t' <- bitraverse onDim pure t put (M.insert v (RetType ext t') ts, sizes) onDim (MonoKnown i) = do (ts, sizes) <- get case M.lookup i sizes of Nothing -> do d <- lift $ lift $ newVName "d" tell [TypeParamDim d loc] put (ts, M.insert i (sizeFromName (qualName d) mempty) sizes) pure $ sizeFromName (qualName d) mempty Just d -> pure d onDim (MonoAnon i) = do (_, sizes) <- get case M.lookup i sizes of Nothing -> pure anySize Just d -> pure d -- Perform a given substitution on the types in a pattern. substPat :: (t -> t) -> Pat t -> Pat t substPat f pat = case pat of TuplePat pats loc -> TuplePat (map (substPat f) pats) loc RecordPat fs loc -> RecordPat (map substField fs) loc where substField (n, p) = (n, substPat f p) PatParens p loc -> PatParens (substPat f p) loc PatAttr attr p loc -> PatAttr attr (substPat f p) loc Id vn (Info tp) loc -> Id vn (Info $ f tp) loc Wildcard (Info tp) loc -> Wildcard (Info $ f tp) loc PatAscription p _ _ -> substPat f p PatLit e (Info tp) loc -> PatLit e (Info $ f tp) loc PatConstr n (Info tp) ps loc -> PatConstr n (Info $ f tp) ps loc toPolyBinding :: ValBind -> PolyBinding toPolyBinding (ValBind entry name _ (Info rettype) tparams params body _ attrs loc) = PolyBinding (unInfo <$> entry, name, tparams, params, rettype, body, attrs, loc) transformValBind :: ValBind -> MonoM Env transformValBind valbind = do let valbind' = toPolyBinding valbind when (isJust $ valBindEntryPoint valbind) $ do let t = funType (valBindParams valbind) $ unInfo $ valBindRetType valbind (name, infer, valbind'') <- monomorphiseBinding valbind' $ monoType t tell $ Seq.singleton (name, valbind'') addLifted (valBindName valbind) (monoType t) (name, infer) pure mempty { envPolyBindings = M.singleton (valBindName valbind) valbind', envGlobalScope = if null (valBindParams valbind) then S.fromList $ retDims $ unInfo $ valBindRetType valbind else mempty } transformValBinds :: [ValBind] -> MonoM () transformValBinds [] = pure () transformValBinds (valbind : ds) = do env <- transformValBind valbind localEnv env $ transformValBinds ds -- | Monomorphise a list of top-level value bindings. transformProg :: (MonadFreshNames m) => [ValBind] -> m [ValBind] transformProg decs = fmap (toList . fmap snd . snd) $ modifyNameSource $ \namesrc -> runMonoM namesrc $ transformValBinds decs futhark-0.25.27/src/Futhark/Internalise/ReplaceRecords.hs000066400000000000000000000164541475065116200232550ustar00rootroot00000000000000-- | Converts identifiers of record type into record patterns (and -- similarly for tuples). This is to ensure that the closures -- produced in lambda lifting and defunctionalisation do not carry -- around huge records of which only a tiny part is needed. module Futhark.Internalise.ReplaceRecords (transformProg) where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor import Data.Bitraversable import Data.Map.Strict qualified as M import Futhark.MonadFreshNames import Language.Futhark import Language.Futhark.Traversals -- Mapping from record names to the variable names that contain the -- fields, as well as an expression for the entire record. This is -- used because the monomorphiser also expands all, record patterns. type RecordReplacements = M.Map VName RecordReplacement type RecordReplacement = (M.Map Name (VName, StructType), Exp) newtype Env = Env { envRecordReplacements :: RecordReplacements } data S = S { stateNameSource :: VNameSource, stateStructTypeMemo :: M.Map StructType StructType, stateParamTypeMemo :: M.Map ParamType ParamType } -- The monomorphization monad. newtype RecordM a = RecordM (ReaderT Env (State S) a) deriving ( Functor, Applicative, Monad, MonadReader Env, MonadState S ) instance MonadFreshNames RecordM where getNameSource = RecordM $ gets stateNameSource putNameSource src = RecordM $ modify $ \s -> s {stateNameSource = src} runRecordM :: VNameSource -> RecordM a -> (a, VNameSource) runRecordM src (RecordM m) = second stateNameSource $ runState (runReaderT m (Env mempty)) (S src mempty mempty) withRecordReplacements :: RecordReplacements -> RecordM a -> RecordM a withRecordReplacements rr = local $ \env -> env {envRecordReplacements = rr <> envRecordReplacements env} lookupRecordReplacement :: VName -> RecordM (Maybe RecordReplacement) lookupRecordReplacement v = asks $ M.lookup v . envRecordReplacements wildcard :: TypeBase Size u -> SrcLoc -> Pat (TypeBase Size u) wildcard (Scalar (Record fs)) loc = RecordPat (zip (map (L noLoc) (M.keys fs)) $ map ((`Wildcard` loc) . Info) $ M.elems fs) loc wildcard t loc = Wildcard (Info t) loc memoParamType :: ParamType -> RecordM ParamType -> RecordM ParamType memoParamType t m = do prev <- gets $ M.lookup t . stateParamTypeMemo case prev of Just t' -> pure t' Nothing -> do t' <- m modify $ \s -> s {stateParamTypeMemo = M.insert t t' $ stateParamTypeMemo s} pure t' memoStructType :: StructType -> RecordM StructType -> RecordM StructType memoStructType t m = do prev <- gets $ M.lookup t . stateStructTypeMemo case prev of Just t' -> pure t' Nothing -> do t' <- m modify $ \s -> s {stateStructTypeMemo = M.insert t t' $ stateStructTypeMemo s} pure t' -- No need to keep memoisation cache between top level functions. memoClear :: RecordM () memoClear = modify $ \s -> s { stateStructTypeMemo = mempty, stateParamTypeMemo = mempty } transformPat :: (TypeBase Size u -> RecordM (TypeBase Size u)) -> Pat (TypeBase Size u) -> RecordM (Pat (TypeBase Size u), RecordReplacements) transformPat _ (Id v (Info (Scalar (Record fs))) loc) = do let fs' = M.toList fs (fs_ks, fs_ts) <- fmap unzip $ forM fs' $ \(f, ft) -> (,) <$> newVName (nameToString f) <*> pure ft pure ( RecordPat (zip (map (L noLoc . fst) fs') (zipWith3 Id fs_ks (map Info fs_ts) $ repeat loc)) loc, M.singleton v ( M.fromList $ zip (map fst fs') $ zip fs_ks $ map toStruct fs_ts, RecordLit (zipWith3 toField (map fst fs') fs_ks fs_ts) loc ) ) where toField f f_v f_t = let f_v' = Var (qualName f_v) (Info $ toStruct f_t) loc in RecordFieldExplicit (L noLoc f) f_v' loc transformPat onType (Id v t loc) = do t' <- traverse onType t pure (Id v t' loc, mempty) transformPat onType (TuplePat pats loc) = do (pats', rrs) <- mapAndUnzipM (transformPat onType) pats pure (TuplePat pats' loc, mconcat rrs) transformPat onType (RecordPat fields loc) = do let (field_names, field_pats) = unzip fields (field_pats', rrs) <- mapAndUnzipM (transformPat onType) field_pats pure (RecordPat (zip field_names field_pats') loc, mconcat rrs) transformPat onType (PatParens pat loc) = do (pat', rr) <- transformPat onType pat pure (PatParens pat' loc, rr) transformPat onType (PatAttr attr pat loc) = do (pat', rr) <- transformPat onType pat pure (PatAttr attr pat' loc, rr) transformPat onType (Wildcard (Info t) loc) = do t' <- onType t pure (wildcard t' loc, mempty) transformPat onType (PatAscription pat _ _) = transformPat onType pat transformPat _ (PatLit e t loc) = pure (PatLit e t loc, mempty) transformPat onType (PatConstr name t all_ps loc) = do (all_ps', rrs) <- mapAndUnzipM (transformPat onType) all_ps pure (PatConstr name t all_ps' loc, mconcat rrs) transformParamType :: ParamType -> RecordM ParamType transformParamType t = memoParamType t $ bitraverse transformExp pure t transformStructType :: StructType -> RecordM StructType transformStructType t = memoStructType t $ bitraverse transformExp pure t transformExp :: Exp -> RecordM Exp transformExp (Project n e t loc) = do maybe_fs <- case e of Var qn _ _ -> lookupRecordReplacement (qualLeaf qn) _ -> pure Nothing case maybe_fs of Just (m, _) | Just (v, _) <- M.lookup n m -> pure $ Var (qualName v) t loc _ -> do e' <- transformExp e pure $ Project n e' t loc transformExp (Var fname t loc) = do maybe_fs <- lookupRecordReplacement $ qualLeaf fname case maybe_fs of Just (_, e) -> pure e Nothing -> Var fname <$> traverse transformStructType t <*> pure loc transformExp (AppExp (LetPat sizes pat e body loc) res) = do e' <- transformExp e (pat', rr) <- transformPat transformStructType pat body' <- withRecordReplacements rr $ transformExp body pure $ AppExp (LetPat sizes pat' e' body' loc) res transformExp (AppExp (LetFun fname (tparams, params, retdecl, ret, funbody) body loc) res) = do (params', rrs) <- mapAndUnzipM (transformPat transformParamType) params funbody' <- withRecordReplacements (mconcat rrs) $ transformExp funbody body' <- transformExp body pure $ AppExp (LetFun fname (tparams, params', retdecl, ret, funbody') body' loc) res transformExp (Lambda params body retdecl ret loc) = do (params', rrs) <- mapAndUnzipM (transformPat transformParamType) params body' <- withRecordReplacements (mconcat rrs) $ transformExp body pure $ Lambda params' body' retdecl ret loc transformExp e = astMap m e where m = identityMapper {mapOnExp = transformExp} onValBind :: ValBind -> RecordM ValBind onValBind vb = do (params', rrs) <- mapAndUnzipM (transformPat transformParamType) $ valBindParams vb e' <- withRecordReplacements (mconcat rrs) $ transformExp $ valBindBody vb ret <- traverse (bitraverse transformExp pure) $ valBindRetType vb memoClear pure $ vb { valBindBody = e', valBindParams = params', valBindRetType = ret } -- | Monomorphise a list of top-level declarations. A module-free input program -- is expected, so only value declarations and type declaration are accepted. transformProg :: (MonadFreshNames m) => [ValBind] -> m [ValBind] transformProg vbs = modifyNameSource $ \namesrc -> runRecordM namesrc $ mapM onValBind vbs futhark-0.25.27/src/Futhark/Internalise/TypesValues.hs000066400000000000000000000333261475065116200226410ustar00rootroot00000000000000module Futhark.Internalise.TypesValues ( -- * Internalising types internaliseReturnType, internaliseCoerceType, internaliseLambdaReturnType, internaliseEntryReturnType, internaliseType, internaliseParamTypes, internaliseLoopParamType, internalisePrimType, internalisedTypeSize, internaliseSumTypeRep, internaliseSumType, Tree, -- * Internalising values internalisePrimValue, -- * For internal testing inferAliases, internaliseConstructors, ) where import Control.Monad import Control.Monad.Free (Free (..)) import Control.Monad.State import Data.Bifunctor import Data.Bitraversable (bitraverse) import Data.Foldable (toList) import Data.List (delete, find, foldl') import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Futhark.IR.SOACS hiding (Free) import Futhark.IR.SOACS qualified as I import Futhark.Internalise.Monad import Futhark.Util (chunkLike) import Language.Futhark qualified as E internaliseUniqueness :: E.Uniqueness -> I.Uniqueness internaliseUniqueness E.Nonunique = I.Nonunique internaliseUniqueness E.Unique = I.Unique newtype TypeState = TypeState {typeCounter :: Int} newtype InternaliseTypeM a = InternaliseTypeM (State TypeState a) deriving (Functor, Applicative, Monad, MonadState TypeState) runInternaliseTypeM :: InternaliseTypeM a -> a runInternaliseTypeM = runInternaliseTypeM' mempty runInternaliseTypeM' :: [VName] -> InternaliseTypeM a -> a runInternaliseTypeM' exts (InternaliseTypeM m) = evalState m $ TypeState (length exts) internaliseParamTypes :: [E.ParamType] -> InternaliseM [[Tree (I.TypeBase Shape Uniqueness)]] internaliseParamTypes ts = mapM (mapM (mapM mkAccCerts)) . runInternaliseTypeM $ mapM (fmap (map (fmap onType)) . internaliseTypeM mempty . E.paramToRes) ts where onType = fromMaybe bad . hasStaticShape bad = error $ "internaliseParamTypes: " ++ prettyString ts -- We need to fix up the arrays for any Acc return values or loop -- parameters. We look at the concrete types for this, since the Acc -- parameter name in the second list will just be something we made up. fixupKnownTypes :: [TypeBase shape1 u1] -> [(TypeBase shape2 u2, b)] -> [(TypeBase shape2 u2, b)] fixupKnownTypes = zipWith fixup where fixup (Acc acc ispace ts _) (Acc _ _ _ u2, b) = (Acc acc ispace ts u2, b) fixup _ t = t -- Generate proper certificates for the placeholder accumulator -- certificates produced by internaliseType (identified with tag 0). -- Only needed when we cannot use 'fixupKnownTypes'. mkAccCerts :: TypeBase shape u -> InternaliseM (TypeBase shape u) mkAccCerts (Array pt shape u) = pure $ Array pt shape u mkAccCerts (Acc c shape ts u) = Acc <$> c' <*> pure shape <*> pure ts <*> pure u where c' | baseTag c == 0 = newVName "acc_cert" | otherwise = pure c mkAccCerts t = pure t internaliseLoopParamType :: E.ParamType -> [TypeBase shape u] -> InternaliseM [I.TypeBase Shape Uniqueness] internaliseLoopParamType et ts = map fst . fixupKnownTypes ts . map (,()) . concatMap (concatMap toList) <$> internaliseParamTypes [et] -- Tag every sublist with its offset in corresponding flattened list. withOffsets :: (Foldable a) => [a b] -> [(a b, Int)] withOffsets xs = zip xs (scanl (+) 0 $ map length xs) numberFrom :: Int -> Tree a -> Tree (a, Int) numberFrom o = flip evalState o . f where f (Pure x) = state $ \i -> (Pure (x, i), i + 1) f (Free xs) = Free <$> traverse f xs numberTrees :: [Tree a] -> [Tree (a, Int)] numberTrees = map (uncurry $ flip numberFrom) . withOffsets nonuniqueArray :: TypeBase shape Uniqueness -> Bool nonuniqueArray t@Array {} = not $ unique t nonuniqueArray _ = False matchTrees :: Tree a -> Tree b -> Maybe (Tree (a, b)) matchTrees (Pure a) (Pure b) = Just $ Pure (a, b) matchTrees (Free as) (Free bs) | length as == length bs = Free <$> zipWithM matchTrees as bs matchTrees _ _ = Nothing subtreesMatching :: Tree a -> Tree b -> [Tree (a, b)] subtreesMatching as bs = case matchTrees as bs of Just m -> [m] Nothing -> case bs of Pure _ -> [] Free bs' -> foldMap (subtreesMatching as) bs' -- See Note [Alias Inference]. inferAliases :: [Tree (I.TypeBase Shape Uniqueness)] -> [Tree (I.TypeBase ExtShape Uniqueness)] -> [[(I.TypeBase ExtShape Uniqueness, RetAls)]] inferAliases all_param_ts all_res_ts = map onRes all_res_ts where all_res_ts' = numberTrees all_res_ts all_param_ts' = numberTrees all_param_ts aliasable_param_ts = filter (all $ nonuniqueArray . fst) all_param_ts' aliasable_res_ts = filter (all $ nonuniqueArray . fst) all_res_ts' onRes (Pure res_t) = -- Necessarily a non-array. [(res_t, RetAls mempty mempty)] onRes (Free res_ts) = [ if nonuniqueArray res_t then (res_t, RetAls pals rals) else (res_t, mempty) | (res_t, pals, rals) <- zip3 (toList (Free res_ts)) palss ralss ] where reorder [] = replicate (length (Free res_ts)) [] reorder xs = L.transpose xs infer ts = reorder . map (toList . fmap (snd . snd)) $ foldMap (subtreesMatching (Free res_ts)) ts palss = infer aliasable_param_ts ralss = infer aliasable_res_ts internaliseReturnType :: [Tree (I.TypeBase Shape Uniqueness)] -> E.ResRetType -> [TypeBase shape u] -> [(I.TypeBase ExtShape Uniqueness, RetAls)] internaliseReturnType paramts (E.RetType dims et) ts = fixupKnownTypes ts . concat . inferAliases paramts $ runInternaliseTypeM' dims (internaliseTypeM exts et) where exts = M.fromList $ zip dims [0 ..] -- | As 'internaliseReturnType', but returns components of a top-level -- tuple type piecemeal. internaliseEntryReturnType :: [Tree (I.TypeBase Shape Uniqueness)] -> E.ResRetType -> [[(I.TypeBase ExtShape Uniqueness, RetAls)]] internaliseEntryReturnType paramts (E.RetType dims et) = let et' = runInternaliseTypeM' dims . mapM (internaliseTypeM exts) $ case E.isTupleRecord et of Just ets | not $ null ets -> ets _ -> [et] in map concat $ chunkLike et' $ inferAliases paramts $ concat et' where exts = M.fromList $ zip dims [0 ..] internaliseCoerceType :: E.StructType -> [TypeBase shape u] -> [I.TypeBase ExtShape Uniqueness] internaliseCoerceType et ts = map fst $ internaliseReturnType [] (E.RetType [] $ E.toRes E.Nonunique et) ts internaliseLambdaReturnType :: E.ResType -> [TypeBase shape u] -> InternaliseM [I.TypeBase Shape NoUniqueness] internaliseLambdaReturnType et ts = map fromDecl <$> internaliseLoopParamType (E.resToParam et) ts internaliseType :: E.TypeBase E.Size NoUniqueness -> [Tree (I.TypeBase I.ExtShape Uniqueness)] internaliseType = runInternaliseTypeM . internaliseTypeM mempty . E.toRes E.Nonunique newId :: InternaliseTypeM Int newId = do i <- gets typeCounter modify $ \s -> s {typeCounter = i + 1} pure i internaliseDim :: M.Map VName Int -> E.Size -> InternaliseTypeM ExtSize internaliseDim exts d = case d of e | e == E.anySize -> Ext <$> newId (E.IntLit n _ _) -> pure $ I.Free $ intConst I.Int64 n (E.Var name _ _) -> pure $ namedDim name e -> error $ "Unexpected size expression: " ++ prettyString e where namedDim (E.QualName _ name) | Just x <- name `M.lookup` exts = I.Ext x | otherwise = I.Free $ I.Var name -- | A tree is just an instantiation of the free monad with a list -- monad. -- -- The important thing is that we use it to represent the original -- structure of arrayss, as this matters for aliasing. Each 'Free' -- constructor corresponds to an array dimension. Only non-arrays -- have a 'Pure' at the top level. See Note [Alias Inference]. type Tree = Free [] internaliseTypeM :: M.Map VName Int -> E.ResType -> InternaliseTypeM [Tree (I.TypeBase ExtShape Uniqueness)] internaliseTypeM exts orig_t = case orig_t of E.Array u shape et -> do dims <- internaliseShape shape ets <- internaliseTypeM exts $ E.toRes E.Nonunique $ E.Scalar et let f et' = I.arrayOf et' (Shape dims) $ internaliseUniqueness u pure [array $ map (fmap f) ets] E.Scalar (E.Prim bt) -> pure [Pure $ I.Prim $ internalisePrimType bt] E.Scalar (E.Record ets) -- We map empty records to units, because otherwise arrays of -- unit will lose their sizes. | null ets -> pure [Pure $ I.Prim I.Unit] | otherwise -> concat <$> mapM (internaliseTypeM exts . snd) (E.sortFields ets) E.Scalar (E.TypeVar u tn [E.TypeArgType arr_t]) | baseTag (E.qualLeaf tn) <= E.maxIntrinsicTag, baseString (E.qualLeaf tn) == "acc" -> do ts <- foldMap (toList . fmap (fromDecl . onAccType)) <$> internaliseTypeM exts (E.toRes Nonunique arr_t) let acc_param = VName "PLACEHOLDER" 0 -- See mkAccCerts. acc_shape = Shape [arraysSize 0 ts] u' = internaliseUniqueness u acc_t = Acc acc_param acc_shape (map rowType ts) u' pure [Pure acc_t] E.Scalar E.TypeVar {} -> error $ "internaliseTypeM: cannot handle type variable: " ++ prettyString orig_t E.Scalar E.Arrow {} -> error $ "internaliseTypeM: cannot handle function type: " ++ prettyString orig_t E.Scalar (E.Sum cs) -> do (ts, _) <- internaliseConstructors <$> traverse (fmap concat . mapM (internaliseTypeM exts)) cs pure $ Pure (I.Prim (I.IntType I.Int8)) : ts where internaliseShape = mapM (internaliseDim exts) . E.shapeDims array [Free ts] = Free ts array ts = Free ts onAccType = fromMaybe bad . hasStaticShape bad = error $ "internaliseTypeM Acc: " ++ prettyString orig_t -- | Only exposed for testing purposes. internaliseConstructors :: M.Map Name [Tree (I.TypeBase ExtShape Uniqueness)] -> ( [Tree (I.TypeBase ExtShape Uniqueness)], [(Name, [Int])] ) internaliseConstructors cs = L.mapAccumL onConstructor mempty $ E.sortConstrs cs where onConstructor ts (c, c_ts) = let (_, js, new_ts) = foldl' f (withOffsets (map (fmap fromDecl) ts), mempty, mempty) c_ts in (ts ++ new_ts, (c, js)) where size = sum . map length f (ts', js, new_ts) t | all primType t, Just (_, j) <- find ((== fmap fromDecl t) . fst) ts' = ( delete (fmap fromDecl t, j) ts', js ++ take (length t) [j ..], new_ts ) | otherwise = ( ts', js ++ take (length t) [size ts + size new_ts ..], new_ts ++ [t] ) internaliseSumTypeRep :: M.Map Name [E.StructType] -> ( [I.TypeBase ExtShape Uniqueness], [(Name, [Int])] ) internaliseSumTypeRep cs = first (foldMap toList) . runInternaliseTypeM $ internaliseConstructors <$> traverse (fmap concat . mapM (internaliseTypeM mempty . E.toRes E.Nonunique)) cs internaliseSumType :: M.Map Name [E.StructType] -> InternaliseM ( [I.TypeBase ExtShape Uniqueness], [(Name, [Int])] ) internaliseSumType = bitraverse (mapM mkAccCerts) pure . internaliseSumTypeRep -- | How many core language values are needed to represent one source -- language value of the given type? internalisedTypeSize :: E.TypeBase E.Size als -> Int -- A few special cases for performance. internalisedTypeSize (E.Scalar (E.Prim _)) = 1 internalisedTypeSize (E.Array _ _ (E.Prim _)) = 1 internalisedTypeSize t = sum $ map length $ internaliseType $ E.toStruct t -- | Convert an external primitive to an internal primitive. internalisePrimType :: E.PrimType -> I.PrimType internalisePrimType (E.Signed t) = I.IntType t internalisePrimType (E.Unsigned t) = I.IntType t internalisePrimType (E.FloatType t) = I.FloatType t internalisePrimType E.Bool = I.Bool -- | Convert an external primitive value to an internal primitive value. internalisePrimValue :: E.PrimValue -> I.PrimValue internalisePrimValue (E.SignedValue v) = I.IntValue v internalisePrimValue (E.UnsignedValue v) = I.IntValue v internalisePrimValue (E.FloatValue v) = I.FloatValue v internalisePrimValue (E.BoolValue b) = I.BoolValue b -- Note [Alias Inference] -- -- The core language requires us to precisely indicate the aliasing of -- function results (the RetAls type). This is a problem when coming -- from the source language, where it is implicit: a non-unique -- function return value aliases every function argument. The problem -- now occurs because the core language uses a different value -- representation than the source language - in particular, we do not -- have arrays of tuples. E.g. @([]i32,[]i32)@ and @[](i32,i32)@ both -- have the same core representation, but their implications for -- aliasing are different. -- -- -- To understand why this is a problem, consider a source program -- -- def id (x: [](i32,i32)) = x -- -- def f n = -- let x = replicate n (0,0) -- let x' = id x -- let x'' = x' with [0] = (1,1) -- in x'' -- -- With the core language value representation, it will be this: -- -- def id (x1: []i32) (x2: []i32) = (x1,x2) -- -- def f n = -- let x1 = replicate n 0 -- let x2 = replicate n 0 -- let (x1', x2') = id x1 x2 -- let x1'' = x1' with [0] = 1 -- let x2'' = x2' with [0] = 1 -- in (x1'', x2'') -- -- The results of 'id' alias *both* of the arguments, so x1' aliases -- x1 and x2, and x2' also aliases x1 and x2. This means that the -- first with-expression will consume all of x1/x2/x1'/x2', and then -- the second with-expression is a type error, as it references a -- consumed variable. -- -- Our solution is to deduce the possible aliasing such that -- components that originally constituted the same array-of-tuples are -- not aliased. The main complexity is that we have to keep -- information on the original (source) type structure around for a -- while. This is done with the Tree type. futhark-0.25.27/src/Futhark/LSP/000077500000000000000000000000001475065116200161735ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/LSP/Compile.hs000066400000000000000000000117611475065116200201250ustar00rootroot00000000000000-- | Building blocks for "recompiling" (actually just type-checking) -- the Futhark program managed by the language server. The challenge -- here is that if the program becomes type-invalid, we want to keep -- the old state around. module Futhark.LSP.Compile (tryTakeStateFromIORef, tryReCompile) where import Colog.Core (logStringStderr, (<&)) import Control.Lens.Getter (view) import Control.Monad.IO.Class (MonadIO (liftIO)) import Data.IORef (IORef, readIORef, writeIORef) import Data.Map qualified as M import Data.Maybe (fromMaybe) import Data.Text qualified as T import Futhark.Compiler.Program (LoadedProg, lpFilePaths, lpWarnings, noLoadedProg, reloadProg) import Futhark.LSP.Diagnostic (diagnosticSource, maxDiagnostic, publishErrorDiagnostics, publishWarningDiagnostics) import Futhark.LSP.State (State (..), emptyState, updateStaleContent, updateStaleMapping) import Futhark.LSP.Tool (computeMapping) import Language.Futhark.Warnings (listWarnings) import Language.LSP.Protocol.Types ( filePathToUri, fromNormalizedFilePath, toNormalizedUri, uriToNormalizedFilePath, ) import Language.LSP.Server (LspT, flushDiagnosticsBySource, getVirtualFile, getVirtualFiles) import Language.LSP.VFS (VFS, vfsMap, virtualFileText) -- | Try to take state from IORef, if it's empty, try to compile. tryTakeStateFromIORef :: IORef State -> Maybe FilePath -> LspT () IO State tryTakeStateFromIORef state_mvar file_path = do old_state <- liftIO $ readIORef state_mvar case stateProgram old_state of Nothing -> do new_state <- tryCompile old_state file_path noLoadedProg liftIO $ writeIORef state_mvar new_state pure new_state Just prog -> do -- If this is in the context of some file that is not part of -- the program, try to reload the program from that file. let files = lpFilePaths prog state <- case file_path of Just file_path' | file_path' `notElem` files -> do logStringStderr <& ("File not part of program: " <> show file_path') logStringStderr <& ("Program contains: " <> show files) tryCompile old_state file_path noLoadedProg _ -> pure old_state liftIO $ writeIORef state_mvar state pure state -- | Try to (re)-compile, replace old state if successful. tryReCompile :: IORef State -> Maybe FilePath -> LspT () IO () tryReCompile state_mvar file_path = do logStringStderr <& "(Re)-compiling ..." old_state <- liftIO $ readIORef state_mvar let loaded_prog = getLoadedProg old_state new_state <- tryCompile old_state file_path loaded_prog case stateProgram new_state of Nothing -> do logStringStderr <& "Failed to (re)-compile, using old state or Nothing" logStringStderr <& "Computing PositionMapping for: " <> show file_path mapping <- computeMapping old_state file_path liftIO $ writeIORef state_mvar $ updateStaleMapping file_path mapping old_state Just _ -> do logStringStderr <& "(Re)-compile successful" liftIO $ writeIORef state_mvar new_state -- | Try to compile, publish diagnostics on warnings and errors, return newly compiled state. -- Single point where the compilation is done, and shouldn't be exported. tryCompile :: State -> Maybe FilePath -> LoadedProg -> LspT () IO State tryCompile _ Nothing _ = pure emptyState tryCompile state (Just path) old_loaded_prog = do logStringStderr <& "Reloading program from " <> show path vfs <- getVirtualFiles res <- liftIO $ reloadProg old_loaded_prog [path] (transformVFS vfs) -- NOTE: vfs only keeps track of current opened files flushDiagnosticsBySource maxDiagnostic diagnosticSource case res of Right new_loaded_prog -> do publishWarningDiagnostics $ listWarnings $ lpWarnings new_loaded_prog maybe_virtual_file <- getVirtualFile $ toNormalizedUri $ filePathToUri path case maybe_virtual_file of Nothing -> pure $ State (Just new_loaded_prog) (staleData state) -- should never happen Just virtual_file -> pure $ updateStaleContent path virtual_file new_loaded_prog state -- Preserve files that have been opened should be enoguth. -- But still might need an update on re-compile logic, don't discard all state afterwards, -- try to compile from root file, if there is a depencency relatetion, improve performance and provide more dignostic. Left prog_error -> do logStringStderr <& "Compilation failed, publishing diagnostics" publishErrorDiagnostics prog_error pure emptyState -- | Transform VFS to a map of file paths to file contents. -- This is used to pass the file contents to the compiler. transformVFS :: VFS -> M.Map FilePath T.Text transformVFS vfs = M.foldrWithKey ( \uri virtual_file acc -> case uriToNormalizedFilePath uri of Nothing -> acc Just file_path -> M.insert (fromNormalizedFilePath file_path) (virtualFileText virtual_file) acc ) M.empty (view vfsMap vfs) getLoadedProg :: State -> LoadedProg getLoadedProg state = fromMaybe noLoadedProg (stateProgram state) futhark-0.25.27/src/Futhark/LSP/Diagnostic.hs000066400000000000000000000056761475065116200206310ustar00rootroot00000000000000-- | Handling of diagnostics in the language server - things like -- warnings and errors. module Futhark.LSP.Diagnostic ( publishWarningDiagnostics, publishErrorDiagnostics, diagnosticSource, maxDiagnostic, ) where import Colog.Core (logStringStderr, (<&)) import Control.Lens ((^.)) import Data.Foldable (for_) import Data.List.NonEmpty qualified as NE import Data.Map qualified as M import Data.Text qualified as T import Futhark.Compiler.Program (ProgError (..)) import Futhark.LSP.Tool (posToUri, rangeFromLoc) import Futhark.Util.Loc (Loc (..)) import Futhark.Util.Pretty (Doc, docText) import Language.LSP.Diagnostics (partitionBySource) import Language.LSP.Protocol.Lens (HasVersion (version)) import Language.LSP.Protocol.Types import Language.LSP.Server (LspT, getVersionedTextDoc, publishDiagnostics) mkDiagnostic :: Range -> DiagnosticSeverity -> T.Text -> Diagnostic mkDiagnostic range severity msg = Diagnostic range (Just severity) Nothing Nothing diagnosticSource msg Nothing Nothing Nothing -- | Publish diagnostics from a Uri to Diagnostics mapping. publish :: [(Uri, [Diagnostic])] -> LspT () IO () publish uri_diags_map = for_ uri_diags_map $ \(uri, diags) -> do doc <- getVersionedTextDoc $ TextDocumentIdentifier uri logStringStderr <& ("Publishing diagnostics for " ++ show uri ++ " Version: " ++ show (doc ^. version)) publishDiagnostics maxDiagnostic (toNormalizedUri uri) (Just $ doc ^. version) (partitionBySource diags) -- | Send warning diagnostics to the client. publishWarningDiagnostics :: [(Loc, Doc a)] -> LspT () IO () publishWarningDiagnostics warnings = do publish $ M.assocs $ M.unionsWith (++) $ map onWarn warnings where onWarn (NoLoc, _) = mempty onWarn (loc@(Loc pos _), msg) = M.singleton (posToUri pos) [ mkDiagnostic (rangeFromLoc loc) DiagnosticSeverity_Warning (docText msg) ] -- | Send error diagnostics to the client. publishErrorDiagnostics :: NE.NonEmpty ProgError -> LspT () IO () publishErrorDiagnostics errors = publish $ M.assocs $ M.unionsWith (++) $ map onDiag $ NE.toList errors where onDiag (ProgError NoLoc _) = mempty onDiag (ProgError loc@(Loc pos _) msg) = M.singleton (posToUri pos) [ mkDiagnostic (rangeFromLoc loc) DiagnosticSeverity_Error (docText msg) ] onDiag (ProgWarning NoLoc _) = mempty onDiag (ProgWarning loc@(Loc pos _) msg) = M.singleton (posToUri pos) [ mkDiagnostic (rangeFromLoc loc) DiagnosticSeverity_Error (docText msg) ] -- | The maximum number of diagnostics to report. maxDiagnostic :: Int maxDiagnostic = 100 -- | The source of the diagnostics. (That is, the Futhark compiler, -- but apparently the client must be told such things...) diagnosticSource :: Maybe T.Text diagnosticSource = Just "futhark" futhark-0.25.27/src/Futhark/LSP/Handlers.hs000066400000000000000000000106761475065116200203010ustar00rootroot00000000000000{-# LANGUAGE DataKinds #-} -- | The handlers exposed by the language server. module Futhark.LSP.Handlers (handlers) where import Colog.Core (logStringStderr, (<&)) import Control.Lens ((^.)) import Data.Aeson.Types (Value (Array, String)) import Data.IORef import Data.Proxy (Proxy (..)) import Data.Vector qualified as V import Futhark.LSP.Compile (tryReCompile, tryTakeStateFromIORef) import Futhark.LSP.State (State (..)) import Futhark.LSP.Tool (findDefinitionRange, getHoverInfoFromState) import Language.LSP.Protocol.Lens (HasUri (uri)) import Language.LSP.Protocol.Message import Language.LSP.Protocol.Types import Language.LSP.Server (Handlers, LspM, notificationHandler, requestHandler) onInitializeHandler :: Handlers (LspM ()) onInitializeHandler = notificationHandler SMethod_Initialized $ \_msg -> logStringStderr <& "Initialized" onHoverHandler :: IORef State -> Handlers (LspM ()) onHoverHandler state_mvar = requestHandler SMethod_TextDocumentHover $ \req responder -> do let TRequestMessage _ _ _ (HoverParams doc pos _workDone) = req Position l c = pos file_path = uriToFilePath $ doc ^. uri logStringStderr <& ("Got hover request: " <> show (file_path, pos)) state <- tryTakeStateFromIORef state_mvar file_path responder $ Right $ maybe (InR Null) InL $ getHoverInfoFromState state file_path (fromEnum l + 1) (fromEnum c + 1) onDocumentFocusHandler :: IORef State -> Handlers (LspM ()) onDocumentFocusHandler state_mvar = notificationHandler (SMethod_CustomMethod (Proxy @"custom/onFocusTextDocument")) $ \msg -> do logStringStderr <& "Got custom request: onFocusTextDocument" let TNotificationMessage _ _ (Array vector_param) = msg String focused_uri = V.head vector_param -- only one parameter passed from the client tryReCompile state_mvar (uriToFilePath (Uri focused_uri)) goToDefinitionHandler :: IORef State -> Handlers (LspM ()) goToDefinitionHandler state_mvar = requestHandler SMethod_TextDocumentDefinition $ \req responder -> do let TRequestMessage _ _ _ (DefinitionParams doc pos _workDone _partial) = req Position l c = pos file_path = uriToFilePath $ doc ^. uri logStringStderr <& ("Got goto definition: " <> show (file_path, pos)) state <- tryTakeStateFromIORef state_mvar file_path case findDefinitionRange state file_path (fromEnum l + 1) (fromEnum c + 1) of Nothing -> responder $ Right $ InR $ InR Null Just loc -> responder $ Right $ InL $ Definition $ InL loc onDocumentSaveHandler :: IORef State -> Handlers (LspM ()) onDocumentSaveHandler state_mvar = notificationHandler SMethod_TextDocumentDidSave $ \msg -> do let TNotificationMessage _ _ (DidSaveTextDocumentParams doc _text) = msg file_path = uriToFilePath $ doc ^. uri logStringStderr <& ("Saved document: " ++ show doc) tryReCompile state_mvar file_path onDocumentChangeHandler :: IORef State -> Handlers (LspM ()) onDocumentChangeHandler state_mvar = notificationHandler SMethod_TextDocumentDidChange $ \msg -> do let TNotificationMessage _ _ (DidChangeTextDocumentParams doc _content) = msg file_path = uriToFilePath $ doc ^. uri tryReCompile state_mvar file_path -- Some clients (Eglot) sends open/close events whether we want them -- or not, so we better be prepared to ignore them. onDocumentOpenHandler :: Handlers (LspM ()) onDocumentOpenHandler = notificationHandler SMethod_TextDocumentDidOpen $ \_ -> pure () onDocumentCloseHandler :: Handlers (LspM ()) onDocumentCloseHandler = notificationHandler SMethod_TextDocumentDidClose $ \_msg -> pure () -- Sent by Eglot when first connecting - not sure when else it might -- be sent. onWorkspaceDidChangeConfiguration :: IORef State -> Handlers (LspM ()) onWorkspaceDidChangeConfiguration _state_mvar = notificationHandler SMethod_WorkspaceDidChangeConfiguration $ \_ -> logStringStderr <& "WorkspaceDidChangeConfiguration" -- | Given an 'IORef' tracking the state, produce a set of handlers. -- When we want to add more features to the language server, this is -- the thing to change. handlers :: IORef State -> ClientCapabilities -> Handlers (LspM ()) handlers state_mvar _ = mconcat [ onInitializeHandler, onDocumentOpenHandler, onDocumentCloseHandler, onDocumentSaveHandler state_mvar, onDocumentChangeHandler state_mvar, onDocumentFocusHandler state_mvar, goToDefinitionHandler state_mvar, onHoverHandler state_mvar, onWorkspaceDidChangeConfiguration state_mvar ] futhark-0.25.27/src/Futhark/LSP/PositionMapping.hs000066400000000000000000000063711475065116200216560ustar00rootroot00000000000000-- | Provide mapping between position in stale content and current. module Futhark.LSP.PositionMapping ( mappingFromDiff, PositionMapping, toStalePos, toCurrentLoc, StaleFile (..), ) where import Data.Algorithm.Diff (Diff, PolyDiff (Both, First, Second), getDiff) import Data.Bifunctor (Bifunctor (bimap, first, second)) import Data.Text qualified as T import Data.Vector qualified as V import Futhark.Util.Loc (Loc (Loc), Pos (Pos)) import Language.LSP.VFS (VirtualFile) -- | A mapping between current file content and the stale (last successful compiled) file content. -- Currently, only supports entire line mapping, -- more detailed mapping might be achieved via referring to haskell-language-server@efb4b94 data PositionMapping = PositionMapping { -- | The mapping from stale position to current. -- e.g. staleToCurrent[2] = 4 means "line 2" in the stale file, corresponds to "line 4" in the current file. staleToCurrent :: V.Vector Int, -- | The mapping from current position to stale. currentToStale :: V.Vector Int } deriving (Show) -- | Stale pretty document stored in state. data StaleFile = StaleFile { -- | The last successfully compiled file content. -- Using VirtualFile for convenience, we can use anything with {version, content} staleContent :: VirtualFile, -- | PositionMapping between current and stale file content. -- Nothing if last type-check is successful. staleMapping :: Maybe PositionMapping } deriving (Show) -- | Compute PositionMapping using the diff between two texts. mappingFromDiff :: [T.Text] -> [T.Text] -> PositionMapping mappingFromDiff stale current = do let (stale_to_current, current_to_stale) = rawMapping (getDiff stale current) 0 0 PositionMapping (V.fromList stale_to_current) (V.fromList current_to_stale) where rawMapping :: [Diff T.Text] -> Int -> Int -> ([Int], [Int]) rawMapping [] _ _ = ([], []) rawMapping (Both _ _ : xs) lold lnew = bimap (lnew :) (lold :) $ rawMapping xs (lold + 1) (lnew + 1) rawMapping (First _ : xs) lold lnew = first (-1 :) $ rawMapping xs (lold + 1) lnew rawMapping (Second _ : xs) lold lnew = second (-1 :) $ rawMapping xs lold (lnew + 1) -- | Transform current Pos to the stale pos for query -- Note: line and col in Pos is larger by one toStalePos :: Maybe PositionMapping -> Pos -> Maybe Pos toStalePos (Just (PositionMapping _ current_to_stale)) pos = if l > Prelude.length current_to_stale then Nothing else Just $ Pos file (V.unsafeIndex current_to_stale (l - 1) + 1) c o where Pos file l c o = pos toStalePos Nothing pos = Just pos -- some refactoring might be needed, same logic as toStalePos toCurrentPos :: Maybe PositionMapping -> Pos -> Maybe Pos toCurrentPos (Just (PositionMapping stale_to_current _)) pos = if l > Prelude.length stale_to_current then Nothing else Just $ Pos file (V.unsafeIndex stale_to_current (l - 1) + 1) c o where Pos file l c o = pos toCurrentPos Nothing pos = Just pos -- | Transform stale Loc gotten from stale AST to current Loc. toCurrentLoc :: Maybe PositionMapping -> Loc -> Maybe Loc toCurrentLoc mapping loc = do let Loc start end = loc current_start <- toCurrentPos mapping start current_end <- toCurrentPos mapping end Just $ Loc current_start current_end futhark-0.25.27/src/Futhark/LSP/State.hs000066400000000000000000000045061475065116200176140ustar00rootroot00000000000000-- | The language server state definition. module Futhark.LSP.State ( State (..), emptyState, getStaleContent, getStaleMapping, updateStaleContent, updateStaleMapping, ) where import Data.Map qualified as M import Futhark.Compiler.Program (LoadedProg) import Futhark.LSP.PositionMapping (PositionMapping, StaleFile (..)) import Language.LSP.VFS (VirtualFile) -- | The state of the language server. data State = State { -- | The loaded program. stateProgram :: Maybe LoadedProg, -- | The stale data, stored to provide PositionMapping when requested. -- All files that have been opened have an entry. staleData :: M.Map FilePath StaleFile } -- | Initial state. emptyState :: State emptyState = State Nothing M.empty -- | Get the contents of a stale (last successfully complied) file's contents. getStaleContent :: State -> FilePath -> Maybe VirtualFile getStaleContent state file_path = (Just . staleContent) =<< M.lookup file_path (staleData state) -- | Get the PositionMapping for a file. getStaleMapping :: State -> FilePath -> Maybe PositionMapping getStaleMapping state file_path = staleMapping =<< M.lookup file_path (staleData state) -- | Update the state with another pair of file_path and contents. -- Could do a clean up becausae there is no need to store files that are not in lpFilePaths prog. updateStaleContent :: FilePath -> VirtualFile -> LoadedProg -> State -> State updateStaleContent file_path file_content loadedProg state = -- NOTE: insert will replace the old value if the key already exists. -- updateStaleContent is only called after a successful type-check, -- so the PositionsMapping should be Nothing here, it's calculated after failed type-check. State (Just loadedProg) (M.insert file_path (StaleFile file_content Nothing) (staleData state)) -- | Update the state with another pair of file_path and PositionMapping. updateStaleMapping :: Maybe FilePath -> Maybe PositionMapping -> State -> State updateStaleMapping (Just file_path) mapping state = do case M.lookup file_path (staleData state) of Nothing -> state -- Only happends when the file have never been successfully type-checked before. Just (StaleFile file_content _mapping) -> State (stateProgram state) (M.insert file_path (StaleFile file_content mapping) (staleData state)) updateStaleMapping _ _ state = state futhark-0.25.27/src/Futhark/LSP/Tool.hs000066400000000000000000000123021475065116200174420ustar00rootroot00000000000000-- | Generally useful definition used in various places in the -- language server implementation. module Futhark.LSP.Tool ( getHoverInfoFromState, findDefinitionRange, rangeFromLoc, posToUri, computeMapping, ) where import Data.Text qualified as T import Futhark.Compiler.Program (lpImports) import Futhark.LSP.PositionMapping ( PositionMapping, mappingFromDiff, toCurrentLoc, toStalePos, ) import Futhark.LSP.State (State (..), getStaleContent, getStaleMapping) import Futhark.Util.Loc (Loc (Loc, NoLoc), Pos (Pos)) import Futhark.Util.Pretty (prettyText) import Language.Futhark.Prop (isBuiltinLoc) import Language.Futhark.Query ( AtPos (AtName), BoundTo (..), atPos, boundLoc, ) import Language.LSP.Protocol.Types import Language.LSP.Server (LspM, getVirtualFile) import Language.LSP.VFS (VirtualFile, virtualFileText, virtualFileVersion) -- | Retrieve hover info for the definition referenced at the given -- file at the given line and column number (the two 'Int's). getHoverInfoFromState :: State -> Maybe FilePath -> Int -> Int -> Maybe Hover getHoverInfoFromState state (Just path) l c = do AtName _ (Just def) loc <- queryAtPos state $ Pos path l c 0 let msg = case def of BoundTerm t _ -> prettyText t BoundModule {} -> "module" BoundModuleType {} -> "module type" BoundType {} -> "type" ms = MarkupContent MarkupKind_PlainText msg Just $ Hover (InL ms) (Just (rangeFromLoc loc)) getHoverInfoFromState _ _ _ _ = Nothing -- | Find the location of the definition referenced at the given file -- at the given line and column number (the two 'Int's). findDefinitionRange :: State -> Maybe FilePath -> Int -> Int -> Maybe Location findDefinitionRange state (Just path) l c = do -- some unnessecary operations inside `queryAtPos` for this function -- but shouldn't affect performance much since "Go to definition" is called less frequently AtName _qn (Just bound) _loc <- queryAtPos state $ Pos path l c 0 let loc = boundLoc bound Loc (Pos file_path _ _ _) _ = loc if isBuiltinLoc loc then Nothing else Just $ Location (filePathToUri file_path) (rangeFromLoc loc) findDefinitionRange _ _ _ _ = Nothing -- | Query the AST for information at certain Pos. queryAtPos :: State -> Pos -> Maybe AtPos queryAtPos state pos = do let Pos path _ _ _ = pos mapping = getStaleMapping state path loaded_prog <- stateProgram state stale_pos <- toStalePos mapping pos query_result <- atPos (lpImports loaded_prog) stale_pos updateAtPos mapping query_result where -- Update the 'AtPos' with the current mapping. updateAtPos :: Maybe PositionMapping -> AtPos -> Maybe AtPos updateAtPos mapping (AtName qn (Just def) loc) = do let def_loc = boundLoc def Loc (Pos def_file _ _ _) _ = def_loc Pos current_file _ _ _ = pos current_loc <- toCurrentLoc mapping loc if def_file == current_file then do current_def_loc <- toCurrentLoc mapping def_loc Just $ AtName qn (Just (updateBoundLoc def current_def_loc)) current_loc else do -- Defined in another file, get the corresponding PositionMapping. let def_mapping = getStaleMapping state def_file current_def_loc <- toCurrentLoc def_mapping def_loc Just $ AtName qn (Just (updateBoundLoc def current_def_loc)) current_loc updateAtPos _ _ = Nothing updateBoundLoc :: BoundTo -> Loc -> BoundTo updateBoundLoc (BoundTerm t _loc) current_loc = BoundTerm t current_loc updateBoundLoc (BoundModule _loc) current_loc = BoundModule current_loc updateBoundLoc (BoundModuleType _loc) current_loc = BoundModuleType current_loc updateBoundLoc (BoundType _loc) current_loc = BoundType current_loc -- | Entry point for computing PositionMapping. computeMapping :: State -> Maybe FilePath -> LspM () (Maybe PositionMapping) computeMapping state (Just file_path) = do virtual_file <- getVirtualFile $ toNormalizedUri $ filePathToUri file_path pure $ getMapping (getStaleContent state file_path) virtual_file where getMapping :: Maybe VirtualFile -> Maybe VirtualFile -> Maybe PositionMapping getMapping (Just stale_file) (Just current_file) = if virtualFileVersion stale_file == virtualFileVersion current_file then Nothing -- Happens when other files (e.g. dependencies) fail to type-check. else Just $ mappingFromDiff (T.lines $ virtualFileText stale_file) (T.lines $ virtualFileText current_file) getMapping _ _ = Nothing computeMapping _ _ = pure Nothing -- | Convert a Futhark 'Pos' to an LSP 'Uri'. posToUri :: Pos -> Uri posToUri (Pos file _ _ _) = filePathToUri file -- Futhark's parser has a slightly different notion of locations than -- LSP; so we tweak the positions here. getStartPos :: Pos -> Position getStartPos (Pos _ l c _) = Position (toEnum l - 1) (toEnum c - 1) getEndPos :: Pos -> Position getEndPos (Pos _ l c _) = Position (toEnum l - 1) (toEnum c) -- | Create an LSP 'Range' from a Futhark 'Loc'. rangeFromLoc :: Loc -> Range rangeFromLoc (Loc start end) = Range (getStartPos start) (getEndPos end) rangeFromLoc NoLoc = Range (Position 0 0) (Position 0 5) -- only when file not found, throw error after moving to vfs futhark-0.25.27/src/Futhark/MonadFreshNames.hs000066400000000000000000000115571475065116200211140ustar00rootroot00000000000000{-# LANGUAGE UndecidableInstances #-} -- | This module provides a monadic facility similar (and built on top -- of) "Futhark.FreshNames". The removes the need for a (small) amount of -- boilerplate, at the cost of using some GHC extensions. The idea is -- that if your compiler pass runs in a monad that is an instance of -- 'MonadFreshNames', you can automatically use the name generation -- functions exported by this module. module Futhark.MonadFreshNames ( MonadFreshNames (..), modifyNameSource, newName, newNameFromString, newVName, newIdent, newIdent', newParam, module Futhark.FreshNames, ) where import Control.Monad.Except import Control.Monad.RWS.Lazy qualified import Control.Monad.RWS.Strict qualified import Control.Monad.Reader import Control.Monad.State.Lazy qualified import Control.Monad.State.Strict qualified import Control.Monad.Trans.Maybe qualified import Control.Monad.Writer.Lazy qualified import Control.Monad.Writer.Strict qualified import Futhark.FreshNames hiding (newName) import Futhark.FreshNames qualified as FreshNames import Futhark.IR.Syntax -- | A monad that stores a name source. The following is a good -- instance for a monad in which the only state is a @NameSource vn@: -- -- @ -- instance MonadFreshNames vn MyMonad where -- getNameSource = get -- putNameSource = put -- @ class (Monad m) => MonadFreshNames m where getNameSource :: m VNameSource putNameSource :: VNameSource -> m () instance (Monad im) => MonadFreshNames (Control.Monad.State.Lazy.StateT VNameSource im) where getNameSource = Control.Monad.State.Lazy.get putNameSource = Control.Monad.State.Lazy.put instance (Monad im) => MonadFreshNames (Control.Monad.State.Strict.StateT VNameSource im) where getNameSource = Control.Monad.State.Strict.get putNameSource = Control.Monad.State.Strict.put instance (Monad im, Monoid w) => MonadFreshNames (Control.Monad.RWS.Lazy.RWST r w VNameSource im) where getNameSource = Control.Monad.RWS.Lazy.get putNameSource = Control.Monad.RWS.Lazy.put instance (Monad im, Monoid w) => MonadFreshNames (Control.Monad.RWS.Strict.RWST r w VNameSource im) where getNameSource = Control.Monad.RWS.Strict.get putNameSource = Control.Monad.RWS.Strict.put -- | Run a computation needing a fresh name source and returning a new -- one, using 'getNameSource' and 'putNameSource' before and after the -- computation. modifyNameSource :: (MonadFreshNames m) => (VNameSource -> (a, VNameSource)) -> m a modifyNameSource m = do src <- getNameSource let (x, src') = m src src' `seq` putNameSource src' pure x -- | Produce a fresh name, using the given name as a template. newName :: (MonadFreshNames m) => VName -> m VName newName = modifyNameSource . flip FreshNames.newName -- | As @newName@, but takes a 'String' for the name template. newNameFromString :: (MonadFreshNames m) => String -> m VName newNameFromString s = newName $ VName (nameFromString s) 0 -- | Produce a fresh 'VName', using the given base name as a template. newID :: (MonadFreshNames m) => Name -> m VName newID s = newName $ VName s 0 -- | Produce a fresh 'VName', using the given base name as a template. newVName :: (MonadFreshNames m) => String -> m VName newVName = newID . nameFromString -- | Produce a fresh 'Ident', using the given name as a template. newIdent :: (MonadFreshNames m) => String -> Type -> m Ident newIdent s t = do s' <- newID $ nameFromString s pure $ Ident s' t -- | Produce a fresh 'Ident', using the given 'Ident' as a template, -- but possibly modifying the name. newIdent' :: (MonadFreshNames m) => (String -> String) -> Ident -> m Ident newIdent' f ident = newIdent (f $ nameToString $ baseName $ identName ident) (identType ident) -- | Produce a fresh 'Param', using the given name as a template. newParam :: (MonadFreshNames m) => String -> dec -> m (Param dec) newParam s t = do s' <- newID $ nameFromString s pure $ Param mempty s' t -- Utility instance defintions for MTL classes. This requires -- UndecidableInstances, but saves on typing elsewhere. instance (MonadFreshNames m) => MonadFreshNames (ReaderT s m) where getNameSource = lift getNameSource putNameSource = lift . putNameSource instance (MonadFreshNames m, Monoid s) => MonadFreshNames (Control.Monad.Writer.Lazy.WriterT s m) where getNameSource = lift getNameSource putNameSource = lift . putNameSource instance (MonadFreshNames m, Monoid s) => MonadFreshNames (Control.Monad.Writer.Strict.WriterT s m) where getNameSource = lift getNameSource putNameSource = lift . putNameSource instance (MonadFreshNames m) => MonadFreshNames (Control.Monad.Trans.Maybe.MaybeT m) where getNameSource = lift getNameSource putNameSource = lift . putNameSource instance (MonadFreshNames m) => MonadFreshNames (ExceptT e m) where getNameSource = lift getNameSource putNameSource = lift . putNameSource futhark-0.25.27/src/Futhark/Optimise/000077500000000000000000000000001475065116200173265ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Optimise/ArrayLayout.hs000066400000000000000000000031041475065116200221340ustar00rootroot00000000000000module Futhark.Optimise.ArrayLayout ( optimiseArrayLayoutGPU, optimiseArrayLayoutMC, ) where import Control.Monad.State.Strict import Futhark.Analysis.AccessPattern (Analyse, analyseDimAccesses) import Futhark.Analysis.PrimExp.Table (primExpTable) import Futhark.Builder import Futhark.IR.GPU (GPU) import Futhark.IR.MC (MC) import Futhark.Optimise.ArrayLayout.Layout (layoutTableFromIndexTable) import Futhark.Optimise.ArrayLayout.Transform (Transform, transformStms) import Futhark.Pass optimiseArrayLayout :: (Analyse rep, Transform rep, BuilderOps rep) => String -> Pass rep rep optimiseArrayLayout s = Pass ("optimise array layout " <> s) "Transform array layout for locality optimisations." $ \prog -> do -- Analyse the program let index_table = analyseDimAccesses prog -- Compute primExps for all variables let table = primExpTable prog -- Compute permutations to acheive coalescence for all arrays let permutation_table = layoutTableFromIndexTable table index_table -- Insert permutations in the AST intraproceduralTransformation (onStms permutation_table) prog where onStms layout_table scope stms = do let m = transformStms layout_table mempty stms fmap fst $ modifyNameSource $ runState $ runBuilderT m scope -- | The optimisation performed on the GPU representation. optimiseArrayLayoutGPU :: Pass GPU GPU optimiseArrayLayoutGPU = optimiseArrayLayout "gpu" -- | The optimisation performed on the MC representation. optimiseArrayLayoutMC :: Pass MC MC optimiseArrayLayoutMC = optimiseArrayLayout "mc" futhark-0.25.27/src/Futhark/Optimise/ArrayLayout/000077500000000000000000000000001475065116200216025ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Optimise/ArrayLayout/Layout.hs000066400000000000000000000227441475065116200234240ustar00rootroot00000000000000module Futhark.Optimise.ArrayLayout.Layout ( layoutTableFromIndexTable, Layout, Permutation, LayoutTable, -- * Exposed for testing commonPermutationEliminators, ) where import Control.Monad (join) import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Analysis.AccessPattern import Futhark.Analysis.PrimExp.Table (PrimExpTable) import Futhark.IR.Aliases import Futhark.IR.GPU import Futhark.IR.MC import Futhark.IR.MCMem import Futhark.Util (mininum) type Permutation = [Int] type LayoutTable = M.Map SegOpName ( M.Map ArrayName (M.Map IndexExprName Permutation) ) class Layout rep where -- | Produce a coalescing permutation that will be used to create a -- manifest of the array. Returns Nothing if the array is already in -- the optimal layout or if the array access is too complex to -- confidently determine the optimal layout. Map each list of -- 'DimAccess' in the IndexTable to a permutation in a generic way -- that can be handled uniquely by each backend. permutationFromDimAccess :: PrimExpTable -> SegOpName -> ArrayName -> IndexExprName -> [DimAccess rep] -> Maybe Permutation isInscrutableExp :: PrimExp VName -> Bool isInscrutableExp (LeafExp _ _) = False isInscrutableExp (ValueExp _) = False isInscrutableExp (BinOpExp _ a b) = isInscrutableExp a || isInscrutableExp b isInscrutableExp (UnOpExp _ a) = isInscrutableExp a isInscrutableExp _ = True isInscrutable :: PrimExp VName -> Bool -> Bool isInscrutable op@(BinOpExp {}) counter = if counter then -- Calculate stride and offset for loop-counters and thread-IDs case reduceStrideAndOffset op of -- Maximum allowable stride, might need tuning. Just (s, _) -> s > 8 Nothing -> isInscrutableExp op else isInscrutableExp op isInscrutable op _ = isInscrutableExp op reduceStrideAndOffset :: PrimExp l -> Maybe (Int, Int) reduceStrideAndOffset (LeafExp _ _) = Just (1, 0) reduceStrideAndOffset (BinOpExp oper a b) = case (a, b) of (ValueExp (IntValue v), _) -> reduce v b (_, ValueExp (IntValue v)) -> reduce v a _ -> Nothing where reduce v (LeafExp _ _) = case oper of Add _ _ -> Just (1, valueIntegral v) Sub _ _ -> Just (1, -valueIntegral v) Mul _ _ -> Just (valueIntegral v, 0) _ -> Nothing reduce v op@(BinOpExp {}) = case reduceStrideAndOffset op of Nothing -> Nothing Just (s, o) -> case oper of Add _ _ -> Just (s, o + valueIntegral v) Sub _ _ -> Just (s, o - valueIntegral v) Mul _ _ -> Just (s * valueIntegral v, o * valueIntegral v) _ -> Nothing reduce _ (UnOpExp (Neg Bool) _) = Nothing reduce _ (UnOpExp (Complement _) _) = Nothing reduce _ (UnOpExp (Abs _) _) = Nothing reduce _ (UnOpExp _ sub_op) = reduceStrideAndOffset sub_op reduce _ (ConvOpExp _ sub_op) = reduceStrideAndOffset sub_op reduce _ _ = Nothing reduceStrideAndOffset _ = Nothing -- | Reasons common to all backends to not manifest an array. commonPermutationEliminators :: [Int] -> [BodyType] -> Bool commonPermutationEliminators perm nest = do -- Don't manifest if the permutation is the permutation is invalid let is_invalid_perm = not (L.sort perm `L.isPrefixOf` [0 ..]) -- Don't manifest if the permutation is the identity permutation is_identity = perm `L.isPrefixOf` [0 ..] -- or is not a transpose. inefficient_transpose = isNothing $ isMapTranspose perm -- or if the last idx remains last static_last_idx = last perm == length perm - 1 -- Don't manifest if the array is defined inside a segOp inside_undesired = any undesired nest is_invalid_perm || is_identity || inefficient_transpose || static_last_idx || inside_undesired where undesired :: BodyType -> Bool undesired bodyType = case bodyType of SegOpName _ -> True _ -> False sortMC :: [(Int, DimAccess rep)] -> [(Int, DimAccess rep)] sortMC = L.sortBy dimdexMCcmp where dimdexMCcmp (ia, a) (ib, b) = do let aggr1 = foldl max' Nothing $ map (f ia . snd) $ M.toList $ dependencies a aggr2 = foldl max' Nothing $ map (f ib . snd) $ M.toList $ dependencies b cmpIdxPat aggr1 aggr2 where cmpIdxPat Nothing Nothing = EQ cmpIdxPat (Just _) Nothing = GT cmpIdxPat Nothing (Just _) = LT cmpIdxPat (Just (iterL, lvlL, original_lvl_L)) (Just (iterR, lvlR, original_lvl_R)) = case (iterL, iterR) of (ThreadID, ThreadID) -> (lvlL, original_lvl_L) `compare` (lvlR, original_lvl_R) (ThreadID, _) -> LT (_, ThreadID) -> GT _ -> (lvlL, original_lvl_L) `compare` (lvlR, original_lvl_R) max' lhs rhs = case cmpIdxPat lhs rhs of LT -> rhs _ -> lhs f og (Dependency lvl varType) = Just (varType, lvl, og) multicorePermutation :: PrimExpTable -> SegOpName -> ArrayName -> IndexExprName -> [DimAccess rep] -> Maybe Permutation multicorePermutation primExpTable _segOpName (_arr_name, nest, arr_layout) _idx_name dimAccesses = do -- Dont accept indices where the last index is invariant let lastIdxIsInvariant = isInvariant $ last dimAccesses -- Check if any of the dependencies are too complex to reason about let dimAccesses' = filter (isJust . originalVar) dimAccesses deps = mapMaybe originalVar dimAccesses' counters = concatMap (map (isCounter . varType . snd) . M.toList . dependencies) dimAccesses' primExps = mapM (join . (`M.lookup` primExpTable)) deps inscrutable = maybe True (any (uncurry isInscrutable) . flip zip counters) primExps -- Create a candidate permutation let perm = map fst $ sortMC (zip arr_layout dimAccesses) -- Check if we want to manifest this array with the permutation if lastIdxIsInvariant || inscrutable || commonPermutationEliminators perm nest then Nothing else Just perm instance Layout MC where permutationFromDimAccess = multicorePermutation sortGPU :: [(Int, DimAccess rep)] -> [(Int, DimAccess rep)] sortGPU = L.sortBy dimdexGPUcmp where dimdexGPUcmp (ia, a) (ib, b) = do let aggr1 = foldl max' Nothing $ map (f ia . snd) $ M.toList $ dependencies a aggr2 = foldl max' Nothing $ map (f ib . snd) $ M.toList $ dependencies b cmpIdxPat aggr1 aggr2 where cmpIdxPat Nothing Nothing = EQ cmpIdxPat (Just _) Nothing = GT cmpIdxPat Nothing (Just _) = LT cmpIdxPat (Just (iterL, lvlL, original_lvl_L)) (Just (iterR, lvlR, original_lvl_R)) = case (iterL, iterR) of (ThreadID, ThreadID) -> (lvlL, original_lvl_L) `compare` (lvlR, original_lvl_R) (ThreadID, _) -> GT (_, ThreadID) -> LT _ -> (lvlL, original_lvl_L) `compare` (lvlR, original_lvl_R) max' lhs rhs = case cmpIdxPat lhs rhs of LT -> rhs _ -> lhs f og (Dependency lvl varType) = Just (varType, lvl, og) gpuPermutation :: PrimExpTable -> SegOpName -> ArrayName -> IndexExprName -> [DimAccess rep] -> Maybe Permutation gpuPermutation primExpTable _segOpName (_arr_name, nest, arr_layout) _idx_name dimAccesses = do -- Find the outermost parallel level. XXX: this is a bit hacky. Why -- don't we simply know at this point the nest in which this index -- occurs? let outermost_par = mininum $ foldMap (map lvl . parDeps) dimAccesses invariantToPar = (< outermost_par) . lvl -- Do nothing if last index is invariant to segop. let lastIdxIsInvariant = all invariantToPar $ dependencies $ last dimAccesses -- Do nothing if any index is constant, because otherwise we can end -- up transposing a too-large array. let anyIsConstant = any (null . dependencies) dimAccesses -- Check if any of the dependencies are too complex to reason about let dimAccesses' = filter (isJust . originalVar) dimAccesses deps = mapMaybe originalVar dimAccesses' counters = concatMap (map (isCounter . varType . snd) . M.toList . dependencies) dimAccesses' primExps = mapM (join . (`M.lookup` primExpTable)) deps inscrutable = maybe True (any (uncurry isInscrutable) . flip zip counters) primExps -- Create a candidate permutation let perm = map fst $ sortGPU (zip arr_layout dimAccesses) -- Check if we want to manifest this array with the permutation if lastIdxIsInvariant || anyIsConstant || inscrutable || commonPermutationEliminators perm nest then Nothing else Just perm where parDeps = filter ((== ThreadID) . varType) . M.elems . dependencies instance Layout GPU where permutationFromDimAccess = gpuPermutation -- | like mapMaybe, but works on nested maps. Eliminates "dangling" -- maps / rows with missing (Nothing) values. tableMapMaybe :: (k0 -> k1 -> k2 -> a -> Maybe b) -> M.Map k0 (M.Map k1 (M.Map k2 a)) -> M.Map k0 (M.Map k1 (M.Map k2 b)) tableMapMaybe f = M.mapMaybeWithKey $ \key0 -> mapToMaybe $ mapToMaybe . f key0 where maybeMap :: M.Map k a -> Maybe (M.Map k a) maybeMap val = if null val then Nothing else Just val mapToMaybe g = maybeMap . M.mapMaybeWithKey g -- | Given an ordering function for `DimAccess`, and an IndexTable, -- return a LayoutTable. We remove entries with no results after -- `permutationFromDimAccess` layoutTableFromIndexTable :: (Layout rep) => PrimExpTable -> IndexTable rep -> LayoutTable layoutTableFromIndexTable = tableMapMaybe . permutationFromDimAccess futhark-0.25.27/src/Futhark/Optimise/ArrayLayout/Transform.hs000066400000000000000000000233321475065116200241140ustar00rootroot00000000000000-- | Do various kernel optimisations - mostly related to coalescing. module Futhark.Optimise.ArrayLayout.Transform ( Transform, transformStms, ) where import Control.Monad import Control.Monad.State.Strict import Data.Map.Strict qualified as M import Futhark.Analysis.AccessPattern (IndexExprName, SegOpName (..)) import Futhark.Analysis.PrimExp.Table (PrimExpAnalysis) import Futhark.Builder import Futhark.Construct import Futhark.IR.Aliases import Futhark.IR.GPU import Futhark.IR.MC import Futhark.Optimise.ArrayLayout.Layout (Layout, LayoutTable, Permutation) class (Layout rep, PrimExpAnalysis rep) => Transform rep where onOp :: (Monad m) => SOACMapper rep rep m -> Op rep -> m (Op rep) transformOp :: LayoutTable -> ExpMap rep -> Stm rep -> Op rep -> TransformM rep (LayoutTable, ExpMap rep) type TransformM rep = Builder rep -- | A map from the name of an expression to the expression that defines it. type ExpMap rep = M.Map VName (Stm rep) instance Transform GPU where onOp soac_mapper (Futhark.IR.GPU.OtherOp soac) = Futhark.IR.GPU.OtherOp <$> mapSOACM soac_mapper soac onOp _ op = pure op transformOp perm_table expmap stm gpuOp | SegOp op <- gpuOp, -- TODO: handle non-segthread cases. This requires some care to -- avoid doing huge manifests at the block level. SegThread {} <- segLevel op = transformSegOpGPU perm_table expmap stm op | _ <- gpuOp = transformRestOp perm_table expmap stm instance Transform MC where onOp soac_mapper (Futhark.IR.MC.OtherOp soac) = Futhark.IR.MC.OtherOp <$> mapSOACM soac_mapper soac onOp _ op = pure op transformOp perm_table expmap stm mcOp | ParOp maybe_par_segop seqSegOp <- mcOp = transformSegOpMC perm_table expmap stm maybe_par_segop seqSegOp | _ <- mcOp = transformRestOp perm_table expmap stm transformSegOpGPU :: LayoutTable -> ExpMap GPU -> Stm GPU -> SegOp SegLevel GPU -> TransformM GPU (LayoutTable, ExpMap GPU) transformSegOpGPU perm_table expmap stm@(Let pat aux _) op = -- Optimization: Only traverse the body of the SegOp if it is -- represented in the layout table case M.lookup patternName (M.mapKeys vnameFromSegOp perm_table) of Nothing -> do addStm stm pure (perm_table, M.fromList [(name, stm) | name <- patNames pat] <> expmap) Just _ -> do let mapper = identitySegOpMapper { mapOnSegOpBody = case segLevel op of SegBlock {} -> transformSegGroupKernelBody perm_table expmap _ -> transformSegThreadKernelBody perm_table patternName } op' <- mapSegOpM mapper op let stm' = Let pat aux $ Op $ SegOp op' addStm stm' pure (perm_table, M.fromList [(name, stm') | name <- patNames pat] <> expmap) where patternName = patElemName . head $ patElems pat transformSegOpMC :: LayoutTable -> ExpMap MC -> Stm MC -> Maybe (SegOp () MC) -> SegOp () MC -> TransformM MC (LayoutTable, ExpMap MC) transformSegOpMC perm_table expmap (Let pat aux _) maybe_par_segop seqSegOp | Nothing <- maybe_par_segop = add Nothing | Just par_segop <- maybe_par_segop = -- Optimization: Only traverse the body of the SegOp if it is -- represented in the layout table case M.lookup patternName (M.mapKeys vnameFromSegOp perm_table) of Nothing -> add $ Just par_segop Just _ -> add . Just =<< mapSegOpM mapper par_segop where add maybe_par_segop' = do -- Map the sequential part of the ParOp seqSegOp' <- mapSegOpM mapper seqSegOp let stm' = Let pat aux $ Op $ ParOp maybe_par_segop' seqSegOp' addStm stm' pure (perm_table, M.fromList [(name, stm') | name <- patNames pat] <> expmap) mapper = identitySegOpMapper {mapOnSegOpBody = transformKernelBody perm_table expmap patternName} patternName = patElemName . head $ patElems pat transformRestOp :: (Transform rep, BuilderOps rep) => LayoutTable -> ExpMap rep -> Stm rep -> TransformM rep (LayoutTable, ExpMap rep) transformRestOp perm_table expmap (Let pat aux e) = do e' <- mapExpM (transform perm_table expmap) e let stm' = Let pat aux e' addStm stm' pure (perm_table, M.fromList [(name, stm') | name <- patNames pat] <> expmap) transform :: (Transform rep, BuilderOps rep) => LayoutTable -> ExpMap rep -> Mapper rep rep (TransformM rep) transform perm_table expmap = identityMapper {mapOnBody = \scope -> localScope scope . transformBody perm_table expmap} -- | Recursively transform the statements in a body. transformBody :: (Transform rep, BuilderOps rep) => LayoutTable -> ExpMap rep -> Body rep -> TransformM rep (Body rep) transformBody perm_table expmap (Body b stms res) = Body b <$> transformStms perm_table expmap stms <*> pure res -- | Recursively transform the statements in the body of a SegGroup kernel. transformSegGroupKernelBody :: (Transform rep, BuilderOps rep) => LayoutTable -> ExpMap rep -> KernelBody rep -> TransformM rep (KernelBody rep) transformSegGroupKernelBody perm_table expmap (KernelBody b stms res) = KernelBody b <$> transformStms perm_table expmap stms <*> pure res -- | Transform the statements in the body of a SegThread kernel. transformSegThreadKernelBody :: (Transform rep, BuilderOps rep) => LayoutTable -> VName -> KernelBody rep -> TransformM rep (KernelBody rep) transformSegThreadKernelBody perm_table seg_name kbody = do evalStateT ( traverseKernelBodyArrayIndexes seg_name (ensureTransformedAccess perm_table) kbody ) mempty transformKernelBody :: (Transform rep, BuilderOps rep) => LayoutTable -> ExpMap rep -> VName -> KernelBody rep -> TransformM rep (KernelBody rep) transformKernelBody perm_table expmap seg_name (KernelBody b stms res) = do stms' <- transformStms perm_table expmap stms evalStateT ( traverseKernelBodyArrayIndexes seg_name (ensureTransformedAccess perm_table) (KernelBody b stms' res) ) mempty traverseKernelBodyArrayIndexes :: forall m rep. (Monad m, Transform rep) => VName -> -- seg_name ArrayIndexTransform m -> KernelBody rep -> m (KernelBody rep) traverseKernelBodyArrayIndexes seg_name coalesce (KernelBody b kstms kres) = KernelBody b . stmsFromList <$> mapM onStm (stmsToList kstms) <*> pure kres where onLambda lam = (\body' -> lam {lambdaBody = body'}) <$> onBody (lambdaBody lam) onBody (Body bdec stms bres) = do stms' <- stmsFromList <$> mapM onStm (stmsToList stms) pure $ Body bdec stms' bres onStm (Let pat dec (BasicOp (Index arr is))) = Let pat dec . oldOrNew <$> coalesce seg_name patternName arr is where oldOrNew Nothing = BasicOp $ Index arr is oldOrNew (Just (arr', is')) = BasicOp $ Index arr' is' patternName = patElemName . head $ patElems pat onStm (Let pat dec e) = Let pat dec <$> mapExpM mapper e soac_mapper = identitySOACMapper { mapOnSOACLambda = onLambda } mapper = (identityMapper @rep) { mapOnBody = const onBody, mapOnOp = onOp soac_mapper } -- | Used to keep track of which pairs of arrays and permutations we have -- already created manifests for, in order to avoid duplicates. type Replacements = M.Map (VName, Permutation) VName type ArrayIndexTransform m = VName -> -- seg_name (name of the SegThread expression's pattern) VName -> -- idx_name (name of the Index expression's pattern) VName -> -- arr (name of the array) Slice SubExp -> -- slice m (Maybe (VName, Slice SubExp)) ensureTransformedAccess :: (MonadBuilder m) => LayoutTable -> ArrayIndexTransform (StateT Replacements m) ensureTransformedAccess perm_table seg_name idx_name arr slice = do -- Check if the array has the optimal layout in memory. -- If it does not, replace it with a manifest to allocate -- it with the optimal layout case lookupPermutation perm_table seg_name idx_name arr of Nothing -> pure $ Just (arr, slice) Just perm -> do seen <- gets $ M.lookup (arr, perm) case seen of -- Already created a manifest for this array + permutation. -- So, just replace the name and don't make a new manifest. Just arr' -> pure $ Just (arr', slice) Nothing -> replace perm =<< lift (manifest perm arr) where replace perm arr' = do -- Store the fact that we have seen this array + permutation -- so we don't make duplicate manifests modify $ M.insert (arr, perm) arr' -- Return the new manifest pure $ Just (arr', slice) manifest perm array = letExp (baseString array ++ "_coalesced") $ BasicOp (Manifest perm array) lookupPermutation :: LayoutTable -> VName -> IndexExprName -> VName -> Maybe Permutation lookupPermutation perm_table seg_name idx_name arr_name = case M.lookup seg_name (M.mapKeys vnameFromSegOp perm_table) of Nothing -> Nothing Just arrayNameMap -> -- Look for the current array case M.lookup arr_name (M.mapKeys (\(n, _, _) -> n) arrayNameMap) of Nothing -> Nothing Just idxs -> M.lookup idx_name idxs transformStm :: (Transform rep, BuilderOps rep) => (LayoutTable, ExpMap rep) -> Stm rep -> TransformM rep (LayoutTable, ExpMap rep) transformStm (perm_table, expmap) (Let pat aux (Op op)) = transformOp perm_table expmap (Let pat aux (Op op)) op transformStm (perm_table, expmap) (Let pat aux e) = do e' <- mapExpM (transform perm_table expmap) e let stm' = Let pat aux e' addStm stm' pure (perm_table, M.fromList [(name, stm') | name <- patNames pat] <> expmap) transformStms :: (Transform rep, BuilderOps rep) => LayoutTable -> ExpMap rep -> Stms rep -> TransformM rep (Stms rep) transformStms perm_table expmap stms = collectStms_ $ foldM_ transformStm (perm_table, expmap) stms futhark-0.25.27/src/Futhark/Optimise/ArrayShortCircuiting.hs000066400000000000000000000201121475065116200237750ustar00rootroot00000000000000{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE TypeFamilies #-} -- | Perform array short circuiting module Futhark.Optimise.ArrayShortCircuiting ( optimiseSeqMem, optimiseGPUMem, optimiseMCMem, ) where import Control.Monad import Control.Monad.Reader import Data.Function ((&)) import Data.List qualified as L import Data.Map qualified as M import Data.Maybe (fromMaybe) import Futhark.Analysis.Alias qualified as AnlAls import Futhark.IR.Aliases import Futhark.IR.GPUMem import Futhark.IR.MCMem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.IR.SeqMem import Futhark.Optimise.ArrayShortCircuiting.ArrayCoalescing import Futhark.Optimise.ArrayShortCircuiting.DataStructs import Futhark.Pass (Pass (..)) import Futhark.Pass qualified as Pass import Futhark.Util data Env inner = Env { envCoalesceTab :: CoalsTab, onInner :: inner -> UpdateM inner inner, memAllocsToRemove :: Names } type UpdateM inner a = Reader (Env inner) a optimiseSeqMem :: Pass SeqMem SeqMem optimiseSeqMem = pass "short-circuit" "Array Short-Circuiting" mkCoalsTab pure replaceInParams optimiseGPUMem :: Pass GPUMem GPUMem optimiseGPUMem = pass "short-circuit-gpu" "Array Short-Circuiting (GPU)" mkCoalsTabGPU replaceInHostOp replaceInParams optimiseMCMem :: Pass MCMem MCMem optimiseMCMem = pass "short-circuit-mc" "Array Short-Circuiting (MC)" mkCoalsTabMC replaceInMCOp replaceInParams replaceInParams :: CoalsTab -> [Param FParamMem] -> (Names, [Param FParamMem]) replaceInParams coalstab fparams = let (mem_allocs_to_remove, fparams') = foldl replaceInParam (mempty, mempty) fparams in (mem_allocs_to_remove, reverse fparams') where replaceInParam (to_remove, acc) (Param attrs name dec) = case dec of MemMem _ | Just entry <- M.lookup name coalstab -> (oneName (dstmem entry) <> to_remove, Param attrs (dstmem entry) dec : acc) MemArray pt shp u (ArrayIn m ixf) | Just entry <- M.lookup m coalstab -> (to_remove, Param attrs name (MemArray pt shp u $ ArrayIn (dstmem entry) ixf) : acc) _ -> (to_remove, Param attrs name dec : acc) removeAllocsInStms :: Stms rep -> UpdateM inner (Stms rep) removeAllocsInStms stms = do to_remove <- asks memAllocsToRemove stmsToList stms & filter (not . flip nameIn to_remove . head . patNames . stmPat) & stmsFromList & pure pass :: (Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) => String -> String -> (Prog (Aliases rep) -> Pass.PassM (M.Map Name CoalsTab)) -> (inner rep -> UpdateM (inner rep) (inner rep)) -> (CoalsTab -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])) -> Pass rep rep pass flag desc mk on_inner on_fparams = Pass flag desc $ \prog -> do coaltabs <- mk $ AnlAls.aliasAnalysis prog Pass.intraproceduralTransformationWithConsts pure (onFun coaltabs) prog where onFun coaltabs _ f = do let coaltab = coaltabs M.! funDefName f let (mem_allocs_to_remove, new_fparams) = on_fparams coaltab $ funDefParams f pure $ f { funDefBody = onBody coaltab mem_allocs_to_remove $ funDefBody f, funDefParams = new_fparams } onBody coaltab mem_allocs_to_remove body = body { bodyStms = runReader (updateStms $ bodyStms body) (Env coaltab on_inner mem_allocs_to_remove), bodyResult = map (replaceResMem coaltab) $ bodyResult body } replaceResMem :: CoalsTab -> SubExpRes -> SubExpRes replaceResMem coaltab res = case flip M.lookup coaltab =<< subExpResVName res of Just entry -> res {resSubExp = Var $ dstmem entry} Nothing -> res updateStms :: (Mem rep inner, LetDec rep ~ LetDecMem) => Stms rep -> UpdateM (inner rep) (Stms rep) updateStms stms = do stms' <- mapM replaceInStm stms removeAllocsInStms stms' replaceInStm :: (Mem rep inner, LetDec rep ~ LetDecMem) => Stm rep -> UpdateM (inner rep) (Stm rep) replaceInStm (Let (Pat elems) (StmAux c a d) e) = do elems' <- mapM replaceInPatElem elems e' <- replaceInExp elems' e entries <- asks (M.elems . envCoalesceTab) let c' = case filter (\entry -> (map patElemName elems `L.intersect` M.keys (vartab entry)) /= []) entries of [] -> c entries' -> c <> foldMap certs entries' pure $ Let (Pat elems') (StmAux c' a d) e' where replaceInPatElem :: PatElem LetDecMem -> UpdateM inner (PatElem LetDecMem) replaceInPatElem p@(PatElem vname (MemArray _ _ u _)) = fromMaybe p <$> lookupAndReplace vname PatElem u replaceInPatElem p = pure p replaceInExp :: (Mem rep inner, LetDec rep ~ LetDecMem) => [PatElem LetDecMem] -> Exp rep -> UpdateM (inner rep) (Exp rep) replaceInExp _ e@(BasicOp _) = pure e replaceInExp pat_elems (Match cond_ses cases defbody dec) = do defbody' <- replaceInIfBody defbody cases' <- mapM (\(Case p b) -> Case p <$> replaceInIfBody b) cases case_rets <- zipWithM (generalizeIxfun pat_elems) pat_elems $ matchReturns dec let dec' = dec {matchReturns = case_rets} pure $ Match cond_ses cases' defbody' dec' replaceInExp _ (Loop loop_inits loop_form (Body dec stms res)) = do loop_inits' <- mapM (replaceInFParam . fst) loop_inits stms' <- updateStms stms coalstab <- asks envCoalesceTab let res' = map (replaceResMem coalstab) res pure $ Loop (zip loop_inits' $ map snd loop_inits) loop_form $ Body dec stms' res' replaceInExp _ (Op op) = case op of Inner i -> do on_op <- asks onInner Op . Inner <$> on_op i _ -> pure $ Op op replaceInExp _ e@WithAcc {} = pure e replaceInExp _ e@Apply {} = pure e replaceInSegOp :: (Mem rep inner, LetDec rep ~ LetDecMem) => SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep) replaceInSegOp (SegMap lvl sp tps body) = do stms <- updateStms $ kernelBodyStms body pure $ SegMap lvl sp tps $ body {kernelBodyStms = stms} replaceInSegOp (SegRed lvl sp binops tps body) = do stms <- updateStms $ kernelBodyStms body pure $ SegRed lvl sp binops tps $ body {kernelBodyStms = stms} replaceInSegOp (SegScan lvl sp binops tps body) = do stms <- updateStms $ kernelBodyStms body pure $ SegScan lvl sp binops tps $ body {kernelBodyStms = stms} replaceInSegOp (SegHist lvl sp hist_ops tps body) = do stms <- updateStms $ kernelBodyStms body pure $ SegHist lvl sp hist_ops tps $ body {kernelBodyStms = stms} replaceInHostOp :: HostOp NoOp GPUMem -> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem) replaceInHostOp (SegOp op) = SegOp <$> replaceInSegOp op replaceInHostOp op = pure op replaceInMCOp :: MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem) replaceInMCOp (ParOp par_op op) = ParOp <$> traverse replaceInSegOp par_op <*> replaceInSegOp op replaceInMCOp op = pure op generalizeIxfun :: [PatElem dec] -> PatElem LetDecMem -> BodyReturns -> UpdateM inner BodyReturns generalizeIxfun pat_elems (PatElem vname (MemArray _ _ _ (ArrayIn mem ixf))) m@(MemArray pt shp u _) = do coaltab <- asks envCoalesceTab if any (M.member vname . vartab) coaltab then existentialiseLMAD (map patElemName pat_elems) ixf & ReturnsInBlock mem & MemArray pt shp u & pure else pure m generalizeIxfun _ _ m = pure m replaceInIfBody :: (Mem rep inner, LetDec rep ~ LetDecMem) => Body rep -> UpdateM (inner rep) (Body rep) replaceInIfBody b@(Body _ stms res) = do coaltab <- asks envCoalesceTab stms' <- updateStms stms pure $ b {bodyStms = stms', bodyResult = map (replaceResMem coaltab) res} replaceInFParam :: Param FParamMem -> UpdateM inner (Param FParamMem) replaceInFParam p@(Param _ vname (MemArray _ _ u _)) = do fromMaybe p <$> lookupAndReplace vname (Param mempty) u replaceInFParam p = pure p lookupAndReplace :: VName -> (VName -> MemBound u -> a) -> u -> UpdateM inner (Maybe a) lookupAndReplace vname f u = do coaltab <- asks envCoalesceTab case M.lookup vname $ foldMap vartab coaltab of Just (Coalesced _ (MemBlock pt shp mem ixf) subs) -> ixf & fixPoint (LMAD.substitute subs) & ArrayIn mem & MemArray pt shp u & f vname & Just & pure Nothing -> pure Nothing futhark-0.25.27/src/Futhark/Optimise/ArrayShortCircuiting/000077500000000000000000000000001475065116200234455ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Optimise/ArrayShortCircuiting/ArrayCoalescing.hs000066400000000000000000002311541475065116200270550ustar00rootroot00000000000000{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} -- | The bulk of the short-circuiting implementation. module Futhark.Optimise.ArrayShortCircuiting.ArrayCoalescing ( mkCoalsTab, CoalsTab, mkCoalsTabGPU, mkCoalsTabMC, ) where import Control.Exception.Base qualified as Exc import Control.Monad import Control.Monad.Reader import Control.Monad.State.Strict import Data.Function ((&)) import Data.List qualified as L import Data.List.NonEmpty (NonEmpty (..)) import Data.Map.Strict qualified as M import Data.Maybe import Data.Sequence (Seq (..)) import Data.Set qualified as S import Futhark.Analysis.LastUse import Futhark.Analysis.PrimExp.Convert import Futhark.IR.Aliases import Futhark.IR.GPUMem as GPU import Futhark.IR.MCMem as MC import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.IR.SeqMem import Futhark.MonadFreshNames import Futhark.Optimise.ArrayShortCircuiting.DataStructs import Futhark.Optimise.ArrayShortCircuiting.MemRefAggreg import Futhark.Optimise.ArrayShortCircuiting.TopdownAnalysis import Futhark.Util -- | A helper type describing representations that can be short-circuited. type Coalesceable rep inner = ( Mem rep inner, ASTRep rep, CanBeAliased inner, AliasableRep rep, Op rep ~ MemOp inner rep, HasMemBlock (Aliases rep), LetDec rep ~ LetDecMem, TopDownHelper (inner (Aliases rep)) ) type ComputeScalarTable rep op = ScopeTab rep -> op -> ScalarTableM rep (M.Map VName (PrimExp VName)) -- Helper type for computing scalar tables on ops. newtype ComputeScalarTableOnOp rep = ComputeScalarTableOnOp { scalarTableOnOp :: ComputeScalarTable rep (Op (Aliases rep)) } type ScalarTableM rep a = Reader (ComputeScalarTableOnOp rep) a data ShortCircuitReader rep = ShortCircuitReader { onOp :: LUTabFun -> Pat (VarAliases, LetDecMem) -> Certs -> Op (Aliases rep) -> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv, ssPointFromOp :: LUTabFun -> TopdownEnv rep -> ScopeTab rep -> Pat (VarAliases, LetDecMem) -> Certs -> Op (Aliases rep) -> Maybe [SSPointInfo] } newtype ShortCircuitM rep a = ShortCircuitM (ReaderT (ShortCircuitReader rep) (State VNameSource) a) deriving (Functor, Applicative, Monad, MonadReader (ShortCircuitReader rep), MonadState VNameSource) instance MonadFreshNames (ShortCircuitM rep) where putNameSource = put getNameSource = get emptyTopdownEnv :: TopdownEnv rep emptyTopdownEnv = TopdownEnv { alloc = mempty, scope = mempty, inhibited = mempty, v_alias = mempty, m_alias = mempty, nonNegatives = mempty, scalarTable = mempty, knownLessThan = mempty, td_asserts = mempty } emptyBotUpEnv :: BotUpEnv emptyBotUpEnv = BotUpEnv { scals = mempty, activeCoals = mempty, successCoals = mempty, inhibit = mempty } -------------------------------------------------------------------------------- --- Main Coalescing Transformation computes a successful coalescing table --- -------------------------------------------------------------------------------- -- | Given a 'Prog' in 'SegMem' representation, compute the coalescing table -- by folding over each function. mkCoalsTab :: (MonadFreshNames m) => Prog (Aliases SeqMem) -> m (M.Map Name CoalsTab) mkCoalsTab prog = mkCoalsTabProg (lastUseSeqMem prog) (ShortCircuitReader shortCircuitSeqMem genSSPointInfoSeqMem) (ComputeScalarTableOnOp $ const $ const $ pure mempty) prog -- | Given a 'Prog' in 'GPUMem' representation, compute the coalescing table -- by folding over each function. mkCoalsTabGPU :: (MonadFreshNames m) => Prog (Aliases GPUMem) -> m (M.Map Name CoalsTab) mkCoalsTabGPU prog = mkCoalsTabProg (lastUseGPUMem prog) (ShortCircuitReader shortCircuitGPUMem genSSPointInfoGPUMem) (ComputeScalarTableOnOp (computeScalarTableMemOp computeScalarTableGPUMem)) prog -- | Given a 'Prog' in 'MCMem' representation, compute the coalescing table -- by folding over each function. mkCoalsTabMC :: (MonadFreshNames m) => Prog (Aliases MCMem) -> m (M.Map Name CoalsTab) mkCoalsTabMC prog = mkCoalsTabProg (lastUseMCMem prog) (ShortCircuitReader shortCircuitMCMem genSSPointInfoMCMem) (ComputeScalarTableOnOp (computeScalarTableMemOp computeScalarTableMCMem)) prog -- | Given a function, compute the coalescing table mkCoalsTabProg :: (MonadFreshNames m, Coalesceable rep inner) => LUTabProg -> ShortCircuitReader rep -> ComputeScalarTableOnOp rep -> Prog (Aliases rep) -> m (M.Map Name CoalsTab) mkCoalsTabProg (_, lutab_prog) r computeScalarOnOp prog = fmap M.fromList . mapM onFun . progFuns $ prog where consts_scope = scopeOf (progConsts prog) onFun fun@(FunDef _ _ fname _ fpars body) = do -- First compute last-use information let unique_mems = getUniqueMemFParam fpars lutab = lutab_prog M.! fname scalar_table = runReader ( concatMapM ( computeScalarTable $ consts_scope <> scopeOf fun <> scopeOf (bodyStms body) ) (stmsToList $ bodyStms body) ) computeScalarOnOp topenv = emptyTopdownEnv { scope = consts_scope <> scopeOfFParams fpars, alloc = unique_mems, scalarTable = scalar_table, nonNegatives = foldMap paramSizes fpars } ShortCircuitM m = fixPointCoalesce lutab fpars body topenv (fname,) <$> modifyNameSource (runState (runReaderT m r)) paramSizes :: Param FParamMem -> Names paramSizes (Param _ _ (MemArray _ shp _ _)) = freeIn shp paramSizes _ = mempty -- | Short-circuit handler for a 'SeqMem' 'Op'. -- -- Because 'SeqMem' don't have any special operation, simply return the input -- 'BotUpEnv'. shortCircuitSeqMem :: LUTabFun -> Pat (VarAliases, LetDecMem) -> Certs -> Op (Aliases SeqMem) -> TopdownEnv SeqMem -> BotUpEnv -> ShortCircuitM SeqMem BotUpEnv shortCircuitSeqMem _ _ _ _ _ = pure -- | Short-circuit handler for SegOp. shortCircuitSegOp :: (Coalesceable rep inner) => (lvl -> Bool) -> LUTabFun -> Pat (VarAliases, LetDecMem) -> Certs -> SegOp lvl (Aliases rep) -> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv shortCircuitSegOp lvlOK lutab pat pat_certs (SegMap lvl space _ kernel_body) td_env bu_env = -- No special handling necessary for 'SegMap'. Just call the helper-function. shortCircuitSegOpHelper 0 lvlOK lvl lutab pat pat_certs space kernel_body td_env bu_env shortCircuitSegOp lvlOK lutab pat pat_certs (SegRed lvl space binops _ kernel_body) td_env bu_env = -- When handling 'SegRed', we we first invalidate all active coalesce-entries -- where any of the variables in 'vartab' are also free in the list of -- 'SegBinOp'. In other words, anything that is used as part of the reduction -- step should probably not be coalesced. let to_fail = M.filter (\entry -> namesFromList (M.keys $ vartab entry) `namesIntersect` foldMap (freeIn . segBinOpLambda) binops) $ activeCoals bu_env (active, inh) = foldl markFailedCoal (activeCoals bu_env, inhibit bu_env) $ M.keys to_fail bu_env' = bu_env {activeCoals = active, inhibit = inh} num_reds = length red_ts in shortCircuitSegOpHelper num_reds lvlOK lvl lutab pat pat_certs space kernel_body td_env bu_env' where segment_dims = init $ segSpaceDims space red_ts = do op <- binops let shp = Shape segment_dims <> segBinOpShape op map (`arrayOfShape` shp) (lambdaReturnType $ segBinOpLambda op) shortCircuitSegOp lvlOK lutab pat pat_certs (SegScan lvl space binops _ kernel_body) td_env bu_env = -- Like in the handling of 'SegRed', we do not want to coalesce anything that -- is used in the 'SegBinOp' let to_fail = M.filter (\entry -> namesFromList (M.keys $ vartab entry) `namesIntersect` foldMap (freeIn . segBinOpLambda) binops) $ activeCoals bu_env (active, inh) = foldl markFailedCoal (activeCoals bu_env, inhibit bu_env) $ M.keys to_fail bu_env' = bu_env {activeCoals = active, inhibit = inh} in shortCircuitSegOpHelper 0 lvlOK lvl lutab pat pat_certs space kernel_body td_env bu_env' shortCircuitSegOp lvlOK lutab pat pat_certs (SegHist lvl space histops _ kernel_body) td_env bu_env = do -- Need to take zipped patterns and histDest (flattened) and insert transitive coalesces let to_fail = M.filter (\entry -> namesFromList (M.keys $ vartab entry) `namesIntersect` foldMap (freeIn . histOp) histops) $ activeCoals bu_env (active, inh) = foldl markFailedCoal (activeCoals bu_env, inhibit bu_env) $ M.keys to_fail bu_env' = bu_env {activeCoals = active, inhibit = inh} bu_env'' <- shortCircuitSegOpHelper 0 lvlOK lvl lutab pat pat_certs space kernel_body td_env bu_env' pure $ foldl insertHistCoals bu_env'' $ zip (patElems pat) $ concatMap histDest histops where insertHistCoals acc (PatElem p _, hist_dest) = case ( getScopeMemInfo p $ scope td_env, getScopeMemInfo hist_dest $ scope td_env ) of (Just (MemBlock _ _ p_mem _), Just (MemBlock _ _ dest_mem _)) -> case M.lookup p_mem $ successCoals acc of Just entry -> -- Update this entry with an optdep for the memory block of hist_dest let entry' = entry {optdeps = M.insert p p_mem $ optdeps entry} in acc { successCoals = M.insert p_mem entry' $ successCoals acc, activeCoals = M.insert dest_mem entry $ activeCoals acc } Nothing -> acc _ -> acc -- | Short-circuit handler for 'GPUMem' 'Op'. -- -- When the 'Op' is a 'SegOp', we handle it accordingly, otherwise we do -- nothing. shortCircuitGPUMem :: LUTabFun -> Pat (VarAliases, LetDecMem) -> Certs -> Op (Aliases GPUMem) -> TopdownEnv GPUMem -> BotUpEnv -> ShortCircuitM GPUMem BotUpEnv shortCircuitGPUMem _ _ _ (Alloc _ _) _ bu_env = pure bu_env shortCircuitGPUMem lutab pat certs (Inner (GPU.SegOp op)) td_env bu_env = shortCircuitSegOp isSegThread lutab pat certs op td_env bu_env shortCircuitGPUMem lutab pat certs (Inner (GPU.GPUBody _ body)) td_env bu_env = do fresh1 <- newNameFromString "gpubody" fresh2 <- newNameFromString "gpubody" shortCircuitSegOpHelper 0 isSegThread -- Construct a 'SegLevel' corresponding to a single thread ( GPU.SegThread GPU.SegNoVirt $ Just $ GPU.KernelGrid (GPU.Count $ Constant $ IntValue $ Int64Value 1) (GPU.Count $ Constant $ IntValue $ Int64Value 1) ) lutab pat certs (SegSpace fresh1 [(fresh2, Constant $ IntValue $ Int64Value 1)]) (bodyToKernelBody body) td_env bu_env shortCircuitGPUMem _ _ _ (Inner (GPU.SizeOp _)) _ bu_env = pure bu_env shortCircuitGPUMem _ _ _ (Inner (GPU.OtherOp NoOp)) _ bu_env = pure bu_env shortCircuitMCMem :: LUTabFun -> Pat (VarAliases, LetDecMem) -> Certs -> Op (Aliases MCMem) -> TopdownEnv MCMem -> BotUpEnv -> ShortCircuitM MCMem BotUpEnv shortCircuitMCMem _ _ _ (Alloc _ _) _ bu_env = pure bu_env shortCircuitMCMem _ _ _ (Inner (MC.OtherOp NoOp)) _ bu_env = pure bu_env shortCircuitMCMem lutab pat certs (Inner (MC.ParOp (Just par_op) op)) td_env bu_env = shortCircuitSegOp (const True) lutab pat certs par_op td_env bu_env >>= shortCircuitSegOp (const True) lutab pat certs op td_env shortCircuitMCMem lutab pat certs (Inner (MC.ParOp Nothing op)) td_env bu_env = shortCircuitSegOp (const True) lutab pat certs op td_env bu_env dropLastSegSpace :: SegSpace -> SegSpace dropLastSegSpace space = space {unSegSpace = init $ unSegSpace space} isSegThread :: GPU.SegLevel -> Bool isSegThread GPU.SegThread {} = True isSegThread _ = False -- | Computes the slice written at the end of a thread in a 'SegOp'. threadSlice :: SegSpace -> KernelResult -> Maybe (Slice (TPrimExp Int64 VName)) threadSlice space Returns {} = Just $ Slice $ map (DimFix . TPrimExp . flip LeafExp (IntType Int64) . fst) $ unSegSpace space threadSlice space (RegTileReturns _ dims _) = Just $ Slice $ zipWith ( \(_, block_tile_size0, reg_tile_size0) (x0, _) -> let x = pe64 $ Var x0 block_tile_size = pe64 block_tile_size0 reg_tile_size = pe64 reg_tile_size0 in DimSlice (x * block_tile_size * reg_tile_size) (block_tile_size * reg_tile_size) 1 ) dims $ unSegSpace space threadSlice _ _ = Nothing bodyToKernelBody :: Body (Aliases GPUMem) -> KernelBody (Aliases GPUMem) bodyToKernelBody (Body dec stms res) = KernelBody dec stms $ map (\(SubExpRes cert subexps) -> Returns ResultNoSimplify cert subexps) res -- | A helper for all the different kinds of 'SegOp'. -- -- Consists of four parts: -- -- 1. Create coalescing relations between the pattern elements and the kernel -- body results using 'makeSegMapCoals'. -- -- 2. Process the statements of the 'KernelBody'. -- -- 3. Check the overlap between the different threads. -- -- 4. Mark active coalescings as finished, since a 'SegOp' is an array creation -- point. shortCircuitSegOpHelper :: (Coalesceable rep inner) => -- | The number of returns for which we should drop the last seg space Int -> -- | Whether we should look at a segop with this lvl. (lvl -> Bool) -> lvl -> LUTabFun -> Pat (VarAliases, LetDecMem) -> Certs -> SegSpace -> KernelBody (Aliases rep) -> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv shortCircuitSegOpHelper num_reds lvlOK lvl lutab pat@(Pat ps0) pat_certs space0 kernel_body td_env bu_env = do -- We need to drop the last element of the 'SegSpace' for pattern elements -- that correspond to reductions. let ps_space_and_res = zip3 ps0 (replicate num_reds (dropLastSegSpace space0) <> repeat space0) $ kernelBodyResult kernel_body -- Create coalescing relations between pattern elements and kernel body -- results let (actv0, inhibit0) = filterSafetyCond2and5 (activeCoals bu_env) (inhibit bu_env) (scals bu_env) td_env (patElems pat) (actv_return, inhibit_return) = if num_reds > 0 then (actv0, inhibit0) else foldl (makeSegMapCoals lvlOK lvl td_env kernel_body pat_certs) (actv0, inhibit0) ps_space_and_res -- Start from empty references, we'll update with aggregates later. let actv0' = M.map (\etry -> etry {memrefs = mempty}) $ actv0 <> actv_return -- Process kernel body statements bu_env' <- mkCoalsTabStms lutab (kernelBodyStms kernel_body) td_env $ bu_env {activeCoals = actv0', inhibit = inhibit_return} let actv_coals_after = M.mapWithKey ( \k etry -> etry { memrefs = memrefs etry <> maybe mempty memrefs (M.lookup k $ actv0 <> actv_return) } ) $ activeCoals bu_env' -- Check partial overlap. let checkPartialOverlap bu_env_f (k, entry) = do let sliceThreadAccess (p, space, res) = case M.lookup (patElemName p) $ vartab entry of Just (Coalesced _ (MemBlock _ _ _ ixf) _) -> maybe Undeterminable ( ixfunToAccessSummary . LMAD.slice ixf . fullSlice (LMAD.shape ixf) ) $ threadSlice space res Nothing -> mempty thread_writes = foldMap sliceThreadAccess ps_space_and_res source_writes = srcwrts (memrefs entry) <> thread_writes destination_uses <- case dstrefs (memrefs entry) `accessSubtract` dstrefs (maybe mempty memrefs $ M.lookup k $ activeCoals bu_env) of Set s -> concatMapM (aggSummaryMapPartial (scalarTable td_env) $ unSegSpace space0) (S.toList s) Undeterminable -> pure Undeterminable -- Do not allow short-circuiting from a segop-shared memory -- block (not in the topdown scope) to an outer memory block. if dstmem entry `M.member` scope td_env && noMemOverlap td_env destination_uses source_writes then pure bu_env_f else do let (ac, inh) = markFailedCoal (activeCoals bu_env_f, inhibit bu_env_f) k pure $ bu_env_f {activeCoals = ac, inhibit = inh} bu_env'' <- foldM checkPartialOverlap (bu_env' {activeCoals = actv_coals_after}) $ M.toList actv_coals_after let updateMemRefs entry = do wrts <- aggSummaryMapTotal (scalarTable td_env) (unSegSpace space0) $ srcwrts $ memrefs entry uses <- aggSummaryMapTotal (scalarTable td_env) (unSegSpace space0) $ dstrefs $ memrefs entry -- Add destination uses from the pattern let uses' = foldMap ( \case PatElem _ (_, MemArray _ _ _ (ArrayIn p_mem p_ixf)) | p_mem `nameIn` alsmem entry -> ixfunToAccessSummary p_ixf _ -> mempty ) ps0 pure $ entry {memrefs = MemRefs (uses <> uses') wrts} actv <- mapM updateMemRefs $ activeCoals bu_env'' let bu_env''' = bu_env'' {activeCoals = actv} -- Process pattern and return values let mergee_writes = mapMaybe ( \(p, _, _) -> fmap (p,) $ getDirAliasedIxfn' td_env (activeCoals bu_env''') $ patElemName p ) ps_space_and_res -- Now, for each mergee write, we need to check that it doesn't overlap with any previous uses of the destination. let checkMergeeOverlap bu_env_f (p, (m_b, _, ixf)) = let as = ixfunToAccessSummary ixf in -- Should be @bu_env@ here, because we need to check overlap -- against previous uses. case M.lookup m_b $ activeCoals bu_env of Just coal_entry -> do let mrefs = memrefs coal_entry res = noMemOverlap td_env as $ dstrefs mrefs fail_res = let (ac, inh) = markFailedCoal (activeCoals bu_env_f, inhibit bu_env_f) m_b in bu_env_f {activeCoals = ac, inhibit = inh} if res then case M.lookup (patElemName p) $ vartab coal_entry of Nothing -> pure bu_env_f Just (Coalesced knd mbd@(MemBlock _ _ _ ixfn) _) -> pure $ case freeVarSubstitutions (scope td_env) (scalarTable td_env) ixfn of Just fv_subst -> let entry = coal_entry { vartab = M.insert (patElemName p) (Coalesced knd mbd fv_subst) (vartab coal_entry) } (ac, suc) = markSuccessCoal (activeCoals bu_env_f, successCoals bu_env_f) m_b entry in bu_env_f {activeCoals = ac, successCoals = suc} Nothing -> fail_res else pure fail_res _ -> pure bu_env_f foldM checkMergeeOverlap bu_env''' mergee_writes -- | Given a pattern element and the corresponding kernel result, try to put the -- kernel result directly in the memory block of pattern element makeSegMapCoals :: (Coalesceable rep inner) => (lvl -> Bool) -> lvl -> TopdownEnv rep -> KernelBody (Aliases rep) -> Certs -> (CoalsTab, InhibitTab) -> (PatElem (VarAliases, LetDecMem), SegSpace, KernelResult) -> (CoalsTab, InhibitTab) makeSegMapCoals lvlOK lvl td_env kernel_body pat_certs (active, inhb) (PatElem pat_name (_, MemArray _ _ _ (ArrayIn pat_mem pat_ixf)), space, Returns _ _ (Var return_name)) | Just (MemBlock tp return_shp return_mem _) <- getScopeMemInfo return_name $ scope td_env <> scopeOf (kernelBodyStms kernel_body), lvlOK lvl, MemMem pat_space <- runReader (lookupMemInfo pat_mem) $ removeScopeAliases $ scope td_env, MemMem return_space <- scope td_env <> scopeOf (kernelBodyStms kernel_body) <> scopeOfSegSpace space & removeScopeAliases & runReader (lookupMemInfo return_mem), pat_space == return_space = case M.lookup pat_mem active of Nothing -> -- We are not in a transitive case case ( maybe False (pat_mem `nameIn`) (M.lookup return_mem inhb), Coalesced InPlaceCoal (MemBlock tp return_shp pat_mem $ resultSlice pat_ixf) mempty & M.singleton return_name & flip (addInvAliasesVarTab td_env) return_name ) of (False, Just vtab) -> ( active <> M.singleton return_mem (CoalsEntry pat_mem pat_ixf (oneName pat_mem) vtab mempty mempty pat_certs), inhb ) _ -> (active, inhb) Just trans -> case ( maybe False (dstmem trans `nameIn`) $ M.lookup return_mem inhb, let Coalesced _ (MemBlock _ _ trans_mem trans_ixf) _ = fromMaybe (error "Impossible") $ M.lookup pat_name $ vartab trans in Coalesced TransitiveCoal (MemBlock tp return_shp trans_mem $ resultSlice trans_ixf) mempty & M.singleton return_name & flip (addInvAliasesVarTab td_env) return_name ) of (False, Just vtab) -> let opts = if dstmem trans == pat_mem then mempty else M.insert pat_name pat_mem $ optdeps trans in ( M.insert return_mem ( CoalsEntry (dstmem trans) (dstind trans) (oneName pat_mem <> alsmem trans) vtab opts mempty (certs trans <> pat_certs) ) active, inhb ) _ -> (active, inhb) where thread_slice = unSegSpace space & map (DimFix . TPrimExp . flip LeafExp (IntType Int64) . fst) & Slice resultSlice ixf = LMAD.slice ixf $ fullSlice (LMAD.shape ixf) thread_slice makeSegMapCoals _ _ td_env _ _ x (_, _, WriteReturns _ return_name _) = case getScopeMemInfo return_name $ scope td_env of Just (MemBlock _ _ return_mem _) -> markFailedCoal x return_mem Nothing -> error "Should not happen?" makeSegMapCoals _ _ td_env _ _ x (_, _, result) = freeIn result & namesToList & mapMaybe (flip getScopeMemInfo $ scope td_env) & foldr (flip markFailedCoal . memName) x fullSlice :: [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName) fullSlice shp (Slice slc) = Slice $ slc ++ map (\d -> DimSlice 0 d 1) (drop (length slc) shp) fixPointCoalesce :: (Coalesceable rep inner) => LUTabFun -> [Param FParamMem] -> Body (Aliases rep) -> TopdownEnv rep -> ShortCircuitM rep CoalsTab fixPointCoalesce lutab fpar bdy topenv = do buenv <- mkCoalsTabStms lutab (bodyStms bdy) topenv (emptyBotUpEnv {inhibit = inhibited topenv}) let succ_tab = successCoals buenv actv_tab = activeCoals buenv inhb_tab = inhibit buenv -- Allow short-circuiting function parameters that are unique and have -- matching index functions, otherwise mark as failed handleFunctionParams (a, i, s) (_, u, MemBlock _ _ m ixf) = case (u, M.lookup m a) of (Unique, Just entry) | dstind entry == ixf, Set dst_uses <- dstrefs (memrefs entry), dst_uses == mempty -> let (a', s') = markSuccessCoal (a, s) m entry in (a', i, s') _ -> let (a', i') = markFailedCoal (a, i) m in (a', i', s) (actv_tab', inhb_tab', succ_tab') = foldl handleFunctionParams (actv_tab, inhb_tab, succ_tab) $ getArrMemAssocFParam fpar (succ_tab'', failed_optdeps) = fixPointFilterDeps succ_tab' M.empty inhb_tab'' = M.unionWith (<>) failed_optdeps inhb_tab' if not $ M.null actv_tab' then error ("COALESCING ROOT: BROKEN INV, active not empty: " ++ show (M.keys actv_tab')) else if M.null $ inhb_tab'' `M.difference` inhibited topenv then pure succ_tab'' else fixPointCoalesce lutab fpar bdy (topenv {inhibited = inhb_tab''}) where fixPointFilterDeps :: CoalsTab -> InhibitTab -> (CoalsTab, InhibitTab) fixPointFilterDeps coaltab inhbtab = let (coaltab', inhbtab') = foldl filterDeps (coaltab, inhbtab) (M.keys coaltab) in if length (M.keys coaltab) == length (M.keys coaltab') then (coaltab', inhbtab') else fixPointFilterDeps coaltab' inhbtab' filterDeps (coal, inhb) mb | not (M.member mb coal) = (coal, inhb) filterDeps (coal, inhb) mb | Just coal_etry <- M.lookup mb coal = let failed = M.filterWithKey (failedOptDep coal) (optdeps coal_etry) in if M.null failed then (coal, inhb) -- all ok else -- optimistic dependencies failed for the current -- memblock; extend inhibited mem-block mergings. markFailedCoal (coal, inhb) mb filterDeps _ _ = error "In ArrayCoalescing.hs, fun filterDeps, impossible case reached!" failedOptDep coal _ mr | not (mr `M.member` coal) = True failedOptDep coal r mr | Just coal_etry <- M.lookup mr coal = not $ r `M.member` vartab coal_etry failedOptDep _ _ _ = error "In ArrayCoalescing.hs, fun failedOptDep, impossible case reached!" -- | Perform short-circuiting on 'Stms'. mkCoalsTabStms :: (Coalesceable rep inner) => LUTabFun -> Stms (Aliases rep) -> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv mkCoalsTabStms lutab stms0 = traverseStms stms0 where non_negs_in_pats = foldMap (nonNegativesInPat . stmPat) stms0 traverseStms Empty _ bu_env = pure bu_env traverseStms (stm :<| stms) td_env bu_env = do -- Compute @td_env@ top down let td_env' = updateTopdownEnv td_env stm -- Compute @bu_env@ bottom up bu_env' <- traverseStms stms td_env' bu_env mkCoalsTabStm lutab stm (td_env' {nonNegatives = nonNegatives td_env' <> non_negs_in_pats}) bu_env' -- | Array (register) coalescing can have one of three shapes: -- a) @let y = copy(b^{lu})@ -- b) @let y = concat(a, b^{lu})@ -- c) @let y[i] = b^{lu}@ -- The intent is to use the memory block of the left-hand side -- for the right-hand side variable, meaning to store @b@ in -- @m_y@ (rather than @m_b@). -- The following five safety conditions are necessary: -- 1. the right-hand side is lastly-used in the current statement -- 2. the allocation of @m_y@ dominates the creation of @b@ -- ^ relax it by hoisting the allocation of @m_y@ -- 3. there is no use of the left-hand side memory block @m_y@ -- during the liveness of @b@, i.e., in between its last use -- and its creation. -- ^ relax it by pointwise/interval-based checking -- 4. @b@ is a newly created array, i.e., does not aliases anything -- ^ relax it to support exitential memory blocks for if-then-else -- 5. the new index function of @b@ corresponding to memory block @m_y@ -- can be translated at the definition of @b@, and the -- same for all variables aliasing @b@. -- Observation: during the live range of @b@, @m_b@ can only be used by -- variables aliased with @b@, because @b@ is newly created. -- relax it: in case @m_b@ is existential due to an if-then-else -- then the checks should be extended to the actual -- array-creation points. mkCoalsTabStm :: (Coalesceable rep inner) => LUTabFun -> Stm (Aliases rep) -> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv mkCoalsTabStm _ (Let (Pat [pe]) _ e) td_env bu_env | Just primexp <- primExpFromExp (vnameToPrimExp (scope td_env) (scals bu_env)) e = pure $ bu_env {scals = M.insert (patElemName pe) primexp (scals bu_env)} mkCoalsTabStm lutab (Let patt _ (Match _ cases defbody _)) td_env bu_env = do let pat_val_elms = patElems patt -- ToDo: 1. we need to record existential memory blocks in alias table on the top-down pass. -- 2. need to extend the scope table -- i) Filter @activeCoals@ by the 2ND AND 5th safety conditions: (activeCoals0, inhibit0) = filterSafetyCond2and5 (activeCoals bu_env) (inhibit bu_env) (scals bu_env) td_env pat_val_elms -- ii) extend @activeCoals@ by transfering the pattern-elements bindings existent -- in @activeCoals@ to the body results of the then and else branches, but only -- if the current pattern element can be potentially coalesced and also -- if the current pattern element satisfies safety conditions 2 & 5. res_mem_def = findMemBodyResult activeCoals0 (scope td_env) pat_val_elms defbody res_mem_cases = map (findMemBodyResult activeCoals0 (scope td_env) pat_val_elms . caseBody) cases subs_def = mkSubsTab patt $ map resSubExp $ bodyResult defbody subs_cases = map (mkSubsTab patt . map resSubExp . bodyResult . caseBody) cases actv_def_i = foldl (transferCoalsToBody subs_def) activeCoals0 res_mem_def actv_cases_i = zipWith (\subs res -> foldl (transferCoalsToBody subs) activeCoals0 res) subs_cases res_mem_cases -- eliminate the original pattern binding of the if statement, -- @let x = if y[0,0] > 0 then map (+y[0,0]) a else map (+1) b@ -- @let y[0] = x@ -- should succeed because @m_y@ is used before @x@ is created. aux ac (MemBodyResult m_b _ _ m_r) = if m_b == m_r then ac else M.delete m_b ac actv_def = foldl aux actv_def_i res_mem_def actv_cases = zipWith (foldl aux) actv_cases_i res_mem_cases -- iii) process the then and else bodies res_def <- mkCoalsTabStms lutab (bodyStms defbody) td_env (bu_env {activeCoals = actv_def}) res_cases <- zipWithM (\c a -> mkCoalsTabStms lutab (bodyStms $ caseBody c) td_env (bu_env {activeCoals = a})) cases actv_cases let (actv_def0, succ_def0, inhb_def0) = (activeCoals res_def, successCoals res_def, inhibit res_def) -- iv) optimistically mark the pattern succesful: ((activeCoals1, inhibit1), successCoals1) = foldl ( foldfun ( (actv_def0, succ_def0) : zip (map activeCoals res_cases) (map successCoals res_cases) ) ) ((activeCoals0, inhibit0), successCoals bu_env) (L.transpose $ res_mem_def : res_mem_cases) -- v) unify coalescing results of all branches by taking the union -- of all entries in the current/then/else success tables. actv_res = foldr (M.intersectionWith unionCoalsEntry) activeCoals1 $ actv_def0 : map activeCoals res_cases succ_res = foldr (M.unionWith unionCoalsEntry) successCoals1 $ succ_def0 : map successCoals res_cases -- vi) The step of filtering by 3rd safety condition is not -- necessary, because we perform index analysis of the -- source/destination uses, and they should have been -- filtered during the analysis of the then/else bodies. inhibit_res = M.unionsWith (<>) ( inhibit1 : zipWith ( \actv inhb -> let failed = M.difference actv $ M.intersectionWith unionCoalsEntry actv activeCoals0 in snd $ foldl markFailedCoal (failed, inhb) (M.keys failed) ) (actv_def0 : map activeCoals res_cases) (inhb_def0 : map inhibit res_cases) ) pure bu_env { activeCoals = actv_res, successCoals = succ_res, inhibit = inhibit_res } where foldfun _ _ [] = error "Imposible Case 1!!!" foldfun _ ((act, _), _) mem_body_results | Nothing <- M.lookup (patMem $ head mem_body_results) act = error "Imposible Case 2!!!" foldfun acc ((act, inhb), succc) mem_body_results@(MemBodyResult m_b _ _ _ : _) | Just info <- M.lookup m_b act, Just _ <- zipWithM (M.lookup . bodyMem) mem_body_results $ map snd acc = -- Optimistically promote to successful coalescing and append! let info' = info { optdeps = foldr (\mbr -> M.insert (bodyName mbr) (bodyMem mbr)) (optdeps info) mem_body_results } (act', succc') = markSuccessCoal (act, succc) m_b info' in ((act', inhb), succc') foldfun acc ((act, inhb), succc) mem_body_results@(MemBodyResult m_b _ _ _ : _) | Just info <- M.lookup m_b act, all ((==) m_b . bodyMem) mem_body_results, Just info' <- zipWithM (M.lookup . bodyMem) mem_body_results $ map fst acc = -- Treating special case resembling: -- @let x0 = map (+1) a @ -- @let x3 = if cond then let x1 = x0 with [0] <- 2 in x1@ -- @ else let x2 = x0 with [1] <- 3 in x2@ -- @let z[1] = x3 @ -- In this case the result active table should be the union -- of the @m_x@ entries of the then and else active tables. let info'' = foldl unionCoalsEntry info info' act' = M.insert m_b info'' act in ((act', inhb), succc) foldfun _ ((act, inhb), succc) (mbr : _) = -- one of the branches has failed coalescing, -- hence remove the coalescing of the result. (markFailedCoal (act, inhb) (patMem mbr), succc) mkCoalsTabStm lutab (Let pat _ (Loop arginis lform body)) td_env bu_env = do let pat_val_elms = patElems pat -- i) Filter @activeCoals@ by the 2nd, 3rd AND 5th safety conditions. In -- other words, for each active coalescing target, the creation of the -- array we're trying to merge should happen before the allocation of the -- merge target and the index function should be translateable. (actv0, inhibit0) = filterSafetyCond2and5 (activeCoals bu_env) (inhibit bu_env) (scals bu_env) td_env pat_val_elms -- ii) Extend @activeCoals@ by transfering the pattern-elements bindings -- existent in @activeCoals@ to the loop-body results, but only if: -- (a) the pattern element is a candidate for coalescing, && -- (b) the pattern element satisfies safety conditions 2 & 5, -- (conditions (a) and (b) have already been checked above), && -- (c) the memory block of the corresponding body result is -- allocated outside the loop, i.e., non-existential, && -- (d) the init name is lastly-used in the initialization -- of the loop variant. -- Otherwise fail and remove from active-coalescing table! bdy_ress = bodyResult body (patmems, argmems, inimems, resmems) = L.unzip4 $ mapMaybe (mapmbFun actv0) (zip3 pat_val_elms arginis $ map resSubExp bdy_ress) -- td_env' -- remove the other pattern elements from the active coalescing table: coal_pat_names = namesFromList $ map fst patmems (actv1, inhibit1) = foldl ( \(act, inhb) (b, MemBlock _ _ m_b _) -> if b `nameIn` coal_pat_names then (act, inhb) -- ok else markFailedCoal (act, inhb) m_b -- remove from active ) (actv0, inhibit0) (getArrMemAssoc pat) -- iii) Process the loop's body. -- If the memory blocks of the loop result and loop variant param differ -- then make the original memory block of the loop result conflict with -- the original memory block of the loop parameter. This is done in -- order to prevent the coalescing of @a1@, @a0@, @x@ and @db@ in the -- same memory block of @y@ in the example below: -- @loop(a1 = a0) = for i < n do @ -- @ let x = map (stencil a1) (iota n)@ -- @ let db = copy x @ -- @ in db @ -- @let y[0] = a1 @ -- Meaning the coalescing of @x@ in @let db = copy x@ should fail because -- @a1@ appears in the definition of @let x = map (stencil a1) (iota n)@. res_mem_bdy = zipWith (\(b, m_b) (r, m_r) -> MemBodyResult m_b b r m_r) patmems resmems res_mem_arg = zipWith (\(b, m_b) (r, m_r) -> MemBodyResult m_b b r m_r) patmems argmems res_mem_ini = zipWith (\(b, m_b) (r, m_r) -> MemBodyResult m_b b r m_r) patmems inimems actv2 = let subs_res = mkSubsTab pat $ map resSubExp $ bodyResult body actv11 = foldl (transferCoalsToBody subs_res) actv1 res_mem_bdy subs_arg = mkSubsTab pat $ map (Var . paramName . fst) arginis actv12 = foldl (transferCoalsToBody subs_arg) actv11 res_mem_arg subs_ini = mkSubsTab pat $ map snd arginis in foldl (transferCoalsToBody subs_ini) actv12 res_mem_ini -- The code below adds an aliasing relation to the loop-arg memory -- so that to prevent, e.g., the coalescing of an iterative stencil -- (you need a buffer for the result and a separate one for the stencil). -- @ let b = @ -- @ loop (a) for i tab Just etry -> M.insert m_r (etry {alsmem = alsmem etry <> oneName m_a}) tab actv3 = foldl insertMemAliases actv2 (zip res_mem_bdy res_mem_arg) -- analysing the loop body starts from a null memory-reference set; -- the results of the loop body iteration are aggregated later actv4 = M.map (\etry -> etry {memrefs = mempty}) actv3 res_env_body <- mkCoalsTabStms lutab (bodyStms body) td_env' ( bu_env { activeCoals = actv4, inhibit = inhibit1 } ) let scals_loop = scals res_env_body (res_actv0, res_succ0, res_inhb0) = (activeCoals res_env_body, successCoals res_env_body, inhibit res_env_body) -- iv) Aggregate memory references across loop and filter unsound coalescing -- a) Filter the active-table by the FIRST SOUNDNESS condition, namely: -- W_i does not overlap with Union_{j=i+1..n} U_j, -- where W_i corresponds to the Write set of src mem-block m_b, -- and U_j correspond to the uses of the destination -- mem-block m_y, in which m_b is coalesced into. -- W_i and U_j correspond to the accesses within the loop body. mb_loop_idx = mbLoopIndexRange lform res_actv1 <- filterMapM1 (loopSoundness1Entry scals_loop mb_loop_idx) res_actv0 -- b) Update the memory-reference summaries across loop: -- W = Union_{i=0..n-1} W_i Union W_{before-loop} -- U = Union_{i=0..n-1} U_i Union U_{before-loop} res_actv2 <- mapM (aggAcrossLoopEntry (scope td_env' <> scopeOf (bodyStms body)) scals_loop mb_loop_idx) res_actv1 -- c) check soundness of the successful promotions for: -- - the entries that have been promoted to success during the loop-body pass -- - for all the entries of active table -- Filter the entries by the SECOND SOUNDNESS CONDITION, namely: -- Union_{i=1..n-1} W_i does not overlap the before-the-loop uses -- of the destination memory block. let res_actv3 = M.filterWithKey (loopSoundness2Entry actv3) res_actv2 let tmp_succ = M.filterWithKey (okLookup actv3) $ M.difference res_succ0 (successCoals bu_env) ver_succ = M.filterWithKey (loopSoundness2Entry actv3) tmp_succ let suc_fail = M.difference tmp_succ ver_succ (res_succ, res_inhb1) = foldl markFailedCoal (res_succ0, res_inhb0) $ M.keys suc_fail -- act_fail = M.difference res_actv0 res_actv3 (_, res_inhb) = foldl markFailedCoal (res_actv0, res_inhb1) $ M.keys act_fail res_actv = M.mapWithKey (addBeforeLoop actv3) res_actv3 -- v) optimistically mark the pattern succesful if there is any chance to succeed ((fin_actv1, fin_inhb1), fin_succ1) = foldl foldFunOptimPromotion ((res_actv, res_inhb), res_succ) $ L.zip4 patmems argmems resmems inimems (fin_actv2, fin_inhb2) = M.foldlWithKey ( \acc k _ -> if k `nameIn` namesFromList (map (paramName . fst) arginis) then markFailedCoal acc k else acc ) (fin_actv1, fin_inhb1) fin_actv1 pure bu_env {activeCoals = fin_actv2, successCoals = fin_succ1, inhibit = fin_inhb2} where allocs_bdy = foldl getAllocs (alloc td_env') $ bodyStms body td_env_allocs = td_env' {alloc = allocs_bdy, scope = scope td_env' <> scopeOf (bodyStms body)} td_env' = updateTopdownEnvLoop td_env arginis lform getAllocs tab (Let (Pat [pe]) _ (Op (Alloc _ sp))) = M.insert (patElemName pe) sp tab getAllocs tab _ = tab okLookup tab m _ | Just _ <- M.lookup m tab = True okLookup _ _ _ = False -- mapmbFun actv0 (patel, (arg, ini), bdyres) | b <- patElemName patel, (_, MemArray _ _ _ (ArrayIn m_b _)) <- patElemDec patel, a <- paramName arg, -- Not safe to short-circuit if the index function of this -- parameter is variant to the loop. not $ any ((`nameIn` freeIn (paramDec arg)) . paramName . fst) arginis, Var a0 <- ini, Var r <- bdyres, Just coal_etry <- M.lookup m_b actv0, Just _ <- M.lookup b (vartab coal_etry), Just (MemBlock _ _ m_a _) <- getScopeMemInfo a (scope td_env_allocs), Just (MemBlock _ _ m_a0 _) <- getScopeMemInfo a0 (scope td_env_allocs), Just (MemBlock _ _ m_r _) <- getScopeMemInfo r (scope td_env_allocs), Just nms <- M.lookup a lutab, a0 `nameIn` nms, m_r `elem` M.keys (alloc td_env_allocs) = Just ((b, m_b), (a, m_a), (a0, m_a0), (r, m_r)) mapmbFun _ (_patel, (_arg, _ini), _bdyres) = Nothing foldFunOptimPromotion :: ((CoalsTab, InhibitTab), CoalsTab) -> ((VName, VName), (VName, VName), (VName, VName), (VName, VName)) -> ((CoalsTab, InhibitTab), CoalsTab) foldFunOptimPromotion ((act, inhb), succc) ((b, m_b), (a, m_a), (_r, m_r), (b_i, m_i)) | m_r == m_i, Just info <- M.lookup m_i act, Just vtab_i <- addInvAliasesVarTab td_env (vartab info) b_i = Exc.assert (m_r == m_b && m_a == m_b) ((M.insert m_b (info {vartab = vtab_i}) act, inhb), succc) | m_r == m_i = Exc.assert (m_r == m_b && m_a == m_b) (markFailedCoal (act, inhb) m_b, succc) | Just info_b0 <- M.lookup m_b act, Just info_a0 <- M.lookup m_a act, Just info_i <- M.lookup m_i act, M.member m_r succc, Just vtab_i <- addInvAliasesVarTab td_env (vartab info_i) b_i, [Just info_b, Just info_a] <- map translateIxFnInScope [(b, info_b0), (a, info_a0)] = let info_b' = info_b {optdeps = M.insert b_i m_i $ optdeps info_b} info_a' = info_a {optdeps = M.insert b_i m_i $ optdeps info_a} info_i' = info_i { optdeps = M.insert b m_b $ optdeps info_i, memrefs = mempty, vartab = vtab_i } act' = M.insert m_i info_i' act (act1, succc1) = foldl (\acc (m, info) -> markSuccessCoal acc m info) (act', succc) [(m_b, info_b'), (m_a, info_a')] in -- ToDo: make sure that ixfun translates and update substitutions (?) ((act1, inhb), succc1) foldFunOptimPromotion ((act, inhb), succc) ((_, m_b), (_a, m_a), (_r, m_r), (_b_i, m_i)) = Exc.assert (m_r /= m_i) (foldl markFailedCoal (act, inhb) [m_b, m_a, m_r, m_i], succc) translateIxFnInScope (x, info) | Just (Coalesced knd mbd@(MemBlock _ _ _ ixfn) _subs0) <- M.lookup x (vartab info), isInScope td_env (dstmem info) = let scope_tab = scope td_env <> scopeOfFParams (map fst arginis) in case freeVarSubstitutions scope_tab (scals bu_env) ixfn of Just fv_subst -> Just $ info {vartab = M.insert x (Coalesced knd mbd fv_subst) (vartab info)} Nothing -> Nothing translateIxFnInScope _ = Nothing se0 = intConst Int64 0 mbLoopIndexRange :: LoopForm -> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName)) mbLoopIndexRange (WhileLoop _) = Nothing mbLoopIndexRange (ForLoop inm _inttp seN) = Just (inm, (pe64 se0, pe64 seN)) addBeforeLoop actv_bef m_b etry = case M.lookup m_b actv_bef of Nothing -> etry Just etry0 -> etry {memrefs = memrefs etry0 <> memrefs etry} aggAcrossLoopEntry scope_loop scal_tab idx etry = do wrts <- aggSummaryLoopTotal (scope td_env) scope_loop scal_tab idx $ (srcwrts . memrefs) etry uses <- aggSummaryLoopTotal (scope td_env) scope_loop scal_tab idx $ (dstrefs . memrefs) etry pure $ etry {memrefs = MemRefs uses wrts} loopSoundness1Entry scal_tab idx etry = do let wrt_i = (srcwrts . memrefs) etry use_p <- aggSummaryLoopPartial (scal_tab <> scalarTable td_env) idx $ dstrefs $ memrefs etry pure $ noMemOverlap td_env' wrt_i use_p loopSoundness2Entry :: CoalsTab -> VName -> CoalsEntry -> Bool loopSoundness2Entry old_actv m_b etry = case M.lookup m_b old_actv of Nothing -> True Just etry0 -> let uses_before = (dstrefs . memrefs) etry0 write_loop = (srcwrts . memrefs) etry in noMemOverlap td_env write_loop uses_before -- The case of in-place update: -- @let x' = x with slice <- elm@ mkCoalsTabStm lutab stm@(Let pat@(Pat [x']) _ (BasicOp (Update safety x _ _elm))) td_env bu_env | [(_, MemBlock _ _ m_x _)] <- getArrMemAssoc pat = do -- (a) filter by the 3rd safety for @elm@ and @x'@ let (actv, inhbt) = recordMemRefUses td_env bu_env stm -- (b) if @x'@ is in active coalesced table, then add an entry for @x@ as well (actv', inhbt') = case M.lookup m_x actv of Nothing -> (actv, inhbt) Just info -> case M.lookup (patElemName x') (vartab info) of Nothing -> markFailedCoal (actv, inhbt) m_x Just (Coalesced k mblk@(MemBlock _ _ _ x_indfun) _) -> case freeVarSubstitutions (scope td_env) (scals bu_env) x_indfun of Just fv_subs | isInScope td_env (dstmem info) -> let coal_etry_x = Coalesced k mblk fv_subs info' = info { vartab = M.insert x coal_etry_x $ M.insert (patElemName x') coal_etry_x (vartab info) } in (M.insert m_x info' actv, inhbt) _ -> markFailedCoal (actv, inhbt) m_x -- (c) this stm is also a potential source for coalescing, so process it actv'' <- if safety == Unsafe then mkCoalsHelper3PatternMatch stm lutab td_env {inhibited = inhbt'} bu_env {activeCoals = actv'} else pure actv' pure $ bu_env {activeCoals = actv'', inhibit = inhbt'} -- The case of flat in-place update: -- @let x' = x with flat-slice <- elm@ mkCoalsTabStm lutab stm@(Let pat@(Pat [x']) _ (BasicOp (FlatUpdate x _ _elm))) td_env bu_env | [(_, MemBlock _ _ m_x _)] <- getArrMemAssoc pat = do -- (a) filter by the 3rd safety for @elm@ and @x'@ let (actv, inhbt) = recordMemRefUses td_env bu_env stm -- (b) if @x'@ is in active coalesced table, then add an entry for @x@ as well (actv', inhbt') = case M.lookup m_x actv of Nothing -> (actv, inhbt) Just info -> case M.lookup (patElemName x') (vartab info) of -- this case should not happen, but if it can that -- just fail conservatively Nothing -> markFailedCoal (actv, inhbt) m_x Just (Coalesced k mblk@(MemBlock _ _ _ x_indfun) _) -> case freeVarSubstitutions (scope td_env) (scals bu_env) x_indfun of Just fv_subs | isInScope td_env (dstmem info) -> let coal_etry_x = Coalesced k mblk fv_subs info' = info { vartab = M.insert x coal_etry_x $ M.insert (patElemName x') coal_etry_x (vartab info) } in (M.insert m_x info' actv, inhbt) _ -> markFailedCoal (actv, inhbt) m_x -- (c) this stm is also a potential source for coalescing, so process it actv'' <- mkCoalsHelper3PatternMatch stm lutab td_env {inhibited = inhbt'} bu_env {activeCoals = actv'} pure $ bu_env {activeCoals = actv'', inhibit = inhbt'} -- mkCoalsTabStm _ (Let pat _ (BasicOp Update {})) _ _ = error $ "In ArrayCoalescing.hs, fun mkCoalsTabStm, illegal pattern for in-place update: " ++ show pat -- default handling mkCoalsTabStm lutab stm@(Let pat aux (Op op)) td_env bu_env = do -- Process body on_op <- asks onOp bu_env' <- on_op lutab pat (stmAuxCerts aux) op td_env bu_env activeCoals' <- mkCoalsHelper3PatternMatch stm lutab td_env bu_env' pure $ bu_env' {activeCoals = activeCoals'} mkCoalsTabStm lutab stm@(Let pat _ e) td_env bu_env = do -- i) Filter @activeCoals@ by the 3rd safety condition: -- this is now relaxed by use of LMAD eqs: -- the memory referenced in stm are added to memrefs::dstrefs -- in corresponding coal-tab entries. let (activeCoals', inhibit') = recordMemRefUses td_env bu_env stm -- ii) promote any of the entries in @activeCoals@ to @successCoals@ as long as -- - this statement defined a variable consumed in a coalesced statement -- - and safety conditions 2, 4, and 5 are satisfied. -- AND extend @activeCoals@ table for any definition of a variable that -- aliases a coalesced variable. safe_4 = createsNewArrOK e ((activeCoals'', inhibit''), successCoals') = foldl (foldfun safe_4) ((activeCoals', inhibit'), successCoals bu_env) (getArrMemAssoc pat) -- iii) record a potentially coalesced statement in @activeCoals@ activeCoals''' <- mkCoalsHelper3PatternMatch stm lutab td_env bu_env {successCoals = successCoals', activeCoals = activeCoals''} pure bu_env {activeCoals = activeCoals''', inhibit = inhibit'', successCoals = successCoals'} where foldfun safe_4 ((a_acc, inhb), s_acc) (b, MemBlock tp shp mb _b_indfun) = case M.lookup mb a_acc of Nothing -> ((a_acc, inhb), s_acc) Just info@(CoalsEntry x_mem _ _ vtab _ _ certs) -> let failed = markFailedCoal (a_acc, inhb) mb in case M.lookup b vtab of Nothing -> -- we hit the definition of some variable @b@ aliased with -- the coalesced variable @x@, hence extend @activeCoals@, e.g., -- @let x = map f arr @ -- @let b = alias x @ <- current statement -- @ ... use of b ... @ -- @let c = alias b @ <- currently fails -- @let y[i] = x @ -- where @alias@ can be @transpose@, @slice@, @reshape@. -- We use getTransitiveAlias helper function to track the aliasing -- through the td_env, and to find the updated ixfun of @b@: case getDirAliasedIxfn td_env a_acc b of Nothing -> (failed, s_acc) Just (_, _, b_indfun') -> case ( freeVarSubstitutions (scope td_env) (scals bu_env) b_indfun', freeVarSubstitutions (scope td_env) (scals bu_env) certs ) of (Just fv_subst, Just fv_subst') -> let mem_info = Coalesced TransitiveCoal (MemBlock tp shp x_mem b_indfun') (fv_subst <> fv_subst') info' = info {vartab = M.insert b mem_info vtab} in ((M.insert mb info' a_acc, inhb), s_acc) _ -> (failed, s_acc) Just (Coalesced k mblk@(MemBlock _ _ _ new_indfun) _) -> -- we are at the definition of the coalesced variable @b@ -- if 2,4,5 hold promote it to successful coalesced table, -- or if e = transpose, etc. then postpone decision for later on let safe_2 = isInScope td_env x_mem in case ( freeVarSubstitutions (scope td_env) (scals bu_env) new_indfun, freeVarSubstitutions (scope td_env) (scals bu_env) certs ) of (Just fv_subst, Just fv_subst') | safe_2 -> let mem_info = Coalesced k mblk (fv_subst <> fv_subst') info' = info {vartab = M.insert b mem_info vtab} in if safe_4 then -- array creation point, successful coalescing verified! let (a_acc', s_acc') = markSuccessCoal (a_acc, s_acc) mb info' in ((a_acc', inhb), s_acc') else -- this is an invertible alias case of the kind -- @ let b = alias a @ -- @ let x[i] = b @ -- do not promote, but update the index function ((M.insert mb info' a_acc, inhb), s_acc) _ -> (failed, s_acc) -- fail! ixfunToAccessSummary :: LMAD.LMAD (TPrimExp Int64 VName) -> AccessSummary ixfunToAccessSummary = Set . S.singleton -- | Check safety conditions 2 and 5 and update new substitutions: -- called on the pat-elements of loop and if-then-else expressions. -- -- The safety conditions are: The allocation of merge target should dominate the -- creation of the array we're trying to merge and the new index function of the -- array can be translated at the definition site of b. The latter requires that -- any variables used in the index function of the target array are available at -- the definition site of b. filterSafetyCond2and5 :: (HasMemBlock (Aliases rep)) => CoalsTab -> InhibitTab -> ScalarTab -> TopdownEnv rep -> [PatElem (VarAliases, LetDecMem)] -> (CoalsTab, InhibitTab) filterSafetyCond2and5 act_coal inhb_coal scals_env td_env pes = foldl helper (act_coal, inhb_coal) pes where helper (acc, inhb) patel = do -- For each pattern element in the input list case (patElemName patel, patElemDec patel) of (b, (_, MemArray tp0 shp0 _ (ArrayIn m_b _idxfn_b))) -> -- If it is an array in memory block m_b case M.lookup m_b acc of Nothing -> (acc, inhb) Just info@(CoalsEntry x_mem _ _ vtab _ _ certs) -> -- And m_b we're trying to coalesce m_b let failed = markFailedCoal (acc, inhb) m_b in -- It is not safe to short circuit if some other pattern -- element is aliased to this one, as this indicates the -- two pattern elements reference the same physical -- value somehow. if any ((`nameIn` aliasesOf patel) . patElemName) pes then failed else case M.lookup b vtab of Nothing -> case getDirAliasedIxfn td_env acc b of Nothing -> failed Just (_, _, b_indfun') -> -- And we have the index function of b case ( freeVarSubstitutions (scope td_env) scals_env b_indfun', freeVarSubstitutions (scope td_env) scals_env certs ) of (Just fv_subst, Just fv_subst') -> let mem_info = Coalesced TransitiveCoal (MemBlock tp0 shp0 x_mem b_indfun') (fv_subst <> fv_subst') info' = info {vartab = M.insert b mem_info vtab} in (M.insert m_b info' acc, inhb) _ -> failed Just (Coalesced k (MemBlock pt shp _ new_indfun) _) -> let safe_2 = isInScope td_env x_mem in case ( freeVarSubstitutions (scope td_env) scals_env new_indfun, freeVarSubstitutions (scope td_env) scals_env certs ) of (Just fv_subst, Just fv_subst') | safe_2 -> let mem_info = Coalesced k (MemBlock pt shp x_mem new_indfun) (fv_subst <> fv_subst') info' = info {vartab = M.insert b mem_info vtab} in (M.insert m_b info' acc, inhb) _ -> failed _ -> (acc, inhb) -- | Pattern matches a potentially coalesced statement and -- records a new association in @activeCoals@ mkCoalsHelper3PatternMatch :: (Coalesceable rep inner) => Stm (Aliases rep) -> LUTabFun -> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep CoalsTab mkCoalsHelper3PatternMatch stm lutab td_env bu_env = do clst <- genCoalStmtInfo lutab td_env (scope td_env) stm case clst of Nothing -> pure activeCoals_tab Just clst' -> pure $ foldl processNewCoalesce activeCoals_tab clst' where successCoals_tab = successCoals bu_env activeCoals_tab = activeCoals bu_env processNewCoalesce acc (knd, alias_fn, x, m_x, ind_x, b, m_b, _, tp_b, shp_b, certs) = -- test whether we are in a transitive coalesced case, i.e., -- @let b = scratch ...@ -- @.....@ -- @let x[j] = b@ -- @let y[i] = x@ -- and compose the index function of @x@ with that of @y@, -- and update aliasing of the @m_b@ entry to also contain @m_y@ -- on top of @m_x@, i.e., transitively, any use of @m_y@ should -- be checked for the lifetime of @b@. let proper_coals_tab = case knd of InPlaceCoal -> activeCoals_tab _ -> successCoals_tab (m_yx, ind_yx, mem_yx_al, x_deps, certs') = case M.lookup m_x proper_coals_tab of Nothing -> (m_x, alias_fn ind_x, oneName m_x, M.empty, mempty) Just (CoalsEntry m_y ind_y y_al vtab x_deps0 _ certs'') -> let ind = case M.lookup x vtab of Just (Coalesced _ (MemBlock _ _ _ ixf) _) -> ixf Nothing -> ind_y in (m_y, alias_fn ind, oneName m_x <> y_al, x_deps0, certs <> certs'') m_b_aliased_m_yx = areAnyAliased td_env m_b [m_yx] -- m_b \= m_yx in if not m_b_aliased_m_yx && isInScope td_env m_yx -- nameIn m_yx (alloc td_env) -- Finally update the @activeCoals@ table with a fresh -- binding for @m_b@; if such one exists then overwrite. -- Also, add all variables from the alias chain of @b@ to -- @vartab@, for example, in the case of a sequence: -- @ b0 = if cond then ... else ... @ -- @ b1 = alias0 b0 @ -- @ b = alias1 b1 @ -- @ x[j] = b @ -- Then @b1@ and @b0@ should also be added to @vartab@ if -- @alias1@ and @alias0@ are invertible, otherwise fail early! then let mem_info = Coalesced knd (MemBlock tp_b shp_b m_yx ind_yx) M.empty opts' = if m_yx == m_x then M.empty else M.insert x m_x x_deps vtab = M.singleton b mem_info mvtab = addInvAliasesVarTab td_env vtab b is_inhibited = case M.lookup m_b $ inhibited td_env of Just nms -> m_yx `nameIn` nms Nothing -> False in case (is_inhibited, mvtab) of (True, _) -> acc -- fail due to inhibited (_, Nothing) -> acc -- fail early due to non-invertible aliasing (_, Just vtab') -> -- successfully adding a new coalesced entry let coal_etry = CoalsEntry m_yx ind_yx mem_yx_al vtab' opts' mempty (certs <> certs') in M.insert m_b coal_etry acc else acc -- | Information about a particular short-circuit point type SSPointInfo = ( CoalescedKind, LMAD -> LMAD, VName, VName, LMAD, VName, VName, LMAD, PrimType, Shape, Certs ) -- | Given an op, return a list of potential short-circuit points type GenSSPoint rep op = LUTabFun -> TopdownEnv rep -> ScopeTab rep -> Pat (VarAliases, LetDecMem) -> Certs -> op -> Maybe [SSPointInfo] genSSPointInfoSeqMem :: GenSSPoint SeqMem (Op (Aliases SeqMem)) genSSPointInfoSeqMem _ _ _ _ _ _ = Nothing -- | For 'SegOp', we currently only handle 'SegMap', and only under the following -- circumstances: -- -- 1. The 'SegMap' has only one return/pattern value, which is a 'Returns'. -- -- 2. The 'KernelBody' contains an 'Index' statement that is indexing an array using -- only the values from the 'SegSpace'. -- -- 3. The array being indexed is last-used in that statement, is free in the -- 'SegMap', is unique or has been recently allocated (specifically, it should -- not be a non-unique argument to the enclosing function), has elements with -- the same bit-size as the pattern elements, and has the exact same 'LMAD' as -- the pattern of the 'SegMap' statement. -- -- There can be multiple candidate arrays, but the current implementation will -- always just try the first one. -- -- The first restriction could be relaxed by trying to match up arrays in the -- 'KernelBody' with patterns of the 'SegMap', but the current implementation -- should be enough to handle many common cases. -- -- The result of the 'SegMap' is treated as the destination, while the candidate -- array from inside the body is treated as the source. genSSPointInfoSegOp :: (Coalesceable rep inner) => GenSSPoint rep (SegOp lvl (Aliases rep)) genSSPointInfoSegOp lutab td_env scopetab (Pat [PatElem dst (_, MemArray dst_pt _ _ (ArrayIn dst_mem dst_ixf))]) certs (SegMap _ space _ kernel_body@KernelBody {kernelBodyResult = [Returns {}]}) | (src, MemBlock src_pt shp src_mem src_ixf) : _ <- mapMaybe getPotentialMapShortCircuit $ stmsToList $ kernelBodyStms kernel_body = Just [(MapCoal, id, dst, dst_mem, dst_ixf, src, src_mem, src_ixf, src_pt, shp, certs)] where iterators = map fst $ unSegSpace space frees = freeIn kernel_body getPotentialMapShortCircuit (Let (Pat [PatElem x _]) _ (BasicOp (Index src slc))) | Just inds <- sliceIndices slc, L.sort inds == L.sort (map Var iterators), Just last_uses <- M.lookup x lutab, src `nameIn` last_uses, Just memblock@(MemBlock src_pt _ src_mem src_ixf) <- getScopeMemInfo src scopetab, src_mem `nameIn` last_uses, -- The 'alloc' table contains allocated memory blocks, including -- unique memory blocks from the enclosing function. It does _not_ -- include non-unique memory blocks from the enclosing function. src_mem `M.member` alloc td_env, src `nameIn` frees, src_ixf == dst_ixf, primBitSize src_pt == primBitSize dst_pt = Just (src, memblock) getPotentialMapShortCircuit _ = Nothing genSSPointInfoSegOp _ _ _ _ _ _ = Nothing genSSPointInfoMemOp :: GenSSPoint rep (inner (Aliases rep)) -> GenSSPoint rep (MemOp inner (Aliases rep)) genSSPointInfoMemOp onOp lutab td_end scopetab pat certs (Inner op) = onOp lutab td_end scopetab pat certs op genSSPointInfoMemOp _ _ _ _ _ _ _ = Nothing genSSPointInfoGPUMem :: GenSSPoint GPUMem (Op (Aliases GPUMem)) genSSPointInfoGPUMem = genSSPointInfoMemOp f where f lutab td_env scopetab pat certs (GPU.SegOp op) = genSSPointInfoSegOp lutab td_env scopetab pat certs op f _ _ _ _ _ _ = Nothing genSSPointInfoMCMem :: GenSSPoint MCMem (Op (Aliases MCMem)) genSSPointInfoMCMem = genSSPointInfoMemOp f where f lutab td_env scopetab pat certs (MC.ParOp Nothing op) = genSSPointInfoSegOp lutab td_env scopetab pat certs op f _ _ _ _ _ _ = Nothing genCoalStmtInfo :: (Coalesceable rep inner) => LUTabFun -> TopdownEnv rep -> ScopeTab rep -> Stm (Aliases rep) -> ShortCircuitM rep (Maybe [SSPointInfo]) -- CASE a) @let x <- copy(b^{lu})@ genCoalStmtInfo lutab td_env scopetab (Let pat aux (BasicOp (Replicate (Shape []) (Var b)))) | Pat [PatElem x (_, MemArray _ _ _ (ArrayIn m_x ind_x))] <- pat, Just last_uses <- M.lookup x lutab, Just (MemBlock tpb shpb m_b ind_b) <- getScopeMemInfo b scopetab, sameSpace td_env m_x m_b, b `nameIn` last_uses = pure $ Just [(CopyCoal, id, x, m_x, ind_x, b, m_b, ind_b, tpb, shpb, stmAuxCerts aux)] -- CASE c) @let x[i] = b^{lu}@ genCoalStmtInfo lutab td_env scopetab (Let pat aux (BasicOp (Update _ x slice_x (Var b)))) | Pat [PatElem x' (_, MemArray _ _ _ (ArrayIn m_x ind_x))] <- pat, Just last_uses <- M.lookup x' lutab, Just (MemBlock tpb shpb m_b ind_b) <- getScopeMemInfo b scopetab, sameSpace td_env m_x m_b, b `nameIn` last_uses = pure $ Just [(InPlaceCoal, (`updateIndFunSlice` slice_x), x, m_x, ind_x, b, m_b, ind_b, tpb, shpb, stmAuxCerts aux)] where updateIndFunSlice :: LMAD -> Slice SubExp -> LMAD updateIndFunSlice ind_fun slc_x = let slc_x' = map (fmap pe64) $ unSlice slc_x in LMAD.slice ind_fun $ Slice slc_x' genCoalStmtInfo lutab td_env scopetab (Let pat aux (BasicOp (FlatUpdate x slice_x b))) | Pat [PatElem x' (_, MemArray _ _ _ (ArrayIn m_x ind_x))] <- pat, Just last_uses <- M.lookup x' lutab, Just (MemBlock tpb shpb m_b ind_b) <- getScopeMemInfo b scopetab, sameSpace td_env m_x m_b, b `nameIn` last_uses = pure $ Just [(InPlaceCoal, (`updateIndFunSlice` slice_x), x, m_x, ind_x, b, m_b, ind_b, tpb, shpb, stmAuxCerts aux)] where updateIndFunSlice :: LMAD -> FlatSlice SubExp -> LMAD updateIndFunSlice ind_fun (FlatSlice offset dims) = LMAD.flatSlice ind_fun $ FlatSlice (pe64 offset) $ map (fmap pe64) dims -- CASE b) @let x = concat(a, b^{lu})@ genCoalStmtInfo lutab td_env scopetab (Let pat aux (BasicOp (Concat concat_dim (b0 :| bs) _))) | Pat [PatElem x (_, MemArray _ _ _ (ArrayIn m_x ind_x))] <- pat, Just last_uses <- M.lookup x lutab = pure $ let (res, _, _) = foldl (markConcatParts last_uses x m_x ind_x) ([], zero, True) (b0 : bs) in if null res then Nothing else Just res where zero = pe64 $ intConst Int64 0 markConcatParts _ _ _ _ acc@(_, _, False) _ = acc markConcatParts last_uses x m_x ind_x (acc, offs, True) b | Just (MemBlock tpb shpb@(Shape dims@(_ : _)) m_b ind_b) <- getScopeMemInfo b scopetab, Just d <- maybeNth concat_dim dims, offs' <- offs + pe64 d = if b `nameIn` last_uses && sameSpace td_env m_x m_b then let slc = Slice $ map (unitSlice zero . pe64) (take concat_dim dims) <> [unitSlice offs (pe64 d)] <> map (unitSlice zero . pe64) (drop (concat_dim + 1) dims) in ( acc ++ [(ConcatCoal, (`LMAD.slice` slc), x, m_x, ind_x, b, m_b, ind_b, tpb, shpb, stmAuxCerts aux)], offs', True ) else (acc, offs', True) | otherwise = (acc, offs, False) -- case d) short-circuit points from ops. For instance, the result of a segmap -- can be considered a short-circuit point. genCoalStmtInfo lutab td_env scopetab (Let pat aux (Op op)) = do ss_op <- asks ssPointFromOp pure $ ss_op lutab td_env scopetab pat (stmAuxCerts aux) op -- CASE other than a), b), c), or d) not supported genCoalStmtInfo _ _ _ _ = pure Nothing sameSpace :: (Coalesceable rep inner) => TopdownEnv rep -> VName -> VName -> Bool sameSpace td_env m_x m_b | Just (MemMem pat_space) <- nameInfoToMemInfo <$> M.lookup m_x scope', Just (MemMem return_space) <- nameInfoToMemInfo <$> M.lookup m_b scope' = pat_space == return_space | otherwise = False where scope' = removeScopeAliases $ scope td_env data MemBodyResult = MemBodyResult { patMem :: VName, _patName :: VName, bodyName :: VName, bodyMem :: VName } -- | Results in pairs of pattern-blockresult pairs of (var name, mem block) -- for those if-patterns that are candidates for coalescing. findMemBodyResult :: (HasMemBlock (Aliases rep)) => CoalsTab -> ScopeTab rep -> [PatElem (VarAliases, LetDecMem)] -> Body (Aliases rep) -> [MemBodyResult] findMemBodyResult activeCoals_tab scope_env patelms bdy = mapMaybe findMemBodyResult' (zip patelms $ map resSubExp $ bodyResult bdy) where scope_env' = scope_env <> scopeOf (bodyStms bdy) findMemBodyResult' (patel, se_r) = case (patElemName patel, patElemDec patel, se_r) of (b, (_, MemArray _ _ _ (ArrayIn m_b _)), Var r) -> case getScopeMemInfo r scope_env' of Nothing -> Nothing Just (MemBlock _ _ m_r _) -> case M.lookup m_b activeCoals_tab of Nothing -> Nothing Just coal_etry -> case M.lookup b (vartab coal_etry) of Nothing -> Nothing Just _ -> Just $ MemBodyResult m_b b r m_r _ -> Nothing -- | transfers coalescing from if-pattern to then|else body result -- in the active coalesced table. The transfer involves, among -- others, inserting @(r,m_r)@ in the optimistically-dependency -- set of @m_b@'s entry and inserting @(b,m_b)@ in the opt-deps -- set of @m_r@'s entry. Meaning, ultimately, @m_b@ can be merged -- if @m_r@ can be merged (and vice-versa). This is checked by a -- fix point iteration at the function-definition level. transferCoalsToBody :: M.Map VName (TPrimExp Int64 VName) -> -- (PrimExp VName) CoalsTab -> MemBodyResult -> CoalsTab transferCoalsToBody exist_subs activeCoals_tab (MemBodyResult m_b b r m_r) | -- the @Nothing@ pattern for the two lookups cannot happen -- because they were already cheked in @findMemBodyResult@ Just etry <- M.lookup m_b activeCoals_tab, Just (Coalesced knd (MemBlock btp shp _ ind_b) subst_b) <- M.lookup b $ vartab etry = -- by definition of if-stmt, r and b have the same basic type, shape and -- index function, hence, for example, do not need to rebase -- We will check whether it is translatable at the definition point of r. let ind_r = LMAD.substitute exist_subs ind_b subst_r = M.union exist_subs subst_b mem_info = Coalesced knd (MemBlock btp shp (dstmem etry) ind_r) subst_r in if m_r == m_b -- already unified, just add binding for @r@ then let etry' = etry { optdeps = M.insert b m_b (optdeps etry), vartab = M.insert r mem_info (vartab etry) } in M.insert m_r etry' activeCoals_tab else -- make them both optimistically depend on each other let opts_x_new = M.insert r m_r (optdeps etry) -- Here we should translate the @ind_b@ field of @mem_info@ -- across the existential introduced by the if-then-else coal_etry = etry { vartab = M.singleton r mem_info, optdeps = M.insert b m_b (optdeps etry) } in M.insert m_b (etry {optdeps = opts_x_new}) $ M.insert m_r coal_etry activeCoals_tab | otherwise = error "Impossible" mkSubsTab :: Pat (aliases, LetDecMem) -> [SubExp] -> M.Map VName (TPrimExp Int64 VName) mkSubsTab pat res = let pat_elms = patElems pat in M.fromList $ mapMaybe mki64subst $ zip pat_elms res where mki64subst (a, Var v) | (_, MemPrim (IntType Int64)) <- patElemDec a = Just (patElemName a, le64 v) mki64subst (a, se@(Constant (IntValue (Int64Value _)))) = Just (patElemName a, pe64 se) mki64subst _ = Nothing computeScalarTable :: (Coalesceable rep inner) => ScopeTab rep -> Stm (Aliases rep) -> ScalarTableM rep (M.Map VName (PrimExp VName)) computeScalarTable scope_table (Let (Pat [pe]) _ e) | Just primexp <- primExpFromExp (vnameToPrimExp scope_table mempty) e = pure $ M.singleton (patElemName pe) primexp computeScalarTable scope_table (Let _ _ (Loop loop_inits loop_form body)) = concatMapM ( computeScalarTable $ scope_table <> scopeOfFParams (map fst loop_inits) <> scopeOfLoopForm loop_form <> scopeOf (bodyStms body) ) (stmsToList $ bodyStms body) computeScalarTable scope_table (Let _ _ (Match _ cases body _)) = do body_tab <- concatMapM (computeScalarTable $ scope_table <> scopeOf (bodyStms body)) (stmsToList $ bodyStms body) cases_tab <- concatMapM ( \(Case _ b) -> concatMapM (computeScalarTable $ scope_table <> scopeOf (bodyStms b)) ( stmsToList $ bodyStms body ) ) cases pure $ body_tab <> cases_tab computeScalarTable scope_table (Let _ _ (Op op)) = do on_op <- asks scalarTableOnOp on_op scope_table op computeScalarTable _ _ = pure mempty computeScalarTableMemOp :: ComputeScalarTable rep (inner (Aliases rep)) -> ComputeScalarTable rep (MemOp inner (Aliases rep)) computeScalarTableMemOp _ _ (Alloc _ _) = pure mempty computeScalarTableMemOp onInner scope_table (Inner op) = onInner scope_table op computeScalarTableSegOp :: (Coalesceable rep inner) => ComputeScalarTable rep (GPU.SegOp lvl (Aliases rep)) computeScalarTableSegOp scope_table segop = do concatMapM ( computeScalarTable $ scope_table <> scopeOf (kernelBodyStms $ segBody segop) <> scopeOfSegSpace (segSpace segop) ) (stmsToList $ kernelBodyStms $ segBody segop) computeScalarTableGPUMem :: ComputeScalarTable GPUMem (GPU.HostOp NoOp (Aliases GPUMem)) computeScalarTableGPUMem scope_table (GPU.SegOp segop) = computeScalarTableSegOp scope_table segop computeScalarTableGPUMem _ (GPU.SizeOp _) = pure mempty computeScalarTableGPUMem _ (GPU.OtherOp NoOp) = pure mempty computeScalarTableGPUMem scope_table (GPU.GPUBody _ body) = concatMapM (computeScalarTable $ scope_table <> scopeOf (bodyStms body)) (stmsToList $ bodyStms body) computeScalarTableMCMem :: ComputeScalarTable MCMem (MC.MCOp NoOp (Aliases MCMem)) computeScalarTableMCMem _ (MC.OtherOp NoOp) = pure mempty computeScalarTableMCMem scope_table (MC.ParOp par_op segop) = (<>) <$> maybe (pure mempty) (computeScalarTableSegOp scope_table) par_op <*> computeScalarTableSegOp scope_table segop filterMapM1 :: (Eq k, Monad m) => (v -> m Bool) -> M.Map k v -> m (M.Map k v) filterMapM1 f m = fmap M.fromAscList $ filterM (f . snd) $ M.toAscList m futhark-0.25.27/src/Futhark/Optimise/ArrayShortCircuiting/DataStructs.hs000066400000000000000000000332071475065116200262470ustar00rootroot00000000000000{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeFamilies #-} module Futhark.Optimise.ArrayShortCircuiting.DataStructs ( Coalesced (..), CoalescedKind (..), ArrayMemBound (..), AllocTab, HasMemBlock, ScalarTab, CoalsTab, ScopeTab, CoalsEntry (..), FreeVarSubsts, LmadRef, MemRefs (..), AccessSummary (..), BotUpEnv (..), InhibitTab, unionCoalsEntry, vnameToPrimExp, getArrMemAssocFParam, getScopeMemInfo, createsNewArrOK, getArrMemAssoc, getUniqueMemFParam, markFailedCoal, accessSubtract, markSuccessCoal, ) where import Control.Applicative import Data.Functor ((<&>)) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.IR.Aliases import Futhark.IR.GPUMem as GPU import Futhark.IR.MCMem as MC import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.IR.SeqMem import Futhark.Util.Pretty hiding (line, sep, ()) import Prelude -- | maps array-variable names to various info, including -- types, memory block and index function, etc. type ScopeTab rep = Scope (Aliases rep) -- | An LMAD specialized to TPrimExps (a typed primexp) type LmadRef = LMAD.LMAD (TPrimExp Int64 VName) -- | Summary of all memory accesses at a given point in the code data AccessSummary = -- | The access summary was statically undeterminable, for instance by -- having multiple lmads. In this case, we should conservatively avoid all -- coalescing. Undeterminable | -- | A conservative estimate of the set of accesses up until this point. Set (S.Set LmadRef) instance Semigroup AccessSummary where Undeterminable <> _ = Undeterminable _ <> Undeterminable = Undeterminable (Set a) <> (Set b) = Set $ S.union a b instance Monoid AccessSummary where mempty = Set mempty instance FreeIn AccessSummary where freeIn' Undeterminable = mempty freeIn' (Set s) = freeIn' s accessSubtract :: AccessSummary -> AccessSummary -> AccessSummary accessSubtract Undeterminable _ = Undeterminable accessSubtract _ Undeterminable = Undeterminable accessSubtract (Set s1) (Set s2) = Set $ s1 S.\\ s2 data MemRefs = MemRefs { -- | The access summary of all references (reads -- and writes) to the destination of a coalescing entry dstrefs :: AccessSummary, -- | The access summary of all writes to the source of a coalescing entry srcwrts :: AccessSummary } instance Semigroup MemRefs where m1 <> m2 = MemRefs (dstrefs m1 <> dstrefs m2) (srcwrts m1 <> srcwrts m2) instance Monoid MemRefs where mempty = MemRefs mempty mempty data CoalescedKind = -- | let x = copy b^{lu} CopyCoal | -- | let x[i] = b^{lu} InPlaceCoal | -- | let x = concat(a, b^{lu}) ConcatCoal | -- | transitive, i.e., other variables aliased with b. TransitiveCoal | MapCoal -- | Information about a memory block: type, shape, name and ixfun. data ArrayMemBound = MemBlock { primType :: PrimType, shape :: Shape, memName :: VName, ixfun :: LMAD } -- | Free variable substitutions type FreeVarSubsts = M.Map VName (TPrimExp Int64 VName) -- | Coalesced Access Entry data Coalesced = Coalesced -- | the kind of coalescing CoalescedKind -- | destination mem_block info @f_m_x[i]@ (must be ArrayMem) -- (Maybe IxFun) -- the inverse ixfun of a coalesced array, such that -- -- ixfuns can be correctly constructed for aliases; ArrayMemBound -- | substitutions for free vars in index function FreeVarSubsts data CoalsEntry = CoalsEntry { -- | destination memory block dstmem :: VName, -- | index function of the destination (used for rebasing) dstind :: LMAD, -- | aliased destination memory blocks can appear -- due to repeated (optimistic) coalescing. alsmem :: Names, -- | per variable-name coalesced entries vartab :: M.Map VName Coalesced, -- | keys are variable names, values are memblock names; -- it records optimistically added coalesced nodes, e.g., -- in the case of if-then-else expressions. For example: -- @x = map f a@ -- @.. use of y ..@ -- @b = map g a@ -- @x[i] = b @ -- @y[k] = x @ -- the coalescing of @b@ in @x[i]@ succeeds, but -- is dependent of the success of the coalescing -- of @x@ in @y[k]@, which fails in this case -- because @y@ is used before the new array creation -- of @x = map f@. Hence @optdeps@ of the @m_b@ CoalsEntry -- records @x -> m_x@ and at the end of analysis it is removed -- from the successfully coalesced table if @m_x@ is -- unsuccessful. -- Storing @m_x@ would probably be sufficient if memory would -- not be reused--e.g., by register allocation on arrays--the -- @x@ discriminates between memory being reused across semantically -- different arrays (searched in @vartab@ field). optdeps :: M.Map VName VName, -- | Access summaries of uses and writes of destination and source -- respectively. memrefs :: MemRefs, -- | Certificates of the destination, which must be propagated to -- the source. When short-circuiting reaches the array creation -- point, we must check whether the certs are in scope for -- short-circuiting to succeed. certs :: Certs } -- | the allocatted memory blocks type AllocTab = M.Map VName Space -- | maps a variable name to its PrimExp scalar expression type ScalarTab = M.Map VName (PrimExp VName) -- | maps a memory-block name to a 'CoalsEntry'. Among other things, it contains -- @vartab@, a map in which each variable associated to that memory block is -- bound to its 'Coalesced' info. type CoalsTab = M.Map VName CoalsEntry -- | inhibited memory-block mergings from the key (memory block) -- to the value (set of memory blocks). type InhibitTab = M.Map VName Names data BotUpEnv = BotUpEnv { -- | maps scalar variables to theirs PrimExp expansion scals :: ScalarTab, -- | Optimistic coalescing info. We are currently trying to coalesce these -- memory blocks. activeCoals :: CoalsTab, -- | Committed (successfull) coalescing info. These memory blocks have been -- successfully coalesced. successCoals :: CoalsTab, -- | The coalescing failures from this pass. We will no longer try to merge -- these memory blocks. inhibit :: InhibitTab } instance Pretty CoalsTab where pretty = pretty . M.toList instance Pretty AccessSummary where pretty Undeterminable = "Undeterminable" pretty (Set a) = "Access-Set:" <+> pretty (S.toList a) <+> " " instance Pretty MemRefs where pretty (MemRefs a b) = "( Use-Sum:" <+> pretty a <+> "Write-Sum:" <+> pretty b <> ")" instance Pretty CoalescedKind where pretty CopyCoal = "Copy" pretty InPlaceCoal = "InPlace" pretty ConcatCoal = "Concat" pretty TransitiveCoal = "Transitive" pretty MapCoal = "Map" instance Pretty ArrayMemBound where pretty (MemBlock ptp shp m_nm ixfn) = "{" <> pretty ptp <> "," <+> pretty shp <> "," <+> pretty m_nm <> "," <+> pretty ixfn <> "}" instance Pretty Coalesced where pretty (Coalesced knd mbd _) = "(Kind:" <+> pretty knd <> ", membds:" <+> pretty mbd -- <> ", subs:" <+> pretty subs <> ")" <+> "\n" instance Pretty CoalsEntry where pretty etry = "{" <+> "Dstmem:" <+> pretty (dstmem etry) <> ", AliasMems:" <+> pretty (alsmem etry) <+> ", optdeps:" <+> pretty (M.toList $ optdeps etry) <+> ", memrefs:" <+> pretty (memrefs etry) <+> ", vartab:" <+> pretty (M.toList $ vartab etry) <+> "}" <+> "\n" -- | Compute the union of two 'CoalsEntry'. If two 'CoalsEntry' do not refer to -- the same destination memory and use the same index function, the first -- 'CoalsEntry' is returned. unionCoalsEntry :: CoalsEntry -> CoalsEntry -> CoalsEntry unionCoalsEntry etry1 (CoalsEntry dstmem2 dstind2 alsmem2 vartab2 optdeps2 memrefs2 certs2) = if dstmem etry1 /= dstmem2 || dstind etry1 /= dstind2 then etry1 else etry1 { alsmem = alsmem etry1 <> alsmem2, optdeps = optdeps etry1 <> optdeps2, vartab = vartab etry1 <> vartab2, memrefs = memrefs etry1 <> memrefs2, certs = certs etry1 <> certs2 } -- | Get the names of array 'PatElem's in a 'Pat' and the corresponding -- 'ArrayMemBound' information for each array. getArrMemAssoc :: Pat (aliases, LetDecMem) -> [(VName, ArrayMemBound)] getArrMemAssoc pat = mapMaybe ( \patel -> case snd $ patElemDec patel of (MemArray tp shp _ (ArrayIn mem_nm indfun)) -> Just (patElemName patel, MemBlock tp shp mem_nm indfun) MemMem _ -> Nothing MemPrim _ -> Nothing MemAcc {} -> Nothing ) $ patElems pat -- | Get the names of arrays in a list of 'FParam' and the corresponding -- 'ArrayMemBound' information for each array. getArrMemAssocFParam :: [Param FParamMem] -> [(VName, Uniqueness, ArrayMemBound)] getArrMemAssocFParam = mapMaybe ( \param -> case paramDec param of (MemArray tp shp u (ArrayIn mem_nm indfun)) -> Just (paramName param, u, MemBlock tp shp mem_nm indfun) MemMem _ -> Nothing MemPrim _ -> Nothing MemAcc {} -> Nothing ) -- | Get memory blocks in a list of 'FParam' that are used for unique arrays in -- the same list of 'FParam'. getUniqueMemFParam :: [Param FParamMem] -> M.Map VName Space getUniqueMemFParam params = let mems = M.fromList $ mapMaybe justMem params arrayMems = S.fromList $ mapMaybe (justArrayMem . paramDec) params in mems `M.restrictKeys` arrayMems where justMem (Param _ nm (MemMem sp)) = Just (nm, sp) justMem _ = Nothing justArrayMem (MemArray _ _ Unique (ArrayIn mem_nm _)) = Just mem_nm justArrayMem _ = Nothing class HasMemBlock rep where -- | Looks up 'VName' in the given scope. If it is a 'MemArray', return the -- 'ArrayMemBound' information for the array. getScopeMemInfo :: VName -> Scope rep -> Maybe ArrayMemBound instance HasMemBlock (Aliases SeqMem) where getScopeMemInfo r scope_env0 = case M.lookup r scope_env0 of Just (LetName (_, MemArray tp shp _ (ArrayIn m idx))) -> Just (MemBlock tp shp m idx) Just (FParamName (MemArray tp shp _ (ArrayIn m idx))) -> Just (MemBlock tp shp m idx) Just (LParamName (MemArray tp shp _ (ArrayIn m idx))) -> Just (MemBlock tp shp m idx) _ -> Nothing instance HasMemBlock (Aliases GPUMem) where getScopeMemInfo r scope_env0 = case M.lookup r scope_env0 of Just (LetName (_, MemArray tp shp _ (ArrayIn m idx))) -> Just (MemBlock tp shp m idx) Just (FParamName (MemArray tp shp _ (ArrayIn m idx))) -> Just (MemBlock tp shp m idx) Just (LParamName (MemArray tp shp _ (ArrayIn m idx))) -> Just (MemBlock tp shp m idx) _ -> Nothing instance HasMemBlock (Aliases MCMem) where getScopeMemInfo r scope_env0 = case M.lookup r scope_env0 of Just (LetName (_, MemArray tp shp _ (ArrayIn m idx))) -> Just (MemBlock tp shp m idx) Just (FParamName (MemArray tp shp _ (ArrayIn m idx))) -> Just (MemBlock tp shp m idx) Just (LParamName (MemArray tp shp _ (ArrayIn m idx))) -> Just (MemBlock tp shp m idx) _ -> Nothing -- | @True@ if the expression returns a "fresh" array. createsNewArrOK :: Exp rep -> Bool createsNewArrOK (BasicOp Replicate {}) = True createsNewArrOK (BasicOp Iota {}) = True createsNewArrOK (BasicOp Manifest {}) = True createsNewArrOK (BasicOp Concat {}) = True createsNewArrOK (BasicOp ArrayLit {}) = True createsNewArrOK (BasicOp ArrayVal {}) = True createsNewArrOK (BasicOp Scratch {}) = True createsNewArrOK _ = False -- | Memory-block removal from active-coalescing table -- should only be handled via this function, it is easy -- to run into infinite execution problem; i.e., the -- fix-pointed iteration of coalescing transformation -- assumes that whenever a coalescing fails it is -- recorded in the @inhibit@ table. markFailedCoal :: (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab) markFailedCoal (coal_tab, inhb_tab) src_mem = case M.lookup src_mem coal_tab of Nothing -> (coal_tab, inhb_tab) Just coale -> let failed_set = oneName $ dstmem coale failed_set' = failed_set <> fromMaybe mempty (M.lookup src_mem inhb_tab) in ( M.delete src_mem coal_tab, M.insert src_mem failed_set' inhb_tab ) -- | promotion from active-to-successful coalescing tables -- should be handled with this function (for clarity). markSuccessCoal :: (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab) markSuccessCoal (actv, succc) m_b info_b = ( M.delete m_b actv, appendCoalsInfo m_b info_b succc ) -- | merges entries in the coalesced table. appendCoalsInfo :: VName -> CoalsEntry -> CoalsTab -> CoalsTab appendCoalsInfo mb info_new coalstab = case M.lookup mb coalstab of Nothing -> M.insert mb info_new coalstab Just info_old -> M.insert mb (unionCoalsEntry info_old info_new) coalstab -- | Attempt to convert a 'VName' to a PrimExp. -- -- First look in 'ScalarTab' to see if we have recorded the scalar value of the -- argument. Otherwise look up the type of the argument and return a 'LeafExp' -- if it is a 'PrimType'. vnameToPrimExp :: (AliasableRep rep) => ScopeTab rep -> ScalarTab -> VName -> Maybe (PrimExp VName) vnameToPrimExp scopetab scaltab v = M.lookup v scaltab <|> ( M.lookup v scopetab >>= toPrimType . typeOf <&> LeafExp v ) -- | Attempt to extract the 'PrimType' from a 'TypeBase'. toPrimType :: TypeBase shp u -> Maybe PrimType toPrimType (Prim pt) = Just pt toPrimType _ = Nothing futhark-0.25.27/src/Futhark/Optimise/ArrayShortCircuiting/MemRefAggreg.hs000066400000000000000000000474661475065116200263120ustar00rootroot00000000000000{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} module Futhark.Optimise.ArrayShortCircuiting.MemRefAggreg ( recordMemRefUses, freeVarSubstitutions, translateAccessSummary, aggSummaryLoopTotal, aggSummaryLoopPartial, aggSummaryMapPartial, aggSummaryMapTotal, noMemOverlap, ) where import Control.Monad import Data.Function ((&)) import Data.List (intersect, partition, uncons) import Data.List.NonEmpty (NonEmpty (..)) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.Analysis.AlgSimplify import Futhark.Analysis.PrimExp.Convert import Futhark.IR.Aliases import Futhark.IR.Mem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.MonadFreshNames import Futhark.Optimise.ArrayShortCircuiting.DataStructs import Futhark.Optimise.ArrayShortCircuiting.TopdownAnalysis import Futhark.Util ----------------------------------------------------- -- Some translations of Accesses and Ixfuns -- ----------------------------------------------------- -- | Checks whether the index function can be translated at the current program -- point and also returns the substitutions. It comes down to answering the -- question: "can one perform enough substitutions (from the bottom-up scalar -- table) until all vars appearing in the index function are defined in the -- current scope?" freeVarSubstitutions :: (FreeIn a) => ScopeTab rep -> ScalarTab -> a -> Maybe FreeVarSubsts freeVarSubstitutions scope0 scals0 indfun = freeVarSubstitutions' mempty $ namesToList $ freeIn indfun where freeVarSubstitutions' :: FreeVarSubsts -> [VName] -> Maybe FreeVarSubsts freeVarSubstitutions' subs [] = Just subs freeVarSubstitutions' subs0 fvs = let fvs_not_in_scope = filter (`M.notMember` scope0) fvs in case mapAndUnzipM getSubstitution fvs_not_in_scope of -- We require that all free variables can be substituted Just (subs, new_fvs) -> freeVarSubstitutions' (subs0 <> mconcat subs) $ concat new_fvs Nothing -> Nothing getSubstitution v | Just pe <- M.lookup v scals0, IntType _ <- primExpType pe = Just (M.singleton v $ TPrimExp pe, namesToList $ freeIn pe) getSubstitution _v = Nothing -- | Translates free variables in an access summary translateAccessSummary :: ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary translateAccessSummary _ _ Undeterminable = Undeterminable translateAccessSummary scope0 scals0 (Set slmads) | Just subs <- freeVarSubstitutions scope0 scals0 slmads = slmads & S.map (LMAD.substitute subs) & Set translateAccessSummary _ _ _ = Undeterminable -- | This function computes the written and read memory references for the current statement getUseSumFromStm :: (Op rep ~ MemOp inner rep, HasMemBlock (Aliases rep)) => TopdownEnv rep -> CoalsTab -> Stm (Aliases rep) -> -- | A pair of written and written+read memory locations, along with their -- associated array and the index function used Maybe ([(VName, VName, LMAD)], [(VName, VName, LMAD)]) getUseSumFromStm td_env coal_tab (Let _ _ (BasicOp (Index arr (Slice slc)))) | Just (MemBlock _ shp _ _) <- getScopeMemInfo arr (scope td_env), length slc == length (shapeDims shp) && all isFix slc = do (mem_b, mem_arr, ixfn_arr) <- getDirAliasedIxfn td_env coal_tab arr let new_ixfn = LMAD.slice ixfn_arr $ Slice $ map (fmap pe64) slc pure ([], [(mem_b, mem_arr, new_ixfn)]) where isFix DimFix {} = True isFix _ = False getUseSumFromStm _ _ (Let Pat {} _ (BasicOp Index {})) = Just ([], []) -- incomplete slices getUseSumFromStm _ _ (Let Pat {} _ (BasicOp FlatIndex {})) = Just ([], []) -- incomplete slices getUseSumFromStm td_env coal_tab (Let (Pat pes) _ (BasicOp ArrayVal {})) = let wrts = mapMaybe (getDirAliasedIxfn td_env coal_tab . patElemName) pes in Just (wrts, wrts) getUseSumFromStm td_env coal_tab (Let (Pat pes) _ (BasicOp (ArrayLit ses _))) = let rds = mapMaybe (getDirAliasedIxfn td_env coal_tab) $ mapMaybe seName ses wrts = mapMaybe (getDirAliasedIxfn td_env coal_tab . patElemName) pes in Just (wrts, wrts ++ rds) where seName (Var a) = Just a seName (Constant _) = Nothing -- In place update @x[slc] <- a@. In the "in-place update" case, -- summaries should be added after the old variable @x@ has -- been added in the active coalesced table. getUseSumFromStm td_env coal_tab (Let (Pat [x']) _ (BasicOp (Update _ _x (Slice slc) a_se))) = do (m_b, m_x, x_ixfn) <- getDirAliasedIxfn td_env coal_tab (patElemName x') let x_ixfn_slc = LMAD.slice x_ixfn $ Slice $ map (fmap pe64) slc r1 = (m_b, m_x, x_ixfn_slc) case a_se of Constant _ -> Just ([r1], [r1]) Var a -> case getDirAliasedIxfn td_env coal_tab a of Nothing -> Just ([r1], [r1]) Just r2 -> Just ([r1], [r1, r2]) getUseSumFromStm td_env coal_tab (Let (Pat [y]) _ (BasicOp (Replicate (Shape []) (Var x)))) = do -- y = copy x wrt <- getDirAliasedIxfn td_env coal_tab $ patElemName y rd <- getDirAliasedIxfn td_env coal_tab x pure ([wrt], [wrt, rd]) getUseSumFromStm _ _ (Let Pat {} _ (BasicOp (Replicate (Shape []) _))) = error "Impossible" getUseSumFromStm td_env coal_tab (Let (Pat ys) _ (BasicOp (Concat _i (a :| bs) _ses))) = -- concat let ws = mapMaybe (getDirAliasedIxfn td_env coal_tab . patElemName) ys rs = mapMaybe (getDirAliasedIxfn td_env coal_tab) (a : bs) in Just (ws, ws ++ rs) getUseSumFromStm td_env coal_tab (Let (Pat ys) _ (BasicOp (Manifest _perm x))) = let ws = mapMaybe (getDirAliasedIxfn td_env coal_tab . patElemName) ys rs = mapMaybe (getDirAliasedIxfn td_env coal_tab) [x] in Just (ws, ws ++ rs) getUseSumFromStm td_env coal_tab (Let (Pat ys) _ (BasicOp (Replicate _shp se))) = let ws = mapMaybe (getDirAliasedIxfn td_env coal_tab . patElemName) ys in case se of Constant _ -> Just (ws, ws) Var x -> Just (ws, ws ++ mapMaybe (getDirAliasedIxfn td_env coal_tab) [x]) getUseSumFromStm td_env coal_tab (Let (Pat [x]) _ (BasicOp (FlatUpdate _ (FlatSlice offset slc) v))) | Just (m_b, m_x, x_ixfn) <- getDirAliasedIxfn td_env coal_tab (patElemName x) = do let x_ixfn_slc = LMAD.flatSlice x_ixfn $ FlatSlice (pe64 offset) $ map (fmap pe64) slc let r1 = (m_b, m_x, x_ixfn_slc) case getDirAliasedIxfn td_env coal_tab v of Nothing -> Just ([r1], [r1]) Just r2 -> Just ([r1], [r1, r2]) -- getUseSumFromStm td_env coal_tab (Let (Pat ys) _ (BasicOp bop)) = -- let wrt = mapMaybe (getDirAliasedIxfn td_env coal_tab . patElemName) ys -- in trace ("getUseBla: " <> show bop) $ pure (wrt, wrt) getUseSumFromStm td_env coal_tab (Let (Pat ys) _ (BasicOp Iota {})) = let wrt = mapMaybe (getDirAliasedIxfn td_env coal_tab . patElemName) ys in pure (wrt, wrt) getUseSumFromStm _ _ (Let Pat {} _ BasicOp {}) = Just ([], []) getUseSumFromStm _ _ (Let Pat {} _ (Op (Alloc _ _))) = Just ([], []) getUseSumFromStm _ _ _ = -- if-then-else, loops are supposed to be treated separately, -- calls are not supported, and Ops are not yet supported Nothing -- | This function: -- 1. computes the written and read memory references for the current statement -- (by calling @getUseSumFromStm@) -- 2. fails the entries in active coalesced table for which the write set -- overlaps the uses of the destination (to that point) recordMemRefUses :: (AliasableRep rep, Op rep ~ MemOp inner rep, HasMemBlock (Aliases rep)) => TopdownEnv rep -> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab) recordMemRefUses td_env bu_env stm = let active_tab = activeCoals bu_env inhibit_tab = inhibit bu_env active_etries = M.toList active_tab in case getUseSumFromStm td_env active_tab stm of Nothing -> M.toList active_tab & foldl ( \state (m_b, entry) -> if not $ null $ patNames (stmPat stm) `intersect` M.keys (vartab entry) then markFailedCoal state m_b else state ) (active_tab, inhibit_tab) Just use_sums -> let (mb_wrts, prev_uses, mb_lmads) = map (checkOverlapAndExpand use_sums active_tab) active_etries & unzip3 -- keep only the entries that do not overlap with the memory -- blocks defined in @pat@ or @inner_free_vars@. -- the others must be recorded in @inhibit_tab@ because -- they violate the 3rd safety condition. active_tab1 = M.fromList $ map ( \(wrts, (uses, prev_use, (k, etry))) -> let mrefs' = (memrefs etry) {dstrefs = prev_use} etry' = etry {memrefs = mrefs'} in (k, addLmads wrts uses etry') ) $ mapMaybe (\(x, y) -> (,y) <$> x) -- only keep successful coals $ zip mb_wrts $ zip3 mb_lmads prev_uses active_etries failed_tab = M.fromList $ map snd $ filter (isNothing . fst) $ zip mb_wrts active_etries (_, inhibit_tab1) = foldl markFailedCoal (failed_tab, inhibit_tab) $ M.keys failed_tab in (active_tab1, inhibit_tab1) where checkOverlapAndExpand (stm_wrts, stm_uses) active_tab (m_b, etry) = let alias_m_b = getAliases mempty m_b stm_uses' = filter ((`notNameIn` alias_m_b) . tupFst) stm_uses all_aliases = foldl getAliases mempty $ namesToList $ alsmem etry ixfns = map tupThd $ filter ((`nameIn` all_aliases) . tupSnd) stm_uses' lmads' = mapMaybe mbLmad ixfns lmads'' = if length lmads' == length ixfns then Set $ S.fromList lmads' else Undeterminable wrt_ixfns = map tupThd $ filter ((`nameIn` alias_m_b) . tupFst) stm_wrts wrt_tmps = mapMaybe mbLmad wrt_ixfns prev_use = translateAccessSummary (scope td_env) (scalarTable td_env) $ (dstrefs . memrefs) etry wrt_lmads' = if length wrt_tmps == length wrt_ixfns then Set $ S.fromList wrt_tmps else Undeterminable original_mem_aliases = fmap tupFst stm_uses & uncons & fmap fst & (=<<) (`M.lookup` active_tab) & maybe mempty alsmem (wrt_lmads'', lmads) = if m_b `nameIn` original_mem_aliases then (wrt_lmads' <> lmads'', Set mempty) else (wrt_lmads', lmads'') no_overlap = noMemOverlap td_env (lmads <> prev_use) wrt_lmads'' wrt_lmads = if no_overlap then Just wrt_lmads'' else Nothing in (wrt_lmads, prev_use, lmads) tupFst (a, _, _) = a tupSnd (_, b, _) = b tupThd (_, _, c) = c getAliases acc m = oneName m <> acc <> fromMaybe mempty (M.lookup m (m_alias td_env)) mbLmad indfun | Just subs <- freeVarSubstitutions (scope td_env) (scals bu_env) indfun = Just $ LMAD.substitute subs indfun mbLmad _ = Nothing addLmads wrts uses etry = etry {memrefs = MemRefs uses wrts <> memrefs etry} -- | Check for memory overlap of two access summaries. -- -- This check is conservative, so unless we can guarantee that there is no -- overlap, we return 'False'. noMemOverlap :: (AliasableRep rep) => TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool noMemOverlap _ _ (Set mr) | mr == mempty = True noMemOverlap _ (Set mr) _ | mr == mempty = True noMemOverlap td_env (Set is0) (Set js0) | Just non_negs <- mapM (primExpFromSubExpM (vnameToPrimExp (scope td_env) (scalarTable td_env)) . Var) $ namesToList $ nonNegatives td_env = let (_, not_disjoints) = partition ( \i -> all ( \j -> LMAD.disjoint less_thans (nonNegatives td_env) i j || LMAD.disjoint2 () () less_thans (nonNegatives td_env) i j || LMAD.disjoint3 (typeOf <$> scope td_env) asserts less_thans non_negs i j ) js ) is in null not_disjoints where less_thans = map (fmap $ fixPoint $ substituteInPrimExp $ scalarTable td_env) $ knownLessThan td_env asserts = map (fixPoint (substituteInPrimExp $ scalarTable td_env) . primExpFromSubExp Bool) $ td_asserts td_env is = map (fixPoint (LMAD.substitute $ TPrimExp <$> scalarTable td_env)) $ S.toList is0 js = map (fixPoint (LMAD.substitute $ TPrimExp <$> scalarTable td_env)) $ S.toList js0 noMemOverlap _ _ _ = False -- | Computes the total aggregated access summary for a loop by expanding the -- access summary given according to the iterator variable and bounds of the -- loop. -- -- Corresponds to: -- -- \[ -- \bigcup_{j=0}^{j ScopeTab rep -> ScopeTab rep -> ScalarTab -> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName)) -> AccessSummary -> m AccessSummary aggSummaryLoopTotal _ _ _ _ Undeterminable = pure Undeterminable aggSummaryLoopTotal _ _ _ _ (Set l) | l == mempty = pure $ Set mempty aggSummaryLoopTotal scope_bef scope_loop scals_loop _ access | Set ls <- translateAccessSummary scope_loop scals_loop access, nms <- foldl (<>) mempty $ map freeIn $ S.toList ls, all inBeforeScope $ namesToList nms = do pure $ Set ls where inBeforeScope v = case M.lookup v scope_bef of Nothing -> False Just _ -> True aggSummaryLoopTotal _ _ scalars_loop (Just (iterator_var, (lower_bound, upper_bound))) (Set lmads) = concatMapM ( aggSummaryOne iterator_var lower_bound upper_bound . fixPoint (LMAD.substitute $ fmap TPrimExp scalars_loop) ) (S.toList lmads) aggSummaryLoopTotal _ _ _ _ _ = pure Undeterminable -- | For a given iteration of the loop $i$, computes the aggregated loop access -- summary of all later iterations. -- -- Corresponds to: -- -- \[ -- \bigcup_{j=i+1}^{j ScalarTab -> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName)) -> AccessSummary -> m AccessSummary aggSummaryLoopPartial _ _ Undeterminable = pure Undeterminable aggSummaryLoopPartial _ Nothing _ = pure Undeterminable aggSummaryLoopPartial scalars_loop (Just (iterator_var, (_, upper_bound))) (Set lmads) = do -- map over each index function in the access summary -- Substitube a fresh variable k for the loop iterator -- if k is in stride or span of ixfun: fall back to total -- new_stride = old_offset - old_offset (where k+1 is substituted for k) -- new_offset = old_offset where k = lower bound of iteration -- new_span = upper bound of iteration concatMapM ( aggSummaryOne iterator_var (isInt64 (LeafExp iterator_var $ IntType Int64) + 1) (upper_bound - typedLeafExp iterator_var - 1) . fixPoint (LMAD.substitute $ fmap TPrimExp scalars_loop) ) (S.toList lmads) -- | For a given map with $k$ dimensions and an index $i$ for each dimension, -- compute the aggregated access summary of all other threads. -- -- For the innermost dimension, this corresponds to -- -- \[ -- \bigcup_{j=0}^{j ScalarTab -> [(VName, SubExp)] -> LmadRef -> m AccessSummary aggSummaryMapPartial _ [] = const $ pure mempty aggSummaryMapPartial scalars dims = helper mempty (reverse dims) . Set . S.singleton -- Reverse dims so we work from the inside out where helper acc [] _ = pure acc helper Undeterminable _ _ = pure Undeterminable helper _ _ Undeterminable = pure Undeterminable helper (Set acc) ((gtid, size) : rest) (Set as) = do partial_as <- aggSummaryMapPartialOne scalars (gtid, size) (Set as) total_as <- concatMapM (aggSummaryOne gtid 0 (TPrimExp $ primExpFromSubExp (IntType Int64) size)) (S.toList as) helper (Set acc <> partial_as) rest total_as -- | Given an access summary $a$, a thread id $i$ and the size $n$ of the -- dimension, compute the partial map summary. -- -- Corresponds to -- -- \[ -- \bigcup_{j=0}^{j ScalarTab -> (VName, SubExp) -> AccessSummary -> m AccessSummary aggSummaryMapPartialOne _ _ Undeterminable = pure Undeterminable aggSummaryMapPartialOne _ (_, Constant n) (Set _) | oneIsh n = pure mempty aggSummaryMapPartialOne scalars (gtid, size) (Set lmads0) = concatMapM helper [ (0, isInt64 (LeafExp gtid $ IntType Int64)), ( isInt64 (LeafExp gtid $ IntType Int64) + 1, isInt64 (primExpFromSubExp (IntType Int64) size) - isInt64 (LeafExp gtid $ IntType Int64) - 1 ) ] where lmads = map (fixPoint (LMAD.substitute $ fmap TPrimExp scalars)) $ S.toList lmads0 helper (x, y) = concatMapM (aggSummaryOne gtid x y) lmads -- | Computes to total access summary over a multi-dimensional map. aggSummaryMapTotal :: (MonadFreshNames m) => ScalarTab -> [(VName, SubExp)] -> AccessSummary -> m AccessSummary aggSummaryMapTotal _ [] _ = pure mempty aggSummaryMapTotal _ _ (Set lmads) | lmads == mempty = pure mempty aggSummaryMapTotal _ _ Undeterminable = pure Undeterminable aggSummaryMapTotal scalars segspace (Set lmads0) = foldM ( \as' (gtid', size') -> case as' of Set lmads' -> concatMapM ( aggSummaryOne gtid' 0 $ TPrimExp $ primExpFromSubExp (IntType Int64) size' ) (S.toList lmads') Undeterminable -> pure Undeterminable ) (Set lmads) (reverse segspace) where lmads = S.fromList $ map (fixPoint (LMAD.substitute $ fmap TPrimExp scalars)) $ S.toList lmads0 -- | Helper function that aggregates the accesses of single LMAD according to a -- given iterator value, a lower bound and a span. -- -- If successful, the result is an index function with an extra outer -- dimension. The stride of the outer dimension is computed by taking the -- difference between two points in the index function. -- -- The function returns 'Underterminable' if the iterator is free in the output -- LMAD or the dimensions of the input LMAD . aggSummaryOne :: (MonadFreshNames m) => VName -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> LmadRef -> m AccessSummary aggSummaryOne iterator_var lower_bound spn lmad@(LMAD.LMAD offset0 dims0) | iterator_var `nameIn` freeIn dims0 = pure Undeterminable | iterator_var `notNameIn` freeIn offset0 = pure $ Set $ S.singleton lmad | otherwise = do new_var <- newVName "k" let offset = replaceIteratorWith (typedLeafExp new_var) offset0 offsetp1 = replaceIteratorWith (typedLeafExp new_var + 1) offset0 new_stride = TPrimExp $ constFoldPrimExp $ simplify $ untyped $ offsetp1 - offset new_offset = replaceIteratorWith lower_bound offset0 new_lmad = LMAD.LMAD new_offset $ LMAD.LMADDim new_stride spn : dims0 if new_var `nameIn` freeIn new_lmad then pure Undeterminable else pure $ Set $ S.singleton new_lmad where replaceIteratorWith se = TPrimExp . substituteInPrimExp (M.singleton iterator_var $ untyped se) . untyped -- | Takes a 'VName' and converts it into a 'TPrimExp' with type 'Int64'. typedLeafExp :: VName -> TPrimExp Int64 VName typedLeafExp vname = isInt64 $ LeafExp vname (IntType Int64) futhark-0.25.27/src/Futhark/Optimise/ArrayShortCircuiting/TopdownAnalysis.hs000066400000000000000000000277661475065116200271610ustar00rootroot00000000000000{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TypeFamilies #-} module Futhark.Optimise.ArrayShortCircuiting.TopdownAnalysis ( TopdownEnv (..), ScopeTab, TopDownHelper, InhibitTab, updateTopdownEnv, updateTopdownEnvLoop, getDirAliasedIxfn, getDirAliasedIxfn', addInvAliasesVarTab, areAnyAliased, isInScope, nonNegativesInPat, ) where import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Analysis.PrimExp.Convert import Futhark.IR.Aliases import Futhark.IR.GPUMem as GPU import Futhark.IR.MCMem as MC import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.Optimise.ArrayShortCircuiting.DataStructs type DirAlias = LMAD -> Maybe LMAD -- ^ A direct aliasing transformation type InvAlias = Maybe (LMAD -> LMAD) -- ^ An inverse aliasing transformation type VarAliasTab = M.Map VName (VName, DirAlias, InvAlias) type MemAliasTab = M.Map VName Names data TopdownEnv rep = TopdownEnv { -- | contains the already allocated memory blocks alloc :: AllocTab, -- | variable info, including var-to-memblock assocs scope :: ScopeTab rep, -- | the inherited inhibitions from the previous try inhibited :: InhibitTab, -- | for statements such as transpose, reshape, index, etc., that alias -- an array variable: maps var-names to pair of aliased var name -- and index function transformation. For example, for -- @let b = a[slc]@ it should add the binding -- @ b |-> (a, `slice` slc )@ v_alias :: VarAliasTab, -- | keeps track of memory block aliasing. -- this needs to be implemented m_alias :: MemAliasTab, -- | Contains symbol information about the variables in the program. Used to -- determine if a variable is non-negative. nonNegatives :: Names, scalarTable :: M.Map VName (PrimExp VName), -- | A list of known relations of the form 'VName' @<@ 'SubExp', typically -- gotten from 'LoopForm' and 'SegSpace'. knownLessThan :: [(VName, PrimExp VName)], -- | A list of the asserts encountered so far td_asserts :: [SubExp] } isInScope :: TopdownEnv rep -> VName -> Bool isInScope td_env m = m `M.member` scope td_env -- | Get alias and (direct) index function mapping from expression getDirAliasFromExp :: Exp (Aliases rep) -> Maybe (VName, DirAlias) getDirAliasFromExp (BasicOp (SubExp (Var x))) = Just (x, Just) getDirAliasFromExp (BasicOp (Opaque _ (Var x))) = Just (x, Just) getDirAliasFromExp (BasicOp (Reshape ReshapeCoerce shp x)) = Just (x, Just . (`LMAD.coerce` shapeDims (fmap pe64 shp))) getDirAliasFromExp (BasicOp (Reshape ReshapeArbitrary shp x)) = Just (x, (`LMAD.reshape` shapeDims (fmap pe64 shp))) getDirAliasFromExp (BasicOp (Rearrange _ _)) = Nothing getDirAliasFromExp (BasicOp (Index x slc)) = Just (x, Just . (`LMAD.slice` (Slice $ map (fmap pe64) $ unSlice slc))) getDirAliasFromExp (BasicOp (Update _ x _ _elm)) = Just (x, Just) getDirAliasFromExp (BasicOp (FlatIndex x (FlatSlice offset idxs))) = Just ( x, Just . (`LMAD.flatSlice` FlatSlice (pe64 offset) (map (fmap pe64) idxs)) ) getDirAliasFromExp (BasicOp (FlatUpdate x _ _)) = Just (x, Just) getDirAliasFromExp _ = Nothing -- | This was former @createsAliasedArrOK@ from DataStructs -- While Rearrange creates aliased arrays, we -- do not yet support them because it would mean we have -- to "reverse" the index function, for example to support -- coalescing in the case below, -- @let a = map f a0 @ -- @let b = transpose a@ -- @let y[4] = copy(b) @ -- we would need to assign to @a@ as index function, the -- inverse of the transpose, such that, when creating @b@ -- by transposition we get a directly-mapped array, which -- is expected by the copying in y[4]. -- For the moment we support only transposition and VName-expressions, -- but rotations and full slices could also be supported. -- -- This function complements 'getDirAliasFromExp' by returning a function that -- applies the inverse index function transformation. getInvAliasFromExp :: Exp (Aliases rep) -> InvAlias getInvAliasFromExp (BasicOp (SubExp (Var _))) = Just id getInvAliasFromExp (BasicOp (Opaque _ (Var _))) = Just id getInvAliasFromExp (BasicOp Update {}) = Just id getInvAliasFromExp (BasicOp (Rearrange perm _)) = Just (`LMAD.permute` rearrangeInverse perm) getInvAliasFromExp _ = Nothing class TopDownHelper inner where innerNonNegatives :: [VName] -> inner -> Names innerKnownLessThan :: inner -> [(VName, PrimExp VName)] scopeHelper :: inner -> Scope rep instance TopDownHelper (SegOp lvl rep) where innerNonNegatives _ op = foldMap (oneName . fst) $ unSegSpace $ segSpace op innerKnownLessThan op = map (fmap $ primExpFromSubExp $ IntType Int64) $ unSegSpace $ segSpace op scopeHelper op = scopeOfSegSpace $ segSpace op instance TopDownHelper (HostOp NoOp (Aliases GPUMem)) where innerNonNegatives vs (SegOp op) = innerNonNegatives vs op innerNonNegatives [vname] (SizeOp (GetSize _ _)) = oneName vname innerNonNegatives [vname] (SizeOp (GetSizeMax _)) = oneName vname innerNonNegatives _ _ = mempty innerKnownLessThan (SegOp op) = innerKnownLessThan op innerKnownLessThan _ = mempty scopeHelper (SegOp op) = scopeHelper op scopeHelper _ = mempty instance (TopDownHelper (inner (Aliases MCMem))) => TopDownHelper (MC.MCOp inner (Aliases MCMem)) where innerNonNegatives vs (ParOp par_op op) = maybe mempty (innerNonNegatives vs) par_op <> innerNonNegatives vs op innerNonNegatives vs (MC.OtherOp op) = innerNonNegatives vs op innerKnownLessThan (ParOp par_op op) = maybe mempty innerKnownLessThan par_op <> innerKnownLessThan op innerKnownLessThan (MC.OtherOp op) = innerKnownLessThan op scopeHelper (ParOp par_op op) = maybe mempty scopeHelper par_op <> scopeHelper op scopeHelper MC.OtherOp {} = mempty instance TopDownHelper (NoOp rep) where innerNonNegatives _ NoOp = mempty innerKnownLessThan NoOp = mempty scopeHelper NoOp = mempty -- | fills in the TopdownEnv table updateTopdownEnv :: (ASTRep rep, Op rep ~ MemOp inner rep, TopDownHelper (inner (Aliases rep))) => TopdownEnv rep -> Stm (Aliases rep) -> TopdownEnv rep updateTopdownEnv env stm@(Let (Pat [pe]) _ (Op (Alloc (Var vname) sp))) = env { alloc = M.insert (patElemName pe) sp $ alloc env, scope = scope env <> scopeOf stm, nonNegatives = nonNegatives env <> oneName vname } updateTopdownEnv env stm@(Let pat _ (Op (Inner inner))) = env { scope = scope env <> scopeOf stm <> scopeHelper inner, nonNegatives = nonNegatives env <> innerNonNegatives (patNames pat) inner, knownLessThan = knownLessThan env <> innerKnownLessThan inner } updateTopdownEnv env stm@(Let (Pat _) _ (BasicOp (Assert se _ _))) = env { scope = scope env <> scopeOf stm, td_asserts = se : td_asserts env } updateTopdownEnv env stm@(Let (Pat [pe]) _ e) | Just (x, ixfn) <- getDirAliasFromExp e = let ixfn_inv = getInvAliasFromExp e in env { v_alias = M.insert (patElemName pe) (x, ixfn, ixfn_inv) (v_alias env), scope = scope env <> scopeOf stm, nonNegatives = nonNegatives env <> nonNegativesInPat (stmPat stm) } updateTopdownEnv env stm = env { scope = scope env <> scopeOf stm, nonNegatives = nonNegatives env <> nonNegativesInPat (stmPat stm) } nonNegativesInPat :: (Typed rep) => Pat rep -> Names nonNegativesInPat (Pat elems) = foldMap (namesFromList . mapMaybe subExpVar . arrayDims . typeOf) elems -- | The topdown handler for loops. updateTopdownEnvLoop :: TopdownEnv rep -> [(FParam rep, SubExp)] -> LoopForm -> TopdownEnv rep updateTopdownEnvLoop td_env arginis lform = let scopetab = scope td_env <> scopeOfFParams (map fst arginis) <> scopeOfLoopForm lform non_negatives = nonNegatives td_env <> case lform of ForLoop v _ _ -> oneName v _ -> mempty less_than = case lform of ForLoop v _ b -> [(v, primExpFromSubExp (IntType Int64) b)] _ -> mempty in td_env { scope = scopetab, nonNegatives = non_negatives, knownLessThan = less_than <> knownLessThan td_env } -- | Get direct aliased index function. Returns a triple of current memory -- block to be coalesced, the destination memory block and the index function of -- the access in the space of the destination block. getDirAliasedIxfn :: (HasMemBlock (Aliases rep)) => TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, LMAD) getDirAliasedIxfn td_env coals_tab x = case getScopeMemInfo x (scope td_env) of Just (MemBlock _ _ m_x orig_ixfun) -> case M.lookup m_x coals_tab of Just coal_etry -> do (Coalesced _ (MemBlock _ _ m ixf) _) <- walkAliasTab (v_alias td_env) (vartab coal_etry) x pure (m_x, m, ixf) Nothing -> -- This value is not subject to coalescing at the moment. Just return the -- original index function Just (m_x, m_x, orig_ixfun) Nothing -> Nothing -- | Like 'getDirAliasedIxfn', but this version returns 'Nothing' if the value -- is not currently subject to coalescing. getDirAliasedIxfn' :: (HasMemBlock (Aliases rep)) => TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, LMAD) getDirAliasedIxfn' td_env coals_tab x = case getScopeMemInfo x (scope td_env) of Just (MemBlock _ _ m_x _) -> case M.lookup m_x coals_tab of Just coal_etry -> do (Coalesced _ (MemBlock _ _ m ixf) _) <- walkAliasTab (v_alias td_env) (vartab coal_etry) x pure (m_x, m, ixf) Nothing -> -- This value is not subject to coalescing at the moment. Just return the -- original index function Nothing Nothing -> Nothing -- | Given a 'VName', walk the 'VarAliasTab' until found in the 'Map'. walkAliasTab :: VarAliasTab -> M.Map VName Coalesced -> VName -> Maybe Coalesced walkAliasTab _ vtab x | Just c <- M.lookup x vtab = Just c -- @x@ is in @vartab@ together with its new ixfun walkAliasTab alias_tab vtab x | Just (x0, alias0, _) <- M.lookup x alias_tab = do Coalesced knd (MemBlock pt shp vname ixf) substs <- walkAliasTab alias_tab vtab x0 ixf' <- alias0 ixf pure $ Coalesced knd (MemBlock pt shp vname ixf') substs walkAliasTab _ _ _ = Nothing -- | We assume @x@ is in @vartab@ and we add the variables that @x@ aliases -- for as long as possible following a chain of direct-aliasing operators, -- i.e., without considering aliasing of if-then-else, loops, etc. For example: -- @ x0 = if c then ... else ...@ -- @ x1 = rearrange r1 x0 @ -- @ x2 = reverse x1@ -- @ y[slc] = x2 @ -- We assume @vartab@ constains a binding for @x2@, and calling this function -- with @x2@ as argument should also insert entries for @x1@ and @x0@ to -- @vartab@, of course if their aliasing operations are invertible. -- We assume inverting aliases has been performed by the top-down pass. addInvAliasesVarTab :: (HasMemBlock (Aliases rep)) => TopdownEnv rep -> M.Map VName Coalesced -> VName -> Maybe (M.Map VName Coalesced) addInvAliasesVarTab td_env vtab x | Just (Coalesced _ (MemBlock _ _ m_y x_ixfun) fv_subs) <- M.lookup x vtab = case M.lookup x (v_alias td_env) of Nothing -> Just vtab Just (_, _, Nothing) -> Nothing -- can't invert ixfun, conservatively fail! Just (x0, _, Just inv_alias0) -> let x_ixfn0 = inv_alias0 x_ixfun in case getScopeMemInfo x0 (scope td_env) of Nothing -> error "impossible" Just (MemBlock ptp shp _ _) -> let coal = Coalesced TransitiveCoal (MemBlock ptp shp m_y x_ixfn0) fv_subs vartab' = M.insert x0 coal vtab in addInvAliasesVarTab td_env vartab' x0 addInvAliasesVarTab _ _ _ = Nothing areAliased :: TopdownEnv rep -> VName -> VName -> Bool areAliased _ m_x m_y = -- this is a dummy implementation m_x == m_y areAnyAliased :: TopdownEnv rep -> VName -> [VName] -> Bool areAnyAliased td_env m_x = any (areAliased td_env m_x) futhark-0.25.27/src/Futhark/Optimise/BlkRegTiling.hs000066400000000000000000001631271475065116200222110ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Perform a restricted form of block+register tiling corresponding to -- the following pattern: -- * a redomap is quasi-perfectly nested inside a kernel with at -- least two parallel dimension (the perfectly nested restriction -- is relaxed a bit to allow for SGEMM); -- * all streamed arrays of redomap are one dimensional; -- * all streamed arrays are variant to exacly one of the two -- innermost parallel dimensions, and conversely for each of -- the two innermost parallel dimensions, there is at least -- one streamed array variant to it; -- * the stream's result is a tuple of scalar values, which are -- also the "thread-in-space" return of the kernel. -- * We have further restrictions that in principle can be relaxed: -- the redomap has exactly two array input -- the redomap produces one scalar result -- the kernel produces one scalar result module Futhark.Optimise.BlkRegTiling (mmBlkRegTiling, doRegTiling3D) where import Control.Monad import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Data.Sequence qualified as Seq import Futhark.IR.GPU import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.MonadFreshNames import Futhark.Optimise.TileLoops.Shared import Futhark.Tools import Futhark.Transform.Rename import Futhark.Transform.Substitute se0 :: SubExp se0 = intConst Int64 0 se1 :: SubExp se1 = intConst Int64 1 se2 :: SubExp se2 = intConst Int64 2 se4 :: SubExp se4 = intConst Int64 4 se8 :: SubExp se8 = intConst Int64 8 isInnerCoal :: Env -> VName -> Stm GPU -> Bool isInnerCoal (_, ixfn_env) slc_X (Let (Pat [pe]) _ (BasicOp (Index x _))) | slc_X == patElemName pe, Nothing <- M.lookup x ixfn_env = True -- if not in the table, we assume not-transposed! isInnerCoal (_, ixfn_env) slc_X (Let (Pat [pe]) _ (BasicOp (Index x _))) | slc_X == patElemName pe, Just lmad <- M.lookup x ixfn_env = innerHasStride1 lmad where innerHasStride1 lmad = let lmad_dims = LMAD.dims lmad stride = LMAD.ldStride $ last lmad_dims in stride == pe64 (intConst Int64 1) isInnerCoal _ _ _ = error "kkLoopBody.isInnerCoal: not an error, but I would like to know why!" scratch :: (MonadBuilder m) => String -> PrimType -> [SubExp] -> m VName scratch se_name t shape = letExp se_name $ BasicOp $ Scratch t shape -- | Main helper function for Register-and-Block Tiling kkLoopBody :: Env -> ( (SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp), SegLevel, [Int], (VName, SubExp, VName, SubExp, SubExp), (VName, VName), (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType), (Lambda GPU, Lambda GPU) ) -> VName -> (VName, VName, VName) -> Bool -> Builder GPU [VName] kkLoopBody env ( (rx, ry, tx, ty, tk, tk_div_tx, _tk_div_ty, tx_rx), segthd_lvl, var_dims, (gtid_x, width_B, gtid_y, height_A, common_dim), (iii, jjj), (load_A, inp_A, pt_A, load_B, inp_B, pt_B), (map_lam, red_lam) ) kk0 (thd_res_merge, a_loc_init', b_loc_init') epilogue = do let (map_t1, map_t2) = (pt_A, pt_B) kk <- letExp "kk" =<< toExp (le64 kk0 * pe64 tk) -- copy A to shared memory (a_loc, aCopyLoc2Reg) <- copyGlb2ShMem False kk (gtid_y, iii, map_t1, height_A, inp_A, load_A, a_loc_init') -- copy B from global to shared memory (b_loc, bCopyLoc2Reg) <- copyGlb2ShMem True kk (gtid_x, jjj, map_t2, width_B, inp_B, load_B, b_loc_init') -- inner loop updating this thread's accumulator (loop k in mmm_kernels). thd_acc <- mkRedomapOneTileBody kk thd_res_merge aCopyLoc2Reg bCopyLoc2Reg True pure [thd_acc, a_loc, b_loc] where mk_ik is_B is_coal (thd_y, thd_x) (i0, k0) | is_coal = do -- not-transposed case (i.e., already coalesced) let (t_par, t_seq) = (tx, tk) k <- letExp "k" =<< toExp (le64 thd_x + le64 k0 * pe64 t_par) i <- letExp "i" =<< toExp (le64 thd_y + le64 i0 * pe64 t_par) -- to optimize bank conflicts, we use padding only for B -- iff B has the last dimension permuted. let pad_term = if is_B then pe64 se1 else pe64 se0 let e = le64 k + le64 i * (pe64 t_seq + pad_term) pure (i, k, e) mk_ik _ _ (thd_y, thd_x) (i0, k0) = do -- matrix is transposed case (i.e., uncoalesced): let (t_par, tr_par) = (tx, tx_rx) k <- letExp "k" =<< toExp (le64 thd_y + le64 k0 * pe64 t_par) i <- letExp "i" =<< toExp (le64 thd_x + le64 i0 * pe64 t_par) -- no padding let e = le64 i + le64 k * pe64 tr_par pure (i, k, e) -- mkCompLoopRxRy fits_ij css_init (a_idx_fn, b_idx_fn) (ltid_y, ltid_x) = do css <- forLoop ry [css_init] $ \i [css_merge] -> do css <- forLoop rx [css_merge] $ \j [css_merge'] -> (resultBodyM <=< letTupExp' "foo") =<< eIf ( toExp $ if fits_ij then true else -- this condition is never needed because -- if i and j are out of range than css[i,j] -- is garbage anyways and should not be written. -- so fits_ij should be always true!!! (le64 iii + le64 i + pe64 ry * le64 ltid_y .<. pe64 height_A) .&&. (le64 jjj + le64 j + pe64 rx * le64 ltid_x .<. pe64 width_B) ) ( do a <- a_idx_fn ltid_y i b <- b_idx_fn ltid_x j c <- index "c" css_merge' [i, j] map_lam' <- renameLambda map_lam red_lam' <- renameLambda red_lam -- the inputs to map are supposed to be permutted with the -- inverted permutation, so as to reach the original position; -- it just so happens that the inverse of [a,b] is [b,a] let map_inp_reg = if var_dims == [0, 1] then [a, b] else [b, a] map_res <- eLambda map_lam' (map (eSubExp . Var) map_inp_reg) ~[red_res] <- eLambda red_lam' (map eSubExp $ Var c : map resSubExp map_res) css <- update "css" css_merge' [i, j] (resSubExp red_res) resultBodyM [Var css] ) (resultBodyM [Var css_merge']) resultBodyM [Var css] resultBodyM [Var css] -- mkRedomapOneTileBody kk css_merge a_idx_fn b_idx_fn fits_ij = do -- the actual redomap. redomap_res <- segMap2D "redomap_res" segthd_lvl ResultPrivate (ty, tx) $ \(ltid_y, ltid_x) -> do css_init <- index "css_init" css_merge [ltid_y, ltid_x] css <- forLoop tk [css_init] $ \k [acc_merge] -> (resultBodyM <=< letTupExp' "foo") =<< eIf ( toExp $ if epilogue then le64 kk + le64 k .<. pe64 common_dim else true -- if in prologue, always compute redomap. ) (mkCompLoopRxRy fits_ij acc_merge (a_idx_fn k, b_idx_fn k) (ltid_y, ltid_x)) (resultBodyM [Var acc_merge]) pure [varRes css] pure $ head redomap_res -- copyGlb2ShMem :: Bool -> VName -> (VName, VName, PrimType, SubExp, VName, Stm GPU, VName) -> Builder GPU (VName, VName -> VName -> VName -> Builder GPU VName) copyGlb2ShMem is_B kk (gtid, ii, ptp_X_el, parlen_X, inp_X, load_X, x_loc_init') = do let (t_par, r_par, tseq_div_tpar) = (tx, rx, tk_div_tx) is_inner_coal = isInnerCoal env inp_X load_X str_A = baseString inp_X x_loc <- segScatter2D (str_A ++ "_glb2loc") x_loc_init' [r_par, tseq_div_tpar] (t_par, t_par) $ scatterFun is_inner_coal pure (x_loc, indexLocMem is_inner_coal str_A x_loc) where indexLocMem :: Bool -> String -> VName -> VName -> VName -> VName -> Builder GPU VName indexLocMem is_inner_coal str_A x_loc k ltid_yx ij = do let (r_par, t_seq, tr_par) = (rx, tk, tx_rx) let pad_term = if is_B then pe64 se1 else pe64 se0 x_loc_ind_32 <- letExp (str_A ++ "_loc_ind_64") =<< toExp ( if is_inner_coal -- ToDo: check this is correct + turn to i32 then le64 k + (le64 ltid_yx * pe64 r_par + le64 ij) * (pe64 t_seq + pad_term) else le64 ij + le64 ltid_yx * pe64 r_par + le64 k * pe64 tr_par ) index (str_A ++ "_loc_elem") x_loc [x_loc_ind_32] -- scatterFun :: Bool -> [VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp) scatterFun is_inner_coal [i0, k0] (thd_y, thd_x) = do let str_A = baseString inp_X t_seq = tk (i, k, epx_loc_fi) <- mk_ik is_B is_inner_coal (thd_y, thd_x) (i0, k0) letBindNames [gtid] =<< toExp (le64 ii + le64 i) a_seqdim_idx <- letExp (str_A ++ "_seqdim_idx") =<< toExp (le64 kk + le64 k) a_elem <- letSubExp (str_A ++ "_elem") =<< eIf ( toExp $ le64 gtid .<. pe64 parlen_X .&&. if epilogue then le64 a_seqdim_idx .<. pe64 common_dim else true ) ( do addStm load_X res <- index "A_elem" inp_X [a_seqdim_idx] resultBodyM [Var res] ) (eBody [eBlank $ Prim ptp_X_el]) a_loc_ind <- letSubExp (str_A ++ "_loc_ind") =<< eIf (toExp $ le64 k .<. pe64 t_seq) (eBody [toExp epx_loc_fi]) (eBody [eSubExp $ intConst Int64 (-1)]) pure (a_elem, a_loc_ind) scatterFun _ _ _ = do error "Function scatterFun in Shared.hs: 2nd arg should be an array with 2 elements!" -- ToDo: we need tx == ty (named t_par), and rx == ry (named r_par) -- in order to handle all the cases without transpositions. -- additionally, of course, we need that tk is a multiple of t_par. mmBlkRegTiling :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU)) mmBlkRegTiling env stm = do res <- mmBlkRegTilingAcc env stm case res of Nothing -> mmBlkRegTilingNrm env stm _ -> pure res mmBlkRegTilingAcc :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU)) mmBlkRegTilingAcc env (Let pat aux (Op (SegOp (SegMap SegThread {} seg_space ts old_kbody)))) | KernelBody () kstms [Returns ResultMaySimplify cs (Var res_nm)] <- old_kbody, cs == mempty, -- check kernel has one result of primitive type [res_tp] <- ts, isAcc res_tp, -- we get the global-thread id for the two inner dimensions, -- as we are probably going to use it in code generation (gtid_x, width_B) : (gtid_y, height_A) : rem_outer_dims_rev <- reverse $ unSegSpace seg_space, rem_outer_dims <- reverse rem_outer_dims_rev, Just ( code2', (load_A, inp_A, map_t1, load_B, inp_B, map_t2), common_dim, var_dims, (map_lam, red_lam, red_ne, redomap_orig_res, red_t) ) <- matchesBlkRegTile seg_space kstms, checkAccumulatesRedomapRes res_nm code2' redomap_orig_res = do -- Here we start the implementation -- let is_B_coal = isInnerCoal env inp_B load_B ---- in this binder: host code and outer seggroup (ie. the new kernel) ---- (new_kernel, host_stms) <- runBuilder $ do -- host code (rx, ry, tx, ty, tk, tk_div_tx, tk_div_ty, tx_rx, ty_ry, a_loc_sz, b_loc_sz) <- mkTileMemSizes height_A width_B common_dim is_B_coal rk <- letSubExp "rk" $ BasicOp $ SubExp $ intConst Int64 8 -- 16 and 8 seem good values tk_rk <- letSubExp "tk_rk" =<< toExp (pe64 tk * pe64 rk) gridDim_t <- letSubExp "gridDim_t" =<< ceilDiv common_dim tk_rk gridDim_y <- letSubExp "gridDim_y" =<< ceilDiv height_A ty_ry gridDim_x <- letSubExp "gridDim_x" =<< ceilDiv width_B tx_rx let gridxyt_pexp = pe64 gridDim_y * pe64 gridDim_x * pe64 gridDim_t grid_pexp = foldl (\x d -> pe64 d * x) gridxyt_pexp $ map snd rem_outer_dims_rev (grid_size, tblock_size, segthd_lvl) <- mkNewSegthdLvl tx ty grid_pexp (gid_x, gid_y, gid_flat) <- mkGidsXYF gid_t <- newVName "gid_t" ---- in this binder: outer seggroup ---- (ret_seggroup, stms_seggroup) <- runBuilder $ do iii <- letExp "iii" =<< toExp (le64 gid_y * pe64 ty_ry) jjj <- letExp "jjj" =<< toExp (le64 gid_x * pe64 tx_rx) ttt <- letExp "ttt" =<< toExp (le64 gid_t * pe64 tk_rk) -- initialize register mem with neutral elements and create shmem (cssss, a_loc_init, b_loc_init) <- initRegShmem (rx, tx, ry, ty, a_loc_sz, b_loc_sz) (map_t1, map_t2, red_t) segthd_lvl red_ne -- build prologue. elems_on_t <- letSubExp "elems_on_t" =<< toExp (pe64 common_dim - le64 ttt) tiles_on_t <- letSubExp "tiles_on_t" $ BasicOp $ BinOp (SQuot Int64 Unsafe) elems_on_t tk full_tiles <- letExp "full_tiles" $ BasicOp $ BinOp (SMin Int64) rk tiles_on_t let ct_arg = ( (rx, ry, tx, ty, tk, tk_div_tx, tk_div_ty, tx_rx), segthd_lvl, var_dims, (gtid_x, width_B, gtid_y, height_A, common_dim), (iii, jjj), (load_A, inp_A, map_t1, load_B, inp_B, map_t2), (map_lam, red_lam) ) prologue_res_list <- forLoop' (Var full_tiles) [cssss, a_loc_init, b_loc_init] $ \kk0 [thd_res_merge, a_loc_merge, b_loc_merge] -> do off_t <- letExp "off_t" =<< toExp (pe64 rk * le64 gid_t + le64 kk0) process_full_tiles <- kkLoopBody env ct_arg off_t (thd_res_merge, a_loc_merge, b_loc_merge) False resultBodyM $ map Var process_full_tiles let prologue_res : a_loc_reuse : b_loc_reuse : _ = prologue_res_list redomap_res_lst <- letTupExp "redomap_res_if" =<< eIf ( toExp $ le64 full_tiles .==. pe64 rk .||. pe64 common_dim .==. (pe64 tk * le64 full_tiles + le64 ttt) ) (resultBodyM $ map Var prologue_res_list) ( do off_t <- letExp "off_t" =<< toExp (pe64 rk * le64 gid_t + le64 full_tiles) process_sprs_tile <- kkLoopBody env ct_arg off_t (prologue_res, a_loc_reuse, b_loc_reuse) True resultBodyM $ map Var process_sprs_tile ) let redomap_res : _ = redomap_res_lst -- support for non-empty code2' -- segmap (ltid_y < ty, ltid_x < tx) { -- for i < ry do -- for j < rx do -- res = if (iii+ltid_y*ry+i < height_A && jjj+ltid_x*rx+j < width_B) -- then code2' else dummy -- final_res[i,j] = res mkEpilogueAccRes segthd_lvl (redomap_orig_res, redomap_res) (res_nm, res_tp) (ty, tx, ry, rx) (iii, jjj) (gtid_y, gtid_x) (height_A, width_B, rem_outer_dims) code2' let grid = KernelGrid (Count grid_size) (Count tblock_size) level' = SegBlock SegNoVirt (Just grid) space' = SegSpace gid_flat (rem_outer_dims ++ [(gid_t, gridDim_t), (gid_y, gridDim_y), (gid_x, gridDim_x)]) kbody' = KernelBody () stms_seggroup ret_seggroup pure $ Let pat aux $ Op $ SegOp $ SegMap level' space' ts kbody' pure $ Just (host_stms, new_kernel) where sameAccType acc_sglton (Acc sglton _ _ _) = acc_sglton == sglton sameAccType _ _ = False getAccumFV (Acc singleton _shp [_eltp] _) = do let fvs = namesToList $ freeIn old_kbody -- code tps <- localScope (scopeOfSegSpace seg_space) $ do mapM lookupType fvs let (acc_0s, _) = unzip $ filter (sameAccType singleton . snd) $ zip fvs tps case acc_0s of [acc_0] -> pure acc_0 _ -> error "Impossible case reached when treating accumulators!" getAccumFV tp = error ("Should be an accumulator type at this point, given: " ++ prettyString tp) -- -- checks that the redomap result is used directly as the accumulated value, -- in which case it is safe to parallelize the innermost dimension (of tile tk) checkAccumulatesRedomapRes res_nm acc_code redomap_orig_res = do foldl getAccumStm False $ reverse $ stmsToList acc_code where getAccumStm True _ = True getAccumStm False (Let (Pat [pat_el]) _aux (BasicOp (UpdateAcc _ _acc_nm _ind vals))) | [v] <- vals, patElemName pat_el == res_nm = v == Var redomap_orig_res getAccumStm False _ = False -- -- epilogue for accumulator result type mkEpilogueAccRes segthd_lvl (redomap_orig_res, redomap_res) (res_nm, res_tp) (ty, tx, ry, rx) (iii, jjj) (gtid_y, gtid_x) (height_A, width_B, _rem_outer_dims) code2' = do rss_init <- getAccumFV res_tp rssss_list <- segMap2D "rssss" segthd_lvl ResultMaySimplify (ty, tx) $ \(ltid_y, ltid_x) -> do (css, ii, jj) <- getThdRedomapRes (rx, ry) (ltid_x, ltid_y) (iii, jjj, redomap_res) rss <- forLoop ry [rss_init] $ \i [rss_merge] -> do rss' <- forLoop rx [rss_merge] $ \j [rss_merge'] -> do prereqAddCode2 (gtid_x, gtid_y) (ii, i, jj, j) (css, redomap_orig_res) let code2_subs = substituteNames (M.singleton rss_init rss_merge') code2' res_el <- letSubExp "res_elem" =<< eIf ( toExp $ le64 gtid_y .<. pe64 height_A .&&. le64 gtid_x .<. pe64 width_B ) ( do addStms code2_subs resultBodyM [Var res_nm] ) (resultBodyM [Var rss_merge']) resultBodyM [res_el] resultBodyM [Var rss'] pure [varRes rss] let epilogue_res_acc : _ = rssss_list pure [Returns ResultMaySimplify (Certs []) $ Var epilogue_res_acc] mmBlkRegTilingAcc _ _ = pure Nothing -------------------------- -------------------------- mmBlkRegTilingNrm :: Env -> Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU)) mmBlkRegTilingNrm env (Let pat aux (Op (SegOp (SegMap SegThread {} seg_space ts old_kbody)))) | KernelBody () kstms [Returns ResultMaySimplify cs (Var res_nm)] <- old_kbody, cs == mempty, -- check kernel has one result of primitive type [res_tp] <- ts, primType res_tp, -- we get the global-thread id for the two inner dimensions, -- as we are probably going to use it in code generation (gtid_x, width_B) : (gtid_y, height_A) : rem_outer_dims_rev <- reverse $ unSegSpace seg_space, rem_outer_dims <- reverse rem_outer_dims_rev, Just ( code2', (load_A, inp_A, map_t1, load_B, inp_B, map_t2), common_dim, var_dims, (map_lam, red_lam, red_ne, redomap_orig_res, red_t) ) <- matchesBlkRegTile seg_space kstms = do -- Here we start the implementation let is_B_coal = isInnerCoal env inp_B load_B ---- in this binder: host code and outer seggroup (ie. the new kernel) ---- (new_kernel, host_stms) <- runBuilder $ do -- host code (rx, ry, tx, ty, tk, tk_div_tx, tk_div_ty, tx_rx, ty_ry, a_loc_sz, b_loc_sz) <- mkTileMemSizes height_A width_B common_dim is_B_coal gridDim_x <- letSubExp "gridDim_x" =<< ceilDiv width_B tx_rx gridDim_y <- letSubExp "gridDim_y" =<< ceilDiv height_A ty_ry let gridxy_pexp = pe64 gridDim_y * pe64 gridDim_x let grid_pexp = foldl (\x d -> pe64 d * x) gridxy_pexp $ map snd rem_outer_dims_rev (grid_size, tblock_size, segthd_lvl) <- mkNewSegthdLvl tx ty grid_pexp (gid_x, gid_y, gid_flat) <- mkGidsXYF ---- in this binder: outer seggroup ---- (ret_seggroup, stms_seggroup) <- runBuilder $ do iii <- letExp "iii" =<< toExp (le64 gid_y * pe64 ty_ry) jjj <- letExp "jjj" =<< toExp (le64 gid_x * pe64 tx_rx) -- initialize register mem with neutral elements and create shmem (cssss, a_loc_init, b_loc_init) <- initRegShmem (rx, tx, ry, ty, a_loc_sz, b_loc_sz) (map_t1, map_t2, red_t) segthd_lvl red_ne -- build prologue. full_tiles <- letExp "full_tiles" $ BasicOp $ BinOp (SQuot Int64 Unsafe) common_dim tk let ct_arg = ( (rx, ry, tx, ty, tk, tk_div_tx, tk_div_ty, tx_rx), segthd_lvl, var_dims, (gtid_x, width_B, gtid_y, height_A, common_dim), (iii, jjj), (load_A, inp_A, map_t1, load_B, inp_B, map_t2), (map_lam, red_lam) ) prologue_res_list <- forLoop' (Var full_tiles) [cssss, a_loc_init, b_loc_init] $ \kk0 [thd_res_merge, a_loc_merge, b_loc_merge] -> do process_full_tiles <- kkLoopBody env ct_arg kk0 (thd_res_merge, a_loc_merge, b_loc_merge) False resultBodyM $ map Var process_full_tiles let prologue_res : a_loc_reuse : b_loc_reuse : _ = prologue_res_list -- build epilogue. epilogue_res_list <- kkLoopBody env ct_arg full_tiles (prologue_res, a_loc_reuse, b_loc_reuse) True let redomap_res : _ = epilogue_res_list -- support for non-empty code2' -- segmap (ltid_y < ty, ltid_x < tx) { -- for i < ry do -- for j < rx do -- res = if (iii+ltid_y*ry+i < height_A && jjj+ltid_x*rx+j < width_B) -- then code2' else dummy -- final_res[i,j] = res mkEpiloguePrimRes segthd_lvl (redomap_orig_res, redomap_res) (res_nm, res_tp) (ty, tx, ry, rx) (iii, jjj) (gtid_y, gtid_x) (height_A, width_B, rem_outer_dims) code2' let grid = KernelGrid (Count grid_size) (Count tblock_size) level' = SegBlock SegNoVirt (Just grid) space' = SegSpace gid_flat (rem_outer_dims ++ [(gid_y, gridDim_y), (gid_x, gridDim_x)]) kbody' = KernelBody () stms_seggroup ret_seggroup pure $ Let pat aux $ Op $ SegOp $ SegMap level' space' ts kbody' pure $ Just (host_stms, new_kernel) where mkEpiloguePrimRes segthd_lvl (redomap_orig_res, redomap_res) (res_nm, res_tp) (ty, tx, ry, rx) (iii, jjj) (gtid_y, gtid_x) (height_A, width_B, rem_outer_dims) code2' = do epilogue_res <- if redomap_orig_res == res_nm then pure redomap_res -- epilogue_res_list else do rssss_list <- segMap2D "rssss" segthd_lvl ResultPrivate (ty, tx) $ \(ltid_y, ltid_x) -> do rss_init <- scratch "rss_init" (elemType res_tp) [ry, rx] (css, ii, jj) <- getThdRedomapRes (rx, ry) (ltid_x, ltid_y) (iii, jjj, redomap_res) rss <- forLoop ry [rss_init] $ \i [rss_merge] -> do rss' <- forLoop rx [rss_merge] $ \j [rss_merge'] -> do prereqAddCode2 (gtid_x, gtid_y) (ii, i, jj, j) (css, redomap_orig_res) res_el <- letSubExp "res_elem" =<< eIf ( toExp $ le64 gtid_y .<. pe64 height_A .&&. le64 gtid_x .<. pe64 width_B ) ( do addStms code2' resultBodyM [Var res_nm] ) (eBody [eBlank res_tp]) rss'' <- update "rss" rss_merge' [i, j] res_el resultBodyM [Var rss''] resultBodyM [Var rss'] pure [varRes rss] let rssss : _ = rssss_list pure rssss let regtile_ret_dims = map (\(_, sz) -> (sz, se1, se1)) rem_outer_dims ++ [(height_A, ty, ry), (width_B, tx, rx)] -- Add dummy dimensions to tile to reflect the outer dimensions. epilogue_res' <- if null rem_outer_dims then pure epilogue_res else do epilogue_t <- lookupType epilogue_res let (block_dims, rest_dims) = splitAt 2 $ arrayDims epilogue_t ones = map (const $ intConst Int64 1) rem_outer_dims new_shape = Shape $ concat [ones, block_dims, ones, rest_dims] letExp "res_reshaped" . BasicOp $ Reshape ReshapeArbitrary new_shape epilogue_res pure [RegTileReturns mempty regtile_ret_dims epilogue_res'] mmBlkRegTilingNrm _ _ = pure Nothing -- pattern match the properties of the code that we look to -- tile: a redomap whose two input arrays are each invariant -- to one of the last two (innermost) parallel dimensions. matchesBlkRegTile :: SegSpace -> Stms GPU -> Maybe ( Stms GPU, (Stm GPU, VName, PrimType, Stm GPU, VName, PrimType), SubExp, [Int], (Lambda GPU, Lambda GPU, SubExp, VName, PrimType) ) matchesBlkRegTile seg_space kstms | -- build the variance table, that records, for -- each variable name, the variables it depends on initial_variance <- M.map mempty $ scopeOfSegSpace seg_space, variance <- varianceInStms initial_variance kstms, -- check that the code fits the pattern having: -- some `code1`, followed by one Screma SOAC, followed by some `code2` (code1, Just screma_stmt, code2) <- matchCodeStreamCode kstms, Let pat_redomap _ (Op _) <- screma_stmt, -- checks that the Screma SOAC is actually a redomap and normalizes it Just (common_dim, arrs, (_, red_lam, red_nes, map_lam)) <- isTileableRedomap screma_stmt, -- check that exactly two 1D arrays are streamed thorugh redomap, -- and the result of redomap is one scalar -- !!!I need to rearrange this whole thing!!! including inp_A and inp_B length arrs == 2, [red_ne] <- red_nes, [map_t1t, map_t2t] <- map paramDec $ lambdaParams map_lam, [red_t1, _] <- map paramDec $ lambdaParams red_lam, primType map_t1t && primType map_t2t && primType red_t1, map_t1_0 <- elemType map_t1t, map_t2_0 <- elemType map_t2t, -- checks that the input arrays to redomap are variant to -- exactly one of the two innermost dimensions of the kernel Just var_dims <- isInvarTo1of2InnerDims mempty seg_space variance arrs, -- get the variables on which the first result of redomap depends on [redomap_orig_res] <- patNames pat_redomap, Just res_red_var <- M.lookup redomap_orig_res variance, -- variance of the reduce result -- we furthermore check that code1 is only formed by -- 1. statements that slice some globally-declared arrays -- to produce the input for the redomap, and -- 2. potentially some statements on which the redomap -- is independent; these are recorded in `code2''` Just (code2'', tab_inv_stm) <- foldl (processIndirections (namesFromList arrs) res_red_var) (Just (Seq.empty, M.empty)) code1, -- identify load_A, load_B tmp_stms <- mapMaybe (`M.lookup` tab_inv_stm) arrs, length tmp_stms == length arrs = let zip_AB = zip3 tmp_stms arrs [map_t1_0, map_t2_0] [(load_A, inp_A, map_t1), (load_B, inp_B, map_t2)] = if var_dims == [0, 1] then zip_AB else reverse zip_AB code2' = code2'' <> code2 in Just ( code2', (load_A, inp_A, map_t1, load_B, inp_B, map_t2), common_dim, var_dims, (map_lam, red_lam, red_ne, redomap_orig_res, elemType red_t1) ) matchesBlkRegTile _ _ = Nothing -- ceiled division expression ceilDiv :: (MonadBuilder m) => SubExp -> SubExp -> m (Exp (Rep m)) ceilDiv x y = pure $ BasicOp $ BinOp (SDivUp Int64 Unsafe) x y mkTileMemSizes :: SubExp -> SubExp -> SubExp -> Bool -> Builder GPU ( SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp, SubExp ) mkTileMemSizes height_A _width_B common_dim is_B_not_transp = do tk_name <- nameFromString . prettyString <$> newVName "Tk" ty_name <- nameFromString . prettyString <$> newVName "Ty" ry_name <- nameFromString . prettyString <$> newVName "Ry" -- until we change the copying to use lmads we need to -- guarantee that Tx=Ty AND Rx = Ry AND Tx | Tk -- for matrix multiplication it would be safe if they aren't -- but not for any of the other three cases! (ty, ry) <- getParTiles ("Ty", "Ry") (ty_name, ry_name) height_A let (tx, rx) = (ty, ry) tk <- getSeqTile "Tk" tk_name common_dim tx ty tk_div_tx <- letSubExp "tk_div_tx" =<< ceilDiv tk tx tk_div_ty <- letSubExp "tk_div_ty" =<< ceilDiv tk ty tx_rx <- letSubExp "TxRx" =<< toExp (pe64 tx * pe64 rx) ty_ry <- letSubExp "TyRy" =<< toExp (pe64 ty * pe64 ry) -- let pad_term = sMax64 (pe64 tk) (pe64 ty * pe64 ry) let pad_term = if is_B_not_transp then pe64 ty * pe64 ry else pe64 se0 a_loc_sz <- letSubExp "a_loc_sz" =<< toExp (pe64 ty * pe64 ry * pe64 tk) -- if B is transposed, its shmem should be [tk][tx*rx] -- we pad as above, by assuming tx*rx == ty*ry >= tk b_loc_sz <- letSubExp "b_loc_sz" =<< toExp (pe64 tx * pe64 rx * pe64 tk + pad_term) pure (rx, ry, tx, ty, tk, tk_div_tx, tk_div_ty, tx_rx, ty_ry, a_loc_sz, b_loc_sz) mkNewSegthdLvl :: SubExp -> SubExp -> TPrimExp Int64 VName -> Builder GPU (SubExp, SubExp, SegLevel) mkNewSegthdLvl tx ty grid_pexp = do grid_size <- letSubExp "grid_size" =<< toExp grid_pexp tblock_size <- letSubExp "tblock_size" =<< toExp (pe64 ty * pe64 tx) let segthd_lvl = SegThreadInBlock (SegNoVirtFull (SegSeqDims [])) pure (grid_size, tblock_size, segthd_lvl) mkGidsXYF :: Builder GPU (VName, VName, VName) mkGidsXYF = do gid_y <- newVName "gid_y" gid_x <- newVName "gid_x" gid_flat <- newVName "gid_flat" pure (gid_x, gid_y, gid_flat) initRegShmem :: (SubExp, SubExp, SubExp, SubExp, SubExp, SubExp) -> (PrimType, PrimType, PrimType) -> SegLevel -> SubExp -> Builder GPU (VName, VName, VName) initRegShmem (rx, tx, ry, ty, a_loc_sz, b_loc_sz) (map_t1, map_t2, red_t) segthd_lvl red_ne = do -- initialize register mem with neutral elements. cssss_list <- segMap2D "cssss" segthd_lvl ResultPrivate (ty, tx) $ \_ -> do css_init <- scratch "css_init" red_t [ry, rx] css <- forLoop ry [css_init] $ \i [css_merge] -> do css' <- forLoop rx [css_merge] $ \j [css_merge'] -> do css'' <- update "css" css_merge' [i, j] red_ne resultBodyM [Var css''] resultBodyM [Var css'] pure [varRes css] let [cssss] = cssss_list -- scratch shared memory a_loc_init <- scratch "A_loc" map_t1 [a_loc_sz] b_loc_init <- scratch "B_loc" map_t2 [b_loc_sz] pure (cssss, a_loc_init, b_loc_init) getThdRedomapRes :: (SubExp, SubExp) -> (VName, VName) -> (VName, VName, VName) -> Builder GPU (VName, VName, VName) getThdRedomapRes (rx, ry) (ltid_x, ltid_y) (iii, jjj, redomap_res) = do css <- index "redomap_thd" redomap_res [ltid_y, ltid_x] ii <- letExp "ii" =<< toExp (le64 iii + le64 ltid_y * pe64 ry) jj <- letExp "jj" =<< toExp (le64 jjj + le64 ltid_x * pe64 rx) pure (css, ii, jj) prereqAddCode2 :: (VName, VName) -> (VName, VName, VName, VName) -> (VName, VName) -> Builder GPU () prereqAddCode2 (gtid_x, gtid_y) (ii, i, jj, j) (css, redomap_orig_res) = do c <- index "redomap_elm" css [i, j] cpy_stm <- mkLetNamesM [redomap_orig_res] $ BasicOp $ SubExp $ Var c addStm cpy_stm letBindNames [gtid_y] =<< toExp (le64 ii + le64 i) letBindNames [gtid_x] =<< toExp (le64 jj + le64 j) -- | Tries to identify the following pattern: -- code followed by some Screma followed by more code. matchCodeStreamCode :: Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU) matchCodeStreamCode kstms = let (code1, screma, code2) = foldl ( \acc stmt -> case (acc, stmt) of ((cd1, Nothing, cd2), Let _ _ (Op (OtherOp Screma {}))) -> (cd1, Just stmt, cd2) ((cd1, Nothing, cd2), _) -> (cd1 ++ [stmt], Nothing, cd2) ((cd1, Just strm, cd2), _) -> (cd1, Just strm, cd2 ++ [stmt]) ) ([], Nothing, []) (stmsToList kstms) in (stmsFromList code1, screma, stmsFromList code2) -- | Checks that all streamed arrays are variant to exacly one of -- the two innermost parallel dimensions, and conversely, for -- each of the two innermost parallel dimensions, there is at -- least one streamed array variant to it. The result is the -- number of the only variant parallel dimension for each array. isInvarTo1of2InnerDims :: Names -> SegSpace -> VarianceTable -> [VName] -> Maybe [Int] isInvarTo1of2InnerDims branch_variant kspace variance arrs = let inner_perm0 = map varToOnly1of2InnerDims arrs inner_perm = catMaybes inner_perm0 ok1 = elem 0 inner_perm && elem 1 inner_perm ok2 = length inner_perm0 == length inner_perm in if ok1 && ok2 then Just inner_perm else Nothing where varToOnly1of2InnerDims :: VName -> Maybe Int varToOnly1of2InnerDims arr = do (j, _) : (i, _) : _ <- Just $ reverse $ unSegSpace kspace let variant_to = M.findWithDefault mempty arr variance branch_invariant = not $ nameIn j branch_variant || nameIn i branch_variant if not branch_invariant then Nothing -- if i or j in branch_variant; return nothing else if nameIn i variant_to && j `notNameIn` variant_to then Just 0 else if nameIn j variant_to && i `notNameIn` variant_to then Just 1 else Nothing processIndirections :: Names -> -- input arrays to redomap Names -> -- variables on which the result of redomap depends on. Maybe (Stms GPU, M.Map VName (Stm GPU)) -> Stm GPU -> Maybe (Stms GPU, M.Map VName (Stm GPU)) processIndirections arrs _ acc stm@(Let patt _ (BasicOp (Index _ _))) | Just (ss, tab) <- acc, [p] <- patElems patt, p_nm <- patElemName p, p_nm `nameIn` arrs = Just (ss, M.insert p_nm stm tab) processIndirections _ res_red_var acc stm'@(Let patt _ _) | Just (ss, tab) <- acc, ps <- patElems patt, all (\p -> patElemName p `notNameIn` res_red_var) ps = Just (ss Seq.|> stm', tab) | otherwise = Nothing getParTiles :: (String, String) -> (Name, Name) -> SubExp -> Builder GPU (SubExp, SubExp) getParTiles (t_str, r_str) (t_name, r_name) len_dim = case len_dim of Constant (IntValue (Int64Value 8)) -> pure (se8, se1) Constant (IntValue (Int64Value 16)) -> pure (se8, se2) Constant (IntValue (Int64Value 32)) -> pure (se8, se4) _ -> do t <- letSubExp t_str $ Op $ SizeOp $ GetSize t_name SizeTile r <- letSubExp r_str $ Op $ SizeOp $ GetSize r_name SizeRegTile pure (t, r) getSeqTile :: String -> Name -> SubExp -> SubExp -> SubExp -> Builder GPU SubExp getSeqTile tk_str tk_name len_dim tx ty = case (tx, ty) of (Constant (IntValue (Int64Value v_x)), Constant (IntValue (Int64Value v_y))) -> letSubExp tk_str . BasicOp . SubExp . constant $ case len_dim of Constant (IntValue (Int64Value v_d)) -> min v_d $ min v_x v_y _ -> min v_x v_y _ -> letSubExp tk_str $ Op $ SizeOp $ GetSize tk_name SizeTile ---------------------------------------------------------------------------------------------- --- 3D Tiling (RegTiling for the outermost dimension & Block tiling for the innermost two) --- ---------------------------------------------------------------------------------------------- maxRegTile :: Int64 maxRegTile = 30 mkRegTileSe :: Int64 -> SubExp mkRegTileSe = constant variantToDim :: VarianceTable -> VName -> VName -> Bool variantToDim variance gid_outer nm = gid_outer == nm || nameIn gid_outer (M.findWithDefault mempty nm variance) -- | Checks that all streamed arrays are variant to exacly one of -- the two innermost parallel dimensions, and conversely, for -- each of the two innermost parallel dimensions, there is at -- least one streamed array variant to it. The result is the -- number of the only variant parallel dimension for each array. isInvarTo2of3InnerDims :: Names -> SegSpace -> VarianceTable -> [VName] -> Maybe [Int] isInvarTo2of3InnerDims branch_variant kspace variance arrs = let inner_perm0 = map varToOnly1of3InnerDims arrs inner_perm = catMaybes inner_perm0 ok1 = elem 0 inner_perm && elem 1 inner_perm && elem 2 inner_perm ok2 = length inner_perm0 == length inner_perm in if ok1 && ok2 then Just inner_perm else Nothing where varToOnly1of3InnerDims :: VName -> Maybe Int varToOnly1of3InnerDims arr = do (k, _) : (j, _) : (i, _) : _ <- Just $ reverse $ unSegSpace kspace let variant_to = M.findWithDefault mempty arr variance branch_invariant = not $ nameIn k branch_variant || nameIn j branch_variant || nameIn i branch_variant if not branch_invariant then Nothing -- if i or j or k in branch_variant; return nothing else if nameIn i variant_to && not (nameIn j variant_to || nameIn k variant_to) then Just 0 else if nameIn j variant_to && not (nameIn i variant_to || nameIn k variant_to) then Just 1 else if nameIn k variant_to && not (nameIn i variant_to || nameIn j variant_to) then Just 2 else Nothing -- | Expects a kernel statement as argument. -- CONDITIONS for 3D tiling optimization to fire are: -- 1. a) The kernel body can be broken into -- scalar-code-1 ++ [Redomap stmt] ++ scalar-code-2. -- b) The kernels has a per-thread result, and obviously -- the result is variant to the 3rd dimension -- (counted from innermost to outermost) -- 2. For the Redomap: -- a) the streamed arrays are one dimensional -- b) each of the array arguments of Redomap are variant -- to exactly one of the three innermost-parallel dimension -- of the kernel. This condition can be relaxed by interchanging -- kernel dimensions whenever possible. -- 3. For scalar-code-1: -- a) each of the statements is a slice that produces one of the -- streamed arrays -- -- mmBlkRegTiling :: Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU)) -- mmBlkRegTiling (Let pat aux (Op (SegOp (SegMap SegThread{} seg_space ts old_kbody)))) doRegTiling3D :: Stm GPU -> TileM (Maybe (Stms GPU, Stm GPU)) doRegTiling3D (Let pat aux (Op (SegOp old_kernel))) | SegMap SegThread {} space kertp (KernelBody () kstms kres) <- old_kernel, -- build the variance table, that records, for -- each variable name, the variables it depends on initial_variance <- M.map mempty $ scopeOfSegSpace space, variance <- varianceInStms initial_variance kstms, -- we get the global-thread id for the two inner dimensions, -- as we are probably going to use it in code generation (gtid_x, d_Kx) : (gtid_y, d_Ky) : (gtid_z, d_M) : rem_outer_dims_rev <- reverse $ unSegSpace space, rem_outer_dims <- reverse rem_outer_dims_rev, -- check that the code fits the pattern having: -- some `code1`, followed by one Screma SOAC, followed by some `code2` (code1, Just screma_stmt, code2) <- matchCodeStreamCode kstms, Let pat_redomap _ (Op _) <- screma_stmt, -- checks that the Screma SOAC is actually a redomap and normalize it Just (common_dim, inp_soac_arrs, (_, red_lam, red_nes, map_lam)) <- isTileableRedomap screma_stmt, not (null red_nes), -- assuming we have a budget of maxRegTile registers, we distribute -- that budget across the result of redomap and the kernel result num_res <- max (length red_nes) (length kres), reg_tile <- maxRegTile `quot` fromIntegral num_res, reg_tile_se <- mkRegTileSe reg_tile, -- check that the element-type of the map and reduce are scalars: all (primType . paramDec) $ lambdaParams map_lam, red_res_tps <- map paramDec $ take (length red_nes) $ lambdaParams red_lam, all primType red_res_tps, -- checks that the input arrays to redomap are variant to -- exactly one of the two innermost dimensions of the kernel Just _ <- isInvarTo2of3InnerDims mempty space variance inp_soac_arrs, -- get the free variables on which the result of redomap depends on redomap_orig_res <- patElems pat_redomap, res_red_var <- -- variance of the reduce result mconcat $ mapMaybe ((`M.lookup` variance) . patElemName) redomap_orig_res, mempty /= res_red_var, -- we furthermore check that code1 is only formed by -- 1. statements that slice some globally-declared arrays -- to produce the input for the redomap, and -- 2. potentially some statements on which the redomap -- is independent; these are recorded in `code2''` Just (code2'', arr_tab0) <- foldl (processIndirections (namesFromList inp_soac_arrs) res_red_var) (Just (Seq.empty, M.empty)) code1, -- check that code1 contains exacly one slice for each of the input array to redomap tmp_stms <- mapMaybe (`M.lookup` arr_tab0) inp_soac_arrs, length tmp_stms == length inp_soac_arrs, -- code1' <- stmsFromList $ stmsToList code1 \\ stmsToList code2'', code2' <- code2'' <> code2, -- we assume the kernel results are variant to the thrid-outer parallel dimension -- (for sanity sake, they should be) ker_res_nms <- mapMaybe getResNm kres, length ker_res_nms == length kres, all primType kertp, all (variantToDim variance gtid_z) ker_res_nms = do -- HERE STARTS THE IMPLEMENTATION: (new_kernel, host_stms) <- runBuilder $ do -- host code -- process the z-variant arrays that need transposition; -- these "manifest" statements should come before the kernel (tab_inn, tab_out) <- foldM (insertTranspose variance (gtid_z, d_M)) (M.empty, M.empty) $ M.toList arr_tab0 tx_name <- nameFromString . prettyString <$> newVName "Tx" ty_name <- nameFromString . prettyString <$> newVName "Ty" tx0 <- letSubExp "Tx" $ Op $ SizeOp $ GetSize tx_name SizeTile ty0 <- letSubExp "Ty" $ Op $ SizeOp $ GetSize ty_name SizeTile ty <- limitTile "Ty" ty0 d_Ky tx <- limitTile "Tx" tx0 d_Kx let rz = reg_tile_se gridDim_x <- letSubExp "gridDim_x" =<< ceilDiv d_Kx tx gridDim_y <- letSubExp "gridDim_y" =<< ceilDiv d_Ky ty gridDim_z <- letSubExp "gridDim_z" =<< ceilDiv d_M rz let gridxyz_pexp = pe64 gridDim_z * pe64 gridDim_y * pe64 gridDim_x let grid_pexp = product $ gridxyz_pexp : map (pe64 . snd) rem_outer_dims_rev grid_size <- letSubExp "grid_size_tile3d" =<< toExp grid_pexp tblock_size <- letSubExp "tblock_size_tile3d" =<< toExp (pe64 ty * pe64 tx) let segthd_lvl = SegThreadInBlock (SegNoVirtFull (SegSeqDims [])) count_shmem <- letSubExp "count_shmem" =<< ceilDiv rz tblock_size gid_x <- newVName "gid_x" gid_y <- newVName "gid_y" gid_z <- newVName "gid_z" gid_flat <- newVName "gid_flat" ---- in this binder: outer seggroup ---- (ret_seggroup, stms_seggroup) <- runBuilder $ do ii <- letExp "ii" =<< toExp (le64 gid_z * pe64 rz) jj1 <- letExp "jj1" =<< toExp (le64 gid_y * pe64 ty) jj2 <- letExp "jj2" =<< toExp (le64 gid_x * pe64 tx) -- initialize the register arrays corresponding to the result of redomap; reg_arr_nms <- segMap2D "res" segthd_lvl ResultPrivate (ty, tx) $ \_ -> forM (zip red_nes red_res_tps) $ \(red_ne, red_t) -> do css_init <- scratch "res_init" (elemType red_t) [rz] css <- forLoop rz [css_init] $ \i [css_merge] -> do css' <- update "css" css_merge [i] red_ne resultBodyM [Var css'] pure $ varRes css -- scratch the shared-memory arrays corresponding to the arrays that are -- input to the redomap and are invariant to the outermost parallel dimension. loc_arr_nms <- forM (M.toList tab_out) $ \(nm, (ptp, _)) -> scratch (baseString nm ++ "_loc") ptp [rz] prologue_res_list <- forLoop' common_dim (reg_arr_nms ++ loc_arr_nms) $ \q var_nms -> do let reg_arr_merge_nms = take (length red_nes) var_nms let loc_arr_merge_nms = drop (length red_nes) var_nms -- collective copy from global to shared memory loc_arr_nms' <- forLoop' count_shmem loc_arr_merge_nms $ \tt loc_arr_merge2_nms -> do loc_arr_merge2_nms' <- forM (zip loc_arr_merge2_nms (M.toList tab_out)) $ \(loc_Y_nm, (glb_Y_nm, (ptp_Y, load_Y))) -> do ltid_flat <- newVName "ltid_flat" ltid <- newVName "ltid" let segspace = SegSpace ltid_flat [(ltid, tblock_size)] ((res_v, res_i), stms) <- runBuilder $ do offs <- letExp "offs" =<< toExp (pe64 tblock_size * le64 tt) loc_ind <- letExp "loc_ind" =<< toExp (le64 ltid + le64 offs) letBindNames [gtid_z] =<< toExp (le64 ii + le64 loc_ind) let glb_ind = gtid_z y_elm <- letSubExp "y_elem" =<< eIf (toExp $ le64 glb_ind .<. pe64 d_M) ( do addStm load_Y res <- index "Y_elem" glb_Y_nm [q] resultBodyM [Var res] ) (eBody [eBlank $ Prim ptp_Y]) y_ind <- letSubExp "y_loc_ind" =<< eIf (toExp $ le64 loc_ind .<. pe64 rz) (toExp loc_ind >>= letTupExp' "loc_fi" >>= resultBodyM) (eBody [pure $ BasicOp $ SubExp $ intConst Int64 (-1)]) -- y_tp <- subExpType y_elm pure (y_elm, y_ind) let ret = WriteReturns mempty loc_Y_nm [(Slice [DimFix res_i], res_v)] let body = KernelBody () stms [ret] loc_Y_nm_t <- lookupType loc_Y_nm res_nms <- letTupExp "Y_glb2loc" <=< renameExp $ Op . SegOp $ SegMap segthd_lvl segspace [loc_Y_nm_t] body let res_nm : _ = res_nms pure res_nm resultBodyM $ map Var loc_arr_merge2_nms' redomap_res <- segMap2D "redomap_res" segthd_lvl ResultPrivate (ty, tx) $ \(ltid_y, ltid_x) -> do letBindNames [gtid_y] =<< toExp (le64 jj1 + le64 ltid_y) letBindNames [gtid_x] =<< toExp (le64 jj2 + le64 ltid_x) reg_arr_merge_nms_slc <- forM reg_arr_merge_nms $ \reg_arr_nm -> index "res_reg_slc" reg_arr_nm [ltid_y, ltid_x] fmap subExpsRes . letTupExp' "redomap_guarded" =<< eIf (toExp $ le64 gtid_y .<. pe64 d_Ky .&&. le64 gtid_x .<. pe64 d_Kx) ( do inp_scals_invar_outer <- forM (M.toList tab_inn) $ \(inp_arr_nm, load_stm) -> do addStm load_stm index (baseString inp_arr_nm) inp_arr_nm [q] -- build the loop of count R whose body is semantically the redomap code reg_arr_merge_nms' <- forLoop' rz reg_arr_merge_nms_slc $ \i reg_arr_mm_nms -> do letBindNames [gtid_z] =<< toExp (le64 ii + le64 i) resultBodyM =<< letTupExp' "redomap_lam" =<< eIf (toExp $ le64 gtid_z .<. pe64 d_M) ( do -- read from shared memory ys <- forM loc_arr_nms' $ \loc_arr_nm -> index "inp_reg_var2z" loc_arr_nm [i] cs <- forM reg_arr_mm_nms $ \reg_arr_nm -> index "res_reg_var2z" reg_arr_nm [i] -- here we need to put in order the scalar inputs to map: let tab_scals = M.fromList $ zip (map fst $ M.toList tab_out) ys ++ zip (map fst $ M.toList tab_inn) inp_scals_invar_outer map_inp_scals <- forM inp_soac_arrs $ \arr_nm -> case M.lookup arr_nm tab_scals of Nothing -> error "Impossible case reached in tiling3D\n" Just nm -> pure nm map_lam' <- renameLambda map_lam red_lam' <- renameLambda red_lam map_res_scals <- eLambda map_lam' (map (eSubExp . Var) map_inp_scals) red_res <- eLambda red_lam' (map eSubExp (map Var cs ++ map resSubExp map_res_scals)) css <- forM (zip reg_arr_mm_nms red_res) $ \(reg_arr_nm, c) -> update (baseString reg_arr_nm) reg_arr_nm [i] (resSubExp c) resultBodyM $ map Var css ) (resultBodyM $ map Var reg_arr_mm_nms) resultBodyM $ map Var reg_arr_merge_nms' ) (resultBodyM $ map Var reg_arr_merge_nms_slc) resultBodyM $ map Var $ redomap_res ++ loc_arr_nms' -- support for non-empty code2' -- segmap (ltid_y < ty, ltid_x < tx) { -- for i < rz do -- res = if (ii+i < d_M && jj1+ltid_y < d_Ky && jj2 + ltid_x < d_Kx) -- then code2' else dummy -- final_res[i] = res let redomap_res = take (length red_nes) prologue_res_list epilogue_res <- if length redomap_orig_res == length ker_res_nms && ker_res_nms == map patElemName redomap_orig_res then segMap3D "rssss" segthd_lvl ResultPrivate (se1, ty, tx) $ \(_ltid_z, ltid_y, ltid_x) -> forM (zip kertp redomap_res) $ \(res_tp, res) -> do rss_init <- scratch "rss_init" (elemType res_tp) [rz, se1, se1] fmap varRes $ forLoop rz [rss_init] $ \i [rss] -> do let slice = Slice [DimFix $ Var i, DimFix se0, DimFix se0] thread_res <- index "thread_res" res [ltid_y, ltid_x, i] rss' <- letSubExp "rss" $ BasicOp $ Update Unsafe rss slice $ Var thread_res resultBodyM [rss'] else segMap3D "rssss" segthd_lvl ResultPrivate (se1, ty, tx) $ \(_ltid_z, ltid_y, ltid_x) -> do letBindNames [gtid_y] =<< toExp (le64 jj1 + le64 ltid_y) letBindNames [gtid_x] =<< toExp (le64 jj2 + le64 ltid_x) rss_init <- forM kertp $ \res_tp -> scratch "rss_init" (elemType res_tp) [rz, se1, se1] rss <- forLoop' rz rss_init $ \i rss_merge -> do letBindNames [gtid_z] =<< toExp (le64 ii + le64 i) forM_ (zip redomap_orig_res redomap_res) $ \(o_res, n_res) -> do c <- index "redomap_thd" n_res [ltid_y, ltid_x, i] letBindNames [patElemName o_res] =<< toExp (le64 c) pure c res_els <- letTupExp' "res_elem" =<< eIf ( toExp $ le64 gtid_y .<. pe64 d_Ky .&&. le64 gtid_x .<. pe64 d_Kx .&&. le64 gtid_z .<. pe64 d_M ) ( do addStms code2' resultBodyM $ map Var ker_res_nms ) (eBody $ map eBlank kertp) rss' <- forM (zip res_els rss_merge) $ \(res_el, rs_merge) -> do let slice = Slice [DimFix $ Var i, DimFix se0, DimFix se0] letSubExp "rss" $ BasicOp $ Update Unsafe rs_merge slice res_el resultBodyM rss' pure $ varsRes rss ---------------------------------------------------------------- -- Finally, reshape the result arrays for the RegTileReturn --- ---------------------------------------------------------------- let regtile_ret_dims = map (\(_, sz) -> (sz, se1, se1)) rem_outer_dims ++ [(d_M, se1, rz), (d_Ky, ty, se1), (d_Kx, tx, se1)] epilogue_res' <- forM epilogue_res $ \res -> if null rem_outer_dims then pure res else do -- Add dummy dimensions to tile to reflect the outer dimensions res_tp' <- lookupType res let (block_dims, rest_dims) = splitAt 2 $ arrayDims res_tp' ones = map (const se1) rem_outer_dims new_shape = Shape $ concat [ones, block_dims, ones, rest_dims] letExp "res_reshaped" . BasicOp $ Reshape ReshapeArbitrary new_shape res pure $ map (RegTileReturns mempty regtile_ret_dims) epilogue_res' -- END (ret_seggroup, stms_seggroup) <- runBuilder $ do let grid = KernelGrid (Count grid_size) (Count tblock_size) level' = SegBlock SegNoVirt (Just grid) space' = SegSpace gid_flat (rem_outer_dims ++ [(gid_z, gridDim_z), (gid_y, gridDim_y), (gid_x, gridDim_x)]) kbody' = KernelBody () stms_seggroup ret_seggroup pure $ Let pat aux $ Op $ SegOp $ SegMap level' space' kertp kbody' -- END (new_kernel, host_stms) <- runBuilder $ do pure $ Just (host_stms, new_kernel) where getResNm (Returns ResultMaySimplify _ (Var res_nm)) = Just res_nm getResNm _ = Nothing limitTile :: String -> SubExp -> SubExp -> Builder GPU SubExp limitTile t_str t d_K = letSubExp t_str $ BasicOp $ BinOp (SMin Int64) t d_K insertTranspose :: VarianceTable -> (VName, SubExp) -> (M.Map VName (Stm GPU), M.Map VName (PrimType, Stm GPU)) -> (VName, Stm GPU) -> Builder GPU (M.Map VName (Stm GPU), M.Map VName (PrimType, Stm GPU)) insertTranspose variance (gidz, _) (tab_inn, tab_out) (p_nm, stm@(Let patt yy (BasicOp (Index arr_nm slc)))) | [p] <- patElems patt, ptp <- elemType $ patElemType p, p_nm == patElemName p = case L.findIndices (variantSliceDim variance gidz) (unSlice slc) of [] -> pure (M.insert p_nm stm tab_inn, tab_out) i : _ -> do arr_tp <- lookupType arr_nm let perm = [i + 1 .. arrayRank arr_tp - 1] ++ [0 .. i] let arr_tr_str = baseString arr_nm ++ "_transp" arr_tr_nm <- letExp arr_tr_str $ BasicOp $ Manifest perm arr_nm let e_ind' = BasicOp $ Index arr_tr_nm slc let stm' = Let patt yy e_ind' pure (tab_inn, M.insert p_nm (ptp, stm') tab_out) insertTranspose _ _ _ _ = error "\nUnreachable case reached in insertTranspose case, doRegTiling3D\n" variantSliceDim :: VarianceTable -> VName -> DimIndex SubExp -> Bool variantSliceDim variance gidz (DimFix (Var vnm)) = variantToDim variance gidz vnm variantSliceDim _ _ _ = False doRegTiling3D _ = pure Nothing futhark-0.25.27/src/Futhark/Optimise/CSE.hs000066400000000000000000000237241475065116200203040ustar00rootroot00000000000000{-# LANGUAGE UndecidableInstances #-} -- | This module implements common-subexpression elimination. This -- module does not actually remove the duplicate, but only replaces -- one with a diference to the other. E.g: -- -- @ -- let a = x + y -- let b = x + y -- @ -- -- becomes: -- -- @ -- let a = x + y -- let b = a -- @ -- -- After which copy propagation in the simplifier will actually remove -- the definition of @b@. -- -- Our CSE is still rather stupid. No normalisation is performed, so -- the expressions @x+y@ and @y+x@ will be considered distinct. -- Furthermore, no expression with its own binding will be considered -- equal to any other, since the variable names will be distinct. -- This affects SOACs in particular. module Futhark.Optimise.CSE ( performCSE, performCSEOnFunDef, performCSEOnStms, CSEInOp, ) where import Control.Monad.Reader import Data.Map.Strict qualified as M import Futhark.Analysis.Alias import Futhark.IR import Futhark.IR.Aliases ( Aliases, consumedInStms, mkStmsAliases, removeFunDefAliases, removeProgAliases, removeStmAliases, ) import Futhark.IR.GPU qualified as GPU import Futhark.IR.MC qualified as MC import Futhark.IR.Mem qualified as Memory import Futhark.IR.Prop.Aliases import Futhark.IR.SOACS.SOAC qualified as SOAC import Futhark.Pass import Futhark.Transform.Substitute -- | Perform CSE on every function in a program. -- -- If the boolean argument is false, the pass will not perform CSE on -- expressions producing arrays. This should be disabled when the rep has -- memory information, since at that point arrays have identity beyond their -- value. performCSE :: (AliasableRep rep, CSEInOp (Op (Aliases rep))) => Bool -> Pass rep rep performCSE cse_arrays = Pass "CSE" "Combine common subexpressions." $ \prog -> fmap removeProgAliases . intraproceduralTransformationWithConsts (onConsts (freeIn (progFuns prog))) onFun . aliasAnalysis $ prog where onConsts free_in_funs stms = do let free_list = namesToList free_in_funs (res_als, stms_cons) = mkStmsAliases stms $ varsRes free_list pure . fst $ runReader ( cseInStms (mconcat res_als <> stms_cons) (stmsToList stms) (pure ()) ) (newCSEState cse_arrays) onFun _ = pure . cseInFunDef cse_arrays -- | Perform CSE on a single function. -- -- If the boolean argument is false, the pass will not perform CSE on -- expressions producing arrays. This should be disabled when the rep has -- memory information, since at that point arrays have identity beyond their -- value. performCSEOnFunDef :: (AliasableRep rep, CSEInOp (Op (Aliases rep))) => Bool -> FunDef rep -> FunDef rep performCSEOnFunDef cse_arrays = removeFunDefAliases . cseInFunDef cse_arrays . analyseFun -- | Perform CSE on some statements. performCSEOnStms :: (AliasableRep rep, CSEInOp (Op (Aliases rep))) => Stms rep -> Stms rep performCSEOnStms = fmap removeStmAliases . f . fst . analyseStms mempty where f stms = fst $ runReader (cseInStms (consumedInStms stms) (stmsToList stms) (pure ())) -- It is never safe to CSE arrays in stms in isolation, -- because we might introduce additional aliasing. (newCSEState False) cseInFunDef :: (Aliased rep, CSEInOp (Op rep)) => Bool -> FunDef rep -> FunDef rep cseInFunDef cse_arrays fundec = fundec { funDefBody = runReader (cseInBody ds $ funDefBody fundec) $ newCSEState cse_arrays } where -- XXX: we treat every array result as a consumption here, because -- it is otherwise complicated to ensure we do not introduce more -- aliasing than specified by the return type. This is not a -- practical problem while we still perform such aggressive -- inlining. ds = map (retDiet . fst) $ funDefRetType fundec retDiet t | primType $ declExtTypeOf t = Observe | otherwise = Consume type CSEM rep = Reader (CSEState rep) cseInBody :: (Aliased rep, CSEInOp (Op rep)) => [Diet] -> Body rep -> CSEM rep (Body rep) cseInBody ds (Body bodydec stms res) = do (stms', res') <- cseInStms (res_cons <> stms_cons) (stmsToList stms) $ do CSEState (_, nsubsts) _ <- ask pure $ substituteNames nsubsts res pure $ Body bodydec stms' res' where (res_als, stms_cons) = mkStmsAliases stms res res_cons = mconcat $ zipWith consumeResult ds res_als consumeResult Consume als = als consumeResult _ _ = mempty cseInLambda :: (Aliased rep, CSEInOp (Op rep)) => Lambda rep -> CSEM rep (Lambda rep) cseInLambda lam = do body' <- cseInBody (map (const Observe) $ lambdaReturnType lam) $ lambdaBody lam pure lam {lambdaBody = body'} cseInStms :: forall rep a. (Aliased rep, CSEInOp (Op rep)) => Names -> [Stm rep] -> CSEM rep a -> CSEM rep (Stms rep, a) cseInStms _ [] m = do a <- m pure (mempty, a) cseInStms consumed (stm : stms) m = cseInStm consumed stm $ \stm' -> do (stms', a) <- cseInStms consumed stms m stm'' <- mapM nestedCSE stm' pure (stmsFromList stm'' <> stms', a) where nestedCSE stm' = do let ds = case stmExp stm' of Loop merge _ _ -> map (diet . declTypeOf . fst) merge _ -> map patElemDiet $ patElems $ stmPat stm' e <- mapExpM (cse ds) $ stmExp stm' pure stm' {stmExp = e} cse ds = (identityMapper @rep) { mapOnBody = const $ cseInBody ds, mapOnOp = cseInOp } patElemDiet pe | patElemName pe `nameIn` consumed = Consume | otherwise = Observe -- A small amount of normalisation of expressions that otherwise would -- be different for pointless reasons. normExp :: Exp lore -> Exp lore normExp (Apply fname args ret (safety, _, _)) = Apply fname args ret (safety, mempty, mempty) normExp e = e cseInStm :: (ASTRep rep) => Names -> Stm rep -> ([Stm rep] -> CSEM rep a) -> CSEM rep a cseInStm consumed (Let pat (StmAux cs attrs edec) e) m = do CSEState (esubsts, nsubsts) cse_arrays <- ask let e' = normExp $ substituteNames nsubsts e pat' = substituteNames nsubsts pat if not (alreadyAliases e) && any (bad cse_arrays) (patElems pat) then m [Let pat' (StmAux cs attrs edec) e'] else case M.lookup (edec, e') esubsts of Just (subcs, subpat) -> do let subsumes = all (`elem` unCerts subcs) (unCerts cs) -- We can only do a plain name substitution if it doesn't -- violate any certificate dependencies. local (if subsumes then addNameSubst pat' subpat else id) $ do let lets = [ Let (Pat [patElem']) (StmAux cs attrs edec) $ BasicOp (SubExp $ Var $ patElemName patElem) | (name, patElem) <- zip (patNames pat') $ patElems subpat, let patElem' = patElem {patElemName = name} ] m lets _ -> local (addExpSubst pat' edec cs e') $ m [Let pat' (StmAux cs attrs edec) e'] where alreadyAliases (BasicOp Index {}) = True alreadyAliases (BasicOp Reshape {}) = True alreadyAliases _ = False bad cse_arrays pe | Mem {} <- patElemType pe = True | Array {} <- patElemType pe, not cse_arrays = True | patElemName pe `nameIn` consumed = True | otherwise = False type ExpressionSubstitutions rep = M.Map (ExpDec rep, Exp rep) (Certs, Pat (LetDec rep)) type NameSubstitutions = M.Map VName VName data CSEState rep = CSEState { _cseSubstitutions :: (ExpressionSubstitutions rep, NameSubstitutions), _cseArrays :: Bool } newCSEState :: Bool -> CSEState rep newCSEState = CSEState (M.empty, M.empty) mkSubsts :: Pat dec -> Pat dec -> M.Map VName VName mkSubsts pat vs = M.fromList $ zip (patNames pat) (patNames vs) addNameSubst :: Pat dec -> Pat dec -> CSEState rep -> CSEState rep addNameSubst pat subpat (CSEState (esubsts, nsubsts) cse_arrays) = CSEState (esubsts, mkSubsts pat subpat `M.union` nsubsts) cse_arrays addExpSubst :: (ASTRep rep) => Pat (LetDec rep) -> ExpDec rep -> Certs -> Exp rep -> CSEState rep -> CSEState rep addExpSubst pat edec cs e (CSEState (esubsts, nsubsts) cse_arrays) = CSEState (M.insert (edec, e) (cs, pat) esubsts, nsubsts) cse_arrays -- | The operations that permit CSE. class CSEInOp op where -- | Perform CSE within any nested expressions. cseInOp :: op -> CSEM rep op instance CSEInOp (NoOp rep) where cseInOp NoOp = pure NoOp subCSE :: CSEM rep r -> CSEM otherrep r subCSE m = do CSEState _ cse_arrays <- ask pure $ runReader m $ newCSEState cse_arrays instance ( Aliased rep, CSEInOp (Op rep), CSEInOp (op rep) ) => CSEInOp (GPU.HostOp op rep) where cseInOp (GPU.SegOp op) = GPU.SegOp <$> cseInOp op cseInOp (GPU.OtherOp op) = GPU.OtherOp <$> cseInOp op cseInOp (GPU.GPUBody types body) = subCSE $ GPU.GPUBody types <$> cseInBody (map (const Observe) types) body cseInOp x = pure x instance ( Aliased rep, CSEInOp (Op rep), CSEInOp (op rep) ) => CSEInOp (MC.MCOp op rep) where cseInOp (MC.ParOp par_op op) = MC.ParOp <$> traverse cseInOp par_op <*> cseInOp op cseInOp (MC.OtherOp op) = MC.OtherOp <$> cseInOp op instance (Aliased rep, CSEInOp (Op rep)) => CSEInOp (GPU.SegOp lvl rep) where cseInOp = subCSE . GPU.mapSegOpM (GPU.SegOpMapper pure cseInLambda cseInKernelBody pure pure) cseInKernelBody :: (Aliased rep, CSEInOp (Op rep)) => GPU.KernelBody rep -> CSEM rep (GPU.KernelBody rep) cseInKernelBody (GPU.KernelBody bodydec stms res) = do Body _ stms' _ <- cseInBody (map (const Observe) res) $ Body bodydec stms [] pure $ GPU.KernelBody bodydec stms' res instance (CSEInOp (op rep)) => CSEInOp (Memory.MemOp op rep) where cseInOp o@Memory.Alloc {} = pure o cseInOp (Memory.Inner k) = Memory.Inner <$> subCSE (cseInOp k) instance (AliasableRep rep, CSEInOp (Op (Aliases rep))) => CSEInOp (SOAC.SOAC (Aliases rep)) where cseInOp = subCSE . SOAC.mapSOACM (SOAC.SOACMapper pure cseInLambda pure) futhark-0.25.27/src/Futhark/Optimise/DoubleBuffer.hs000066400000000000000000000300061475065116200222250ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | The simplification engine is only willing to hoist allocations -- out of loops if the memory block resulting from the allocation is -- dead at the end of the loop. If it is not, we may cause data -- hazards. -- -- This pass tries to rewrite loops with memory parameters. -- Specifically, it takes loops of this form: -- -- @ -- loop {..., A_mem, ..., A, ...} ... do { -- ... -- let A_out_mem = alloc(...) -- stores A_out -- in {..., A_out_mem, ..., A_out, ...} -- } -- @ -- -- and turns them into -- -- @ -- let A_in_mem = alloc(...) -- let A_out_mem = alloc(...) -- let A_in = copy A -- in A_in_mem -- loop {..., A_in_mem, A_out_mem, ..., A=A_in, ...} ... do { -- ... -- in {..., A_out_mem, A_mem, ..., A_out, ...} -- } -- @ -- -- The result is essentially "pointer swapping" between the two memory -- initial blocks @A_mem@ and @A_out_mem@. The invariant is that the -- array is always stored in the "first" memory block at the beginning -- of the loop (and also in the final result). We do need to add an -- extra element to the pattern, however. The initial copy of @A@ -- could be elided if @A@ is unique (thus @A_in_mem=A_mem@). This is -- because only then is it safe to use @A_mem@ to store loop results. -- We don't currently do this. -- -- Unfortunately, not all loops fit the pattern above. In particular, -- a nested loop that has been transformed as such does not! -- Therefore we also have another double buffering strategy, that -- turns -- -- @ -- loop {..., A_mem, ..., A, ...} ... do { -- ... -- let A_out_mem = alloc(...) -- -- A in A_out_mem -- in {..., A_out_mem, ..., A, ...} -- } -- @ -- -- into -- -- @ -- let A_res_mem = alloc(...) -- loop {..., A_mem, ..., A, ...} ... do { -- ... -- let A_out_mem = alloc(...) -- -- A in A_out_mem -- let A' = copy A -- -- A' in A_res_mem -- in {..., A_res_mem, ..., A, ...} -- } -- @ -- -- The allocation of A_out_mem can then be hoisted out because it is -- dead at the end of the loop. This always works as long as -- A_out_mem has a loop-invariant allocation size, but requires a copy -- per iteration (and an initial one, elided above). module Futhark.Optimise.DoubleBuffer (doubleBufferGPU, doubleBufferMC) where import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Construct import Futhark.IR.GPUMem as GPU import Futhark.IR.MCMem as MC import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.Pass import Futhark.Pass.ExplicitAllocations.GPU () import Futhark.Transform.Substitute import Futhark.Util (mapAccumLM) type OptimiseLoop rep = Pat (LetDec rep) -> [(FParam rep, SubExp)] -> Body rep -> DoubleBufferM rep ( Stms rep, Pat (LetDec rep), [(FParam rep, SubExp)], Body rep ) type OptimiseOp rep = Op rep -> DoubleBufferM rep (Op rep) data Env rep = Env { envScope :: Scope rep, envOptimiseLoop :: OptimiseLoop rep, envOptimiseOp :: OptimiseOp rep } newtype DoubleBufferM rep a = DoubleBufferM { runDoubleBufferM :: ReaderT (Env rep) (State VNameSource) a } deriving (Functor, Applicative, Monad, MonadReader (Env rep), MonadFreshNames) instance (ASTRep rep) => HasScope rep (DoubleBufferM rep) where askScope = asks envScope instance (ASTRep rep) => LocalScope rep (DoubleBufferM rep) where localScope scope = local $ \env -> env {envScope = envScope env <> scope} optimiseBody :: (ASTRep rep) => Body rep -> DoubleBufferM rep (Body rep) optimiseBody body = do stms' <- optimiseStms $ stmsToList $ bodyStms body pure $ body {bodyStms = stms'} optimiseStms :: (ASTRep rep) => [Stm rep] -> DoubleBufferM rep (Stms rep) optimiseStms [] = pure mempty optimiseStms (e : es) = do e_es <- optimiseStm e es' <- localScope (castScope $ scopeOf e_es) $ optimiseStms es pure $ e_es <> es' optimiseStm :: forall rep. (ASTRep rep) => Stm rep -> DoubleBufferM rep (Stms rep) optimiseStm (Let pat aux (Loop merge form body)) = do body' <- localScope (scopeOfLoopForm form <> scopeOfFParams (map fst merge)) $ optimiseBody body opt_loop <- asks envOptimiseLoop (stms, pat', merge', body'') <- opt_loop pat merge body' pure $ stms <> oneStm (Let pat' aux $ Loop merge' form body'') optimiseStm (Let pat aux e) = do onOp <- asks envOptimiseOp oneStm . Let pat aux <$> mapExpM (optimise onOp) e where optimise onOp = (identityMapper @rep) { mapOnBody = \_ x -> optimiseBody x :: DoubleBufferM rep (Body rep), mapOnOp = onOp } optimiseGPUOp :: OptimiseOp GPUMem optimiseGPUOp (Inner (SegOp op)) = local inSegOp $ Inner . SegOp <$> mapSegOpM mapper op where mapper = identitySegOpMapper { mapOnSegOpLambda = optimiseLambda, mapOnSegOpBody = optimiseKernelBody } inSegOp env = env {envOptimiseLoop = optimiseLoop} optimiseGPUOp op = pure op optimiseMCOp :: OptimiseOp MCMem optimiseMCOp (Inner (ParOp par_op op)) = local inSegOp $ Inner <$> (ParOp <$> traverse (mapSegOpM mapper) par_op <*> mapSegOpM mapper op) where mapper = identitySegOpMapper { mapOnSegOpLambda = optimiseLambda, mapOnSegOpBody = optimiseKernelBody } inSegOp env = env {envOptimiseLoop = optimiseLoop} optimiseMCOp op = pure op optimiseKernelBody :: (ASTRep rep) => KernelBody rep -> DoubleBufferM rep (KernelBody rep) optimiseKernelBody kbody = do stms' <- optimiseStms $ stmsToList $ kernelBodyStms kbody pure $ kbody {kernelBodyStms = stms'} optimiseLambda :: (ASTRep rep) => Lambda rep -> DoubleBufferM rep (Lambda rep) optimiseLambda lam = do body <- localScope (castScope $ scopeOf lam) $ optimiseBody $ lambdaBody lam pure lam {lambdaBody = body} type Constraints rep inner = ( Mem rep inner, BuilderOps rep, ExpDec rep ~ (), BodyDec rep ~ (), LetDec rep ~ LetDecMem ) extractAllocOf :: (Constraints rep inner) => Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep) extractAllocOf bound needle stms = do (stm, stms') <- stmsHead stms case stm of Let (Pat [pe]) _ (Op (Alloc size _)) | patElemName pe == needle, invariant size -> Just (stm, stms') _ -> let bound' = namesFromList (patNames (stmPat stm)) <> bound in second (oneStm stm <>) <$> extractAllocOf bound' needle stms' where invariant Constant {} = True invariant (Var v) = v `notNameIn` bound isArrayIn :: VName -> Param FParamMem -> Bool isArrayIn x (Param _ _ (MemArray _ _ _ (ArrayIn y _))) = x == y isArrayIn _ _ = False doubleBufferSpace :: Space -> Bool doubleBufferSpace ScalarSpace {} = False doubleBufferSpace _ = True optimiseLoop :: (Constraints rep inner) => OptimiseLoop rep optimiseLoop (Pat pes) merge body@(Body _ body_stms body_res) = do ((pat', merge', body'), outer_stms) <- runBuilder $ do ((param_changes, body_stms'), (pes', merge', body_res')) <- second unzip3 <$> mapAccumLM check (id, body_stms) (zip3 pes merge body_res) pure ( Pat $ mconcat pes', map param_changes $ mconcat merge', Body () body_stms' $ mconcat body_res' ) pure (outer_stms, pat', merge', body') where bound_in_loop = namesFromList (map (paramName . fst) merge) <> boundInBody body findLmadOfArray v = listToMaybe . mapMaybe onStm $ stmsToList body_stms where onStm = listToMaybe . mapMaybe onPatElem . patElems . stmPat onPatElem (PatElem pe_v (MemArray _ _ _ (ArrayIn _ lmad))) | v == pe_v, not $ bound_in_loop `namesIntersect` freeIn lmad = Just lmad onPatElem _ = Nothing changeParam p_needle new (p, p_initial) = if p == p_needle then new else (p, p_initial) check (param_changes, body_stms') (pe, (param, arg), res) | Mem space <- paramType param, doubleBufferSpace space, Var arg_v <- arg, -- XXX: what happens if there are multiple arrays in the same -- memory block? [((arr_param, Var arr_param_initial), Var arr_v)] <- filter (isArrayIn (paramName param) . fst . fst) (zip merge $ map resSubExp body_res), MemArray pt shape _ (ArrayIn _ param_lmad) <- paramDec arr_param, Var arr_mem_out <- resSubExp res, Just arr_lmad <- findLmadOfArray arr_v, Just (arr_mem_out_alloc, body_stms'') <- extractAllocOf bound_in_loop arr_mem_out body_stms' = do -- Put the allocations outside the loop. num_bytes <- letSubExp "num_bytes" =<< toExp (primByteSize pt * (1 + LMAD.range arr_lmad)) arr_mem_in <- letExp (baseString arg_v <> "_in") $ Op $ Alloc num_bytes space addStm arr_mem_out_alloc -- Construct additional pattern element and parameter for -- the memory block that is not used afterwards. pe_unused <- PatElem <$> newVName (baseString (patElemName pe) <> "_unused") <*> pure (MemMem space) param_out <- newParam (baseString (paramName param) <> "_out") (MemMem space) -- Copy the initial array value to the input memory, with -- the same index function as the result. arr_v_copy <- newVName $ baseString arr_v <> "_db_copy" let arr_initial_info = MemArray pt shape NoUniqueness $ ArrayIn arr_mem_in arr_lmad arr_initial_pe = PatElem arr_v_copy arr_initial_info addStm . Let (Pat [arr_initial_pe]) (defAux ()) . BasicOp $ Replicate mempty (Var arr_param_initial) -- AS a trick we must make the array parameter Unique to -- avoid unfortunate hoisting (see #1533) because we are -- invalidating the underlying memory. let arr_param' = Param mempty (paramName arr_param) $ MemArray pt shape Unique (ArrayIn (paramName param) param_lmad) -- We must also update the initial values of the parameters -- used in the index function of this array parameter, such -- that they match the result. let mkUpdate lmad_v = case L.find ((== lmad_v) . paramName . fst . fst) $ zip merge body_res of Nothing -> id Just ((p, _), p_res) -> changeParam p (p, resSubExp p_res) updateLmadParam = foldl (.) id $ map mkUpdate $ namesToList $ freeIn param_lmad pure ( ( updateLmadParam . changeParam arr_param (arr_param', Var arr_v_copy) . param_changes, substituteNames (M.singleton arr_mem_out (paramName param_out)) body_stms'' ), ( [pe, pe_unused], [(param, Var arr_mem_in), (param_out, Var arr_mem_out)], [ res {resSubExp = Var $ paramName param_out}, subExpRes $ Var $ paramName param ] ) ) | otherwise = pure ( (param_changes, body_stms'), ([pe], [(param, arg)], [res]) ) -- | The double buffering pass definition. doubleBuffer :: (Mem rep inner) => String -> String -> OptimiseOp rep -> Pass rep rep doubleBuffer name desc onOp = Pass { passName = name, passDescription = desc, passFunction = intraproceduralTransformation optimise } where optimise scope stms = modifyNameSource $ \src -> let m = runDoubleBufferM $ localScope scope $ optimiseStms $ stmsToList stms in runState (runReaderT m env) src env = Env mempty doNotTouchLoop onOp doNotTouchLoop pat merge body = pure (mempty, pat, merge, body) -- | The pass for GPU kernels. doubleBufferGPU :: Pass GPUMem GPUMem doubleBufferGPU = doubleBuffer "Double buffer GPU" "Double buffer memory in sequential loops (GPU rep)." optimiseGPUOp -- | The pass for multicore doubleBufferMC :: Pass MCMem MCMem doubleBufferMC = doubleBuffer "Double buffer MC" "Double buffer memory in sequential loops (MC rep)." optimiseMCOp futhark-0.25.27/src/Futhark/Optimise/EntryPointMem.hs000066400000000000000000000047571475065116200224510ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | We require that entry points return arrays with zero offset in -- row-major order. "Futhark.Pass.ExplicitAllocations" is -- conservative and inserts copies to ensure this is the case. After -- simplification, it may turn out that those copies are redundant. -- This pass removes them. It's a pretty simple pass, as it only has -- to look at the top level of entry points. module Futhark.Optimise.EntryPointMem ( entryPointMemGPU, entryPointMemMC, entryPointMemSeq, ) where import Data.List (find) import Data.Map.Strict qualified as M import Futhark.IR.GPUMem (GPUMem) import Futhark.IR.MCMem (MCMem) import Futhark.IR.Mem import Futhark.IR.SeqMem (SeqMem) import Futhark.Pass import Futhark.Pass.ExplicitAllocations.GPU () import Futhark.Transform.Substitute type Table rep = M.Map VName (Stm rep) mkTable :: Stms rep -> Table rep mkTable = foldMap f where f stm = M.fromList $ map (,stm) (patNames (stmPat stm)) varInfo :: (Mem rep inner) => VName -> Table rep -> Maybe (LetDecMem, Exp rep) varInfo v table = do Let pat _ e <- M.lookup v table PatElem _ info <- find ((== v) . patElemName) (patElems pat) Just (letDecMem info, e) optimiseFun :: (Mem rep inner) => Table rep -> FunDef rep -> FunDef rep optimiseFun consts_table fd = fd {funDefBody = onBody $ funDefBody fd} where table = consts_table <> mkTable (bodyStms (funDefBody fd)) mkSubst (Var v0) | Just (MemArray _ _ _ (ArrayIn mem0 lmad0), BasicOp (Manifest _ v1)) <- varInfo v0 table, Just (MemArray _ _ _ (ArrayIn mem1 lmad1), _) <- varInfo v1 table, lmad0 == lmad1 = M.fromList [(mem0, mem1), (v0, v1)] mkSubst _ = mempty onBody (Body dec stms res) = let substs = mconcat $ map (mkSubst . resSubExp) res in Body dec stms $ substituteNames substs res entryPointMem :: (Mem rep inner) => Pass rep rep entryPointMem = Pass { passName = "Entry point memory optimisation", passDescription = "Remove redundant copies of entry point results.", passFunction = intraproceduralTransformationWithConsts pure onFun } where onFun consts fd = pure $ optimiseFun (mkTable consts) fd -- | The pass for GPU representation. entryPointMemGPU :: Pass GPUMem GPUMem entryPointMemGPU = entryPointMem -- | The pass for MC representation. entryPointMemMC :: Pass MCMem MCMem entryPointMemMC = entryPointMem -- | The pass for Seq representation. entryPointMemSeq :: Pass SeqMem SeqMem entryPointMemSeq = entryPointMem futhark-0.25.27/src/Futhark/Optimise/Fusion.hs000066400000000000000000000576021475065116200211370ustar00rootroot00000000000000{-# LANGUAGE Strict #-} -- | Perform horizontal and vertical fusion of SOACs. See the paper -- /A T2 Graph-Reduction Approach To Fusion/ for the basic idea (some -- extensions discussed in /Design and GPGPU Performance of Futhark’s -- Redomap Construct/). module Futhark.Optimise.Fusion (fuseSOACs) where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.Graph.Inductive.Graph qualified as G import Data.Graph.Inductive.Query.DFS qualified as Q import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Analysis.Alias qualified as Alias import Futhark.Analysis.HORep.SOAC qualified as H import Futhark.Construct import Futhark.IR.Prop.Aliases import Futhark.IR.SOACS hiding (SOAC (..)) import Futhark.IR.SOACS qualified as Futhark import Futhark.IR.SOACS.Simplify (simplifyLambda) import Futhark.Optimise.Fusion.GraphRep import Futhark.Optimise.Fusion.RulesWithAccs qualified as SF import Futhark.Optimise.Fusion.TryFusion qualified as TF import Futhark.Pass import Futhark.Transform.Rename import Futhark.Transform.Substitute data FusionEnv = FusionEnv { vNameSource :: VNameSource, fusionCount :: Int, fuseScans :: Bool } freshFusionEnv :: FusionEnv freshFusionEnv = FusionEnv { vNameSource = blankNameSource, fusionCount = 0, fuseScans = True } newtype FusionM a = FusionM (ReaderT (Scope SOACS) (State FusionEnv) a) deriving ( Monad, Applicative, Functor, MonadState FusionEnv, HasScope SOACS, LocalScope SOACS ) instance MonadFreshNames FusionM where getNameSource = gets vNameSource putNameSource source = modify (\env -> env {vNameSource = source}) runFusionM :: (MonadFreshNames m) => Scope SOACS -> FusionEnv -> FusionM a -> m a runFusionM scope fenv (FusionM a) = modifyNameSource $ \src -> let x = runReaderT a scope (y, z) = runState x (fenv {vNameSource = src}) in (y, vNameSource z) doFuseScans :: FusionM a -> FusionM a doFuseScans m = do fs <- gets fuseScans modify (\s -> s {fuseScans = True}) r <- m modify (\s -> s {fuseScans = fs}) pure r dontFuseScans :: FusionM a -> FusionM a dontFuseScans m = do fs <- gets fuseScans modify (\s -> s {fuseScans = False}) r <- m modify (\s -> s {fuseScans = fs}) pure r isNotVarInput :: [H.Input] -> [H.Input] isNotVarInput = filter (isNothing . H.isVarInput) finalizeNode :: (HasScope SOACS m, MonadFreshNames m) => NodeT -> m (Stms SOACS) finalizeNode nt = case nt of StmNode stm -> pure $ oneStm stm SoacNode ots outputs soac aux -> runBuilder_ $ do untransformed_outputs <- mapM newName $ patNames outputs auxing aux $ letBindNames untransformed_outputs . Op =<< H.toSOAC soac forM_ (zip (patNames outputs) untransformed_outputs) $ \(output, v) -> letBindNames [output] . BasicOp . SubExp . Var =<< H.applyTransforms ots v ResNode _ -> pure mempty TransNode output tr ia -> do (cs, e) <- H.transformToExp tr ia runBuilder_ $ certifying cs $ letBindNames [output] e FreeNode _ -> pure mempty DoNode stm lst -> do lst' <- mapM (finalizeNode . fst) lst pure $ mconcat lst' <> oneStm stm MatchNode stm lst -> do lst' <- mapM (finalizeNode . fst) lst pure $ mconcat lst' <> oneStm stm linearizeGraph :: (HasScope SOACS m, MonadFreshNames m) => DepGraph -> m (Stms SOACS) linearizeGraph dg = fmap mconcat $ mapM finalizeNode $ reverse $ Q.topsort' (dgGraph dg) fusedSomething :: NodeT -> FusionM (Maybe NodeT) fusedSomething x = do modify $ \s -> s {fusionCount = 1 + fusionCount s} pure $ Just x vTryFuseNodesInGraph :: G.Node -> G.Node -> DepGraphAug FusionM -- find the neighbors -> verify that fusion causes no cycles -> fuse vTryFuseNodesInGraph node_1 node_2 dg@DepGraph {dgGraph = g} | not (G.gelem node_1 g && G.gelem node_2 g) = pure dg | vFusionFeasability dg node_1 node_2 = do let (ctx1, ctx2) = (G.context g node_1, G.context g node_2) fres <- vFuseContexts edgs infusable_nodes ctx1 ctx2 case fres of Just (inputs, _, nodeT, outputs) -> do nodeT' <- if null fusedC then pure nodeT else do let (_, _, _, deps_1) = ctx1 (_, _, _, deps_2) = ctx2 -- make copies of everything that was not -- previously consumed old_cons = map (getName . fst) $ filter (isCons . fst) (deps_1 <> deps_2) makeCopiesOfFusedExcept old_cons nodeT contractEdge node_2 (inputs, node_1, nodeT', outputs) dg Nothing -> pure dg | otherwise = pure dg where edgs = map G.edgeLabel $ edgesBetween dg node_1 node_2 fusedC = map getName $ filter isCons edgs infusable_nodes = map depsFromEdge (concatMap (edgesBetween dg node_1) (filter (/= node_2) $ G.pre g node_1)) hTryFuseNodesInGraph :: G.Node -> G.Node -> DepGraphAug FusionM hTryFuseNodesInGraph node_1 node_2 dg@DepGraph {dgGraph = g} | not (G.gelem node_1 g && G.gelem node_2 g) = pure dg | hFusionFeasability dg node_1 node_2 = do fres <- hFuseContexts (G.context g node_1) (G.context g node_2) case fres of Just ctx -> contractEdge node_2 ctx dg Nothing -> pure dg | otherwise = pure dg hFuseContexts :: DepContext -> DepContext -> FusionM (Maybe DepContext) hFuseContexts c1 c2 = do let (_, _, nodeT1, _) = c1 (_, _, nodeT2, _) = c2 fres <- hFuseNodeT nodeT1 nodeT2 case fres of Just nodeT -> pure $ Just (mergedContext nodeT c1 c2) Nothing -> pure Nothing vFuseContexts :: [EdgeT] -> [VName] -> DepContext -> DepContext -> FusionM (Maybe DepContext) vFuseContexts edgs infusable c1 c2 = do let (i1, n1, nodeT1, o1) = c1 (_i2, n2, nodeT2, o2) = c2 fres <- vFuseNodeT edgs infusable (nodeT1, map fst $ filter ((/=) n2 . snd) i1, map fst o1) (nodeT2, map fst $ filter ((/=) n1 . snd) o2) case fres of Just nodeT -> pure $ Just (mergedContext nodeT c1 c2) Nothing -> pure Nothing makeCopiesOfFusedExcept :: (LocalScope SOACS m, MonadFreshNames m) => [VName] -> NodeT -> m NodeT makeCopiesOfFusedExcept noCopy (SoacNode ots pats soac aux) = do let lam = H.lambda soac localScope (scopeOf lam) $ do fused_inner <- filterM (fmap (not . isAcc) . lookupType) . namesToList . consumedByLambda $ Alias.analyseLambda mempty lam lam' <- makeCopiesInLambda (fused_inner L.\\ noCopy) lam pure $ SoacNode ots pats (H.setLambda lam' soac) aux makeCopiesOfFusedExcept _ nodeT = pure nodeT makeCopiesInLambda :: (LocalScope SOACS m, MonadFreshNames m) => [VName] -> Lambda SOACS -> m (Lambda SOACS) makeCopiesInLambda toCopy lam = do (copies, nameMap) <- makeCopyStms toCopy let l_body = lambdaBody lam newBody = insertStms copies (substituteNames nameMap l_body) newLambda = lam {lambdaBody = newBody} pure newLambda makeCopyStms :: (LocalScope SOACS m, MonadFreshNames m) => [VName] -> m (Stms SOACS, M.Map VName VName) makeCopyStms vs = do vs' <- mapM makeNewName vs copies <- forM (zip vs vs') $ \(name, name') -> mkLetNames [name'] $ BasicOp $ Replicate mempty $ Var name pure (stmsFromList copies, M.fromList $ zip vs vs') where makeNewName name = newVName $ baseString name <> "_copy" okToFuseProducer :: H.SOAC SOACS -> FusionM Bool okToFuseProducer (H.Screma _ _ form) = do let is_scan = isJust $ Futhark.isScanomapSOAC form gets $ (not is_scan ||) . fuseScans okToFuseProducer _ = pure True -- First node is producer, second is consumer. vFuseNodeT :: [EdgeT] -> [VName] -> (NodeT, [EdgeT], [EdgeT]) -> (NodeT, [EdgeT]) -> FusionM (Maybe NodeT) vFuseNodeT _ infusible (s1, _, e1s) (MatchNode stm2 dfused, _) | isRealNode s1, null infusible = pure $ Just $ MatchNode stm2 $ (s1, e1s) : dfused vFuseNodeT _ infusible (TransNode stm1_out tr stm1_in, _, _) (SoacNode ots2 pats2 soac2 aux2, _) | null infusible = do stm1_in_t <- lookupType stm1_in let onInput inp | H.inputArray inp == stm1_out = H.Input (tr H.<| H.inputTransforms inp) stm1_in stm1_in_t | otherwise = inp soac2' = map onInput (H.inputs soac2) `H.setInputs` soac2 pure $ Just $ SoacNode ots2 pats2 soac2' aux2 vFuseNodeT _ _ (SoacNode ots1 pats1 soac1 aux1, i1s, _e1s) (SoacNode ots2 pats2 soac2 aux2, _e2s) = do let ker = TF.FusedSOAC { TF.fsSOAC = soac2, TF.fsOutputTransform = ots2, TF.fsOutNames = patNames pats2 } preserveEdge InfDep {} = True preserveEdge e = isDep e preserve = namesFromList $ map getName $ filter preserveEdge i1s ok <- okToFuseProducer soac1 r <- if ok && ots1 == mempty then TF.attemptFusion TF.Vertical preserve (patNames pats1) soac1 ker else pure Nothing case r of Just ker' -> do let pats2' = zipWith PatElem (TF.fsOutNames ker') (H.typeOf (TF.fsSOAC ker')) fusedSomething $ SoacNode (TF.fsOutputTransform ker') (Pat pats2') (TF.fsSOAC ker') (aux1 <> aux2) Nothing -> pure Nothing vFuseNodeT _ infusible (SoacNode ots1 pat1 (H.Screma w inps form) aux1, _, _) (TransNode stm2_out (H.Index cs slice@(Slice (ds@(DimSlice _ w' _) : ds_rest))) _, _) | null infusible, w /= w', ots1 == mempty, Just _ <- isMapSOAC form, [pe] <- patElems pat1 = do let out_t = patElemType pe `setArrayShape` sliceShape slice inps' = map sliceInput inps -- Even if we move the slice of the outermost dimension, there -- might still be some slicing of the inner ones. ots1' = ots1 H.|> H.Index cs (Slice (sliceDim w' : ds_rest)) fusedSomething $ SoacNode ots1' (Pat [PatElem stm2_out out_t]) (H.Screma w' inps' form) aux1 where sliceInput inp = H.addTransform (H.Index cs (fullSlice (H.inputType inp) [ds])) inp -- Case of fusing a screma with an WithAcc such as to (hopefully) perform -- more fusion within the WithAcc. This would allow the withAcc to move in -- the code (since up to now they mostly remain where they were introduced.) -- We conservatively allow the fusion to fire---i.e., to move the soac inside -- the withAcc---when the following are not part of withAcc's accumulators: -- 1. the in-dependencies of the soac and -- 2. the result of the soac -- Note that the soac result is allowed to be part of the `infusible` -- for as long as it is returned by the withAcc. If `infusible` is empty -- then the extranous result will be simplified away. vFuseNodeT _edges _infusible (SoacNode ots1 pat1 soac@(H.Screma _w _form _s_inps) aux1, is1, _os1) (StmNode (Let pat2 aux2 (WithAcc w_inps lam0)), _os2) | ots1 == mempty, wacc_cons_nms <- namesFromList $ concatMap (\(_, nms, _) -> nms) w_inps, soac_prod_nms <- map patElemName $ patElems pat1, soac_indep_nms <- map getName is1, all (`notNameIn` wacc_cons_nms) (soac_indep_nms ++ soac_prod_nms) = do lam <- fst <$> doFusionInLambda lam0 bdy' <- runBodyBuilder $ inScopeOf lam $ do soac' <- H.toExp soac addStm $ Let pat1 aux1 soac' lam_res <- bodyBind $ lambdaBody lam let pat1_res = map (SubExpRes (Certs []) . Var) soac_prod_nms pure $ lam_res ++ pat1_res let lam_ret_tp = lambdaReturnType lam ++ map patElemType (patElems pat1) pat = Pat $ patElems pat2 ++ patElems pat1 lam' <- renameLambda $ lam {lambdaBody = bdy', lambdaReturnType = lam_ret_tp} -- see if bringing the map inside the scatter has actually benefitted fusion (lam'', success) <- doFusionInLambda lam' if not success then pure Nothing else do -- `aux1` already appear in the moved SOAC stm; is there -- any need to add it to the enclosing withAcc stm as well? fusedSomething $ StmNode $ Let pat aux2 $ WithAcc w_inps lam'' -- -- The reverse of the case above, i.e., fusing a screma at the back of an -- WithAcc such as to (hopefully) enable more fusion there. -- This should be safe as long as the SOAC does not uses any of the -- accumulator arrays produced by the withAcc. -- We could not provide a test for this case, due to the very restrictive -- way in which accumulators can be used at source level. -- -- vFuseNodeT edges _infusible (StmNode (Let pat1 aux1 (WithAcc w_inps wlam0)), _is1, _os1) (SoacNode ots2 pat2 soac@(H.Screma _w _form _s_inps) aux2, _os2) | ots2 == mempty, n <- length (lambdaParams wlam0) `div` 2, pat1_acc_nms <- namesFromList $ take n $ map patElemName $ patElems pat1, -- not $ namesIntersect (freeIn soac) pat1_acc_nms all ((`notNameIn` pat1_acc_nms) . getName) edges = do let empty_aux = StmAux mempty mempty mempty wlam <- fst <$> doFusionInLambda wlam0 bdy' <- runBodyBuilder $ inScopeOf wlam $ do -- adding stms of withacc's lambda wlam_res <- bodyBind $ lambdaBody wlam -- add copies of the non-accumulator results of withacc let other_pr1 = drop n $ zip (patElems pat1) wlam_res forM_ other_pr1 $ \(pat_elm, bdy_res) -> do let (nm, se, tp) = (patElemName pat_elm, resSubExp bdy_res, patElemType pat_elm) aux = empty_aux {stmAuxCerts = resCerts bdy_res} addStm $ Let (Pat [PatElem nm tp]) aux $ BasicOp $ SubExp se -- add the soac stmt soac' <- H.toExp soac addStm $ Let pat2 aux2 soac' -- build the body result let pat2_res = map (SubExpRes (Certs []) . Var . patElemName) $ patElems pat2 pure $ wlam_res ++ pat2_res let lam_ret_tp = lambdaReturnType wlam ++ map patElemType (patElems pat2) pat = Pat $ patElems pat1 ++ patElems pat2 wlam' <- renameLambda $ wlam {lambdaBody = bdy', lambdaReturnType = lam_ret_tp} -- see if bringing the map inside the scatter has actually benefitted fusion (wlam'', success) <- doFusionInLambda wlam' if not success then pure Nothing else -- `aux2` already appear in the enclosed SOAC stm; is there -- any need to add it to the enclosing withAcc stm as well? fusedSomething $ StmNode $ Let pat aux1 $ WithAcc w_inps wlam'' -- the case of fusing two withaccs vFuseNodeT _edges infusible (StmNode (Let pat1 aux1 (WithAcc w1_inps lam1)), is1, _os1) (StmNode (Let pat2 aux2 (WithAcc w2_inps lam2)), _os2) | wacc2_cons_nms <- namesFromList $ concatMap (\(_, nms, _) -> nms) w2_inps, wacc1_indep_nms <- map getName is1, all (`notNameIn` wacc2_cons_nms) wacc1_indep_nms = do -- \^ the other safety checks are done inside `tryFuseWithAccs` lam1' <- fst <$> doFusionInLambda lam1 lam2' <- fst <$> doFusionInLambda lam2 let stm1 = Let pat1 aux1 (WithAcc w1_inps lam1') stm2 = Let pat2 aux2 (WithAcc w2_inps lam2') mstm <- SF.tryFuseWithAccs infusible stm1 stm2 case mstm of Just (Let pat aux (WithAcc w_inps wlam)) -> do (wlam', success) <- doFusionInLambda wlam let new_stm = Let pat aux (WithAcc w_inps wlam') if success then fusedSomething (StmNode new_stm) else pure Nothing _ -> error "Illegal result of tryFuseWithAccs called from vFuseNodeT." -- vFuseNodeT _ _ _ _ = pure Nothing resFromLambda :: Lambda rep -> Result resFromLambda = bodyResult . lambdaBody hasNoDifferingInputs :: [H.Input] -> [H.Input] -> Bool hasNoDifferingInputs is1 is2 = let (vs1, vs2) = (isNotVarInput is1, isNotVarInput $ is2 L.\\ is1) in null $ vs1 `L.intersect` vs2 hFuseNodeT :: NodeT -> NodeT -> FusionM (Maybe NodeT) hFuseNodeT (SoacNode ots1 pats1 soac1 aux1) (SoacNode ots2 pats2 soac2 aux2) | ots1 == mempty, ots2 == mempty, hasNoDifferingInputs (H.inputs soac1) (H.inputs soac2) = do let ker = TF.FusedSOAC { TF.fsSOAC = soac2, TF.fsOutputTransform = mempty, TF.fsOutNames = patNames pats2 } preserve = namesFromList $ patNames pats1 r <- TF.attemptFusion TF.Horizontal preserve (patNames pats1) soac1 ker case r of Just ker' -> do let pats2' = zipWith PatElem (TF.fsOutNames ker') (H.typeOf (TF.fsSOAC ker')) fusedSomething $ SoacNode mempty (Pat pats2') (TF.fsSOAC ker') (aux1 <> aux2) Nothing -> pure Nothing hFuseNodeT _ _ = pure Nothing removeOutputsExcept :: [VName] -> NodeT -> NodeT removeOutputsExcept toKeep s = case s of SoacNode ots (Pat pats1) soac@(H.Screma _ _ (ScremaForm lam_1 scans_1 red_1)) aux1 -> SoacNode ots (Pat $ pats_unchanged <> pats_new) (H.setLambda lam_new soac) aux1 where scan_output_size = Futhark.scanResults scans_1 red_output_size = Futhark.redResults red_1 (pats_unchanged, pats_toChange) = splitAt (scan_output_size + red_output_size) pats1 (res_unchanged, res_toChange) = splitAt (scan_output_size + red_output_size) (zip (resFromLambda lam_1) (lambdaReturnType lam_1)) (pats_new, other) = unzip $ filter (\(x, _) -> patElemName x `elem` toKeep) (zip pats_toChange res_toChange) (results, types) = unzip (res_unchanged ++ other) lam_new = lam_1 { lambdaReturnType = types, lambdaBody = (lambdaBody lam_1) {bodyResult = results} } node -> node vNameFromAdj :: G.Node -> (EdgeT, G.Node) -> VName vNameFromAdj n1 (edge, n2) = depsFromEdge (n2, n1, edge) removeUnusedOutputsFromContext :: DepContext -> FusionM DepContext removeUnusedOutputsFromContext (incoming, n1, nodeT, outgoing) = pure (incoming, n1, nodeT', outgoing) where toKeep = map (vNameFromAdj n1) incoming nodeT' = removeOutputsExcept toKeep nodeT removeUnusedOutputs :: DepGraphAug FusionM removeUnusedOutputs = mapAcross removeUnusedOutputsFromContext tryFuseNodeInGraph :: DepNode -> DepGraphAug FusionM tryFuseNodeInGraph node_to_fuse dg@DepGraph {dgGraph = g} | not (G.gelem (nodeFromLNode node_to_fuse) g) = pure dg -- \^ Node might have been fused away since. tryFuseNodeInGraph node_to_fuse dg@DepGraph {dgGraph = g} = do spec_rule_res <- SF.ruleMFScat node_to_fuse dg -- \^ specialized fusion rules such as the one -- enabling map-flatten-scatter fusion case spec_rule_res of Just dg' -> pure dg' Nothing -> applyAugs (map (vTryFuseNodesInGraph node_to_fuse_id) fuses_with) dg where node_to_fuse_id = nodeFromLNode node_to_fuse relevant (n, InfDep _) = isWithAccNodeId n dg relevant (_, e) = isDep e fuses_with = map fst $ filter relevant $ G.lpre g node_to_fuse_id doVerticalFusion :: DepGraphAug FusionM doVerticalFusion dg = applyAugs (map tryFuseNodeInGraph $ reverse $ filter relevant $ G.labNodes (dgGraph dg)) dg where relevant (_, n@(StmNode {})) = isWithAccNodeT n relevant (_, ResNode {}) = False relevant _ = True -- | For each pair of SOAC nodes that share an input, attempt to fuse -- them horizontally. doHorizontalFusion :: DepGraphAug FusionM doHorizontalFusion dg = applyAugs pairs dg where pairs :: [DepGraphAug FusionM] pairs = do (x, SoacNode _ _ soac_x _) <- G.labNodes $ dgGraph dg (y, SoacNode _ _ soac_y _) <- G.labNodes $ dgGraph dg guard $ x < y -- Must share an input. guard $ any ((`elem` map H.inputArray (H.inputs soac_x)) . H.inputArray) (H.inputs soac_y) pure $ \dg' -> do -- Nodes might have been fused away by now. if G.gelem x (dgGraph dg') && G.gelem y (dgGraph dg') then hTryFuseNodesInGraph x y dg' else pure dg' doInnerFusion :: DepGraphAug FusionM doInnerFusion = mapAcross runInnerFusionOnContext -- Fixed-point iteration. keepTrying :: DepGraphAug FusionM -> DepGraphAug FusionM keepTrying f g = do prev_fused <- gets fusionCount g' <- f g aft_fused <- gets fusionCount if prev_fused /= aft_fused then keepTrying f g' else pure g' doAllFusion :: DepGraphAug FusionM doAllFusion = applyAugs [ keepTrying . applyAugs $ [ doVerticalFusion, doHorizontalFusion, doInnerFusion ], removeUnusedOutputs ] runInnerFusionOnContext :: DepContext -> FusionM DepContext runInnerFusionOnContext c@(incoming, node, nodeT, outgoing) = case nodeT of DoNode (Let pat aux (Loop params form body)) to_fuse -> doFuseScans . localScope (scopeOfFParams (map fst params) <> scopeOfLoopForm form) $ do b <- doFusionWithDelayed body to_fuse pure (incoming, node, DoNode (Let pat aux (Loop params form b)) [], outgoing) MatchNode (Let pat aux (Match cond cases defbody dec)) to_fuse -> doFuseScans $ do cases' <- mapM (traverse $ renameBody <=< (`doFusionWithDelayed` to_fuse)) cases defbody' <- doFusionWithDelayed defbody to_fuse pure (incoming, node, MatchNode (Let pat aux (Match cond cases' defbody' dec)) [], outgoing) StmNode (Let pat aux (Op (Futhark.VJP args vec lam))) -> doFuseScans $ do lam' <- fst <$> doFusionInLambda lam pure (incoming, node, StmNode (Let pat aux (Op (Futhark.VJP args vec lam'))), outgoing) StmNode (Let pat aux (Op (Futhark.JVP args vec lam))) -> doFuseScans $ do lam' <- fst <$> doFusionInLambda lam pure (incoming, node, StmNode (Let pat aux (Op (Futhark.JVP args vec lam'))), outgoing) StmNode (Let pat aux (WithAcc inputs lam)) -> doFuseScans $ do lam' <- fst <$> doFusionInLambda lam pure (incoming, node, StmNode (Let pat aux (WithAcc inputs lam')), outgoing) SoacNode ots pat soac aux -> do let lam = H.lambda soac lam' <- inScopeOf lam $ case soac of H.Stream {} -> dontFuseScans $ fst <$> doFusionInLambda lam _ -> doFuseScans $ fst <$> doFusionInLambda lam let nodeT' = SoacNode ots pat (H.setLambda lam' soac) aux pure (incoming, node, nodeT', outgoing) _ -> pure c where doFusionWithDelayed :: Body SOACS -> [(NodeT, [EdgeT])] -> FusionM (Body SOACS) doFusionWithDelayed (Body () stms res) extraNodes = inScopeOf stms $ do stm_node <- mapM (finalizeNode . fst) extraNodes stms' <- fuseGraph (mkBody (mconcat stm_node <> stms) res) pure $ Body () stms' res doFusionInLambda :: Lambda SOACS -> FusionM (Lambda SOACS, Bool) doFusionInLambda lam = do -- To clean up previous instances of fusion. lam' <- simplifyLambda lam prev_count <- gets fusionCount newbody <- inScopeOf lam' $ doFusionBody $ lambdaBody lam' aft_count <- gets fusionCount -- To clean up any inner fusion. lam'' <- (if prev_count /= aft_count then simplifyLambda else pure) lam' {lambdaBody = newbody} pure (lam'', prev_count /= aft_count) where doFusionBody :: Body SOACS -> FusionM (Body SOACS) doFusionBody body = do stms' <- fuseGraph body pure $ body {bodyStms = stms'} -- main fusion function. fuseGraph :: Body SOACS -> FusionM (Stms SOACS) fuseGraph body = inScopeOf (bodyStms body) $ do graph_not_fused <- mkDepGraph body graph_fused <- doAllFusion graph_not_fused linearizeGraph graph_fused fuseConsts :: [VName] -> Stms SOACS -> PassM (Stms SOACS) fuseConsts outputs stms = runFusionM (scopeOf stms) freshFusionEnv (fuseGraph (mkBody stms (varsRes outputs))) fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS) fuseFun consts fun = do fun_stms' <- runFusionM (scopeOf fun <> scopeOf consts) freshFusionEnv (fuseGraph (funDefBody fun)) pure fun {funDefBody = (funDefBody fun) {bodyStms = fun_stms'}} -- | The pass definition. {-# NOINLINE fuseSOACs #-} fuseSOACs :: Pass SOACS SOACS fuseSOACs = Pass { passName = "Fuse SOACs", passDescription = "Perform higher-order optimisation, i.e., fusion.", passFunction = \p -> intraproceduralTransformationWithConsts (fuseConsts (namesToList $ freeIn (progFuns p))) fuseFun p } futhark-0.25.27/src/Futhark/Optimise/Fusion/000077500000000000000000000000001475065116200205715ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Optimise/Fusion/Composing.hs000066400000000000000000000230121475065116200230610ustar00rootroot00000000000000-- | Facilities for composing SOAC functions. Mostly intended for use -- by the fusion module, but factored into a separate module for ease -- of testing, debugging and development. Of course, there is nothing -- preventing you from using the exported functions whereever you -- want. -- -- Important: this module is \"dumb\" in the sense that it does not -- check the validity of its inputs, and does not have any -- functionality for massaging SOACs to be fusible. It is assumed -- that the given SOACs are immediately compatible. -- -- The module will, however, remove duplicate inputs after fusion. module Futhark.Optimise.Fusion.Composing ( fuseMaps, fuseRedomap, ) where import Data.List (mapAccumL) import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Analysis.HORep.SOAC qualified as SOAC import Futhark.Builder (Buildable (..), insertStm, insertStms, mkLet) import Futhark.Construct (mapResult) import Futhark.IR import Futhark.Util (dropLast, splitAt3, takeLast) -- | @fuseMaps lam1 inp1 out1 lam2 inp2@ fuses the function @lam1@ into -- @lam2@. Both functions must be mapping functions, although @lam2@ -- may have leading reduction parameters. @inp1@ and @inp2@ are the -- array inputs to the SOACs containing @lam1@ and @lam2@ -- respectively. @out1@ are the identifiers to which the output of -- the SOAC containing @lam1@ is bound. It is nonsensical to call -- this function unless the intersection of @out1@ and @inp2@ is -- non-empty. -- -- If @lam2@ accepts more parameters than there are elements in -- @inp2@, it is assumed that the surplus (which are positioned at the -- beginning of the parameter list) are reduction (accumulator) -- parameters, that do not correspond to array elements, and they are -- thus not modified. -- -- The result is the fused function, and a list of the array inputs -- expected by the SOAC containing the fused function. fuseMaps :: (Buildable rep) => -- | The producer var names that still need to be returned Names -> -- | Function of SOAC to be fused. Lambda rep -> -- | Input of SOAC to be fused. [SOAC.Input] -> -- | Output of SOAC to be fused. The -- first identifier is the name of the -- actual output, where the second output -- is an identifier that can be used to -- bind a single element of that output. [(VName, Ident)] -> -- | Function to be fused with. Lambda rep -> -- | Input of SOAC to be fused with. [SOAC.Input] -> -- | The fused lambda and the inputs of -- the resulting SOAC. (Lambda rep, [SOAC.Input]) fuseMaps unfus_nms lam1 inp1 out1 lam2 inp2 = (lam2', M.elems inputmap) where lam2' = lam2 { lambdaParams = [ Param mempty name t | Ident name t <- lam2redparams ++ M.keys inputmap ], lambdaBody = new_body2' } new_body2 = let stms res = [ certify cs $ mkLet [p] $ BasicOp $ SubExp e | (p, SubExpRes cs e) <- zip pat res ] bindLambda res = stmsFromList (stms res) `insertStms` makeCopiesInner (lambdaBody lam2) in makeCopies $ mapResult bindLambda (lambdaBody lam1) new_body2_rses = bodyResult new_body2 new_body2' = new_body2 {bodyResult = new_body2_rses ++ map (varRes . identName) unfus_pat} -- infusible variables are added at the end of the result/pattern/type (lam2redparams, unfus_pat, pat, inputmap, makeCopies, makeCopiesInner) = fuseInputs unfus_nms lam1 inp1 out1 lam2 inp2 -- (unfus_accpat, unfus_arrpat) = splitAt (length unfus_accs) unfus_pat fuseInputs :: (Buildable rep) => Names -> Lambda rep -> [SOAC.Input] -> [(VName, Ident)] -> Lambda rep -> [SOAC.Input] -> ( [Ident], [Ident], [Ident], M.Map Ident SOAC.Input, Body rep -> Body rep, Body rep -> Body rep ) fuseInputs unfus_nms lam1 inp1 out1 lam2 inp2 = (lam2redparams, unfus_vars, outstms, inputmap, makeCopies, makeCopiesInner) where (lam2redparams, lam2arrparams) = splitAt (length lam2params - length inp2) lam2params lam1params = map paramIdent $ lambdaParams lam1 lam2params = map paramIdent $ lambdaParams lam2 lam1inputmap = M.fromList $ zip lam1params inp1 lam2inputmap = M.fromList $ zip lam2arrparams inp2 (lam2inputmap', makeCopiesInner) = removeDuplicateInputs lam2inputmap originputmap = lam1inputmap `M.union` lam2inputmap' outins = uncurry (outParams $ map fst out1) $ unzip $ M.toList lam2inputmap' outstms = filterOutParams out1 outins (inputmap, makeCopies) = removeDuplicateInputs $ originputmap `M.difference` outins -- Cosmin: @unfus_vars@ is supposed to be the lam2 vars corresponding to unfus_nms (?) getVarParPair x = case SOAC.isVarInput (snd x) of Just nm -> Just (nm, fst x) Nothing -> Nothing -- should not be reached! outinsrev = M.fromList $ mapMaybe getVarParPair $ M.toList outins unfusible outname | outname `nameIn` unfus_nms = outname `M.lookup` M.union outinsrev (M.fromList out1) unfusible _ = Nothing unfus_vars = mapMaybe (unfusible . fst) out1 outParams :: [VName] -> [Ident] -> [SOAC.Input] -> M.Map Ident SOAC.Input outParams out1 lam2arrparams inp2 = M.fromList $ mapMaybe isOutParam $ zip lam2arrparams inp2 where isOutParam (p, inp) | Just a <- SOAC.isVarInput inp, a `elem` out1 = Just (p, inp) isOutParam _ = Nothing filterOutParams :: [(VName, Ident)] -> M.Map Ident SOAC.Input -> [Ident] filterOutParams out1 outins = snd $ mapAccumL checkUsed outUsage out1 where outUsage = M.foldlWithKey' add M.empty outins where add m p inp = case SOAC.isVarInput inp of Just v -> M.insertWith (++) v [p] m Nothing -> m checkUsed m (a, ra) = case M.lookup a m of Just (p : ps) -> (M.insert a ps m, p) _ -> (m, ra) removeDuplicateInputs :: (Buildable rep) => M.Map Ident SOAC.Input -> (M.Map Ident SOAC.Input, Body rep -> Body rep) removeDuplicateInputs = fst . M.foldlWithKey' comb ((M.empty, id), M.empty) where comb ((parmap, inner), arrmap) par arr = case M.lookup arr arrmap of Nothing -> ( (M.insert par arr parmap, inner), M.insert arr (identName par) arrmap ) Just par' -> ( (parmap, inner . forward par par'), arrmap ) forward to from b = mkLet [to] (BasicOp $ SubExp $ Var from) `insertStm` b fuseRedomap :: (Buildable rep) => Names -> [VName] -> Lambda rep -> [SubExp] -> [SubExp] -> [SOAC.Input] -> [(VName, Ident)] -> Lambda rep -> [SubExp] -> [SubExp] -> [SOAC.Input] -> (Lambda rep, [SOAC.Input]) fuseRedomap unfus_nms outVars p_lam p_scan_nes p_red_nes p_inparr outPairs c_lam c_scan_nes c_red_nes c_inparr = -- We hack the implementation of map o redomap to handle this case: -- (i) we remove the accumulator formal paramter and corresponding -- (body) result from from redomap's fold-lambda body let p_num_nes = length p_scan_nes + length p_red_nes unfus_arrs = filter (`nameIn` unfus_nms) outVars p_lam_body = lambdaBody p_lam (p_lam_scan_ts, p_lam_red_ts, p_lam_map_ts) = splitAt3 (length p_scan_nes) (length p_red_nes) $ lambdaReturnType p_lam (p_lam_scan_res, p_lam_red_res, p_lam_map_res) = splitAt3 (length p_scan_nes) (length p_red_nes) $ bodyResult p_lam_body p_lam_hacked = p_lam { lambdaParams = takeLast (length p_inparr) $ lambdaParams p_lam, lambdaBody = p_lam_body {bodyResult = p_lam_map_res}, lambdaReturnType = p_lam_map_ts } -- (ii) we remove the accumulator's (global) output result from -- @outPairs@, then ``map o redomap'' fuse the two lambdas -- (in the usual way), and construct the extra return types -- for the arrays that fall through. (res_lam, new_inp) = fuseMaps (namesFromList unfus_arrs) p_lam_hacked p_inparr (drop p_num_nes outPairs) c_lam c_inparr (res_lam_scan_ts, res_lam_red_ts, res_lam_map_ts) = splitAt3 (length c_scan_nes) (length c_red_nes) $ lambdaReturnType res_lam (_, extra_map_ts) = unzip $ filter (\(nm, _) -> nm `elem` unfus_arrs) $ zip (drop p_num_nes outVars) $ drop p_num_nes $ lambdaReturnType p_lam -- (iii) Finally, we put back the accumulator's formal parameter and -- (body) result in the first position of the obtained lambda. accpars = dropLast (length p_inparr) $ lambdaParams p_lam res_body = lambdaBody res_lam (res_lam_scan_res, res_lam_red_res, res_lam_map_res) = splitAt3 (length c_scan_nes) (length c_red_nes) $ bodyResult res_body res_body' = res_body { bodyResult = p_lam_scan_res ++ res_lam_scan_res ++ p_lam_red_res ++ res_lam_red_res ++ res_lam_map_res } res_lam' = res_lam { lambdaParams = accpars ++ lambdaParams res_lam, lambdaBody = res_body', lambdaReturnType = p_lam_scan_ts ++ res_lam_scan_ts ++ p_lam_red_ts ++ res_lam_red_ts ++ res_lam_map_ts ++ extra_map_ts } in (res_lam', new_inp) futhark-0.25.27/src/Futhark/Optimise/Fusion/GraphRep.hs000066400000000000000000000367101475065116200226440ustar00rootroot00000000000000-- | A graph representation of a sequence of Futhark statements -- (i.e. a 'Body'), built to handle fusion. Could perhaps be made -- more general. An important property is that it does not handle -- "nested bodies" (e.g. 'Match'); these are represented as single -- nodes. -- -- This is all implemented on top of the graph representation provided -- by the @fgl@ package ("Data.Graph.Inductive"). The graph provided -- by this package allows nodes and edges to have arbitrarily-typed -- "labels". It is these labels ('EdgeT', 'NodeT') that we use to -- contain Futhark-specific information. An edge goes *from* -- consumers to producers. There are also edges that do not represent -- normal data dependencies, but other things. This means that a node -- can have multiple edges for the same name, indicating different -- kinds of dependencies. module Futhark.Optimise.Fusion.GraphRep ( -- * Data structure EdgeT (..), NodeT (..), DepContext, DepGraphAug, DepGraph (..), DepNode, -- * Queries getName, nodeFromLNode, mergedContext, mapAcross, edgesBetween, reachable, applyAugs, depsFromEdge, contractEdge, isRealNode, isCons, isDep, isInf, -- * Construction mkDepGraph, mkDepGraphForFun, pprg, isWithAccNodeT, isWithAccNodeId, vFusionFeasability, hFusionFeasability, ) where import Control.Monad.Reader import Data.Bifunctor (bimap) import Data.Foldable (foldlM) import Data.Graph.Inductive.Dot qualified as G import Data.Graph.Inductive.Graph qualified as G import Data.Graph.Inductive.Query.DFS qualified as Q import Data.Graph.Inductive.Tree qualified as G import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe (mapMaybe) import Data.Set qualified as S import Futhark.Analysis.Alias qualified as Alias import Futhark.Analysis.HORep.SOAC qualified as H import Futhark.IR.Prop.Aliases import Futhark.IR.SOACS hiding (SOAC (..)) import Futhark.IR.SOACS qualified as Futhark import Futhark.Util (nubOrd) -- | Information associated with an edge in the graph. data EdgeT = Alias VName | InfDep VName | Dep VName | Cons VName | Fake VName | Res VName deriving (Eq, Ord) -- | Information associated with a node in the graph. data NodeT = StmNode (Stm SOACS) | SoacNode H.ArrayTransforms (Pat Type) (H.SOAC SOACS) (StmAux (ExpDec SOACS)) | -- | First 'VName' is result; last is input. TransNode VName H.ArrayTransform VName | -- | Node corresponding to a result of the entire computation -- (i.e. the 'Result' of a body). Any node that is not -- transitively reachable from one of these can be considered -- dead. ResNode VName | -- | Node corresponding to a free variable. These are used to -- safely handle consumption, which also means we don't have to -- create a node for every free single variable. FreeNode VName | MatchNode (Stm SOACS) [(NodeT, [EdgeT])] | DoNode (Stm SOACS) [(NodeT, [EdgeT])] deriving (Eq) instance Show EdgeT where show (Dep vName) = "Dep " <> prettyString vName show (InfDep vName) = "iDep " <> prettyString vName show (Cons _) = "Cons" show (Fake _) = "Fake" show (Res _) = "Res" show (Alias _) = "Alias" instance Show NodeT where show (StmNode (Let pat _ _)) = L.intercalate ", " $ map prettyString $ patNames pat show (SoacNode _ pat _ _) = prettyString pat show (TransNode _ tr _) = prettyString tr show (ResNode name) = prettyString $ "Res: " ++ prettyString name show (FreeNode name) = prettyString $ "Input: " ++ prettyString name show (MatchNode stm _) = "Match: " ++ L.intercalate ", " (map prettyString $ stmNames stm) show (DoNode stm _) = "Do: " ++ L.intercalate ", " (map prettyString $ stmNames stm) -- | The name that this edge depends on. getName :: EdgeT -> VName getName edgeT = case edgeT of Alias vn -> vn InfDep vn -> vn Dep vn -> vn Cons vn -> vn Fake vn -> vn Res vn -> vn -- | Does the node acutally represent something in the program? A -- "non-real" node represents things like fake nodes inserted to -- express ordering due to consumption. isRealNode :: NodeT -> Bool isRealNode ResNode {} = False isRealNode FreeNode {} = False isRealNode _ = True -- | Prettyprint dependency graph. pprg :: DepGraph -> String pprg = G.showDot . G.fglToDotString . G.nemap show show . dgGraph -- | A pair of a 'G.Node' and the node label. type DepNode = G.LNode NodeT type DepEdge = G.LEdge EdgeT -- | A tuple with four parts: inbound links to the node, the node -- itself, the 'NodeT' "label", and outbound links from the node. -- This type is used to modify the graph in 'mapAcross'. type DepContext = G.Context NodeT EdgeT -- | A dependency graph. Edges go from *consumers* to *producers* -- (i.e. from usage to definition). That means the incoming edges of -- a node are the dependents of that node, and the outgoing edges are -- the dependencies of that node. data DepGraph = DepGraph { dgGraph :: G.Gr NodeT EdgeT, dgProducerMapping :: ProducerMapping, -- | A table mapping VNames to VNames that are aliased to it. dgAliasTable :: AliasTable } -- | A "graph augmentation" is a monadic action that modifies the graph. type DepGraphAug m = DepGraph -> m DepGraph -- | For each node, what producer should the node depend on and what -- type is it. type EdgeGenerator = NodeT -> [(VName, EdgeT)] -- | A mapping from variable name to the graph node that produces -- it. type ProducerMapping = M.Map VName G.Node makeMapping :: (Monad m) => DepGraphAug m makeMapping dg@(DepGraph {dgGraph = g}) = pure dg {dgProducerMapping = M.fromList $ concatMap gen_dep_list (G.labNodes g)} where gen_dep_list :: DepNode -> [(VName, G.Node)] gen_dep_list (i, node) = [(name, i) | name <- getOutputs node] -- | Apply several graph augmentations in sequence. applyAugs :: (Monad m) => [DepGraphAug m] -> DepGraphAug m applyAugs augs g = foldlM (flip ($)) g augs -- | Creates deps for the given nodes on the graph using the 'EdgeGenerator'. genEdges :: (Monad m) => [DepNode] -> EdgeGenerator -> DepGraphAug m genEdges l_stms edge_fun dg = depGraphInsertEdges (concatMap (genEdge (dgProducerMapping dg)) l_stms) dg where -- statements -> mapping from declared array names to soac index genEdge :: M.Map VName G.Node -> DepNode -> [G.LEdge EdgeT] genEdge name_map (from, node) = do (dep, edgeT) <- edge_fun node Just to <- [M.lookup dep name_map] pure $ G.toLEdge (from, to) edgeT depGraphInsertEdges :: (Monad m) => [DepEdge] -> DepGraphAug m depGraphInsertEdges edgs dg = pure $ dg {dgGraph = G.insEdges edgs $ dgGraph dg} -- | Monadically modify every node of the graph. mapAcross :: (Monad m) => (DepContext -> m DepContext) -> DepGraphAug m mapAcross f dg = do g' <- foldlM (flip helper) (dgGraph dg) (G.nodes (dgGraph dg)) pure $ dg {dgGraph = g'} where helper n g' = case G.match n g' of (Just c, g_new) -> do c' <- f c pure $ c' G.& g_new (Nothing, _) -> pure g' stmFromNode :: NodeT -> Stms SOACS -- do not use outside of edge generation stmFromNode (StmNode x) = oneStm x stmFromNode _ = mempty -- | Get the underlying @fgl@ node. nodeFromLNode :: DepNode -> G.Node nodeFromLNode = fst -- | Get the variable name that this edge refers to. depsFromEdge :: DepEdge -> VName depsFromEdge = getName . G.edgeLabel -- | Find all the edges connecting the two nodes. edgesBetween :: DepGraph -> G.Node -> G.Node -> [DepEdge] edgesBetween dg n1 n2 = G.labEdges $ G.subgraph [n1, n2] $ dgGraph dg -- | @reachable dg from to@ is true if @to@ is reachable from @from@. reachable :: DepGraph -> G.Node -> G.Node -> Bool reachable dg source target = target `elem` Q.reachable source (dgGraph dg) -- Utility func for augs augWithFun :: (Monad m) => EdgeGenerator -> DepGraphAug m augWithFun f dg = genEdges (G.labNodes (dgGraph dg)) f dg addDeps :: (Monad m) => DepGraphAug m addDeps = augWithFun toDep where toDep stmt = let (fusible, infusible) = bimap (map fst) (map fst) . L.partition ((== SOACInput) . snd) . S.toList $ foldMap stmInputs (stmFromNode stmt) mkDep vname = (vname, Dep vname) mkInfDep vname = (vname, InfDep vname) in map mkDep fusible <> map mkInfDep infusible addConsAndAliases :: (Monad m) => DepGraphAug m addConsAndAliases = augWithFun edges where edges (StmNode s) = consEdges s' <> aliasEdges s' where s' = Alias.analyseStm mempty s edges _ = mempty consEdges s = zip names (map Cons names) where names = namesToList $ consumedInStm s aliasEdges = map (\vname -> (vname, Alias vname)) . namesToList . mconcat . patAliases . stmPat -- extra dependencies mask the fact that consuming nodes "depend" on all other -- nodes coming before it (now also adds fake edges to aliases - hope this -- fixes asymptotic complexity guarantees) addExtraCons :: (Monad m) => DepGraphAug m addExtraCons dg = depGraphInsertEdges (concatMap makeEdge (G.labEdges g)) dg where g = dgGraph dg alias_table = dgAliasTable dg mapping = dgProducerMapping dg makeEdge (from, to, Cons cname) = do let aliases = namesToList $ M.findWithDefault mempty cname alias_table to' = mapMaybe (`M.lookup` mapping) aliases p (tonode, toedge) = tonode /= from && getName toedge `elem` (cname : aliases) (to2, _) <- filter p $ concatMap (G.lpre g) to' <> G.lpre g to pure $ G.toLEdge (from, to2) (Fake cname) makeEdge _ = [] mapAcrossNodeTs :: (Monad m) => (NodeT -> m NodeT) -> DepGraphAug m mapAcrossNodeTs f = mapAcross f' where f' (ins, n, nodeT, outs) = do nodeT' <- f nodeT pure (ins, n, nodeT', outs) nodeToSoacNode :: (HasScope SOACS m, Monad m) => NodeT -> m NodeT nodeToSoacNode n@(StmNode s@(Let pat aux op)) = case op of Op {} -> do maybeSoac <- H.fromExp op case maybeSoac of Right hsoac -> pure $ SoacNode mempty pat hsoac aux Left H.NotSOAC -> pure n Loop {} -> pure $ DoNode s [] Match {} -> pure $ MatchNode s [] e | [output] <- patNames pat, Just (ia, tr) <- H.transformFromExp (stmAuxCerts aux) e -> pure $ TransNode output tr ia _ -> pure n nodeToSoacNode n = pure n -- | Construct a graph with only nodes, but no edges. emptyGraph :: Body SOACS -> DepGraph emptyGraph body = DepGraph { dgGraph = G.mkGraph (labelNodes (stmnodes <> resnodes <> inputnodes)) [], dgProducerMapping = mempty, dgAliasTable = aliases } where labelNodes = zip [0 ..] stmnodes = map StmNode $ stmsToList $ bodyStms body resnodes = map ResNode $ namesToList $ freeIn $ bodyResult body inputnodes = map FreeNode $ namesToList consumed (_, (aliases, consumed)) = Alias.analyseStms mempty $ bodyStms body getStmRes :: EdgeGenerator getStmRes (ResNode name) = [(name, Res name)] getStmRes _ = [] addResEdges :: (Monad m) => DepGraphAug m addResEdges = augWithFun getStmRes -- | Make a dependency graph corresponding to a 'Body'. mkDepGraph :: (HasScope SOACS m, Monad m) => Body SOACS -> m DepGraph mkDepGraph body = applyAugs augs $ emptyGraph body where augs = [ makeMapping, addDeps, addConsAndAliases, addExtraCons, addResEdges, mapAcrossNodeTs nodeToSoacNode -- Must be done after adding edges ] -- | Make a dependency graph corresponding to a function. mkDepGraphForFun :: FunDef SOACS -> DepGraph mkDepGraphForFun f = runReader (mkDepGraph (funDefBody f)) scope where scope = scopeOfFParams (funDefParams f) <> scopeOf (bodyStms (funDefBody f)) -- | Merges two contexts. mergedContext :: (Ord b) => a -> G.Context a b -> G.Context a b -> G.Context a b mergedContext mergedlabel (inp1, n1, _, out1) (inp2, n2, _, out2) = let new_inp = filter (\n -> snd n /= n1 && snd n /= n2) (nubOrd (inp1 <> inp2)) new_out = filter (\n -> snd n /= n1 && snd n /= n2) (nubOrd (out1 <> out2)) in (new_inp, n1, mergedlabel, new_out) -- | Remove the given node, and insert the 'DepContext' into the -- graph, replacing any existing information about the node contained -- in the 'DepContext'. contractEdge :: (Monad m) => G.Node -> DepContext -> DepGraphAug m contractEdge n2 ctx dg = do let n1 = G.node' ctx -- n1 remains pure $ dg {dgGraph = ctx G.& G.delNodes [n1, n2] (dgGraph dg)} -- Utils for fusibility/infusibility -- find dependencies - either fusible or infusible. edges are generated based on these -- | A classification of a free variable. data Classification = -- | Used as array input to a SOAC (meaning fusible). SOACInput | -- | Used in some other way. Other deriving (Eq, Ord, Show) type Classifications = S.Set (VName, Classification) freeClassifications :: (FreeIn a) => a -> Classifications freeClassifications = S.fromList . (`zip` repeat Other) . namesToList . freeIn stmInputs :: Stm SOACS -> Classifications stmInputs (Let pat aux e) = freeClassifications (pat, aux) <> expInputs e bodyInputs :: Body SOACS -> Classifications bodyInputs (Body _ stms res) = foldMap stmInputs stms <> freeClassifications res expInputs :: Exp SOACS -> Classifications expInputs (Match cond cases defbody attr) = foldMap (bodyInputs . caseBody) cases <> bodyInputs defbody <> freeClassifications (cond, attr) expInputs (Loop params form b1) = freeClassifications (params, form) <> bodyInputs b1 expInputs (Op soac) = case soac of Futhark.Screma w is form -> inputs is <> freeClassifications (w, form) Futhark.Hist w is ops lam -> inputs is <> freeClassifications (w, ops, lam) Futhark.Scatter w is lam iws -> inputs is <> freeClassifications (w, lam, iws) Futhark.Stream w is nes lam -> inputs is <> freeClassifications (w, nes, lam) Futhark.JVP {} -> freeClassifications soac Futhark.VJP {} -> freeClassifications soac where inputs = S.fromList . (`zip` repeat SOACInput) expInputs e | Just (arr, _) <- H.transformFromExp mempty e = S.singleton (arr, SOACInput) <> freeClassifications (freeIn e `namesSubtract` oneName arr) | otherwise = freeClassifications e stmNames :: Stm SOACS -> [VName] stmNames = patNames . stmPat getOutputs :: NodeT -> [VName] getOutputs node = case node of (StmNode stm) -> stmNames stm (TransNode v _ _) -> [v] (ResNode _) -> [] (FreeNode name) -> [name] (MatchNode stm _) -> stmNames stm (DoNode stm _) -> stmNames stm (SoacNode _ pat _ _) -> patNames pat -- | Is there a possibility of fusion? isDep :: EdgeT -> Bool isDep (Dep _) = True isDep (Res _) = True isDep _ = False -- | Is this an infusible edge? isInf :: (G.Node, G.Node, EdgeT) -> Bool isInf (_, _, e) = case e of InfDep _ -> True Fake _ -> True -- this is infusible to avoid simultaneous cons/dep edges _ -> False -- | Is this a 'Cons' edge? isCons :: EdgeT -> Bool isCons (Cons _) = True isCons _ = False -- | Is this a withAcc? isWithAccNodeT :: NodeT -> Bool isWithAccNodeT (StmNode (Let _ _ (WithAcc _ _))) = True isWithAccNodeT _ = False isWithAccNodeId :: G.Node -> DepGraph -> Bool isWithAccNodeId node_id (DepGraph {dgGraph = g}) = let (_, _, nT, _) = G.context g node_id in isWithAccNodeT nT unreachableEitherDir :: DepGraph -> G.Node -> G.Node -> Bool unreachableEitherDir g a b = not (reachable g a b || reachable g b a) vFusionFeasability :: DepGraph -> G.Node -> G.Node -> Bool vFusionFeasability dg@DepGraph {dgGraph = g} n1 n2 = (isWithAccNodeId n2 dg || not (any isInf (edgesBetween dg n1 n2))) && not (any (reachable dg n2) (filter (/= n2) (G.pre g n1))) hFusionFeasability :: DepGraph -> G.Node -> G.Node -> Bool hFusionFeasability = unreachableEitherDir futhark-0.25.27/src/Futhark/Optimise/Fusion/RulesWithAccs.hs000066400000000000000000000636451475065116200236630ustar00rootroot00000000000000{-# LANGUAGE Strict #-} -- | This module consists of rules for fusion -- that involves WithAcc constructs. -- Currently, we support two non-trivial -- transformations: -- I. map-flatten-scatter: a map nest produces -- multi-dimensional index and values arrays -- that are then flattened and used in a -- scatter consumer. Such pattern can be fused -- by re-writing the scatter by means of a WithAcc -- containing a map-nest, thus eliminating the flatten -- operations. The obtained WithAcc can then be fused -- with the producer map nest, e.g., benefiting intra-group -- kernels. The eloquent target for this rule is -- an efficient implementation of radix-sort. -- -- II. WithAcc-WithAcc fusion: two withaccs can be -- fused as long as the common accumulators use -- the same operator, and as long as the non-accumulator -- input of an WithAcc is not used as an accumulator in -- the other. This fusion opens the door for fusing -- the SOACs appearing inside the WithAccs. This is -- also intended to demonstrate that it is not so -- important where exactly the WithAccs were originally -- introduced in the code, it is more important that -- they can be transformed by various optimizations passes. module Futhark.Optimise.Fusion.RulesWithAccs ( ruleMFScat, tryFuseWithAccs, ) where import Control.Monad import Data.Graph.Inductive.Graph qualified as G import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Analysis.HORep.SOAC qualified as H import Futhark.Construct import Futhark.IR.SOACS hiding (SOAC (..)) import Futhark.IR.SOACS qualified as F import Futhark.Optimise.Fusion.GraphRep import Futhark.Tools import Futhark.Transform.Rename import Futhark.Transform.Substitute se0 :: SubExp se0 = intConst Int64 0 se1 :: SubExp se1 = intConst Int64 1 ------------------------------------- --- I. Map-Flatten-Scatter Fusion --- ------------------------------------- -- helper data structures type IotaInp = ((VName, LParam SOACS), (SubExp, SubExp, SubExp, IntType)) -- ^ ((array-name, lambda param), (len, start, stride, Int64)) type RshpInp = ((VName, LParam SOACS), (Shape, Shape, Type)) -- ^ ((array-name, lambda param), (flat-shape, unflat-shape, elem-type)) -- | Implements a specialized rule for fusing a pattern -- formed by a map o flatten o scatter, i.e., -- let (inds, vals) = map-nest f inps -- (finds, fvals) = (flatten inds, flatten vals) -- let res = scatter res0 finds fvals -- where inds & vals have higher rank than finds & fvals. ruleMFScat :: (HasScope SOACS m, MonadFreshNames m) => DepNode -> DepGraph -> m (Maybe DepGraph) ruleMFScat node_to_fuse dg@DepGraph {dgGraph = g} | soac_nodeT <- snd node_to_fuse, scat_node_id <- nodeFromLNode node_to_fuse, SoacNode node_out_trsfs scat_pat scat_soac scat_aux <- soac_nodeT, H.nullTransforms node_out_trsfs, -- \^ for simplicity we do not allow transforms on scatter's result. H.Scatter _len scat_inp scat_out scat_lam <- scat_soac, -- \^ get the scatter scat_trsfs <- map H.inputTransforms (H.inputs scat_soac), -- \^ get the transforms on the input any (/= mempty) scat_trsfs, scat_ctx <- G.context g scat_node_id, (out_deps, _, _, inp_deps) <- scat_ctx, cons_deps <- filter (isCons . fst) inp_deps, drct_deps <- filter (isDep . fst) inp_deps, cons_ctxs <- map (G.context g . snd) cons_deps, drct_ctxs <- map (G.context g . snd) drct_deps, _cons_nTs <- map getNodeTfromCtx cons_ctxs, -- not used!! drct_tups0 <- mapMaybe (pairUp (zip drct_ctxs (map fst drct_deps))) scat_inp, length drct_tups0 == length scat_inp, -- \^ checks that all direct dependencies are also array -- inputs to scatter (t1s, t2s) <- unzip drct_tups0, drct_tups <- zip t1s $ zip t2s (lambdaParams scat_lam), (ctxs_iots, drct_iots) <- unzip $ filter (isIota . snd . fst . snd) drct_tups, (ctxs_rshp, drct_rshp) <- unzip $ filter (not . isIota . snd . fst . snd) drct_tups, length drct_iots + length drct_rshp == length scat_inp, -- \^ direct dependencies are either flatten reshapes or iotas. rep_iotas <- mapMaybe getRepIota drct_iots, length rep_iotas == length drct_iots, rep_rshps_certs <- mapMaybe getRepRshpArr drct_rshp, (rep_rshps, certs_rshps) <- unzip rep_rshps_certs, -- \^ gather the representations for the iotas and reshapes, that use -- the helper types `IotaInp` and `RshpInp` not (null rep_rshps), -- \^ at least one flatten-reshaped array length rep_rshps == length drct_rshp, (_, (s1, s2, _)) : _ <- rep_rshps, all (\(_, (s1', s2', _)) -> s1 == s1' && s2 == s2') rep_rshps, -- \^ Check that all unflatten shape dimensions are the same, -- so that we can construct a map nest; -- check profitability, which is conservatively defined as all -- the reshaped and consumer arrays are used solely by the -- scatter AND all reshape dependencies originate in the same -- map. checkSafeAndProfitable dg scat_node_id ctxs_rshp cons_ctxs = do -- generate the withAcc statement let cons_patels_outs = zip (patElems scat_pat) scat_out wacc_stm <- mkWithAccStm rep_iotas rep_rshps cons_patels_outs scat_aux scat_lam let all_cert_rshp = mconcat certs_rshps aux = stmAux wacc_stm aux' = aux {stmAuxCerts = all_cert_rshp <> stmAuxCerts aux} wacc_stm' = wacc_stm {stmAux = aux'} -- get the input deps of iotas fiot acc (_, _, _, inp_deps_iot) = acc <> inp_deps_iot deps_of_iotas = foldl fiot mempty ctxs_iots -- iota_nms = namesFromList $ map (fst . fst) rep_iotas inp_deps_wo_iotas = filter ((`notNameIn` iota_nms) . getName . fst) inp_deps -- generate a new node for the with-acc-stmt and its associated context: -- add the inp-deps of iotas but remove the iota themselves from deps. new_withacc_nT = StmNode wacc_stm' inp_deps' = inp_deps_wo_iotas <> deps_of_iotas new_withacc_ctx = (out_deps, scat_node_id, new_withacc_nT, inp_deps') -- construct the new WithAcc node/graph; do we need to use `fusedSomething` ?? new_node = G.node' new_withacc_ctx dg' = dg {dgGraph = new_withacc_ctx G.& G.delNodes [new_node] g} -- result pure $ Just dg' where -- getNodeTfromCtx (_, _, nT, _) = nT findCtxOf ctxes nm | [ctxe] <- filter (\x -> nm == getName (snd x)) ctxes = Just ctxe findCtxOf _ _ = Nothing pairUp :: [(DepContext, EdgeT)] -> H.Input -> Maybe (DepContext, (H.Input, NodeT)) pairUp ctxes inp@(H.Input _arrtrsfs nm _tp) | Just (ctx@(_, _, nT, _), _) <- findCtxOf ctxes nm = Just (ctx, (inp, nT)) pairUp _ _ = Nothing -- isIota :: NodeT -> Bool isIota (StmNode (Let _ _ (BasicOp (Iota {})))) = True isIota _ = False -- getRepIota :: ((H.Input, NodeT), LParam SOACS) -> Maybe IotaInp getRepIota ((H.Input iottrsf arr_nm _arr_tp, nt), farg) | mempty == iottrsf, StmNode (Let _ _ (BasicOp (Iota n x s Int64))) <- nt = Just ((arr_nm, farg), (n, x, s, Int64)) getRepIota _ = Nothing -- getRepRshpArr :: ((H.Input, NodeT), LParam SOACS) -> Maybe (RshpInp, Certs) getRepRshpArr ((H.Input outtrsf arr_nm arr_tp, _nt), farg) | rshp_trsfm H.:< other_trsfms <- H.viewf outtrsf, (H.Reshape c ReshapeArbitrary shp_flat) <- rshp_trsfm, other_trsfms == mempty, eltp <- paramDec farg, Just shp_flat' <- checkShp eltp shp_flat, Array _ptp shp_unflat _ <- arr_tp, Just shp_unflat' <- checkShp eltp shp_unflat, shapeRank shp_flat' == 1, shapeRank shp_flat' < shapeRank shp_unflat' = Just (((arr_nm, farg), (shp_flat', shp_unflat', eltp)), c) getRepRshpArr _ = Nothing -- checkShp (Prim _) shp_arr = Just shp_arr checkShp (Array _ptp shp_elm _) shp_arr = let dims_elm = shapeDims shp_elm dims_arr = shapeDims shp_arr (m, n) = (length dims_elm, length dims_arr) shp' = Shape $ take (n - m) dims_arr dims_com = drop (n - m) dims_arr in if all (uncurry (==)) (zip dims_com dims_elm) then Just shp' else Nothing checkShp _ _ = Nothing -- default fails: ruleMFScat _ _ = pure Nothing checkSafeAndProfitable :: DepGraph -> G.Node -> [DepContext] -> [DepContext] -> Bool checkSafeAndProfitable dg scat_node_id ctxs_rshp@(_ : _) ctxs_cons = let all_deps = concatMap (\(x, _, _, _) -> x) $ ctxs_rshp ++ ctxs_cons prof1 = all (\(_, dep_id) -> dep_id == scat_node_id) all_deps -- \^ scatter is the sole target to all consume & unflatten-reshape deps (_, map_node_id, map_nT, _) = head ctxs_rshp prof2 = all (\(_, nid, _, _) -> nid == map_node_id) ctxs_rshp prof3 = isMap map_nT -- \^ all reshapes come from the same node, which is a map safe = vFusionFeasability dg map_node_id scat_node_id in safe && prof1 && prof2 && prof3 where isMap nT | SoacNode out_trsfs _pat soac _ <- nT, H.Screma _ _ form <- soac, ScremaForm _ [] [] <- form = H.nullTransforms out_trsfs isMap _ = False checkSafeAndProfitable _ _ _ _ = False -- | produces the withAcc statement that constitutes the translation of -- the scater o flatten o map composition in which the map inputs are -- reshaped in the same way mkWithAccStm :: (HasScope SOACS m, MonadFreshNames m) => [IotaInp] -> [RshpInp] -> [(PatElem (LetDec SOACS), (Shape, Int, VName))] -> StmAux (ExpDec SOACS) -> Lambda SOACS -> m (Stm SOACS) mkWithAccStm iota_inps rshp_inps cons_patels_outs scatter_aux scatter_lam -- iotas are assumed to operate on Int64 values -- ToDo: maybe simplify rshp_inps -- check that the unflat shape is the same across reshapes -- check that the rank of the unflatten shape is higher than the flatten | rshp_inp : _ <- rshp_inps, (_, (_, s_unflat, _)) <- rshp_inp, (_ : _) <- shapeDims s_unflat = do -- (cert_params, acc_params) <- fmap unzip $ forM cons_patels_outs $ \(patel, (shp, _, nm)) -> do cert_param <- newParam "acc_cert_p" $ Prim Unit let arr_tp = patElemType patel acc_tp = stripArray (shapeRank shp) arr_tp acc_param <- newParam (baseString nm) $ Acc (paramName cert_param) shp [acc_tp] NoUniqueness pure (cert_param, acc_param) let cons_params_outs = zip acc_params $ map snd cons_patels_outs acc_bdy <- mkWithAccBdy s_unflat iota_inps rshp_inps cons_params_outs scatter_lam let withacc_lam = Lambda { lambdaParams = cert_params ++ acc_params, lambdaReturnType = map paramDec acc_params, lambdaBody = acc_bdy } withacc_inps = map (\(_, (shp, _, nm)) -> (shp, [nm], Nothing)) cons_patels_outs withacc_pat = Pat $ map fst cons_patels_outs stm = Let withacc_pat scatter_aux $ WithAcc withacc_inps withacc_lam pure stm mkWithAccStm _ _ _ _ _ = error "Unreachable case reached!" -- | Wrapper function for constructing the body of the withAcc -- translation of the scatter mkWithAccBdy :: (HasScope SOACS m, MonadFreshNames m) => Shape -> [IotaInp] -> [RshpInp] -> [(LParam SOACS, (Shape, Int, VName))] -> Lambda SOACS -> m (Body SOACS) mkWithAccBdy shp iota_inps rshp_inps cons_params_outs scat_lam = do let cons_ps = map fst cons_params_outs scat_res_info = map snd cons_params_outs static_arg = (iota_inps, rshp_inps, scat_res_info, scat_lam) mkParam ((nm, _), (_, s, t)) = Param mempty nm (arrayOfShape t s) rshp_ps = map mkParam rshp_inps mkWithAccBdy' static_arg (shapeDims shp) [] [] rshp_ps cons_ps -- | builds a body that essentially consists of a map-nest with accumulators, -- i.e., one level for each level of the unflatten shape of scatter's reshaped -- input arrays mkWithAccBdy' :: (HasScope SOACS m, MonadFreshNames m) => ([IotaInp], [RshpInp], [(Shape, Int, VName)], Lambda SOACS) -> [SubExp] -> [SubExp] -> [VName] -> [LParam SOACS] -> [LParam SOACS] -> m (Body SOACS) -- | the base case below addapts the scatter's lambda mkWithAccBdy' static_arg [] dims_rev iot_par_nms rshp_ps cons_ps = do let (iota_inps, rshp_inps, scat_res_info, scat_lam) = static_arg tp_int = Prim $ IntType Int64 scope <- askScope runBodyBuilder $ localScope (scope <> scopeOfLParams (rshp_ps ++ cons_ps)) $ do -- handle iota args let strides_rev = scanl (*) (pe64 se1) $ map pe64 dims_rev strides = tail $ reverse strides_rev prods = zipWith (*) (map le64 iot_par_nms) strides i_pe = sum prods i_norm <- letExp "iota_norm_arg" =<< toExp i_pe forM_ iota_inps $ \arg -> do let ((_, i_par), (_, b, s, _)) = arg i_new <- letExp "tmp" =<< toExp (pe64 b + le64 i_norm * pe64 s) letBind (Pat [PatElem (paramName i_par) tp_int]) $ BasicOp $ SubExp $ Var i_new -- handle rshp args let rshp_lam_args = map (snd . fst) rshp_inps forM_ (zip rshp_lam_args rshp_ps) $ \(old_par, new_par) -> do let pat = Pat [PatElem (paramName old_par) (paramDec old_par)] letBind pat $ BasicOp $ SubExp $ Var $ paramName new_par -- add the body of the scatter's lambda mapM_ addStm $ bodyStms $ lambdaBody scat_lam -- add the withAcc update statements let iv_ses = groupScatterResults' scat_res_info $ bodyResult $ lambdaBody scat_lam res_nms <- forM (zip cons_ps iv_ses) $ \(cons_p, (i_ses, v_se)) -> do -- i_ses is a list let f nm_in i_se = letExp (baseString nm_in) $ BasicOp $ UpdateAcc Safe nm_in [resSubExp i_se] [resSubExp v_se] foldM f (paramName cons_p) i_ses let lam_certs = foldMap resCerts $ bodyResult $ lambdaBody scat_lam pure $ map (SubExpRes lam_certs . Var) res_nms -- \| the recursive case builds a call to a map soac. mkWithAccBdy' static_arg (dim : dims) dims_rev iot_par_nms rshp_ps cons_ps = do scope <- askScope runBodyBuilder $ localScope (scope <> scopeOfLParams (rshp_ps ++ cons_ps)) $ do iota_arr <- letExp "iota_arr" $ BasicOp $ Iota dim se0 se1 Int64 iota_p <- newParam "iota_arg" $ Prim $ IntType Int64 rshp_ps' <- forM (zip [0 .. length rshp_ps - 1] (map paramDec rshp_ps)) $ \(i, arr_tp) -> newParam ("rshp_arg_" ++ show i) $ stripArray 1 arr_tp cons_ps' <- forM (zip [0 .. length cons_ps - 1] (map paramDec cons_ps)) $ \(i, arr_tp) -> newParam ("acc_arg_" ++ show i) arr_tp map_lam_bdy <- mkWithAccBdy' static_arg dims (dim : dims_rev) (iot_par_nms ++ [paramName iota_p]) rshp_ps' cons_ps' let map_lam = Lambda (rshp_ps' ++ [iota_p] ++ cons_ps') (map paramDec cons_ps') map_lam_bdy map_inps = map paramName rshp_ps ++ [iota_arr] ++ map paramName cons_ps map_soac = F.Screma dim map_inps $ ScremaForm map_lam [] [] res_nms <- letTupExp "acc_res" $ Op map_soac pure $ map (subExpRes . Var) res_nms --------------------------------------------------- --- II. WithAcc-WithAcc Fusion --------------------------------------------------- -- | Local helper type that tuples together: -- 1. the pattern element corresponding to one withacc input -- 2. the withacc input -- 3-5 withacc's lambda corresponding acc-certificate param, -- argument param and result name type AccTup = ( [PatElem (LetDec SOACS)], WithAccInput SOACS, LParam SOACS, LParam SOACS, (VName, Certs) ) accTup1 :: AccTup -> [PatElem (LetDec SOACS)] accTup1 (a, _, _, _, _) = a accTup2 :: AccTup -> WithAccInput SOACS accTup2 (_, a, _, _, _) = a accTup3 :: AccTup -> LParam SOACS accTup3 (_, _, a, _, _) = a accTup4 :: AccTup -> LParam SOACS accTup4 (_, _, _, a, _) = a accTup5 :: AccTup -> (VName, Certs) accTup5 (_, _, _, _, a) = a -- | Simple case for fusing two withAccs (can be extended): -- let (b1, ..., bm, x1, ..., xq) = withAcc a1 ... am lam1 -- let (d1, ..., dn, y1, ..., yp) = withAcc c1 ... cn lam2 -- Notation: `b1 ... bm` are the accumulator results of the -- first withAcc and `d1, ..., dn` of the second withAcc. -- `x1 ... xq` and `y1, ..., yp` are non-accumulator results. -- Conservative conditions: -- 1. for any bi (i=1..m) either `bi IN {c1, ..., cm}` OR -- `bi NOT-IN FV(lam2)`, i.e., perfect producer-consumer -- relation on accums. Of course the binary-op should -- be the same. -- 2. The `bs` that are also accumulated upon in lam2 -- do NOT belong to the `infusible` set (they are destroyed) -- 3. x1 ... xq do not overlap with c1 ... cn -- Fusion will create one withacc that accumulates on the -- union of `a1 ... am` and `c1 ... cn` and returns, in addition -- to the accumulator arrays the union of regular variables -- `x1 ... xq` and `y1, ..., yp` tryFuseWithAccs :: (HasScope SOACS m, MonadFreshNames m) => [VName] -> Stm SOACS -> Stm SOACS -> m (Maybe (Stm SOACS)) tryFuseWithAccs infusible (Let pat1 aux1 (WithAcc w_inps1 lam1)) (Let pat2 aux2 (WithAcc w_inps2 lam2)) | (pat1_els, pat2_els) <- (patElems pat1, patElems pat2), (acc_tup1, other_pr1) <- groupAccs pat1_els w_inps1 lam1, (acc_tup2, other_pr2) <- groupAccs pat2_els w_inps2 lam2, (tup_common, acc_tup1', acc_tup2') <- groupCommonAccs acc_tup1 acc_tup2, -- safety 0: make sure that the accs from acc_tup1' and -- acc_tup2' do not overlap pnms_1' <- map patElemName $ concatMap (\(nms, _, _, _, _) -> nms) acc_tup1', winp_2' <- concatMap (\(_, (_, nms, _), _, _, _) -> nms) acc_tup2', not $ namesIntersect (namesFromList pnms_1') (namesFromList winp_2'), -- safety 1: we have already determined the commons; -- now we also need to check NOT-IN FV(lam2) not $ namesIntersect (namesFromList pnms_1') (freeIn lam2), -- safety 2: -- bs <- map patElemName $ concatMap accTup1 acc_tup1, bs <- map patElemName $ concatMap (accTup1 . fst) tup_common, all (`notElem` infusible) bs, -- safety 3: cs <- namesFromList $ concatMap ((\(_, xs, _) -> xs) . accTup2) acc_tup2, all ((`notNameIn` cs) . patElemName . fst) other_pr1 = do let getCertPairs (t1, t2) = (paramName (accTup3 t2), paramName (accTup3 t1)) tab_certs = M.fromList $ map getCertPairs tup_common lam2_bdy' = substituteNames tab_certs (lambdaBody lam2) rcrt_params = map (accTup3 . fst) tup_common ++ map accTup3 acc_tup1' ++ map accTup3 acc_tup2' racc_params = map (accTup4 . fst) tup_common ++ map accTup4 acc_tup1' ++ map accTup4 acc_tup2' (comm_res_nms, comm_res_certs2) = unzip $ map (accTup5 . snd) tup_common (_, comm_res_certs1) = unzip $ map (accTup5 . fst) tup_common com_res_certs = zipWith (\x y -> Certs (unCerts x ++ unCerts y)) comm_res_certs1 comm_res_certs2 bdyres_certs = com_res_certs ++ map (snd . accTup5) (acc_tup1' ++ acc_tup2') bdyres_accse = map Var comm_res_nms ++ map (Var . fst . accTup5) (acc_tup1' ++ acc_tup2') bdy_res_accs = zipWith SubExpRes bdyres_certs bdyres_accse bdy_res_others = map snd $ other_pr1 ++ other_pr2 scope <- askScope lam_bdy <- runBodyBuilder $ do localScope (scope <> scopeOfLParams (rcrt_params ++ racc_params)) $ do -- add the stms of lam1 mapM_ addStm $ stmsToList $ bodyStms $ lambdaBody lam1 -- add the copy stms for the common accumulator forM_ tup_common $ \(tup1, tup2) -> do let (lpar1, lpar2) = (accTup4 tup1, accTup4 tup2) ((nm1, _), nm2, tp_acc) = (accTup5 tup1, paramName lpar2, paramDec lpar1) letBind (Pat [PatElem nm2 tp_acc]) $ BasicOp $ SubExp $ Var nm1 -- add copy stms to bring in scope x1 ... xq forM_ other_pr1 $ \(pat_elm, bdy_res) -> do let (nm, se, tp) = (patElemName pat_elm, resSubExp bdy_res, patElemType pat_elm) certifying (resCerts bdy_res) $ letBind (Pat [PatElem nm tp]) $ BasicOp (SubExp se) -- add the statements of lam2 (in which the acc-certificates have been substituted) mapM_ addStm $ stmsToList $ bodyStms lam2_bdy' -- build the result of body pure $ bdy_res_accs ++ bdy_res_others let tp_res_other = map (patElemType . fst) (other_pr1 ++ other_pr2) res_lam = Lambda { lambdaParams = rcrt_params ++ racc_params, lambdaBody = lam_bdy, lambdaReturnType = map paramDec racc_params ++ tp_res_other } res_lam' <- renameLambda res_lam let res_pat = concatMap (accTup1 . snd) tup_common ++ concatMap accTup1 (acc_tup1' ++ acc_tup2') ++ map fst (other_pr1 ++ other_pr2) res_w_inps = map (accTup2 . fst) tup_common ++ map accTup2 (acc_tup1' ++ acc_tup2') res_w_inps' <- mapM renameLamInWAccInp res_w_inps let stm_res = Let (Pat res_pat) (aux1 <> aux2) $ WithAcc res_w_inps' res_lam' pure $ Just stm_res where -- local helpers: groupAccs :: [PatElem (LetDec SOACS)] -> [WithAccInput SOACS] -> Lambda SOACS -> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)]) groupAccs pat_els wacc_inps wlam = let lam_params = lambdaParams wlam n = length lam_params (lam_par_crts, lam_par_accs) = splitAt (n `div` 2) lam_params lab_res_ses = bodyResult $ lambdaBody wlam in groupAccsHlp pat_els wacc_inps lam_par_crts lam_par_accs lab_res_ses groupAccsHlp :: [PatElem (LetDec SOACS)] -> [WithAccInput SOACS] -> [LParam SOACS] -> [LParam SOACS] -> [SubExpRes] -> ([AccTup], [(PatElem (LetDec SOACS), SubExpRes)]) groupAccsHlp pat_els [] [] [] lam_res_ses | length pat_els == length lam_res_ses = ([], zip pat_els lam_res_ses) groupAccsHlp pat_els (winp@(_, inp, _) : wacc_inps) (par_crt : lam_par_crts) (par_acc : lam_par_accs) (res_se : lam_res_ses) | n <- length inp, (n <= length pat_els) && (n <= (1 + length lam_res_ses)), Var res_nm <- resSubExp res_se = let (pat_els_cur, pat_els') = splitAt n pat_els (rec1, rec2) = groupAccsHlp pat_els' wacc_inps lam_par_crts lam_par_accs lam_res_ses in ((pat_els_cur, winp, par_crt, par_acc, (res_nm, resCerts res_se)) : rec1, rec2) groupAccsHlp _ _ _ _ _ = error "Unreachable case reached in groupAccsHlp!" -- groupCommonAccs :: [AccTup] -> [AccTup] -> ([(AccTup, AccTup)], [AccTup], [AccTup]) groupCommonAccs [] tup_accs2 = ([], [], tup_accs2) groupCommonAccs (tup_acc1 : tup_accs1) tup_accs2 | commons2 <- filter (matchingAccTup tup_acc1) tup_accs2, length commons2 <= 1 = let (rec1, rec2, rec3) = groupCommonAccs tup_accs1 $ if null commons2 then tup_accs2 else filter (not . matchingAccTup tup_acc1) tup_accs2 in if null commons2 then (rec1, tup_acc1 : rec2, rec3) else ((tup_acc1, head commons2) : rec1, tup_accs1, rec3) groupCommonAccs _ _ = error "Unreachable case reached in groupCommonAccs!" renameLamInWAccInp (shp, inps, Just (lam, se)) = do lam' <- renameLambda lam pure (shp, inps, Just (lam', se)) renameLamInWAccInp winp = pure winp -- tryFuseWithAccs _ _ _ = pure Nothing ------------------------------- --- simple helper functions --- ------------------------------- equivLambda :: M.Map VName VName -> Lambda SOACS -> Lambda SOACS -> Bool equivLambda stab lam1 lam2 | (ps1, ps2) <- (lambdaParams lam1, lambdaParams lam2), (nms1, nms2) <- (map paramName ps1, map paramName ps2), map paramDec ps1 == map paramDec ps2, map paramAttrs ps1 == map paramAttrs ps2, lambdaReturnType lam1 == lambdaReturnType lam2, (bdy1, bdy2) <- (lambdaBody lam1, lambdaBody lam2), bodyDec bdy1 == bodyDec bdy2 = let insert tab (x, k) = M.insert k x tab stab' = foldl insert stab $ zip nms1 nms2 fStm (vtab, False) _ = (vtab, False) fStm (vtab, True) (s1, s2) = equivStm vtab s1 s2 (stab'', success) = foldl fStm (stab', True) $ zip (stmsToList (bodyStms bdy1)) $ stmsToList (bodyStms bdy2) sres2 = substInSEs stab'' $ map resSubExp $ bodyResult bdy2 in success && map resSubExp (bodyResult bdy1) == sres2 equivLambda _ _ _ = False equivStm :: M.Map VName VName -> Stm SOACS -> Stm SOACS -> (M.Map VName VName, Bool) equivStm stab (Let pat1 aux1 (BasicOp (BinOp bop1 se11 se12))) (Let pat2 aux2 (BasicOp (BinOp bop2 se21 se22))) | [se11, se12] == substInSEs stab [se21, se22], (pels1, pels2) <- (patElems pat1, patElems pat2), map patElemDec pels1 == map patElemDec pels2, bop1 == bop2 && aux1 == aux2 = let stab_new = M.fromList $ zip (map patElemName pels2) (map patElemName pels1) in (M.union stab_new stab, True) -- To Be Continued ... equivStm vtab _ _ = (vtab, False) matchingAccTup :: AccTup -> AccTup -> Bool matchingAccTup (pat_els1, (shp1, _winp_arrs1, mlam1), _, _, _) (_, (shp2, winp_arrs2, mlam2), _, _, _) = shapeDims shp1 == shapeDims shp2 && map patElemName pat_els1 == winp_arrs2 && case (mlam1, mlam2) of (Nothing, Nothing) -> True (Just (lam1, see1), Just (lam2, see2)) -> (see1 == see2) && equivLambda M.empty lam1 lam2 _ -> False substInSEs :: M.Map VName VName -> [SubExp] -> [SubExp] substInSEs vtab = map substInSE where substInSE (Var x) | Just y <- M.lookup x vtab = Var y substInSE z = z futhark-0.25.27/src/Futhark/Optimise/Fusion/TryFusion.hs000066400000000000000000000755161475065116200231050ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Facilities for fusing two SOACs. -- -- When the fusion algorithm decides that it's worth fusing two SOAC -- statements, this is the module that tries to see if that's -- possible. May involve massaging either producer or consumer in -- various ways. module Futhark.Optimise.Fusion.TryFusion ( FusedSOAC (..), Mode (..), attemptFusion, ) where import Control.Applicative import Control.Arrow (first) import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.List (find, tails, (\\)) import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Analysis.HORep.MapNest qualified as MapNest import Futhark.Analysis.HORep.SOAC qualified as SOAC import Futhark.Construct import Futhark.IR.SOACS hiding (SOAC (..)) import Futhark.IR.SOACS qualified as Futhark import Futhark.Optimise.Fusion.Composing import Futhark.Pass.ExtractKernels.ISRWIM (rwimPossible) import Futhark.Transform.Rename (renameLambda) import Futhark.Transform.Substitute import Futhark.Util (splitAt3) newtype TryFusion a = TryFusion ( ReaderT (Scope SOACS) (StateT VNameSource Maybe) a ) deriving ( Functor, Applicative, Alternative, Monad, MonadFail, MonadFreshNames, HasScope SOACS, LocalScope SOACS ) tryFusion :: (MonadFreshNames m) => TryFusion a -> Scope SOACS -> m (Maybe a) tryFusion (TryFusion m) types = modifyNameSource $ \src -> case runStateT (runReaderT m types) src of Just (x, src') -> (Just x, src') Nothing -> (Nothing, src) liftMaybe :: Maybe a -> TryFusion a liftMaybe Nothing = fail "Nothing" liftMaybe (Just x) = pure x type SOAC = SOAC.SOAC SOACS type MapNest = MapNest.MapNest SOACS inputToOutput :: SOAC.Input -> Maybe (SOAC.ArrayTransform, SOAC.Input) inputToOutput (SOAC.Input ts ia iat) = case SOAC.viewf ts of t SOAC.:< ts' -> Just (t, SOAC.Input ts' ia iat) SOAC.EmptyF -> Nothing -- | A fused SOAC contains a bit of extra information. data FusedSOAC = FusedSOAC { -- | The actual SOAC. fsSOAC :: SOAC, -- | A transformation to be applied to *all* results of the SOAC. fsOutputTransform :: SOAC.ArrayTransforms, -- | The outputs of the SOAC (i.e. the names in the pattern that -- the result of this SOAC should be bound to). fsOutNames :: [VName] } deriving (Show) inputs :: FusedSOAC -> [SOAC.Input] inputs = SOAC.inputs . fsSOAC setInputs :: [SOAC.Input] -> FusedSOAC -> FusedSOAC setInputs inps ker = ker {fsSOAC = inps `SOAC.setInputs` fsSOAC ker} tryOptimizeSOAC :: Mode -> Names -> [VName] -> SOAC -> FusedSOAC -> TryFusion FusedSOAC tryOptimizeSOAC mode unfus_nms outVars soac ker = do (soac', ots) <- optimizeSOAC Nothing soac mempty let ker' = map (addInitialTransformIfRelevant ots) (inputs ker) `setInputs` ker outIdents = zipWith Ident outVars $ SOAC.typeOf soac' ker'' = fixInputTypes outIdents ker' applyFusionRules mode unfus_nms outVars soac' ker'' where addInitialTransformIfRelevant ots inp | SOAC.inputArray inp `elem` outVars = SOAC.addInitialTransforms ots inp | otherwise = inp tryOptimizeKernel :: Mode -> Names -> [VName] -> SOAC -> FusedSOAC -> TryFusion FusedSOAC tryOptimizeKernel mode unfus_nms outVars soac ker = do ker' <- optimizeKernel (Just outVars) ker applyFusionRules mode unfus_nms outVars soac ker' tryExposeInputs :: Mode -> Names -> [VName] -> SOAC -> FusedSOAC -> TryFusion FusedSOAC tryExposeInputs mode unfus_nms outVars soac ker = do (ker', ots) <- exposeInputs outVars ker if SOAC.nullTransforms ots then fuseSOACwithKer mode unfus_nms outVars soac ker' else do guard $ unfus_nms == mempty (soac', ots') <- pullOutputTransforms soac ots let outIdents = zipWith Ident outVars $ SOAC.typeOf soac' ker'' = fixInputTypes outIdents ker' if SOAC.nullTransforms ots' then applyFusionRules mode unfus_nms outVars soac' ker'' else fail "tryExposeInputs could not pull SOAC transforms" fixInputTypes :: [Ident] -> FusedSOAC -> FusedSOAC fixInputTypes outIdents ker = ker {fsSOAC = fixInputTypes' $ fsSOAC ker} where fixInputTypes' soac = map fixInputType (SOAC.inputs soac) `SOAC.setInputs` soac fixInputType (SOAC.Input ts v _) | Just v' <- find ((== v) . identName) outIdents = SOAC.Input ts v $ identType v' fixInputType inp = inp applyFusionRules :: Mode -> Names -> [VName] -> SOAC -> FusedSOAC -> TryFusion FusedSOAC applyFusionRules mode unfus_nms outVars soac ker = tryOptimizeSOAC mode unfus_nms outVars soac ker <|> tryOptimizeKernel mode unfus_nms outVars soac ker <|> fuseSOACwithKer mode unfus_nms outVars soac ker <|> tryExposeInputs mode unfus_nms outVars soac ker -- | Whether we are doing horizontal or vertical fusion. Note that -- vertical also includes "diagonal" fusion, where some producer -- results are also produced by the final SOAC. data Mode = Horizontal | Vertical -- | Attempt fusing the producer into the consumer. attemptFusion :: (HasScope SOACS m, MonadFreshNames m) => Mode -> -- | Outputs of the producer that should still be output by the -- fusion result (corresponding to "diagonal fusion"). Names -> -- | The outputs of the SOAC. [VName] -> SOAC -> FusedSOAC -> m (Maybe FusedSOAC) attemptFusion mode unfus_nms outVars soac ker = do scope <- askScope tryFusion (applyFusionRules mode unfus_nms outVars soac ker) scope -- | Check that the consumer does not use any scan or reduce results. scremaFusionOK :: ([VName], [VName]) -> FusedSOAC -> Bool scremaFusionOK (nonmap_outs, _map_outs) ker = all (`notElem` nonmap_outs) $ mapMaybe SOAC.isVarishInput (inputs ker) -- | Check that the consumer uses all the outputs of the producer unmodified. mapWriteFusionOK :: [VName] -> FusedSOAC -> Bool mapWriteFusionOK outVars ker = all (`elem` inpIds) outVars where inpIds = mapMaybe SOAC.isVarishInput (inputs ker) -- | The brain of this module: Fusing a SOAC with a Kernel. fuseSOACwithKer :: Mode -> Names -> [VName] -> SOAC -> FusedSOAC -> TryFusion FusedSOAC fuseSOACwithKer mode unfus_set outVars soac_p ker = do -- We are fusing soac_p into soac_c, i.e, the output of soac_p is going -- into soac_c. let soac_c = fsSOAC ker inp_p_arr = SOAC.inputs soac_p inp_c_arr = SOAC.inputs soac_c lam_p = SOAC.lambda soac_p lam_c = SOAC.lambda soac_c w = SOAC.width soac_p returned_outvars = filter (`nameIn` unfus_set) outVars success res_outnms res_soac = do -- Avoid name duplication, because the producer lambda is not -- removed from the program until much later. uniq_lam <- renameLambda $ SOAC.lambda res_soac pure $ ker { fsSOAC = uniq_lam `SOAC.setLambda` res_soac, fsOutNames = res_outnms } -- Can only fuse SOACs with same width. guard $ SOAC.width soac_p == SOAC.width soac_c -- If we are getting rid of a producer output, then it must be used -- exclusively without any transformations. let ker_inputs = map SOAC.inputArray (inputs ker) okInput v inp = v /= SOAC.inputArray inp || isJust (SOAC.isVarishInput inp) inputOrUnfus v = all (okInput v) (inputs ker) || v `notElem` ker_inputs guard $ all inputOrUnfus outVars outPairs <- forM (zip outVars $ map rowType $ SOAC.typeOf soac_p) $ \(outVar, t) -> do outVar' <- newVName $ baseString outVar ++ "_elem" pure (outVar, Ident outVar' t) let mapLikeFusionCheck = let (res_lam, new_inp) = fuseMaps unfus_set lam_p inp_p_arr outPairs lam_c inp_c_arr (extra_nms, extra_rtps) = unzip $ filter ((`nameIn` unfus_set) . fst) $ zip outVars $ map (stripArray 1) $ SOAC.typeOf soac_p res_lam' = res_lam {lambdaReturnType = lambdaReturnType res_lam ++ extra_rtps} in (extra_nms, res_lam', new_inp) case (soac_c, soac_p, mode) of _ | SOAC.width soac_p /= SOAC.width soac_c -> fail "SOAC widths must match." (_, _, Horizontal) | not (SOAC.nullTransforms $ fsOutputTransform ker) -> fail "Horizontal fusion is invalid in the presence of output transforms." (_, _, Vertical) | unfus_set /= mempty, not (SOAC.nullTransforms $ fsOutputTransform ker) -> fail "Cannot perform diagonal fusion in the presence of output transforms." ( SOAC.Screma _ _ (ScremaForm _ scans_c reds_c), SOAC.Screma _ _ (ScremaForm _ scans_p reds_p), _ ) | scremaFusionOK ( splitAt ( Futhark.scanResults scans_p + Futhark.redResults reds_p ) outVars ) ker -> do let red_nes_p = concatMap redNeutral reds_p red_nes_c = concatMap redNeutral reds_c scan_nes_p = concatMap scanNeutral scans_p scan_nes_c = concatMap scanNeutral scans_c (res_lam', new_inp) = fuseRedomap unfus_set outVars lam_p scan_nes_p red_nes_p inp_p_arr outPairs lam_c scan_nes_c red_nes_c inp_c_arr (soac_p_scanout, soac_p_redout, _soac_p_mapout) = splitAt3 (length scan_nes_p) (length red_nes_p) outVars (soac_c_scanout, soac_c_redout, soac_c_mapout) = splitAt3 (length scan_nes_c) (length red_nes_c) $ fsOutNames ker unfus_arrs = returned_outvars \\ (soac_p_scanout ++ soac_p_redout) success ( soac_p_scanout ++ soac_c_scanout ++ soac_p_redout ++ soac_c_redout ++ soac_c_mapout ++ unfus_arrs ) $ SOAC.Screma w new_inp (ScremaForm res_lam' (scans_p ++ scans_c) (reds_p ++ reds_c)) ------------------ -- Scatter fusion -- ------------------ -- Map-Scatter fusion. -- -- The 'inplace' mechanism for kernels already takes care of -- checking that the Scatter is not writing to any array used in -- the Map. ( SOAC.Scatter _len _ivs dests _lam, SOAC.Screma _ _ form, _ ) | isJust $ isMapSOAC form, -- 1. all arrays produced by the map are ONLY used (consumed) -- by the scatter, i.e., not used elsewhere. all (`notNameIn` unfus_set) outVars, -- 2. all arrays produced by the map are input to the scatter. mapWriteFusionOK outVars ker -> do let (extra_nms, res_lam', new_inp) = mapLikeFusionCheck success (fsOutNames ker ++ extra_nms) $ SOAC.Scatter w new_inp dests res_lam' -- Map-Hist fusion. -- -- The 'inplace' mechanism for kernels already takes care of -- checking that the Hist is not writing to any array used in -- the Map. ( SOAC.Hist _ _ ops _, SOAC.Screma _ _ form, _ ) | isJust $ isMapSOAC form, -- 1. all arrays produced by the map are ONLY used (consumed) -- by the hist, i.e., not used elsewhere. all (`notNameIn` unfus_set) outVars, -- 2. all arrays produced by the map are input to the scatter. mapWriteFusionOK outVars ker -> do let (extra_nms, res_lam', new_inp) = mapLikeFusionCheck success (fsOutNames ker ++ extra_nms) $ SOAC.Hist w new_inp ops res_lam' -- Hist-Hist fusion ( SOAC.Hist _ _ ops_c _, SOAC.Hist _ _ ops_p _, Horizontal ) -> do let p_num_buckets = length ops_p c_num_buckets = length ops_c (body_p, body_c) = (lambdaBody lam_p, lambdaBody lam_c) body' = Body { bodyDec = bodyDec body_p, -- body_p and body_c have the same decorations bodyStms = bodyStms body_p <> bodyStms body_c, bodyResult = take c_num_buckets (bodyResult body_c) ++ take p_num_buckets (bodyResult body_p) ++ drop c_num_buckets (bodyResult body_c) ++ drop p_num_buckets (bodyResult body_p) } lam' = Lambda { lambdaParams = lambdaParams lam_c ++ lambdaParams lam_p, lambdaBody = body', lambdaReturnType = replicate (c_num_buckets + p_num_buckets) (Prim int64) ++ drop c_num_buckets (lambdaReturnType lam_c) ++ drop p_num_buckets (lambdaReturnType lam_p) } success (fsOutNames ker ++ returned_outvars) $ SOAC.Hist w (inp_c_arr <> inp_p_arr) (ops_c <> ops_p) lam' -- Scatter-write fusion. ( SOAC.Scatter _w_c ivs_c as_c _lam_c, SOAC.Scatter _w_p ivs_p as_p _lam_p, Horizontal ) -> do let zipW as_xs xs as_ys ys = xs_indices ++ ys_indices ++ xs_vals ++ ys_vals where (xs_indices, xs_vals) = splitScatterResults as_xs xs (ys_indices, ys_vals) = splitScatterResults as_ys ys let (body_p, body_c) = (lambdaBody lam_p, lambdaBody lam_c) let body' = Body { bodyDec = bodyDec body_p, -- body_p and body_c have the same decorations bodyStms = bodyStms body_p <> bodyStms body_c, bodyResult = zipW as_c (bodyResult body_c) as_p (bodyResult body_p) } let lam' = Lambda { lambdaParams = lambdaParams lam_c ++ lambdaParams lam_p, lambdaBody = body', lambdaReturnType = zipW as_c (lambdaReturnType lam_c) as_p (lambdaReturnType lam_p) } success (fsOutNames ker ++ returned_outvars) $ SOAC.Scatter w (ivs_c ++ ivs_p) (as_c ++ as_p) lam' (SOAC.Scatter {}, _, _) -> fail "Cannot fuse a scatter with anything else than a scatter or a map" (_, SOAC.Scatter {}, _) -> fail "Cannot fuse a scatter with anything else than a scatter or a map" ---------------------------- -- Stream-Stream Fusions: -- ---------------------------- (SOAC.Stream {}, SOAC.Stream {}, _) -> do -- fuse two SEQUENTIAL streams (res_nms, res_stream) <- fuseStreamHelper (fsOutNames ker) unfus_set outVars outPairs soac_c soac_p success res_nms res_stream ------------------------------------------------------------------- --- If one is a stream, translate the other to a stream as well.--- --- This does not get in trouble (infinite computation) because --- --- scan's translation to Stream introduces a hindrance to --- --- (horizontal fusion), hence repeated application is for the--- --- moment impossible. However, if with a dependence-graph rep--- --- we could run in an infinite recursion, i.e., repeatedly --- --- fusing map o scan into an infinity of Stream levels! --- ------------------------------------------------------------------- (SOAC.Stream {}, _, _) -> do -- If this rule is matched then soac_p is NOT a stream. -- To fuse a stream kernel, we transform soac_p to a stream, which -- borrows the sequential/parallel property of the soac_c Stream, -- and recursively perform stream-stream fusion. (soac_p', newacc_ids) <- SOAC.soacToStream soac_p fuseSOACwithKer mode (namesFromList (map identName newacc_ids) <> unfus_set) (map identName newacc_ids ++ outVars) soac_p' ker (_, SOAC.Screma _ _ form, _) | Just _ <- Futhark.isScanomapSOAC form -> do -- A Scan soac can be currently only fused as a (sequential) stream, -- hence it is first translated to a (sequential) Stream and then -- fusion with a kernel is attempted. (soac_p', newacc_ids) <- SOAC.soacToStream soac_p if soac_p' /= soac_p then fuseSOACwithKer mode (namesFromList (map identName newacc_ids) <> unfus_set) (map identName newacc_ids ++ outVars) soac_p' ker else fail "SOAC could not be turned into stream." (_, SOAC.Stream {}, _) -> do -- If it reached this case then soac_c is NOT a Stream kernel, -- hence transform the kernel's soac to a stream and attempt -- stream-stream fusion recursivelly. -- The newly created stream corresponding to soac_c borrows the -- sequential/parallel property of the soac_p stream. (soac_c', newacc_ids) <- SOAC.soacToStream soac_c if soac_c' /= soac_c then fuseSOACwithKer mode (namesFromList (map identName newacc_ids) <> unfus_set) outVars soac_p $ ker {fsSOAC = soac_c', fsOutNames = map identName newacc_ids ++ fsOutNames ker} else fail "SOAC could not be turned into stream." --------------------------------- --- DEFAULT, CANNOT FUSE CASE --- --------------------------------- _ -> fail "Cannot fuse" fuseStreamHelper :: [VName] -> Names -> [VName] -> [(VName, Ident)] -> SOAC -> SOAC -> TryFusion ([VName], SOAC) fuseStreamHelper out_kernms unfus_set outVars outPairs (SOAC.Stream w2 inp2_arr nes2 lam2) (SOAC.Stream _ inp1_arr nes1 lam1) = do -- very similar to redomap o redomap composition, but need -- to remove first the `chunk' parameters of streams' -- lambdas and put them in the resulting stream lambda. let chunk1 = head $ lambdaParams lam1 chunk2 = head $ lambdaParams lam2 hmnms = M.fromList [(paramName chunk2, paramName chunk1)] lam20 = substituteNames hmnms lam2 lam1' = lam1 {lambdaParams = tail $ lambdaParams lam1} lam2' = lam20 {lambdaParams = tail $ lambdaParams lam20} (res_lam', new_inp) = fuseRedomap unfus_set outVars lam1' [] nes1 inp1_arr outPairs lam2' [] nes2 inp2_arr res_lam'' = res_lam' {lambdaParams = chunk1 : lambdaParams res_lam'} unfus_accs = take (length nes1) outVars unfus_arrs = filter (`notElem` unfus_accs) $ filter (`nameIn` unfus_set) outVars pure ( unfus_accs ++ out_kernms ++ unfus_arrs, SOAC.Stream w2 new_inp (nes1 ++ nes2) res_lam'' ) fuseStreamHelper _ _ _ _ _ _ = fail "Cannot Fuse Streams!" -- Here follows optimizations and transforms to expose fusability. optimizeKernel :: Maybe [VName] -> FusedSOAC -> TryFusion FusedSOAC optimizeKernel inp ker = do (soac, resTrans) <- optimizeSOAC inp (fsSOAC ker) (fsOutputTransform ker) pure $ ker {fsSOAC = soac, fsOutputTransform = resTrans} optimizeSOAC :: Maybe [VName] -> SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms) optimizeSOAC inp soac os = do res <- foldM comb (False, soac, os) optimizations case res of (False, _, _) -> fail "No optimisation applied" (True, soac', os') -> pure (soac', os') where comb (changed, soac', os') f = do (soac'', os'') <- f inp soac' os pure (True, soac'', os'') <|> pure (changed, soac', os') type Optimization = Maybe [VName] -> SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms) optimizations :: [Optimization] optimizations = [iswim] iswim :: Maybe [VName] -> SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms) iswim _ (SOAC.Screma w arrs form) ots | Just [Futhark.Scan scan_fun nes] <- Futhark.isScanSOAC form, Just (map_pat, map_cs, map_w, map_fun) <- rwimPossible scan_fun, Just nes_names <- mapM subExpVar nes = do let nes_idents = zipWith Ident nes_names $ lambdaReturnType scan_fun map_nes = map SOAC.identInput nes_idents map_arrs' = map_nes ++ map (SOAC.transposeInput 0 1) arrs (scan_acc_params, scan_elem_params) = splitAt (length arrs) $ lambdaParams scan_fun map_params = map removeParamOuterDim scan_acc_params ++ map (setParamOuterDimTo w) scan_elem_params map_rettype = map (`setOuterSize` w) $ lambdaReturnType scan_fun scan_params = lambdaParams map_fun scan_body = lambdaBody map_fun scan_rettype = lambdaReturnType map_fun scan_fun' = Lambda scan_params scan_rettype scan_body nes' = map Var $ take (length map_nes) $ map paramName map_params arrs' = drop (length map_nes) $ map paramName map_params scan_form <- scanSOAC [Futhark.Scan scan_fun' nes'] let map_body = mkBody ( oneStm $ Let (setPatOuterDimTo w map_pat) (defAux ()) . Op $ Futhark.Screma w arrs' scan_form ) $ varsRes $ patNames map_pat map_fun' = Lambda map_params map_rettype map_body perm = case lambdaReturnType scan_fun of -- instead of map_fun [] -> [] t : _ -> 1 : 0 : [2 .. arrayRank t] pure ( SOAC.Screma map_w map_arrs' (mapSOAC map_fun'), ots SOAC.|> SOAC.Rearrange map_cs perm ) iswim _ _ _ = fail "ISWIM does not apply." removeParamOuterDim :: LParam SOACS -> LParam SOACS removeParamOuterDim param = let t = rowType $ paramType param in param {paramDec = t} setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS setParamOuterDimTo w param = let t = paramType param `setOuterSize` w in param {paramDec = t} setPatOuterDimTo :: SubExp -> Pat Type -> Pat Type setPatOuterDimTo w = fmap (`setOuterSize` w) -- Now for fiddling with transpositions... commonTransforms :: [VName] -> [SOAC.Input] -> (SOAC.ArrayTransforms, [SOAC.Input]) commonTransforms interesting inps = commonTransforms' inps' where inps' = [ (SOAC.inputArray inp `elem` interesting, inp) | inp <- inps ] commonTransforms' :: [(Bool, SOAC.Input)] -> (SOAC.ArrayTransforms, [SOAC.Input]) commonTransforms' inps = case foldM inspect (Nothing, []) inps of Just (Just mot, inps') -> first (mot SOAC.<|) $ commonTransforms' $ reverse inps' _ -> (SOAC.noTransforms, map snd inps) where inspect (mot, prev) (True, inp) = case (mot, inputToOutput inp) of (Nothing, Just (ot, inp')) -> Just (Just ot, (True, inp') : prev) (Just ot1, Just (ot2, inp')) | ot1 == ot2 -> Just (Just ot2, (True, inp') : prev) _ -> Nothing inspect (mot, prev) inp = Just (mot, inp : prev) mapDepth :: MapNest -> Int mapDepth (MapNest.MapNest _ lam levels _) = min resDims (length levels) + 1 where resDims = minDim $ case levels of [] -> lambdaReturnType lam nest : _ -> MapNest.nestingReturnType nest minDim [] = 0 minDim (t : ts) = foldl min (arrayRank t) $ map arrayRank ts pullRearrange :: SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms) pullRearrange soac ots = do nest <- liftMaybe =<< MapNest.fromSOAC soac SOAC.Rearrange cs perm SOAC.:< ots' <- pure $ SOAC.viewf ots if rearrangeReach perm <= mapDepth nest then do let -- Expand perm to cover the full extent of the input dimensionality perm' inp = take r perm ++ [length perm .. r - 1] where r = SOAC.inputRank inp addPerm inp = SOAC.addTransform (SOAC.Rearrange cs $ perm' inp) inp inputs' = map addPerm $ MapNest.inputs nest soac' <- MapNest.toSOAC $ inputs' `MapNest.setInputs` rearrangeReturnTypes nest perm pure (soac', ots') else fail "Cannot pull transpose" pullIndex :: SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms) pullIndex (SOAC.Screma _ inps form) ots | SOAC.Index cs slice@(Slice (ds@(DimSlice _ w' _) : inner_ds)) SOAC.:< ots' <- SOAC.viewf ots, Just lam <- isMapSOAC form = do let sliceInput inp = SOAC.addTransform (SOAC.Index cs (fullSlice (SOAC.inputType inp) [ds])) inp sliceRes (SubExpRes rcs (Var v)) = certifying rcs . fmap subExpRes . letSubExp (baseString v <> "_sliced") $ BasicOp (Index v (Slice inner_ds)) sliceRes r = pure r inner_changed = any ((/= stripDims 1 (sliceShape slice)) . arrayShape) (lambdaReturnType lam) lam' <- if not inner_changed then pure lam else runLambdaBuilder (lambdaParams lam) $ mapM sliceRes =<< bodyBind (lambdaBody lam) pure (SOAC.Screma w' (map sliceInput inps) (mapSOAC lam'), ots') pullIndex _ _ = fail "Cannot pull index" pushRearrange :: [VName] -> SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms) pushRearrange inpIds soac ots = do nest <- liftMaybe =<< MapNest.fromSOAC soac (perm, inputs') <- liftMaybe $ fixupInputs inpIds $ MapNest.inputs nest if rearrangeReach perm <= mapDepth nest then do let invertRearrange = SOAC.Rearrange mempty $ rearrangeInverse perm soac' <- MapNest.toSOAC $ inputs' `MapNest.setInputs` rearrangeReturnTypes nest perm pure (soac', invertRearrange SOAC.<| ots) else fail "Cannot push transpose" -- | Actually also rearranges indices. rearrangeReturnTypes :: MapNest -> [Int] -> MapNest rearrangeReturnTypes nest@(MapNest.MapNest w body nestings inps) perm = MapNest.MapNest w body ( zipWith setReturnType nestings $ drop 1 $ iterate (map rowType) ts ) inps where origts = MapNest.typeOf nest -- The permutation may be deeper than the rank of the type, -- but it is required that it is an identity permutation -- beyond that. This is supposed to be checked as an -- invariant by whoever calls rearrangeReturnTypes. rearrangeType' t = rearrangeType (take (arrayRank t) perm) t ts = map rearrangeType' origts setReturnType nesting t' = nesting {MapNest.nestingReturnType = t'} fixupInputs :: [VName] -> [SOAC.Input] -> Maybe ([Int], [SOAC.Input]) fixupInputs inpIds inps = case mapMaybe inputRearrange $ filter exposable inps of perm : _ -> do inps' <- mapM (fixupInput (rearrangeReach perm) perm) inps pure (perm, inps') _ -> Nothing where exposable = (`elem` inpIds) . SOAC.inputArray inputRearrange (SOAC.Input ts _ _) | _ SOAC.:> SOAC.Rearrange _ perm <- SOAC.viewl ts = Just perm inputRearrange _ = Nothing fixupInput d perm inp | r <- SOAC.inputRank inp, r >= d = Just $ SOAC.addTransform (SOAC.Rearrange mempty $ take r perm) inp | otherwise = Nothing pullReshape :: SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms) pullReshape (SOAC.Screma _ inps form) ots | Just maplam <- Futhark.isMapSOAC form, SOAC.Reshape cs k shape SOAC.:< ots' <- SOAC.viewf ots, all primType $ lambdaReturnType maplam = do let mapw' = case reverse $ shapeDims shape of [] -> intConst Int64 0 d : _ -> d trInput inp | arrayRank (SOAC.inputType inp) == 1 = SOAC.addTransform (SOAC.Reshape cs k shape) inp | otherwise = SOAC.addTransform (SOAC.ReshapeOuter cs k shape) inp inputs' = map trInput inps inputTypes = map SOAC.inputType inputs' let outersoac :: ([SOAC.Input] -> SOAC) -> (SubExp, [SubExp]) -> TryFusion ([SOAC.Input] -> SOAC) outersoac inner (w, outershape) = do let addDims t = arrayOf t (Shape outershape) NoUniqueness retTypes = map addDims $ lambdaReturnType maplam ps <- forM inputTypes $ \inpt -> newParam "pullReshape_param" $ stripArray (length shape - length outershape) inpt inner_body <- runBodyBuilder $ varsRes <$> (letTupExp "x" <=< SOAC.toExp $ inner $ map (SOAC.identInput . paramIdent) ps) let inner_fun = Lambda { lambdaParams = ps, lambdaReturnType = retTypes, lambdaBody = inner_body } pure $ flip (SOAC.Screma w) $ Futhark.mapSOAC inner_fun op' <- foldM outersoac (flip (SOAC.Screma mapw') $ Futhark.mapSOAC maplam) $ zip (drop 1 $ reverse $ shapeDims shape) $ drop 1 . reverse . drop 1 . tails $ shapeDims shape pure (op' inputs', ots') pullReshape _ _ = fail "Cannot pull reshape" -- Tie it all together in exposeInputs (for making inputs to a -- consumer available) and pullOutputTransforms (for moving -- output-transforms of a producer to its inputs instead). exposeInputs :: [VName] -> FusedSOAC -> TryFusion (FusedSOAC, SOAC.ArrayTransforms) exposeInputs inpIds ker = (exposeInputs' =<< pushRearrange') <|> (exposeInputs' =<< pullRearrange') <|> (exposeInputs' =<< pullIndex') <|> exposeInputs' ker where ot = fsOutputTransform ker pushRearrange' = do (soac', ot') <- pushRearrange inpIds (fsSOAC ker) ot pure ker { fsSOAC = soac', fsOutputTransform = ot' } pullRearrange' = do (soac', ot') <- pullRearrange (fsSOAC ker) ot unless (SOAC.nullTransforms ot') $ fail "pullRearrange was not enough" pure ker { fsSOAC = soac', fsOutputTransform = SOAC.noTransforms } pullIndex' = do (soac', ot') <- pullIndex (fsSOAC ker) ot unless (SOAC.nullTransforms ot') $ fail "pullIndex was not enough" pure ker { fsSOAC = soac', fsOutputTransform = SOAC.noTransforms } exposeInputs' ker' = case commonTransforms inpIds $ inputs ker' of (ot', inps') | all exposed inps' -> pure (ker' {fsSOAC = inps' `SOAC.setInputs` fsSOAC ker'}, ot') _ -> fail "Cannot expose" exposed (SOAC.Input ts _ _) | SOAC.nullTransforms ts = True exposed inp = SOAC.inputArray inp `notElem` inpIds outputTransformPullers :: [ SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms) ] outputTransformPullers = [pullRearrange, pullReshape, pullIndex] pullOutputTransforms :: SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms) pullOutputTransforms = attempt outputTransformPullers where attempt [] _ _ = fail "Cannot pull anything" attempt (p : ps) soac ots = do (soac', ots') <- p soac ots if SOAC.nullTransforms ots' then pure (soac', SOAC.noTransforms) else pullOutputTransforms soac' ots' <|> pure (soac', ots') <|> attempt ps soac ots futhark-0.25.27/src/Futhark/Optimise/GenRedOpt.hs000066400000000000000000000426271475065116200215240ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Tries to turn a generalized reduction kernel into -- a more specialized construct, for example: -- (a) a map nest with a sequential redomap ripe for tiling -- (b) a SegRed kernel followed by a smallish accumulation kernel. -- (c) a histogram (for this we need to track the withAccs) -- The idea is to identify the first accumulation and -- to separate the initial kernels into two: -- 1. the code up to and including the accumulation, -- which is optimized to turn the accumulation either -- into a map-reduce composition or a histogram, and -- 2. the remaining code, which is recursively optimized. -- Since this is mostly prototyping, when the accumulation -- can be rewritten as a map-reduce, we sequentialize the -- map-reduce, as to potentially enable tiling oportunities. module Futhark.Optimise.GenRedOpt (optimiseGenRed) where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Builder import Futhark.IR.GPU import Futhark.Optimise.TileLoops.Shared import Futhark.Pass import Futhark.Tools import Futhark.Transform.Rename type GenRedM = ReaderT (Scope GPU) (State VNameSource) -- | The pass definition. optimiseGenRed :: Pass GPU GPU optimiseGenRed = Pass "optimise generalized reductions" "Specializes generalized reductions into map-reductions or histograms" $ intraproceduralTransformation onStms where onStms scope stms = modifyNameSource $ runState $ runReaderT (optimiseStms (M.empty, M.empty) stms) scope optimiseBody :: Env -> Body GPU -> GenRedM (Body GPU) optimiseBody env (Body () stms res) = Body () <$> optimiseStms env stms <*> pure res optimiseStms :: Env -> Stms GPU -> GenRedM (Stms GPU) optimiseStms env stms = localScope (scopeOf stms) $ do (_, stms') <- foldM foldfun (env, mempty) $ stmsToList stms pure stms' where foldfun :: (Env, Stms GPU) -> Stm GPU -> GenRedM (Env, Stms GPU) foldfun (e, ss) s = do (e', s') <- optimiseStm e s pure (e', ss <> s') optimiseStm :: Env -> Stm GPU -> GenRedM (Env, Stms GPU) optimiseStm env stm@(Let _ _ (Op (SegOp (SegMap SegThread {} _ _ _)))) = do res_genred_opt <- genRedOpts env stm let stms' = case res_genred_opt of Just stms -> stms Nothing -> oneStm stm pure (env, stms') optimiseStm env (Let pat aux e) = do env' <- changeEnv env (head $ patNames pat) e e' <- mapExpM (optimise env') e pure (env', oneStm $ Let pat aux e') where optimise env' = identityMapper {mapOnBody = \scope -> localScope scope . optimiseBody env'} ------------------------ genRedOpts :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU)) genRedOpts env ker = do res_tile <- genRed2Tile2d env ker case res_tile of Nothing -> do res_sgrd <- genRed2SegRed env ker helperGenRed res_sgrd _ -> helperGenRed res_tile where helperGenRed Nothing = pure Nothing helperGenRed (Just (stms_before, ker_snd)) = do mb_stms_after <- genRedOpts env ker_snd case mb_stms_after of Just stms_after -> pure $ Just $ stms_before <> stms_after Nothing -> pure $ Just $ stms_before <> oneStm ker_snd genRed2Tile2d :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU)) genRed2Tile2d env kerstm@(Let pat_ker aux (Op (SegOp (SegMap seg_thd seg_space kres_tps old_kbody)))) | SegThread _novirt _ <- seg_thd, -- novirt == SegNoVirtFull || novirt == SegNoVirt, KernelBody () kstms kres <- old_kbody, Just (css, r_ses) <- allGoodReturns kres, null css, -- build the variance table, that records, for -- each variable name, the variables it depends on initial_variance <- M.map mempty $ scopeOfSegSpace seg_space, variance <- varianceInStms initial_variance kstms, -- check that the code fits the pattern having: -- some `code1`, followed by one accumulation, followed by some `code2` -- UpdateAcc VName [SubExp] [SubExp] (code1, Just accum_stmt, code2) <- matchCodeAccumCode kstms, Let pat_accum _aux_acc (BasicOp (UpdateAcc safety acc_nm acc_inds acc_vals)) <- accum_stmt, [pat_acc_nm] <- patNames pat_accum, -- check that the `acc_inds` are invariant to at least one -- parallel kernel dimensions, and return the innermost such one: Just (invar_gid, gid_ind) <- isInvarToParDim mempty seg_space variance acc_inds, gid_dims_new_0 <- filter (\x -> invar_gid /= fst x) (unSegSpace seg_space), -- reorder the variant dimensions such that inner(most) accum-indices -- correspond to inner(most) parallel dimensions, so that the babysitter -- does not introduce transpositions -- gid_dims_new <- gid_dims_new_0, gid_dims_new <- reorderParDims variance acc_inds gid_dims_new_0, -- check that all global-memory accesses in `code1` on which -- `accum_stmt` depends on are invariant to at least one of -- the remaining parallel dimensions (i.e., excluding `invar_gid`) all (isTileable invar_gid gid_dims_new variance pat_acc_nm) (stmsToList code1), -- need to establish a cost model for the stms that would now -- be redundantly executed by the two kernels. If any recurence -- is redundant than it is a no go. Otherwise we need to look at -- memory accesses: if more than two are re-executed, then we -- should abort. cost <- costRedundantExecution variance pat_acc_nm r_ses kstms, maxCost cost (Small 2) == Small 2 = do -- 1. create the first kernel acc_tp <- lookupType acc_nm let inv_dim_len = segSpaceDims seg_space !! gid_ind -- 1.1. get the accumulation operator ((redop0, neutral), el_tps) = getAccLambda acc_tp redop <- renameLambda redop0 let red = Reduce { redComm = Commutative, redLambda = redop, redNeutral = neutral } -- 1.2. build the sequential map-reduce screma code1' = stmsFromList $ filter (dependsOnAcc pat_acc_nm variance) $ stmsToList code1 (code1'', code1_tr_host) <- transposeFVs (freeIn kerstm) variance invar_gid code1' let map_lam_body = mkBody code1'' $ map (SubExpRes (Certs [])) acc_vals map_lam0 = Lambda [Param mempty invar_gid (Prim int64)] el_tps map_lam_body map_lam <- renameLambda map_lam0 (k1_res, ker1_stms) <- runBuilderT' $ do iota <- letExp "iota" $ BasicOp $ Iota inv_dim_len (intConst Int64 0) (intConst Int64 1) Int64 let op_exp = Op (OtherOp (Screma inv_dim_len [iota] (ScremaForm map_lam [] [red]))) res_redmap <- letTupExp "res_mapred" op_exp letSubExp (baseString pat_acc_nm ++ "_big_update") $ BasicOp (UpdateAcc safety acc_nm acc_inds $ map Var res_redmap) -- 1.3. build the kernel expression and rename it! gid_flat_1 <- newVName "gid_flat" let space1 = SegSpace gid_flat_1 gid_dims_new let level1 = SegThread (SegNoVirtFull (SegSeqDims [])) Nothing -- novirt ? kbody1 = KernelBody () ker1_stms [Returns ResultMaySimplify (Certs []) k1_res] -- is it OK here to use the "aux" from the parrent kernel? ker_exp <- renameExp $ Op (SegOp (SegMap level1 space1 [acc_tp] kbody1)) let ker1 = Let pat_accum aux ker_exp -- 2 build the second kernel let ker2_body = old_kbody {kernelBodyStms = code1 <> code2} ker2_exp <- renameExp $ Op (SegOp (SegMap seg_thd seg_space kres_tps ker2_body)) let ker2 = Let pat_ker aux ker2_exp pure $ Just (code1_tr_host <> oneStm ker1, ker2) where isIndVarToParDim _ (Constant _) _ = False isIndVarToParDim variance (Var acc_ind) par_dim = acc_ind == fst par_dim || nameIn (fst par_dim) (M.findWithDefault mempty acc_ind variance) foldfunReorder variance (unused_dims, inner_dims) acc_ind = case L.findIndex (isIndVarToParDim variance acc_ind) unused_dims of Nothing -> (unused_dims, inner_dims) Just i -> ( take i unused_dims ++ drop (i + 1) unused_dims, (unused_dims !! i) : inner_dims ) reorderParDims variance acc_inds gid_dims_new_0 = let (invar_dims, inner_dims) = foldl (foldfunReorder variance) (gid_dims_new_0, []) (reverse acc_inds) in invar_dims ++ inner_dims -- getAccLambda acc_tp = case acc_tp of (Acc tp_id _shp el_tps _) -> case M.lookup tp_id (fst env) of Just lam -> (lam, el_tps) _ -> error $ "Lookup in environment failed! " ++ prettyString tp_id ++ " env: " ++ show (fst env) _ -> error "Illegal accumulator type!" -- is a subexp invariant to a gid of a parallel dimension? isSeInvar2 variance gid (Var x) = let x_deps = M.findWithDefault mempty x variance in gid /= x && gid `notNameIn` x_deps isSeInvar2 _ _ _ = True -- is a DimIndex invar to a gid of a parallel dimension? isDimIdxInvar2 variance gid (DimFix d) = isSeInvar2 variance gid d isDimIdxInvar2 variance gid (DimSlice d1 d2 d3) = all (isSeInvar2 variance gid) [d1, d2, d3] -- is an entire slice invariant to at least one gid of a parallel dimension isSliceInvar2 variance slc = any (\gid -> all (isDimIdxInvar2 variance gid) (unSlice slc)) -- are all statements that touch memory invariant to at least one parallel dimension? isTileable :: VName -> [(VName, SubExp)] -> VarianceTable -> VName -> Stm GPU -> Bool isTileable seq_gid gid_dims variance acc_nm (Let (Pat [pel]) _ (BasicOp (Index _ slc))) | acc_deps <- M.findWithDefault mempty acc_nm variance, patElemName pel `nameIn` acc_deps = let invar_par = isSliceInvar2 variance slc (map fst gid_dims) invar_seq = isSliceInvar2 variance slc [seq_gid] in invar_par || invar_seq -- this relies on the cost model, that currently accepts only -- global-memory reads, and for example rejects in-place updates -- or loops inside the code that is transformed in a redomap. isTileable _ _ _ _ _ = True -- does the to-be-reduced accumulator depends on this statement? dependsOnAcc pat_acc_nm variance (Let pat _ _) = let acc_deps = M.findWithDefault mempty pat_acc_nm variance in any (`nameIn` acc_deps) $ patNames pat genRed2Tile2d _ _ = pure Nothing genRed2SegRed :: Env -> Stm GPU -> GenRedM (Maybe (Stms GPU, Stm GPU)) genRed2SegRed _ _ = pure Nothing transposeFVs :: Names -> VarianceTable -> VName -> Stms GPU -> GenRedM (Stms GPU, Stms GPU) transposeFVs fvs variance gid stms = do (tab, stms') <- foldM foldfun (M.empty, mempty) $ stmsToList stms let stms_host = M.foldr (\(_, _, s) ss -> ss <> s) mempty tab pure (stms', stms_host) where foldfun (tab, all_stms) stm = do (tab', stm') <- transposeFV (tab, stm) pure (tab', all_stms <> oneStm stm') -- ToDo: currently handles only 2-dim arrays, please generalize transposeFV (tab, Let pat aux (BasicOp (Index arr slc))) | dims <- unSlice slc, all isFixDim dims, arr `nameIn` fvs, iis <- L.findIndices depOnGid dims, [ii] <- iis, -- generalize below: treat any rearange and add to tab if not there. Nothing <- M.lookup arr tab, ii /= length dims - 1, perm <- [0 .. ii - 1] ++ [ii + 1 .. length dims - 1] ++ [ii] = do (arr_tr, stms_tr) <- runBuilderT' $ do arr' <- letExp (baseString arr ++ "_trsp") $ BasicOp $ Rearrange perm arr -- Manifest [1,0] arr letExp (baseString arr' ++ "_opaque") $ BasicOp $ Opaque OpaqueNil $ Var arr' let tab' = M.insert arr (perm, arr_tr, stms_tr) tab slc' = Slice $ map (dims !!) perm stm' = Let pat aux $ BasicOp $ Index arr_tr slc' pure (tab', stm') where isFixDim DimFix {} = True isFixDim _ = False depOnGid (DimFix (Var nm)) = gid == nm || nameIn gid (M.findWithDefault mempty nm variance) depOnGid _ = False transposeFV r = pure r -- | Tries to identify the following pattern: -- code followed by some UpdateAcc-statement -- followed by more code. matchCodeAccumCode :: Stms GPU -> (Stms GPU, Maybe (Stm GPU), Stms GPU) matchCodeAccumCode kstms = let (code1, screma, code2) = foldl ( \acc stmt -> case (acc, stmt) of ((cd1, Nothing, cd2), Let _ _ (BasicOp UpdateAcc {})) -> (cd1, Just stmt, cd2) ((cd1, Nothing, cd2), _) -> (cd1 ++ [stmt], Nothing, cd2) ((cd1, Just strm, cd2), _) -> (cd1, Just strm, cd2 ++ [stmt]) ) ([], Nothing, []) (stmsToList kstms) in (stmsFromList code1, screma, stmsFromList code2) -- | Checks that there exist a parallel dimension (among @kids@), -- to which all the indices (@acc_inds@) are invariant to. -- It returns the innermost such parallel dimension, as a tuple -- of the pardim gid ('VName') and its index ('Int') in the -- parallel space. isInvarToParDim :: Names -> SegSpace -> VarianceTable -> [SubExp] -> Maybe (VName, Int) isInvarToParDim branch_variant kspace variance acc_inds = let ker_gids = map fst $ unSegSpace kspace branch_invariant = all (`notNameIn` branch_variant) ker_gids allvar2 = allvariant2 acc_inds ker_gids last_invar_dim = foldl (lastNotIn allvar2) Nothing $ zip ker_gids [0 .. length ker_gids - 1] in if branch_invariant then last_invar_dim else Nothing where variant2 (Var ind) kids = let variant_to = M.findWithDefault mempty ind variance <> (if ind `elem` kids then oneName ind else mempty) in filter (`nameIn` variant_to) kids variant2 _ _ = [] allvariant2 ind_ses kids = namesFromList $ concatMap (`variant2` kids) ind_ses lastNotIn allvar2 acc (kid, k) = if kid `nameIn` allvar2 then acc else Just (kid, k) allGoodReturns :: [KernelResult] -> Maybe ([VName], [SubExp]) allGoodReturns kres | all goodReturn kres = do Just $ foldl addCertAndRes ([], []) kres where goodReturn (Returns ResultMaySimplify _ _) = True goodReturn _ = False addCertAndRes (cs, rs) (Returns ResultMaySimplify c r_se) = (cs ++ unCerts c, rs ++ [r_se]) addCertAndRes _ _ = error "Impossible case reached in GenRedOpt.hs, function allGoodReturns!" allGoodReturns _ = Nothing -------------------------- --- Cost Model Helpers --- -------------------------- costRedundantExecution :: VarianceTable -> VName -> [SubExp] -> Stms GPU -> Cost costRedundantExecution variance pat_acc_nm r_ses kstms = let acc_deps = M.findWithDefault mempty pat_acc_nm variance vartab_cut_acc = varianceInStmsWithout (oneName pat_acc_nm) mempty kstms res_deps = mconcat $ map (findDeps vartab_cut_acc) $ mapMaybe se2nm r_ses common_deps = namesIntersection res_deps acc_deps in foldl (addCostOfStmt common_deps) (Small 0) kstms where se2nm (Var nm) = Just nm se2nm _ = Nothing findDeps vartab nm = M.findWithDefault mempty nm vartab addCostOfStmt common_deps cur_cost stm = let pat_nms = patNames $ stmPat stm in if namesIntersect (namesFromList pat_nms) common_deps then addCosts cur_cost $ costRedundantStmt stm else cur_cost varianceInStmsWithout :: Names -> VarianceTable -> Stms GPU -> VarianceTable varianceInStmsWithout nms = L.foldl' (varianceInStmWithout nms) varianceInStmWithout cuts vartab stm = let pat_nms = patNames $ stmPat stm in if namesIntersect (namesFromList pat_nms) cuts then vartab else L.foldl' add vartab pat_nms where add variance' v = M.insert v binding_variance variance' look variance' v = oneName v <> M.findWithDefault mempty v variance' binding_variance = mconcat $ map (look vartab) $ namesToList (freeIn stm) data Cost = Small Int | Big | Break deriving (Eq) addCosts :: Cost -> Cost -> Cost addCosts Break _ = Break addCosts _ Break = Break addCosts Big _ = Big addCosts _ Big = Big addCosts (Small c1) (Small c2) = Small (c1 + c2) maxCost :: Cost -> Cost -> Cost maxCost (Small c1) (Small c2) = Small (max c1 c2) maxCost c1 c2 = addCosts c1 c2 costBody :: Body GPU -> Cost costBody bdy = foldl addCosts (Small 0) $ map costRedundantStmt $ stmsToList $ bodyStms bdy costRedundantStmt :: Stm GPU -> Cost costRedundantStmt (Let _ _ (Op _)) = Big costRedundantStmt (Let _ _ Loop {}) = Big costRedundantStmt (Let _ _ Apply {}) = Big costRedundantStmt (Let _ _ WithAcc {}) = Big costRedundantStmt (Let _ _ (Match _ cases defbody _)) = L.foldl' maxCost (costBody defbody) $ map (costBody . caseBody) cases costRedundantStmt (Let _ _ (BasicOp (ArrayLit _ Array {}))) = Big costRedundantStmt (Let _ _ (BasicOp (ArrayLit _ _))) = Small 1 costRedundantStmt (Let _ _ (BasicOp (Index _ slc))) = if all isFixDim (unSlice slc) then Small 1 else Small 0 where isFixDim DimFix {} = True isFixDim _ = False costRedundantStmt (Let _ _ (BasicOp FlatIndex {})) = Small 0 costRedundantStmt (Let _ _ (BasicOp Update {})) = Break costRedundantStmt (Let _ _ (BasicOp FlatUpdate {})) = Break costRedundantStmt (Let _ _ (BasicOp Concat {})) = Big costRedundantStmt (Let _ _ (BasicOp Manifest {})) = Big costRedundantStmt (Let _ _ (BasicOp Replicate {})) = Big costRedundantStmt (Let _ _ (BasicOp UpdateAcc {})) = Break costRedundantStmt (Let _ _ (BasicOp _)) = Small 0 futhark-0.25.27/src/Futhark/Optimise/HistAccs.hs000066400000000000000000000130301475065116200213600ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Turn certain uses of accumulators into SegHists. module Futhark.Optimise.HistAccs (histAccsGPU) where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.Map.Strict qualified as M import Futhark.IR.GPU import Futhark.MonadFreshNames import Futhark.Pass import Futhark.Tools import Futhark.Transform.Rename import Prelude hiding (quot) -- | A mapping from accumulator variables to their source. type Accs rep = M.Map VName (WithAccInput rep) type OptM = ReaderT (Scope GPU) (State VNameSource) optimiseBody :: Accs GPU -> Body GPU -> OptM (Body GPU) optimiseBody accs body = mkBody <$> optimiseStms accs (bodyStms body) <*> pure (bodyResult body) optimiseExp :: Accs GPU -> Exp GPU -> OptM (Exp GPU) optimiseExp accs = mapExpM mapper where mapper = identityMapper { mapOnBody = \scope body -> localScope scope $ optimiseBody accs body } extractUpdate :: Accs rep -> VName -> Stms rep -> Maybe ((WithAccInput rep, VName, [SubExp], [SubExp]), Stms rep) extractUpdate accs v stms = do (stm, stms') <- stmsHead stms case stm of Let (Pat [PatElem pe_v _]) _ (BasicOp (UpdateAcc _ acc is vs)) | pe_v == v -> do acc_input <- M.lookup acc accs Just ((acc_input, acc, is, vs), stms') _ -> do (x, stms'') <- extractUpdate accs v stms' pure (x, oneStm stm <> stms'') mkHistBody :: Accs GPU -> KernelBody GPU -> Maybe (KernelBody GPU, WithAccInput GPU, VName) mkHistBody accs (KernelBody () stms [Returns rm cs (Var v)]) = do ((acc_input, acc, is, vs), stms') <- extractUpdate accs v stms pure ( KernelBody () stms' $ map (Returns rm cs) is ++ map (Returns rm cs) vs, acc_input, acc ) mkHistBody _ _ = Nothing withAccLamToHistLam :: (MonadFreshNames m) => Shape -> Lambda GPU -> m (Lambda GPU) withAccLamToHistLam shape lam = renameLambda $ lam {lambdaParams = drop (shapeRank shape) (lambdaParams lam)} addArrsToAcc :: (MonadBuilder m, Rep m ~ GPU) => SegLevel -> Shape -> [VName] -> VName -> m (Exp GPU) addArrsToAcc lvl shape arrs acc = do flat <- newVName "phys_tid" gtids <- replicateM (shapeRank shape) (newVName "gtid") let space = SegSpace flat $ zip gtids $ shapeDims shape (acc', stms) <- localScope (scopeOfSegSpace space) . collectStms $ do vs <- forM arrs $ \arr -> do arr_t <- lookupType arr letSubExp (baseString arr <> "_elem") $ BasicOp $ Index arr $ fullSlice arr_t $ map (DimFix . Var) gtids letExp (baseString acc <> "_upd") $ BasicOp $ UpdateAcc Safe acc (map Var gtids) vs acc_t <- lookupType acc pure . Op . SegOp . SegMap lvl space [acc_t] $ KernelBody () stms [Returns ResultMaySimplify mempty (Var acc')] flatKernelBody :: (MonadBuilder m) => SegSpace -> KernelBody (Rep m) -> m (SegSpace, KernelBody (Rep m)) flatKernelBody space kbody = do gtid <- newVName "gtid" dims_prod <- letSubExp "dims_prod" =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) (segSpaceDims space) let space' = SegSpace (segFlat space) [(gtid, dims_prod)] kbody_stms <- localScope (scopeOfSegSpace space') . collectStms_ $ do let new_inds = unflattenIndex (map pe64 (segSpaceDims space)) (pe64 $ Var gtid) zipWithM_ letBindNames (map (pure . fst) (unSegSpace space)) =<< mapM toExp new_inds addStms $ kernelBodyStms kbody pure (space', kbody {kernelBodyStms = kbody_stms}) optimiseStm :: Accs GPU -> Stm GPU -> OptM (Stms GPU) -- TODO: this is very restricted currently, but shows the idea. optimiseStm accs (Let pat aux (WithAcc inputs lam)) = do localScope (scopeOfLParams (lambdaParams lam)) $ do body' <- optimiseBody accs' $ lambdaBody lam let lam' = lam {lambdaBody = body'} pure $ oneStm $ Let pat aux $ WithAcc inputs lam' where acc_names = map paramName $ drop (length inputs) $ lambdaParams lam accs' = M.fromList (zip acc_names inputs) <> accs optimiseStm accs (Let pat aux (Op (SegOp (SegMap lvl space _ kbody)))) | accs /= mempty, Just (kbody', (acc_shape, _, Just (acc_lam, acc_nes)), acc) <- mkHistBody accs kbody, all primType $ lambdaReturnType acc_lam = runBuilder_ $ do hist_dests <- forM acc_nes $ \ne -> letExp "hist_dest" $ BasicOp $ Replicate acc_shape ne acc_lam' <- withAccLamToHistLam acc_shape acc_lam let ts' = replicate (shapeRank acc_shape) (Prim int64) ++ lambdaReturnType acc_lam histop = HistOp { histShape = acc_shape, histRaceFactor = intConst Int64 1, histDest = hist_dests, histNeutral = acc_nes, histOpShape = mempty, histOp = acc_lam' } (space', kbody'') <- flatKernelBody space kbody' hist_dest_upd <- letTupExp "hist_dest_upd" $ Op $ SegOp $ SegHist lvl space' [histop] ts' kbody'' addStm . Let pat aux =<< addArrsToAcc lvl acc_shape hist_dest_upd acc optimiseStm accs (Let pat aux e) = oneStm . Let pat aux <$> optimiseExp accs e optimiseStms :: Accs GPU -> Stms GPU -> OptM (Stms GPU) optimiseStms accs stms = localScope (scopeOf stms) $ mconcat <$> mapM (optimiseStm accs) (stmsToList stms) -- | The pass for GPU kernels. histAccsGPU :: Pass GPU GPU histAccsGPU = Pass "hist accs" "Turn certain accumulations into histograms" $ intraproceduralTransformation onStms where onStms scope stms = modifyNameSource . runState $ runReaderT (optimiseStms mempty stms) scope futhark-0.25.27/src/Futhark/Optimise/InliningDeadFun.hs000066400000000000000000000317131475065116200226650ustar00rootroot00000000000000-- | This module implements a compiler pass for inlining functions, -- then removing those that have become dead. module Futhark.Optimise.InliningDeadFun ( inlineAggressively, inlineConservatively, removeDeadFunctions, ) where import Control.Monad import Control.Monad.Identity import Control.Monad.State import Control.Parallel.Strategies import Data.Functor (($>)) import Data.List (partition) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.Analysis.CallGraph import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Builder import Futhark.IR.SOACS import Futhark.IR.SOACS.Simplify ( simpleSOACS, simplifyConsts, simplifyFun, ) import Futhark.Optimise.CSE import Futhark.Optimise.Simplify.Rep (addScopeWisdom, informStms) import Futhark.Pass import Futhark.Transform.CopyPropagate ( copyPropagateInFun, copyPropagateInProg, ) import Futhark.Transform.Rename parMapM :: (MonadFreshNames m) => (a -> State VNameSource b) -> [a] -> m [b] -- The special-casing of [] is quite important here! If 'as' is -- empty, then we might otherwise create an empty name source below, -- which can wreak all kinds of havoc. parMapM _ [] = pure [] parMapM f as = modifyNameSource $ \src -> let f' a = runState (f a) src (bs, srcs) = unzip $ parMap rpar f' as in (bs, mconcat srcs) -- It is more efficient to shrink the program as soon as possible, -- rather than wait until it has balooned after full inlining. This -- is the inverse rate at which we perform full simplification after -- inlining. For the other steps we just do copy propagation. The -- simplification rates used have been determined heuristically and -- are probably not optimal for any given program. inlineFunctions :: (MonadFreshNames m) => Int -> CallGraph -> S.Set Name -> Prog SOACS -> m (Prog SOACS) inlineFunctions simplify_rate cg what_should_be_inlined prog = do let consts = progConsts prog funs = progFuns prog vtable = ST.fromScope (addScopeWisdom (scopeOf consts)) (consts', funs') <- recurse (1, vtable) (consts, funs) what_should_be_inlined pure $ prog {progConsts = consts', progFuns = funs'} where fdmap fds = M.fromList $ zip (map funDefName fds) fds noCallsTo which from = S.null $ allCalledBy from cg `S.intersection` which recurse (i, vtable) (consts, funs) to_inline = do let (to_inline_now, to_inline_later) = S.partition (noCallsTo to_inline) to_inline (dont_inline_in, to_inline_in) = partition (noCallsTo to_inline_now . funDefName) funs if null to_inline_now then pure (consts, funs) else do let inlinemap = fdmap $ filter ((`S.member` to_inline_now) . funDefName) dont_inline_in (vtable', consts') <- if any (`calledByConsts` cg) to_inline_now then do consts' <- simplifyConsts . performCSEOnStms =<< inlineInStms inlinemap consts pure (ST.insertStms (informStms consts') mempty, consts') else pure (vtable, consts) let simplifyFun' fd | i `rem` simplify_rate == 0 = copyPropagateInFun simpleSOACS vtable' . performCSEOnFunDef True =<< simplifyFun vtable' fd | otherwise = copyPropagateInFun simpleSOACS vtable' fd onFun = simplifyFun' <=< inlineInFunDef inlinemap to_inline_in' <- parMapM onFun to_inline_in recurse (i + 1, vtable') (consts', dont_inline_in <> to_inline_in') to_inline_later calledOnce :: CallGraph -> S.Set Name calledOnce = S.fromList . map fst . filter ((== 1) . snd) . M.toList . numOccurences inlineBecauseTiny :: Prog SOACS -> S.Set Name inlineBecauseTiny = foldMap onFunDef . progFuns where onFunDef fd | (length (bodyStms (funDefBody fd)) <= k) || ("inline" `inAttrs` funDefAttrs fd) = S.singleton (funDefName fd) | otherwise = mempty where k = length (funDefRetType fd) + length (funDefParams fd) progStms :: Prog SOACS -> Stms SOACS progStms prog = progConsts prog <> foldMap (bodyStms . funDefBody) (progFuns prog) data Used = InSOAC | InAD deriving (Eq, Ord, Show) directlyCalledInSOACs :: Prog SOACS -> M.Map Name Used directlyCalledInSOACs = flip execState mempty . mapM_ (onStm Nothing) . progStms where onBody :: Maybe Used -> Body SOACS -> State (M.Map Name Used) () onBody u = mapM_ (onStm u) . bodyStms onStm u stm = onExp u (stmExp stm) $> stm onExp (Just u) (Apply fname _ _ _) = modify $ M.insertWith max fname u onExp Nothing Apply {} = pure () onExp u e = walkExpM (walker u) e onSOAC u soac = void $ traverseSOACStms (const (traverse (onStm u'))) soac where u' = max u $ Just $ usage soac usage JVP {} = InAD usage VJP {} = InAD usage _ = InSOAC walker u = (identityWalker :: Walker SOACS (State (M.Map Name Used))) { walkOnBody = const (onBody u), walkOnOp = onSOAC u } -- Expand set of function names with all reachable functions. withTransitiveCalls :: CallGraph -> M.Map Name Used -> M.Map Name Used withTransitiveCalls cg fs | fs == fs' = fs | otherwise = withTransitiveCalls cg fs' where look :: (Name, Used) -> M.Map Name Used look (f, u) = M.fromList $ map (,u) (S.toList (allCalledBy f cg)) fs' = foldr (M.unionWith max . look) fs $ M.toList fs calledInSOACs :: CallGraph -> Prog SOACS -> M.Map Name Used calledInSOACs cg prog = withTransitiveCalls cg $ directlyCalledInSOACs prog -- Inline those functions that are used in SOACs, and which involve -- arrays of any kind, as well as any functions used in AD. inlineBecauseSOACs :: CallGraph -> Prog SOACS -> S.Set Name inlineBecauseSOACs cg prog = S.fromList $ mapMaybe onFunDef (progFuns prog) where called = calledInSOACs cg prog isArray = not . primType inline _ InAD = True inline fd InSOAC = any (isArray . paramType) (funDefParams fd) || any (isArray . fst) (funDefRetType fd) || arrayInBody (funDefBody fd) onFunDef fd = do guard $ maybe False (inline fd) $ M.lookup (funDefName fd) called Just $ funDefName fd arrayInBody = any arrayInStm . bodyStms arrayInStm stm = any isArray (patTypes (stmPat stm)) || arrayInExp (stmExp stm) arrayInExp (Match _ cases defbody _) = any arrayInBody $ defbody : map caseBody cases arrayInExp (Loop _ _ body) = arrayInBody body arrayInExp _ = False -- Conservative inlining of functions that are called just once, or -- have #[inline] on them. consInlineFunctions :: (MonadFreshNames m) => Prog SOACS -> m (Prog SOACS) consInlineFunctions prog = inlineFunctions 4 cg (calledOnce cg <> inlineBecauseTiny prog) prog where cg = buildCallGraph prog -- Inline aggressively; in particular most things called from a SOAC. aggInlineFunctions :: (MonadFreshNames m) => Prog SOACS -> m (Prog SOACS) aggInlineFunctions prog = inlineFunctions 3 cg (inlineBecauseTiny prog <> inlineBecauseSOACs cg prog) prog where cg = buildCallGraph prog -- | @inlineInFunDef constf fdmap caller@ inlines in @calleer@ the -- functions in @fdmap@ that are called as @constf@. At this point the -- preconditions are that if @fdmap@ is not empty, and, more -- importantly, the functions in @fdmap@ do not call any other -- functions. inlineInFunDef :: (MonadFreshNames m) => M.Map Name (FunDef SOACS) -> FunDef SOACS -> m (FunDef SOACS) inlineInFunDef fdmap (FunDef entry attrs name rtp args body) = FunDef entry attrs name rtp args <$> inlineInBody fdmap body inlineFunction :: (MonadFreshNames m) => Pat Type -> StmAux dec -> [(SubExp, Diet)] -> (Safety, SrcLoc, [SrcLoc]) -> FunDef SOACS -> m (Stms SOACS) inlineFunction pat aux args (safety, loc, locs) fun = do Body _ stms res <- renameBody $ mkBody (param_stms <> body_stms) (bodyResult (funDefBody fun)) pure $ stms <> stmsFromList (zipWith bindSubExpRes (patIdents pat) res) where param_stms = stmsFromList $ certify (stmAuxCerts aux) <$> zipWith bindSubExp (map paramIdent $ funDefParams fun) (map fst args) body_stms = addLocations (stmAuxAttrs aux) safety (filter notmempty (loc : locs)) $ bodyStms $ funDefBody fun -- Note that the sizes of arrays may not be correct at this -- point - it is crucial that we run copy propagation before -- the type checker sees this! bindSubExp ident se = mkLet [ident] $ BasicOp $ SubExp se bindSubExpRes ident (SubExpRes cs se) = certify cs $ bindSubExp ident se notmempty = (/= mempty) . locOf inlineInStms :: (MonadFreshNames m) => M.Map Name (FunDef SOACS) -> Stms SOACS -> m (Stms SOACS) inlineInStms fdmap stms = bodyStms <$> inlineInBody fdmap (mkBody stms []) inlineInBody :: (MonadFreshNames m) => M.Map Name (FunDef SOACS) -> Body SOACS -> m (Body SOACS) inlineInBody fdmap = onBody where inline (Let pat aux (Apply fname args _ what) : rest) | Just fd <- M.lookup fname fdmap, not $ "noinline" `inAttrs` funDefAttrs fd, not $ "noinline" `inAttrs` stmAuxAttrs aux = (<>) <$> inlineFunction pat aux args what fd <*> inline rest inline (stm@(Let _ _ BasicOp {}) : rest) = (oneStm stm <>) <$> inline rest inline (stm : rest) = (<>) <$> (oneStm <$> onStm stm) <*> inline rest inline [] = pure mempty onBody (Body dec stms res) = Body dec <$> inline (stmsToList stms) <*> pure res onStm (Let pat aux e) = Let pat aux <$> mapExpM inliner e inliner = (identityMapper @SOACS) { mapOnBody = const onBody, mapOnOp = onSOAC } onSOAC = mapSOACM identitySOACMapper {mapOnSOACLambda = onLambda} onLambda (Lambda params ret body) = Lambda params ret <$> onBody body -- Propagate source locations and attributes to the inlined -- statements. Attributes are propagated only when applicable (this -- probably means that every supported attribute needs to be handled -- specially here). addLocations :: Attrs -> Safety -> [SrcLoc] -> Stms SOACS -> Stms SOACS addLocations attrs caller_safety more_locs = fmap onStm where onStm (Let pat aux (Apply fname args t (safety, loc, locs))) = Let pat aux' $ Apply fname args t (min caller_safety safety, loc, locs ++ more_locs) where aux' = aux {stmAuxAttrs = attrs <> stmAuxAttrs aux} onStm (Let pat aux (BasicOp (Assert cond desc (loc, locs)))) = Let pat (withAttrs (attrsForAssert attrs) aux) $ case caller_safety of Safe -> BasicOp $ Assert cond desc (loc, locs ++ more_locs) Unsafe -> BasicOp $ SubExp $ Constant UnitValue onStm (Let pat aux (Op soac)) = Let pat (withAttrs attrs' aux) $ Op $ runIdentity $ mapSOACM identitySOACMapper { mapOnSOACLambda = pure . onLambda } soac where attrs' = attrs `withoutAttrs` for_assert for_assert = attrsForAssert attrs onLambda lam = lam {lambdaBody = onBody for_assert $ lambdaBody lam} onStm (Let pat aux e) = Let pat aux $ onExp e onExp = mapExp identityMapper { mapOnBody = const $ pure . onBody attrs } withAttrs attrs' aux = aux {stmAuxAttrs = attrs' <> stmAuxAttrs aux} onBody attrs' body = body { bodyStms = addLocations attrs' caller_safety more_locs $ bodyStms body } -- | Remove functions not ultimately called from an entry point or a -- constant. removeDeadFunctionsF :: Prog SOACS -> Prog SOACS removeDeadFunctionsF prog = let cg = buildCallGraph prog live_funs = filter ((`isFunInCallGraph` cg) . funDefName) $ progFuns prog in prog {progFuns = live_funs} -- | Inline all functions and remove the resulting dead functions. inlineAggressively :: Pass SOACS SOACS inlineAggressively = Pass { passName = "Inline aggressively", passDescription = "Aggressively inline and remove resulting dead functions.", passFunction = copyPropagateInProg simpleSOACS . removeDeadFunctionsF <=< aggInlineFunctions } -- | Inline some functions and remove the resulting dead functions. inlineConservatively :: Pass SOACS SOACS inlineConservatively = Pass { passName = "Inline conservatively", passDescription = "Conservatively inline and remove resulting dead functions.", passFunction = copyPropagateInProg simpleSOACS . removeDeadFunctionsF <=< consInlineFunctions } -- | @removeDeadFunctions prog@ removes the functions that are unreachable from -- the main function from the program. removeDeadFunctions :: Pass SOACS SOACS removeDeadFunctions = Pass { passName = "Remove dead functions", passDescription = "Remove the functions that are unreachable from entry points", passFunction = pure . removeDeadFunctionsF } futhark-0.25.27/src/Futhark/Optimise/MemoryBlockMerging.hs000066400000000000000000000167721475065116200234330ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | This module implements an optimization that tries to statically reuse -- kernel-level allocations. The goal is to lower the static memory usage, which -- might allow more programs to run using intra-group parallelism. module Futhark.Optimise.MemoryBlockMerging (optimise) where import Control.Exception import Control.Monad.State.Strict import Data.Function ((&)) import Data.Map (Map, (!)) import Data.Map qualified as M import Data.Set (Set) import Data.Set qualified as S import Futhark.Analysis.Interference qualified as Interference import Futhark.Builder.Class import Futhark.Construct import Futhark.IR.GPUMem import Futhark.Optimise.MemoryBlockMerging.GreedyColoring qualified as GreedyColoring import Futhark.Pass (Pass (..), PassM) import Futhark.Pass qualified as Pass import Futhark.Util (invertMap) -- | A mapping from allocation names to their size and space. type Allocs = Map VName (SubExp, Space) getAllocsStm :: Stm GPUMem -> Allocs getAllocsStm (Let (Pat [PatElem name _]) _ (Op (Alloc se sp))) = M.singleton name (se, sp) getAllocsStm (Let _ _ (Op (Alloc _ _))) = error "impossible" getAllocsStm (Let _ _ (Match _ cases defbody _)) = foldMap (foldMap getAllocsStm . bodyStms) $ defbody : map caseBody cases getAllocsStm (Let _ _ (Loop _ _ body)) = foldMap getAllocsStm (bodyStms body) getAllocsStm _ = mempty getAllocsSegOp :: SegOp lvl GPUMem -> Allocs getAllocsSegOp (SegMap _ _ _ body) = foldMap getAllocsStm (kernelBodyStms body) getAllocsSegOp (SegRed _ _ _ _ body) = foldMap getAllocsStm (kernelBodyStms body) getAllocsSegOp (SegScan _ _ _ _ body) = foldMap getAllocsStm (kernelBodyStms body) getAllocsSegOp (SegHist _ _ _ _ body) = foldMap getAllocsStm (kernelBodyStms body) setAllocsStm :: Map VName SubExp -> Stm GPUMem -> Stm GPUMem setAllocsStm m stm@(Let (Pat [PatElem name _]) _ (Op (Alloc _ _))) | Just s <- M.lookup name m = stm {stmExp = BasicOp $ SubExp s} setAllocsStm _ stm@(Let _ _ (Op (Alloc _ _))) = stm setAllocsStm m stm@(Let _ _ (Op (Inner (SegOp segop)))) = stm {stmExp = Op $ Inner $ SegOp $ setAllocsSegOp m segop} setAllocsStm m stm@(Let _ _ (Match cond cases defbody dec)) = stm {stmExp = Match cond (map (fmap onBody) cases) (onBody defbody) dec} where onBody (Body () stms res) = Body () (setAllocsStm m <$> stms) res setAllocsStm m stm@(Let _ _ (Loop merge form body)) = stm { stmExp = Loop merge form (body {bodyStms = setAllocsStm m <$> bodyStms body}) } setAllocsStm _ stm = stm setAllocsSegOp :: Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem setAllocsSegOp m (SegMap lvl sp tps body) = SegMap lvl sp tps $ body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body} setAllocsSegOp m (SegRed lvl sp segbinops tps body) = SegRed lvl sp segbinops tps $ body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body} setAllocsSegOp m (SegScan lvl sp segbinops tps body) = SegScan lvl sp segbinops tps $ body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body} setAllocsSegOp m (SegHist lvl sp segbinops tps body) = SegHist lvl sp segbinops tps $ body {kernelBodyStms = setAllocsStm m <$> kernelBodyStms body} maxSubExp :: (MonadBuilder m) => Set SubExp -> m SubExp maxSubExp = helper . S.toList where helper (s1 : s2 : sexps) = do z <- letSubExp "maxSubHelper" $ BasicOp $ BinOp (UMax Int64) s1 s2 helper (z : sexps) helper [s] = pure s helper [] = error "impossible" isKernelInvariant :: Scope GPUMem -> (SubExp, space) -> Bool isKernelInvariant scope (Var vname, _) = vname `M.member` scope isKernelInvariant _ _ = True isScalarSpace :: (subExp, Space) -> Bool isScalarSpace (_, ScalarSpace _ _) = True isScalarSpace _ = False onKernelBodyStms :: (MonadBuilder m) => SegOp lvl GPUMem -> (Stms GPUMem -> m (Stms GPUMem)) -> m (SegOp lvl GPUMem) onKernelBodyStms (SegMap lvl space ts body) f = do stms <- f $ kernelBodyStms body pure $ SegMap lvl space ts $ body {kernelBodyStms = stms} onKernelBodyStms (SegRed lvl space binops ts body) f = do stms <- f $ kernelBodyStms body pure $ SegRed lvl space binops ts $ body {kernelBodyStms = stms} onKernelBodyStms (SegScan lvl space binops ts body) f = do stms <- f $ kernelBodyStms body pure $ SegScan lvl space binops ts $ body {kernelBodyStms = stms} onKernelBodyStms (SegHist lvl space binops ts body) f = do stms <- f $ kernelBodyStms body pure $ SegHist lvl space binops ts $ body {kernelBodyStms = stms} -- | This is the actual optimiser. Given an interference graph and a @SegOp@, -- replace allocations and references to memory blocks inside with a (hopefully) -- reduced number of allocations. optimiseKernel :: (MonadBuilder m, Rep m ~ GPUMem) => Interference.Graph VName -> SegOp lvl GPUMem -> m (SegOp lvl GPUMem) optimiseKernel graph segop0 = do segop <- onKernelBodyStms segop0 $ onKernels $ optimiseKernel graph scope_here <- askScope let allocs = M.filter (\alloc -> isKernelInvariant scope_here alloc && not (isScalarSpace alloc)) $ getAllocsSegOp segop (colorspaces, coloring) = GreedyColoring.colorGraph (fmap snd allocs) graph (maxes, maxstms) <- invertMap coloring & M.elems & mapM (maxSubExp . S.map (fst . (allocs !))) & collectStms (colors, stms) <- assert (length maxes == M.size colorspaces) maxes & zip [0 ..] & mapM (\(i, x) -> letSubExp "color" $ Op $ Alloc x $ colorspaces ! i) & collectStms let segop' = setAllocsSegOp (fmap (colors !!) coloring) segop pure $ case segop' of SegMap lvl sp tps body -> SegMap lvl sp tps $ body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body} SegRed lvl sp binops tps body -> SegRed lvl sp binops tps $ body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body} SegScan lvl sp binops tps body -> SegScan lvl sp binops tps $ body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body} SegHist lvl sp binops tps body -> SegHist lvl sp binops tps $ body {kernelBodyStms = maxstms <> stms <> kernelBodyStms body} -- | Helper function that modifies kernels found inside some statements. onKernels :: (LocalScope GPUMem m) => (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)) -> Stms GPUMem -> m (Stms GPUMem) onKernels f orig_stms = inScopeOf orig_stms $ mapM helper orig_stms where helper stm@Let {stmExp = Op (Inner (SegOp segop))} = do exp' <- f segop pure $ stm {stmExp = Op $ Inner $ SegOp exp'} helper stm@Let {stmExp = Match c cases defbody dec} = do cases' <- mapM (traverse onBody) cases defbody' <- onBody defbody pure $ stm {stmExp = Match c cases' defbody' dec} where onBody (Body () stms res) = Body () <$> f `onKernels` stms <*> pure res helper stm@Let {stmExp = Loop merge form body} = do body_stms <- f `onKernels` bodyStms body pure $ stm {stmExp = Loop merge form (body {bodyStms = body_stms})} helper stm = pure stm -- | Perform the reuse-allocations optimization. optimise :: Pass GPUMem GPUMem optimise = Pass "memory block merging" "memory block merging allocations" $ \prog -> let graph = Interference.analyseProgGPU prog in Pass.intraproceduralTransformation (onStms graph) prog where onStms :: Interference.Graph VName -> Scope GPUMem -> Stms GPUMem -> PassM (Stms GPUMem) onStms graph scope stms = do let m = localScope scope $ optimiseKernel graph `onKernels` stms fmap fst $ modifyNameSource $ runState (runBuilderT m mempty) futhark-0.25.27/src/Futhark/Optimise/MemoryBlockMerging/000077500000000000000000000000001475065116200230625ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Optimise/MemoryBlockMerging/GreedyColoring.hs000066400000000000000000000037571475065116200263460ustar00rootroot00000000000000-- | Provides a greedy graph-coloring algorithm. module Futhark.Optimise.MemoryBlockMerging.GreedyColoring (colorGraph, Coloring) where import Data.Function ((&)) import Data.Map qualified as M import Data.Maybe (fromMaybe) import Data.Set qualified as S import Futhark.Analysis.Interference qualified as Interference -- | A map of values to their color, identified by an integer. type Coloring a = M.Map a Int -- | A map of values to the set "neighbors" in the graph type Neighbors a = M.Map a (S.Set a) -- | Computes the neighbor map of a graph. neighbors :: (Ord a) => Interference.Graph a -> Neighbors a neighbors = S.foldr ( \(x, y) acc -> acc & M.insertWith S.union x (S.singleton y) & M.insertWith S.union y (S.singleton x) ) M.empty firstAvailable :: (Eq space) => M.Map Int space -> S.Set Int -> Int -> space -> (M.Map Int space, Int) firstAvailable spaces xs i sp = case (i `S.member` xs, spaces M.!? i) of (False, Just sp') | sp' == sp -> (spaces, i) (False, Nothing) -> (M.insert i sp spaces, i) _ -> firstAvailable spaces xs (i + 1) sp colorNode :: (Ord a, Eq space) => Neighbors a -> (a, space) -> (M.Map Int space, Coloring a) -> (M.Map Int space, Coloring a) colorNode nbs (x, sp) (spaces, coloring) = let nb_colors = foldMap (maybe S.empty S.singleton . (coloring M.!?)) $ fromMaybe mempty (nbs M.!? x) (spaces', color) = firstAvailable spaces nb_colors 0 sp in (spaces', M.insert x color coloring) -- | Graph coloring that takes into account the @space@ of values. Two values -- can only share the same color if they live in the same space. The result is -- map from each color to a space and a map from each value in the input graph -- to it's new color. colorGraph :: (Ord a, Ord space) => M.Map a space -> Interference.Graph a -> (M.Map Int space, Coloring a) colorGraph spaces graph = let nodes = S.fromList $ M.toList spaces nbs = neighbors graph in S.foldr (colorNode nbs) mempty nodes futhark-0.25.27/src/Futhark/Optimise/MergeGPUBodies.hs000066400000000000000000000572461475065116200224410ustar00rootroot00000000000000-- | -- This module implements an optimization pass that merges 'GPUBody' kernels to -- eliminate memory transactions and reduce the number of kernel launches. -- This is useful because the "Futhark.Optimise.ReduceDeviceSyncs" pass introduces -- 'GPUBody' kernels that only execute single statements. -- -- To merge as many 'GPUBody' kernels as possible, this pass reorders statements -- with the goal of bringing as many 'GPUBody' statements next to each other in -- a sequence. Such sequence can then trivially be merged. module Futhark.Optimise.MergeGPUBodies (mergeGPUBodies) where import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.State.Strict hiding (State) import Data.Bifunctor (first) import Data.Foldable import Data.IntMap qualified as IM import Data.IntSet ((\\)) import Data.IntSet qualified as IS import Data.Map qualified as M import Data.Maybe (fromMaybe) import Data.Sequence ((|>)) import Data.Sequence qualified as SQ import Futhark.Analysis.Alias import Futhark.Construct (sliceDim) import Futhark.Error import Futhark.IR.Aliases import Futhark.IR.GPU import Futhark.MonadFreshNames hiding (newName) import Futhark.Pass -- | An optimization pass that reorders and merges 'GPUBody' statements to -- eliminate memory transactions and reduce the number of kernel launches. mergeGPUBodies :: Pass GPU GPU mergeGPUBodies = Pass "merge GPU bodies" "Reorder and merge GPUBody constructs to reduce kernels executions." $ intraproceduralTransformationWithConsts onStms onFunDef . aliasAnalysis where onFunDef _ (FunDef entry attrs name types params body) = FunDef entry attrs name types params . fst <$> transformBody mempty body onStms stms = fst <$> transformStms mempty stms -------------------------------------------------------------------------------- -- COMMON - TYPES -- -------------------------------------------------------------------------------- -- | A set of 'VName' tags that denote all variables that some group of -- statements depend upon. Those must be computed before the group statements. type Dependencies = IS.IntSet -- | A set of 'VName' tags that denote all variables that some group of -- statements binds. type Bindings = IS.IntSet -- | A set of 'VName' tags that denote the root aliases of all arrays that some -- statement consumes. type Consumption = IS.IntSet -------------------------------------------------------------------------------- -- COMMON - HELPERS -- -------------------------------------------------------------------------------- -- | All free variables of a construct as 'Dependencies'. depsOf :: (FreeIn a) => a -> Dependencies depsOf = namesToSet . freeIn -- | Convert 'Names' to an integer set of name tags. namesToSet :: Names -> IS.IntSet namesToSet = IS.fromList . map baseTag . namesToList -------------------------------------------------------------------------------- -- AD HOC OPTIMIZATION -- -------------------------------------------------------------------------------- -- | Optimize a lambda and determine its dependencies. transformLambda :: AliasTable -> Lambda (Aliases GPU) -> PassM (Lambda GPU, Dependencies) transformLambda aliases (Lambda params types body) = do (body', deps) <- transformBody aliases body pure (Lambda params types body', deps) -- | Optimize a body and determine its dependencies. transformBody :: AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies) transformBody aliases (Body _ stms res) = do grp <- evalStateT (foldM_ reorderStm aliases stms >> collapse) initialState let stms' = groupStms grp let deps = (groupDependencies grp <> depsOf res) \\ groupBindings grp pure (Body () stms' res, deps) -- | Optimize a sequence of statements and determine their dependencies. transformStms :: AliasTable -> Stms (Aliases GPU) -> PassM (Stms GPU, Dependencies) transformStms aliases stms = do (Body _ stms' _, deps) <- transformBody aliases (Body mempty stms []) pure (stms', deps) -- | Optimizes and reorders a single statement within a sequence while tracking -- the declaration, observation, and consumption of its dependencies. -- This creates sequences of GPUBody statements that can be merged into single -- kernels. reorderStm :: AliasTable -> Stm (Aliases GPU) -> ReorderM AliasTable reorderStm aliases (Let pat (StmAux cs attrs _) e) = do (e', deps) <- lift (transformExp aliases e) let pat' = removePatAliases pat let stm' = Let pat' (StmAux cs attrs ()) e' let pes' = patElems pat' -- Array aliases can be seen as a directed graph where vertices are arrays -- (or the names that bind them) and an edge x -> y denotes that x aliases y. -- The root aliases of some array A is then the set of arrays that can be -- reached from A in graph and which have no edges themselves. -- -- All arrays that share a root alias are considered aliases of each other -- and will be consumed if either of them is consumed. -- When reordering statements we must ensure that no statement that consumes -- an array is moved before any statement that observes one of its aliases. -- -- That is to move statement X before statement Y the set of root aliases of -- arrays consumed by X must not overlap with the root aliases of arrays -- observed by Y. -- -- We consider the root aliases of Y's observed arrays as part of Y's -- dependencies and simply say that the root aliases of arrays consumed by X -- must not overlap those. -- -- To move X before Y then the dependencies of X must also not overlap with -- the variables bound by Y. let observed = namesToSet $ rootAliasesOf (fold $ patAliases pat) aliases let consumed = namesToSet $ rootAliasesOf (consumedInExp e) aliases let usage = Usage { usageBindings = IS.fromList $ map (baseTag . patElemName) pes', usageDependencies = observed <> deps <> depsOf pat' <> depsOf cs } case e' of Op GPUBody {} -> moveGPUBody stm' usage consumed _ -> moveOther stm' usage consumed pure $ foldl recordAliases aliases (patElems pat) where rootAliasesOf names atable = let look n = M.findWithDefault (oneName n) n atable in foldMap look (namesToList names) recordAliases atable pe | aliasesOf pe == mempty = atable | otherwise = let root_aliases = rootAliasesOf (aliasesOf pe) atable in M.insert (patElemName pe) root_aliases atable -- | Optimize a single expression and determine its dependencies. transformExp :: AliasTable -> Exp (Aliases GPU) -> PassM (Exp GPU, Dependencies) transformExp aliases e = case e of BasicOp {} -> pure (removeExpAliases e, depsOf e) Apply {} -> pure (removeExpAliases e, depsOf e) Match ses cases defbody dec -> do let transformCase (Case vs body) = first (Case vs) <$> transformBody aliases body (cases', cases_deps) <- mapAndUnzipM transformCase cases (defbody', defbody_deps) <- transformBody aliases defbody let deps = depsOf ses <> mconcat cases_deps <> defbody_deps <> depsOf dec pure (Match ses cases' defbody' dec, deps) Loop merge lform body -> do -- What merge and lform aliases outside the loop is irrelevant as those -- cannot be consumed within the loop. (body', body_deps) <- transformBody aliases body let (params, args) = unzip merge let deps = body_deps <> depsOf params <> depsOf args <> depsOf lform let scope = scopeOfLoopForm lform <> scopeOfFParams params :: Scope (Aliases GPU) let bound = IS.fromList $ map baseTag (M.keys scope) let deps' = deps \\ bound let dummy = Loop merge lform (Body (bodyDec body) SQ.empty []) :: Exp (Aliases GPU) let Loop merge' lform' _ = removeExpAliases dummy pure (Loop merge' lform' body', deps') WithAcc inputs lambda -> do accs <- mapM (transformWithAccInput aliases) inputs let (inputs', input_deps) = unzip accs -- The lambda parameters are all unique and thus have no aliases. (lambda', deps) <- transformLambda aliases lambda pure (WithAcc inputs' lambda', deps <> fold input_deps) Op {} -> -- A GPUBody cannot be nested within other HostOp constructs. pure (removeExpAliases e, depsOf e) -- | Optimize a single WithAcc input and determine its dependencies. transformWithAccInput :: AliasTable -> WithAccInput (Aliases GPU) -> PassM (WithAccInput GPU, Dependencies) transformWithAccInput aliases (shape, arrs, op) = do (op', deps) <- case op of Nothing -> pure (Nothing, mempty) Just (f, nes) -> do -- The lambda parameters have no aliases. (f', deps) <- transformLambda aliases f pure (Just (f', nes), deps <> depsOf nes) let deps' = deps <> depsOf shape <> depsOf arrs pure ((shape, arrs, op'), deps') -------------------------------------------------------------------------------- -- REORDERING - TYPES -- -------------------------------------------------------------------------------- -- | The monad used to reorder statements within a sequence such that its -- GPUBody statements can be merged into as few possible kernels. type ReorderM = StateT State PassM -- | The state used by a 'ReorderM' monad. data State = State { -- | All statements that already have been processed from the sequence, -- divided into alternating groups of non-GPUBody and GPUBody statements. -- Blocks at even indices only contain non-GPUBody statements. Blocks at -- odd indices only contain GPUBody statements. stateBlocks :: Blocks, stateEquivalents :: EquivalenceTable } -- | A map from variable tags to t'SubExp's returned from within GPUBodies. type EquivalenceTable = IM.IntMap Entry -- | An entry in an 'EquivalenceTable'. data Entry = Entry { -- | A value returned from within a GPUBody kernel. -- In @let res = gpu { x }@ this is @x@. entryValue :: SubExp, -- | The type of the 'entryValue'. entryType :: Type, -- | The name of the variable that binds the return value for 'entryValue'. -- In @let res = gpu { x }@ this is @res@. entryResult :: VName, -- | The index of the group that `entryResult` is bound in. entryBlockIdx :: Int, -- | If 'False' then the entry key is a variable that binds the same value -- as the 'entryValue'. Otherwise it binds an array with an outer dimension -- of one whose row equals that value. entryStored :: Bool } type Blocks = SQ.Seq Group -- | A group is a subsequence of statements, usually either only GPUBody -- statements or only non-GPUBody statements. The 'Usage' statistics of those -- statements are also stored. data Group = Group { -- | The statements of the group. groupStms :: Stms GPU, -- | The usage statistics of the statements within the group. groupUsage :: Usage } -- | Usage statistics for some set of statements. data Usage = Usage { -- | The variables that the statements bind. usageBindings :: Bindings, -- | The variables that the statements depend upon, i.e. the free variables -- of each statement and the root aliases of every array that they observe. usageDependencies :: Dependencies } instance Semigroup Group where (Group s1 u1) <> (Group s2 u2) = Group (s1 <> s2) (u1 <> u2) instance Monoid Group where mempty = Group {groupStms = mempty, groupUsage = mempty} instance Semigroup Usage where (Usage b1 d1) <> (Usage b2 d2) = Usage (b1 <> b2) (d1 <> d2) instance Monoid Usage where mempty = Usage {usageBindings = mempty, usageDependencies = mempty} -------------------------------------------------------------------------------- -- REORDERING - FUNCTIONS -- -------------------------------------------------------------------------------- -- | Return the usage bindings of the group. groupBindings :: Group -> Bindings groupBindings = usageBindings . groupUsage -- | Return the usage dependencies of the group. groupDependencies :: Group -> Dependencies groupDependencies = usageDependencies . groupUsage -- | An initial state to use when running a 'ReorderM' monad. initialState :: State initialState = State { stateBlocks = SQ.singleton mempty, stateEquivalents = mempty } -- | Modify the groups that the sequence has been split into so far. modifyBlocks :: (Blocks -> Blocks) -> ReorderM () modifyBlocks f = modify $ \st -> st {stateBlocks = f (stateBlocks st)} -- | Remove these keys from the equivalence table. removeEquivalents :: IS.IntSet -> ReorderM () removeEquivalents keys = modify $ \st -> let eqs' = stateEquivalents st `IM.withoutKeys` keys in st {stateEquivalents = eqs'} -- | Add an entry to the equivalence table. recordEquivalent :: VName -> Entry -> ReorderM () recordEquivalent n entry = modify $ \st -> let eqs = stateEquivalents st eqs' = IM.insert (baseTag n) entry eqs in st {stateEquivalents = eqs'} -- | Moves a GPUBody statement to the furthest possible group of the statement -- sequence, possibly a new group at the end of sequence. -- -- To simplify consumption handling a GPUBody is not allowed to merge with a -- kernel whose result it consumes. Such GPUBody may therefore not be moved -- into the same group as such kernel. moveGPUBody :: Stm GPU -> Usage -> Consumption -> ReorderM () moveGPUBody stm usage consumed = do -- Replace dependencies with their GPUBody result equivalents. eqs <- gets stateEquivalents let g i = maybe i (baseTag . entryResult) (IM.lookup i eqs) let deps' = IS.map g (usageDependencies usage) let usage' = usage {usageDependencies = deps'} -- Move the GPUBody. grps <- gets stateBlocks let f = groupBlocks usage' consumed let idx = fromMaybe 1 (SQ.findIndexR f grps) let idx' = case idx `mod` 2 of 0 -> idx + 1 _ | consumes idx grps -> idx + 2 _ -> idx modifyBlocks $ moveToGrp (stm, usage) idx' -- Record the kernel equivalents of the bound results. let pes = patElems (stmPat stm) let Op (GPUBody _ (Body _ _ res)) = stmExp stm mapM_ (stores idx') (zip pes (map resSubExp res)) where consumes idx grps | Just grp <- SQ.lookup idx grps = not $ IS.disjoint (groupBindings grp) consumed | otherwise = False stores idx (PatElem n t, se) | Just row_t <- peelArray 1 t = recordEquivalent n $ Entry se row_t n idx True | otherwise = recordEquivalent n $ Entry se t n idx False -- | Moves a non-GPUBody statement to the furthest possible groups of the -- statement sequence, possibly a new group at the end of sequence. moveOther :: Stm GPU -> Usage -> Consumption -> ReorderM () moveOther stm usage consumed = do grps <- gets stateBlocks let f = groupBlocks usage consumed let idx = fromMaybe 0 (SQ.findIndexR f grps) let idx' = ((idx + 1) `div` 2) * 2 modifyBlocks $ moveToGrp (stm, usage) idx' recordEquivalentsOf stm idx' -- | @recordEquivalentsOf stm idx@ records the GPUBody result and/or return -- value that @stm@ is equivalent to. @idx@ is the index of the group that @stm@ -- belongs to. -- -- A GPUBody can have a dependency substituted with a result equivalent if it -- merges with the source GPUBody, allowing it to be moved beyond the binding -- site of that dependency. -- -- To guarantee that a GPUBody which moves beyond a dependency also merges with -- its source GPUBody, equivalents are only allowed to be recorded for results -- bound within the group at index @idx-1@. recordEquivalentsOf :: Stm GPU -> Int -> ReorderM () recordEquivalentsOf stm idx = do eqs <- gets stateEquivalents case stm of Let (Pat [PatElem x _]) _ (BasicOp (SubExp (Var n))) | Just entry <- IM.lookup (baseTag n) eqs, entryBlockIdx entry == idx - 1 -> recordEquivalent x entry Let (Pat [PatElem x _]) _ (BasicOp (Index arr slice)) | Just entry <- IM.lookup (baseTag arr) eqs, entryBlockIdx entry == idx - 1, Slice (DimFix i : dims) <- slice, i == intConst Int64 0, dims == map sliceDim (arrayDims $ entryType entry) -> recordEquivalent x (entry {entryStored = False}) _ -> pure () -- | Does this group block a statement with this usage/consumption statistics -- from being moved past it? groupBlocks :: Usage -> Consumption -> Group -> Bool groupBlocks usage consumed grp = let bound = groupBindings grp deps = groupDependencies grp used = usageDependencies usage in not (IS.disjoint bound used && IS.disjoint deps consumed) -- | @moveToGrp stm idx grps@ moves @stm@ into the group at index @idx@ of -- @grps@. moveToGrp :: (Stm GPU, Usage) -> Int -> Blocks -> Blocks moveToGrp stm idx grps | idx >= SQ.length grps = moveToGrp stm idx (grps |> mempty) | otherwise = SQ.adjust' (stm `moveTo`) idx grps -- | Adds the statement and its usage statistics to the group. moveTo :: (Stm GPU, Usage) -> Group -> Group moveTo (stm, usage) grp = grp { groupStms = groupStms grp |> stm, groupUsage = groupUsage grp <> usage } -------------------------------------------------------------------------------- -- MERGING GPU BODIES - TYPES -- -------------------------------------------------------------------------------- -- | The monad used for rewriting a GPUBody to use the t'SubExp's that are -- returned from kernels it is merged with rather than the results that they -- bind. -- -- The state is a prologue of statements to be added at the beginning of the -- rewritten kernel body. type RewriteM = StateT (Stms GPU) ReorderM -------------------------------------------------------------------------------- -- MERGING GPU BODIES - FUNCTIONS -- -------------------------------------------------------------------------------- -- | Collapses the processed sequence of groups into a single group and returns -- it, merging GPUBody groups into single kernels in the process. collapse :: ReorderM Group collapse = do grps <- zip (cycle [False, True]) . toList <$> gets stateBlocks grp <- foldM clps mempty grps modify $ \st -> st {stateBlocks = SQ.singleton grp} pure grp where clps grp0 (gpu_bodies, Group stms usage) = do grp1 <- if gpu_bodies then Group <$> mergeKernels stms <*> pure usage else pure (Group stms usage) -- Remove equivalents that no longer are relevant for rewriting GPUBody -- kernels. This ensures that they are not substituted in later kernels -- where the replacement variables might not be in scope. removeEquivalents (groupBindings grp1) pure (grp0 <> grp1) -- | Merges a sequence of GPUBody statements into a single kernel. mergeKernels :: Stms GPU -> ReorderM (Stms GPU) mergeKernels stms | SQ.length stms < 2 = pure stms | otherwise = SQ.singleton <$> foldrM merge empty stms where empty = Let mempty (StmAux mempty mempty ()) noop noop = Op (GPUBody [] (Body () SQ.empty [])) merge :: Stm GPU -> Stm GPU -> ReorderM (Stm GPU) merge stm0 stm1 | Let pat0 (StmAux cs0 attrs0 _) (Op (GPUBody types0 body)) <- stm0, Let pat1 (StmAux cs1 attrs1 _) (Op (GPUBody types1 body1)) <- stm1 = do Body _ stms0 res0 <- execRewrite (rewriteBody body) let Body _ stms1 res1 = body1 pat' = pat0 <> pat1 aux' = StmAux (cs0 <> cs1) (attrs0 <> attrs1) () types' = types0 ++ types1 body' = Body () (stms0 <> stms1) (res0 <> res1) in pure (Let pat' aux' (Op (GPUBody types' body'))) merge _ _ = compilerBugS "mergeGPUBodies: cannot merge non-GPUBody statements" -- | Perform a rewrite and finish it by adding the rewrite prologue to the start -- of the body. execRewrite :: RewriteM (Body GPU) -> ReorderM (Body GPU) execRewrite m = evalStateT m' SQ.empty where m' = do Body _ stms res <- m prologue <- get pure (Body () (prologue <> stms) res) -- | Return the equivalence table. equivalents :: RewriteM EquivalenceTable equivalents = lift (gets stateEquivalents) rewriteBody :: Body GPU -> RewriteM (Body GPU) rewriteBody (Body _ stms res) = Body () <$> rewriteStms stms <*> rewriteResult res rewriteStms :: Stms GPU -> RewriteM (Stms GPU) rewriteStms = mapM rewriteStm rewriteStm :: Stm GPU -> RewriteM (Stm GPU) rewriteStm (Let (Pat pes) (StmAux cs attrs _) e) = do pat' <- Pat <$> mapM rewritePatElem pes cs' <- rewriteCerts cs e' <- rewriteExp e pure $ Let pat' (StmAux cs' attrs ()) e' rewritePatElem :: PatElem Type -> RewriteM (PatElem Type) rewritePatElem (PatElem n t) = PatElem n <$> rewriteType t rewriteExp :: Exp GPU -> RewriteM (Exp GPU) rewriteExp e = do eqs <- equivalents case e of BasicOp (Index arr slice) | Just entry <- IM.lookup (baseTag arr) eqs, DimFix idx : dims <- unSlice slice, idx == intConst Int64 0 -> let se = entryValue entry in pure . BasicOp $ case (dims, se) of ([], _) -> SubExp se (_, Var src) -> Index src (Slice dims) _ -> compilerBugS "rewriteExp: bad equivalence entry" _ -> mapExpM rewriter e where rewriter = Mapper { mapOnSubExp = rewriteSubExp, mapOnBody = const rewriteBody, mapOnVName = rewriteName, mapOnRetType = rewriteExtType, mapOnBranchType = rewriteExtType, mapOnFParam = rewriteParam, mapOnLParam = rewriteParam, mapOnOp = const opError } opError = compilerBugS "rewriteExp: unhandled HostOp in GPUBody" rewriteResult :: Result -> RewriteM Result rewriteResult = mapM rewriteSubExpRes rewriteSubExpRes :: SubExpRes -> RewriteM SubExpRes rewriteSubExpRes (SubExpRes cs se) = SubExpRes <$> rewriteCerts cs <*> rewriteSubExp se rewriteCerts :: Certs -> RewriteM Certs rewriteCerts (Certs cs) = Certs <$> mapM rewriteName cs rewriteType :: TypeBase Shape u -> RewriteM (TypeBase Shape u) -- Note: mapOnType also maps the VName token of accumulators rewriteType = mapOnType rewriteSubExp rewriteExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u) -- Note: mapOnExtType also maps the VName token of accumulators rewriteExtType = mapOnExtType rewriteSubExp rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u)) rewriteParam (Param attrs n t) = Param attrs n <$> rewriteType t rewriteSubExp :: SubExp -> RewriteM SubExp rewriteSubExp (Constant c) = pure (Constant c) rewriteSubExp (Var n) = do eqs <- equivalents case IM.lookup (baseTag n) eqs of Nothing -> pure (Var n) Just (Entry se _ _ _ False) -> pure se Just (Entry se t _ _ True) -> Var <$> asArray se t rewriteName :: VName -> RewriteM VName rewriteName n = do se <- rewriteSubExp (Var n) case se of Var n' -> pure n' Constant c -> referConst c -- | @asArray se t@ adds @let x = [se]@ to the rewrite prologue and returns the -- name of @x@. @t@ is the type of @se@. asArray :: SubExp -> Type -> RewriteM VName asArray se row_t = do name <- newName "arr" let t = row_t `arrayOfRow` intConst Int64 1 let pat = Pat [PatElem name t] let aux = StmAux mempty mempty () let e = BasicOp (ArrayLit [se] row_t) modify (|> Let pat aux e) pure name -- | @referConst c@ adds @let x = c@ to the rewrite prologue and returns the -- name of @x@. referConst :: PrimValue -> RewriteM VName referConst c = do name <- newName "cnst" let t = Prim (primValueType c) let pat = Pat [PatElem name t] let aux = StmAux mempty mempty () let e = BasicOp (SubExp $ Constant c) modify (|> Let pat aux e) pure name -- | Produce a fresh name, using the given string as a template. newName :: String -> RewriteM VName newName s = lift $ lift (newNameFromString s) futhark-0.25.27/src/Futhark/Optimise/ReduceDeviceSyncs.hs000066400000000000000000001031571475065116200232400ustar00rootroot00000000000000-- | This module implements an optimization that migrates host -- statements into 'GPUBody' kernels to reduce the number of -- host-device synchronizations that occur when a scalar variable is -- written to or read from device memory. Which statements that should -- be migrated are determined by a 'MigrationTable' produced by the -- "Futhark.Optimise.ReduceDeviceSyncs.MigrationTable" module; this module -- merely performs the migration and rewriting dictated by that table. module Futhark.Optimise.ReduceDeviceSyncs (reduceDeviceSyncs) where import Control.Monad import Control.Monad.Reader import Control.Monad.State hiding (State) import Data.Bifunctor (second) import Data.Foldable import Data.IntMap.Strict qualified as IM import Data.List (transpose, zip4) import Data.Map.Strict qualified as M import Data.Sequence ((><), (|>)) import Data.Text qualified as T import Futhark.Construct (fullSlice, mkBody, sliceDim) import Futhark.Error import Futhark.IR.GPU import Futhark.MonadFreshNames import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable import Futhark.Pass import Futhark.Transform.Substitute -- | An optimization pass that migrates host statements into 'GPUBody' kernels -- to reduce the number of host-device synchronizations. reduceDeviceSyncs :: Pass GPU GPU reduceDeviceSyncs = Pass "reduce device synchronizations" "Move host statements to device to reduce blocking memory operations." $ \prog -> do let hof = hostOnlyFunDefs $ progFuns prog consts_mt = analyseConsts hof (progFuns prog) (progConsts prog) consts <- onConsts consts_mt $ progConsts prog funs <- parPass (onFun hof consts_mt) (progFuns prog) pure $ prog {progConsts = consts, progFuns = funs} where onConsts consts_mt stms = runReduceM consts_mt (optimizeStms stms) onFun hof consts_mt fd = do let mt = consts_mt <> analyseFunDef hof fd runReduceM mt (optimizeFunDef fd) -------------------------------------------------------------------------------- -- AD HOC OPTIMIZATION -- -------------------------------------------------------------------------------- -- | Optimize a function definition. Its type signature will remain unchanged. optimizeFunDef :: FunDef GPU -> ReduceM (FunDef GPU) optimizeFunDef fd = do let body = funDefBody fd stms' <- optimizeStms (bodyStms body) pure $ fd {funDefBody = body {bodyStms = stms'}} -- | Optimize a body. Scalar results may be replaced with single-element arrays. optimizeBody :: Body GPU -> ReduceM (Body GPU) optimizeBody (Body _ stms res) = do stms' <- optimizeStms stms res' <- resolveResult res pure (Body () stms' res') -- | Optimize a sequence of statements. optimizeStms :: Stms GPU -> ReduceM (Stms GPU) optimizeStms = foldM optimizeStm mempty -- | Optimize a single statement, rewriting it into one or more statements to -- be appended to the provided 'Stms'. Only variables with continued host usage -- will remain in scope if their statement is migrated. optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU) optimizeStm out stm = do move <- asks (shouldMoveStm stm) if move then moveStm out stm else case stmExp stm of BasicOp (Update safety arr slice (Var v)) | Just _ <- sliceIndices slice -> do -- Rewrite the Update if its write value has been migrated. Copying -- is faster than doing a synchronous write, so we use the device -- array even if the value has been made available to the host. dev <- storedScalar (Var v) case dev of Nothing -> pure (out |> stm) Just dst -> do -- Transform the single element Update into a slice Update. let dims = unSlice slice let (outer, [DimFix i]) = splitAt (length dims - 1) dims let one = intConst Int64 1 let slice' = Slice $ outer ++ [DimSlice i one one] let e = BasicOp (Update safety arr slice' (Var dst)) let stm' = stm {stmExp = e} pure (out |> stm') BasicOp (Replicate (Shape dims) (Var v)) | Pat [PatElem n arr_t] <- stmPat stm -> do -- A Replicate can be rewritten to not require its replication value -- to be available on host. If its value is migrated the Replicate -- thus needs to be transformed. -- -- If the inner dimension of the replication array is one then the -- rewrite can be performed more efficiently than the general case. v' <- resolveName v let v_kept_on_device = v /= v' gpubody_ok <- gets stateGPUBodyOk case v_kept_on_device of False -> pure (out |> stm) True | all (== intConst Int64 1) dims, Just t' <- peelArray 1 arr_t, gpubody_ok -> do let n' = VName (baseName n `withSuffix` "_inner") 0 let pat' = Pat [PatElem n' t'] let e' = BasicOp $ Replicate (Shape $ tail dims) (Var v) let stm' = Let pat' (stmAux stm) e' -- `gpu { v }` is slightly faster than `replicate 1 v` and -- can merge with the GPUBody that v was computed by. gpubody <- inGPUBody (rewriteStm stm') pure (out |> gpubody {stmPat = stmPat stm}) True | last dims == intConst Int64 1 -> let e' = BasicOp $ Replicate (Shape $ init dims) (Var v') stm' = stm {stmExp = e'} in pure (out |> stm') True -> do n' <- newName n -- v_kept_on_device implies that v is a scalar. let dims' = dims ++ [intConst Int64 1] let arr_t' = Array (elemType arr_t) (Shape dims') NoUniqueness let pat' = Pat [PatElem n' arr_t'] let e' = BasicOp $ Replicate (Shape dims) (Var v') let repl = Let pat' (stmAux stm) e' let aux = StmAux mempty mempty () let slice = map sliceDim (arrayDims arr_t) let slice' = slice ++ [DimFix $ intConst Int64 0] let idx = BasicOp $ Index n' (Slice slice') let index = Let (stmPat stm) aux idx pure (out |> repl |> index) BasicOp {} -> pure (out |> stm) Apply {} -> pure (out |> stm) Match ses cases defbody (MatchDec btypes sort) -> do -- Rewrite branches. cases_stms <- mapM (optimizeStms . bodyStms . caseBody) cases let cases_res = map (bodyResult . caseBody) cases defbody_stms <- optimizeStms $ bodyStms defbody let defbody_res = bodyResult defbody -- Ensure return values and types match if one or both branches -- return a result that now reside on device. let bmerge (acc, all_stms) (pe, reses, bt) = do let onHost (Var v) = (v ==) <$> resolveName v onHost _ = pure True on_host <- and <$> mapM (onHost . resSubExp) reses if on_host then -- No result resides on device ==> nothing to do. pure ((pe, reses, bt) : acc, all_stms) else do -- Otherwise, ensure all results are migrated. (all_stms', arrs) <- fmap unzip $ forM (zip all_stms reses) $ \(stms, res) -> storeScalar stms (resSubExp res) (patElemType pe) pe' <- arrayizePatElem pe let bt' = staticShapes1 (patElemType pe') reses' = zipWith SubExpRes (map resCerts reses) (map Var arrs) pure ((pe', reses', bt') : acc, all_stms') pes = patElems (stmPat stm) (acc, ~(defbody_stms' : cases_stms')) <- foldM bmerge ([], defbody_stms : cases_stms) $ zip3 pes (transpose $ defbody_res : cases_res) btypes let (pes', reses, btypes') = unzip3 (reverse acc) -- Rewrite statement. let cases' = zipWith Case (map casePat cases) $ zipWith mkBody cases_stms' $ drop 1 $ transpose reses defbody' = mkBody defbody_stms' $ map head reses e' = Match ses cases' defbody' (MatchDec btypes' sort) stm' = Let (Pat pes') (stmAux stm) e' -- Read migrated scalars that are used on host. foldM addRead (out |> stm') (zip pes pes') Loop params lform body -> do -- Update statement bound variables and parameters if their values -- have been migrated to device. let lmerge (res, stms, rebinds) (pe, param, StayOnHost) = pure ((pe, param) : res, stms, rebinds) lmerge (res, stms, rebinds) (pe, (Param _ pn pt, pval), _) = do -- Migrate the bound variable. pe' <- arrayizePatElem pe -- Move the initial value to device if not already there to -- ensure that the parameter argument and loop return value -- converge. (stms', arr) <- storeScalar stms pval (fromDecl pt) -- Migrate the parameter. pn' <- newName pn let pt' = toDecl (patElemType pe') Nonunique let pval' = Var arr let param' = (Param mempty pn' pt', pval') -- Record the migration and rebind the parameter inside the -- loop body if necessary. rebinds' <- (pe {patElemName = pn}) `migratedTo` (pn', rebinds) pure ((pe', param') : res, stms', rebinds') mt <- ask let pes = patElems (stmPat stm) let mss = map (\(Param _ n _, _) -> statusOf n mt) params (zipped', out', rebinds) <- foldM lmerge ([], out, mempty) (zip3 pes params mss) let (pes', params') = unzip (reverse zipped') -- Rewrite body. let body1 = body {bodyStms = rebinds >< bodyStms body} body2 <- optimizeBody body1 let zipped = zip4 mss (bodyResult body2) (map resSubExp $ bodyResult body) (map patElemType pes) let rstore (bstms, res) (StayOnHost, r, _, _) = pure (bstms, r : res) rstore (bstms, res) (_, SubExpRes certs _, se, t) = do (bstms', dev) <- storeScalar bstms se t pure (bstms', SubExpRes certs (Var dev) : res) (bstms, res) <- foldM rstore (bodyStms body2, []) zipped let body3 = body2 {bodyStms = bstms, bodyResult = reverse res} -- Rewrite statement. let e' = Loop params' lform body3 let stm' = Let (Pat pes') (stmAux stm) e' -- Read migrated scalars that are used on host. foldM addRead (out' |> stm') (zip pes pes') WithAcc inputs lmd -> do let getAcc (Acc a _ _ _) = a getAcc _ = compilerBugS "Type error: WithAcc expression did not return accumulator." let accs = zipWith (\t i -> (getAcc t, i)) (lambdaReturnType lmd) inputs inputs' <- mapM (uncurry optimizeWithAccInput) accs let body = lambdaBody lmd stms' <- optimizeStms (bodyStms body) let rewrite (SubExpRes certs se, t, pe) = do se' <- resolveSubExp se if se == se' then pure (SubExpRes certs se, t, pe) else do pe' <- arrayizePatElem pe let t' = patElemType pe' pure (SubExpRes certs se', t', pe') -- Rewrite non-accumulator results that have been migrated. -- -- Accumulator return values do not map to arrays one-to-one but -- one-to-many. They are not transformed however and can be mapped -- as a no-op. let len = length inputs let (res0, res1) = splitAt len (bodyResult body) let (rts0, rts1) = splitAt len (lambdaReturnType lmd) let pes = patElems (stmPat stm) let (pes0, pes1) = splitAt (length pes - length res1) pes (res1', rts1', pes1') <- unzip3 <$> mapM rewrite (zip3 res1 rts1 pes1) let res' = res0 ++ res1' let rts' = rts0 ++ rts1' let pes' = pes0 ++ pes1' -- Rewrite statement. let body' = Body () stms' res' let lmd' = lmd {lambdaBody = body', lambdaReturnType = rts'} let e' = WithAcc inputs' lmd' let stm' = Let (Pat pes') (stmAux stm) e' -- Read migrated scalars that are used on host. foldM addRead (out |> stm') (zip pes pes') Op op -> do op' <- optimizeHostOp op pure (out |> stm {stmExp = Op op'}) where addRead stms (pe@(PatElem n _), PatElem dev _) | n == dev = pure stms | otherwise = pe `migratedTo` (dev, stms) -- | Optimize an accumulator input. The 'VName' is the accumulator token. optimizeWithAccInput :: VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU) optimizeWithAccInput _ (shape, arrs, Nothing) = pure (shape, arrs, Nothing) optimizeWithAccInput acc (shape, arrs, Just (op, nes)) = do device_only <- asks (shouldMove acc) if device_only then do op' <- addReadsToLambda op pure (shape, arrs, Just (op', nes)) else do let body = lambdaBody op -- To pass type check neither parameters nor results can change. -- -- op may be used on both host and device so we must avoid introducing -- any GPUBody statements. stms' <- noGPUBody $ optimizeStms (bodyStms body) let op' = op {lambdaBody = body {bodyStms = stms'}} pure (shape, arrs, Just (op', nes)) -- | Optimize a host operation. 'Index' statements are added to kernel code -- that depends on migrated scalars. optimizeHostOp :: HostOp op GPU -> ReduceM (HostOp op GPU) optimizeHostOp (SegOp (SegMap lvl space types kbody)) = SegOp . SegMap lvl space types <$> addReadsToKernelBody kbody optimizeHostOp (SegOp (SegRed lvl space ops types kbody)) = do ops' <- mapM addReadsToSegBinOp ops SegOp . SegRed lvl space ops' types <$> addReadsToKernelBody kbody optimizeHostOp (SegOp (SegScan lvl space ops types kbody)) = do ops' <- mapM addReadsToSegBinOp ops SegOp . SegScan lvl space ops' types <$> addReadsToKernelBody kbody optimizeHostOp (SegOp (SegHist lvl space ops types kbody)) = do ops' <- mapM addReadsToHistOp ops SegOp . SegHist lvl space ops' types <$> addReadsToKernelBody kbody optimizeHostOp (SizeOp op) = pure (SizeOp op) optimizeHostOp OtherOp {} = -- These should all have been taken care of in the unstreamGPU pass. compilerBugS "optimizeHostOp: unhandled OtherOp" optimizeHostOp (GPUBody types body) = GPUBody types <$> addReadsToBody body -------------------------------------------------------------------------------- -- COMMON HELPERS -- -------------------------------------------------------------------------------- -- | Append the given string to a name. withSuffix :: Name -> String -> Name withSuffix name sfx = nameFromText $ T.append (nameToText name) (T.pack sfx) -------------------------------------------------------------------------------- -- MIGRATION - TYPES -- -------------------------------------------------------------------------------- -- | The monad used to perform migration-based synchronization reductions. newtype ReduceM a = ReduceM (StateT State (Reader MigrationTable) a) deriving ( Functor, Applicative, Monad, MonadState State, MonadReader MigrationTable ) runReduceM :: (MonadFreshNames m) => MigrationTable -> ReduceM a -> m a runReduceM mt (ReduceM m) = modifyNameSource $ \src -> second stateNameSource (runReader (runStateT m (initialState src)) mt) instance MonadFreshNames ReduceM where getNameSource = gets stateNameSource putNameSource src = modify $ \s -> s {stateNameSource = src} -- | The state used by a 'ReduceM' monad. data State = State { -- | A source to generate new 'VName's from. stateNameSource :: VNameSource, -- | A table of variables in the original program which have been migrated -- to device. Each variable maps to a tuple that describes: -- * 'baseName' of the original variable. -- * Type of the original variable. -- * Name of the single element array holding the migrated value. -- * Whether the original variable still can be used on the host. stateMigrated :: IM.IntMap (Name, Type, VName, Bool), -- | Whether non-migration optimizations may introduce 'GPUBody' kernels at -- the current location. stateGPUBodyOk :: Bool } -------------------------------------------------------------------------------- -- MIGRATION - PRIMITIVES -- -------------------------------------------------------------------------------- -- | An initial state to use when running a 'ReduceM' monad. initialState :: VNameSource -> State initialState ns = State { stateNameSource = ns, stateMigrated = mempty, stateGPUBodyOk = True } -- | Perform non-migration optimizations without introducing any GPUBody -- kernels. noGPUBody :: ReduceM a -> ReduceM a noGPUBody m = do prev <- gets stateGPUBodyOk modify $ \st -> st {stateGPUBodyOk = False} res <- m modify $ \st -> st {stateGPUBodyOk = prev} pure res -- | Create a 'PatElem' that binds the array of a migrated variable binding. arrayizePatElem :: PatElem Type -> ReduceM (PatElem Type) arrayizePatElem (PatElem n t) = do let name = baseName n `withSuffix` "_dev" dev <- newName (VName name 0) let dev_t = t `arrayOfRow` intConst Int64 1 pure (PatElem dev dev_t) -- | @x `movedTo` arr@ registers that the value of @x@ has been migrated to -- @arr[0]@. movedTo :: Ident -> VName -> ReduceM () movedTo = recordMigration False -- | @x `aliasedBy` arr@ registers that the value of @x@ also is available on -- device as @arr[0]@. aliasedBy :: Ident -> VName -> ReduceM () aliasedBy = recordMigration True -- | @recordMigration host x arr@ records the migration of variable @x@ to -- @arr[0]@. If @host@ then the original binding can still be used on host. recordMigration :: Bool -> Ident -> VName -> ReduceM () recordMigration host (Ident x t) arr = modify $ \st -> let migrated = stateMigrated st entry = (baseName x, t, arr, host) migrated' = IM.insert (baseTag x) entry migrated in st {stateMigrated = migrated'} -- | @pe `migratedTo` (dev, stms)@ registers that the variable @pe@ in the -- original program has been migrated to @dev@ and rebinds the variable if -- deemed necessary, adding an index statement to the given statements. migratedTo :: PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU) migratedTo pe (dev, stms) = do used <- asks (usedOnHost $ patElemName pe) if used then patElemIdent pe `aliasedBy` dev >> pure (stms |> bind pe (eIndex dev)) else patElemIdent pe `movedTo` dev >> pure stms -- | @useScalar stms n@ returns a variable that binds the result bound by @n@ -- in the original program. If the variable has been migrated to device and have -- not been copied back to host a new variable binding will be added to the -- provided statements and be returned. useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName) useScalar stms n = do entry <- gets $ IM.lookup (baseTag n) . stateMigrated case entry of Nothing -> pure (stms, n) Just (_, _, _, True) -> pure (stms, n) Just (name, t, arr, _) -> do n' <- newName (VName name 0) let stm = bind (PatElem n' t) (eIndex arr) pure (stms |> stm, n') -- | Create an expression that reads the first element of a 1-dimensional array. eIndex :: VName -> Exp GPU eIndex arr = BasicOp $ Index arr (Slice [DimFix $ intConst Int64 0]) -- | A shorthand for binding a single variable to an expression. bind :: PatElem Type -> Exp GPU -> Stm GPU bind pe = Let (Pat [pe]) (StmAux mempty mempty ()) -- | Returns the array alias of @se@ if it is a variable that has been migrated -- to device. Otherwise returns @Nothing@. storedScalar :: SubExp -> ReduceM (Maybe VName) storedScalar (Constant _) = pure Nothing storedScalar (Var n) = do entry <- gets $ IM.lookup (baseTag n) . stateMigrated pure $ fmap (\(_, _, arr, _) -> arr) entry -- | @storeScalar stms se t@ returns a variable that binds a single element -- array that contains the value of @se@ in the original program. If @se@ is a -- variable that has been migrated to device, its existing array alias will be -- used. Otherwise a new variable binding will be added to the provided -- statements and be returned. @t@ is the type of @se@. storeScalar :: Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName) storeScalar stms se t = do entry <- case se of Var n -> gets $ IM.lookup (baseTag n) . stateMigrated _ -> pure Nothing case entry of Just (_, _, arr, _) -> pure (stms, arr) Nothing -> do -- How to most efficiently create an array containing the given value -- depends on whether it is a variable or a constant. Creating a constant -- array is a runtime copy of static memory, while creating an array that -- contains a variable results in a synchronous write. The latter is thus -- replaced with either a mergeable GPUBody kernel or a Replicate. -- -- Whether it makes sense to hoist arrays out of bodies to enable CSE is -- left to the simplifier to figure out. Duplicates will be eliminated if -- a scalar is stored multiple times within a body. -- -- TODO: Enable the simplifier to hoist non-consumed, non-returned arrays -- out of top-level function definitions. All constant arrays -- produced here are in principle top-level hoistable. gpubody_ok <- gets stateGPUBodyOk case se of Var n | gpubody_ok -> do n' <- newName n let stm = bind (PatElem n' t) (BasicOp $ SubExp se) gpubody <- inGPUBody (pure stm) let dev = patElemName $ head $ patElems (stmPat gpubody) pure (stms |> gpubody, dev) Var n -> do pe <- arrayizePatElem (PatElem n t) let shape = Shape [intConst Int64 1] let stm = bind pe (BasicOp $ Replicate shape se) pure (stms |> stm, patElemName pe) _ -> do let n = VName (nameFromString "const") 0 pe <- arrayizePatElem (PatElem n t) let stm = bind pe (BasicOp $ ArrayLit [se] t) pure (stms |> stm, patElemName pe) -- | Map a variable name to itself or, if the variable no longer can be used on -- host, the name of a single element array containing its value. resolveName :: VName -> ReduceM VName resolveName n = do entry <- gets $ IM.lookup (baseTag n) . stateMigrated case entry of Nothing -> pure n Just (_, _, _, True) -> pure n Just (_, _, arr, _) -> pure arr -- | Like 'resolveName' but for a t'SubExp'. Constants are mapped to themselves. resolveSubExp :: SubExp -> ReduceM SubExp resolveSubExp (Var n) = Var <$> resolveName n resolveSubExp cnst = pure cnst -- | Like 'resolveSubExp' but for a 'SubExpRes'. resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes resolveSubExpRes (SubExpRes certs se) = -- Certificates are always read back to host. SubExpRes certs <$> resolveSubExp se -- | Apply 'resolveSubExpRes' to a list of results. resolveResult :: Result -> ReduceM Result resolveResult = mapM resolveSubExpRes -- | Migrate a statement to device, ensuring all its bound variables used on -- host will remain available with the same names. moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU) moveStm out (Let pat aux (BasicOp (ArrayLit [se] t'))) | Pat [PatElem n _] <- pat = do -- Save an 'Index' by rewriting the 'ArrayLit' rather than migrating it. let n' = VName (baseName n `withSuffix` "_inner") 0 let pat' = Pat [PatElem n' t'] let e' = BasicOp (SubExp se) let stm' = Let pat' aux e' gpubody <- inGPUBody (rewriteStm stm') pure (out |> gpubody {stmPat = pat}) moveStm out stm = do -- Move the statement to device. gpubody <- inGPUBody (rewriteStm stm) -- Read non-scalars and scalars that are used on host. let arrs = zip (patElems $ stmPat stm) (patElems $ stmPat gpubody) foldM addRead (out |> gpubody) arrs where addRead stms (pe@(PatElem _ t), PatElem dev dev_t) = let add' e = pure $ stms |> bind pe e add = add' . BasicOp in case arrayRank dev_t of -- Alias non-arrays with their prior name. 0 -> add $ SubExp (Var dev) -- Read all certificates for free. 1 | t == Prim Unit -> add' (eIndex dev) -- Record the device alias of each scalar variable and read them -- if used on host. 1 -> pe `migratedTo` (dev, stms) -- Drop the added dimension of multidimensional arrays. _ -> add $ Index dev (fullSlice dev_t [DimFix $ intConst Int64 0]) -- | Create a GPUBody kernel that executes a single statement and stores its -- results in single element arrays. inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU) inGPUBody m = do (stm, st) <- runStateT m initialRState let prologue = rewritePrologue st let pes = patElems (stmPat stm) pat <- Pat <$> mapM arrayizePatElem pes let aux = StmAux mempty mempty () let types = map patElemType pes let res = map (SubExpRes mempty . Var . patElemName) pes let body = Body () (prologue |> stm) res let e = Op (GPUBody types body) pure (Let pat aux e) -------------------------------------------------------------------------------- -- KERNEL REWRITING - TYPES -- -------------------------------------------------------------------------------- -- The monad used to rewrite (migrated) kernel code. type RewriteM = StateT RState ReduceM -- | The state used by a 'RewriteM' monad. data RState = RState { -- | Maps variables in the original program to names to be used by rewrites. rewriteRenames :: IM.IntMap VName, -- | Statements to be added as a prologue before rewritten statements. rewritePrologue :: Stms GPU } -------------------------------------------------------------------------------- -- KERNEL REWRITING - FUNCTIONS -- -------------------------------------------------------------------------------- -- | An initial state to use when running a 'RewriteM' monad. initialRState :: RState initialRState = RState { rewriteRenames = mempty, rewritePrologue = mempty } -- | Rewrite 'SegBinOp' dependencies to scalars that have been migrated. addReadsToSegBinOp :: SegBinOp GPU -> ReduceM (SegBinOp GPU) addReadsToSegBinOp op = do f' <- addReadsToLambda (segBinOpLambda op) pure (op {segBinOpLambda = f'}) -- | Rewrite 'HistOp' dependencies to scalars that have been migrated. addReadsToHistOp :: HistOp GPU -> ReduceM (HistOp GPU) addReadsToHistOp op = do f' <- addReadsToLambda (histOp op) pure (op {histOp = f'}) -- | Rewrite generic lambda dependencies to scalars that have been migrated. addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU) addReadsToLambda f = do body' <- addReadsToBody (lambdaBody f) pure (f {lambdaBody = body'}) -- | Rewrite generic body dependencies to scalars that have been migrated. addReadsToBody :: Body GPU -> ReduceM (Body GPU) addReadsToBody body = do (body', prologue) <- addReadsHelper body pure body' {bodyStms = prologue >< bodyStms body'} -- | Rewrite kernel body dependencies to scalars that have been migrated. addReadsToKernelBody :: KernelBody GPU -> ReduceM (KernelBody GPU) addReadsToKernelBody kbody = do (kbody', prologue) <- addReadsHelper kbody pure kbody' {kernelBodyStms = prologue >< kernelBodyStms kbody'} -- | Rewrite migrated scalar dependencies within anything. The returned -- statements must be added to the scope of the rewritten construct. addReadsHelper :: (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU) addReadsHelper x = do let from = namesToList (freeIn x) (to, st) <- runStateT (mapM rename from) initialRState let rename_map = M.fromList (zip from to) pure (substituteNames rename_map x, rewritePrologue st) -- | Create a fresh name, registering which name it is a rewrite of. rewriteName :: VName -> RewriteM VName rewriteName n = do n' <- lift (newName n) modify $ \st -> st {rewriteRenames = IM.insert (baseTag n) n' (rewriteRenames st)} pure n' -- | Rewrite all bindings introduced by a body (to ensure they are unique) and -- fix any dependencies that are broken as a result of migration or rewriting. rewriteBody :: Body GPU -> RewriteM (Body GPU) rewriteBody (Body _ stms res) = do stms' <- rewriteStms stms res' <- renameResult res pure (Body () stms' res') -- | Rewrite all bindings introduced by a sequence of statements (to ensure they -- are unique) and fix any dependencies that are broken as a result of migration -- or rewriting. rewriteStms :: Stms GPU -> RewriteM (Stms GPU) rewriteStms = foldM rewriteTo mempty where rewriteTo out stm = do stm' <- rewriteStm stm pure $ case stmExp stm' of Op (GPUBody _ (Body _ stms res)) -> let pes = patElems (stmPat stm') in foldl' bnd (out >< stms) (zip pes res) _ -> out |> stm' bnd :: Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU bnd out (pe, SubExpRes cs se) | Just t' <- peelArray 1 (typeOf pe) = out |> Let (Pat [pe]) (StmAux cs mempty ()) (BasicOp $ ArrayLit [se] t') | otherwise = out |> Let (Pat [pe]) (StmAux cs mempty ()) (BasicOp $ SubExp se) -- | Rewrite all bindings introduced by a single statement (to ensure they are -- unique) and fix any dependencies that are broken as a result of migration or -- rewriting. -- -- NOTE: GPUBody kernels must be rewritten through 'rewriteStms'. rewriteStm :: Stm GPU -> RewriteM (Stm GPU) rewriteStm (Let pat aux e) = do e' <- rewriteExp e pat' <- rewritePat pat aux' <- rewriteStmAux aux pure (Let pat' aux' e') -- | Rewrite all bindings introduced by a pattern (to ensure they are unique) -- and fix any dependencies that are broken as a result of migration or -- rewriting. rewritePat :: Pat Type -> RewriteM (Pat Type) rewritePat pat = Pat <$> mapM rewritePatElem (patElems pat) -- | Rewrite the binding introduced by a single pattern element (to ensure it is -- unique) and fix any dependencies that are broken as a result of migration or -- rewriting. rewritePatElem :: PatElem Type -> RewriteM (PatElem Type) rewritePatElem (PatElem n t) = do n' <- rewriteName n t' <- renameType t pure (PatElem n' t') -- | Fix any 'StmAux' certificate references that are broken as a result of -- migration or rewriting. rewriteStmAux :: StmAux () -> RewriteM (StmAux ()) rewriteStmAux (StmAux certs attrs _) = do certs' <- renameCerts certs pure (StmAux certs' attrs ()) -- | Rewrite the bindings introduced by an expression (to ensure they are -- unique) and fix any dependencies that are broken as a result of migration or -- rewriting. rewriteExp :: Exp GPU -> RewriteM (Exp GPU) rewriteExp = mapExpM $ Mapper { mapOnSubExp = renameSubExp, mapOnBody = const rewriteBody, mapOnVName = rename, mapOnRetType = renameExtType, mapOnBranchType = renameExtType, mapOnFParam = rewriteParam, mapOnLParam = rewriteParam, mapOnOp = const opError } where -- This indicates that something fundamentally is wrong with the migration -- table produced by the "Futhark.Analysis.MigrationTable" module. opError = compilerBugS "Cannot migrate a host-only operation to device." -- | Rewrite the binding introduced by a single parameter (to ensure it is -- unique) and fix any dependencies that are broken as a result of migration or -- rewriting. rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u)) rewriteParam (Param attrs n t) = do n' <- rewriteName n t' <- renameType t pure (Param attrs n' t') -- | Return the name to use for a rewritten dependency. rename :: VName -> RewriteM VName rename n = do st <- get let renames = rewriteRenames st let idx = baseTag n case IM.lookup idx renames of Just n' -> pure n' _ -> do let stms = rewritePrologue st (stms', n') <- lift $ useScalar stms n modify $ \st' -> st' { rewriteRenames = IM.insert idx n' renames, rewritePrologue = stms' } pure n' -- | Update the variable names within a 'Result' to account for migration and -- rewriting. renameResult :: Result -> RewriteM Result renameResult = mapM renameSubExpRes -- | Update the variable names within a 'SubExpRes' to account for migration and -- rewriting. renameSubExpRes :: SubExpRes -> RewriteM SubExpRes renameSubExpRes (SubExpRes certs se) = do certs' <- renameCerts certs se' <- renameSubExp se pure (SubExpRes certs' se') -- | Update the variable names of certificates to account for migration and -- rewriting. renameCerts :: Certs -> RewriteM Certs renameCerts cs = Certs <$> mapM rename (unCerts cs) -- | Update any variable name within a t'SubExp' to account for migration and -- rewriting. renameSubExp :: SubExp -> RewriteM SubExp renameSubExp (Var n) = Var <$> rename n renameSubExp se = pure se -- | Update the variable names within a type to account for migration and -- rewriting. renameType :: TypeBase Shape u -> RewriteM (TypeBase Shape u) -- Note: mapOnType also maps the VName token of accumulators renameType = mapOnType renameSubExp -- | Update the variable names within an existential type to account for -- migration and rewriting. renameExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u) -- Note: mapOnExtType also maps the VName token of accumulators renameExtType = mapOnExtType renameSubExp futhark-0.25.27/src/Futhark/Optimise/ReduceDeviceSyncs/000077500000000000000000000000001475065116200226755ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Optimise/ReduceDeviceSyncs/MigrationTable.hs000066400000000000000000001717241475065116200261460ustar00rootroot00000000000000-- | -- This module implements program analysis to determine which program statements -- the "Futhark.Optimise.ReduceDeviceSyncs" pass should move into 'GPUBody' kernels -- to reduce blocking memory transfers between host and device. The results of -- the analysis is encoded into a 'MigrationTable' which can be queried. -- -- To reduce blocking scalar reads the module constructs a data flow -- dependency graph of program variables (see -- "Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph") in which -- it finds a minimum vertex cut that separates array reads of scalars -- from transitive usage that cannot or should not be migrated to -- device. -- -- The variables of each partition are assigned a 'MigrationStatus' that states -- whether the computation of those variables should be moved to device or -- remain on host. Due to how the graph is built and the vertex cut is found all -- variables bound by a single statement will belong to the same partition. -- -- The vertex cut contains all variables that will reside in device memory but -- are required by host operations. These variables must be read from device -- memory and cannot be reduced further in number merely by migrating -- statements (subject to the accuracy of the graph model). The model is built -- to reduce the worst-case number of scalar reads; an optimal migration of -- statements depends on runtime data. -- -- Blocking scalar writes are reduced by either turning such writes into -- asynchronous kernels, as is done with scalar array literals and accumulator -- updates, or by transforming host-device writing into device-device copying. -- -- For details on how the graph is constructed and how the vertex cut is found, -- see the master thesis "Reducing Synchronous GPU Memory Transfers" by Philip -- Børgesen (2022). module Futhark.Optimise.ReduceDeviceSyncs.MigrationTable ( -- * Analysis analyseFunDef, analyseConsts, hostOnlyFunDefs, -- * Types MigrationTable, MigrationStatus (..), -- * Query -- | These functions all assume that no parent statement should be migrated. -- That is @shouldMoveStm stm mt@ should return @False@ for every statement -- @stm@ with a body that a queried 'VName' or 'Stm' is nested within, -- otherwise the query result may be invalid. shouldMoveStm, shouldMove, usedOnHost, statusOf, ) where import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.Reader qualified as R import Control.Monad.Trans.State.Strict () import Control.Monad.Trans.State.Strict hiding (State) import Data.Bifunctor (first, second) import Data.Foldable import Data.IntMap.Strict qualified as IM import Data.IntSet qualified as IS import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe (fromMaybe, isJust, isNothing) import Data.Sequence qualified as SQ import Data.Set (Set, (\\)) import Data.Set qualified as S import Futhark.Error import Futhark.IR.GPU import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph ( EdgeType (..), Edges (..), Id, IdSet, Result (..), Routing (..), Vertex (..), ) import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph qualified as MG -------------------------------------------------------------------------------- -- MIGRATION TABLES -- -------------------------------------------------------------------------------- -- | Where the value bound by a name should be computed. data MigrationStatus = -- | The statement that computes the value should be moved to device. -- No host usage of the value will be left after the migration. MoveToDevice | -- | As 'MoveToDevice' but host usage of the value will remain after -- migration. UsedOnHost | -- | The statement that computes the value should remain on host. StayOnHost deriving (Eq, Ord, Show) -- | Identifies -- -- (1) which statements should be moved from host to device to reduce the -- worst case number of blocking memory transfers. -- -- (2) which migrated variables that still will be used on the host after -- all such statements have been moved. newtype MigrationTable = MigrationTable (IM.IntMap MigrationStatus) instance Semigroup MigrationTable where MigrationTable a <> MigrationTable b = MigrationTable (a `IM.union` b) -- | Where should the value bound by this name be computed? statusOf :: VName -> MigrationTable -> MigrationStatus statusOf n (MigrationTable mt) = fromMaybe StayOnHost $ IM.lookup (baseTag n) mt -- | Should this whole statement be moved from host to device? shouldMoveStm :: Stm GPU -> MigrationTable -> Bool shouldMoveStm (Let (Pat ((PatElem n _) : _)) _ (BasicOp (Index _ slice))) mt = statusOf n mt == MoveToDevice || any movedOperand slice where movedOperand (Var op) = statusOf op mt == MoveToDevice movedOperand _ = False shouldMoveStm (Let (Pat ((PatElem n _) : _)) _ (BasicOp _)) mt = statusOf n mt /= StayOnHost shouldMoveStm (Let (Pat ((PatElem n _) : _)) _ Apply {}) mt = statusOf n mt /= StayOnHost shouldMoveStm (Let _ _ (Match cond _ _ _)) mt = all ((== MoveToDevice) . (`statusOf` mt)) $ subExpVars cond shouldMoveStm (Let _ _ (Loop _ (ForLoop _ _ (Var n)) _)) mt = statusOf n mt == MoveToDevice shouldMoveStm (Let _ _ (Loop _ (WhileLoop n) _)) mt = statusOf n mt == MoveToDevice -- BasicOp and Apply statements might not bind any variables (shouldn't happen). -- If statements might use a constant branch condition. -- For loop statements might use a constant number of iterations. -- HostOp statements cannot execute on device. -- WithAcc statements are never moved in their entirety. shouldMoveStm _ _ = False -- | Should the value bound by this name be computed on device? shouldMove :: VName -> MigrationTable -> Bool shouldMove n mt = statusOf n mt /= StayOnHost -- | Will the value bound by this name be used on host? usedOnHost :: VName -> MigrationTable -> Bool usedOnHost n mt = statusOf n mt /= MoveToDevice -------------------------------------------------------------------------------- -- HOST-ONLY FUNCTION ANALYSIS -- -------------------------------------------------------------------------------- -- | Identifies top-level function definitions that cannot be run on the -- device. The application of any such function is host-only. type HostOnlyFuns = Set Name -- | Returns the names of all top-level functions that cannot be called from the -- device. The evaluation of such a function is host-only. hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns hostOnlyFunDefs funs = let names = map funDefName funs call_map = M.fromList $ zip names (map checkFunDef funs) in S.fromList names \\ keysToSet (removeHostOnly call_map) where keysToSet = S.fromAscList . M.keys removeHostOnly cm = let (host_only, cm') = M.partition isHostOnly cm in if M.null host_only then cm' else removeHostOnly $ M.map (checkCalls $ keysToSet host_only) cm' isHostOnly = isNothing -- A function that calls a host-only function is itself host-only. checkCalls hostOnlyFuns (Just calls) | hostOnlyFuns `S.disjoint` calls = Just calls checkCalls _ _ = Nothing -- | 'checkFunDef' returns 'Nothing' if this function definition uses arrays or -- HostOps. Otherwise it returns the names of all applied functions, which may -- include user defined functions that could turn out to be host-only. checkFunDef :: FunDef GPU -> Maybe (Set Name) checkFunDef fun = do checkFParams $ funDefParams fun checkRetTypes $ map fst $ funDefRetType fun checkBody $ funDefBody fun where hostOnly = Nothing ok = Just () check isArr as = if any isArr as then hostOnly else ok checkFParams = check isArray checkLParams = check (isArray . fst) checkRetTypes = check isArrayType checkPats = check isArray checkBody = checkStms . bodyStms checkStms stms = S.unions <$> mapM checkStm stms checkStm (Let (Pat pats) _ e) = checkPats pats >> checkExp e -- Any expression that produces an array is caught by checkPats checkExp (BasicOp (Index _ _)) = hostOnly checkExp (WithAcc _ _) = hostOnly checkExp (Op _) = hostOnly checkExp (Apply fn _ _ _) = Just (S.singleton fn) checkExp (Match _ cases defbody _) = mconcat <$> mapM checkBody (defbody : map caseBody cases) checkExp (Loop params _ body) = do checkLParams params checkBody body checkExp BasicOp {} = Just S.empty -------------------------------------------------------------------------------- -- MIGRATION ANALYSIS -- -------------------------------------------------------------------------------- -- | HostUsage identifies scalar variables that are used on host. type HostUsage = [Id] nameToId :: VName -> Id nameToId = baseTag -- | Analyses top-level constants. analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable analyseConsts hof funs consts = let usage = M.foldlWithKey (f $ freeIn funs) [] (scopeOf consts) in analyseStms hof usage consts where f free usage n t | isScalar t, n `nameIn` free = nameToId n : usage | otherwise = usage -- | Analyses a top-level function definition. analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable analyseFunDef hof fd = let body = funDefBody fd usage = foldl' f [] $ zip (bodyResult body) (map fst $ funDefRetType fd) stms = bodyStms body in analyseStms hof usage stms where f usage (SubExpRes _ (Var n), t) | isScalarType t = nameToId n : usage f usage _ = usage -- | Analyses statements. The 'HostUsage' list identifies which bound scalar -- variables that subsequently may be used on host. All free variables such as -- constants and function parameters are assumed to reside on host. analyseStms :: HostOnlyFuns -> HostUsage -> Stms GPU -> MigrationTable analyseStms hof usage stms = let (g, srcs, _) = buildGraph hof usage stms (routed, unrouted) = srcs (_, g') = MG.routeMany unrouted g -- hereby routed f st' = MG.fold g' visit st' Normal st = foldl' f (initial, MG.none) unrouted (vr, vn, tn) = fst $ foldl' f st routed in -- TODO: Delay reads into (deeper) branches MigrationTable $ IM.unions [ IM.fromSet (const MoveToDevice) vr, IM.fromSet (const MoveToDevice) vn, -- Read by host if not reached by a reversed edge IM.fromSet (const UsedOnHost) tn ] where -- 1) Visited by reversed edge. -- 2) Visited by normal edge, no route. -- 3) Visited by normal edge, had route; will potentially be read by host. initial = (IS.empty, IS.empty, IS.empty) visit (vr, vn, tn) Reversed v = let vr' = IS.insert (vertexId v) vr in (vr', vn, tn) visit (vr, vn, tn) Normal v@Vertex {vertexRouting = NoRoute} = let vn' = IS.insert (vertexId v) vn in (vr, vn', tn) visit (vr, vn, tn) Normal v = let tn' = IS.insert (vertexId v) tn in (vr, vn, tn') -------------------------------------------------------------------------------- -- TYPE HELPERS -- -------------------------------------------------------------------------------- isScalar :: (Typed t) => t -> Bool isScalar = isScalarType . typeOf isScalarType :: TypeBase shape u -> Bool isScalarType (Prim Unit) = False isScalarType (Prim _) = True isScalarType _ = False isArray :: (Typed t) => t -> Bool isArray = isArrayType . typeOf isArrayType :: (ArrayShape shape) => TypeBase shape u -> Bool isArrayType = (0 <) . arrayRank -------------------------------------------------------------------------------- -- GRAPH BUILDING -- -------------------------------------------------------------------------------- buildGraph :: HostOnlyFuns -> HostUsage -> Stms GPU -> (Graph, Sources, Sinks) buildGraph hof usage stms = let (g, srcs, sinks) = execGrapher hof (graphStms stms) g' = foldl' (flip MG.connectToSink) g usage in (g', srcs, sinks) -- | Graph a body. graphBody :: Body GPU -> Grapher () graphBody body = do let res_ops = namesIntSet $ freeIn (bodyResult body) body_stats <- captureBodyStats $ incBodyDepthFor (graphStms (bodyStms body) >> tellOperands res_ops) body_depth <- (1 +) <$> getBodyDepth let host_only = IS.member body_depth (bodyHostOnlyParents body_stats) modify $ \st -> let stats = stateStats st hops' = IS.delete body_depth (bodyHostOnlyParents stats) -- If body contains a variable that is required on host the parent -- statement that contains this body cannot be migrated as a whole. stats' = if host_only then stats {bodyHostOnly = True} else stats in st {stateStats = stats' {bodyHostOnlyParents = hops'}} -- | Graph multiple statements. graphStms :: Stms GPU -> Grapher () graphStms = mapM_ graphStm -- | Graph a single statement. graphStm :: Stm GPU -> Grapher () graphStm stm = do let bs = boundBy stm let e = stmExp stm -- IMPORTANT! It is generally assumed that all scalars within types and -- shapes are present on host. Any expression of a type wherein one of its -- scalar operands appears must therefore ensure that that scalar operand is -- marked as a size variable (see the 'hostSize' function). case e of BasicOp (SubExp se) -> do graphSimple bs e one bs `reusesSubExp` se BasicOp (Opaque _ se) -> do graphSimple bs e one bs `reusesSubExp` se BasicOp (ArrayLit arr t) | isScalar t, any (isJust . subExpVar) arr -> -- Migrating an array literal with free variables saves a write for -- every scalar it contains. Under some backends the compiler -- generates asynchronous writes for scalar constants but otherwise -- each write will be synchronous. If all scalars are constants then -- the compiler generates more efficient code that copies static -- device memory. graphAutoMove (one bs) BasicOp UnOp {} -> graphSimple bs e BasicOp BinOp {} -> graphSimple bs e BasicOp CmpOp {} -> graphSimple bs e BasicOp ConvOp {} -> graphSimple bs e BasicOp Assert {} -> -- == OpenCL ============================================================= -- -- The next read after the execution of a kernel containing an assertion -- will be made asynchronous, followed by an asynchronous read to check -- if any assertion failed. The runtime will then block for all enqueued -- operations to finish. -- -- Since an assertion only binds a certificate of unit type, an assertion -- cannot increase the number of (read) synchronizations that occur. In -- this regard it is free to migrate. The synchronization that does occur -- is however (presumably) more expensive as the pipeline of GPU work will -- be flushed. -- -- Since this cost is difficult to quantify and amortize over assertion -- migration candidates (cost depends on ordering of kernels and reads) we -- assume it is insignificant. This will likely hold for a system where -- multiple threads or processes schedules GPU work, as system-wide -- throughput only will decrease if the GPU utilization decreases as a -- result. -- -- == CUDA =============================================================== -- -- Under the CUDA backend every read is synchronous and is followed by -- a full synchronization that blocks for all enqueued operations to -- finish. If any enqueued kernel contained an assertion, another -- synchronous read is then made to check if an assertion failed. -- -- Migrating an assertion to save a read may thus introduce new reads, and -- the total number of reads can hence either decrease, remain the same, -- or even increase, subject to the ordering of reads and kernels that -- perform assertions. -- -- Since it is possible to implement the same failure checking scheme as -- OpenCL using asynchronous reads (and doing so would be a good idea!) -- we consider this to be acceptable. -- -- TODO: Implement the OpenCL failure checking scheme under CUDA. This -- should reduce the number of synchronizations per read to one. graphSimple bs e BasicOp (Index _ slice) | isFixed slice -> graphRead (one bs) BasicOp {} | [(_, t)] <- bs, dims <- arrayDims t, dims /= [], -- i.e. produces an array all (== intConst Int64 1) dims -> -- An expression that produces an array that only contains a single -- primitive value is as efficient to compute and copy as a scalar, -- and introduces no size variables. -- -- This is an exception to the inefficiency rules that comes next. graphSimple bs e -- Expressions with a cost sublinear to the size of their result arrays are -- risky to migrate as we cannot guarantee that their results are not -- returned from a GPUBody, which always copies its return values. Since -- this would make the effective asymptotic cost of such statements linear -- we block them from being migrated on their own. -- -- The parent statement of an enclosing body may still be migrated as a -- whole given that each of its returned arrays either -- 1) is backed by memory used by a migratable statement within its body. -- 2) contains just a single element. -- An array matching either criterion is denoted "copyable memory" because -- the asymptotic cost of copying it is less than or equal to the statement -- that produced it. This makes the parent of statements with sublinear cost -- safe to migrate. BasicOp (Index arr s) -> do graphInefficientReturn (sliceDims s) e one bs `reuses` arr BasicOp (Update _ arr slice _) | isFixed slice -> do graphInefficientReturn [] e one bs `reuses` arr BasicOp (FlatIndex arr s) -> do -- Migrating a FlatIndex leads to a memory allocation error. -- -- TODO: Fix FlatIndex memory allocation error. -- -- Can be replaced with 'graphHostOnly e' to disable migration. -- A fix can be verified by enabling tests/migration/reuse2_flatindex.fut graphInefficientReturn (flatSliceDims s) e one bs `reuses` arr BasicOp (FlatUpdate arr _ _) -> do graphInefficientReturn [] e one bs `reuses` arr BasicOp (Scratch _ s) -> -- Migrating a Scratch leads to a memory allocation error. -- -- TODO: Fix Scratch memory allocation error. -- -- Can be replaced with 'graphHostOnly e' to disable migration. -- A fix can be verified by enabling tests/migration/reuse4_scratch.fut graphInefficientReturn s e BasicOp (Reshape _ s arr) -> do graphInefficientReturn (shapeDims s) e one bs `reuses` arr BasicOp (Rearrange _ arr) -> do graphInefficientReturn [] e one bs `reuses` arr -- Expressions with a cost linear to the size of their result arrays are -- inefficient to migrate into GPUBody kernels as such kernels are single- -- threaded. For sufficiently large arrays the cost may exceed what is saved -- by avoiding reads. We therefore also block these from being migrated, -- as well as their parents. BasicOp ArrayLit {} -> -- An array literal purely of primitive constants can be hoisted out to be -- a top-level constant, unless it is to be returned or consumed. -- Otherwise its runtime implementation will copy a precomputed static -- array and thus behave like a 'Copy'. -- Whether the rows are primitive constants or arrays, without any scalar -- variable operands such ArrayLit cannot directly prevent a scalar read. graphHostOnly e BasicOp ArrayVal {} -> -- As above. graphHostOnly e BasicOp Update {} -> graphHostOnly e BasicOp Concat {} -> -- Is unlikely to prevent a scalar read as the only SubExp operand in -- practice is a computation of host-only size variables. graphHostOnly e BasicOp Manifest {} -> -- Takes no scalar operands so cannot directly prevent a scalar read. -- It is introduced as part of the BlkRegTiling kernel optimization and -- is thus unlikely to prevent the migration of a parent which was not -- already blocked by some host-only operation. graphHostOnly e BasicOp Iota {} -> graphHostOnly e BasicOp Replicate {} -> graphHostOnly e -- END BasicOp UpdateAcc {} -> graphUpdateAcc (one bs) e Apply fn _ _ _ -> graphApply fn bs e Match ses cases defbody _ -> graphMatch bs ses cases defbody Loop params lform body -> graphLoop bs params lform body WithAcc inputs f -> graphWithAcc bs inputs f Op GPUBody {} -> -- A GPUBody can be migrated into a parent GPUBody by replacing it with -- its body statements and binding its return values inside 'ArrayLit's. tellGPUBody Op _ -> graphHostOnly e where one [x] = x one _ = compilerBugS "Type error: unexpected number of pattern elements." isFixed = isJust . sliceIndices -- new_dims may introduce new size variables which must be present on host -- when this expression is evaluated. graphInefficientReturn new_dims e = do mapM_ hostSize new_dims graphedScalarOperands e >>= addEdges ToSink hostSize (Var n) = hostSizeVar n hostSize _ = pure () hostSizeVar = requiredOnHost . nameToId -- | Bindings for all pattern elements bound by a statement. boundBy :: Stm GPU -> [Binding] boundBy = map (\(PatElem n t) -> (nameToId n, t)) . patElems . stmPat -- | Graph a statement which in itself neither reads scalars from device memory -- nor forces such scalars to be available on host. Such statement can be moved -- to device to eliminate the host usage of its operands which transitively may -- depend on a scalar device read. graphSimple :: [Binding] -> Exp GPU -> Grapher () graphSimple bs e = do -- Only add vertices to the graph if they have a transitive dependency to -- an array read. Transitive dependencies through variables connected to -- sinks do not count. ops <- graphedScalarOperands e let edges = MG.declareEdges (map fst bs) unless (IS.null ops) (mapM_ addVertex bs >> addEdges edges ops) -- | Graph a statement that reads a scalar from device memory. graphRead :: Binding -> Grapher () graphRead b = do -- Operands are not important as the source will block routes through b. addSource b tellRead -- | Graph a statement that always should be moved to device. graphAutoMove :: Binding -> Grapher () graphAutoMove = -- Operands are not important as the source will block routes through b. addSource -- | Graph a statement that is unfit for execution in a GPUBody and thus must -- be executed on host, requiring all its operands to be made available there. -- Parent statements of enclosing bodies are also blocked from being migrated. graphHostOnly :: Exp GPU -> Grapher () graphHostOnly e = do -- Connect the vertices of all operands to sinks to mark that they are -- required on host. Transitive reads that they depend upon can be delayed -- no further, and any parent statements cannot be migrated. ops <- graphedScalarOperands e addEdges ToSink ops tellHostOnly -- | Graph an 'UpdateAcc' statement. graphUpdateAcc :: Binding -> Exp GPU -> Grapher () graphUpdateAcc b e | (_, Acc a _ _ _) <- b = -- The actual graphing is delayed to the corrensponding 'WithAcc' parent. modify $ \st -> let accs = stateUpdateAccs st accs' = IM.alter add (nameToId a) accs in st {stateUpdateAccs = accs'} where add Nothing = Just [(b, e)] add (Just xs) = Just $ (b, e) : xs graphUpdateAcc _ _ = compilerBugS "Type error: UpdateAcc did not produce accumulator typed value." -- | Graph a function application. graphApply :: Name -> [Binding] -> Exp GPU -> Grapher () graphApply fn bs e = do hof <- isHostOnlyFun fn if hof then graphHostOnly e else graphSimple bs e -- | Graph a Match statement. graphMatch :: [Binding] -> [SubExp] -> [Case (Body GPU)] -> Body GPU -> Grapher () graphMatch bs ses cases defbody = do body_host_only <- incForkDepthFor $ any bodyHostOnly <$> mapM (captureBodyStats . graphBody) (defbody : map caseBody cases) let branch_results = results defbody : map (results . caseBody) cases -- Record aliases for copyable memory backing returned arrays. may_copy_results <- reusesBranches bs branch_results let may_migrate = not body_host_only && may_copy_results cond_id <- if may_migrate then onlyGraphedScalars $ subExpVars ses else do -- The migration status of the condition is what determines -- whether the statement may be migrated as a whole or -- not. See 'shouldMoveStm'. mapM_ (connectToSink . nameToId) (subExpVars ses) pure IS.empty tellOperands cond_id -- Connect branch results to bound variables to allow delaying reads out of -- branches. It might also be beneficial to move the whole statement to -- device, to avoid reading the branch condition value. This must be balanced -- against the need to read the values bound by the if statement. -- -- By connecting the branch condition to each variable bound by the statement -- the condition will only stay on device if -- -- (1) the if statement is not required on host, based on the statements -- within its body. -- -- (2) no additional reads will be required to use the if statement bound -- variables should the whole statement be migrated. -- -- If the condition is migrated to device and stays there, then the if -- statement must necessarily execute on device. -- -- While the graph model built by this module generally migrates no more -- statements than necessary to obtain a minimum vertex cut, the branches -- of if statements are subject to an inaccuracy. Specifically model is not -- strong enough to capture their mutual exclusivity and thus encodes that -- both branches are taken. While this does not affect the resulting number -- of host-device reads it means that some reads may needlessly be delayed -- out of branches. The overhead as measured on futhark-benchmarks appears -- to be neglible though. ret <- mapM (comb cond_id) $ L.transpose branch_results mapM_ (uncurry createNode) (zip bs ret) where results = map resSubExp . bodyResult comb ci a = (ci <>) <$> onlyGraphedScalars (S.fromList $ subExpVars a) ----------------------------------------------------- -- These type aliases are only used by 'graphLoop' -- ----------------------------------------------------- type ReachableBindings = IdSet type ReachableBindingsCache = MG.Visited (MG.Result ReachableBindings) type NonExhausted = [Id] type LoopValue = (Binding, Id, SubExp, SubExp) ----------------------------------------------------- ----------------------------------------------------- -- | Graph a loop statement. graphLoop :: [Binding] -> [(FParam GPU, SubExp)] -> LoopForm -> Body GPU -> Grapher () graphLoop [] _ _ _ = -- We expect each loop to bind a value or be eliminated. compilerBugS "Loop statement bound no variable; should have been eliminated." graphLoop (b : bs) params lform body = do -- Graph loop params and body while capturing statistics. g <- getGraph stats <- captureBodyStats (subgraphId `graphIdFor` graphTheLoop) -- Record aliases for copyable memory backing returned arrays. -- Does the loop return any arrays which prevent it from being migrated? let args = map snd params let results = map resSubExp (bodyResult body) may_copy_results <- reusesBranches (b : bs) [args, results] -- Connect the loop condition to a sink if the loop cannot be migrated, -- ensuring that it will be available to the host. The migration status -- of the condition is what determines whether the loop may be migrated -- as a whole or not. See 'shouldMoveStm'. let may_migrate = not (bodyHostOnly stats) && may_copy_results unless may_migrate $ case lform of ForLoop _ _ (Var n) -> connectToSink (nameToId n) WhileLoop n | Just (_, p, _, res) <- loopValueFor n -> do connectToSink p case res of Var v -> connectToSink (nameToId v) _ -> pure () _ -> pure () -- Connect graphed return values to their loop parameters. mapM_ mergeLoopParam loopValues -- Route the sources within the loop body in isolation. -- The loop graph must not be altered after this point. srcs <- routeSubgraph subgraphId -- Graph the variables bound by the statement. forM_ loopValues $ \(bnd, p, _, _) -> createNode bnd (IS.singleton p) -- If a device read is delayed from one iteration to the next the -- corresponding variables bound by the statement must be treated as -- sources. g' <- getGraph let (dbs, rbc) = foldl' (deviceBindings g') (IS.empty, MG.none) srcs modifySources $ second (IS.toList dbs <>) -- Connect operands to sinks if they can reach a sink within the loop. -- Otherwise connect them to the loop bound variables that they can -- reach and exhaust their normal entry edges into the loop. -- This means a read can be delayed through a loop but not into it if -- that would increase the number of reads done by any given iteration. let ops = IS.filter (`MG.member` g) (bodyOperands stats) foldM_ connectOperand rbc (IS.elems ops) -- It might be beneficial to move the whole loop to device, to avoid -- reading the (initial) loop condition value. This must be balanced -- against the need to read the values bound by the loop statement. -- -- For more details see the similar description for if statements. when may_migrate $ case lform of ForLoop _ _ n -> onlyGraphedScalarSubExp n >>= addEdges (ToNodes bindings Nothing) WhileLoop n | Just (_, _, arg, _) <- loopValueFor n -> onlyGraphedScalarSubExp arg >>= addEdges (ToNodes bindings Nothing) _ -> pure () where subgraphId :: Id subgraphId = fst b loopValues :: [LoopValue] loopValues = let tmp = zip3 (b : bs) params (bodyResult body) tmp' = flip map tmp $ \(bnd, (p, arg), res) -> let i = nameToId (paramName p) in (bnd, i, arg, resSubExp res) in filter (\((_, t), _, _, _) -> isScalar t) tmp' bindings :: IdSet bindings = IS.fromList $ map (\((i, _), _, _, _) -> i) loopValues loopValueFor n = find (\(_, p, _, _) -> p == nameToId n) loopValues graphTheLoop :: Grapher () graphTheLoop = do mapM_ graphParam loopValues -- For simplicity we do not currently track memory reuse through merge -- parameters. A parameter does not simply reuse the memory of its -- argument; it must also consider the iteration return value, which in -- turn may depend on other merge parameters. -- -- Situations that would benefit from this tracking is unlikely to occur -- at the time of writing, and if it occurs current compiler limitations -- will prevent successful compilation. -- Specifically it requires the merge parameter argument to reuse memory -- from an array literal, and both it and the loop must occur within an -- if statement branch. Array literals are generally hoisted out of if -- statements however, and when they are not, a memory allocation error -- occurs. -- -- TODO: Track memory reuse through merge parameters. case lform of ForLoop _ _ n -> onlyGraphedScalarSubExp n >>= tellOperands WhileLoop _ -> pure () graphBody body where graphParam ((_, t), p, arg, _) = do -- It is unknown whether a read can be delayed via the parameter -- from one iteration to the next, so we have to create a vertex -- even if the initial value never depends on a read. addVertex (p, t) ops <- onlyGraphedScalarSubExp arg addEdges (MG.oneEdge p) ops mergeLoopParam :: LoopValue -> Grapher () mergeLoopParam (_, p, _, res) | Var n <- res, ret <- nameToId n, ret /= p = addEdges (MG.oneEdge p) (IS.singleton ret) | otherwise = pure () deviceBindings :: Graph -> (ReachableBindings, ReachableBindingsCache) -> Id -> (ReachableBindings, ReachableBindingsCache) deviceBindings g (rb, rbc) i = let (r, rbc') = MG.reduce g bindingReach rbc Normal i in case r of Produced rb' -> (rb <> rb', rbc') _ -> compilerBugS "Migration graph sink could be reached from source after it\ \ had been attempted routed." bindingReach :: ReachableBindings -> EdgeType -> Vertex Meta -> ReachableBindings bindingReach rb _ v | i <- vertexId v, IS.member i bindings = IS.insert i rb | otherwise = rb connectOperand :: ReachableBindingsCache -> Id -> Grapher ReachableBindingsCache connectOperand cache op = do g <- getGraph case MG.lookup op g of Nothing -> pure cache Just v -> case vertexEdges v of ToSink -> pure cache ToNodes es Nothing -> connectOp g cache op es ToNodes _ (Just nx) -> connectOp g cache op nx where connectOp :: Graph -> ReachableBindingsCache -> Id -> -- operand id IdSet -> -- its edges Grapher ReachableBindingsCache connectOp g rbc i es = do let (res, nx, rbc') = findBindings g (IS.empty, [], rbc) (IS.elems es) case res of FoundSink -> connectToSink i Produced rb -> modifyGraph $ MG.adjust (updateEdges nx rb) i pure rbc' updateEdges :: NonExhausted -> ReachableBindings -> Vertex Meta -> Vertex Meta updateEdges nx rb v | ToNodes es _ <- vertexEdges v = let nx' = IS.fromList nx es' = ToNodes (rb <> es) $ Just (rb <> nx') in v {vertexEdges = es'} | otherwise = v findBindings :: Graph -> (ReachableBindings, NonExhausted, ReachableBindingsCache) -> [Id] -> -- current non-exhausted edges (MG.Result ReachableBindings, NonExhausted, ReachableBindingsCache) findBindings _ (rb, nx, rbc) [] = (Produced rb, nx, rbc) findBindings g (rb, nx, rbc) (i : is) | Just v <- MG.lookup i g, Just gid <- metaGraphId (vertexMeta v), gid == subgraphId -- only search the subgraph = let (res, rbc') = MG.reduce g bindingReach rbc Normal i in case res of FoundSink -> (FoundSink, [], rbc') Produced rb' -> findBindings g (rb <> rb', nx, rbc') is | otherwise = -- don't exhaust findBindings g (rb, i : nx, rbc) is -- | Graph a 'WithAcc' statement. graphWithAcc :: [Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher () graphWithAcc bs inputs f = do -- Graph the body, capturing 'UpdateAcc' statements for delayed graphing. graphBody (lambdaBody f) -- Graph each accumulator monoid and its associated 'UpdateAcc' statements. mapM_ graph $ zip (lambdaReturnType f) inputs -- Record aliases for the backing memory of each returned array. -- 'WithAcc' statements are never migrated as a whole and always returns -- arrays backed by memory allocated elsewhere. let arrs = concatMap (\(_, as, _) -> map Var as) inputs let res = drop (length inputs) (bodyResult $ lambdaBody f) _ <- reusesReturn bs (arrs ++ map resSubExp res) -- Connect return variables to bound values. No outgoing edge exists -- from an accumulator vertex so skip those. Note that accumulators do -- not map to returned arrays one-to-one but one-to-many. ret <- mapM (onlyGraphedScalarSubExp . resSubExp) res mapM_ (uncurry createNode) $ zip (drop (length arrs) bs) ret where graph (Acc a _ types _, (_, _, comb)) = do let i = nameToId a delayed <- fromMaybe [] <$> gets (IM.lookup i . stateUpdateAccs) modify $ \st -> st {stateUpdateAccs = IM.delete i (stateUpdateAccs st)} graphAcc i types (fst <$> comb) delayed -- Neutral elements must always be made available on host for 'WithAcc' -- to type check. mapM_ connectSubExpToSink $ maybe [] snd comb graph _ = compilerBugS "Type error: WithAcc expression did not return accumulator." -- Graph the operator and all 'UpdateAcc' statements associated with an -- accumulator. -- -- The arguments are the 'Id' for the accumulator token, the element types of -- the accumulator/operator, its combining function if any, and all associated -- 'UpdateAcc' statements outside kernels. graphAcc :: Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher () graphAcc i _ _ [] = addSource (i, Prim Unit) -- Only used on device. graphAcc i types op delayed = do -- Accumulators are intended for use within SegOps but in principle the AST -- allows their 'UpdateAcc's to be used outside a kernel. This case handles -- that unlikely situation. env <- ask st <- get -- Collect statistics about the operator statements. let lambda = fromMaybe (Lambda [] [] (Body () SQ.empty [])) op let m = graphBody (lambdaBody lambda) let stats = R.runReader (evalStateT (captureBodyStats m) st) env -- We treat GPUBody kernels as host-only to not bother rewriting them inside -- operators and to simplify the analysis. They are unlikely to occur anyway. -- -- NOTE: Performance may degrade if a GPUBody is replaced with its contents -- but the containing operator is used on host. let host_only = bodyHostOnly stats || bodyHasGPUBody stats -- op operands are read from arrays and written back so if any of the operands -- are scalar then a read can be avoided by moving the UpdateAcc usages to -- device. If the op itself performs scalar reads its UpdateAcc usages should -- also be moved. let does_read = bodyReads stats || any isScalar types -- Determine which external variables the operator depends upon. -- 'bodyOperands' cannot be used as it might exclude operands that were -- connected to sinks within the body, so instead we create an artifical -- expression to capture graphed operands from. ops <- graphedScalarOperands (WithAcc [] lambda) case (host_only, does_read) of (True, _) -> do -- If the operator cannot run well in a GPUBody then all non-kernel -- UpdateAcc statements are host-only. The current analysis is ignorant -- of what happens inside kernels so we must assume that the operator -- is used within a kernel, meaning that we cannot migrate its statements. -- -- TODO: Improve analysis if UpdateAcc ever is used outside kernels. mapM_ (graphHostOnly . snd) delayed addEdges ToSink ops (_, True) -> do -- Migrate all accumulator usage to device to avoid reads and writes. mapM_ (graphAutoMove . fst) delayed addSource (i, Prim Unit) _ -> do -- Only migrate operator and UpdateAcc statements if it can allow their -- operands to be migrated. createNode (i, Prim Unit) ops forM_ delayed $ \(b, e) -> graphedScalarOperands e >>= createNode b . IS.insert i -- Returns for an expression all scalar operands that must be made available -- on host to execute the expression there. graphedScalarOperands :: Exp GPU -> Grapher Operands graphedScalarOperands e = let is = fst $ execState (collect e) initial in IS.intersection is <$> getGraphedScalars where initial = (IS.empty, S.empty) -- scalar operands, accumulator tokens captureName n = modify $ first $ IS.insert (nameToId n) captureAcc a = modify $ second $ S.insert a collectFree x = mapM_ captureName (namesToList $ freeIn x) collect b@BasicOp {} = collectBasic b collect (Apply _ params _ _) = mapM_ (collectSubExp . fst) params collect (Match ses cases defbody _) = do mapM_ collectSubExp ses mapM_ (collectBody . caseBody) cases collectBody defbody collect (Loop params lform body) = do mapM_ (collectSubExp . snd) params collectLForm lform collectBody body collect (WithAcc accs f) = collectWithAcc accs f collect (Op op) = collectHostOp op collectBasic (BasicOp (Update _ _ slice _)) = -- Writing a scalar to an array can be replaced with copying a single- -- element slice. If the scalar originates from device memory its read -- can thus be prevented without requiring the 'Update' to be migrated. collectFree slice collectBasic (BasicOp (Replicate shape _)) = -- The replicate of a scalar can be rewritten as a replicate of a single -- element array followed by a slice index. collectFree shape collectBasic e' = -- Note: Plain VName values only refer to arrays. walkExpM (identityWalker {walkOnSubExp = collectSubExp}) e' collectSubExp (Var n) = captureName n collectSubExp _ = pure () collectBody body = do collectStms (bodyStms body) collectFree (bodyResult body) collectStms = mapM_ collectStm collectStm (Let pat _ ua) | BasicOp UpdateAcc {} <- ua, Pat [pe] <- pat, Acc a _ _ _ <- typeOf pe = -- Capture the tokens of accumulators used on host. captureAcc a >> collectBasic ua collectStm stm = collect (stmExp stm) collectLForm (ForLoop _ _ b) = collectSubExp b -- WhileLoop condition is declared as a loop parameter. collectLForm (WhileLoop _) = pure () -- The collective operands of an operator lambda body are only used on host -- if the associated accumulator is used in an UpdateAcc statement outside a -- kernel. collectWithAcc inputs f = do collectBody (lambdaBody f) used_accs <- gets snd let accs = take (length inputs) (lambdaReturnType f) let used = map (\(Acc a _ _ _) -> S.member a used_accs) accs mapM_ collectAcc (zip used inputs) collectAcc (_, (_, _, Nothing)) = pure () collectAcc (used, (_, _, Just (op, nes))) = do mapM_ collectSubExp nes when used $ collectBody (lambdaBody op) -- Does not collect named operands in -- -- * types and shapes; size variables are assumed available to the host. -- -- * use by a kernel body. -- -- All other operands are conservatively collected even if they generally -- appear to be size variables or results computed by a SizeOp. collectHostOp (SegOp (SegMap lvl sp _ _)) = do collectSegLevel lvl collectSegSpace sp collectHostOp (SegOp (SegRed lvl sp ops _ _)) = do collectSegLevel lvl collectSegSpace sp mapM_ collectSegBinOp ops collectHostOp (SegOp (SegScan lvl sp ops _ _)) = do collectSegLevel lvl collectSegSpace sp mapM_ collectSegBinOp ops collectHostOp (SegOp (SegHist lvl sp ops _ _)) = do collectSegLevel lvl collectSegSpace sp mapM_ collectHistOp ops collectHostOp (SizeOp op) = collectFree op collectHostOp (OtherOp op) = collectFree op collectHostOp GPUBody {} = pure () collectSegLevel = mapM_ captureName . namesToList . freeIn collectSegSpace space = mapM_ collectSubExp (segSpaceDims space) collectSegBinOp (SegBinOp _ _ nes _) = mapM_ collectSubExp nes collectHistOp (HistOp _ rf _ nes _ _) = do collectSubExp rf mapM_ collectSubExp nes -------------------------------------------------------------------------------- -- GRAPH BUILDING - PRIMITIVES -- -------------------------------------------------------------------------------- -- | Creates a vertex for the given binding, provided that the set of operands -- is not empty. createNode :: Binding -> Operands -> Grapher () createNode b ops = unless (IS.null ops) (addVertex b >> addEdges (MG.oneEdge $ fst b) ops) -- | Adds a vertex to the graph for the given binding. addVertex :: Binding -> Grapher () addVertex (i, t) = do meta <- getMeta let v = MG.vertex i meta when (isScalar t) $ modifyGraphedScalars (IS.insert i) when (isArray t) $ recordCopyableMemory i (metaBodyDepth meta) modifyGraph (MG.insert v) -- | Adds a source connected vertex to the graph for the given binding. addSource :: Binding -> Grapher () addSource b = do addVertex b modifySources $ second (fst b :) -- | Adds the given edges to each vertex identified by the 'IdSet'. It is -- assumed that all vertices reside within the body that currently is being -- graphed. addEdges :: Edges -> IdSet -> Grapher () addEdges ToSink is = do modifyGraph $ \g -> IS.foldl' (flip MG.connectToSink) g is modifyGraphedScalars (`IS.difference` is) addEdges es is = do modifyGraph $ \g -> IS.foldl' (flip $ MG.addEdges es) g is tellOperands is -- | Ensure that a variable (which is in scope) will be made available on host -- before its first use. requiredOnHost :: Id -> Grapher () requiredOnHost i = do mv <- MG.lookup i <$> getGraph case mv of Nothing -> pure () Just v -> do connectToSink i tellHostOnlyParent (metaBodyDepth $ vertexMeta v) -- | Connects the vertex of the given id to a sink. connectToSink :: Id -> Grapher () connectToSink i = do modifyGraph (MG.connectToSink i) modifyGraphedScalars (IS.delete i) -- | Like 'connectToSink' but vertex is given by a t'SubExp'. This is a no-op if -- the t'SubExp' is a constant. connectSubExpToSink :: SubExp -> Grapher () connectSubExpToSink (Var n) = connectToSink (nameToId n) connectSubExpToSink _ = pure () -- | Routes all possible routes within the subgraph identified by this id. -- Returns the ids of the source connected vertices that were attempted routed. -- -- Assumption: The subgraph with the given id has just been created and no path -- exists from it to an external sink. routeSubgraph :: Id -> Grapher [Id] routeSubgraph si = do st <- get let g = stateGraph st let (routed, unrouted) = stateSources st let (gsrcs, unrouted') = span (inSubGraph si g) unrouted let (sinks, g') = MG.routeMany gsrcs g put $ st { stateGraph = g', stateSources = (gsrcs ++ routed, unrouted'), stateSinks = sinks ++ stateSinks st } pure gsrcs -- | @inSubGraph si g i@ returns whether @g@ contains a vertex with id @i@ that -- is declared within the subgraph with id @si@. inSubGraph :: Id -> Graph -> Id -> Bool inSubGraph si g i | Just v <- MG.lookup i g, Just mgi <- metaGraphId (vertexMeta v) = si == mgi inSubGraph _ _ _ = False -- | @b `reuses` n@ records that @b@ binds an array backed by the same memory -- as @n@. If @b@ is not array typed or the backing memory is not copyable then -- this does nothing. reuses :: Binding -> VName -> Grapher () reuses (i, t) n | isArray t = do body_depth <- outermostCopyableArray n forM_ body_depth (recordCopyableMemory i) | otherwise = pure () reusesSubExp :: Binding -> SubExp -> Grapher () reusesSubExp b (Var n) = b `reuses` n reusesSubExp _ _ = pure () -- @reusesReturn bs res@ records each array binding in @bs@ as reusing copyable -- memory if the corresponding return value in @res@ is backed by copyable -- memory. -- -- If every array binding is registered as being backed by copyable memory then -- the function returns @True@, otherwise it returns @False@. reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool reusesReturn bs res = do body_depth <- metaBodyDepth <$> getMeta foldM (reuse body_depth) True (zip bs res) where reuse :: Int -> Bool -> (Binding, SubExp) -> Grapher Bool reuse body_depth onlyCopyable (b, se) | all (== intConst Int64 1) (arrayDims $ snd b) = -- Single element arrays are immediately recognizable as copyable so -- don't bother recording those. Note that this case also matches -- primitive return values. pure onlyCopyable | (i, t) <- b, isArray t, Var n <- se = do res_body_depth <- outermostCopyableArray n case res_body_depth of Just inner -> do recordCopyableMemory i (min body_depth inner) let returns_free_var = inner <= body_depth pure (onlyCopyable && not returns_free_var) _ -> pure False | otherwise = pure onlyCopyable -- @reusesBranches bs seses@ records each array binding in @bs@ as -- reusing copyable memory if each corresponding return value in the -- lists in @ses@ are backed by copyable memory. Each list is the -- result of a branch body (i.e. for 'if' the list has two elements). -- -- If every array binding is registered as being backed by copyable -- memory then the function returns @True@, otherwise it returns -- @False@. reusesBranches :: [Binding] -> [[SubExp]] -> Grapher Bool reusesBranches bs seses = do body_depth <- metaBodyDepth <$> getMeta foldM (reuse body_depth) True $ zip bs $ L.transpose seses where reuse :: Int -> Bool -> (Binding, [SubExp]) -> Grapher Bool reuse body_depth onlyCopyable (b, ses) | all (== intConst Int64 1) (arrayDims $ snd b) = -- Single element arrays are immediately recognizable as copyable so -- don't bother recording those. Note that this case also matches -- primitive return values. pure onlyCopyable | (i, t) <- b, isArray t, Just ns <- mapM subExpVar ses = do body_depths <- mapM outermostCopyableArray ns case sequence body_depths of Just bds -> do let inner = minimum bds recordCopyableMemory i (min body_depth inner) let returns_free_var = inner <= body_depth pure (onlyCopyable && not returns_free_var) _ -> pure False | otherwise = pure onlyCopyable -------------------------------------------------------------------------------- -- GRAPH BUILDING - TYPES -- -------------------------------------------------------------------------------- type Grapher = StateT State (R.Reader Env) data Env = Env { -- | See 'HostOnlyFuns'. envHostOnlyFuns :: HostOnlyFuns, -- | Metadata for the current body being graphed. envMeta :: Meta } -- | A measurement of how many bodies something is nested within. type BodyDepth = Int -- | Metadata on the environment that a variable is declared within. data Meta = Meta { -- | How many if statement branch bodies the variable binding is nested -- within. If a route passes through the edge u->v and the fork depth -- -- 1) increases from u to v, then u is within a conditional branch. -- -- 2) decreases from u to v, then v binds the result of two or more -- branches. -- -- After the graph has been built and routed, this can be used to delay -- reads into deeper branches to reduce their likelihood of manifesting. metaForkDepth :: Int, -- | How many bodies the variable is nested within. metaBodyDepth :: BodyDepth, -- | An id for the subgraph within which the variable exists, defined at -- the body level. A read may only be delayed to a point within its own -- subgraph. metaGraphId :: Maybe Id } -- | Ids for all variables used as an operand. type Operands = IdSet -- | Statistics on the statements within a body and their dependencies. data BodyStats = BodyStats { -- | Whether the body contained any host-only statements. bodyHostOnly :: Bool, -- | Whether the body contained any GPUBody kernels. bodyHasGPUBody :: Bool, -- | Whether the body performed any reads. bodyReads :: Bool, -- | All scalar variables represented in the graph that have been used -- as return values of the body or as operands within it, including those -- that are defined within the body itself. Variables with vertices -- connected to sinks may be excluded. bodyOperands :: Operands, -- | Depth of parent bodies with variables that are required on host. Since -- the variables are required on host, the parent statements of these bodies -- cannot be moved to device as a whole. They are host-only. bodyHostOnlyParents :: IS.IntSet } instance Semigroup BodyStats where (BodyStats ho1 gb1 r1 o1 hop1) <> (BodyStats ho2 gb2 r2 o2 hop2) = BodyStats { bodyHostOnly = ho1 || ho2, bodyHasGPUBody = gb1 || gb2, bodyReads = r1 || r2, bodyOperands = IS.union o1 o2, bodyHostOnlyParents = IS.union hop1 hop2 } instance Monoid BodyStats where mempty = BodyStats { bodyHostOnly = False, bodyHasGPUBody = False, bodyReads = False, bodyOperands = IS.empty, bodyHostOnlyParents = IS.empty } type Graph = MG.Graph Meta -- | All vertices connected from a source, partitioned into those that have -- been attempted routed and those which have not. type Sources = ([Id], [Id]) -- | All terminal vertices of routes. type Sinks = [Id] -- | A captured statement for which graphing has been delayed. type Delayed = (Binding, Exp GPU) -- | The vertex handle for a variable and its type. type Binding = (Id, Type) -- | Array variables backed by memory segments that may be copied, mapped to the -- outermost known body depths that declares arrays backed by a superset of -- those segments. type CopyableMemoryMap = IM.IntMap BodyDepth data State = State { -- | The graph being built. stateGraph :: Graph, -- | All known scalars that have been graphed. stateGraphedScalars :: IdSet, -- | All variables that directly bind scalars read from device memory. stateSources :: Sources, -- | Graphed scalars that are used as operands by statements that cannot be -- migrated. A read cannot be delayed beyond these, so if the statements -- that bind these variables are moved to device, the variables must be read -- from device memory. stateSinks :: Sinks, -- | Observed 'UpdateAcc' host statements to be graphed later. stateUpdateAccs :: IM.IntMap [Delayed], -- | A map of encountered arrays that are backed by copyable memory. -- Trivial instances such as single element arrays are excluded. stateCopyableMemory :: CopyableMemoryMap, -- | Information about the current body being graphed. stateStats :: BodyStats } -------------------------------------------------------------------------------- -- GRAPHER OPERATIONS -- -------------------------------------------------------------------------------- execGrapher :: HostOnlyFuns -> Grapher a -> (Graph, Sources, Sinks) execGrapher hof m = let s = R.runReader (execStateT m st) env in (stateGraph s, stateSources s, stateSinks s) where env = Env { envHostOnlyFuns = hof, envMeta = Meta { metaForkDepth = 0, metaBodyDepth = 0, metaGraphId = Nothing } } st = State { stateGraph = MG.empty, stateGraphedScalars = IS.empty, stateSources = ([], []), stateSinks = [], stateUpdateAccs = IM.empty, stateCopyableMemory = IM.empty, stateStats = mempty } -- | Execute a computation in a modified environment. local :: (Env -> Env) -> Grapher a -> Grapher a local f = mapStateT (R.local f) -- | Fetch the value of the environment. ask :: Grapher Env ask = lift R.ask -- | Retrieve a function of the current environment. asks :: (Env -> a) -> Grapher a asks = lift . R.asks -- | Register that the body contains a host-only statement. This means its -- parent statement and any parent bodies themselves are host-only. A host-only -- statement should not be migrated, either because it cannot run on device or -- because it would be inefficient to do so. tellHostOnly :: Grapher () tellHostOnly = modify $ \st -> st {stateStats = (stateStats st) {bodyHostOnly = True}} -- | Register that the body contains a GPUBody kernel. tellGPUBody :: Grapher () tellGPUBody = modify $ \st -> st {stateStats = (stateStats st) {bodyHasGPUBody = True}} -- | Register that the current body contains a statement that reads device -- memory. tellRead :: Grapher () tellRead = modify $ \st -> st {stateStats = (stateStats st) {bodyReads = True}} -- | Register that these variables are used as operands within the current body. tellOperands :: IdSet -> Grapher () tellOperands is = modify $ \st -> let stats = stateStats st operands = bodyOperands stats in st {stateStats = stats {bodyOperands = operands <> is}} -- | Register that the current statement with a body at the given body depth is -- host-only. tellHostOnlyParent :: BodyDepth -> Grapher () tellHostOnlyParent body_depth = modify $ \st -> let stats = stateStats st parents = bodyHostOnlyParents stats parents' = IS.insert body_depth parents in st {stateStats = stats {bodyHostOnlyParents = parents'}} -- | Get the graph under construction. getGraph :: Grapher Graph getGraph = gets stateGraph -- | All scalar variables with a vertex representation in the graph. getGraphedScalars :: Grapher IdSet getGraphedScalars = gets stateGraphedScalars -- | Every known array that is backed by a memory segment that may be copied, -- mapped to the outermost known body depth where an array is backed by a -- superset of that segment. -- -- A body where all returned arrays are backed by such memory and are written by -- its own statements will retain its asymptotic cost if migrated as a whole. getCopyableMemory :: Grapher CopyableMemoryMap getCopyableMemory = gets stateCopyableMemory -- | The outermost known body depth for an array backed by the same copyable -- memory as the array with this name. outermostCopyableArray :: VName -> Grapher (Maybe BodyDepth) outermostCopyableArray n = IM.lookup (nameToId n) <$> getCopyableMemory -- | Reduces the variables to just the 'Id's of those that are scalars and which -- have a vertex representation in the graph, excluding those that have been -- connected to sinks. onlyGraphedScalars :: (Foldable t) => t VName -> Grapher IdSet onlyGraphedScalars vs = do let is = foldl' (\s n -> IS.insert (nameToId n) s) IS.empty vs IS.intersection is <$> getGraphedScalars -- | Like 'onlyGraphedScalars' but for a single 'VName'. onlyGraphedScalar :: VName -> Grapher IdSet onlyGraphedScalar n = do let i = nameToId n gss <- getGraphedScalars if IS.member i gss then pure (IS.singleton i) else pure IS.empty -- | Like 'onlyGraphedScalars' but for a single t'SubExp'. onlyGraphedScalarSubExp :: SubExp -> Grapher IdSet onlyGraphedScalarSubExp (Constant _) = pure IS.empty onlyGraphedScalarSubExp (Var n) = onlyGraphedScalar n -- | Update the graph under construction. modifyGraph :: (Graph -> Graph) -> Grapher () modifyGraph f = modify $ \st -> st {stateGraph = f (stateGraph st)} -- | Update the contents of the graphed scalar set. modifyGraphedScalars :: (IdSet -> IdSet) -> Grapher () modifyGraphedScalars f = modify $ \st -> st {stateGraphedScalars = f (stateGraphedScalars st)} -- | Update the contents of the copyable memory map. modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher () modifyCopyableMemory f = modify $ \st -> st {stateCopyableMemory = f (stateCopyableMemory st)} -- | Update the set of source connected vertices. modifySources :: (Sources -> Sources) -> Grapher () modifySources f = modify $ \st -> st {stateSources = f (stateSources st)} -- | Record that this variable binds an array that is backed by copyable -- memory shared by an array at this outermost body depth. recordCopyableMemory :: Id -> BodyDepth -> Grapher () recordCopyableMemory i bd = modifyCopyableMemory (IM.insert i bd) -- | Increment the fork depth for variables graphed by this action. incForkDepthFor :: Grapher a -> Grapher a incForkDepthFor = local $ \env -> let meta = envMeta env fork_depth = metaForkDepth meta in env {envMeta = meta {metaForkDepth = fork_depth + 1}} -- | Increment the body depth for variables graphed by this action. incBodyDepthFor :: Grapher a -> Grapher a incBodyDepthFor = local $ \env -> let meta = envMeta env body_depth = metaBodyDepth meta in env {envMeta = meta {metaBodyDepth = body_depth + 1}} -- | Change the graph id for variables graphed by this action. graphIdFor :: Id -> Grapher a -> Grapher a graphIdFor i = local $ \env -> let meta = envMeta env in env {envMeta = meta {metaGraphId = Just i}} -- | Capture body stats produced by the given action. captureBodyStats :: Grapher a -> Grapher BodyStats captureBodyStats m = do stats <- gets stateStats modify $ \st -> st {stateStats = mempty} _ <- m stats' <- gets stateStats modify $ \st -> st {stateStats = stats <> stats'} pure stats' -- | Can applications of this function be moved to device? isHostOnlyFun :: Name -> Grapher Bool isHostOnlyFun fn = asks $ S.member fn . envHostOnlyFuns -- | Get the 'Meta' corresponding to the current body. getMeta :: Grapher Meta getMeta = asks envMeta -- | Get the body depth of the current body (its nesting level). getBodyDepth :: Grapher BodyDepth getBodyDepth = asks (metaBodyDepth . envMeta) futhark-0.25.27/src/Futhark/Optimise/ReduceDeviceSyncs/MigrationTable/000077500000000000000000000000001475065116200255765ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Optimise/ReduceDeviceSyncs/MigrationTable/Graph.hs000066400000000000000000000531351475065116200272020ustar00rootroot00000000000000-- | This module contains the type definitions and basic operations -- for the graph that -- "Futhark.Optimise.ReduceDeviceSyncs.MigrationTable" internally uses -- to construct a migration table. It is however completely -- Futhark-agnostic and implements a generic graph abstraction. -- -- = Overview -- -- The 'Graph' type is a data flow dependency graph of program variables, each -- variable represented by a 'Vertex'. A vertex may have edges to other vertices -- or to a sink, which is a special vertex with no graph representation. Each -- edge to a vertex is either from another vertex or from a source, which also -- is a special vertex with no graph representation. -- -- The primary graph operation provided by this module is 'route'. Given the -- vertex that some unspecified source has an edge to, a path is attempted -- found to a sink. If a sink can be reached from the source, all edges along -- the path are reversed. The path in the opposite direction of reversed edges -- from a source to some sink is a route. -- -- Routes can be used to find a minimum vertex cut in the graph through what -- amounts to a specialized depth-first search implementation of the -- Ford-Fulkerson method. When viewed in this way each graph edge has a capacity -- of 1 and the reversing of edges to create routes amounts to sending reverse -- flow through a residual network (the current state of the graph). -- The max-flow min-cut theorem allows one to determine a minimum edge cut that -- separates the sources and sinks. -- -- If each vertex @v@ in the graph is viewed as two vertices, @v_in@ and -- @v_out@, with all ingoing edges to @v@ going to @v_in@, all outgoing edges -- from @v@ going from @v_out@, and @v_in@ connected to @v_out@ with a single -- edge, then the minimum edge cut of the view amounts to a minimum vertex cut -- in the actual graph. The view need not be manifested as whether @v_in@ or -- @v_out@ is reached by an edge to @v@ can be determined from whether that edge -- is reversed or not. The presence of an outgoing, reversed edge also gives the -- state of the virtual edge that connects @v_in@ to @v_out@. -- -- When routing fails to find a sink in some subgraph reached via an edge then -- that edge is marked exhausted. No sink can be reached via an exhausted edge, -- and any subsequent routing attempt will skip pathfinding along such edge. module Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph ( -- * Types Graph, Id, IdSet, Vertex (..), Routing (..), Exhaustion (..), Edges (..), EdgeType (..), Visited, Result (..), -- * Construction empty, vertex, declareEdges, oneEdge, none, -- * Insertion insert, -- * Update adjust, connectToSink, addEdges, -- * Query member, lookup, isSinkConnected, -- * Routing route, routeMany, -- * Traversal fold, reduce, ) where import Data.IntMap.Strict qualified as IM import Data.IntSet qualified as IS import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe (fromJust) import Prelude hiding (lookup) -------------------------------------------------------------------------------- -- TYPES -- -------------------------------------------------------------------------------- -- | A data flow dependency graph of program variables, each variable -- represented by a 'Vertex'. newtype Graph m = Graph (IM.IntMap (Vertex m)) -- | A handle that identifies a specific 'Vertex'. type Id = Int -- | A set of 'Id's. type IdSet = IS.IntSet -- | A graph representation of some program variable. data Vertex m = Vertex { -- | The handle for this vertex in the graph. vertexId :: Id, -- | Custom data associated with the variable. vertexMeta :: m, -- | Whether a route passes through this vertex, and from where. vertexRouting :: Routing, -- | Handles of vertices that this vertex has an edge to. vertexEdges :: Edges } -- | Route tracking for some vertex. -- If a route passes through the vertex then both an ingoing and an outgoing -- edge to/from that vertex will have been reversed, and the vertex will in -- effect have lost one edge and gained another. The gained edge will be to -- the prior vertex along the route that passes through. data Routing = -- | No route passes through the vertex, and no edges have been reversed, -- added, nor deleted compared to what was declared. NoRoute | -- | A route passes through the vertex, and the prior vertex is the source -- of that route. The edge gained by reversal is by definition exhausted. FromSource | -- | A route passes through the vertex, and this is the handle of the prior -- vertex. The edge gained by reversal may be exhausted. Routing assumes -- that at most one 'FromNode' routing exists to each vertex in a graph. FromNode Id Exhaustion deriving (Show, Eq, Ord) -- | Whether some edge is exhausted or not. No sink can be reached via an -- exhausted edge. data Exhaustion = Exhausted | NotExhausted deriving (Show, Eq, Ord) -- | All relevant edges that have been declared from some vertex, plus -- bookkeeping to track their exhaustion and reversal. data Edges = -- | The vertex has an edge to a sink; all other declared edges are -- irrelevant. The edge cannot become exhausted, and it is reversed if a -- route passes through the vertex (@vertexRouting v /= NoRoute@). ToSink | -- | All vertices that the vertex has a declared edge to, and which of -- those edges that are not exhausted nor reversed, if not all. ToNodes IdSet (Maybe IdSet) deriving (Show, Eq, Ord) instance Semigroup Edges where ToSink <> _ = ToSink _ <> ToSink = ToSink (ToNodes a1 Nothing) <> (ToNodes a2 Nothing) = ToNodes (a1 <> a2) Nothing (ToNodes a1 (Just e1)) <> (ToNodes a2 Nothing) = ToNodes (a1 <> a2) $ Just (e1 <> IS.difference a2 a1) (ToNodes a1 Nothing) <> (ToNodes a2 (Just e2)) = ToNodes (a1 <> a2) $ Just (e2 <> IS.difference a1 a2) (ToNodes a1 (Just e1)) <> (ToNodes a2 (Just e2)) = let a = IS.difference e2 (IS.difference a1 e1) b = IS.difference e1 (IS.difference a2 e2) in ToNodes (a1 <> a2) $ Just (a <> b) instance Monoid Edges where -- The empty set of edges. mempty = ToNodes IS.empty Nothing -- | Whether a vertex is reached via a normal or reversed edge. data EdgeType = Normal | Reversed deriving (Eq, Ord) -- | State that tracks which vertices a traversal has visited, caching immediate -- computations. newtype Visited a = Visited {visited :: M.Map (EdgeType, Id) a} -- | The result of a graph traversal that may abort early in case a sink is -- reached. data Result a = -- | The traversal finished without encountering a sink, producing this -- value. Produced a | -- | The traversal was aborted because a sink was reached. FoundSink deriving (Eq) instance (Semigroup a) => Semigroup (Result a) where FoundSink <> _ = FoundSink _ <> FoundSink = FoundSink Produced x <> Produced y = Produced (x <> y) -------------------------------------------------------------------------------- -- CONSTRUCTION -- -------------------------------------------------------------------------------- -- | The empty graph. empty :: Graph m empty = Graph IM.empty -- | Constructs a 'Vertex' without any edges. vertex :: Id -> m -> Vertex m vertex i m = Vertex { vertexId = i, vertexMeta = m, vertexRouting = NoRoute, vertexEdges = mempty } -- | Creates a set of edges where no edge is reversed or exhausted. declareEdges :: [Id] -> Edges declareEdges is = ToNodes (IS.fromList is) Nothing -- | Like 'declareEdges' but for a single vertex. oneEdge :: Id -> Edges oneEdge i = ToNodes (IS.singleton i) Nothing -- | Initial 'Visited' state before any vertex has been visited. none :: Visited a none = Visited M.empty -------------------------------------------------------------------------------- -- INSERTION -- -------------------------------------------------------------------------------- -- | Insert a new vertex into the graph. If its variable already is represented -- in the graph, the original graph is returned. insert :: Vertex m -> Graph m -> Graph m insert v (Graph m) = Graph $ IM.insertWith const (vertexId v) v m -------------------------------------------------------------------------------- -- UPDATE -- -------------------------------------------------------------------------------- -- | Adjust the vertex with this specific id. When no vertex with that id is a -- member of the graph, the original graph is returned. adjust :: (Vertex m -> Vertex m) -> Id -> Graph m -> Graph m adjust f i (Graph m) = Graph $ IM.adjust f i m -- | Connect the vertex with this id to a sink. When no vertex with that id is a -- member of the graph, the original graph is returned. connectToSink :: Id -> Graph m -> Graph m connectToSink = adjust $ \v -> v {vertexEdges = ToSink} -- | Add these edges to the vertex with this id. When no vertex with that id is -- a member of the graph, the original graph is returned. addEdges :: Edges -> Id -> Graph m -> Graph m addEdges es = adjust $ \v -> v {vertexEdges = es <> vertexEdges v} -------------------------------------------------------------------------------- -- QUERY -- -------------------------------------------------------------------------------- -- | Does a vertex for the given id exist in the graph? member :: Id -> Graph m -> Bool member i (Graph m) = IM.member i m -- | Returns the vertex with the given id. lookup :: Id -> Graph m -> Maybe (Vertex m) lookup i (Graph m) = IM.lookup i m -- | Returns whether a vertex with the given id exists in the -- graph and is connected directly to a sink. isSinkConnected :: Id -> Graph m -> Bool isSinkConnected i g = maybe False ((ToSink ==) . vertexEdges) (lookup i g) -- ROUTING -- -------------------------------------------------------------------------------- -- | @route src g@ attempts to find a path in @g@ from the source connected -- vertex with id @src@. If a sink is found, all edges along the path will be -- reversed to create a route, and the id of the vertex that connects to the -- sink is returned. route :: Id -> Graph m -> (Maybe Id, Graph m) route src g = case route' IM.empty 0 Nothing Normal src g of (DeadEnd, g') -> (Nothing, g') (SinkFound snk, g') -> (Just snk, g') (CycleDetected {}, _) -> error "Routing did not escape cycle in Futhark.Analysis.MigrationTable.Graph." -- | @routeMany srcs g@ attempts to create a 'route' in @g@ from every vertex -- in @srcs@. Returns the ids for the vertices connected to each found sink. routeMany :: [Id] -> Graph m -> ([Id], Graph m) routeMany srcs graph = L.foldl' f ([], graph) srcs where f (snks, g) src = case route src g of (Nothing, g') -> (snks, g') (Just snk, g') -> (snk : snks, g') -------------------------------------------------------------------------------- -- TRAVERSAL -- -------------------------------------------------------------------------------- -- | @fold g f (a, vs) et i@ folds @f@ over the vertices in @g@ that can be -- reached from the vertex with handle @i@ accessed via an edge of type @et@. -- Each vertex @v@ may be visited up to two times, once for each type of edge -- @e@ pointing to it, and each time @f a e v@ is evaluated to produce an -- updated @a@ value to be used in future @f@ evaluations or to be returned. -- The @vs@ set records which @f a e v@ evaluations already have taken place. -- The function returns an updated 'Visited' set recording the evaluations it -- has performed. fold :: Graph m -> (a -> EdgeType -> Vertex m -> a) -> (a, Visited ()) -> EdgeType -> Id -> (a, Visited ()) fold g f (res, vs) et i | M.notMember (et, i) (visited vs), Just v <- lookup i g = let res' = f res et v vs' = Visited $ M.insert (et, i) () (visited vs) st = (res', vs') in case (et, vertexRouting v) of (Normal, FromSource) -> st (Normal, FromNode rev _) -> foldReversed st rev (Reversed, FromNode rev _) -> foldAll st rev (vertexEdges v) _ -> foldNormals st (vertexEdges v) | otherwise = (res, vs) where foldReversed st = fold g f st Reversed foldAll st rev es = foldReversed (foldNormals st es) rev foldNormals st ToSink = st foldNormals st (ToNodes es _) = IS.foldl' (\s -> fold g f s Normal) st es -- | @reduce g r vs et i@ returns 'FoundSink' if a sink can be reached via the -- vertex @v@ with id @i@ in @g@. Otherwise it returns 'Produced' @(r x et v)@ -- where @x@ is the '<>' aggregate of all values produced by reducing the -- vertices that are available via the edges of @v@. -- @et@ identifies the type of edge that @v@ is accessed by and thereby which -- edges of @v@ that are available. @vs@ caches reductions of vertices that -- previously have been visited in the graph. -- -- The reduction of a cyclic reference resolves to 'mempty'. reduce :: (Monoid a) => Graph m -> (a -> EdgeType -> Vertex m -> a) -> Visited (Result a) -> EdgeType -> Id -> (Result a, Visited (Result a)) reduce g r vs et i | Just res <- M.lookup (et, i) (visited vs) = (res, vs) | Just v <- lookup i g = reduceVertex v | otherwise = (Produced mempty, vs) -- shouldn't happen where reduceVertex v = let (res, vs') = reduceEdges v in case res of Produced x -> cached (Produced $ r x et v) vs' FoundSink -> cached res vs' cached res vs0 = let vs1 = Visited (M.insert (et, i) res $ visited vs0) in (res, vs1) reduceEdges v = case (et, vertexRouting v) of (Normal, FromSource) -> (Produced mempty, vs) (Normal, FromNode rev _) -> entry (reduceReversed rev) (Reversed, FromNode rev _) -> entry (reduceAll rev $ vertexEdges v) _ -> entry (reduceNormals $ vertexEdges v) -- Handle cycles entry f = f $ Visited $ M.insert (et, i) (Produced mempty) (visited vs) reduceReversed rev vs' = reduce g r vs' Reversed rev reduceAll rev es vs0 = let (res, vs1) = reduceNormals es vs0 in case res of Produced _ -> let (res', vs2) = reduceReversed rev vs1 in (res <> res', vs2) FoundSink -> (res, vs1) reduceNormals ToSink vs' = (FoundSink, vs') reduceNormals (ToNodes es _) vs' = reduceNorms mempty (IS.elems es) vs' reduceNorms x [] vs0 = (Produced x, vs0) reduceNorms x (e : es) vs0 = let (res, vs1) = reduce g r vs0 Normal e in case res of Produced y -> reduceNorms (x <> y) es vs1 FoundSink -> (res, vs1) -------------------------------------------------------------------------------- -- ROUTING INTERNALS -- -------------------------------------------------------------------------------- -- | A set of vertices visited by a graph traversal, and at what depth they were -- encountered. Used to detect cycles. type Pending = IM.IntMap Depth -- | Search depth. Distance to some vertex from some search root. type Depth = Int -- | The outcome of attempted to find a route through a vertex. data RoutingResult a = -- | No sink could be reached through this vertex. DeadEnd | -- | A cycle was detected. A sink can be reached through this vertex if a -- sink can be reached from the vertex at this depth. If no sink can be -- reached from the vertex at this depth, then the graph should be updated -- by these actions. Until the vertex is reached, the status of these -- vertices are pending. CycleDetected Depth [Graph a -> Graph a] Pending | -- | A sink was found. This is the id of the vertex connected to it. SinkFound Id instance Semigroup (RoutingResult a) where SinkFound i <> _ = SinkFound i _ <> SinkFound i = SinkFound i CycleDetected d1 as1 _ <> CycleDetected d2 as2 p2 = CycleDetected (min d1 d2) (as1 ++ as2) p2 _ <> CycleDetected d as p = CycleDetected d as p CycleDetected d as p <> _ = CycleDetected d as p DeadEnd <> DeadEnd = DeadEnd instance Monoid (RoutingResult a) where mempty = DeadEnd route' :: Pending -> Depth -> Maybe Id -> EdgeType -> Id -> Graph m -> (RoutingResult m, Graph m) route' p d prev et i g | Just d' <- IM.lookup i p = let found_cycle = (CycleDetected d' [] p, g) in case et of -- Accessing some vertex v via a normal edge corresponds to accessing -- v_in via a normal edge. If v_in has a reversed edge then that is -- the only outgoing edge that is available. -- All outgoing edges available via this ingoing edge were already -- available via the edge that first reached the vertex. Normal -> found_cycle -- Accessing some vertex v via a reversed edge corresponds to -- accessing v_out via a reversed edge. All other edges of v_out are -- available, and the edge from v_in to v_out has been reversed, -- implying that v_in has a single reversed edge that also is -- available. -- There exists at most one reversed edge to each vertex. Since this -- vertex was reached via one, and the vertex already have been -- reached, then the first reach must have been via a normal edge -- that only could traverse a reversed edge. The reversed edge from -- v_out to v_in thus completes a cycle, but a sink might be -- reachable via any of the other edges from v_out. -- The depth for the vertex need not be updated as this is the only -- edge to v_out and 'prev' is already in the 'Pending' map. -- It follows that no (new) cycle can start here. Reversed -> let (res, g') = routeNormals (fromJust $ lookup i g) g p in (fst found_cycle <> res, g') | Just v <- lookup i g = routeVertex v | otherwise = backtrack where backtrack = (DeadEnd, g) routeVertex v = case (et, vertexRouting v) of (Normal, FromSource) -> backtrack (Normal, FromNode _ Exhausted) -> backtrack (Normal, FromNode rev _) -> entry (routeReversed rev) (Reversed, FromNode rev _) -> entry (routeAll rev v) _ -> entry (routeNormals v) entry f = let (res, g0) = f g (IM.insert i d p) in case res of CycleDetected d' as _ | d == d' -> (DeadEnd, L.foldl' (\g1 a -> a g1) g0 as) _ | otherwise -> (res, g0) routeAll rev v g0 p0 = let (res, g1) = routeNormals v g0 p0 in case res of DeadEnd -> routeReversed rev g1 p0 CycleDetected _ _ p1 -> let (res', g2) = routeReversed rev g1 p1 in (res <> res', g2) SinkFound _ -> (res, g1) routeReversed rev g0 p0 = let (res, g') = route' p0 (d + 1) (Just i) Reversed rev g0 exhaust = flip adjust i $ \v -> v {vertexRouting = FromNode rev Exhausted} in case (res, et) of (DeadEnd, _) -> (res, exhaust g') (CycleDetected d' as p', _) -> (CycleDetected d' (exhaust : as) p', g') (SinkFound _, Normal) -> (res, setRoute g') (SinkFound _, Reversed) -> let f v = v { vertexEdges = withPrev (vertexEdges v), vertexRouting = NoRoute } in (res, adjust f i g') setRoute = adjust (\v -> v {vertexRouting = routing}) i routing = case prev of Nothing -> FromSource Just i' -> FromNode i' NotExhausted withPrev edges | Just i' <- prev, ToNodes es (Just es') <- edges = ToNodes es (Just $ IS.insert i' es') | otherwise = edges -- shouldn't happen routeNormals v g0 p0 | ToSink <- vertexEdges v = -- There cannot be a reversed edge to a vertex with an edge to a sink. (SinkFound i, setRoute g0) | ToNodes es nx <- vertexEdges v = let (res, g', nx') = case nx of Just es' -> routeNorms (IS.toAscList es') g0 p0 Nothing -> routeNorms (IS.toAscList es) g0 p0 edges = ToNodes es (Just $ IS.fromDistinctAscList nx') exhaust = flip adjust i $ \v' -> v' {vertexEdges = ToNodes es (Just IS.empty)} in case (res, et) of (DeadEnd, _) -> (res, exhaust g') (CycleDetected d' as p', _) -> let res' = CycleDetected d' (exhaust : as) p' v' = v {vertexEdges = edges} in (res', insert v' g') (SinkFound _, Normal) -> let v' = v {vertexEdges = edges, vertexRouting = routing} in (res, insert v' g') (SinkFound _, Reversed) -> let v' = v {vertexEdges = withPrev edges} in (res, insert v' g') routeNorms [] g0 _ = (DeadEnd, g0, []) routeNorms (e : es) g0 p0 = let (res, g1) = route' p0 (d + 1) (Just i) Normal e g0 in case res of DeadEnd -> routeNorms es g1 p0 SinkFound _ -> (res, g1, es) CycleDetected _ _ p1 -> let (res', g2, es') = routeNorms es g1 p1 in (res <> res', g2, e : es') futhark-0.25.27/src/Futhark/Optimise/Simplify.hs000066400000000000000000000120771475065116200214650ustar00rootroot00000000000000{-# LANGUAGE Strict #-} module Futhark.Optimise.Simplify ( simplifyProg, simplifySomething, simplifyFun, simplifyLambda, simplifyStms, Engine.SimpleOps (..), Engine.SimpleM, Engine.SimplifyOp, Engine.bindableSimpleOps, Engine.noExtraHoistBlockers, Engine.neverHoist, Engine.SimplifiableRep, Engine.HoistBlockers, RuleBook, ) where import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Analysis.UsageTable qualified as UT import Futhark.IR import Futhark.MonadFreshNames import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Optimise.Simplify.Rep import Futhark.Optimise.Simplify.Rule import Futhark.Pass funDefUsages :: (FreeIn a) => [a] -> UT.UsageTable funDefUsages funs = -- XXX: treat every constant used in the functions as consumed, -- because it is otherwise complicated to ensure we do not introduce -- more aliasing than specified by the return types. CSE has the -- same problem. let free_in_funs = foldMap freeIn funs in UT.usages free_in_funs <> foldMap UT.consumedUsage (namesToList free_in_funs) -- | Simplify the given program. Even if the output differs from the -- output, meaningful simplification may not have taken place - the -- order of bindings may simply have been rearranged. simplifyProg :: (Engine.SimplifiableRep rep) => Engine.SimpleOps rep -> RuleBook (Engine.Wise rep) -> Engine.HoistBlockers rep -> Prog rep -> PassM (Prog rep) simplifyProg simpl rules blockers prog = do let consts = progConsts prog funs = progFuns prog (consts_vtable, consts') <- simplifyConsts (funDefUsages funs) (mempty, informStms consts) -- We deepen the vtable so it will look like the constants are in an -- "outer loop"; this communicates useful information to some -- simplification rules (e.g. see issue #1302). funs' <- parPass (simplifyFun' (ST.deepen consts_vtable) . informFunDef) funs (_, consts'') <- simplifyConsts (funDefUsages funs') (mempty, consts') pure $ prog { progConsts = fmap removeStmWisdom consts'', progFuns = fmap removeFunDefWisdom funs' } where simplifyFun' consts_vtable = simplifySomething (Engine.localVtable (consts_vtable <>) . Engine.simplifyFun) id simpl rules blockers mempty simplifyConsts uses = simplifySomething (onConsts uses . snd) id simpl rules blockers mempty onConsts uses consts' = do consts'' <- Engine.simplifyStmsWithUsage uses consts' pure (ST.insertStms consts'' mempty, consts'') -- | Run a simplification operation to convergence. simplifySomething :: (MonadFreshNames m, Engine.SimplifiableRep rep) => (a -> Engine.SimpleM rep b) -> (b -> a) -> Engine.SimpleOps rep -> RuleBook (Wise rep) -> Engine.HoistBlockers rep -> ST.SymbolTable (Wise rep) -> a -> m a simplifySomething f g simpl rules blockers vtable x = do let f' x' = Engine.localVtable (vtable <>) $ f x' loopUntilConvergence env simpl f' g x where env = Engine.emptyEnv rules blockers -- | Simplify the given function. Even if the output differs from the -- output, meaningful simplification may not have taken place - the -- order of bindings may simply have been rearranged. Runs in a loop -- until convergence. simplifyFun :: (MonadFreshNames m, Engine.SimplifiableRep rep) => Engine.SimpleOps rep -> RuleBook (Engine.Wise rep) -> Engine.HoistBlockers rep -> ST.SymbolTable (Wise rep) -> FunDef rep -> m (FunDef rep) simplifyFun simpl rules blockers vtable fd = removeFunDefWisdom <$> simplifySomething Engine.simplifyFun id simpl rules blockers vtable (informFunDef fd) -- | Simplify just a single t'Lambda'. simplifyLambda :: ( MonadFreshNames m, HasScope rep m, Engine.SimplifiableRep rep ) => Engine.SimpleOps rep -> RuleBook (Engine.Wise rep) -> Engine.HoistBlockers rep -> Lambda rep -> m (Lambda rep) simplifyLambda simpl rules blockers orig_lam = do vtable <- ST.fromScope . addScopeWisdom <$> askScope removeLambdaWisdom <$> simplifySomething Engine.simplifyLambdaNoHoisting id simpl rules blockers vtable (informLambda orig_lam) -- | Simplify a sequence of 'Stm's. simplifyStms :: (MonadFreshNames m, Engine.SimplifiableRep rep) => Engine.SimpleOps rep -> RuleBook (Engine.Wise rep) -> Engine.HoistBlockers rep -> Scope rep -> Stms rep -> m (Stms rep) simplifyStms simpl rules blockers scope = fmap (fmap removeStmWisdom) . simplifySomething Engine.simplifyStms id simpl rules blockers vtable . informStms where vtable = ST.fromScope $ addScopeWisdom scope loopUntilConvergence :: (MonadFreshNames m, Engine.SimplifiableRep rep) => Engine.Env rep -> Engine.SimpleOps rep -> (a -> Engine.SimpleM rep b) -> (b -> a) -> a -> m a loopUntilConvergence env simpl f g x = do (x', changed) <- modifyNameSource $ Engine.runSimpleM (f x) simpl env if changed then loopUntilConvergence env simpl f g (g x') else pure $ g x' futhark-0.25.27/src/Futhark/Optimise/Simplify/000077500000000000000000000000001475065116200211225ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Optimise/Simplify/Engine.hs000066400000000000000000001107471475065116200226750ustar00rootroot00000000000000{-# LANGUAGE Strict #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} -- | -- -- Perform general rule-based simplification based on data dependency -- information. This module will: -- -- * Perform common-subexpression elimination (CSE). -- -- * Hoist expressions out of loops (including lambdas) and -- branches. This is done as aggressively as possible. -- -- * Apply simplification rules (see -- "Futhark.Optimise.Simplification.Rules"). -- -- If you just want to run the simplifier as simply as possible, you -- may prefer to use the "Futhark.Optimise.Simplify" module. module Futhark.Optimise.Simplify.Engine ( -- * Monadic interface SimpleM, runSimpleM, SimpleOps (..), SimplifyOp, bindableSimpleOps, Env (envHoistBlockers, envRules), emptyEnv, HoistBlockers (..), neverBlocks, noExtraHoistBlockers, neverHoist, BlockPred, orIf, hasFree, isConsumed, isConsuming, isFalse, isOp, isNotSafe, isDeviceMigrated, asksEngineEnv, askVtable, localVtable, -- * Building blocks SimplifiableRep, Simplifiable (..), simplifyFun, simplifyStms, simplifyStmsWithUsage, simplifyLambda, simplifyLambdaNoHoisting, bindLParams, simplifyBody, ST.SymbolTable, hoistStms, blockIf, blockMigrated, enterLoop, constructBody, module Futhark.Optimise.Simplify.Rep, ) where import Control.Monad import Control.Monad.Reader import Control.Monad.State.Strict import Data.Bitraversable import Data.Either import Data.List (find, foldl', inits, mapAccumL) import Data.Map qualified as M import Data.Maybe import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Analysis.UsageTable qualified as UT import Futhark.Construct import Futhark.IR import Futhark.IR.Prop.Aliases import Futhark.Optimise.Simplify.Rep import Futhark.Optimise.Simplify.Rule import Futhark.Util (nubOrd) data HoistBlockers rep = HoistBlockers { -- | Blocker for hoisting out of parallel loops. blockHoistPar :: BlockPred (Wise rep), -- | Blocker for hoisting out of sequential loops. blockHoistSeq :: BlockPred (Wise rep), -- | Blocker for hoisting out of branches. blockHoistBranch :: BlockPred (Wise rep), isAllocation :: Stm (Wise rep) -> Bool } noExtraHoistBlockers :: HoistBlockers rep noExtraHoistBlockers = HoistBlockers neverBlocks neverBlocks neverBlocks (const False) neverHoist :: HoistBlockers rep neverHoist = HoistBlockers alwaysBlocks alwaysBlocks alwaysBlocks (const False) data Env rep = Env { envRules :: RuleBook (Wise rep), envHoistBlockers :: HoistBlockers rep, envVtable :: ST.SymbolTable (Wise rep) } emptyEnv :: RuleBook (Wise rep) -> HoistBlockers rep -> Env rep emptyEnv rules blockers = Env { envRules = rules, envHoistBlockers = blockers, envVtable = mempty } -- | A function that protects a hoisted operation (if possible). The -- first operand is the condition of the 'Case' we have hoisted out of -- (or equivalently, a boolean indicating whether a loop has nonzero -- trip count). type Protect m = SubExp -> Pat (LetDec (Rep m)) -> Op (Rep m) -> Maybe (m ()) type SimplifyOp rep op = op -> SimpleM rep (op, Stms (Wise rep)) data SimpleOps rep = SimpleOps { mkExpDecS :: ST.SymbolTable (Wise rep) -> Pat (LetDec (Wise rep)) -> Exp (Wise rep) -> SimpleM rep (ExpDec (Wise rep)), mkBodyS :: ST.SymbolTable (Wise rep) -> Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)), -- | Make a hoisted Op safe. The SubExp is a boolean -- that is true when the value of the statement will -- actually be used. protectHoistedOpS :: Protect (Builder (Wise rep)), opUsageS :: Op (Wise rep) -> UT.UsageTable, simplifyOpS :: SimplifyOp rep (Op (Wise rep)) } bindableSimpleOps :: (SimplifiableRep rep, Buildable rep) => SimplifyOp rep (Op (Wise rep)) -> SimpleOps rep bindableSimpleOps = SimpleOps mkExpDecS' mkBodyS' protectHoistedOpS' (const mempty) where mkExpDecS' _ pat e = pure $ mkExpDec pat e mkBodyS' _ stms res = pure $ mkBody stms res protectHoistedOpS' _ _ _ = Nothing newtype SimpleM rep a = SimpleM ( ReaderT (SimpleOps rep, Env rep) (State (VNameSource, Bool, Certs)) a ) deriving ( Applicative, Functor, Monad, MonadReader (SimpleOps rep, Env rep), MonadState (VNameSource, Bool, Certs) ) instance MonadFreshNames (SimpleM rep) where putNameSource src = modify $ \(_, b, c) -> (src, b, c) getNameSource = gets $ \(a, _, _) -> a instance (SimplifiableRep rep) => HasScope (Wise rep) (SimpleM rep) where askScope = ST.toScope <$> askVtable lookupType name = do vtable <- askVtable case ST.lookupType name vtable of Just t -> pure t Nothing -> error $ "SimpleM.lookupType: cannot find variable " ++ prettyString name ++ " in symbol table." instance (SimplifiableRep rep) => LocalScope (Wise rep) (SimpleM rep) where localScope types = localVtable (<> ST.fromScope types) runSimpleM :: SimpleM rep a -> SimpleOps rep -> Env rep -> VNameSource -> ((a, Bool), VNameSource) runSimpleM (SimpleM m) simpl env src = let (x, (src', b, _)) = runState (runReaderT m (simpl, env)) (src, False, mempty) in ((x, b), src') askEngineEnv :: SimpleM rep (Env rep) askEngineEnv = asks snd asksEngineEnv :: (Env rep -> a) -> SimpleM rep a asksEngineEnv f = f <$> askEngineEnv askVtable :: SimpleM rep (ST.SymbolTable (Wise rep)) askVtable = asksEngineEnv envVtable localVtable :: (ST.SymbolTable (Wise rep) -> ST.SymbolTable (Wise rep)) -> SimpleM rep a -> SimpleM rep a localVtable f = local $ \(ops, env) -> (ops, env {envVtable = f $ envVtable env}) collectCerts :: SimpleM rep a -> SimpleM rep (a, Certs) collectCerts m = do x <- m (a, b, cs) <- get put (a, b, mempty) pure (x, cs) -- | Mark that we have changed something and it would be a good idea -- to re-run the simplifier. changed :: SimpleM rep () changed = modify $ \(src, _, cs) -> (src, True, cs) usedCerts :: Certs -> SimpleM rep () usedCerts cs = modify $ \(a, b, c) -> (a, b, cs <> c) -- | Indicate in the symbol table that we have descended into a loop. enterLoop :: SimpleM rep a -> SimpleM rep a enterLoop = localVtable ST.deepen bindFParams :: (SimplifiableRep rep) => [FParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a bindFParams params = localVtable $ ST.insertFParams params bindLParams :: (SimplifiableRep rep) => [LParam (Wise rep)] -> SimpleM rep a -> SimpleM rep a bindLParams params = localVtable $ \vtable -> foldr ST.insertLParam vtable params bindMerge :: (SimplifiableRep rep) => [(FParam (Wise rep), SubExp, SubExpRes)] -> SimpleM rep a -> SimpleM rep a bindMerge = localVtable . ST.insertLoopMerge bindLoopVar :: (SimplifiableRep rep) => VName -> IntType -> SubExp -> SimpleM rep a -> SimpleM rep a bindLoopVar var it bound = localVtable $ ST.insertLoopVar var it bound makeSafe :: Exp rep -> Maybe (Exp rep) makeSafe (BasicOp (BinOp (SDiv t _) x y)) = Just $ BasicOp (BinOp (SDiv t Safe) x y) makeSafe (BasicOp (BinOp (SDivUp t _) x y)) = Just $ BasicOp (BinOp (SDivUp t Safe) x y) makeSafe (BasicOp (BinOp (SQuot t _) x y)) = Just $ BasicOp (BinOp (SQuot t Safe) x y) makeSafe (BasicOp (BinOp (UDiv t _) x y)) = Just $ BasicOp (BinOp (UDiv t Safe) x y) makeSafe (BasicOp (BinOp (UDivUp t _) x y)) = Just $ BasicOp (BinOp (UDivUp t Safe) x y) makeSafe (BasicOp (BinOp (SMod t _) x y)) = Just $ BasicOp (BinOp (SMod t Safe) x y) makeSafe (BasicOp (BinOp (SRem t _) x y)) = Just $ BasicOp (BinOp (SRem t Safe) x y) makeSafe (BasicOp (BinOp (UMod t _) x y)) = Just $ BasicOp (BinOp (UMod t Safe) x y) makeSafe _ = Nothing emptyOfType :: (MonadBuilder m) => [VName] -> Type -> m (Exp (Rep m)) emptyOfType _ Mem {} = error "emptyOfType: Cannot hoist non-existential memory." emptyOfType _ Acc {} = error "emptyOfType: Cannot hoist accumulator." emptyOfType _ (Prim pt) = pure $ BasicOp $ SubExp $ Constant $ blankPrimValue pt emptyOfType ctx_names (Array et shape _) = do let dims = map zeroIfContext $ shapeDims shape pure $ BasicOp $ Scratch et dims where zeroIfContext (Var v) | v `elem` ctx_names = intConst Int64 0 zeroIfContext se = se protectIf :: (MonadBuilder m) => Protect m -> (Exp (Rep m) -> Bool) -> SubExp -> Stm (Rep m) -> m () protectIf _ _ taken (Let pat aux (Match [cond] [Case [Just (BoolValue True)] taken_body] untaken_body (MatchDec if_ts MatchFallback))) = do cond' <- letSubExp "protect_cond_conj" $ BasicOp $ BinOp LogAnd taken cond auxing aux . letBind pat $ Match [cond'] [Case [Just (BoolValue True)] taken_body] untaken_body $ MatchDec if_ts MatchFallback protectIf _ _ taken (Let pat aux (BasicOp (Assert cond msg loc))) = do not_taken <- letSubExp "loop_not_taken" $ BasicOp $ UnOp (Neg Bool) taken cond' <- letSubExp "protect_assert_disj" $ BasicOp $ BinOp LogOr not_taken cond auxing aux $ letBind pat $ BasicOp $ Assert cond' msg loc protectIf protect _ taken (Let pat aux (Op op)) | Just m <- protect taken pat op = auxing aux m protectIf _ f taken (Let pat aux e) | f e = case makeSafe e of Just e' -> auxing aux $ letBind pat e' Nothing -> do taken_body <- eBody [pure e] untaken_body <- eBody $ map (emptyOfType $ patNames pat) (patTypes pat) if_ts <- expTypesFromPat pat auxing aux . letBind pat $ Match [taken] [Case [Just $ BoolValue True] taken_body] untaken_body $ MatchDec if_ts MatchFallback protectIf _ _ _ stm = addStm stm -- | We are willing to hoist potentially unsafe statements out of -- loops, but they must be protected by adding a branch on top of -- them. protectLoopHoisted :: (SimplifiableRep rep) => [(FParam (Wise rep), SubExp)] -> LoopForm -> SimpleM rep (a, b, Stms (Wise rep)) -> SimpleM rep (a, b, Stms (Wise rep)) protectLoopHoisted merge form m = do (x, y, stms) <- m ops <- asks $ protectHoistedOpS . fst stms' <- runBuilder_ $ do if not $ all (safeExp . stmExp) stms then do is_nonempty <- checkIfNonEmpty mapM_ (protectIf ops (not . safeExp) is_nonempty) stms else addStms stms pure (x, y, stms') where checkIfNonEmpty = case form of WhileLoop cond | Just (_, cond_init) <- find ((== cond) . paramName . fst) merge -> pure cond_init | otherwise -> pure $ constant True -- infinite loop ForLoop _ it bound -> letSubExp "loop_nonempty" $ BasicOp $ CmpOp (CmpSlt it) (intConst it 0) bound -- Produces a true subexpression if the pattern (as in a 'Case') -- matches the subexpression. matching :: (BuilderOps rep) => [(SubExp, Maybe PrimValue)] -> Builder rep SubExp matching = letSubExp "match" <=< eAll <=< sequence . mapMaybe cmp where cmp (se, Just (BoolValue True)) = Just $ pure se cmp (se, Just v) = Just . letSubExp "match_val" . BasicOp $ CmpOp (CmpEq (primValueType v)) se (Constant v) cmp (_, Nothing) = Nothing matchingExactlyThis :: (BuilderOps rep) => [SubExp] -> [[Maybe PrimValue]] -> [Maybe PrimValue] -> Builder rep SubExp matchingExactlyThis ses prior this = do prior_matches <- mapM (matching . zip ses) prior letSubExp "matching_just_this" =<< eBinOp LogAnd (eUnOp (Neg Bool) (eAny prior_matches)) (eSubExp =<< matching (zip ses this)) -- | We are willing to hoist potentially unsafe statements out of -- matches, but they must be protected by adding a branch on top of -- them. (This means such hoisting is not worth it unless they are in -- turn hoisted out of a loop somewhere.) protectCaseHoisted :: (SimplifiableRep rep) => -- | Scrutinee. [SubExp] -> -- | Pattern of previosu cases. [[Maybe PrimValue]] -> -- | Pattern of this case. [Maybe PrimValue] -> SimpleM rep (Stms (Wise rep), a) -> SimpleM rep (Stms (Wise rep), a) protectCaseHoisted ses prior vs m = do (hoisted, x) <- m ops <- asks $ protectHoistedOpS . fst hoisted' <- runBuilder_ $ do if not $ all (safeExp . stmExp) hoisted then do cond' <- matchingExactlyThis ses prior vs mapM_ (protectIf ops unsafeOrCostly cond') hoisted else addStms hoisted pure (hoisted', x) where unsafeOrCostly e = not (safeExp e) || not (cheapExp e) -- | Statements that are not worth hoisting out of loops, because they -- are unsafe, and added safety (by 'protectLoopHoisted') may inhibit -- further optimisation. notWorthHoisting :: (ASTRep rep) => BlockPred rep notWorthHoisting _ _ (Let pat _ e) = not (safeExp e) && any ((> 0) . arrayRank) (patTypes pat) -- Top-down simplify a statement (including copy propagation into the -- pattern and such). Does not recurse into any sub-Bodies or Ops. nonrecSimplifyStm :: (SimplifiableRep rep) => Stm (Wise rep) -> SimpleM rep (Stm (Wise rep)) nonrecSimplifyStm (Let pat (StmAux cs attrs (_, dec)) e) = do cs' <- simplify cs e' <- simplifyExpBase e (pat', pat_cs) <- collectCerts $ traverse simplify $ removePatWisdom pat let aux' = StmAux (cs' <> pat_cs) attrs dec pure $ mkWiseStm pat' aux' e' -- Bottom-up simplify a statement. Recurses into sub-Bodies and Ops. -- Does not copy-propagate into the pattern and similar, as it is -- assumed 'nonrecSimplifyStm' has already touched it (and worst case, -- it'll get it on the next round of the overall fixpoint iteration.) recSimplifyStm :: (SimplifiableRep rep) => Stm (Wise rep) -> UT.UsageTable -> SimpleM rep (Stms (Wise rep), Stm (Wise rep)) recSimplifyStm (Let pat (StmAux cs attrs (_, dec)) e) usage = do ((e', e_hoisted), e_cs) <- collectCerts $ simplifyExp (usage <> UT.usageInPat pat) pat e let aux' = StmAux (cs <> e_cs) attrs dec pure (e_hoisted, mkWiseStm (removePatWisdom pat) aux' e') hoistStms :: (SimplifiableRep rep) => RuleBook (Wise rep) -> BlockPred (Wise rep) -> Stms (Wise rep) -> SimpleM rep (a, UT.UsageTable) -> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep)) hoistStms rules block orig_stms final = do (a, blocked, hoisted) <- simplifyStmsBottomUp orig_stms unless (null hoisted) changed pure (a, stmsFromList blocked, stmsFromList hoisted) where simplifyStmsBottomUp stms = do opUsage <- asks $ opUsageS . fst let usageInStm stm = UT.usageInStm stm <> case stmExp stm of Op op -> opUsage op _ -> mempty (x, _, stms') <- hoistableStms usageInStm stms -- We need to do a final pass to ensure that nothing is -- hoisted past something that it depends on. let (blocked, hoisted) = partitionEithers $ blockUnhoistedDeps stms' pure (x, blocked, hoisted) descend usageInStm stms m = case stmsHead stms of Nothing -> m Just (stms_h, stms_t) -> localVtable (ST.insertStm stms_h) $ do (x, usage, stms_t') <- descend usageInStm stms_t m process usageInStm stms_h stms_t' usage x process usageInStm stm stms usage x = do vtable <- askVtable res <- bottomUpSimplifyStm rules (vtable, usage) stm case res of Nothing -- Nothing to optimise - see if hoistable. | block vtable usage stm -> -- No, not hoistable. pure ( x, expandUsage usageInStm vtable usage stm `UT.without` provides stm, Left stm : stms ) | otherwise -> -- Yes, hoistable. pure ( x, expandUsage usageInStm vtable usage stm, Right stm : stms ) Just optimstms -> do changed descend usageInStm optimstms $ pure (x, usage, stms) hoistableStms usageInStm stms = case stmsHead stms of Nothing -> do (x, usage) <- final pure (x, usage, mempty) Just (stms_h, stms_t) -> do stms_h' <- nonrecSimplifyStm stms_h vtable <- askVtable simplified <- topDownSimplifyStm rules vtable stms_h' case simplified of Just newstms -> do changed hoistableStms usageInStm (newstms <> stms_t) Nothing -> do (x, usage, stms_t') <- localVtable (ST.insertStm stms_h') $ hoistableStms usageInStm stms_t if not $ any (`UT.isUsedDirectly` usage) $ provides stms_h' then -- Dead statement. pure (x, usage, stms_t') else do (stms_h_stms, stms_h'') <- recSimplifyStm stms_h' usage descend usageInStm stms_h_stms $ process usageInStm stms_h'' stms_t' usage x blockUnhoistedDeps :: (ASTRep rep) => [Either (Stm rep) (Stm rep)] -> [Either (Stm rep) (Stm rep)] blockUnhoistedDeps = snd . mapAccumL block mempty where block blocked (Left need) = (blocked <> namesFromList (provides need), Left need) block blocked (Right need) | blocked `namesIntersect` freeIn need = (blocked <> namesFromList (provides need), Left need) | otherwise = (blocked, Right need) provides :: Stm rep -> [VName] provides = patNames . stmPat expandUsage :: (Aliased rep) => (Stm rep -> UT.UsageTable) -> ST.SymbolTable rep -> UT.UsageTable -> Stm rep -> UT.UsageTable expandUsage usageInStm vtable utable stm@(Let pat aux e) = stmUsages <> utable where stmUsages = UT.expand (`ST.lookupAliases` vtable) (usageInStm stm <> usageThroughAliases) <> ( if any (`UT.isSize` utable) (patNames pat) then UT.sizeUsages (freeIn (stmAuxCerts aux) <> freeIn e) else mempty ) usageThroughAliases = mconcat . mapMaybe usageThroughBindeeAliases $ zip (patNames pat) (patAliases pat) usageThroughBindeeAliases (name, aliases) = do uses <- UT.lookup name utable pure . mconcat $ map (`UT.usage` (uses `UT.withoutU` UT.presentU)) $ namesToList aliases type BlockPred rep = ST.SymbolTable rep -> UT.UsageTable -> Stm rep -> Bool neverBlocks :: BlockPred rep neverBlocks _ _ _ = False alwaysBlocks :: BlockPred rep alwaysBlocks _ _ _ = True isFalse :: Bool -> BlockPred rep isFalse b _ _ _ = not b orIf :: BlockPred rep -> BlockPred rep -> BlockPred rep orIf p1 p2 body vtable need = p1 body vtable need || p2 body vtable need andAlso :: BlockPred rep -> BlockPred rep -> BlockPred rep andAlso p1 p2 body vtable need = p1 body vtable need && p2 body vtable need isConsumed :: BlockPred rep isConsumed _ utable = any (`UT.isConsumed` utable) . patNames . stmPat isOp :: BlockPred rep isOp _ _ (Let _ _ Op {}) = True isOp _ _ _ = False constructBody :: (SimplifiableRep rep) => Stms (Wise rep) -> Result -> SimpleM rep (Body (Wise rep)) constructBody stms res = fmap fst . runBuilder . buildBody_ $ do addStms stms pure res blockIf :: (SimplifiableRep rep) => BlockPred (Wise rep) -> Stms (Wise rep) -> SimpleM rep (a, UT.UsageTable) -> SimpleM rep (a, Stms (Wise rep), Stms (Wise rep)) blockIf block stms m = do rules <- asksEngineEnv envRules hoistStms rules block stms m hasFree :: (ASTRep rep) => Names -> BlockPred rep hasFree ks _ _ need = ks `namesIntersect` freeIn need isNotSafe :: (ASTRep rep) => BlockPred rep isNotSafe _ _ = not . safeExp . stmExp isConsuming :: (Aliased rep) => BlockPred rep isConsuming _ _ = isUpdate . stmExp where isUpdate e = consumedInExp e /= mempty isNotCheap :: (ASTRep rep) => BlockPred rep isNotCheap _ _ = not . cheapStm cheapStm :: (ASTRep rep) => Stm rep -> Bool cheapStm = cheapExp . stmExp cheapExp :: (ASTRep rep) => Exp rep -> Bool cheapExp (BasicOp BinOp {}) = True cheapExp (BasicOp SubExp {}) = True cheapExp (BasicOp UnOp {}) = True cheapExp (BasicOp CmpOp {}) = True cheapExp (BasicOp ConvOp {}) = True cheapExp (BasicOp Assert {}) = True cheapExp (BasicOp Replicate {}) = False cheapExp (BasicOp Concat {}) = False cheapExp (BasicOp Manifest {}) = False cheapExp Loop {} = False cheapExp (Match _ cases defbranch _) = all (all cheapStm . bodyStms . caseBody) cases && all cheapStm (bodyStms defbranch) cheapExp (Op op) = cheapOp op cheapExp _ = True -- Used to be False, but -- let's try it out. loopInvariantStm :: (ASTRep rep) => ST.SymbolTable rep -> Stm rep -> Bool loopInvariantStm vtable = all (`nameIn` ST.availableAtClosestLoop vtable) . namesToList . freeIn matchBlocker :: (SimplifiableRep rep) => [SubExp] -> MatchDec rt -> SimpleM rep (BlockPred (Wise rep)) matchBlocker cond (MatchDec _ ifsort) = do is_alloc_fun <- asksEngineEnv $ isAllocation . envHoistBlockers branch_blocker <- asksEngineEnv $ blockHoistBranch . envHoistBlockers vtable <- askVtable let -- We are unwilling to hoist things that are unsafe or costly, -- except if they are invariant to the most enclosing loop, -- because in that case they will also be hoisted past that -- loop. -- -- We also try very hard to hoist allocations or anything that -- contributes to memory or array size, because that will allow -- allocations to be hoisted. cond_loop_invariant = all (`nameIn` ST.availableAtClosestLoop vtable) $ namesToList $ freeIn cond desirableToHoist usage stm = is_alloc_fun stm || ( ST.loopDepth vtable > 0 && cond_loop_invariant && ifsort /= MatchFallback && loopInvariantStm vtable stm -- Avoid hoisting out something that might change the -- asymptotics of the program. && ( all primType (patTypes (stmPat stm)) || (ifsort == MatchEquiv && isManifest (stmExp stm)) ) ) || ( ifsort /= MatchFallback && any (`UT.isSize` usage) (patNames (stmPat stm)) && all primType (patTypes (stmPat stm)) ) notDesirableToHoist _ usage stm = not $ desirableToHoist usage stm -- No matter what, we always want to hoist constants as much as -- possible. isNotHoistableBnd _ _ (Let _ _ (BasicOp ArrayLit {})) = False isNotHoistableBnd _ _ (Let _ _ (BasicOp SubExp {})) = False -- Hoist things that are free. isNotHoistableBnd _ _ (Let _ _ (BasicOp Reshape {})) = False isNotHoistableBnd _ _ (Let _ _ (BasicOp Rearrange {})) = False isNotHoistableBnd _ _ (Let _ _ (BasicOp (Index _ slice))) = null $ sliceDims slice -- isNotHoistableBnd _ _ stm | is_alloc_fun stm = False isNotHoistableBnd _ _ _ = -- Hoist aggressively out of versioning branches. ifsort /= MatchEquiv isManifest (BasicOp Manifest {}) = True isManifest _ = False block = branch_blocker `orIf` ( (isNotSafe `orIf` isNotCheap `orIf` isNotHoistableBnd) `andAlso` notDesirableToHoist ) `orIf` isConsuming pure block -- | Simplify a single body. simplifyBody :: (SimplifiableRep rep) => BlockPred (Wise rep) -> UT.UsageTable -> [UT.Usages] -> Body (Wise rep) -> SimpleM rep (Stms (Wise rep), Body (Wise rep)) simplifyBody blocker usage res_usages (Body _ stms res) = do (res', stms', hoisted) <- blockIf blocker stms $ do (res', res_usage) <- simplifyResult res_usages res pure (res', res_usage <> usage) body' <- constructBody stms' res' pure (hoisted, body') -- | Simplify a single body. simplifyBodyNoHoisting :: (SimplifiableRep rep) => UT.UsageTable -> [UT.Usages] -> Body (Wise rep) -> SimpleM rep (Body (Wise rep)) simplifyBodyNoHoisting usage res_usages body = snd <$> simplifyBody (isFalse False) usage res_usages body usageFromDiet :: Diet -> UT.Usages usageFromDiet Consume = UT.consumedU usageFromDiet _ = mempty -- | Simplify a single 'Result'. simplifyResult :: (SimplifiableRep rep) => [UT.Usages] -> Result -> SimpleM rep (Result, UT.UsageTable) simplifyResult usages res = do res' <- mapM simplify res vtable <- askVtable let more_usages = mconcat $ do (u, Var v) <- zip usages $ map resSubExp res let als_usages = map (`UT.usage` (u `UT.withoutU` UT.presentU)) (namesToList (ST.lookupAliases v vtable)) UT.usage v u : als_usages pure ( res', UT.usages (freeIn res') <> foldMap UT.inResultUsage (namesToList (freeIn res')) <> more_usages ) isLoopResult :: Result -> UT.UsageTable isLoopResult = mconcat . map checkForVar where checkForVar (SubExpRes _ (Var ident)) = UT.inResultUsage ident checkForVar _ = mempty simplifyStms :: (SimplifiableRep rep) => Stms (Wise rep) -> SimpleM rep (Stms (Wise rep)) simplifyStms stms = simplifyStmsWithUsage usage stms where -- XXX: treat everything as consumed, because when these are -- constants it is otherwise complicated to ensure we do not -- introduce more aliasing than specified by the return types. -- CSE has the same problem. all_bound = M.keys (scopeOf stms) usage = UT.usages (namesFromList all_bound) <> foldMap UT.consumedUsage all_bound simplifyStmsWithUsage :: (SimplifiableRep rep) => UT.UsageTable -> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep)) simplifyStmsWithUsage usage stms = do ((), stms', _) <- blockIf (isFalse False) stms $ pure ((), usage) pure stms' simplifyOp :: Op (Wise rep) -> SimpleM rep (Op (Wise rep), Stms (Wise rep)) simplifyOp op = do f <- asks $ simplifyOpS . fst f op simplifyExp :: (SimplifiableRep rep) => UT.UsageTable -> Pat (LetDec (Wise rep)) -> Exp (Wise rep) -> SimpleM rep (Exp (Wise rep), Stms (Wise rep)) simplifyExp usage (Pat pes) (Match ses cases defbody ifdec@(MatchDec ts ifsort)) = do let pes_usages = map (fromMaybe mempty . (`UT.lookup` usage) . patElemName) pes ses' <- mapM simplify ses ts' <- mapM simplify ts let pats = map casePat cases block <- matchBlocker ses ifdec (cases_hoisted, cases') <- unzip <$> zipWithM (simplifyCase block ses' pes_usages) (inits pats) cases (defbody_hoisted, defbody') <- protectCaseHoisted ses' pats [] $ simplifyBody block usage pes_usages defbody pure ( Match ses' cases' defbody' $ MatchDec ts' ifsort, mconcat $ defbody_hoisted : cases_hoisted ) where simplifyCase block ses' pes_usages prior (Case vs body) = do (hoisted, body') <- protectCaseHoisted ses' prior vs $ simplifyBody block usage pes_usages body pure (hoisted, Case vs body') simplifyExp _ _ (Loop merge form loopbody) = do let (params, args) = unzip merge params' <- mapM (traverse simplify) params args' <- mapM simplify args let merge' = zip params' args' (form', boundnames, wrapbody) <- case form of ForLoop loopvar it boundexp -> do boundexp' <- simplify boundexp let form' = ForLoop loopvar it boundexp' pure ( form', oneName loopvar <> fparamnames, bindLoopVar loopvar it boundexp' . protectLoopHoisted merge' form' ) WhileLoop cond -> do cond' <- simplify cond pure ( WhileLoop cond', fparamnames, protectLoopHoisted merge' (WhileLoop cond') ) seq_blocker <- asksEngineEnv $ blockHoistSeq . envHoistBlockers (loopres, loopstms, hoisted) <- enterLoop . consumeMerge $ bindMerge (zipWith withRes merge' (bodyResult loopbody)) . wrapbody $ blockIf ( hasFree boundnames `orIf` isConsumed `orIf` seq_blocker `orIf` notWorthHoisting ) (bodyStms loopbody) $ do let params_usages = map (\p -> if unique (paramDeclType p) then UT.consumedU else mempty) params' (res, uses) <- simplifyResult params_usages $ bodyResult loopbody pure (res, uses <> isLoopResult res) loopbody' <- constructBody loopstms loopres pure (Loop merge' form' loopbody', hoisted) where fparamnames = namesFromList (map (paramName . fst) merge) consumeMerge = localVtable $ flip (foldl' (flip ST.consume)) $ namesToList consumed_by_merge consumed_by_merge = freeIn $ map snd $ filter (unique . paramDeclType . fst) merge withRes (p, x) y = (p, x, y) simplifyExp _ _ (Op op) = do (op', stms) <- simplifyOp op pure (Op op', stms) simplifyExp usage _ (WithAcc inputs lam) = do (inputs', inputs_stms) <- fmap unzip . forM inputs $ \(shape, arrs, op) -> do (op', op_stms) <- case op of Nothing -> pure (Nothing, mempty) Just (op_lam, nes) -> do (op_lam', op_lam_stms) <- blockMigrated (simplifyLambda mempty op_lam) nes' <- simplify nes pure (Just (op_lam', nes'), op_lam_stms) (,op_stms) <$> ((,,op') <$> simplify shape <*> simplify arrs) let noteAcc = ST.noteAccTokens (zip (map paramName (lambdaParams lam)) inputs') (lam', lam_stms) <- consumeInput inputs' $ simplifyLambdaWith noteAcc (isFalse True) usage lam pure (WithAcc inputs' lam', mconcat inputs_stms <> lam_stms) where inputArrs (_, arrs, _) = arrs consumeInput = localVtable . flip (foldl' (flip ST.consume)) . concatMap inputArrs simplifyExp _ _ e = do e' <- simplifyExpBase e pure (e', mempty) -- | Block hoisting of 'Index' statements introduced by migration. blockMigrated :: (SimplifiableRep rep) => SimpleM rep (Lambda (Wise rep), Stms (Wise rep)) -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)) blockMigrated = local withMigrationBlocker where withMigrationBlocker (ops, env) = let blockers = envHoistBlockers env par_blocker = blockHoistPar blockers blocker = par_blocker `orIf` isDeviceMigrated blockers' = blockers {blockHoistPar = blocker} env' = env {envHoistBlockers = blockers'} in (ops, env') -- | Statement is a scalar read from a single element array of rank one. isDeviceMigrated :: (SimplifiableRep rep) => BlockPred (Wise rep) isDeviceMigrated vtable _ stm | BasicOp (Index arr slice) <- stmExp stm, [DimFix idx] <- unSlice slice, idx == intConst Int64 0, Just arr_t <- ST.lookupType arr vtable, [size] <- arrayDims arr_t, size == intConst Int64 1 = True | otherwise = False -- The simple nonrecursive case that we can perform without bottom-up -- information. simplifyExpBase :: (SimplifiableRep rep) => Exp (Wise rep) -> SimpleM rep (Exp (Wise rep)) -- Special case for simplification of commutative BinOps where we -- arrange the operands in sorted order. This can make expressions -- more identical, which helps CSE. simplifyExpBase (BasicOp (BinOp op x y)) | commutativeBinOp op = do x' <- simplify x y' <- simplify y pure $ BasicOp $ BinOp op (min x' y') (max x' y') simplifyExpBase e = mapExpM hoist e where hoist = identityMapper { mapOnSubExp = simplify, mapOnVName = simplify, mapOnRetType = simplify, mapOnBranchType = simplify } type SimplifiableRep rep = ( ASTRep rep, Simplifiable (LetDec rep), Simplifiable (FParamInfo rep), Simplifiable (LParamInfo rep), Simplifiable (RetType rep), Simplifiable (BranchType rep), TraverseOpStms (Wise rep), CanBeWise (OpC rep), ST.IndexOp (Op (Wise rep)), IsOp (OpC rep), ASTConstraints (OpC rep (Wise rep)), AliasedOp (OpC (Wise rep)), RephraseOp (OpC rep), BuilderOps (Wise rep), IsOp (OpC rep) ) class Simplifiable e where simplify :: (SimplifiableRep rep) => e -> SimpleM rep e instance (Simplifiable a, Simplifiable b) => Simplifiable (a, b) where simplify (x, y) = (,) <$> simplify x <*> simplify y instance (Simplifiable a, Simplifiable b, Simplifiable c) => Simplifiable (a, b, c) where simplify (x, y, z) = (,,) <$> simplify x <*> simplify y <*> simplify z -- Convenient for Scatter. instance Simplifiable Int where simplify = pure instance (Simplifiable a) => Simplifiable (Maybe a) where simplify Nothing = pure Nothing simplify (Just x) = Just <$> simplify x instance (Simplifiable a) => Simplifiable [a] where simplify = mapM simplify instance Simplifiable SubExp where simplify (Var name) = do stm <- ST.lookupSubExp name <$> askVtable case stm of Just (Constant v, cs) -> do changed usedCerts cs pure $ Constant v Just (Var id', cs) -> do changed usedCerts cs pure $ Var id' _ -> pure $ Var name simplify (Constant v) = pure $ Constant v instance Simplifiable SubExpRes where simplify (SubExpRes cs se) = do cs' <- simplify cs (se', se_cs) <- collectCerts $ simplify se pure $ SubExpRes (se_cs <> cs') se' instance Simplifiable () where simplify = pure instance Simplifiable VName where simplify v = do se <- ST.lookupSubExp v <$> askVtable case se of Just (Var v', cs) -> do changed usedCerts cs pure v' _ -> pure v instance (Simplifiable d) => Simplifiable (ShapeBase d) where simplify = fmap Shape . simplify . shapeDims instance Simplifiable ExtSize where simplify (Free se) = Free <$> simplify se simplify (Ext x) = pure $ Ext x instance Simplifiable Space where simplify (ScalarSpace ds t) = ScalarSpace <$> simplify ds <*> pure t simplify s = pure s instance Simplifiable PrimType where simplify = pure instance (Simplifiable shape) => Simplifiable (TypeBase shape u) where simplify (Array et shape u) = Array <$> simplify et <*> simplify shape <*> pure u simplify (Acc acc ispace ts u) = Acc <$> simplify acc <*> simplify ispace <*> simplify ts <*> pure u simplify (Mem space) = Mem <$> simplify space simplify (Prim bt) = pure $ Prim bt instance (Simplifiable d) => Simplifiable (DimIndex d) where simplify (DimFix i) = DimFix <$> simplify i simplify (DimSlice i n s) = DimSlice <$> simplify i <*> simplify n <*> simplify s instance (Simplifiable d) => Simplifiable (Slice d) where simplify = traverse simplify simplifyLambda :: (SimplifiableRep rep) => Names -> Lambda (Wise rep) -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)) simplifyLambda extra_bound lam = do par_blocker <- asksEngineEnv $ blockHoistPar . envHoistBlockers simplifyLambdaMaybeHoist (par_blocker `orIf` hasFree extra_bound) mempty lam simplifyLambdaNoHoisting :: (SimplifiableRep rep) => Lambda (Wise rep) -> SimpleM rep (Lambda (Wise rep)) simplifyLambdaNoHoisting lam = fst <$> simplifyLambdaMaybeHoist (isFalse False) mempty lam simplifyLambdaMaybeHoist :: (SimplifiableRep rep) => BlockPred (Wise rep) -> UT.UsageTable -> Lambda (Wise rep) -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)) simplifyLambdaMaybeHoist = simplifyLambdaWith id simplifyLambdaWith :: (SimplifiableRep rep) => (ST.SymbolTable (Wise rep) -> ST.SymbolTable (Wise rep)) -> BlockPred (Wise rep) -> UT.UsageTable -> Lambda (Wise rep) -> SimpleM rep (Lambda (Wise rep), Stms (Wise rep)) simplifyLambdaWith f blocked usage lam@(Lambda params rettype body) = do params' <- mapM (traverse simplify) params let paramnames = namesFromList $ boundByLambda lam (hoisted, body') <- bindLParams params' . localVtable f $ simplifyBody (blocked `orIf` hasFree paramnames `orIf` isConsumed) usage (map (const mempty) rettype) body rettype' <- simplify rettype pure (Lambda params' rettype' body', hoisted) instance Simplifiable Certs where simplify (Certs ocs) = Certs . nubOrd . concat <$> mapM check ocs where check idd = do vv <- ST.lookupSubExp idd <$> askVtable case vv of Just (Constant _, Certs cs) -> pure cs Just (Var idd', _) -> pure [idd'] _ -> pure [idd] simplifyFun :: (SimplifiableRep rep) => FunDef (Wise rep) -> SimpleM rep (FunDef (Wise rep)) simplifyFun (FunDef entry attrs fname rettype params body) = do rettype' <- mapM (bitraverse simplify pure) rettype params' <- mapM (traverse simplify) params let usages = map usageFromRet rettype' body' <- bindFParams params $ simplifyBodyNoHoisting mempty usages body pure $ FunDef entry attrs fname rettype' params' body' where aliasable Array {} = True aliasable _ = False aliasable_params = map snd $ filter (aliasable . paramType . fst) $ zip params [0 ..] aliasable_rets = map snd $ filter (aliasable . extTypeOf . fst . fst) $ zip rettype [0 ..] restricted als = any (`notElem` als) usageFromRet (t, RetAls pals rals) = usageFromDiet (diet $ declExtTypeOf t) <> if restricted pals aliasable_params || restricted rals aliasable_rets then UT.consumedU else mempty futhark-0.25.27/src/Futhark/Optimise/Simplify/Rep.hs000066400000000000000000000244001475065116200222040ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-redundant-constraints #-} -- | Representation used by the simplification engine. It contains -- aliasing information and a bit of caching for various information -- that is looked up frequently. The name is an old relic; feel free -- to suggest a better one. module Futhark.Optimise.Simplify.Rep ( Wise, VarWisdom (..), ExpWisdom, removeStmWisdom, removeLambdaWisdom, removeFunDefWisdom, removeExpWisdom, removePatWisdom, removeBodyWisdom, removeScopeWisdom, addScopeWisdom, addWisdomToPat, mkWiseBody, mkWiseStm, mkWiseExpDec, CanBeWise (..), -- * Constructing representation Informing, informLambda, informFunDef, informStms, informBody, ) where import Control.Category import Control.Monad.Identity import Control.Monad.Reader import Data.Map.Strict qualified as M import Futhark.Builder import Futhark.IR import Futhark.IR.Aliases ( AliasDec (..), ConsumedInExp, VarAliases, unAliases, ) import Futhark.IR.Aliases qualified as Aliases import Futhark.IR.Prop.Aliases import Futhark.Transform.Rename import Futhark.Transform.Substitute import Futhark.Util.Pretty import Prelude hiding (id, (.)) -- | Representative phantom type for the simplifier representation. data Wise rep -- | The information associated with a let-bound variable. newtype VarWisdom = VarWisdom {varWisdomAliases :: VarAliases} deriving (Eq, Ord, Show) instance Rename VarWisdom where rename = substituteRename instance Substitute VarWisdom where substituteNames substs (VarWisdom als) = VarWisdom (substituteNames substs als) instance FreeIn VarWisdom where freeIn' (VarWisdom als) = freeIn' als -- | Simplifier information about an expression. data ExpWisdom = ExpWisdom { _expWisdomConsumed :: ConsumedInExp, -- | The free variables in the expression. expWisdomFree :: AliasDec } deriving (Eq, Ord, Show) instance FreeIn ExpWisdom where freeIn' = mempty instance FreeDec ExpWisdom where precomputed = const . fvNames . unAliases . expWisdomFree instance Substitute ExpWisdom where substituteNames substs (ExpWisdom cons free) = ExpWisdom (substituteNames substs cons) (substituteNames substs free) instance Rename ExpWisdom where rename = substituteRename -- | Simplifier information about a body. data BodyWisdom = BodyWisdom { bodyWisdomAliases :: [VarAliases], bodyWisdomConsumed :: ConsumedInExp, bodyWisdomFree :: AliasDec } deriving (Eq, Ord, Show) instance Rename BodyWisdom where rename = substituteRename instance Substitute BodyWisdom where substituteNames substs (BodyWisdom als cons free) = BodyWisdom (substituteNames substs als) (substituteNames substs cons) (substituteNames substs free) instance FreeIn BodyWisdom where freeIn' (BodyWisdom als cons free) = freeIn' als <> freeIn' cons <> freeIn' free instance FreeDec BodyWisdom where precomputed = const . fvNames . unAliases . bodyWisdomFree instance ( Informing rep, Ord (OpC rep (Wise rep)), Eq (OpC rep (Wise rep)), Show (OpC rep (Wise rep)), IsOp (OpC rep), Pretty (OpC rep (Wise rep)) ) => RepTypes (Wise rep) where type LetDec (Wise rep) = (VarWisdom, LetDec rep) type ExpDec (Wise rep) = (ExpWisdom, ExpDec rep) type BodyDec (Wise rep) = (BodyWisdom, BodyDec rep) type FParamInfo (Wise rep) = FParamInfo rep type LParamInfo (Wise rep) = LParamInfo rep type RetType (Wise rep) = RetType rep type BranchType (Wise rep) = BranchType rep type OpC (Wise rep) = OpC rep withoutWisdom :: (HasScope (Wise rep) m, Monad m) => ReaderT (Scope rep) m a -> m a withoutWisdom m = do scope <- asksScope removeScopeWisdom runReaderT m scope instance (Informing rep, IsOp (OpC rep)) => ASTRep (Wise rep) where expTypesFromPat = withoutWisdom . expTypesFromPat . removePatWisdom instance Pretty VarWisdom where pretty _ = pretty () instance (Informing rep, Pretty (OpC rep (Wise rep))) => PrettyRep (Wise rep) where ppExpDec (_, dec) = ppExpDec dec . removeExpWisdom instance AliasesOf (VarWisdom, dec) where aliasesOf = unAliases . varWisdomAliases . fst instance (Informing rep) => Aliased (Wise rep) where bodyAliases = map unAliases . bodyWisdomAliases . fst . bodyDec consumedInBody = unAliases . bodyWisdomConsumed . fst . bodyDec removeWisdom :: (RephraseOp (OpC rep)) => Rephraser Identity (Wise rep) rep removeWisdom = Rephraser { rephraseExpDec = pure . snd, rephraseLetBoundDec = pure . snd, rephraseBodyDec = pure . snd, rephraseFParamDec = pure, rephraseLParamDec = pure, rephraseRetType = pure, rephraseBranchType = pure, rephraseOp = rephraseInOp removeWisdom } -- | Remove simplifier information from scope. removeScopeWisdom :: Scope (Wise rep) -> Scope rep removeScopeWisdom = M.map unAlias where unAlias (LetName (_, dec)) = LetName dec unAlias (FParamName dec) = FParamName dec unAlias (LParamName dec) = LParamName dec unAlias (IndexName it) = IndexName it -- | Add simplifier information to scope. All the aliasing -- information will be vacuous, however. addScopeWisdom :: Scope rep -> Scope (Wise rep) addScopeWisdom = M.map alias where alias (LetName dec) = LetName (VarWisdom mempty, dec) alias (FParamName dec) = FParamName dec alias (LParamName dec) = LParamName dec alias (IndexName it) = IndexName it -- | Remove simplifier information from function. removeFunDefWisdom :: (RephraseOp (OpC rep)) => FunDef (Wise rep) -> FunDef rep removeFunDefWisdom = runIdentity . rephraseFunDef removeWisdom -- | Remove simplifier information from statement. removeStmWisdom :: (RephraseOp (OpC rep)) => Stm (Wise rep) -> Stm rep removeStmWisdom = runIdentity . rephraseStm removeWisdom -- | Remove simplifier information from lambda. removeLambdaWisdom :: (RephraseOp (OpC rep)) => Lambda (Wise rep) -> Lambda rep removeLambdaWisdom = runIdentity . rephraseLambda removeWisdom -- | Remove simplifier information from body. removeBodyWisdom :: (RephraseOp (OpC rep)) => Body (Wise rep) -> Body rep removeBodyWisdom = runIdentity . rephraseBody removeWisdom -- | Remove simplifier information from expression. removeExpWisdom :: (RephraseOp (OpC rep)) => Exp (Wise rep) -> Exp rep removeExpWisdom = runIdentity . rephraseExp removeWisdom -- | Remove simplifier information from pattern. removePatWisdom :: Pat (VarWisdom, a) -> Pat a removePatWisdom = runIdentity . rephrasePat (pure . snd) -- | Add simplifier information to pattern. addWisdomToPat :: (Informing rep) => Pat (LetDec rep) -> Exp (Wise rep) -> Pat (LetDec (Wise rep)) addWisdomToPat pat e = f <$> Aliases.mkAliasedPat pat e where f (als, dec) = (VarWisdom als, dec) -- | Produce a body with simplifier information. mkWiseBody :: (Informing rep) => BodyDec rep -> Stms (Wise rep) -> Result -> Body (Wise rep) mkWiseBody dec stms res = Body ( BodyWisdom aliases consumed (AliasDec $ freeIn $ freeInStmsAndRes stms res), dec ) stms res where (aliases, consumed) = Aliases.mkBodyAliasing stms res -- | Produce a statement with simplifier information. mkWiseStm :: (Informing rep) => Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp (Wise rep) -> Stm (Wise rep) mkWiseStm pat (StmAux cs attrs dec) e = let pat' = addWisdomToPat pat e in Let pat' (StmAux cs attrs $ mkWiseExpDec pat' dec e) e -- | Produce simplifier information for an expression. mkWiseExpDec :: (Informing rep) => Pat (LetDec (Wise rep)) -> ExpDec rep -> Exp (Wise rep) -> ExpDec (Wise rep) mkWiseExpDec pat expdec e = ( ExpWisdom (AliasDec $ consumedInExp e) (AliasDec $ freeIn pat <> freeIn expdec <> freeIn e), expdec ) instance (Buildable rep, Informing rep) => Buildable (Wise rep) where mkExpPat ids e = addWisdomToPat (mkExpPat ids $ removeExpWisdom e) e mkExpDec pat e = mkWiseExpDec pat (mkExpDec (removePatWisdom pat) $ removeExpWisdom e) e mkLetNames names e = do env <- asksScope removeScopeWisdom flip runReaderT env $ do Let pat dec _ <- mkLetNames names $ removeExpWisdom e pure $ mkWiseStm pat dec e mkBody stms res = let Body bodyrep _ _ = mkBody (fmap removeStmWisdom stms) res in mkWiseBody bodyrep stms res -- | Constraints that let us transform a representation into a 'Wise' -- representation. type Informing rep = ( ASTRep rep, AliasedOp (OpC rep), RephraseOp (OpC rep), CanBeWise (OpC rep), FreeIn (OpC rep (Wise rep)), ASTConstraints (OpC rep (Wise rep)) ) -- | A type class for indicating that this operation can be lifted into the simplifier representation. class CanBeWise op where addOpWisdom :: (Informing rep) => op rep -> op (Wise rep) instance CanBeWise NoOp where addOpWisdom NoOp = NoOp -- | Construct a 'Wise' statement. informStm :: (Informing rep) => Stm rep -> Stm (Wise rep) informStm (Let pat aux e) = mkWiseStm pat aux $ informExp e -- | Construct 'Wise' statements. informStms :: (Informing rep) => Stms rep -> Stms (Wise rep) informStms = fmap informStm -- | Construct a 'Wise' body. informBody :: (Informing rep) => Body rep -> Body (Wise rep) informBody (Body dec stms res) = mkWiseBody dec (informStms stms) res -- | Construct a 'Wise' lambda. informLambda :: (Informing rep) => Lambda rep -> Lambda (Wise rep) informLambda (Lambda ps ret body) = Lambda ps ret (informBody body) -- | Construct a 'Wise' expression. informExp :: (Informing rep) => Exp rep -> Exp (Wise rep) informExp (Match cond cases defbody (MatchDec ts ifsort)) = Match cond (map (fmap informBody) cases) (informBody defbody) (MatchDec ts ifsort) informExp (Loop merge form loopbody) = Loop merge form $ informBody loopbody informExp e = runIdentity $ mapExpM mapper e where mapper = Mapper { mapOnBody = const $ pure . informBody, mapOnSubExp = pure, mapOnVName = pure, mapOnRetType = pure, mapOnBranchType = pure, mapOnFParam = pure, mapOnLParam = pure, mapOnOp = pure . addOpWisdom } -- | Construct a 'Wise' function definition. informFunDef :: (Informing rep) => FunDef rep -> FunDef (Wise rep) informFunDef (FunDef entry attrs fname rettype params body) = FunDef entry attrs fname rettype params $ informBody body futhark-0.25.27/src/Futhark/Optimise/Simplify/Rule.hs000066400000000000000000000214411475065116200223670ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-redundant-constraints #-} -- | This module defines the concept of a simplification rule for -- bindings. The intent is that you pass some context (such as symbol -- table) and a binding, and is given back a sequence of bindings that -- compute the same result, but are "better" in some sense. -- -- These rewrite rules are "local", in that they do not maintain any -- state or look at the program as a whole. Compare this to the -- fusion algorithm in @Futhark.Optimise.Fusion.Fusion@, which must be implemented -- as its own pass. module Futhark.Optimise.Simplify.Rule ( -- * The rule monad RuleM, cannotSimplify, liftMaybe, -- * Rule definition Rule (..), SimplificationRule (..), RuleGeneric, RuleBasicOp, RuleMatch, RuleLoop, -- * Top-down rules TopDown, TopDownRule, TopDownRuleGeneric, TopDownRuleBasicOp, TopDownRuleMatch, TopDownRuleLoop, TopDownRuleOp, -- * Bottom-up rules BottomUp, BottomUpRule, BottomUpRuleGeneric, BottomUpRuleBasicOp, BottomUpRuleMatch, BottomUpRuleLoop, BottomUpRuleOp, -- * Assembling rules RuleBook, ruleBook, -- * Applying rules topDownSimplifyStm, bottomUpSimplifyStm, ) where import Control.Monad.State import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Analysis.UsageTable qualified as UT import Futhark.Builder import Futhark.IR -- | The monad in which simplification rules are evaluated. newtype RuleM rep a = RuleM (BuilderT rep (StateT VNameSource Maybe) a) deriving ( Functor, Applicative, Monad, MonadFreshNames, HasScope rep, LocalScope rep ) instance (BuilderOps rep) => MonadBuilder (RuleM rep) where type Rep (RuleM rep) = rep mkExpDecM pat e = RuleM $ mkExpDecM pat e mkBodyM stms res = RuleM $ mkBodyM stms res mkLetNamesM pat e = RuleM $ mkLetNamesM pat e addStms = RuleM . addStms collectStms (RuleM m) = RuleM $ collectStms m -- | Execute a 'RuleM' action. If succesful, returns the result and a -- list of new bindings. simplify :: Scope rep -> VNameSource -> Rule rep -> Maybe (Stms rep, VNameSource) simplify _ _ Skip = Nothing simplify scope src (Simplify (RuleM m)) = runStateT (runBuilderT_ m scope) src cannotSimplify :: RuleM rep a cannotSimplify = RuleM $ lift $ lift Nothing liftMaybe :: Maybe a -> RuleM rep a liftMaybe Nothing = cannotSimplify liftMaybe (Just x) = pure x -- | An efficient way of encoding whether a simplification rule should even be attempted. data Rule rep = -- | Give it a shot. Simplify (RuleM rep ()) | -- | Don't bother. Skip type RuleGeneric rep a = a -> Stm rep -> Rule rep type RuleBasicOp rep a = ( a -> Pat (LetDec rep) -> StmAux (ExpDec rep) -> BasicOp -> Rule rep ) type RuleMatch rep a = a -> Pat (LetDec rep) -> StmAux (ExpDec rep) -> ( [SubExp], [Case (Body rep)], Body rep, MatchDec (BranchType rep) ) -> Rule rep type RuleLoop rep a = a -> Pat (LetDec rep) -> StmAux (ExpDec rep) -> ( [(FParam rep, SubExp)], LoopForm, Body rep ) -> Rule rep type RuleOp rep a = a -> Pat (LetDec rep) -> StmAux (ExpDec rep) -> Op rep -> Rule rep -- | A simplification rule takes some argument and a statement, and -- tries to simplify the statement. data SimplificationRule rep a = RuleGeneric (RuleGeneric rep a) | RuleBasicOp (RuleBasicOp rep a) | RuleMatch (RuleMatch rep a) | RuleLoop (RuleLoop rep a) | RuleOp (RuleOp rep a) -- | A collection of rules grouped by which forms of statements they -- may apply to. data Rules rep a = Rules { rulesAny :: [SimplificationRule rep a], rulesBasicOp :: [SimplificationRule rep a], rulesMatch :: [SimplificationRule rep a], rulesLoop :: [SimplificationRule rep a], rulesOp :: [SimplificationRule rep a] } instance Semigroup (Rules rep a) where Rules as1 bs1 cs1 ds1 es1 <> Rules as2 bs2 cs2 ds2 es2 = Rules (as1 <> as2) (bs1 <> bs2) (cs1 <> cs2) (ds1 <> ds2) (es1 <> es2) instance Monoid (Rules rep a) where mempty = Rules mempty mempty mempty mempty mempty -- | Context for a rule applied during top-down traversal of the -- program. Takes a symbol table as argument. type TopDown rep = ST.SymbolTable rep type TopDownRuleGeneric rep = RuleGeneric rep (TopDown rep) type TopDownRuleBasicOp rep = RuleBasicOp rep (TopDown rep) type TopDownRuleMatch rep = RuleMatch rep (TopDown rep) type TopDownRuleLoop rep = RuleLoop rep (TopDown rep) type TopDownRuleOp rep = RuleOp rep (TopDown rep) type TopDownRule rep = SimplificationRule rep (TopDown rep) -- | Context for a rule applied during bottom-up traversal of the -- program. Takes a symbol table and usage table as arguments. type BottomUp rep = (ST.SymbolTable rep, UT.UsageTable) type BottomUpRuleGeneric rep = RuleGeneric rep (BottomUp rep) type BottomUpRuleBasicOp rep = RuleBasicOp rep (BottomUp rep) type BottomUpRuleMatch rep = RuleMatch rep (BottomUp rep) type BottomUpRuleLoop rep = RuleLoop rep (BottomUp rep) type BottomUpRuleOp rep = RuleOp rep (BottomUp rep) type BottomUpRule rep = SimplificationRule rep (BottomUp rep) -- | A collection of top-down rules. type TopDownRules rep = Rules rep (TopDown rep) -- | A collection of bottom-up rules. type BottomUpRules rep = Rules rep (BottomUp rep) -- | A collection of both top-down and bottom-up rules. data RuleBook rep = RuleBook { bookTopDownRules :: TopDownRules rep, bookBottomUpRules :: BottomUpRules rep } instance Semigroup (RuleBook rep) where RuleBook ts1 bs1 <> RuleBook ts2 bs2 = RuleBook (ts1 <> ts2) (bs1 <> bs2) instance Monoid (RuleBook rep) where mempty = RuleBook mempty mempty -- | Construct a rule book from a collection of rules. ruleBook :: [TopDownRule m] -> [BottomUpRule m] -> RuleBook m ruleBook topdowns bottomups = RuleBook (groupRules topdowns) (groupRules bottomups) where groupRules :: [SimplificationRule m a] -> Rules m a groupRules rs = Rules { rulesAny = rs, rulesBasicOp = filter forBasicOp rs, rulesMatch = filter forMatch rs, rulesLoop = filter forLoop rs, rulesOp = filter forOp rs } forBasicOp RuleBasicOp {} = True forBasicOp RuleGeneric {} = True forBasicOp _ = False forMatch RuleMatch {} = True forMatch RuleGeneric {} = True forMatch _ = False forLoop RuleLoop {} = True forLoop RuleGeneric {} = True forLoop _ = False forOp RuleOp {} = True forOp RuleGeneric {} = True forOp _ = False -- | @simplifyStm lookup stm@ performs simplification of the -- binding @stm@. If simplification is possible, a replacement list -- of bindings is returned, that bind at least the same names as the -- original binding (and possibly more, for intermediate results). topDownSimplifyStm :: (MonadFreshNames m, HasScope rep m, PrettyRep rep) => RuleBook rep -> ST.SymbolTable rep -> Stm rep -> m (Maybe (Stms rep)) topDownSimplifyStm = applyRules . bookTopDownRules -- | @simplifyStm uses stm@ performs simplification of the binding -- @stm@. If simplification is possible, a replacement list of -- bindings is returned, that bind at least the same names as the -- original binding (and possibly more, for intermediate results). -- The first argument is the set of names used after this binding. bottomUpSimplifyStm :: (MonadFreshNames m, HasScope rep m, PrettyRep rep) => RuleBook rep -> (ST.SymbolTable rep, UT.UsageTable) -> Stm rep -> m (Maybe (Stms rep)) bottomUpSimplifyStm = applyRules . bookBottomUpRules rulesForStm :: Stm rep -> Rules rep a -> [SimplificationRule rep a] rulesForStm stm = case stmExp stm of BasicOp {} -> rulesBasicOp Loop {} -> rulesLoop Op {} -> rulesOp Match {} -> rulesMatch _ -> rulesAny applyRule :: SimplificationRule rep a -> a -> Stm rep -> Rule rep applyRule (RuleGeneric f) a stm = f a stm applyRule (RuleBasicOp f) a (Let pat aux (BasicOp e)) = f a pat aux e applyRule (RuleLoop f) a (Let pat aux (Loop merge form body)) = f a pat aux (merge, form, body) applyRule (RuleMatch f) a (Let pat aux (Match cond cases defbody ifsort)) = f a pat aux (cond, cases, defbody, ifsort) applyRule (RuleOp f) a (Let pat aux (Op op)) = f a pat aux op applyRule _ _ _ = Skip applyRules :: (MonadFreshNames m, HasScope rep m, PrettyRep rep) => Rules rep a -> a -> Stm rep -> m (Maybe (Stms rep)) applyRules all_rules context stm = do scope <- askScope modifyNameSource $ \src -> let applyRules' [] = Nothing applyRules' (rule : rules) = case simplify scope src (applyRule rule context stm) of Just x -> Just x Nothing -> applyRules' rules in case applyRules' $ rulesForStm stm all_rules of Just (stms, src') -> (Just stms, src') Nothing -> (Nothing, src) futhark-0.25.27/src/Futhark/Optimise/Simplify/Rules.hs000066400000000000000000000276221475065116200225610ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | This module defines a collection of simplification rules, as per -- "Futhark.Optimise.Simplify.Rule". They are used in the -- simplifier. -- -- For performance reasons, many sufficiently simple logically -- separate rules are merged into single "super-rules", like ruleIf -- and ruleBasicOp. This is because it is relatively expensive to -- activate a rule just to determine that it does not apply. Thus, it -- is more efficient to have a few very fat rules than a lot of small -- rules. This does not affect the compiler result in any way; it is -- purely an optimisation to speed up compilation. module Futhark.Optimise.Simplify.Rules ( standardRules, removeUnnecessaryCopy, ) where import Control.Monad import Control.Monad.State import Data.List (insert, unzip4, zip4) import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Analysis.PrimExp.Convert import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Analysis.UsageTable qualified as UT import Futhark.Construct import Futhark.IR import Futhark.Optimise.Simplify.Rule import Futhark.Optimise.Simplify.Rules.BasicOp import Futhark.Optimise.Simplify.Rules.Index import Futhark.Optimise.Simplify.Rules.Loop import Futhark.Optimise.Simplify.Rules.Match import Futhark.Util topDownRules :: (BuilderOps rep) => [TopDownRule rep] topDownRules = [ RuleGeneric constantFoldPrimFun, RuleGeneric withAccTopDown, RuleGeneric emptyArrayToScratch ] bottomUpRules :: (BuilderOps rep, TraverseOpStms rep) => [BottomUpRule rep] bottomUpRules = [ RuleGeneric withAccBottomUp, RuleBasicOp simplifyIndex ] -- | A set of standard simplification rules. These assume pure -- functional semantics, and so probably should not be applied after -- memory block merging. standardRules :: (BuilderOps rep, TraverseOpStms rep) => RuleBook rep standardRules = ruleBook topDownRules bottomUpRules <> loopRules <> basicOpRules <> matchRules -- | Turn @copy(x)@ into @x@ iff @x@ is not used after this copy -- statement and it can be consumed. -- -- This simplistic rule is only valid before we introduce memory. removeUnnecessaryCopy :: (BuilderOps rep) => BottomUpRuleBasicOp rep removeUnnecessaryCopy (vtable, used) (Pat [d]) aux (Replicate (Shape []) (Var v)) | not (v `UT.isConsumed` used), -- This two first clauses below are too conservative, but the -- problem is that 'v' might not look like it has been consumed if -- it is consumed in an outer scope. This is because the -- simplifier applies bottom-up rules in a kind of deepest-first -- order. not (patElemName d `UT.isInResult` used) || (patElemName d `UT.isConsumed` used) -- Always OK to remove the copy if 'v' has no aliases and is never -- used again. || (v_is_fresh && v_not_used_again), (v_not_used_again && consumable) || not (patElemName d `UT.isConsumed` used) = Simplify $ auxing aux $ letBindNames [patElemName d] $ BasicOp $ SubExp $ Var v where v_not_used_again = not (v `UT.used` used) v_is_fresh = v `ST.lookupAliases` vtable == mempty -- We need to make sure we can even consume the original. The big -- missing piece here is that we cannot do copy removal inside of -- 'map' and other SOACs, but that is handled by SOAC-specific rules. consumable = fromMaybe False $ do e <- ST.lookup v vtable guard $ ST.entryDepth e == ST.loopDepth vtable consumableStm e `mplus` consumableFParam e consumableFParam = Just . maybe False (unique . declTypeOf) . ST.entryFParam consumableStm e = do void $ ST.entryStm e -- Must be a stm. guard v_is_fresh pure True removeUnnecessaryCopy _ _ _ _ = Skip constantFoldPrimFun :: (BuilderOps rep) => TopDownRuleGeneric rep constantFoldPrimFun _ (Let pat (StmAux cs attrs _) (Apply fname args _ _)) | Just args' <- mapM (isConst . fst) args, Just (_, _, fun) <- M.lookup (nameToText fname) primFuns, Just result <- fun args' = Simplify $ certifying cs $ attributing attrs $ letBind pat $ BasicOp $ SubExp $ Constant result where isConst (Constant v) = Just v isConst _ = Nothing constantFoldPrimFun _ _ = Skip -- | If an expression produces an array with a constant zero anywhere -- in its shape, just turn that into a Scratch. emptyArrayToScratch :: (BuilderOps rep) => TopDownRuleGeneric rep emptyArrayToScratch _ (Let pat@(Pat [pe]) aux e) | Just (pt, shape) <- isEmptyArray $ patElemType pe, not $ isScratch e = Simplify $ auxing aux $ letBind pat $ BasicOp $ Scratch pt $ shapeDims shape where isScratch (BasicOp Scratch {}) = True isScratch _ = False emptyArrayToScratch _ _ = Skip simplifyIndex :: (BuilderOps rep) => BottomUpRuleBasicOp rep simplifyIndex (vtable, used) pat@(Pat [pe]) (StmAux cs attrs _) (Index idd inds) | Just m <- simplifyIndexing vtable seType idd inds consumed consuming = Simplify $ certifying cs $ do res <- m attributing attrs $ case res of SubExpResult cs' se -> certifying cs' $ letBindNames (patNames pat) $ BasicOp $ SubExp se IndexResult extra_cs idd' inds' -> certifying extra_cs $ letBindNames (patNames pat) $ BasicOp $ Index idd' inds' where consuming = (`UT.isConsumed` used) consumed = consuming $ patElemName pe seType (Var v) = ST.lookupType v vtable seType (Constant v) = Just $ Prim $ primValueType v simplifyIndex _ _ _ _ = Skip withAccTopDown :: (BuilderOps rep) => TopDownRuleGeneric rep -- A WithAcc with no accumulators is sent to Valhalla. withAccTopDown _ (Let pat aux (WithAcc [] lam)) = Simplify . auxing aux $ do lam_res <- bodyBind $ lambdaBody lam forM_ (zip (patNames pat) lam_res) $ \(v, SubExpRes cs se) -> certifying cs $ letBindNames [v] $ BasicOp $ SubExp se -- Identify those results in 'lam' that are free and move them out. withAccTopDown vtable (Let pat aux (WithAcc inputs lam)) = Simplify . auxing aux $ do let (cert_params, acc_params) = splitAt (length inputs) $ lambdaParams lam (acc_res, nonacc_res) = splitFromEnd num_nonaccs $ bodyResult $ lambdaBody lam (acc_pes, nonacc_pes) = splitFromEnd num_nonaccs $ patElems pat -- Look at accumulator results. (acc_pes', inputs', params', acc_res') <- fmap (unzip4 . catMaybes) . mapM tryMoveAcc $ zip4 (chunks (map inputArrs inputs) acc_pes) inputs (zip cert_params acc_params) acc_res let (cert_params', acc_params') = unzip params' -- Look at non-accumulator results. (nonacc_pes', nonacc_res') <- unzip . catMaybes <$> mapM tryMoveNonAcc (zip nonacc_pes nonacc_res) when (concat acc_pes' == acc_pes && nonacc_pes' == nonacc_pes) cannotSimplify lam' <- mkLambda (cert_params' ++ acc_params') $ bodyBind $ (lambdaBody lam) {bodyResult = acc_res' <> nonacc_res'} letBind (Pat (concat acc_pes' <> nonacc_pes')) $ WithAcc inputs' lam' where num_nonaccs = length (lambdaReturnType lam) - length inputs inputArrs (_, arrs, _) = length arrs tryMoveAcc (pes, (_, arrs, _), (_, acc_p), SubExpRes cs (Var v)) | paramName acc_p == v, cs == mempty = do forM_ (zip pes arrs) $ \(pe, arr) -> letBindNames [patElemName pe] $ BasicOp $ SubExp $ Var arr pure Nothing tryMoveAcc x = pure $ Just x tryMoveNonAcc (pe, SubExpRes cs (Var v)) | v `ST.elem` vtable, cs == mempty = do letBindNames [patElemName pe] $ BasicOp $ SubExp $ Var v pure Nothing tryMoveNonAcc (pe, SubExpRes cs (Constant v)) | cs == mempty = do letBindNames [patElemName pe] $ BasicOp $ SubExp $ Constant v pure Nothing tryMoveNonAcc x = pure $ Just x withAccTopDown _ _ = Skip elimUpdates :: forall rep. (ASTRep rep, TraverseOpStms rep) => [VName] -> Body rep -> (Body rep, [VName]) elimUpdates get_rid_of = flip runState mempty . onBody where onBody body = do stms' <- onStms $ bodyStms body pure body {bodyStms = stms'} onStms = traverse onStm onStm (Let pat@(Pat [PatElem _ dec]) aux (BasicOp (UpdateAcc _ acc _ _))) | Acc c _ _ _ <- typeOf dec, c `elem` get_rid_of = do modify (insert c) pure $ Let pat aux $ BasicOp $ SubExp $ Var acc onStm (Let pat aux e) = Let pat aux <$> onExp e onExp = mapExpM mapper where mapper = (identityMapper :: forall m. (Monad m) => Mapper rep rep m) { mapOnOp = traverseOpStms (\_ stms -> onStms stms), mapOnBody = \_ body -> onBody body } withAccBottomUp :: (TraverseOpStms rep, BuilderOps rep) => BottomUpRuleGeneric rep -- Eliminate dead results. See Note [Dead Code Elimination for WithAcc] withAccBottomUp (_, utable) (Let pat aux (WithAcc inputs lam)) | not $ all (`UT.used` utable) $ patNames pat = Simplify $ do let (acc_res, nonacc_res) = splitFromEnd num_nonaccs $ bodyResult $ lambdaBody lam (acc_pes, nonacc_pes) = splitFromEnd num_nonaccs $ patElems pat (cert_params, acc_params) = splitAt (length inputs) $ lambdaParams lam -- Eliminate unused accumulator results let get_rid_of = map snd . filter getRidOf $ zip (chunks (map inputArrs inputs) acc_pes) $ map paramName cert_params -- Eliminate unused non-accumulator results let (nonacc_pes', nonacc_res') = unzip $ filter keepNonAccRes $ zip nonacc_pes nonacc_res when (null get_rid_of && nonacc_pes' == nonacc_pes) cannotSimplify let (body', eliminated) = elimUpdates get_rid_of $ lambdaBody lam when (null eliminated && nonacc_pes' == nonacc_pes) cannotSimplify let pes' = acc_pes ++ nonacc_pes' lam' <- mkLambda (cert_params ++ acc_params) $ do void $ bodyBind body' pure $ acc_res ++ nonacc_res' auxing aux $ letBind (Pat pes') $ WithAcc inputs lam' where num_nonaccs = length (lambdaReturnType lam) - length inputs inputArrs (_, arrs, _) = length arrs getRidOf (pes, _) = not $ any ((`UT.used` utable) . patElemName) pes keepNonAccRes (pe, _) = patElemName pe `UT.used` utable withAccBottomUp _ _ = Skip -- Note [Dead Code Elimination for WithAcc] -- -- Our static semantics for accumulators are basically those of linear -- types. This makes dead code elimination somewhat tricky. First, -- what we consider dead code is when we have a WithAcc where at least -- one of the array results (that internally correspond to an -- accumulator) are unused. E.g -- -- let {X',Y'} = -- with_acc {X, Y} (\X_p Y_p X_acc Y_acc -> ... {X_acc', Y_acc'}) -- -- where X' is not used later. Note that Y' is still used later. If -- none of the results of the WithAcc are used, then the Stm as a -- whole is dead and can be removed. That's the trivial case, done -- implicitly by the simplifier. The interesting case is exactly when -- some of the results are unused. How do we get rid of them? -- -- Naively, we might just remove them: -- -- let Y' = -- with_acc Y (\Y_p Y_acc -> ... Y_acc') -- -- This is safe *only* if X_acc is used *only* in the result (i.e. an -- "identity" WithAcc). Otherwise we end up with references to X_acc, -- which no longer exists. This simple case is actually handled in -- the withAccTopDown rule, and is easy enough. -- -- What we actually do when we decide to eliminate X_acc is that we -- inspect the body of the WithAcc and eliminate all UpdateAcc -- operations that refer to the same accumulator as X_acc (identified -- by the X_p token). I.e. we turn every -- -- let B = update_acc(A, ...) -- -- where 'A' is ultimately decided from X_acc into -- -- let B = A -- -- That's it! We then let ordinary dead code elimination eventually -- simplify the body enough that we have an "identity" WithAcc. There -- is no _guarantee_ that this will happen, but our general dead code -- elimination tends to be pretty good. futhark-0.25.27/src/Futhark/Optimise/Simplify/Rules/000077500000000000000000000000001475065116200222145ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Optimise/Simplify/Rules/BasicOp.hs000066400000000000000000000361361475065116200241010ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-} -- | Some simplification rules for t'BasicOp'. module Futhark.Optimise.Simplify.Rules.BasicOp ( basicOpRules, ) where import Control.Monad import Data.List (find, foldl', isSuffixOf, sort) import Data.List.NonEmpty (NonEmpty (..)) import Data.Maybe (isNothing) import Futhark.Analysis.PrimExp.Convert import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Construct import Futhark.IR import Futhark.Optimise.Simplify.Rule import Futhark.Optimise.Simplify.Rules.Loop import Futhark.Optimise.Simplify.Rules.Simple isCt1 :: SubExp -> Bool isCt1 (Constant v) = oneIsh v isCt1 _ = False isCt0 :: SubExp -> Bool isCt0 (Constant v) = zeroIsh v isCt0 _ = False data ConcatArg = ArgArrayLit [SubExp] | ArgReplicate [SubExp] SubExp | ArgVar VName toConcatArg :: ST.SymbolTable rep -> VName -> (ConcatArg, Certs) toConcatArg vtable v = case ST.lookupBasicOp v vtable of Just (ArrayLit ses _, cs) -> (ArgArrayLit ses, cs) Just (Replicate (Shape [d]) se, cs) -> (ArgReplicate [d] se, cs) _ -> (ArgVar v, mempty) fromConcatArg :: (MonadBuilder m) => Type -> (ConcatArg, Certs) -> m VName fromConcatArg t (ArgArrayLit ses, cs) = certifying cs $ letExp "concat_lit" $ BasicOp $ ArrayLit ses $ rowType t fromConcatArg _ (ArgReplicate ws se, cs) = certifying cs $ do w <- letSubExp "concat_rep_w" =<< toExp (sum $ map pe64 ws) letExp "concat_rep" $ BasicOp $ Replicate (Shape [w]) se fromConcatArg _ (ArgVar v, _) = pure v fuseConcatArg :: [(ConcatArg, Certs)] -> (ConcatArg, Certs) -> [(ConcatArg, Certs)] fuseConcatArg xs (ArgArrayLit [], _) = xs fuseConcatArg xs (ArgReplicate [w] se, cs) | isCt0 w = xs | isCt1 w = fuseConcatArg xs (ArgArrayLit [se], cs) fuseConcatArg ((ArgArrayLit x_ses, x_cs) : xs) (ArgArrayLit y_ses, y_cs) = (ArgArrayLit (x_ses ++ y_ses), x_cs <> y_cs) : xs fuseConcatArg ((ArgReplicate x_ws x_se, x_cs) : xs) (ArgReplicate y_ws y_se, y_cs) | x_se == y_se = (ArgReplicate (x_ws ++ y_ws) x_se, x_cs <> y_cs) : xs fuseConcatArg xs y = y : xs simplifyConcat :: (BuilderOps rep) => BottomUpRuleBasicOp rep -- concat@1(transpose(x),transpose(y)) == transpose(concat@0(x,y)) simplifyConcat (vtable, _) pat _ (Concat i (x :| xs) new_d) | Just r <- arrayRank <$> ST.lookupType x vtable, let perm = [i] ++ [0 .. i - 1] ++ [i + 1 .. r - 1], Just (x', x_cs) <- transposedBy perm x, Just (xs', xs_cs) <- mapAndUnzipM (transposedBy perm) xs = Simplify $ do concat_rearrange <- certifying (x_cs <> mconcat xs_cs) $ letExp "concat_rearrange" $ BasicOp $ Concat 0 (x' :| xs') new_d letBind pat $ BasicOp $ Rearrange perm concat_rearrange where transposedBy perm1 v = case ST.lookupExp v vtable of Just (BasicOp (Rearrange perm2 v'), vcs) | perm1 == perm2 -> Just (v', vcs) _ -> Nothing -- Removing a concatenation that involves only a single array. This -- may be produced as a result of other simplification rules. simplifyConcat (vtable, _) pat aux (Concat _ (x :| []) w) | Just x_t <- ST.lookupType x vtable, arraySize 0 x_t == w = -- Still need a copy because Concat produces a fresh array. Simplify $ auxing aux $ letBind pat $ BasicOp $ Replicate mempty $ Var x -- concat xs (concat ys zs) == concat xs ys zs simplifyConcat (vtable, _) pat (StmAux cs attrs _) (Concat i (x :| xs) new_d) | x' /= x || concat xs' /= xs = Simplify $ certifying (cs <> x_cs <> mconcat xs_cs) $ attributing attrs $ letBind pat $ BasicOp $ Concat i (x' :| zs ++ concat xs') new_d where (x' : zs, x_cs) = isConcat x (xs', xs_cs) = unzip $ map isConcat xs isConcat v = case ST.lookupBasicOp v vtable of Just (Concat j (y :| ys) _, v_cs) | j == i -> (y : ys, v_cs) _ -> ([v], mempty) -- Removing empty arrays from concatenations. simplifyConcat (vtable, _) pat aux (Concat i (x :| xs) new_d) | Just ts <- mapM (`ST.lookupType` vtable) $ x : xs, x' : xs' <- map fst $ filter (isNothing . isEmptyArray . snd) $ zip (x : xs) ts, xs' /= xs = Simplify $ auxing aux $ letBind pat $ BasicOp $ Concat i (x' :| xs') new_d -- Fusing arguments to the concat when possible. Only done when -- concatenating along the outer dimension for now. simplifyConcat (vtable, _) pat aux (Concat 0 (x :| xs) outer_w) | -- We produce the to-be-concatenated arrays in reverse order, so -- reverse them back. y : ys <- forSingleArray . reverse . foldl' fuseConcatArg mempty $ map (toConcatArg vtable) (x : xs), length xs /= length ys = Simplify $ do elem_type <- lookupType x y' <- fromConcatArg elem_type y ys' <- mapM (fromConcatArg elem_type) ys auxing aux $ letBind pat $ BasicOp $ Concat 0 (y' :| ys') outer_w where -- If we fuse so much that there is only a single input left, then -- it must have the right size. forSingleArray [(ArgReplicate _ v, cs)] = [(ArgReplicate [outer_w] v, cs)] forSingleArray ys = ys simplifyConcat _ _ _ _ = Skip ruleBasicOp :: (BuilderOps rep) => TopDownRuleBasicOp rep ruleBasicOp vtable pat aux op | Just (op', cs) <- applySimpleRules defOf seType op = Simplify $ certifying (cs <> stmAuxCerts aux) $ letBind pat $ BasicOp op' where defOf = (`ST.lookupExp` vtable) seType (Var v) = ST.lookupType v vtable seType (Constant v) = Just $ Prim $ primValueType v ruleBasicOp vtable pat aux (Update _ src _ (Var v)) | Just (BasicOp Scratch {}, _) <- ST.lookupExp v vtable = Simplify $ auxing aux $ letBind pat $ BasicOp $ SubExp $ Var src -- If we are writing a single-element slice from some array, and the -- element of that array can be computed as a PrimExp based on the -- index, let's just write that instead. ruleBasicOp vtable pat aux (Update safety src (Slice [DimSlice i n s]) (Var v)) | isCt1 n, isCt1 s, Just (ST.Indexed cs e) <- ST.index v [intConst Int64 0] vtable = Simplify $ do e' <- toSubExp "update_elem" e auxing aux . certifying cs $ letBind pat $ BasicOp $ Update safety src (Slice [DimFix i]) e' ruleBasicOp vtable pat aux (Update _ dest destis (Var v)) | Just (e, _) <- ST.lookupExp v vtable, arrayFrom e = Simplify $ auxing aux $ letBind pat $ BasicOp $ SubExp $ Var dest where arrayFrom (BasicOp (Replicate (Shape []) (Var copy_v))) | Just (e', _) <- ST.lookupExp copy_v vtable = arrayFrom e' arrayFrom (BasicOp (Index src srcis)) = src == dest && destis == srcis arrayFrom (BasicOp (Replicate v_shape v_se)) | Just (Replicate dest_shape dest_se, _) <- ST.lookupBasicOp dest vtable, v_se == dest_se, shapeDims v_shape `isSuffixOf` shapeDims dest_shape = True arrayFrom _ = False ruleBasicOp vtable pat aux (Update Unsafe dest is se) | Just dest_t <- ST.lookupType dest vtable, isFullSlice (arrayShape dest_t) is = Simplify . auxing aux $ case se of Var v | not $ null $ sliceDims is -> do v_reshaped <- letSubExp (baseString v ++ "_reshaped") . BasicOp $ Reshape ReshapeArbitrary (arrayShape dest_t) v letBind pat $ BasicOp $ Replicate mempty v_reshaped _ -> letBind pat $ BasicOp $ ArrayLit [se] $ rowType dest_t ruleBasicOp vtable pat (StmAux cs1 attrs _) (Update safety1 dest1 is1 (Var v1)) | Just (Update safety2 dest2 is2 se2, cs2) <- ST.lookupBasicOp v1 vtable, Just (Replicate (Shape []) (Var v3), cs3) <- ST.lookupBasicOp dest2 vtable, Just (Index v4 is4, cs4) <- ST.lookupBasicOp v3 vtable, is4 == is1, v4 == dest1 = Simplify $ certifying (cs1 <> cs2 <> cs3 <> cs4) $ do is5 <- subExpSlice $ sliceSlice (primExpSlice is1) (primExpSlice is2) attributing attrs $ letBind pat $ BasicOp $ Update (max safety1 safety2) dest1 is5 se2 ruleBasicOp vtable pat _ (CmpOp (CmpEq t) se1 se2) | Just m <- simplifyWith se1 se2 = Simplify m | Just m <- simplifyWith se2 se1 = Simplify m where simplifyWith (Var v) x | Just stm <- ST.lookupStm v vtable, Match [p] [Case [Just (BoolValue True)] tbranch] fbranch _ <- stmExp stm, Just (y, z) <- returns v (stmPat stm) tbranch fbranch, not $ boundInBody tbranch `namesIntersect` freeIn y, not $ boundInBody fbranch `namesIntersect` freeIn z = Just $ do eq_x_y <- letSubExp "eq_x_y" $ BasicOp $ CmpOp (CmpEq t) x y eq_x_z <- letSubExp "eq_x_z" $ BasicOp $ CmpOp (CmpEq t) x z p_and_eq_x_y <- letSubExp "p_and_eq_x_y" $ BasicOp $ BinOp LogAnd p eq_x_y not_p <- letSubExp "not_p" $ BasicOp $ UnOp (Neg Bool) p not_p_and_eq_x_z <- letSubExp "p_and_eq_x_y" $ BasicOp $ BinOp LogAnd not_p eq_x_z letBind pat $ BasicOp $ BinOp LogOr p_and_eq_x_y not_p_and_eq_x_z simplifyWith _ _ = Nothing returns v ifpat tbranch fbranch = fmap snd . find ((== v) . patElemName . fst) $ zip (patElems ifpat) $ zip (map resSubExp (bodyResult tbranch)) (map resSubExp (bodyResult fbranch)) ruleBasicOp _ pat _ (Replicate _ se) | [Acc {}] <- patTypes pat = Simplify $ letBind pat $ BasicOp $ SubExp se ruleBasicOp _ pat _ (Replicate (Shape []) se) | [Prim _] <- patTypes pat = Simplify $ letBind pat $ BasicOp $ SubExp se ruleBasicOp vtable pat _ (Replicate shape (Var v)) | Just (BasicOp (Replicate shape2 se), cs) <- ST.lookupExp v vtable, ST.subExpAvailable se vtable = Simplify $ certifying cs $ letBind pat $ BasicOp $ Replicate (shape <> shape2) se ruleBasicOp _ pat _ (ArrayLit (se : ses) _) | all (== se) ses = Simplify $ let n = constant (fromIntegral (length ses) + 1 :: Int64) in letBind pat $ BasicOp $ Replicate (Shape [n]) se ruleBasicOp vtable pat aux (Index idd slice) | Just inds <- sliceIndices slice, Just (BasicOp (Reshape k newshape idd2), idd_cs) <- ST.lookupExp idd vtable, length newshape == length inds = Simplify $ case k of ReshapeCoerce -> certifying idd_cs . auxing aux . letBind pat . BasicOp $ Index idd2 slice ReshapeArbitrary -> do -- Linearise indices and map to old index space. oldshape <- arrayDims <$> lookupType idd2 let new_inds = reshapeIndex (map pe64 oldshape) (map pe64 $ shapeDims newshape) (map pe64 inds) new_inds' <- mapM (toSubExp "new_index") new_inds certifying idd_cs . auxing aux $ letBind pat $ BasicOp $ Index idd2 $ Slice $ map DimFix new_inds' -- Copying an iota is pointless; just make it an iota instead. ruleBasicOp vtable pat aux (Replicate (Shape []) (Var v)) | Just (Iota n x s it, v_cs) <- ST.lookupBasicOp v vtable = Simplify . certifying v_cs . auxing aux $ letBind pat $ BasicOp $ Iota n x s it -- Handle identity permutation. ruleBasicOp _ pat _ (Rearrange perm v) | sort perm == perm = Simplify $ letBind pat $ BasicOp $ SubExp $ Var v ruleBasicOp vtable pat aux (Rearrange perm v) | Just (BasicOp (Rearrange perm2 e), v_cs) <- ST.lookupExp v vtable = -- Rearranging a rearranging: compose the permutations. Simplify . certifying v_cs . auxing aux $ letBind pat $ BasicOp $ Rearrange (perm `rearrangeCompose` perm2) e -- Rearranging a replicate where the outer dimension is left untouched. ruleBasicOp vtable pat aux (Rearrange perm v1) | Just (BasicOp (Replicate dims (Var v2)), v1_cs) <- ST.lookupExp v1 vtable, num_dims <- shapeRank dims, (rep_perm, rest_perm) <- splitAt num_dims perm, not $ null rest_perm, rep_perm == [0 .. length rep_perm - 1] = Simplify $ certifying v1_cs $ auxing aux $ do v <- letSubExp "rearrange_replicate" $ BasicOp $ Rearrange (map (subtract num_dims) rest_perm) v2 letBind pat $ BasicOp $ Replicate dims v -- Simplify away 0<=i when 'i' is from a loop of form 'for i < n'. ruleBasicOp vtable pat aux (CmpOp CmpSle {} x y) | Constant (IntValue (Int64Value 0)) <- x, Var v <- y, Just _ <- ST.lookupLoopVar v vtable = Simplify $ auxing aux $ letBind pat $ BasicOp $ SubExp $ constant True -- Simplify away i ST.lookupStm v vtable, cs' <- filter (`notElem` v_cs) cs, cs' /= cs = Simplify . certifying (Certs cs') $ letBind pat $ BasicOp $ SubExp $ Var v -- Remove UpdateAccs that contribute the neutral value, which is -- always a no-op. ruleBasicOp vtable pat aux (UpdateAcc _ acc _ vs) | Pat [pe] <- pat, Acc token _ _ _ <- patElemType pe, Just (_, _, Just (_, ne)) <- ST.entryAccInput =<< ST.lookup token vtable, vs == ne = Simplify . auxing aux $ letBind pat $ BasicOp $ SubExp $ Var acc -- Manifest of a a copy (or another Manifest) can be simplified to -- manifesting the original array, if it is still available. ruleBasicOp vtable pat aux (Manifest perm v1) | Just (Replicate (Shape []) (Var v2), cs) <- ST.lookupBasicOp v1 vtable, ST.available v2 vtable = Simplify . auxing aux . certifying cs . letBind pat . BasicOp $ Manifest perm v2 | Just (Manifest _ v2, cs) <- ST.lookupBasicOp v1 vtable, ST.available v2 vtable = Simplify . auxing aux . certifying cs . letBind pat . BasicOp $ Manifest perm v2 ruleBasicOp _ _ _ _ = Skip topDownRules :: (BuilderOps rep) => [TopDownRule rep] topDownRules = [ RuleBasicOp ruleBasicOp ] bottomUpRules :: (BuilderOps rep) => [BottomUpRule rep] bottomUpRules = [ RuleBasicOp simplifyConcat ] -- | A set of simplification rules for t'BasicOp's. Includes rules -- from "Futhark.Optimise.Simplify.Rules.Simple". basicOpRules :: (BuilderOps rep) => RuleBook rep basicOpRules = ruleBook topDownRules bottomUpRules <> loopRules futhark-0.25.27/src/Futhark/Optimise/Simplify/Rules/ClosedForm.hs000066400000000000000000000147171475065116200246170ustar00rootroot00000000000000-- | This module implements facilities for determining whether a -- reduction or fold can be expressed in a closed form (i.e. not as a -- SOAC). -- -- Right now, the module can detect only trivial cases. In the -- future, we would like to make it more powerful, as well as possibly -- also being able to analyse sequential loops. module Futhark.Optimise.Simplify.Rules.ClosedForm ( foldClosedForm, loopClosedForm, ) where import Control.Monad import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Construct import Futhark.IR import Futhark.Optimise.Simplify.Rule import Futhark.Optimise.Simplify.Rules.Simple (VarLookup) import Futhark.Transform.Rename {- Motivation: let {*[int,x_size_27] map_computed_shape_1286} = replicate(x_size_27, all_equal_shape_1044) in let {*[bool,x_size_27] map_size_checks_1292} = replicate(x_size_27, x_1291) in let {bool all_equal_checked_1298, int all_equal_shape_1299} = reduceT(fn {bool, int} (bool bacc_1293, int nacc_1294, bool belm_1295, int nelm_1296) => let {bool tuplit_elems_1297} = bacc_1293 && belm_1295 in {tuplit_elems_1297, nelm_1296}, {True, 0}, map_size_checks_1292, map_computed_shape_1286) -} -- | @foldClosedForm look foldfun accargs arrargs@ determines whether -- each of the results of @foldfun@ can be expressed in a closed form. foldClosedForm :: (BuilderOps rep) => VarLookup rep -> Pat (LetDec rep) -> Lambda rep -> [SubExp] -> [VName] -> RuleM rep () foldClosedForm look pat lam accs arrs = do inputsize <- arraysSize 0 <$> mapM lookupType arrs t <- case patTypes pat of [Prim FloatType {}] -> cannotSimplify [Prim t] -> pure t _ -> cannotSimplify closedBody <- checkResults (patNames pat) inputsize mempty Int64 knownBnds (map paramName (lambdaParams lam)) (lambdaBody lam) accs isEmpty <- newVName "fold_input_is_empty" letBindNames [isEmpty] $ BasicOp $ CmpOp (CmpEq int64) inputsize (intConst Int64 0) letBind pat =<< ( Match [Var isEmpty] <$> (pure . Case [Just $ BoolValue True] <$> resultBodyM accs) <*> renameBody closedBody <*> pure (MatchDec [primBodyType t] MatchNormal) ) where knownBnds = determineKnownBindings look lam accs arrs -- | @loopClosedForm pat respat merge bound bodys@ determines whether -- the do-loop can be expressed in a closed form. loopClosedForm :: (BuilderOps rep) => Pat (LetDec rep) -> [(FParam rep, SubExp)] -> Names -> IntType -> SubExp -> Body rep -> RuleM rep () loopClosedForm pat merge i it bound body = do t <- case patTypes pat of [Prim FloatType {}] -> cannotSimplify [Prim t] -> pure t _ -> cannotSimplify closedBody <- checkResults mergenames bound i it knownBnds (map identName mergeidents) body mergeexp isEmpty <- newVName "bound_is_zero" letBindNames [isEmpty] $ BasicOp $ CmpOp (CmpSlt it) bound (intConst it 0) letBind pat =<< ( Match [Var isEmpty] <$> (pure . Case [Just (BoolValue True)] <$> resultBodyM mergeexp) <*> renameBody closedBody <*> pure (MatchDec [primBodyType t] MatchNormal) ) where (mergepat, mergeexp) = unzip merge mergeidents = map paramIdent mergepat mergenames = map paramName mergepat knownBnds = M.fromList $ zip mergenames mergeexp checkResults :: (BuilderOps rep) => [VName] -> SubExp -> Names -> IntType -> M.Map VName SubExp -> -- | Lambda-bound [VName] -> Body rep -> [SubExp] -> RuleM rep (Body rep) checkResults pat size untouchable it knownBnds params body accs = do ((), stms) <- collectStms $ zipWithM_ checkResult (zip pat res) (zip accparams accs) mkBodyM stms $ varsRes pat where stmMap = makeBindMap body (accparams, _) = splitAt (length accs) params res = bodyResult body nonFree = boundInBody body <> namesFromList params <> untouchable checkResult (p, SubExpRes _ (Var v)) (accparam, acc) | Just (BasicOp (BinOp bop x y)) <- M.lookup v stmMap, x /= y = do -- One of x,y must be *this* accumulator, and the other must -- be something that is free in the body. let isThisAccum = (== Var accparam) (this, el) <- liftMaybe $ case ( (asFreeSubExp x, isThisAccum y), (asFreeSubExp y, isThisAccum x) ) of ((Just free, True), _) -> Just (acc, free) (_, (Just free, True)) -> Just (acc, free) _ -> Nothing case bop of LogAnd -> letBindNames [p] $ BasicOp $ BinOp LogAnd this el Add t w -> do size' <- asIntS t size letBindNames [p] =<< eBinOp (Add t w) (eSubExp this) (pure $ BasicOp $ BinOp (Mul t w) el size') FAdd t | Just properly_typed_size <- properFloatSize t -> do size' <- properly_typed_size letBindNames [p] =<< eBinOp (FAdd t) (eSubExp this) (pure $ BasicOp $ BinOp (FMul t) el size') _ -> cannotSimplify -- Um... sorry. checkResult _ _ = cannotSimplify asFreeSubExp :: SubExp -> Maybe SubExp asFreeSubExp (Var v) | v `nameIn` nonFree = M.lookup v knownBnds asFreeSubExp se = Just se properFloatSize t = Just $ letSubExp "converted_size" $ BasicOp $ ConvOp (SIToFP it t) size determineKnownBindings :: VarLookup rep -> Lambda rep -> [SubExp] -> [VName] -> M.Map VName SubExp determineKnownBindings look lam accs arrs = accBnds <> arrBnds where (accparams, arrparams) = splitAt (length accs) $ lambdaParams lam accBnds = M.fromList $ zip (map paramName accparams) accs arrBnds = M.fromList $ mapMaybe isReplicate $ zip (map paramName arrparams) arrs isReplicate (p, v) | Just (BasicOp (Replicate (Shape (_ : _)) ve), cs) <- look v, cs == mempty = Just (p, ve) isReplicate _ = Nothing makeBindMap :: Body rep -> M.Map VName (Exp rep) makeBindMap = M.fromList . mapMaybe isSingletonStm . stmsToList . bodyStms where isSingletonStm (Let pat _ e) = case patNames pat of [v] -> Just (v, e) _ -> Nothing futhark-0.25.27/src/Futhark/Optimise/Simplify/Rules/Index.hs000066400000000000000000000253571475065116200236330ustar00rootroot00000000000000{-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-} -- | Index simplification mechanics. module Futhark.Optimise.Simplify.Rules.Index ( IndexResult (..), simplifyIndexing, ) where import Control.Monad (guard) import Data.List.NonEmpty (NonEmpty (..)) import Data.Maybe import Futhark.Analysis.PrimExp.Convert import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Construct import Futhark.IR import Futhark.Optimise.Simplify.Rules.Simple import Futhark.Util isCt1 :: SubExp -> Bool isCt1 (Constant v) = oneIsh v isCt1 _ = False isCt0 :: SubExp -> Bool isCt0 (Constant v) = zeroIsh v isCt0 _ = False -- | Some index expressions can be simplified to t'SubExp's, while -- others produce another index expression (which may be further -- simplifiable). data IndexResult = IndexResult Certs VName (Slice SubExp) | SubExpResult Certs SubExp -- Fake expressions that we can recognise. fakeIndices :: [TPrimExp Int64 VName] fakeIndices = map f [0 :: Int ..] where f i = isInt64 $ LeafExp (VName v (negate i)) $ IntType Int64 where v = nameFromText ("fake_" <> showText i) -- | Try to simplify an index operation. simplifyIndexing :: (MonadBuilder m) => ST.SymbolTable (Rep m) -> TypeLookup -> VName -> Slice SubExp -> Bool -> (VName -> Bool) -> Maybe (m IndexResult) simplifyIndexing vtable seType idd (Slice inds) consuming consumed = case defOf idd of _ | Just t <- seType (Var idd), Slice inds == fullSlice t [] -> Just $ pure $ SubExpResult mempty $ Var idd | Just inds' <- sliceIndices (Slice inds), Just (ST.Indexed cs e) <- ST.index idd inds' vtable, worthInlining e, all (`ST.elem` vtable) (unCerts cs) -> Just $ SubExpResult cs <$> toSubExp "index_primexp" e | Just inds' <- sliceIndices (Slice inds), Just (ST.IndexedArray cs arr inds'') <- ST.index idd inds' vtable, all (worthInlining . untyped) inds'', arr `ST.available` vtable, all (`ST.elem` vtable) (unCerts cs) -> Just $ IndexResult cs arr . Slice . map DimFix <$> mapM (toSubExp "index_primexp") inds'' | Just (ST.IndexedArray cs arr inds'') <- ST.index' idd (fixSlice (pe64 <$> Slice inds) (map fst matches)) vtable, all (worthInlining . untyped) inds'', arr `ST.available` vtable, all (`ST.elem` vtable) (unCerts cs), not consuming, not $ consumed arr, Just inds''' <- mapM okIdx inds'' -> do Just $ IndexResult cs arr . Slice <$> sequence inds''' where matches = zip fakeIndices $ sliceDims $ Slice inds okIdx i = case lookup i matches of Just w -> Just $ pure $ DimSlice (constant (0 :: Int64)) w (constant (1 :: Int64)) Nothing -> do guard $ not $ any ((`namesIntersect` freeIn i) . freeIn . fst) matches Just $ DimFix <$> toSubExp "index_primexp" i Nothing -> Nothing Just (SubExp (Var v), cs) -> Just $ pure $ IndexResult cs v $ Slice inds Just (Iota _ x s to_it, cs) | [DimFix ii] <- inds, Just (Prim (IntType from_it)) <- seType ii -> Just $ let mul = BinOpExp $ Mul to_it OverflowWrap add = BinOpExp $ Add to_it OverflowWrap in fmap (SubExpResult cs) $ toSubExp "index_iota" $ ( sExt to_it (primExpFromSubExp (IntType from_it) ii) `mul` primExpFromSubExp (IntType to_it) s ) `add` primExpFromSubExp (IntType to_it) x | [DimSlice i_offset i_n i_stride] <- inds -> Just $ do i_offset' <- asIntS to_it i_offset i_stride' <- asIntS to_it i_stride let mul = BinOpExp $ Mul to_it OverflowWrap add = BinOpExp $ Add to_it OverflowWrap i_offset'' <- toSubExp "iota_offset" $ ( primExpFromSubExp (IntType to_it) x `mul` primExpFromSubExp (IntType to_it) s ) `add` primExpFromSubExp (IntType to_it) i_offset' i_stride'' <- letSubExp "iota_offset" $ BasicOp $ BinOp (Mul Int64 OverflowWrap) s i_stride' fmap (SubExpResult cs) $ letSubExp "slice_iota" $ BasicOp $ Iota i_n i_offset'' i_stride'' to_it Just (Index aa ais, cs) -> Just $ IndexResult cs aa <$> subExpSlice (sliceSlice (primExpSlice ais) (primExpSlice (Slice inds))) Just (Replicate (Shape [_]) (Var vv), cs) | [DimFix {}] <- inds, ST.available vv vtable -> Just $ pure $ SubExpResult cs $ Var vv | DimFix {} : is' <- inds, not consuming, not $ consumed vv, ST.available vv vtable -> Just $ pure $ IndexResult cs vv $ Slice is' Just (Replicate (Shape [_]) val@(Constant _), cs) | [DimFix {}] <- inds, not consuming -> Just $ pure $ SubExpResult cs val Just (Replicate (Shape ds) v, cs) | (ds_inds, rest_inds) <- splitAt (length ds) inds, (ds', ds_inds') <- unzip $ mapMaybe index ds_inds, ds' /= ds, ST.subExpAvailable v vtable -> Just $ do arr <- letExp "smaller_replicate" $ BasicOp $ Replicate (Shape ds') v pure $ IndexResult cs arr $ Slice $ ds_inds' ++ rest_inds where index DimFix {} = Nothing index (DimSlice _ n s) = Just (n, DimSlice (constant (0 :: Int64)) n s) Just (Rearrange perm src, cs) | rearrangeReach perm <= length (takeWhile isIndex inds) -> let inds' = rearrangeShape (rearrangeInverse perm) inds in Just $ pure $ IndexResult cs src $ Slice inds' where isIndex DimFix {} = True isIndex _ = False Just (Replicate (Shape []) (Var src), cs) | Just dims <- arrayDims <$> seType (Var src), length inds == length dims, not $ consumed src, -- It is generally not safe to simplify a slice of a copy, -- because the result may be used in an in-place update of the -- original. But we know this can only happen if the original -- is bound the same depth as we are! all (isJust . dimFix) inds || maybe True ((ST.loopDepth vtable /=) . ST.entryDepth) (ST.lookup src vtable), not consuming, ST.available src vtable -> Just $ pure $ IndexResult cs src $ Slice inds Just (Reshape ReshapeCoerce newshape src, cs) | Just olddims <- arrayDims <$> seType (Var src), changed_dims <- zipWith (/=) (shapeDims newshape) olddims, not $ or $ drop (length inds) changed_dims -> Just $ pure $ IndexResult cs src $ Slice inds | Just olddims <- arrayDims <$> seType (Var src), length newshape == length inds, length olddims == length (shapeDims newshape) -> Just $ pure $ IndexResult cs src $ Slice inds Just (Reshape _ (Shape [_]) v2, cs) | Just [_] <- arrayDims <$> seType (Var v2) -> Just $ pure $ IndexResult cs v2 $ Slice inds Just (Concat d (x :| xs) _, cs) | -- HACK: simplifying the indexing of an N-array concatenation -- is going to produce an N-deep if expression, which is bad -- when N is large. To try to avoid that, we use the -- heuristic not to simplify as long as any of the operands -- are themselves Concats. The hope it that this will give -- simplification some time to cut down the concatenation to -- something smaller, before we start inlining. not $ any isConcat $ x : xs, Just (ibef, DimFix i, iaft) <- focusNth d inds, Just (Prim res_t) <- (`setArrayDims` sliceDims (Slice inds)) <$> ST.lookupType x vtable -> Just $ do x_len <- arraySize d <$> lookupType x xs_lens <- mapM (fmap (arraySize d) . lookupType) xs let add n m = do added <- letSubExp "index_concat_add" $ BasicOp $ BinOp (Add Int64 OverflowWrap) n m pure (added, n) (_, starts) <- mapAccumLM add x_len xs_lens let xs_and_starts = reverse $ zip xs starts let mkBranch [] = letSubExp "index_concat" $ BasicOp $ Index x $ Slice $ ibef ++ DimFix i : iaft mkBranch ((x', start) : xs_and_starts') = do cmp <- letSubExp "index_concat_cmp" $ BasicOp $ CmpOp (CmpSle Int64) start i (thisres, thisstms) <- collectStms $ do i' <- letSubExp "index_concat_i" $ BasicOp $ BinOp (Sub Int64 OverflowWrap) i start letSubExp "index_concat" . BasicOp . Index x' $ Slice (ibef ++ DimFix i' : iaft) thisbody <- mkBodyM thisstms [subExpRes thisres] (altres, altstms) <- collectStms $ mkBranch xs_and_starts' altbody <- mkBodyM altstms [subExpRes altres] certifying cs . letSubExp "index_concat_branch" $ Match [cmp] [Case [Just $ BoolValue True] thisbody] altbody $ MatchDec [primBodyType res_t] MatchNormal SubExpResult mempty <$> mkBranch xs_and_starts Just (ArrayLit ses _, cs) | DimFix (Constant (IntValue (Int64Value i))) : inds' <- inds, Just se <- maybeNth i ses -> case inds' of [] -> Just $ pure $ SubExpResult cs se _ | Var v2 <- se -> Just $ pure $ IndexResult cs v2 $ Slice inds' _ -> Nothing Just (Update Unsafe _ (Slice update_inds) se, cs) | inds == update_inds, ST.subExpAvailable se vtable -> Just $ pure $ SubExpResult cs se -- Indexing single-element arrays. We know the index must be 0. _ | Just t <- seType $ Var idd, isCt1 $ arraySize 0 t, DimFix i : inds' <- inds, not $ isCt0 i -> Just . pure . IndexResult mempty idd . Slice $ DimFix (constant (0 :: Int64)) : inds' _ -> Nothing where defOf v = do (BasicOp op, def_cs) <- ST.lookupExp v vtable pure (op, def_cs) worthInlining e | primExpSizeAtLeast 20 e = False -- totally ad-hoc. | otherwise = worthInlining' e worthInlining' (BinOpExp Pow {} _ _) = False worthInlining' (BinOpExp FPow {} _ _) = False worthInlining' (BinOpExp _ x y) = worthInlining' x && worthInlining' y worthInlining' (CmpOpExp _ x y) = worthInlining' x && worthInlining' y worthInlining' (ConvOpExp _ x) = worthInlining' x worthInlining' (UnOpExp _ x) = worthInlining' x worthInlining' FunExp {} = False worthInlining' _ = True isConcat v | Just (Concat {}, _) <- defOf v = True | otherwise = False futhark-0.25.27/src/Futhark/Optimise/Simplify/Rules/Loop.hs000066400000000000000000000221201475065116200234560ustar00rootroot00000000000000-- | Loop simplification rules. module Futhark.Optimise.Simplify.Rules.Loop (loopRules) where import Control.Monad import Data.Bifunctor (second) import Data.List (partition) import Data.Maybe import Futhark.Analysis.DataDependencies import Futhark.Analysis.PrimExp.Convert import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Analysis.UsageTable qualified as UT import Futhark.Construct import Futhark.IR import Futhark.Optimise.Simplify.Rule import Futhark.Optimise.Simplify.Rules.ClosedForm import Futhark.Transform.Rename -- This next one is tricky - it's easy enough to determine that some -- loop result is not used after the loop, but here, we must also make -- sure that it does not affect any other values. -- -- I do not claim that the current implementation of this rule is -- perfect, but it should suffice for many cases, and should never -- generate wrong code. removeRedundantLoopParams :: (BuilderOps rep) => BottomUpRuleLoop rep removeRedundantLoopParams (_, used) pat aux (merge, form, body) | not $ all (usedAfterLoop . fst) merge = let necessaryForReturned = findNecessaryForReturned usedAfterLoopOrInForm (zip (map fst merge) (map resSubExp $ bodyResult body)) (dataDependencies body) resIsNecessary ((v, _), _) = usedAfterLoop v || (paramName v `nameIn` necessaryForReturned) || referencedInPat v || referencedInForm v (keep_valpart, discard_valpart) = partition (resIsNecessary . snd) $ zip (patElems pat) $ zip merge $ bodyResult body (keep_valpatelems, keep_val) = unzip keep_valpart (_discard_valpatelems, discard_val) = unzip discard_valpart (merge', val_es') = unzip keep_val body' = body {bodyResult = val_es'} pat' = Pat keep_valpatelems in if merge' == merge then Skip else Simplify $ do -- We can't just remove the bindings in 'discard', since the loop -- body may still use their names in (now-dead) expressions. -- Hence, we add them inside the loop, fully aware that dead-code -- removal will eventually get rid of them. Some care is -- necessary to handle unique bindings. body'' <- insertStmsM $ do mapM_ (uncurry letBindNames) $ dummyStms discard_val pure body' auxing aux $ letBind pat' $ Loop merge' form body'' where pat_used = map (`UT.isUsedDirectly` used) $ patNames pat used_vals = map fst $ filter snd $ zip (map (paramName . fst) merge) pat_used usedAfterLoop = flip elem used_vals . paramName usedAfterLoopOrInForm p = usedAfterLoop p || paramName p `nameIn` freeIn form patAnnotNames = freeIn $ map fst merge referencedInPat = (`nameIn` patAnnotNames) . paramName referencedInForm = (`nameIn` freeIn form) . paramName dummyStms = map dummyStm dummyStm ((p, e), _) | unique (paramDeclType p), Var v <- e = ([paramName p], BasicOp $ Replicate mempty $ Var v) | otherwise = ([paramName p], BasicOp $ SubExp e) removeRedundantLoopParams _ _ _ _ = Skip -- We may change the type of the loop if we hoist out a shape -- annotation, in which case we also need to tweak the bound pattern. hoistLoopInvariantLoopParams :: (BuilderOps rep) => TopDownRuleLoop rep hoistLoopInvariantLoopParams vtable pat aux (merge, form, loopbody) = do -- Figure out which of the elements of loopresult are -- loop-invariant, and hoist them out. let explpat = zip (patElems pat) $ map (paramName . fst) merge case foldr checkInvariance ([], explpat, [], []) $ zip3 (patNames pat) merge res of ([], _, _, _) -> -- Nothing is invariant. Skip (invariant, explpat', merge', res') -> Simplify . auxing aux $ do -- We have moved something invariant out of the loop. let loopbody' = loopbody {bodyResult = res'} explpat'' = map fst explpat' forM_ invariant $ \(v1, (v2, cs)) -> certifying cs $ letBindNames [identName v1] $ BasicOp $ SubExp v2 letBind (Pat explpat'') $ Loop merge' form loopbody' where res = bodyResult loopbody namesOfLoopParams = namesFromList $ map (paramName . fst) merge removeFromResult cs (mergeParam, mergeInit) explpat' = case partition ((== paramName mergeParam) . snd) explpat' of ([(patelem, _)], rest) -> (Just (patElemIdent patelem, (mergeInit, cs)), rest) (_, _) -> (Nothing, explpat') checkInvariance (pat_name, (mergeParam, mergeInit), resExp) (invariant, explpat', merge', resExps) | isInvariant, -- Certificates must be available. all (`ST.elem` vtable) $ unCerts $ resCerts resExp = let (stm, explpat'') = removeFromResult (resCerts resExp) (mergeParam, mergeInit) explpat' in ( maybe id (:) stm $ (paramIdent mergeParam, (mergeInit, resCerts resExp)) : invariant, explpat'', merge', resExps ) where -- A non-unique merge variable is invariant if one of the -- following is true: isInvariant -- (0) The result is a variable of the same name as the -- parameter, where all existential parameters are already -- known to be invariant | Var v2 <- resSubExp resExp, paramName mergeParam == v2 = allExistentialInvariant (namesFromList $ map (identName . fst) invariant) mergeParam -- (1) The result is identical to the initial parameter value. | mergeInit == resSubExp resExp = True -- (2) The initial parameter value is equal to an outer -- loop parameter 'P', where the initial value of 'P' is -- equal to 'resExp', AND 'resExp' ultimately becomes the -- new value of 'P'. XXX: it's a bit clumsy that this -- only works for one level of nesting, and I think it -- would not be too hard to generalise. | Var init_v <- mergeInit, Just (p_init, p_res) <- ST.lookupLoopParam init_v vtable, p_init == resSubExp resExp, p_res == Var pat_name = True -- (3) It is a statically empty array. | isJust $ isEmptyArray (paramType mergeParam) = True | otherwise = False checkInvariance (_pat_name, (mergeParam, mergeInit), resExp) (invariant, explpat', merge', resExps) = (invariant, explpat', (mergeParam, mergeInit) : merge', resExp : resExps) allExistentialInvariant namesOfInvariant mergeParam = all (invariantOrNotMergeParam namesOfInvariant) $ namesToList $ freeIn mergeParam `namesSubtract` oneName (paramName mergeParam) invariantOrNotMergeParam namesOfInvariant name = (name `notNameIn` namesOfLoopParams) || (name `nameIn` namesOfInvariant) simplifyClosedFormLoop :: (BuilderOps rep) => TopDownRuleLoop rep simplifyClosedFormLoop _ pat _ (val, ForLoop i it bound, body) = Simplify $ loopClosedForm pat val (oneName i) it bound body simplifyClosedFormLoop _ _ _ _ = Skip unroll :: (BuilderOps rep) => Integer -> [(FParam rep, SubExpRes)] -> (VName, IntType, Integer) -> Body rep -> RuleM rep [SubExpRes] unroll n merge (iv, it, i) body | i >= n = pure $ map snd merge | otherwise = do iter_body <- insertStmsM $ do forM_ merge $ \(mergevar, SubExpRes cs mergeinit) -> certifying cs $ letBindNames [paramName mergevar] $ BasicOp $ SubExp mergeinit letBindNames [iv] $ BasicOp $ SubExp $ intConst it i -- Some of the sizes in the types here might be temporarily wrong -- until copy propagation fixes it up. pure body iter_body' <- renameBody iter_body addStms $ bodyStms iter_body' let merge' = zip (map fst merge) $ bodyResult iter_body' unroll n merge' (iv, it, i + 1) body simplifyKnownIterationLoop :: (BuilderOps rep) => TopDownRuleLoop rep simplifyKnownIterationLoop _ pat aux (merge, ForLoop i it (Constant iters), body) | IntValue n <- iters, zeroIshInt n || oneIshInt n || "unroll" `inAttrs` stmAuxAttrs aux = Simplify $ do res <- unroll (valueIntegral n) (map (second subExpRes) merge) (i, it, 0) body forM_ (zip (patNames pat) res) $ \(v, SubExpRes cs se) -> certifying cs $ letBindNames [v] $ BasicOp $ SubExp se simplifyKnownIterationLoop _ _ _ _ = Skip topDownRules :: (BuilderOps rep) => [TopDownRule rep] topDownRules = [ RuleLoop hoistLoopInvariantLoopParams, RuleLoop simplifyClosedFormLoop, RuleLoop simplifyKnownIterationLoop ] bottomUpRules :: (BuilderOps rep) => [BottomUpRule rep] bottomUpRules = [ RuleLoop removeRedundantLoopParams ] -- | Standard loop simplification rules. loopRules :: (BuilderOps rep) => RuleBook rep loopRules = ruleBook topDownRules bottomUpRules futhark-0.25.27/src/Futhark/Optimise/Simplify/Rules/Match.hs000066400000000000000000000237621475065116200236160ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Match simplification rules. module Futhark.Optimise.Simplify.Rules.Match (matchRules) where import Control.Monad import Data.Either import Data.List (partition, transpose, unzip4, zip5) import Futhark.Analysis.PrimExp.Convert import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Analysis.UsageTable qualified as UT import Futhark.Construct import Futhark.IR import Futhark.Optimise.Simplify.Rule import Futhark.Util -- Does this case always match the scrutinees? caseAlwaysMatches :: [SubExp] -> Case a -> Bool caseAlwaysMatches ses = and . zipWith match ses . casePat where match se (Just v) = se == Constant v match _ Nothing = True -- Can this case never match the scrutinees? caseNeverMatches :: [SubExp] -> Case a -> Bool caseNeverMatches ses = or . zipWith impossible ses . casePat where impossible (Constant v1) (Just v2) = v1 /= v2 impossible _ _ = False ruleMatch :: (BuilderOps rep) => TopDownRuleMatch rep -- Remove impossible cases. ruleMatch _ pat _ (cond, cases, defbody, ifdec) | (impossible, cases') <- partition (caseNeverMatches cond) cases, not $ null impossible = Simplify $ letBind pat $ Match cond cases' defbody ifdec -- Find new default case. ruleMatch _ pat _ (cond, cases, _, ifdec) | (always_matches, cases') <- partition (caseAlwaysMatches cond) cases, new_default : _ <- reverse always_matches = Simplify $ letBind pat $ Match cond cases' (caseBody new_default) ifdec -- Remove caseless match. ruleMatch _ pat (StmAux cs _ _) (_, [], defbody, _) = Simplify $ do defbody_res <- bodyBind defbody certifying cs $ forM_ (zip (patElems pat) defbody_res) $ \(pe, res) -> certifying (resCerts res) . letBind (Pat [pe]) $ BasicOp (SubExp $ resSubExp res) -- IMPROVE: the following two rules can be generalised to work in more -- cases, especially when the branches have bindings, or return more -- than one value. -- -- if c then True else v == c || v ruleMatch _ pat _ ( [cond], [ Case [Just (BoolValue True)] (Body _ tstms [SubExpRes tcs (Constant (BoolValue True))]) ], Body _ fstms [SubExpRes fcs se], MatchDec ts _ ) | null tstms, null fstms, [Prim Bool] <- map extTypeOf ts = Simplify $ certifying (tcs <> fcs) $ letBind pat $ BasicOp $ BinOp LogOr cond se -- When type(x)==bool, if c then x else y == (c && x) || (!c && y) ruleMatch _ pat _ ([cond], [Case [Just (BoolValue True)] tb], fb, MatchDec ts _) | Body _ tstms [SubExpRes tcs tres] <- tb, Body _ fstms [SubExpRes fcs fres] <- fb, all (safeExp . stmExp) $ tstms <> fstms, all ((== Prim Bool) . extTypeOf) ts = Simplify $ do addStms tstms addStms fstms e <- eBinOp LogOr (pure $ BasicOp $ BinOp LogAnd cond tres) ( eBinOp LogAnd (pure $ BasicOp $ UnOp (Neg Bool) cond) (pure $ BasicOp $ SubExp fres) ) certifying (tcs <> fcs) $ letBind pat e ruleMatch _ pat _ (_, [Case _ tbranch], _, MatchDec _ MatchFallback) | all (safeExp . stmExp) $ bodyStms tbranch = Simplify $ do let ses = bodyResult tbranch addStms $ bodyStms tbranch sequence_ [ certifying cs $ letBindNames [patElemName p] $ BasicOp $ SubExp se | (p, SubExpRes cs se) <- zip (patElems pat) ses ] ruleMatch _ pat _ ([cond], [Case [Just (BoolValue True)] tb], fb, _) | Body _ _ [SubExpRes tcs (Constant (IntValue t))] <- tb, Body _ _ [SubExpRes fcs (Constant (IntValue f))] <- fb = if oneIshInt t && zeroIshInt f && tcs == mempty && fcs == mempty then Simplify . letBind pat . BasicOp $ ConvOp (BToI (intValueType t)) cond else if zeroIshInt t && oneIshInt f then Simplify $ do cond_neg <- letSubExp "cond_neg" $ BasicOp $ UnOp (Neg Bool) cond letBind pat $ BasicOp $ ConvOp (BToI (intValueType t)) cond_neg else Skip -- Simplify -- -- let z = if c then x else y -- -- to -- -- let z = y -- -- in the case where 'x' is a loop parameter with initial value 'y' -- and the new value of the loop parameter is 'z'. ('x' and 'y' can -- be flipped.) ruleMatch vtable (Pat [pe]) aux (_c, [Case _ tb], fb, MatchDec [_] _) | Body _ tstms [SubExpRes xcs x] <- tb, null tstms, Body _ fstms [SubExpRes ycs y] <- fb, null fstms, matches x y || matches y x = Simplify . certifying (stmAuxCerts aux <> xcs <> ycs) $ letBind (Pat [pe]) (BasicOp $ SubExp y) where z = patElemName pe matches (Var x) y | Just (initial, res) <- ST.lookupLoopParam x vtable = initial == y && res == Var z matches _ _ = False ruleMatch _ _ _ _ = Skip -- | Move out results of a conditional expression whose computation is -- either invariant to the branches (only done for results used for -- existentials), or the same in both branches. hoistBranchInvariant :: (BuilderOps rep) => TopDownRuleMatch rep hoistBranchInvariant _ pat _ (cond, cases, defbody, MatchDec ret ifsort) = let case_reses = map (bodyResult . caseBody) cases defbody_res = bodyResult defbody (hoistings, (pes, ts, case_reses_tr, defbody_res')) = (fmap unzip4 . partitionEithers) . map branchInvariant $ zip5 [0 ..] (patElems pat) ret (transpose case_reses) defbody_res in if null hoistings then Skip else Simplify $ do ctx_fixes <- sequence hoistings let onCase (Case vs body) case_res = Case vs $ body {bodyResult = case_res} cases' = zipWith onCase cases $ transpose case_reses_tr defbody' = defbody {bodyResult = defbody_res'} ret' = foldr (uncurry fixExt) ts ctx_fixes -- We may have to add some reshapes if we made the type -- less existential. cases'' <- mapM (traverse $ reshapeBodyResults $ map extTypeOf ret') cases' defbody'' <- reshapeBodyResults (map extTypeOf ret') defbody' letBind (Pat pes) $ Match cond cases'' defbody'' (MatchDec ret' ifsort) where bound_in_branches = namesFromList . concatMap (patNames . stmPat) $ foldMap (bodyStms . caseBody) cases <> bodyStms defbody branchInvariant (i, pe, t, case_reses, defres) -- If just one branch has a variant result, then we give up. | namesIntersect bound_in_branches $ freeIn $ defres : case_reses = noHoisting -- Do all branches return the same value? | all ((== resSubExp defres) . resSubExp) case_reses = Left $ do certifying (foldMap resCerts case_reses <> resCerts defres) $ letBindNames [patElemName pe] . BasicOp . SubExp $ resSubExp defres hoisted i pe -- Do all branches return values that are free in the -- branch, and are we not the only pattern element? The -- latter is to avoid infinite application of this rule. | not $ namesIntersect bound_in_branches $ freeIn $ defres : case_reses, patSize pat > 1, Prim _ <- patElemType pe = Left $ do bt <- expTypesFromPat $ Pat [pe] letBindNames [patElemName pe] =<< ( Match cond <$> ( zipWith Case (map casePat cases) <$> mapM (resultBodyM . pure . resSubExp) case_reses ) <*> resultBodyM [resSubExp defres] <*> pure (MatchDec bt ifsort) ) hoisted i pe | otherwise = noHoisting where noHoisting = Right (pe, t, case_reses, defres) hoisted i pe = pure (i, Var $ patElemName pe) reshapeBodyResults rets body = buildBody_ $ do ses <- bodyBind body let (ctx_ses, val_ses) = splitFromEnd (length rets) ses (ctx_ses ++) <$> zipWithM reshapeResult val_ses rets reshapeResult (SubExpRes cs (Var v)) t@Array {} = do v_t <- lookupType v let newshape = arrayDims $ removeExistentials t v_t SubExpRes cs <$> if newshape /= arrayDims v_t then letSubExp "branch_ctx_reshaped" (shapeCoerce newshape v) else pure $ Var v reshapeResult se _ = pure se -- | Remove the return values of a branch, that are not actually used -- after a branch. Standard dead code removal can remove the branch -- if *none* of the return values are used, but this rule is more -- precise. removeDeadBranchResult :: (BuilderOps rep) => BottomUpRuleMatch rep removeDeadBranchResult (_, used) pat _ (cond, cases, defbody, MatchDec rettype ifsort) | -- Figure out which of the names in 'pat' are used... patused <- map keep $ patNames pat, -- If they are not all used, then this rule applies. not (and patused) = do -- Remove the parts of the branch-results that correspond to dead -- return value bindings. Note that this leaves dead code in the -- branch bodies, but that will be removed later. let pick :: [a] -> [a] pick = map snd . filter fst . zip patused pat' = pick $ patElems pat rettype' = pick rettype -- We also need to adjust the existential references in the -- branch type. exts = scanl (+) 0 [if b then 1 else 0 | b <- patused] adjust = mapExt (exts !!) Simplify $ do cases' <- mapM (traverse $ onBody pick) cases defbody' <- onBody pick defbody letBind (Pat pat') $ Match cond cases' defbody' $ MatchDec (map adjust rettype') ifsort | otherwise = Skip where usedDirectly v = v `UT.isUsedDirectly` used usedIndirectly v = any (\pe -> v `nameIn` freeIn pe && usedDirectly (patElemName pe)) (patElems pat) keep v = usedDirectly v || usedIndirectly v onBody pick (Body _ stms res) = mkBodyM stms $ pick res topDownRules :: (BuilderOps rep) => [TopDownRule rep] topDownRules = [ RuleMatch ruleMatch, RuleMatch hoistBranchInvariant ] bottomUpRules :: (BuilderOps rep) => [BottomUpRule rep] bottomUpRules = [ RuleMatch removeDeadBranchResult ] matchRules :: (BuilderOps rep) => RuleBook rep matchRules = ruleBook topDownRules bottomUpRules futhark-0.25.27/src/Futhark/Optimise/Simplify/Rules/Simple.hs000066400000000000000000000325061475065116200240070ustar00rootroot00000000000000-- | Particularly simple simplification rules. module Futhark.Optimise.Simplify.Rules.Simple ( TypeLookup, VarLookup, applySimpleRules, ) where import Control.Monad import Data.List (isSuffixOf) import Data.List.NonEmpty qualified as NE import Futhark.Analysis.PrimExp.Convert import Futhark.IR import Futhark.Util (focusNth) -- | A function that, given a variable name, returns its definition. type VarLookup rep = VName -> Maybe (Exp rep, Certs) -- | A function that, given a subexpression, returns its type. type TypeLookup = SubExp -> Maybe Type -- | A simple rule is a top-down rule that can be expressed as a pure -- function. type SimpleRule rep = VarLookup rep -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certs) isCt1 :: SubExp -> Bool isCt1 (Constant v) = oneIsh v isCt1 _ = False isCt0 :: SubExp -> Bool isCt0 (Constant v) = zeroIsh v isCt0 _ = False simplifyCmpOp :: SimpleRule rep simplifyCmpOp _ _ (CmpOp cmp e1 e2) | e1 == e2 = constRes $ BoolValue $ case cmp of CmpEq {} -> True CmpSlt {} -> False CmpUlt {} -> False CmpSle {} -> True CmpUle {} -> True FCmpLt {} -> False FCmpLe {} -> True CmpLlt -> False CmpLle -> True simplifyCmpOp _ _ (CmpOp cmp (Constant v1) (Constant v2)) = constRes . BoolValue =<< doCmpOp cmp v1 v2 simplifyCmpOp look _ (CmpOp CmpEq {} (Constant (IntValue x)) (Var v)) | Just (BasicOp (ConvOp BToI {} b), cs) <- look v = case valueIntegral x :: Int of 1 -> Just (SubExp b, cs) 0 -> Just (UnOp (Neg Bool) b, cs) _ -> Just (SubExp (Constant (BoolValue False)), cs) simplifyCmpOp _ _ _ = Nothing simplifyBinOp :: SimpleRule rep simplifyBinOp _ _ (BinOp op (Constant v1) (Constant v2)) | Just res <- doBinOp op v1 v2 = constRes res -- By normalisation, constants are always on the left. -- -- x+(y+z) = (x+y)+z (where x and y are constants). simplifyBinOp look _ (BinOp op1 (Constant x1) (Var y1)) | associativeBinOp op1, Just (BasicOp (BinOp op2 (Constant x2) y2), cs) <- look y1, op1 == op2, Just res <- doBinOp op1 x1 x2 = Just (BinOp op1 (Constant res) y2, cs) simplifyBinOp look _ (BinOp (Add it ovf) e1 e2) | isCt0 e1 = resIsSubExp e2 | isCt0 e2 = resIsSubExp e1 -- x+(y-x) => y | Var v2 <- e2, Just (BasicOp (BinOp Sub {} e2_a e2_b), cs) <- look v2, e2_b == e1 = Just (SubExp e2_a, cs) -- x+(-1*y) => x-y | Var v2 <- e2, Just (BasicOp (BinOp Mul {} (Constant (IntValue x)) e3), cs) <- look v2, valueIntegral x == (-1 :: Int) = Just (BinOp (Sub it ovf) e1 e3, cs) simplifyBinOp _ _ (BinOp FAdd {} e1 e2) | isCt0 e1 = resIsSubExp e2 | isCt0 e2 = resIsSubExp e1 simplifyBinOp look _ (BinOp sub@(Sub t _) e1 e2) | isCt0 e2 = resIsSubExp e1 | e1 == e2 = Just (SubExp (intConst t 0), mempty) -- -- Below are cases for simplifying (a+b)-b and permutations. -- -- (e1_a+e1_b)-e1_a == e1_b | Var v1 <- e1, Just (BasicOp (BinOp Add {} e1_a e1_b), cs) <- look v1, e1_a == e2 = Just (SubExp e1_b, cs) -- (e1_a+e1_b)-e1_b == e1_a | Var v1 <- e1, Just (BasicOp (BinOp Add {} e1_a e1_b), cs) <- look v1, e1_b == e2 = Just (SubExp e1_a, cs) -- e2_a-(e2_a+e2_b) == 0-e2_b | Var v2 <- e2, Just (BasicOp (BinOp Add {} e2_a e2_b), cs) <- look v2, e2_a == e1 = Just (BinOp sub (intConst t 0) e2_b, cs) -- e2_b-(e2_a+e2_b) == 0-e2_a | Var v2 <- e2, Just (BasicOp (BinOp Add {} e2_a e2_b), cs) <- look v2, e2_b == e1 = Just (BinOp sub (intConst t 0) e2_a, cs) simplifyBinOp _ _ (BinOp FSub {} e1 e2) | isCt0 e2 = resIsSubExp e1 simplifyBinOp _ _ (BinOp Mul {} e1 e2) | isCt0 e1 = resIsSubExp e1 | isCt0 e2 = resIsSubExp e2 | isCt1 e1 = resIsSubExp e2 | isCt1 e2 = resIsSubExp e1 simplifyBinOp _ _ (BinOp FMul {} e1 e2) | isCt1 e1 = resIsSubExp e2 | isCt1 e2 = resIsSubExp e1 simplifyBinOp look _ (BinOp (SMod t _) e1 e2) | isCt1 e2 = constRes $ IntValue $ intValue t (0 :: Int) | e1 == e2 = constRes $ IntValue $ intValue t (0 :: Int) | Var v1 <- e1, Just (BasicOp (BinOp SMod {} _ e4), v1_cs) <- look v1, e4 == e2 = Just (SubExp e1, v1_cs) simplifyBinOp _ _ (BinOp SDiv {} e1 e2) | isCt0 e1 = resIsSubExp e1 | isCt1 e2 = resIsSubExp e1 | isCt0 e2 = Nothing simplifyBinOp _ _ (BinOp SDivUp {} e1 e2) | isCt0 e1 = resIsSubExp e1 | isCt1 e2 = resIsSubExp e1 | isCt0 e2 = Nothing simplifyBinOp _ _ (BinOp FDiv {} e1 e2) | isCt0 e1 = resIsSubExp e1 | isCt1 e2 = resIsSubExp e1 | isCt0 e2 = Nothing simplifyBinOp _ _ (BinOp (SRem t _) e1 e2) | isCt1 e2 = constRes $ IntValue $ intValue t (0 :: Int) | e1 == e2 = constRes $ IntValue $ intValue t (1 :: Int) simplifyBinOp _ _ (BinOp SQuot {} e1 e2) | isCt1 e2 = resIsSubExp e1 | isCt0 e2 = Nothing simplifyBinOp _ _ (BinOp (Pow t) e1 e2) | e1 == intConst t 2 = Just (BinOp (Shl t) (intConst t 1) e2, mempty) simplifyBinOp _ _ (BinOp (FPow t) e1 e2) | isCt0 e2 = resIsSubExp $ floatConst t 1 | isCt0 e1 || isCt1 e1 || isCt1 e2 = resIsSubExp e1 simplifyBinOp _ _ (BinOp (Shl t) e1 e2) | isCt0 e2 = resIsSubExp e1 | isCt0 e1 = resIsSubExp $ intConst t 0 simplifyBinOp _ _ (BinOp AShr {} e1 e2) | isCt0 e2 = resIsSubExp e1 simplifyBinOp _ _ (BinOp (And t) e1 e2) | isCt0 e1 = resIsSubExp $ intConst t 0 | isCt0 e2 = resIsSubExp $ intConst t 0 | e1 == e2 = resIsSubExp e1 simplifyBinOp _ _ (BinOp Or {} e1 e2) | isCt0 e1 = resIsSubExp e2 | isCt0 e2 = resIsSubExp e1 | e1 == e2 = resIsSubExp e1 simplifyBinOp _ _ (BinOp (Xor t) e1 e2) | isCt0 e1 = resIsSubExp e2 | isCt0 e2 = resIsSubExp e1 | e1 == e2 = resIsSubExp $ intConst t 0 simplifyBinOp defOf _ (BinOp LogAnd e1 e2) | isCt0 e1 = constRes $ BoolValue False | isCt0 e2 = constRes $ BoolValue False | isCt1 e1 = resIsSubExp e2 | isCt1 e2 = resIsSubExp e1 | Var v <- e1, Just (BasicOp (UnOp (Neg Bool) e1'), v_cs) <- defOf v, e1' == e2 = Just (SubExp $ Constant $ BoolValue False, v_cs) | Var v <- e2, Just (BasicOp (UnOp (Neg Bool) e2'), v_cs) <- defOf v, e2' == e1 = Just (SubExp $ Constant $ BoolValue False, v_cs) simplifyBinOp defOf _ (BinOp LogOr e1 e2) | isCt0 e1 = resIsSubExp e2 | isCt0 e2 = resIsSubExp e1 | isCt1 e1 = constRes $ BoolValue True | isCt1 e2 = constRes $ BoolValue True | Var v <- e1, Just (BasicOp (UnOp (Neg Bool) e1'), v_cs) <- defOf v, e1' == e2 = Just (SubExp $ Constant $ BoolValue True, v_cs) | Var v <- e2, Just (BasicOp (UnOp (Neg Bool) e2'), v_cs) <- defOf v, e2' == e1 = Just (SubExp $ Constant $ BoolValue True, v_cs) simplifyBinOp defOf _ (BinOp (SMax it) e1 e2) | e1 == e2 = resIsSubExp e1 | Var v1 <- e1, Just (BasicOp (BinOp (SMax _) e1_1 e1_2), v1_cs) <- defOf v1, e1_1 == e2 = Just (BinOp (SMax it) e1_2 e2, v1_cs) | Var v1 <- e1, Just (BasicOp (BinOp (SMax _) e1_1 e1_2), v1_cs) <- defOf v1, e1_2 == e2 = Just (BinOp (SMax it) e1_1 e2, v1_cs) | Var v2 <- e2, Just (BasicOp (BinOp (SMax _) e2_1 e2_2), v2_cs) <- defOf v2, e2_1 == e1 = Just (BinOp (SMax it) e2_2 e1, v2_cs) | Var v2 <- e2, Just (BasicOp (BinOp (SMax _) e2_1 e2_2), v2_cs) <- defOf v2, e2_2 == e1 = Just (BinOp (SMax it) e2_1 e1, v2_cs) simplifyBinOp _ _ _ = Nothing constRes :: PrimValue -> Maybe (BasicOp, Certs) constRes = Just . (,mempty) . SubExp . Constant resIsSubExp :: SubExp -> Maybe (BasicOp, Certs) resIsSubExp = Just . (,mempty) . SubExp simplifyUnOp :: SimpleRule rep simplifyUnOp _ _ (UnOp op (Constant v)) = constRes =<< doUnOp op v simplifyUnOp defOf _ (UnOp (Neg Bool) (Var v)) | Just (BasicOp (UnOp (Neg Bool) v2), v_cs) <- defOf v = Just (SubExp v2, v_cs) simplifyUnOp _ _ _ = Nothing simplifyConvOp :: SimpleRule rep simplifyConvOp _ _ (ConvOp op (Constant v)) = constRes =<< doConvOp op v simplifyConvOp _ _ (ConvOp op se) | (from, to) <- convOpType op, from == to = resIsSubExp se simplifyConvOp lookupVar _ (ConvOp (SExt t2 t1) (Var v)) | Just (BasicOp (ConvOp (SExt t3 _) se), v_cs) <- lookupVar v, t2 >= t3 = Just (ConvOp (SExt t3 t1) se, v_cs) simplifyConvOp lookupVar _ (ConvOp (ZExt t2 t1) (Var v)) | Just (BasicOp (ConvOp (ZExt t3 _) se), v_cs) <- lookupVar v, t2 >= t3 = Just (ConvOp (ZExt t3 t1) se, v_cs) simplifyConvOp lookupVar _ (ConvOp (SIToFP t2 t1) (Var v)) | Just (BasicOp (ConvOp (SExt t3 _) se), v_cs) <- lookupVar v, t2 >= t3 = Just (ConvOp (SIToFP t3 t1) se, v_cs) simplifyConvOp lookupVar _ (ConvOp (UIToFP t2 t1) (Var v)) | Just (BasicOp (ConvOp (ZExt t3 _) se), v_cs) <- lookupVar v, t2 >= t3 = Just (ConvOp (UIToFP t3 t1) se, v_cs) simplifyConvOp lookupVar _ (ConvOp (FPConv t2 t1) (Var v)) | Just (BasicOp (ConvOp (FPConv t3 _) se), v_cs) <- lookupVar v, t2 >= t3 = Just (ConvOp (FPConv t3 t1) se, v_cs) simplifyConvOp _ _ _ = Nothing -- If expression is true then just replace assertion. simplifyAssert :: SimpleRule rep simplifyAssert _ _ (Assert (Constant (BoolValue True)) _ _) = constRes UnitValue simplifyAssert _ _ _ = Nothing -- No-op reshape. simplifyIdentityReshape :: SimpleRule rep simplifyIdentityReshape _ seType (Reshape _ newshape v) | Just t <- seType $ Var v, newshape == arrayShape t = resIsSubExp $ Var v simplifyIdentityReshape _ _ _ = Nothing simplifyReshapeReshape :: SimpleRule rep simplifyReshapeReshape defOf _ (Reshape k1 newshape v) | Just (BasicOp (Reshape k2 _ v2), v_cs) <- defOf v = Just (Reshape (max k1 k2) newshape v2, v_cs) simplifyReshapeReshape _ _ _ = Nothing simplifyReshapeScratch :: SimpleRule rep simplifyReshapeScratch defOf _ (Reshape _ newshape v) | Just (BasicOp (Scratch bt _), v_cs) <- defOf v = Just (Scratch bt $ shapeDims newshape, v_cs) simplifyReshapeScratch _ _ _ = Nothing simplifyReshapeReplicate :: SimpleRule rep simplifyReshapeReplicate defOf seType (Reshape _ newshape v) | Just (BasicOp (Replicate _ se), v_cs) <- defOf v, Just oldshape <- arrayShape <$> seType se, shapeDims oldshape `isSuffixOf` shapeDims newshape = let new = take (length newshape - shapeRank oldshape) $ shapeDims newshape in Just (Replicate (Shape new) se, v_cs) simplifyReshapeReplicate _ _ _ = Nothing simplifyReshapeIota :: SimpleRule rep simplifyReshapeIota defOf _ (Reshape _ newshape v) | Just (BasicOp (Iota _ offset stride it), v_cs) <- defOf v, [n] <- shapeDims newshape = Just (Iota n offset stride it, v_cs) simplifyReshapeIota _ _ _ = Nothing simplifyReshapeConcat :: SimpleRule rep simplifyReshapeConcat defOf seType (Reshape ReshapeCoerce newshape v) = do (BasicOp (Concat d arrs _), v_cs) <- defOf v (bef, w', aft) <- focusNth d $ shapeDims newshape (arr_bef, _, arr_aft) <- focusNth d <=< fmap arrayDims $ seType $ Var $ NE.head arrs guard $ arr_bef == bef guard $ arr_aft == aft Just (Concat d arrs w', v_cs) simplifyReshapeConcat _ _ _ = Nothing reshapeSlice :: [DimIndex d] -> [d] -> [DimIndex d] reshapeSlice (DimFix i : slice') scs = DimFix i : reshapeSlice slice' scs reshapeSlice (DimSlice x _ s : slice') (d : ds') = DimSlice x d s : reshapeSlice slice' ds' reshapeSlice _ _ = [] -- If we are size-coercing a slice, then we might as well just use a -- different slice instead. simplifyReshapeIndex :: SimpleRule rep simplifyReshapeIndex defOf _ (Reshape ReshapeCoerce newshape v) | Just (BasicOp (Index v' slice), v_cs) <- defOf v, slice' <- Slice $ reshapeSlice (unSlice slice) $ shapeDims newshape, slice' /= slice = Just (Index v' slice', v_cs) simplifyReshapeIndex _ _ _ = Nothing -- If we are updating a slice with the result of a size coercion, we -- instead use the original array and update the slice dimensions. simplifyUpdateReshape :: SimpleRule rep simplifyUpdateReshape defOf seType (Update safety dest slice (Var v)) | Just (BasicOp (Reshape ReshapeCoerce _ v'), v_cs) <- defOf v, Just ds <- arrayDims <$> seType (Var v'), slice' <- Slice $ reshapeSlice (unSlice slice) ds, slice' /= slice = Just (Update safety dest slice' $ Var v', v_cs) simplifyUpdateReshape _ _ _ = Nothing -- | If we are replicating a scratch array (possibly indirectly), just -- turn it into a scratch by itself. repScratchToScratch :: SimpleRule rep repScratchToScratch defOf seType (Replicate shape (Var src)) = do t <- seType $ Var src cs <- isActuallyScratch src pure (Scratch (elemType t) (shapeDims shape <> arrayDims t), cs) where isActuallyScratch v = case defOf v of Just (BasicOp Scratch {}, cs) -> Just cs Just (BasicOp (Rearrange _ v'), cs) -> (cs <>) <$> isActuallyScratch v' Just (BasicOp (Reshape _ _ v'), cs) -> (cs <>) <$> isActuallyScratch v' _ -> Nothing repScratchToScratch _ _ _ = Nothing simpleRules :: [SimpleRule rep] simpleRules = [ simplifyBinOp, simplifyCmpOp, simplifyUnOp, simplifyConvOp, simplifyAssert, repScratchToScratch, simplifyIdentityReshape, simplifyReshapeReshape, simplifyReshapeScratch, simplifyReshapeReplicate, simplifyReshapeIota, simplifyReshapeConcat, simplifyReshapeIndex, simplifyUpdateReshape ] -- | Try to simplify the given t'BasicOp', returning a new t'BasicOp' -- and certificates that it must depend on. {-# NOINLINE applySimpleRules #-} applySimpleRules :: VarLookup rep -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certs) applySimpleRules defOf seType op = msum [rule defOf seType op | rule <- simpleRules] futhark-0.25.27/src/Futhark/Optimise/Sink.hs000066400000000000000000000223601475065116200205710ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | "Sinking" is conceptually the opposite of hoisting. The idea is -- to take code that looks like this: -- -- @ -- x = xs[i] -- y = ys[i] -- if x != 0 then { -- y -- } else { -- 0 -- } -- @ -- -- and turn it into -- -- @ -- x = xs[i] -- if x != 0 then { -- y = ys[i] -- y -- } else { -- 0 -- } -- @ -- -- The idea is to delay loads from memory until (if) they are actually -- needed. Code patterns like the above is particularly common in -- code that makes use of pattern matching on sum types. -- -- We are currently quite conservative about when we do this. In -- particular, if any consumption is going on in a body, we don't do -- anything. This is far too conservative. Also, we are careful -- never to duplicate work. -- -- This pass redundantly computes free-variable information a lot. If -- you ever see this pass as being a compilation speed bottleneck, -- start by caching that a bit. -- -- This pass is defined on post-SOACS representations. This is not -- because we do anything GPU-specific here, but simply because more -- explicit indexing is going on after SOACs are gone. module Futhark.Optimise.Sink (sinkGPU, sinkMC) where import Control.Monad.State import Data.Bifunctor import Data.List (foldl') import Data.Map qualified as M import Futhark.Analysis.Alias qualified as Alias import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Builder.Class import Futhark.IR.Aliases import Futhark.IR.GPU import Futhark.IR.MC import Futhark.Pass type SymbolTable rep = ST.SymbolTable rep type Sinking rep = M.Map VName (Stm rep) type Sunk = Names type Sinker rep a = SymbolTable rep -> Sinking rep -> a -> (a, Sunk) type Constraints rep = ( ASTRep rep, Aliased rep, Buildable rep, ST.IndexOp (Op rep) ) -- | Given a statement, compute how often each of its free variables -- are used. Not accurate: what we care about are only 1, and greater -- than 1. multiplicity :: (Constraints rep) => Stm rep -> M.Map VName Int multiplicity stm = case stmExp stm of Match cond cases defbody _ -> foldl' comb mempty $ free 1 cond : free 1 defbody : map (free 1 . caseBody) cases Op {} -> free 2 stm Loop {} -> free 2 stm WithAcc {} -> free 2 stm _ -> free 1 stm where free k x = M.fromList $ map (,k) $ namesToList $ freeIn x comb = M.unionWith (+) optimiseBranch :: (Constraints rep) => Sinker rep (Op rep) -> Sinker rep (Body rep) optimiseBranch onOp vtable sinking (Body dec stms res) = let (stms', stms_sunk) = optimiseStms onOp vtable sinking' (sunk_stms <> stms) $ freeIn res in ( Body dec stms' res, sunk <> stms_sunk ) where free_in_stms = freeIn stms <> freeIn res (sinking_here, sinking') = M.partitionWithKey sunkHere sinking sunk_stms = stmsFromList $ M.elems sinking_here sunkHere v stm = v `nameIn` free_in_stms && all (`ST.available` vtable) (namesToList (freeIn stm)) sunk = namesFromList $ foldMap (patNames . stmPat) sunk_stms optimiseLoop :: (Constraints rep) => Sinker rep (Op rep) -> Sinker rep ([(FParam rep, SubExp)], LoopForm, Body rep) optimiseLoop onOp vtable sinking (merge, form, body0) = let (body1, sunk) = optimiseBody onOp vtable' sinking body0 in ((merge, form, body1), sunk) where (params, _) = unzip merge scope = case form of WhileLoop {} -> scopeOfFParams params ForLoop i it _ -> M.insert i (IndexName it) $ scopeOfFParams params vtable' = ST.fromScope scope <> vtable optimiseStms :: (Constraints rep) => Sinker rep (Op rep) -> SymbolTable rep -> Sinking rep -> Stms rep -> Names -> (Stms rep, Sunk) optimiseStms onOp init_vtable init_sinking all_stms free_in_res = let (all_stms', sunk) = optimiseStms' init_vtable init_sinking $ stmsToList all_stms in (stmsFromList all_stms', sunk) where multiplicities = foldl' (M.unionWith (+)) (M.fromList (map (,1) (namesToList free_in_res))) (map multiplicity $ stmsToList all_stms) optimiseStms' _ _ [] = ([], mempty) optimiseStms' vtable sinking (stm : stms) | BasicOp Index {} <- stmExp stm, [pe] <- patElems (stmPat stm), primType $ patElemType pe, maybe True (== 1) $ M.lookup (patElemName pe) multiplicities = let (stms', sunk) = optimiseStms' vtable' (M.insert (patElemName pe) stm sinking) stms in if patElemName pe `nameIn` sunk then (stms', sunk) else (stm : stms', sunk) | Match cond cases defbody ret <- stmExp stm = let onCase (Case vs body) = let (body', body_sunk) = optimiseBranch onOp vtable sinking body in (Case vs body', body_sunk) (cases', cases_sunk) = unzip $ map onCase cases (defbody', defbody_sunk) = optimiseBranch onOp vtable sinking defbody (stms', sunk) = optimiseStms' vtable' sinking stms in ( stm {stmExp = Match cond cases' defbody' ret} : stms', mconcat cases_sunk <> defbody_sunk <> sunk ) | Loop merge lform body <- stmExp stm = let comps = (merge, lform, body) (comps', loop_sunk) = optimiseLoop onOp vtable sinking comps (merge', _, body') = comps' (stms', stms_sunk) = optimiseStms' vtable' sinking stms in ( stm {stmExp = Loop merge' lform body'} : stms', stms_sunk <> loop_sunk ) | Op op <- stmExp stm = let (op', op_sunk) = onOp vtable sinking op (stms', stms_sunk) = optimiseStms' vtable' sinking stms in ( stm {stmExp = Op op'} : stms', stms_sunk <> op_sunk ) | otherwise = let (stms', stms_sunk) = optimiseStms' vtable' sinking stms (e', stm_sunk) = runState (mapExpM mapper (stmExp stm)) mempty in ( stm {stmExp = e'} : stms', stm_sunk <> stms_sunk ) where vtable' = ST.insertStm stm vtable mapper = identityMapper { mapOnBody = \scope body -> do let (body', sunk) = optimiseBody onOp (ST.fromScope scope <> vtable) sinking body modify (<> sunk) pure body' } optimiseBody :: (Constraints rep) => Sinker rep (Op rep) -> Sinker rep (Body rep) optimiseBody onOp vtable sinking (Body attr stms res) = let (stms', sunk) = optimiseStms onOp vtable sinking stms $ freeIn res in (Body attr stms' res, sunk) optimiseKernelBody :: (Constraints rep) => Sinker rep (Op rep) -> Sinker rep (KernelBody rep) optimiseKernelBody onOp vtable sinking (KernelBody attr stms res) = let (stms', sunk) = optimiseStms onOp vtable sinking stms $ freeIn res in (KernelBody attr stms' res, sunk) optimiseSegOp :: (Constraints rep) => Sinker rep (Op rep) -> Sinker rep (SegOp lvl rep) optimiseSegOp onOp vtable sinking op = let scope = scopeOfSegSpace $ segSpace op in runState (mapSegOpM (opMapper scope) op) mempty where opMapper scope = identitySegOpMapper { mapOnSegOpLambda = \lam -> do let (body, sunk) = optimiseBody onOp op_vtable sinking $ lambdaBody lam modify (<> sunk) pure lam {lambdaBody = body}, mapOnSegOpBody = \body -> do let (body', sunk) = optimiseKernelBody onOp op_vtable sinking body modify (<> sunk) pure body' } where op_vtable = ST.fromScope scope <> vtable type SinkRep rep = Aliases rep sink :: ( Buildable rep, AliasableRep rep, ST.IndexOp (Op (Aliases rep)) ) => Sinker (SinkRep rep) (Op (SinkRep rep)) -> Pass rep rep sink onOp = Pass "sink" "move memory loads closer to their uses" $ fmap removeProgAliases . intraproceduralTransformationWithConsts onConsts onFun . Alias.aliasAnalysis where onFun _ fd = do let vtable = ST.insertFParams (funDefParams fd) mempty (body, _) = optimiseBody onOp vtable mempty $ funDefBody fd pure fd {funDefBody = body} onConsts consts = pure $ fst $ optimiseStms onOp mempty mempty consts $ namesFromList $ M.keys $ scopeOf consts -- | Sinking in GPU kernels. sinkGPU :: Pass GPU GPU sinkGPU = sink onHostOp where onHostOp :: Sinker (SinkRep GPU) (Op (SinkRep GPU)) onHostOp vtable sinking (SegOp op) = first SegOp $ optimiseSegOp onHostOp vtable sinking op onHostOp vtable sinking (GPUBody types body) = first (GPUBody types) $ optimiseBody onHostOp vtable sinking body onHostOp _ _ op = (op, mempty) -- | Sinking for multicore. sinkMC :: Pass MC MC sinkMC = sink onHostOp where onHostOp :: Sinker (SinkRep MC) (Op (SinkRep MC)) onHostOp vtable sinking (ParOp par_op op) = let (par_op', par_sunk) = maybe (Nothing, mempty) (first Just . optimiseSegOp onHostOp vtable sinking) par_op (op', sunk) = optimiseSegOp onHostOp vtable sinking op in (ParOp par_op' op', par_sunk <> sunk) onHostOp _ _ op = (op, mempty) futhark-0.25.27/src/Futhark/Optimise/TileLoops.hs000066400000000000000000001240061475065116200215770ustar00rootroot00000000000000{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} -- | Perform a restricted form of loop tiling within SegMaps. We only -- tile primitive types, to avoid excessive shared memory use. module Futhark.Optimise.TileLoops (tileLoops) where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.Map.Strict qualified as M import Data.Maybe (mapMaybe) import Data.Sequence qualified as Seq import Futhark.Analysis.Alias qualified as Alias import Futhark.IR.GPU import Futhark.IR.Prop.Aliases (consumedInStm) import Futhark.MonadFreshNames import Futhark.Optimise.BlkRegTiling import Futhark.Optimise.TileLoops.Shared import Futhark.Pass import Futhark.Tools import Futhark.Transform.Rename import Prelude hiding (quot) -- | The pass definition. tileLoops :: Pass GPU GPU tileLoops = Pass "tile loops" "Tile stream loops inside kernels" $ intraproceduralTransformation onStms where onStms scope stms = modifyNameSource $ runState $ runReaderT (optimiseStms (M.empty, M.empty) stms) scope optimiseBody :: Env -> Body GPU -> TileM (Body GPU) optimiseBody env (Body () stms res) = Body () <$> optimiseStms env stms <*> pure res optimiseStms :: Env -> Stms GPU -> TileM (Stms GPU) optimiseStms env stms = localScope (scopeOf stms) $ do (_, stms') <- foldM foldfun (env, mempty) $ stmsToList stms pure stms' where foldfun :: (Env, Stms GPU) -> Stm GPU -> TileM (Env, Stms GPU) foldfun (e, ss) s = do (e', s') <- optimiseStm e s pure (e', ss <> s') optimiseStm :: Env -> Stm GPU -> TileM (Env, Stms GPU) optimiseStm env stm@(Let pat aux (Op (SegOp (SegMap lvl@SegThread {} space ts kbody)))) = do res3dtiling <- localScope (scopeOfSegSpace space) $ doRegTiling3D stm stms' <- case res3dtiling of Just (extra_stms, stmt') -> pure (extra_stms <> oneStm stmt') Nothing -> do blkRegTiling_res <- mmBlkRegTiling env stm case blkRegTiling_res of Just (extra_stms, stmt') -> pure (extra_stms <> oneStm stmt') Nothing -> localScope (scopeOfSegSpace space) $ do (host_stms, (lvl', space', kbody')) <- tileInKernelBody mempty initial_variance lvl space ts kbody pure $ host_stms <> oneStm (Let pat aux $ Op $ SegOp $ SegMap lvl' space' ts kbody') pure (env, stms') where initial_variance = M.map mempty $ scopeOfSegSpace space optimiseStm env (Let pat aux e) = do env' <- changeEnv env (head $ patNames pat) e e' <- mapExpM (optimise env') e pure (env', oneStm $ Let pat aux e') where optimise env' = identityMapper {mapOnBody = \scope -> localScope scope . optimiseBody env'} tileInKernelBody :: Names -> VarianceTable -> SegLevel -> SegSpace -> [Type] -> KernelBody GPU -> TileM (Stms GPU, (SegLevel, SegSpace, KernelBody GPU)) tileInKernelBody branch_variant initial_variance lvl initial_kspace ts kbody | Just kbody_res <- mapM isSimpleResult $ kernelBodyResult kbody = do maybe_tiled <- tileInBody branch_variant initial_variance lvl initial_kspace ts $ Body () (kernelBodyStms kbody) kbody_res case maybe_tiled of Just (host_stms, tiling, tiledBody) -> do (res', stms') <- runBuilder $ mapM (tilingTileReturns tiling) =<< tiledBody mempty mempty pure ( host_stms, ( tilingLevel tiling, tilingSpace tiling, KernelBody () stms' res' ) ) Nothing -> pure (mempty, (lvl, initial_kspace, kbody)) | otherwise = pure (mempty, (lvl, initial_kspace, kbody)) where isSimpleResult (Returns _ cs se) = Just $ SubExpRes cs se isSimpleResult _ = Nothing tileInBody :: Names -> VarianceTable -> SegLevel -> SegSpace -> [Type] -> Body GPU -> TileM (Maybe (Stms GPU, Tiling, TiledBody)) tileInBody branch_variant initial_variance initial_lvl initial_space res_ts (Body () initial_kstms stms_res) = descend mempty $ stmsToList initial_kstms where variance = varianceInStms initial_variance initial_kstms descend _ [] = pure Nothing descend prestms (stm_to_tile : poststms) -- 2D tiling of redomap. | (gtids, kdims) <- unzip $ unSegSpace initial_space, Just (w, arrs, form) <- tileable stm_to_tile, Just inputs <- mapM (invariantToOneOfTwoInnerDims branch_variant variance gtids) arrs, not $ null $ tiledInputs inputs, gtid_y : gtid_x : top_gtids_rev <- reverse gtids, kdim_y : kdim_x : top_kdims_rev <- reverse kdims, Just (prestms', poststms') <- preludeToPostlude variance prestms stm_to_tile (stmsFromList poststms), used <- freeIn stm_to_tile <> freeIn poststms' <> freeIn stms_res = Just . injectPrelude initial_space variance prestms' used <$> tileGeneric (tiling2d $ reverse $ zip top_gtids_rev top_kdims_rev) res_ts (stmPat stm_to_tile) (gtid_x, gtid_y) (kdim_x, kdim_y) w form inputs poststms' stms_res -- 1D tiling of redomap. | (gtid, kdim) : top_space_rev <- reverse $ unSegSpace initial_space, Just (w, arrs, form) <- tileable stm_to_tile, inputs <- map (is1DTileable gtid variance) arrs, not $ null $ tiledInputs inputs, gtid `notNameIn` branch_variant, Just (prestms', poststms') <- preludeToPostlude variance prestms stm_to_tile (stmsFromList poststms), used <- freeIn stm_to_tile <> freeIn poststms' <> freeIn stms_res = Just . injectPrelude initial_space variance prestms' used <$> tileGeneric (tiling1d $ reverse top_space_rev) res_ts (stmPat stm_to_tile) gtid kdim w form inputs poststms' stms_res -- Tiling inside for-loop. | Loop merge (ForLoop i it bound) loopbody <- stmExp stm_to_tile, not $ any ((`nameIn` freeIn merge) . paramName . fst) merge, Just (prestms', poststms') <- preludeToPostlude variance prestms stm_to_tile (stmsFromList poststms) = do let branch_variant' = branch_variant <> mconcat ( map (flip (M.findWithDefault mempty) variance) (namesToList (freeIn bound)) ) merge_params = map fst merge maybe_tiled <- localScope (M.insert i (IndexName it) $ scopeOfFParams merge_params) $ tileInBody branch_variant' variance initial_lvl initial_space (map paramType merge_params) $ mkBody (bodyStms loopbody) (bodyResult loopbody) case maybe_tiled of Nothing -> next Just tiled -> Just <$> tileLoop initial_space variance prestms' (freeIn loopbody <> freeIn merge) tiled res_ts (stmPat stm_to_tile) (stmAux stm_to_tile) merge i it bound poststms' stms_res | otherwise = next where next = localScope (scopeOf stm_to_tile) $ descend (prestms <> oneStm stm_to_tile) poststms -- | Move statements from prelude to postlude if they are not used in -- the tiled statement anyway. Also, fail if the provided Stm uses -- anything from the resulting prelude whose size is not free in the -- prelude. preludeToPostlude :: VarianceTable -> Stms GPU -> Stm GPU -> Stms GPU -> Maybe (Stms GPU, Stms GPU) preludeToPostlude variance prelude stm_to_tile postlude = do let prelude_sizes = freeIn $ foldMap (patTypes . stmPat) prelude_used prelude_bound = namesFromList $ foldMap (patNames . stmPat) prelude_used guard $ not $ prelude_sizes `namesIntersect` prelude_bound Just (prelude_used, prelude_not_used <> postlude) where used_in_tiled = freeIn stm_to_tile used_in_stm_variant = (used_in_tiled <>) $ mconcat $ map (flip (M.findWithDefault mempty) variance) $ namesToList used_in_tiled used stm = any (`nameIn` used_in_stm_variant) $ patNames $ stmPat stm (prelude_used, prelude_not_used) = Seq.partition used prelude -- | Partition prelude statements preceding a tiled loop (or something -- containing a tiled loop) into three categories: -- -- 1) Group-level statements that are invariant to the threads in the group. -- -- 2) Thread-variant statements that should be computed once with a segmap_thread_scalar. -- -- 3) Thread-variant statements that should be recomputed whenever -- they are needed. -- -- The third category duplicates computation, so we only want to do it -- when absolutely necessary. Currently, this is necessary for -- results that are views of an array (slicing, rotate, etc) and which -- results are used after the prelude, because these cannot be -- efficiently represented by a scalar segmap (they'll be manifested -- in memory). To avoid unnecessarily moving computation from -- category 2 to category 3 simply because they depend on a category 3 -- result, everything in category 3 is also in category 2. This is -- efficient only when category 3 contains exclusively "free" or at -- least very cheap expressions (e.g. index space transformations). partitionPrelude :: VarianceTable -> Stms GPU -> Names -> Names -> (Stms GPU, Stms GPU, Stms GPU) partitionPrelude variance prestms private used_after = (invariant_prestms, variant_prestms, recomputed_variant_prestms) where invariantTo names stm = case patNames (stmPat stm) of [] -> True -- Does not matter. v : _ -> all (`notNameIn` names) (namesToList $ M.findWithDefault mempty v variance) consumed_in_prestms = foldMap consumedInStm $ fst $ Alias.analyseStms mempty prestms consumed v = v `nameIn` consumed_in_prestms consumedStm stm = any consumed (patNames (stmPat stm)) later_consumed = namesFromList $ foldMap (patNames . stmPat) $ Seq.filter consumedStm prestms groupInvariant stm = invariantTo private stm && all (`notNameIn` later_consumed) (patNames (stmPat stm)) && invariantTo later_consumed stm (invariant_prestms, variant_prestms) = Seq.partition groupInvariant prestms mustBeInlinedExp (BasicOp (Index _ slice)) = not $ null $ sliceDims slice mustBeInlinedExp (BasicOp Iota {}) = True mustBeInlinedExp (BasicOp Rearrange {}) = True mustBeInlinedExp (BasicOp Reshape {}) = True mustBeInlinedExp _ = False mustBeInlined stm = mustBeInlinedExp (stmExp stm) && any (`nameIn` used_after) (patNames (stmPat stm)) must_be_inlined = namesFromList $ foldMap (patNames . stmPat) $ Seq.filter mustBeInlined variant_prestms recompute stm = any (`nameIn` must_be_inlined) (patNames (stmPat stm)) recomputed_variant_prestms = Seq.filter recompute variant_prestms -- Anything that is variant to the "private" names should be -- considered thread-local. injectPrelude :: SegSpace -> VarianceTable -> Stms GPU -> Names -> (Stms GPU, Tiling, TiledBody) -> (Stms GPU, Tiling, TiledBody) injectPrelude initial_space variance prestms used (host_stms, tiling, tiledBody) = (host_stms, tiling, tiledBody') where tiledBody' private privstms = do let nontiled = (`notElem` unSegSpace (tilingSpace tiling)) private' = private <> namesFromList (map fst (filter nontiled $ unSegSpace initial_space)) ( invariant_prestms, precomputed_variant_prestms, recomputed_variant_prestms ) = partitionPrelude variance prestms private' used addStms invariant_prestms let live_set = namesToList $ liveSet precomputed_variant_prestms $ used <> freeIn recomputed_variant_prestms prelude_arrs <- inScopeOf precomputed_variant_prestms $ doPrelude tiling privstms precomputed_variant_prestms live_set let prelude_privstms = PrivStms recomputed_variant_prestms $ mkReadPreludeValues prelude_arrs live_set tiledBody private' (prelude_privstms <> privstms) tileLoop :: SegSpace -> VarianceTable -> Stms GPU -> Names -> (Stms GPU, Tiling, TiledBody) -> [Type] -> Pat Type -> StmAux (ExpDec GPU) -> [(FParam GPU, SubExp)] -> VName -> IntType -> SubExp -> Stms GPU -> Result -> TileM (Stms GPU, Tiling, TiledBody) tileLoop initial_space variance prestms used_in_body (host_stms, tiling, tiledBody) res_ts pat aux merge i it bound poststms poststms_res = do let prestms_used = used_in_body <> freeIn poststms <> freeIn poststms_res ( invariant_prestms, precomputed_variant_prestms, recomputed_variant_prestms ) = partitionPrelude variance prestms tiled_kdims prestms_used let (mergeparams, mergeinits) = unzip merge -- Expand the loop merge parameters to be arrays. tileDim t = arrayOf t (tilingTileShape tiling) $ uniqueness t merge_scope = M.insert i (IndexName it) $ scopeOfFParams mergeparams tiledBody' private privstms = localScope (scopeOf host_stms <> merge_scope) $ do addStms invariant_prestms let live_set = namesToList $ liveSet precomputed_variant_prestms $ freeIn recomputed_variant_prestms <> prestms_used prelude_arrs <- inScopeOf precomputed_variant_prestms $ doPrelude tiling privstms precomputed_variant_prestms live_set mergeparams' <- forM mergeparams $ \(Param attrs pname pt) -> Param attrs <$> newVName (baseString pname ++ "_group") <*> pure (tileDim pt) let merge_ts = map paramType mergeparams let inloop_privstms = PrivStms recomputed_variant_prestms $ mkReadPreludeValues prelude_arrs live_set mergeinit' <- fmap (map Var) $ certifying (stmAuxCerts aux) $ tilingSegMap tiling "tiled_loopinit" ResultPrivate $ \in_bounds slice -> fmap varsRes $ protectOutOfBounds "loopinit" in_bounds merge_ts $ do addPrivStms slice inloop_privstms addPrivStms slice privstms pure $ subExpsRes mergeinits let merge' = zip mergeparams' mergeinit' let indexLoopParams slice = localScope (scopeOfFParams mergeparams') $ forM_ (zip mergeparams mergeparams') $ \(to, from) -> letBindNames [paramName to] . BasicOp . Index (paramName from) $ fullSlice (paramType from) slice private' = private <> namesFromList (map paramName mergeparams ++ map paramName mergeparams') privstms' = PrivStms mempty indexLoopParams <> privstms <> inloop_privstms loopbody' <- localScope (scopeOfFParams mergeparams') . runBodyBuilder $ varsRes <$> tiledBody private' privstms' accs' <- letTupExp "tiled_inside_loop" $ Loop merge' (ForLoop i it bound) loopbody' postludeGeneric tiling (privstms <> inloop_privstms) pat accs' poststms poststms_res res_ts pure (host_stms, tiling, tiledBody') where tiled_kdims = namesFromList $ map fst $ filter (`notElem` unSegSpace (tilingSpace tiling)) $ unSegSpace initial_space doPrelude :: Tiling -> PrivStms -> Stms GPU -> [VName] -> Builder GPU [VName] doPrelude tiling privstms prestms prestms_live = -- Create a SegMap that takes care of the prelude for every thread. tilingSegMap tiling "prelude" ResultPrivate $ \in_bounds slice -> do ts <- mapM lookupType prestms_live fmap varsRes . protectOutOfBounds "pre" in_bounds ts $ do addPrivStms slice privstms addStms prestms pure $ varsRes prestms_live liveSet :: (FreeIn a) => Stms GPU -> a -> Names liveSet stms after = namesFromList (concatMap (patNames . stmPat) stms) `namesIntersection` freeIn after tileable :: Stm GPU -> Maybe ( SubExp, [VName], (Commutativity, Lambda GPU, [SubExp], Lambda GPU) ) tileable stm | Op (OtherOp (Screma w arrs form)) <- stmExp stm, Just (reds, map_lam) <- isRedomapSOAC form, Reduce red_comm red_lam red_nes <- singleReduce reds, lambdaReturnType map_lam == lambdaReturnType red_lam, -- No mapout arrays. not $ null arrs, all primType $ lambdaReturnType map_lam, all (primType . paramType) $ lambdaParams map_lam, not $ "unroll" `inAttrs` stmAuxAttrs (stmAux stm) = Just (w, arrs, (red_comm, red_lam, red_nes, map_lam)) | otherwise = Nothing -- | We classify the inputs to the tiled loop as whether they are -- tileable (and with what permutation of the kernel indexes) or not. -- In practice, we should have at least one tileable array per loop, -- but this is not enforced in our representation. data InputArray = InputTile [Int] VName | InputDontTile VName tiledInputs :: [InputArray] -> [(VName, [Int])] tiledInputs = mapMaybe f where f (InputTile perm arr) = Just (arr, perm) f InputDontTile {} = Nothing -- | A tile (or an original untiled array). data InputTile = InputTiled [Int] VName | InputUntiled VName -- First VNames are the tiles, second are the untiled. inputsToTiles :: [InputArray] -> [VName] -> [InputTile] inputsToTiles (InputTile perm _ : inputs) (tile : tiles) = InputTiled perm tile : inputsToTiles inputs tiles inputsToTiles (InputDontTile arr : inputs) tiles = InputUntiled arr : inputsToTiles inputs tiles inputsToTiles _ _ = [] -- The atual tile size may be smaller for the last tile, so we have to -- be careful now. sliceUntiled :: (MonadBuilder m) => VName -> SubExp -> SubExp -> SubExp -> m VName sliceUntiled arr tile_id full_tile_size this_tile_size = do arr_t <- lookupType arr slice_offset <- letSubExp "slice_offset" =<< toExp (pe64 tile_id * pe64 full_tile_size) let slice = DimSlice slice_offset this_tile_size (intConst Int64 1) letExp "untiled_slice" $ BasicOp $ Index arr $ fullSlice arr_t [slice] -- | Statements that we insert directly into every thread-private -- SegMaps. This is for things that cannot efficiently be computed -- once in advance in the prelude SegMap, primarily (exclusively?) -- array slicing operations. data PrivStms = PrivStms (Stms GPU) ReadPrelude privStms :: Stms GPU -> PrivStms privStms stms = PrivStms stms $ const $ pure () addPrivStms :: [DimIndex SubExp] -> PrivStms -> Builder GPU () addPrivStms local_slice (PrivStms stms readPrelude) = do readPrelude local_slice addStms stms instance Semigroup PrivStms where PrivStms stms_x readPrelude_x <> PrivStms stms_y readPrelude_y = PrivStms stms_z readPrelude_z where stms_z = stms_x <> stms_y readPrelude_z slice = readPrelude_x slice >> readPrelude_y slice instance Monoid PrivStms where mempty = privStms mempty type ReadPrelude = [DimIndex SubExp] -> Builder GPU () data ProcessTileArgs = ProcessTileArgs { processPrivStms :: PrivStms, processComm :: Commutativity, processRedLam :: Lambda GPU, processMapLam :: Lambda GPU, processTiles :: [InputTile], processAcc :: [VName], processTileId :: SubExp } data ResidualTileArgs = ResidualTileArgs { residualPrivStms :: PrivStms, residualComm :: Commutativity, residualRedLam :: Lambda GPU, residualMapLam :: Lambda GPU, residualInput :: [InputArray], residualAcc :: [VName], residualInputSize :: SubExp, residualNumWholeTiles :: SubExp } -- | Information about a loop that has been tiled inside a kernel, as -- well as the kinds of changes that we would then like to perform on -- the kernel. data Tiling = Tiling { tilingSegMap :: String -> ResultManifest -> (PrimExp VName -> [DimIndex SubExp] -> Builder GPU Result) -> Builder GPU [VName], -- The boolean PrimExp indicates whether they are in-bounds. tilingReadTile :: TileKind -> PrivStms -> SubExp -> [InputArray] -> Builder GPU [InputTile], tilingProcessTile :: ProcessTileArgs -> Builder GPU [VName], tilingProcessResidualTile :: ResidualTileArgs -> Builder GPU [VName], tilingTileReturns :: VName -> Builder GPU KernelResult, tilingSpace :: SegSpace, tilingTileShape :: Shape, tilingLevel :: SegLevel, tilingNumWholeTiles :: Builder GPU SubExp } type DoTiling gtids kdims = gtids -> kdims -> SubExp -> Builder GPU Tiling protectOutOfBounds :: String -> PrimExp VName -> [Type] -> Builder GPU Result -> Builder GPU [VName] protectOutOfBounds desc in_bounds ts m = do -- This is more complicated than you might expect, because we need -- to be able to produce a blank accumulator, which eBlank cannot -- do. By the linear type rules of accumulators, the body returns -- an accumulator of type 'acc_t', then a unique variable of type -- 'acc_t' must also be free in the body. This means we can find it -- based just on the type. m_body <- insertStmsM $ mkBody mempty <$> m let m_body_free = namesToList $ freeIn m_body t_to_v <- filter (isAcc . fst) <$> (zip <$> mapM lookupType m_body_free <*> pure m_body_free) let blank t = maybe (eBlank t) (pure . BasicOp . SubExp . Var) $ lookup t t_to_v letTupExp desc =<< eIf (toExp in_bounds) (pure m_body) (eBody $ map blank ts) postludeGeneric :: Tiling -> PrivStms -> Pat Type -> [VName] -> Stms GPU -> Result -> [Type] -> Builder GPU [VName] postludeGeneric tiling privstms pat accs' poststms poststms_res res_ts = tilingSegMap tiling "thread_res" ResultPrivate $ \in_bounds slice -> do -- Read our per-thread result from the tiled loop. forM_ (zip (patNames pat) accs') $ \(us, everyone) -> do everyone_t <- lookupType everyone letBindNames [us] $ BasicOp $ Index everyone $ fullSlice everyone_t slice if poststms == mempty then do -- The privstms may still be necessary for the result. addPrivStms slice privstms pure poststms_res else fmap varsRes $ protectOutOfBounds "postlude" in_bounds res_ts $ do addPrivStms slice privstms addStms poststms pure poststms_res type TiledBody = Names -> PrivStms -> Builder GPU [VName] tileGeneric :: DoTiling gtids kdims -> [Type] -> Pat Type -> gtids -> kdims -> SubExp -> (Commutativity, Lambda GPU, [SubExp], Lambda GPU) -> [InputArray] -> Stms GPU -> Result -> TileM (Stms GPU, Tiling, TiledBody) tileGeneric doTiling res_ts pat gtids kdims w form inputs poststms poststms_res = do (tiling, tiling_stms) <- runBuilder $ doTiling gtids kdims w pure (tiling_stms, tiling, tiledBody tiling) where (red_comm, red_lam, red_nes, map_lam) = form tiledBody :: Tiling -> Names -> PrivStms -> Builder GPU [VName] tiledBody tiling _private privstms = do let tile_shape = tilingTileShape tiling num_whole_tiles <- tilingNumWholeTiles tiling -- We don't use a Replicate here, because we want to enforce a -- scalar memory space. mergeinits <- tilingSegMap tiling "mergeinit" ResultPrivate $ \in_bounds slice -> -- Constant neutral elements (a common case) do not need protection from OOB. if freeIn red_nes == mempty then pure $ subExpsRes red_nes else fmap varsRes $ protectOutOfBounds "neutral" in_bounds (lambdaReturnType red_lam) $ do addPrivStms slice privstms pure $ subExpsRes red_nes merge <- forM (zip (lambdaParams red_lam) mergeinits) $ \(p, mergeinit) -> (,) <$> newParam (baseString (paramName p) ++ "_merge") (paramType p `arrayOfShape` tile_shape `toDecl` Unique) <*> pure (Var mergeinit) tile_id <- newVName "tile_id" let loopform = ForLoop tile_id Int64 num_whole_tiles loopbody <- renameBody <=< runBodyBuilder $ localScope (scopeOfLoopForm loopform <> scopeOfFParams (map fst merge)) $ do -- Collectively read a tile. tile <- tilingReadTile tiling TilePartial privstms (Var tile_id) inputs -- Now each thread performs a traversal of the tile and -- updates its accumulator. let accs = map (paramName . fst) merge tile_args = ProcessTileArgs privstms red_comm red_lam map_lam tile accs (Var tile_id) varsRes <$> tilingProcessTile tiling tile_args accs <- letTupExp "accs" $ Loop merge loopform loopbody -- We possibly have to traverse a residual tile. red_lam' <- renameLambda red_lam map_lam' <- renameLambda map_lam let residual_args = ResidualTileArgs privstms red_comm red_lam' map_lam' inputs accs w num_whole_tiles accs' <- tilingProcessResidualTile tiling residual_args -- Create a SegMap that takes care of the postlude for every thread. postludeGeneric tiling privstms pat accs' poststms poststms_res res_ts mkReadPreludeValues :: [VName] -> [VName] -> ReadPrelude mkReadPreludeValues prestms_live_arrs prestms_live slice = fmap mconcat . forM (zip prestms_live_arrs prestms_live) $ \(arr, v) -> do arr_t <- lookupType arr letBindNames [v] $ BasicOp $ Index arr $ fullSlice arr_t slice tileReturns :: [(VName, SubExp)] -> [(SubExp, SubExp)] -> VName -> Builder GPU KernelResult tileReturns dims_on_top dims arr = do let unit_dims = replicate (length dims_on_top) (intConst Int64 1) arr_t <- lookupType arr arr' <- if null dims_on_top || null (arrayDims arr_t) -- Second check is for accumulators. then pure arr else do let new_shape = Shape $ unit_dims ++ arrayDims arr_t letExp (baseString arr) . BasicOp $ Reshape ReshapeArbitrary new_shape arr let tile_dims = zip (map snd dims_on_top) unit_dims ++ dims pure $ TileReturns mempty tile_dims arr' is1DTileable :: VName -> M.Map VName Names -> VName -> InputArray is1DTileable gtid variance arr | not $ nameIn gtid $ M.findWithDefault mempty arr variance = InputTile [0] arr | otherwise = InputDontTile arr reconstructGtids1D :: Count BlockSize SubExp -> VName -> VName -> VName -> Builder GPU () reconstructGtids1D tblock_size gtid gid ltid = letBindNames [gtid] =<< toExp (le64 gid * pe64 (unCount tblock_size) + le64 ltid) readTile1D :: SubExp -> VName -> VName -> KernelGrid -> TileKind -> PrivStms -> SubExp -> [InputArray] -> Builder GPU [InputTile] readTile1D tile_size gid gtid (KernelGrid _num_tblocks tblock_size) kind privstms tile_id inputs = fmap (inputsToTiles inputs) . segMap1D "full_tile" lvl ResultNoSimplify tile_size $ \ltid -> do j <- letSubExp "j" =<< toExp (pe64 tile_id * pe64 tile_size + le64 ltid) reconstructGtids1D tblock_size gtid gid ltid addPrivStms [DimFix $ Var ltid] privstms let arrs = map fst $ tiledInputs inputs arr_ts <- mapM lookupType arrs let tile_ts = map rowType arr_ts w = arraysSize 0 arr_ts let readTileElem arr = -- No need for fullSlice because we are tiling only prims. letExp "tile_elem" (BasicOp $ Index arr $ Slice [DimFix j]) fmap varsRes $ case kind of TilePartial -> letTupExp "pre1d" =<< eIf (toExp $ pe64 j .<. pe64 w) (resultBody <$> mapM (fmap Var . readTileElem) arrs) (eBody $ map eBlank tile_ts) TileFull -> mapM readTileElem arrs where lvl = SegThreadInBlock SegNoVirt processTile1D :: VName -> VName -> SubExp -> SubExp -> KernelGrid -> ProcessTileArgs -> Builder GPU [VName] processTile1D gid gtid kdim tile_size (KernelGrid _num_tblocks tblock_size) tile_args = do let red_comm = processComm tile_args privstms = processPrivStms tile_args map_lam = processMapLam tile_args red_lam = processRedLam tile_args tiles = processTiles tile_args tile_id = processTileId tile_args accs = processAcc tile_args segMap1D "acc" lvl ResultPrivate (unCount tblock_size) $ \ltid -> do reconstructGtids1D tblock_size gtid gid ltid addPrivStms [DimFix $ Var ltid] privstms -- We replace the neutral elements with the accumulators (this is -- OK because the parallel semantics are not used after this -- point). thread_accs <- forM accs $ \acc -> letSubExp "acc" $ BasicOp $ Index acc $ Slice [DimFix $ Var ltid] let sliceTile (InputTiled _ arr) = pure arr sliceTile (InputUntiled arr) = sliceUntiled arr tile_id tile_size tile_size tiles' <- mapM sliceTile tiles let form' = redomapSOAC [Reduce red_comm red_lam thread_accs] map_lam fmap varsRes $ letTupExp "acc" =<< eIf (toExp $ le64 gtid .<. pe64 kdim) (eBody [pure $ Op $ OtherOp $ Screma tile_size tiles' form']) (resultBodyM thread_accs) where lvl = SegThreadInBlock SegNoVirt processResidualTile1D :: VName -> VName -> SubExp -> SubExp -> KernelGrid -> ResidualTileArgs -> Builder GPU [VName] processResidualTile1D gid gtid kdim tile_size grid args = do -- The number of residual elements that are not covered by -- the whole tiles. residual_input <- letSubExp "residual_input" $ BasicOp $ BinOp (SRem Int64 Unsafe) w tile_size letTupExp "acc_after_residual" =<< eIf (toExp $ pe64 residual_input .==. 0) (resultBodyM $ map Var accs) (nonemptyTile residual_input) where red_comm = residualComm args map_lam = residualMapLam args red_lam = residualRedLam args privstms = residualPrivStms args inputs = residualInput args accs = residualAcc args num_whole_tiles = residualNumWholeTiles args w = residualInputSize args nonemptyTile residual_input = runBodyBuilder $ do -- Collectively construct a tile. Threads that are out-of-bounds -- provide a blank dummy value. full_tiles <- readTile1D tile_size gid gtid grid TilePartial privstms num_whole_tiles inputs let sliceTile (InputUntiled arr) = pure $ InputUntiled arr sliceTile (InputTiled perm tile) = do let slice = DimSlice (intConst Int64 0) residual_input (intConst Int64 1) InputTiled perm <$> letExp "partial_tile" (BasicOp $ Index tile $ Slice [slice]) tiles <- mapM sliceTile full_tiles -- Now each thread performs a traversal of the tile and -- updates its accumulator. let tile_args = ProcessTileArgs privstms red_comm red_lam map_lam tiles accs num_whole_tiles varsRes <$> processTile1D gid gtid kdim residual_input grid tile_args tiling1d :: [(VName, SubExp)] -> DoTiling VName SubExp tiling1d dims_on_top gtid kdim w = do gid <- newVName "gid" gid_flat <- newVName "gid_flat" tile_size_key <- nameFromString . prettyString <$> newVName "tile_size" tile_size <- letSubExp "tile_size" $ Op $ SizeOp $ GetSize tile_size_key SizeThreadBlock let tblock_size = tile_size (grid, space) <- do -- How many blocks we need to exhaust the innermost dimension. ldim <- letSubExp "ldim" . BasicOp $ BinOp (SDivUp Int64 Unsafe) kdim tblock_size num_tblocks <- letSubExp "computed_num_tblocks" =<< foldBinOp (Mul Int64 OverflowUndef) ldim (map snd dims_on_top) pure ( KernelGrid (Count num_tblocks) (Count tblock_size), SegSpace gid_flat $ dims_on_top ++ [(gid, ldim)] ) let tiling_lvl = SegThreadInBlock SegNoVirt pure Tiling { tilingSegMap = \desc manifest f -> segMap1D desc tiling_lvl manifest tile_size $ \ltid -> do letBindNames [gtid] =<< toExp (le64 gid * pe64 tile_size + le64 ltid) f (untyped $ le64 gtid .<. pe64 kdim) [DimFix $ Var ltid], tilingReadTile = readTile1D tile_size gid gtid grid, tilingProcessTile = processTile1D gid gtid kdim tile_size grid, tilingProcessResidualTile = processResidualTile1D gid gtid kdim tile_size grid, tilingTileReturns = tileReturns dims_on_top [(kdim, tile_size)], tilingTileShape = Shape [tile_size], tilingNumWholeTiles = letSubExp "num_whole_tiles" $ BasicOp $ BinOp (SQuot Int64 Unsafe) w tile_size, tilingLevel = SegBlock SegNoVirt (Just grid), tilingSpace = space } invariantToOneOfTwoInnerDims :: Names -> M.Map VName Names -> [VName] -> VName -> Maybe InputArray invariantToOneOfTwoInnerDims branch_variant variance dims arr = do j : i : _ <- Just $ reverse dims let variant_to = M.findWithDefault mempty arr variance branch_invariant = not $ nameIn j branch_variant || nameIn i branch_variant if branch_invariant && i `nameIn` variant_to && j `notNameIn` variant_to then Just $ InputTile [0, 1] arr else if branch_invariant && j `nameIn` variant_to && i `notNameIn` variant_to then Just $ InputTile [1, 0] arr else Just $ InputDontTile arr -- Reconstruct the original gtids from group and local IDs. reconstructGtids2D :: SubExp -> (VName, VName) -> (VName, VName) -> (VName, VName) -> Builder GPU () reconstructGtids2D tile_size (gtid_x, gtid_y) (gid_x, gid_y) (ltid_x, ltid_y) = do -- Reconstruct the original gtids from gid_x/gid_y and ltid_x/ltid_y. letBindNames [gtid_x] =<< toExp (le64 gid_x * pe64 tile_size + le64 ltid_x) letBindNames [gtid_y] =<< toExp (le64 gid_y * pe64 tile_size + le64 ltid_y) readTile2D :: (SubExp, SubExp) -> (VName, VName) -> (VName, VName) -> SubExp -> TileKind -> PrivStms -> SubExp -> [InputArray] -> Builder GPU [InputTile] readTile2D (kdim_x, kdim_y) (gtid_x, gtid_y) (gid_x, gid_y) tile_size kind privstms tile_id inputs = fmap (inputsToTiles inputs) . segMap2D "full_tile" (SegThread (SegNoVirtFull (SegSeqDims [])) Nothing) ResultNoSimplify (tile_size, tile_size) $ \(ltid_x, ltid_y) -> do i <- letSubExp "i" =<< toExp (pe64 tile_id * pe64 tile_size + le64 ltid_x) j <- letSubExp "j" =<< toExp (pe64 tile_id * pe64 tile_size + le64 ltid_y) reconstructGtids2D tile_size (gtid_x, gtid_y) (gid_x, gid_y) (ltid_x, ltid_y) addPrivStms [DimFix $ Var ltid_x, DimFix $ Var ltid_y] privstms let arrs_and_perms = tiledInputs inputs readTileElem (arr, perm) = -- No need for fullSlice because we are tiling only prims. letExp "tile_elem" ( BasicOp . Index arr $ Slice [DimFix $ last $ rearrangeShape perm [i, j]] ) readTileElemIfInBounds (arr, perm) = do arr_t <- lookupType arr let tile_t = rowType arr_t w = arraySize 0 arr_t idx = last $ rearrangeShape perm [i, j] othercheck = last $ rearrangeShape perm [ le64 gtid_y .<. pe64 kdim_y, le64 gtid_x .<. pe64 kdim_x ] eIf (toExp $ pe64 idx .<. pe64 w .&&. othercheck) (eBody [pure $ BasicOp $ Index arr $ Slice [DimFix idx]]) (eBody [eBlank tile_t]) fmap varsRes $ case kind of TilePartial -> mapM (letExp "pre2d" <=< readTileElemIfInBounds) arrs_and_perms TileFull -> mapM readTileElem arrs_and_perms findTileSize :: (HasScope rep m) => [InputTile] -> m SubExp findTileSize tiles = case mapMaybe isTiled tiles of v : _ -> arraySize 0 <$> lookupType v [] -> pure $ intConst Int64 0 where isTiled InputUntiled {} = Nothing isTiled (InputTiled _ tile) = Just tile processTile2D :: (VName, VName) -> (VName, VName) -> (SubExp, SubExp) -> SubExp -> ProcessTileArgs -> Builder GPU [VName] processTile2D (gid_x, gid_y) (gtid_x, gtid_y) (kdim_x, kdim_y) tile_size tile_args = do let privstms = processPrivStms tile_args red_comm = processComm tile_args red_lam = processRedLam tile_args map_lam = processMapLam tile_args tiles = processTiles tile_args accs = processAcc tile_args tile_id = processTileId tile_args -- Might be truncated in case of a partial tile. actual_tile_size <- findTileSize tiles segMap2D "acc" (SegThreadInBlock (SegNoVirtFull (SegSeqDims []))) ResultPrivate (tile_size, tile_size) $ \(ltid_x, ltid_y) -> do reconstructGtids2D tile_size (gtid_x, gtid_y) (gid_x, gid_y) (ltid_x, ltid_y) addPrivStms [DimFix $ Var ltid_x, DimFix $ Var ltid_y] privstms -- We replace the neutral elements with the accumulators (this is -- OK because the parallel semantics are not used after this -- point). thread_accs <- forM accs $ \acc -> letSubExp "acc" $ BasicOp $ Index acc $ Slice [DimFix $ Var ltid_x, DimFix $ Var ltid_y] let form' = redomapSOAC [Reduce red_comm red_lam thread_accs] map_lam sliceTile (InputUntiled arr) = sliceUntiled arr tile_id tile_size actual_tile_size sliceTile (InputTiled perm tile) = do tile_t <- lookupType tile let idx = DimFix $ Var $ head $ rearrangeShape perm [ltid_x, ltid_y] letExp "tile" $ BasicOp $ Index tile $ sliceAt tile_t (head perm) [idx] tiles' <- mapM sliceTile tiles fmap varsRes $ letTupExp "acc" =<< eIf ( toExp $ le64 gtid_x .<. pe64 kdim_x .&&. le64 gtid_y .<. pe64 kdim_y ) (eBody [pure $ Op $ OtherOp $ Screma actual_tile_size tiles' form']) (resultBodyM thread_accs) processResidualTile2D :: (VName, VName) -> (VName, VName) -> (SubExp, SubExp) -> SubExp -> ResidualTileArgs -> Builder GPU [VName] processResidualTile2D gids gtids kdims tile_size args = do -- The number of residual elements that are not covered by -- the whole tiles. residual_input <- letSubExp "residual_input" $ BasicOp $ BinOp (SRem Int64 Unsafe) w tile_size letTupExp "acc_after_residual" =<< eIf (toExp $ pe64 residual_input .==. 0) (resultBodyM $ map Var accs) (nonemptyTile residual_input) where privstms = residualPrivStms args red_comm = residualComm args red_lam = residualRedLam args map_lam = residualMapLam args accs = residualAcc args inputs = residualInput args num_whole_tiles = residualNumWholeTiles args w = residualInputSize args nonemptyTile residual_input = renameBody <=< runBodyBuilder $ do -- Collectively construct a tile. Threads that are out-of-bounds -- provide a blank dummy value. full_tile <- readTile2D kdims gtids gids tile_size TilePartial privstms num_whole_tiles inputs let slice = DimSlice (intConst Int64 0) residual_input (intConst Int64 1) tiles <- forM full_tile $ \case InputTiled perm tile' -> InputTiled perm <$> letExp "partial_tile" (BasicOp $ Index tile' (Slice [slice, slice])) InputUntiled arr -> pure $ InputUntiled arr let tile_args = ProcessTileArgs privstms red_comm red_lam map_lam tiles accs num_whole_tiles -- Now each thread performs a traversal of the tile and -- updates its accumulator. varsRes <$> processTile2D gids gtids kdims tile_size tile_args tiling2d :: [(VName, SubExp)] -> DoTiling (VName, VName) (SubExp, SubExp) tiling2d dims_on_top (gtid_x, gtid_y) (kdim_x, kdim_y) w = do gid_x <- newVName "gid_x" gid_y <- newVName "gid_y" tile_size_key <- nameFromString . prettyString <$> newVName "tile_size" tile_size <- letSubExp "tile_size" $ Op $ SizeOp $ GetSize tile_size_key SizeTile tblock_size <- letSubExp "tblock_size" $ BasicOp $ BinOp (Mul Int64 OverflowUndef) tile_size tile_size num_tblocks_x <- letSubExp "num_tblocks_x" $ BasicOp $ BinOp (SDivUp Int64 Unsafe) kdim_x tile_size num_tblocks_y <- letSubExp "num_tblocks_y" $ BasicOp $ BinOp (SDivUp Int64 Unsafe) kdim_y tile_size num_tblocks <- letSubExp "num_tblocks_top" =<< foldBinOp (Mul Int64 OverflowUndef) num_tblocks_x (num_tblocks_y : map snd dims_on_top) gid_flat <- newVName "gid_flat" let grid = KernelGrid (Count num_tblocks) (Count tblock_size) lvl = SegBlock (SegNoVirtFull (SegSeqDims [])) (Just grid) space = SegSpace gid_flat $ dims_on_top ++ [(gid_x, num_tblocks_x), (gid_y, num_tblocks_y)] tiling_lvl = SegThreadInBlock SegNoVirt pure Tiling { tilingSegMap = \desc manifest f -> segMap2D desc tiling_lvl manifest (tile_size, tile_size) $ \(ltid_x, ltid_y) -> do reconstructGtids2D tile_size (gtid_x, gtid_y) (gid_x, gid_y) (ltid_x, ltid_y) f ( untyped $ le64 gtid_x .<. pe64 kdim_x .&&. le64 gtid_y .<. pe64 kdim_y ) [DimFix $ Var ltid_x, DimFix $ Var ltid_y], tilingReadTile = readTile2D (kdim_x, kdim_y) (gtid_x, gtid_y) (gid_x, gid_y) tile_size, tilingProcessTile = processTile2D (gid_x, gid_y) (gtid_x, gtid_y) (kdim_x, kdim_y) tile_size, tilingProcessResidualTile = processResidualTile2D (gid_x, gid_y) (gtid_x, gtid_y) (kdim_x, kdim_y) tile_size, tilingTileReturns = tileReturns dims_on_top [(kdim_x, tile_size), (kdim_y, tile_size)], tilingTileShape = Shape [tile_size, tile_size], tilingNumWholeTiles = letSubExp "num_whole_tiles" $ BasicOp $ BinOp (SQuot Int64 Unsafe) w tile_size, tilingLevel = lvl, tilingSpace = space } futhark-0.25.27/src/Futhark/Optimise/TileLoops/000077500000000000000000000000001475065116200212405ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Optimise/TileLoops/Shared.hs000066400000000000000000000256751475065116200230210ustar00rootroot00000000000000module Futhark.Optimise.TileLoops.Shared ( TileM, Env, index, update, forLoop', forLoop, segMap1D, segMap2D, segMap3D, segScatter2D, VarianceTable, varianceInStms, isTileableRedomap, changeEnv, TileKind (..), ) where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.List (foldl', zip4) import Data.Map qualified as M import Futhark.IR.GPU import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.IR.SeqMem qualified as ExpMem import Futhark.MonadFreshNames import Futhark.Tools import Futhark.Transform.Rename type TileM = ReaderT (Scope GPU) (State VNameSource) -- | Are we working with full or partial tiles? data TileKind = TilePartial | TileFull -- index an array with indices given in outer_indices; any inner -- dims of arr not indexed by outer_indices are sliced entirely index :: (MonadBuilder m) => String -> VName -> [VName] -> m VName index se_desc arr outer_indices = do arr_t <- lookupType arr let slice = fullSlice arr_t $ map (DimFix . Var) outer_indices letExp se_desc $ BasicOp $ Index arr slice update :: (MonadBuilder m) => String -> VName -> [VName] -> SubExp -> m VName update se_desc arr indices new_elem = letExp se_desc $ BasicOp $ Update Unsafe arr (Slice $ map (DimFix . Var) indices) new_elem forLoop' :: SubExp -> -- loop var [VName] -> -- loop inits ( VName -> [VName] -> -- (loop var -> loop inits -> loop body) Builder GPU (Body GPU) ) -> Builder GPU [VName] forLoop' i_bound merge body = do i <- newVName "i" -- could give this as arg to the function let loop_form = ForLoop i Int64 i_bound merge_ts <- mapM lookupType merge loop_inits <- mapM (\merge_t -> newParam "merge" $ toDecl merge_t Unique) merge_ts loop_body <- insertStmsM $ localScope (scopeOfLoopForm loop_form <> scopeOfFParams loop_inits) $ body i $ map paramName loop_inits letTupExp "loop" $ Loop (zip loop_inits $ map Var merge) loop_form loop_body forLoop :: SubExp -> [VName] -> (VName -> [VName] -> Builder GPU (Body GPU)) -> Builder GPU VName forLoop i_bound merge body = do res_list <- forLoop' i_bound merge body pure $ head res_list segMap1D :: String -> SegLevel -> ResultManifest -> SubExp -> -- dim_x (VName -> Builder GPU Result) -> Builder GPU [VName] segMap1D desc lvl manifest w f = do ltid <- newVName "ltid" ltid_flat <- newVName "ltid_flat" let space = SegSpace ltid_flat [(ltid, w)] ((ts, res), stms) <- localScope (scopeOfSegSpace space) . runBuilder $ do res <- f ltid ts <- mapM subExpResType res pure (ts, res) Body _ stms' res' <- renameBody $ mkBody stms res let ret (SubExpRes cs se) = Returns manifest cs se letTupExp desc $ Op . SegOp $ SegMap lvl space ts $ KernelBody () stms' $ map ret res' segMap2D :: String -> -- desc SegLevel -> -- lvl ResultManifest -> -- manifest (SubExp, SubExp) -> -- (dim_x, dim_y) ( (VName, VName) -> -- f Builder GPU Result ) -> Builder GPU [VName] segMap2D desc lvl manifest (dim_y, dim_x) f = do ltid_xx <- newVName "ltid_x" ltid_yy <- newVName "ltid_y" ltid_flat <- newVName "ltid_flat" let segspace = SegSpace ltid_flat [(ltid_yy, dim_y), (ltid_xx, dim_x)] ((ts, res), stms) <- localScope (scopeOfSegSpace segspace) . runBuilder $ do res <- f (ltid_yy, ltid_xx) ts <- mapM subExpResType res pure (ts, res) let ret (SubExpRes cs se) = Returns manifest cs se letTupExp desc <=< renameExp $ Op . SegOp $ SegMap lvl segspace ts $ KernelBody () stms $ map ret res segMap3D :: String -> -- desc SegLevel -> -- lvl ResultManifest -> -- manifest (SubExp, SubExp, SubExp) -> -- (dim_z, dim_y, dim_x) ( (VName, VName, VName) -> -- f Builder GPU Result ) -> Builder GPU [VName] segMap3D desc lvl manifest (dim_z, dim_y, dim_x) f = do ltid_flat <- newVName "ltid_flat" ltid_z <- newVName "ltid_z" ltid_y <- newVName "ltid_y" ltid_x <- newVName "ltid_x" let segspace = SegSpace ltid_flat [(ltid_z, dim_z), (ltid_y, dim_y), (ltid_x, dim_x)] ((ts, res), stms) <- localScope (scopeOfSegSpace segspace) . runBuilder $ do res <- f (ltid_z, ltid_y, ltid_x) ts <- mapM subExpResType res pure (ts, res) let ret (SubExpRes cs se) = Returns manifest cs se letTupExp desc <=< renameExp $ Op . SegOp $ SegMap lvl segspace ts $ KernelBody () stms $ map ret res segScatter2D :: String -> VName -> [SubExp] -> -- dims of sequential loop on top (SubExp, SubExp) -> -- (dim_y, dim_x) ([VName] -> (VName, VName) -> Builder GPU (SubExp, SubExp)) -> -- f Builder GPU VName segScatter2D desc updt_arr seq_dims (dim_x, dim_y) f = do ltid_flat <- newVName "ltid_flat" ltid_y <- newVName "ltid_y" ltid_x <- newVName "ltid_x" seq_is <- replicateM (length seq_dims) (newVName "ltid_seq") let seq_space = zip seq_is seq_dims let segspace = SegSpace ltid_flat $ seq_space ++ [(ltid_y, dim_y), (ltid_x, dim_x)] lvl = SegThreadInBlock (SegNoVirtFull (SegSeqDims [0 .. length seq_dims - 1])) ((res_v, res_i), stms) <- runBuilder . localScope (scopeOfSegSpace segspace) $ f seq_is (ltid_y, ltid_x) let ret = WriteReturns mempty updt_arr [(Slice [DimFix res_i], res_v)] let body = KernelBody () stms [ret] updt_arr_t <- lookupType updt_arr letExp desc <=< renameExp $ Op $ SegOp $ SegMap lvl segspace [updt_arr_t] body -- | The variance table keeps a mapping from a variable name -- (something produced by a 'Stm') to the kernel thread indices -- that name depends on. If a variable is not present in this table, -- that means it is bound outside the kernel (and so can be considered -- invariant to all dimensions). type VarianceTable = M.Map VName Names isTileableRedomap :: Stm GPU -> Maybe ( SubExp, [VName], (Commutativity, Lambda GPU, [SubExp], Lambda GPU) ) isTileableRedomap stm | Op (OtherOp (Screma w arrs form)) <- stmExp stm, Just (reds, map_lam) <- isRedomapSOAC form, Reduce red_comm red_lam red_nes <- singleReduce reds, all (primType . rowType . paramType) $ lambdaParams red_lam, all (primType . rowType . paramType) $ lambdaParams map_lam, lambdaReturnType map_lam == lambdaReturnType red_lam, -- No mapout arrays. not (null arrs), all primType $ lambdaReturnType map_lam, all (primType . paramType) $ lambdaParams map_lam = Just (w, arrs, (red_comm, red_lam, red_nes, map_lam)) | otherwise = Nothing defVarianceInStm :: VarianceTable -> Stm GPU -> VarianceTable defVarianceInStm variance stm = foldl' add variance $ patNames $ stmPat stm where add variance' v = M.insert v binding_variance variance' look variance' v = oneName v <> M.findWithDefault mempty v variance' binding_variance = mconcat $ map (look variance) $ namesToList (freeIn stm) -- just in case you need the Screma being treated differently than -- by default; previously Cosmin had to enhance it when dealing with stream. varianceInStm :: VarianceTable -> Stm GPU -> VarianceTable varianceInStm v0 stm@(Let _ _ (Op (OtherOp Screma {}))) | Just (_, arrs, (_, red_lam, red_nes, map_lam)) <- isTileableRedomap stm = let v = defVarianceInStm v0 stm red_ps = lambdaParams red_lam map_ps = lambdaParams map_lam card_red = length red_nes acc_lam_f = take (card_red `quot` 2) red_ps arr_lam_f = drop (card_red `quot` 2) red_ps stm_lam = bodyStms (lambdaBody map_lam) <> bodyStms (lambdaBody red_lam) f vacc (v_a, v_fm, v_fr_acc, v_fr_var) = let vrc = oneName v_a <> M.findWithDefault mempty v_a vacc vacc' = M.insert v_fm vrc vacc vrc' = oneName v_fm <> vrc in M.insert v_fr_acc (oneName v_fr_var <> vrc') $ M.insert v_fr_var vrc' vacc' v' = foldl' f v $ zip4 arrs (map paramName map_ps) (map paramName acc_lam_f) (map paramName arr_lam_f) in varianceInStms v' stm_lam varianceInStm v0 stm = defVarianceInStm v0 stm varianceInStms :: VarianceTable -> Stms GPU -> VarianceTable varianceInStms = foldl' varianceInStm ---------------- ---- Helpers for building the environment that binds array variable names to their index functions ---------------- type LMAD = LMAD.LMAD (TPrimExp Int64 VName) -- | Map from array variable names to their corresponding index functions. -- The info is not guaranteed to be exact, e.g., we assume ifs and loops -- return arrays layed out in normalized (row-major) form in memory. -- We only record aliasing statements, such as transposition, slice, etc. type IxFnEnv = M.Map VName LMAD type WithEnv = M.Map VName (Lambda GPU, [SubExp]) type Env = (WithEnv, IxFnEnv) changeEnv :: Env -> VName -> Exp GPU -> TileM Env changeEnv (with_env, ixfn_env) y e = do with_env' <- changeWithEnv with_env e ixfn_env' <- changeIxFnEnv ixfn_env y e pure (with_env', ixfn_env') changeWithEnv :: WithEnv -> Exp GPU -> TileM WithEnv changeWithEnv with_env (WithAcc accum_decs inner_lam) = do let bindings = map mapfun accum_decs par_tps = take (length bindings) $ map paramName $ lambdaParams inner_lam with_env' = M.union with_env $ M.fromList $ zip par_tps bindings pure with_env' where mapfun (_, _, Nothing) = error "What the hack is an accumulator without operator?" mapfun (shp, _, Just (lam_inds, ne)) = let len_inds = length $ shapeDims shp lam_op = lam_inds {lambdaParams = drop len_inds $ lambdaParams lam_inds} in (lam_op, ne) changeWithEnv with_env _ = pure with_env composeIxfuns :: IxFnEnv -> VName -> VName -> (LMAD -> Maybe LMAD) -> TileM IxFnEnv composeIxfuns env y x ixf_fun = case ixf_fun =<< M.lookup x env of Just ixf -> pure $ M.insert y ixf env Nothing -> do tp <- lookupType x pure $ case tp of Array _ptp shp _u | Just ixf <- ixf_fun $ LMAD.iota 0 $ map ExpMem.pe64 (shapeDims shp) -> M.insert y ixf env _ -> env changeIxFnEnv :: IxFnEnv -> VName -> Exp GPU -> TileM IxFnEnv changeIxFnEnv env y (BasicOp (Reshape ReshapeArbitrary shp_chg x)) = composeIxfuns env y x (`LMAD.reshape` fmap ExpMem.pe64 (shapeDims shp_chg)) changeIxFnEnv env y (BasicOp (Reshape ReshapeCoerce shp_chg x)) = composeIxfuns env y x (Just . (`LMAD.coerce` fmap ExpMem.pe64 (shapeDims shp_chg))) changeIxFnEnv env y (BasicOp (Manifest perm x)) = do tp <- lookupType x case tp of Array _ptp shp _u -> do let shp' = map ExpMem.pe64 (shapeDims shp) let ixfn = LMAD.permute (LMAD.iota 0 shp') perm pure $ M.insert y ixfn env _ -> error "In TileLoops/Shared.hs, changeIxFnEnv: manifest applied to a non-array!" changeIxFnEnv env y (BasicOp (Rearrange perm x)) = composeIxfuns env y x (Just . (`LMAD.permute` perm)) changeIxFnEnv env y (BasicOp (Index x slc)) = composeIxfuns env y x (Just . (`LMAD.slice` Slice (map (fmap ExpMem.pe64) $ unSlice slc))) changeIxFnEnv env y (BasicOp (Opaque _ (Var x))) = composeIxfuns env y x Just changeIxFnEnv env _ _ = pure env futhark-0.25.27/src/Futhark/Optimise/Unstream.hs000066400000000000000000000122211475065116200214560ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Sequentialise any remaining SOACs. It is very important that -- this is run *after* any access-pattern-related optimisation, -- because this pass will destroy information. -- -- This pass conceptually contains three subpasses: -- -- 1. Sequentialise 'Stream' operations, leaving other SOACs intact. -- -- 2. Apply whole-program simplification. -- -- 3. Sequentialise remaining SOACs. -- -- This is because sequentialisation of streams creates many SOACs -- operating on single-element arrays, which can be efficiently -- simplified away, but only *before* they are turned into loops. In -- principle this pass could be split into multiple, but for now it is -- kept together. module Futhark.Optimise.Unstream (unstreamGPU, unstreamMC) where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Futhark.IR.GPU import Futhark.IR.GPU qualified as GPU import Futhark.IR.GPU.Simplify (simplifyGPU) import Futhark.IR.MC import Futhark.IR.MC qualified as MC import Futhark.MonadFreshNames import Futhark.Pass import Futhark.Tools import Futhark.Transform.FirstOrderTransform qualified as FOT -- | The pass for GPU kernels. unstreamGPU :: Pass GPU GPU unstreamGPU = unstream onHostOp simplifyGPU -- | The pass for multicore. unstreamMC :: Pass MC MC unstreamMC = unstream onMCOp MC.simplifyProg data Stage = SeqStreams | SeqAll unstream :: (ASTRep rep) => (Stage -> OnOp rep) -> (Prog rep -> PassM (Prog rep)) -> Pass rep rep unstream onOp simplify = Pass "unstream" "sequentialise remaining SOACs" $ intraproceduralTransformation (optimise SeqStreams) >=> simplify >=> intraproceduralTransformation (optimise SeqAll) where optimise stage scope stms = modifyNameSource $ runState $ runReaderT (optimiseStms (onOp stage) stms) scope type UnstreamM rep = ReaderT (Scope rep) (State VNameSource) type OnOp rep = Pat (LetDec rep) -> StmAux (ExpDec rep) -> Op rep -> UnstreamM rep [Stm rep] optimiseStms :: (ASTRep rep) => OnOp rep -> Stms rep -> UnstreamM rep (Stms rep) optimiseStms onOp stms = localScope (scopeOf stms) $ stmsFromList . concat <$> mapM (optimiseStm onOp) (stmsToList stms) optimiseBody :: (ASTRep rep) => OnOp rep -> Body rep -> UnstreamM rep (Body rep) optimiseBody onOp (Body aux stms res) = Body aux <$> optimiseStms onOp stms <*> pure res optimiseKernelBody :: (ASTRep rep) => OnOp rep -> KernelBody rep -> UnstreamM rep (KernelBody rep) optimiseKernelBody onOp (KernelBody attr stms res) = localScope (scopeOf stms) $ KernelBody attr <$> (stmsFromList . concat <$> mapM (optimiseStm onOp) (stmsToList stms)) <*> pure res optimiseLambda :: (ASTRep rep) => OnOp rep -> Lambda rep -> UnstreamM rep (Lambda rep) optimiseLambda onOp lam = localScope (scopeOfLParams $ lambdaParams lam) $ do body <- optimiseBody onOp $ lambdaBody lam pure lam {lambdaBody = body} optimiseStm :: (ASTRep rep) => OnOp rep -> Stm rep -> UnstreamM rep [Stm rep] optimiseStm onOp (Let pat aux (Op op)) = onOp pat aux op optimiseStm onOp (Let pat aux e) = pure <$> (Let pat aux <$> mapExpM optimise e) where optimise = identityMapper { mapOnBody = \scope -> localScope scope . optimiseBody onOp } optimiseSegOp :: (ASTRep rep) => OnOp rep -> SegOp lvl rep -> UnstreamM rep (SegOp lvl rep) optimiseSegOp onOp op = localScope (scopeOfSegSpace $ segSpace op) $ mapSegOpM optimise op where optimise = identitySegOpMapper { mapOnSegOpBody = optimiseKernelBody onOp, mapOnSegOpLambda = optimiseLambda onOp } onMCOp :: Stage -> OnOp MC onMCOp stage pat aux (ParOp par_op op) = do par_op' <- traverse (optimiseSegOp (onMCOp stage)) par_op op' <- optimiseSegOp (onMCOp stage) op pure [Let pat aux $ Op $ ParOp par_op' op'] onMCOp stage pat aux (MC.OtherOp soac) | sequentialise stage soac = do stms <- runBuilder_ $ auxing aux $ FOT.transformSOAC pat soac fmap concat . localScope (scopeOf stms) $ mapM (optimiseStm (onMCOp stage)) (stmsToList stms) | otherwise = -- Still sequentialise whatever's inside. pure <$> (Let pat aux . Op . MC.OtherOp <$> mapSOACM optimise soac) where optimise = identitySOACMapper { mapOnSOACLambda = optimiseLambda (onMCOp stage) } sequentialise :: Stage -> SOAC rep -> Bool sequentialise SeqStreams Stream {} = True sequentialise SeqStreams _ = False sequentialise SeqAll _ = True onHostOp :: Stage -> OnOp GPU onHostOp stage pat aux (GPU.OtherOp soac) | sequentialise stage soac = do stms <- runBuilder_ $ auxing aux $ FOT.transformSOAC pat soac fmap concat . localScope (scopeOf stms) $ mapM (optimiseStm (onHostOp stage)) (stmsToList stms) | otherwise = -- Still sequentialise whatever's inside. pure <$> (Let pat aux . Op . GPU.OtherOp <$> mapSOACM optimise soac) where optimise = identitySOACMapper { mapOnSOACLambda = optimiseLambda (onHostOp stage) } onHostOp stage pat aux (SegOp op) = pure <$> (Let pat aux . Op . SegOp <$> optimiseSegOp (onHostOp stage) op) onHostOp _ pat aux op = pure [Let pat aux $ Op op] futhark-0.25.27/src/Futhark/Pass.hs000066400000000000000000000067411475065116200170070ustar00rootroot00000000000000{-# LANGUAGE Strict #-} -- | Definition of a polymorphic (generic) pass that can work with -- programs of any rep. module Futhark.Pass ( PassM, runPassM, Pass (..), passLongOption, parPass, intraproceduralTransformation, intraproceduralTransformationWithConsts, ) where import Control.Monad.State.Strict import Control.Monad.Writer.Strict import Control.Parallel.Strategies import Data.Char import Futhark.IR import Futhark.MonadFreshNames import Futhark.Util.Log import Prelude hiding (log) -- | The monad in which passes execute. newtype PassM a = PassM (WriterT Log (State VNameSource) a) deriving (Functor, Applicative, Monad) instance MonadLogger PassM where addLog = PassM . tell instance MonadFreshNames PassM where putNameSource = PassM . put getNameSource = PassM get -- | Execute a 'PassM' action, yielding logging information and either -- an error pretty or a result. runPassM :: (MonadFreshNames m) => PassM a -> m (a, Log) runPassM (PassM m) = modifyNameSource $ runState (runWriterT m) -- | A compiler pass transforming a 'Prog' of a given rep to a 'Prog' -- of another rep. data Pass fromrep torep = Pass { -- | Name of the pass. Keep this short and simple. It will -- be used to automatically generate a command-line option -- name via 'passLongOption'. passName :: String, -- | A slightly longer description, which will show up in the -- command-line --help option. passDescription :: String, passFunction :: Prog fromrep -> PassM (Prog torep) } -- | Take the name of the pass, turn spaces into dashes, and make all -- characters lowercase. passLongOption :: Pass fromrep torep -> String passLongOption = map (spaceToDash . toLower) . passName where spaceToDash ' ' = '-' spaceToDash c = c -- | Apply a 'PassM' operation in parallel to multiple elements, -- joining together the name sources and logs, and propagating any -- error properly. parPass :: (a -> PassM b) -> [a] -> PassM [b] parPass f as = do (x, log) <- modifyNameSource $ \src -> let (bs, logs, srcs) = unzip3 $ parMap rpar (f' src) as in ((bs, mconcat logs), mconcat srcs) addLog log pure x where f' src a = let ((x', log), src') = runState (runPassM (f a)) src in (x', log, src') -- | Apply some operation to the top-level constants. Then applies an -- operation to all the function definitions, which are also given the -- transformed constants so they can be brought into scope. -- The function definition transformations are run in parallel (with -- 'parPass'), since they cannot affect each other. intraproceduralTransformationWithConsts :: (Stms fromrep -> PassM (Stms torep)) -> (Stms torep -> FunDef fromrep -> PassM (FunDef torep)) -> Prog fromrep -> PassM (Prog torep) intraproceduralTransformationWithConsts ct ft prog = do consts' <- ct (progConsts prog) funs' <- parPass (ft consts') (progFuns prog) pure $ prog {progConsts = consts', progFuns = funs'} -- | Like 'intraproceduralTransformationWithConsts', but do not change -- the top-level constants, and simply pass along their 'Scope'. intraproceduralTransformation :: (Scope rep -> Stms rep -> PassM (Stms rep)) -> Prog rep -> PassM (Prog rep) intraproceduralTransformation f = intraproceduralTransformationWithConsts (f mempty) f' where f' consts fd = do stms <- f (scopeOf consts <> scopeOfFParams (funDefParams fd)) (bodyStms $ funDefBody fd) pure fd {funDefBody = (funDefBody fd) {bodyStms = stms}} futhark-0.25.27/src/Futhark/Pass/000077500000000000000000000000001475065116200164435ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Pass/AD.hs000066400000000000000000000065771475065116200173020ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Apply all AD operators in the program, leaving AD-free code. module Futhark.Pass.AD (applyAD, applyADInnermost) where import Control.Monad import Control.Monad.Reader import Futhark.AD.Fwd (fwdJVP) import Futhark.AD.Rev (revVJP) import Futhark.Builder import Futhark.IR.SOACS import Futhark.IR.SOACS.Simplify (simplifyLambda) import Futhark.Pass -- | Whether we apply only the innermost AD operators, or all of them. -- The former is very useful for debugging, but probably not useful -- for actual compilation. data Mode = Innermost | All deriving (Eq) bindLambda :: (MonadBuilder m, Rep m ~ SOACS) => Pat Type -> StmAux (ExpDec SOACS) -> Lambda SOACS -> [SubExp] -> m () bindLambda pat aux (Lambda params _ body) args = do auxing aux . forM_ (zip params args) $ \(param, arg) -> letBindNames [paramName param] $ BasicOp $ case paramType param of Array {} -> Replicate mempty arg _ -> SubExp arg res <- bodyBind body forM_ (zip (patNames pat) res) $ \(v, SubExpRes cs se) -> certifying cs $ letBindNames [v] $ BasicOp $ SubExp se onStm :: Mode -> Scope SOACS -> Stm SOACS -> PassM (Stms SOACS) onStm mode scope (Let pat aux (Op (VJP args vec lam))) = do lam' <- onLambda mode scope lam if mode == All || lam == lam' then do lam'' <- (`runReaderT` scope) . simplifyLambda =<< revVJP scope lam' runBuilderT_ (bindLambda pat aux lam'' $ args ++ vec) scope else pure $ oneStm $ Let pat aux $ Op $ VJP args vec lam' onStm mode scope (Let pat aux (Op (JVP args vec lam))) = do lam' <- onLambda mode scope lam if mode == All || lam == lam' then do lam'' <- fwdJVP scope lam' runBuilderT_ (bindLambda pat aux lam'' $ args ++ vec) scope else pure $ oneStm $ Let pat aux $ Op $ JVP args vec lam' onStm mode scope (Let pat aux e) = oneStm . Let pat aux <$> mapExpM mapper e where mapper = (identityMapper @SOACS) { mapOnBody = \bscope -> onBody mode (bscope <> scope), mapOnOp = mapSOACM soac_mapper } soac_mapper = identitySOACMapper {mapOnSOACLambda = onLambda mode scope} onStms :: Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS) onStms mode scope stms = mconcat <$> mapM (onStm mode scope') (stmsToList stms) where scope' = scopeOf stms <> scope onBody :: Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS) onBody mode scope body = do stms <- onStms mode scope $ bodyStms body pure $ body {bodyStms = stms} onLambda :: Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS) onLambda mode scope lam = do body <- onBody mode (scopeOfLParams (lambdaParams lam) <> scope) $ lambdaBody lam pure $ lam {lambdaBody = body} onFun :: Mode -> Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS) onFun mode consts fd = do body <- onBody mode (scopeOf consts <> scopeOf fd) $ funDefBody fd pure $ fd {funDefBody = body} applyAD :: Pass SOACS SOACS applyAD = Pass { passName = "ad", passDescription = "Apply AD operators", passFunction = intraproceduralTransformationWithConsts (onStms All mempty) (onFun All) } applyADInnermost :: Pass SOACS SOACS applyADInnermost = Pass { passName = "ad innermost", passDescription = "Apply innermost AD operators", passFunction = intraproceduralTransformationWithConsts (onStms Innermost mempty) (onFun Innermost) } futhark-0.25.27/src/Futhark/Pass/ExpandAllocations.hs000066400000000000000000001106011475065116200224060ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Expand allocations inside of maps when possible. module Futhark.Pass.ExpandAllocations (expandAllocations) where import Control.Monad import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State import Control.Monad.Writer import Data.Bifunctor import Data.Either (rights) import Data.List (find, foldl') import Data.Map.Strict qualified as M import Data.Maybe import Data.Sequence qualified as Seq import Futhark.Analysis.Alias as Alias import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Error import Futhark.IR import Futhark.IR.GPU.Simplify qualified as GPU import Futhark.IR.GPUMem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.MonadFreshNames import Futhark.Optimise.Simplify.Rep (addScopeWisdom) import Futhark.Pass import Futhark.Pass.ExplicitAllocations.GPU (explicitAllocationsInStms) import Futhark.Pass.ExtractKernels.BlockedKernel (nonSegRed) import Futhark.Pass.ExtractKernels.ToGPU (segThread) import Futhark.Tools import Futhark.Transform.CopyPropagate (copyPropagateInFun) import Futhark.Transform.Rename (renameStm) import Futhark.Transform.Substitute import Futhark.Util (mapAccumLM) import Prelude hiding (quot) -- | The memory expansion pass definition. expandAllocations :: Pass GPUMem GPUMem expandAllocations = Pass "expand allocations" "Expand allocations" $ \prog -> do consts' <- modifyNameSource $ limitationOnLeft . runStateT (runReaderT (transformStms (progConsts prog)) mempty) funs' <- mapM (transformFunDef $ scopeOf consts') (progFuns prog) pure $ prog {progConsts = consts', progFuns = funs'} -- Cannot use intraproceduralTransformation because it might create -- duplicate size keys (which are not fixed by renamer, and size -- keys must currently be globally unique). type ExpandM = ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) limitationOnLeft :: Either String a -> a limitationOnLeft = either compilerLimitationS id transformFunDef :: Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem) transformFunDef scope fundec = do body' <- modifyNameSource $ limitationOnLeft . runStateT (runReaderT m mempty) copyPropagateInFun simpleGPUMem (ST.fromScope (addScopeWisdom scope)) fundec {funDefBody = body'} where m = localScope scope $ inScopeOf fundec $ transformBody $ funDefBody fundec transformBody :: Body GPUMem -> ExpandM (Body GPUMem) transformBody (Body () stms res) = Body () <$> transformStms stms <*> pure res transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem) transformLambda (Lambda params ret body) = Lambda params ret <$> localScope (scopeOfLParams params) (transformBody body) transformStms :: Stms GPUMem -> ExpandM (Stms GPUMem) transformStms stms = inScopeOf stms $ mconcat <$> mapM transformStm (stmsToList stms) transformStm :: Stm GPUMem -> ExpandM (Stms GPUMem) -- It is possible that we are unable to expand allocations in some -- code versions. If so, we can remove the offending branch. Only if -- all versions fail do we propagate the error. -- FIXME: this can remove safety checks if the default branch fails! transformStm (Let pat aux (Match cond cases defbody (MatchDec ts MatchEquiv))) = do let onCase (Case vs body) = (Right . Case vs <$> transformBody body) `catchError` (pure . Left) cases' <- rights <$> mapM onCase cases defbody' <- (Right <$> transformBody defbody) `catchError` (pure . Left) case (cases', defbody') of ([], Left e) -> throwError e (_ : _, Left _) -> pure $ oneStm $ Let pat aux $ Match cond (init cases') (caseBody $ last cases') (MatchDec ts MatchEquiv) (_, Right defbody'') -> pure $ oneStm $ Let pat aux $ Match cond cases' defbody'' (MatchDec ts MatchEquiv) transformStm (Let pat aux e) = do (stms, e') <- transformExp =<< mapExpM transform e pure $ stms <> oneStm (Let pat aux e') where transform = identityMapper { mapOnBody = \scope -> localScope scope . transformBody } transformExp :: Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem) transformExp (Op (Inner (SegOp (SegMap lvl space ts kbody)))) = do (alloc_stms, (lvl', _, kbody')) <- transformScanRed lvl space [] kbody pure ( alloc_stms, Op $ Inner $ SegOp $ SegMap lvl' space ts kbody' ) transformExp (Op (Inner (SegOp (SegRed lvl space reds ts kbody)))) = do (alloc_stms, (lvl', lams, kbody')) <- transformScanRed lvl space (map segBinOpLambda reds) kbody let reds' = zipWith (\red lam -> red {segBinOpLambda = lam}) reds lams pure ( alloc_stms, Op $ Inner $ SegOp $ SegRed lvl' space reds' ts kbody' ) transformExp (Op (Inner (SegOp (SegScan lvl space scans ts kbody)))) = do (alloc_stms, (lvl', lams, kbody')) <- transformScanRed lvl space (map segBinOpLambda scans) kbody let scans' = zipWith (\red lam -> red {segBinOpLambda = lam}) scans lams pure ( alloc_stms, Op $ Inner $ SegOp $ SegScan lvl' space scans' ts kbody' ) transformExp (Op (Inner (SegOp (SegHist lvl space ops ts kbody)))) = do (alloc_stms, (lvl', lams', kbody')) <- transformScanRed lvl space lams kbody let ops' = zipWith onOp ops lams' pure ( alloc_stms, Op $ Inner $ SegOp $ SegHist lvl' space ops' ts kbody' ) where lams = map histOp ops onOp op lam = op {histOp = lam} transformExp (WithAcc inputs lam) = do lam' <- transformLambda lam (input_alloc_stms, inputs') <- mapAndUnzipM onInput inputs pure ( mconcat input_alloc_stms, WithAcc inputs' lam' ) where onInput (shape, arrs, Nothing) = pure (mempty, (shape, arrs, Nothing)) onInput (shape, arrs, Just (op_lam, nes)) = do bound_outside <- asks $ namesFromList . M.keys let -- XXX: fake a SegLevel, which we don't have here. We will not -- use it for anything, as we will not allow irregular -- allocations inside the update function. lvl = SegThread SegNoVirt Nothing (op_lam', lam_allocs) = extractLambdaAllocations (lvl, [0]) bound_outside mempty op_lam variantAlloc (_, Var v, _) = v `notNameIn` bound_outside variantAlloc _ = False (variant_allocs, invariant_allocs) = M.partition variantAlloc lam_allocs case M.elems variant_allocs of (_, v, _) : _ -> throwError $ "Cannot handle un-sliceable allocation size: " ++ prettyString v ++ "\nLikely cause: irregular nested operations inside accumulator update operator." [] -> pure () let num_is = shapeRank shape is = map paramName $ take num_is $ lambdaParams op_lam (alloc_stms, alloc_offsets) <- genericExpandedInvariantAllocations (const $ const (shape, map le64 is)) invariant_allocs scope <- askScope let scope' = scopeOf op_lam <> scope <> scopeOf alloc_stms either throwError pure <=< runOffsetM scope' $ do op_lam'' <- offsetMemoryInLambda alloc_offsets op_lam' pure (alloc_stms, (shape, arrs, Just (op_lam'', nes))) transformExp e = pure (mempty, e) ensureGridKnown :: SegLevel -> ExpandM (Stms GPUMem, SegLevel, KernelGrid) ensureGridKnown lvl = case lvl of SegThread _ (Just grid) -> pure (mempty, lvl, grid) SegBlock _ (Just grid) -> pure (mempty, lvl, grid) SegThread virt Nothing -> mkGrid (SegThread virt) SegBlock virt Nothing -> mkGrid (SegBlock virt) SegThreadInBlock {} -> error "ensureGridKnown: SegThreadInBlock" where mkGrid f = do (grid, stms) <- runBuilder $ KernelGrid <$> (Count <$> getSize "num_tblocks" SizeGrid) <*> (Count <$> getSize "tblock_size" SizeThreadBlock) pure (stms, f $ Just grid, grid) getSize desc size_class = do size_key <- nameFromString . prettyString <$> newVName desc letSubExp desc $ Op $ Inner $ SizeOp $ GetSize size_key size_class transformScanRed :: SegLevel -> SegSpace -> [Lambda GPUMem] -> KernelBody GPUMem -> ExpandM (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem)) transformScanRed lvl space ops kbody = do bound_outside <- asks $ namesFromList . M.keys let user = (lvl, [le64 $ segFlat space]) (kbody', kbody_allocs) = extractKernelBodyAllocations user bound_outside bound_in_kernel kbody (ops', ops_allocs) = unzip $ map (extractLambdaAllocations user bound_outside mempty) ops variantAlloc (_, Var v, _) = v `notNameIn` bound_outside variantAlloc _ = False (variant_allocs, invariant_allocs) = M.partition variantAlloc $ kbody_allocs <> mconcat ops_allocs badVariant (_, Var v, _) = v `notNameIn` bound_in_kernel badVariant _ = False case find badVariant $ M.elems variant_allocs of Just v -> throwError $ "Cannot handle un-sliceable allocation size: " ++ prettyString v ++ "\nLikely cause: irregular nested operations inside parallel constructs." Nothing -> pure () case lvl of SegBlock {} | not $ null variant_allocs -> throwError "Cannot handle invariant allocations in SegBlock." _ -> pure () if null variant_allocs && null invariant_allocs then pure (mempty, (lvl, ops, kbody)) else do (lvl_stms, lvl', grid) <- ensureGridKnown lvl allocsForBody variant_allocs invariant_allocs grid space kbody kbody' $ \offsets alloc_stms kbody'' -> do ops'' <- forM ops' $ \op' -> localScope (scopeOf op') $ offsetMemoryInLambda offsets op' pure (lvl_stms <> alloc_stms, (lvl', ops'', kbody'')) where bound_in_kernel = namesFromList (M.keys $ scopeOfSegSpace space) <> boundInKernelBody kbody boundInKernelBody :: KernelBody GPUMem -> Names boundInKernelBody = namesFromList . M.keys . scopeOf . kernelBodyStms addStmsToKernelBody :: Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem addStmsToKernelBody stms kbody = kbody {kernelBodyStms = stms <> kernelBodyStms kbody} allocsForBody :: Extraction -> Extraction -> KernelGrid -> SegSpace -> KernelBody GPUMem -> KernelBody GPUMem -> (RebaseMap -> Stms GPUMem -> KernelBody GPUMem -> OffsetM b) -> ExpandM b allocsForBody variant_allocs invariant_allocs grid space kbody kbody' m = do (alloc_offsets, alloc_stms) <- memoryRequirements grid space (kernelBodyStms kbody) variant_allocs invariant_allocs -- We assume that any shared memory allocations can be inserted back -- into kbody'. This would not work if we had SegRed/SegScan -- operations that performed shared memory allocations. We don't -- currently, and if we would in the future, we would need to be -- more careful about summarising the allocations in -- transformScanRed. let (alloc_stms_dev, alloc_stms_shared) = Seq.partition (not . isSharedAlloc) alloc_stms scope <- askScope let scope' = scopeOfSegSpace space <> scope <> scopeOf alloc_stms either throwError pure <=< runOffsetM scope' $ do kbody'' <- addStmsToKernelBody alloc_stms_shared <$> offsetMemoryInKernelBody alloc_offsets kbody' m alloc_offsets alloc_stms_dev kbody'' where isSharedAlloc (Let _ _ (Op (Alloc _ (Space "shared")))) = True isSharedAlloc _ = False memoryRequirements :: KernelGrid -> SegSpace -> Stms GPUMem -> Extraction -> Extraction -> ExpandM (RebaseMap, Stms GPUMem) memoryRequirements grid space kstms variant_allocs invariant_allocs = do (num_threads, num_threads_stms) <- runBuilder . letSubExp "num_threads" . BasicOp $ BinOp (Mul Int64 OverflowUndef) (unCount $ gridNumBlocks grid) (unCount $ gridBlockSize grid) (invariant_alloc_stms, invariant_alloc_offsets) <- inScopeOf num_threads_stms $ expandedInvariantAllocations num_threads (gridNumBlocks grid) (gridBlockSize grid) invariant_allocs (variant_alloc_stms, variant_alloc_offsets) <- inScopeOf num_threads_stms $ expandedVariantAllocations num_threads space kstms variant_allocs pure ( invariant_alloc_offsets <> variant_alloc_offsets, num_threads_stms <> invariant_alloc_stms <> variant_alloc_stms ) type Exp64 = TPrimExp Int64 VName -- | Identifying the spot where an allocation occurs in terms of its -- level and unique thread ID. type User = (SegLevel, [Exp64]) -- | A description of allocations that have been extracted, and how -- much memory (and which space) is needed. type Extraction = M.Map VName (User, SubExp, Space) extractKernelBodyAllocations :: User -> Names -> Names -> KernelBody GPUMem -> ( KernelBody GPUMem, Extraction ) extractKernelBodyAllocations lvl bound_outside bound_kernel = extractGenericBodyAllocations lvl bound_outside bound_kernel kernelBodyStms $ \stms kbody -> kbody {kernelBodyStms = stms} extractBodyAllocations :: User -> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction) extractBodyAllocations user bound_outside bound_kernel = extractGenericBodyAllocations user bound_outside bound_kernel bodyStms $ \stms body -> body {bodyStms = stms} extractLambdaAllocations :: User -> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction) extractLambdaAllocations user bound_outside bound_kernel lam = (lam {lambdaBody = body'}, allocs) where (body', allocs) = extractBodyAllocations user bound_outside bound_kernel $ lambdaBody lam extractGenericBodyAllocations :: User -> Names -> Names -> (body -> Stms GPUMem) -> (Stms GPUMem -> body -> body) -> body -> ( body, Extraction ) extractGenericBodyAllocations user bound_outside bound_kernel get_stms set_stms body = let bound_kernel' = bound_kernel <> boundByStms (get_stms body) (stms, allocs) = runWriter . fmap catMaybes $ mapM (extractStmAllocations user bound_outside bound_kernel') $ stmsToList (get_stms body) in (set_stms (stmsFromList stms) body, allocs) expandable :: User -> Space -> Bool expandable (SegBlock {}, _) (Space "shared") = False expandable _ ScalarSpace {} = False expandable _ _ = True notScalar :: Space -> Bool notScalar ScalarSpace {} = False notScalar _ = True extractStmAllocations :: User -> Names -> Names -> Stm GPUMem -> Writer Extraction (Maybe (Stm GPUMem)) extractStmAllocations user bound_outside bound_kernel (Let (Pat [patElem]) _ (Op (Alloc size space))) | expandable user space && expandableSize size -- FIXME: the '&& notScalar space' part is a hack because we -- don't otherwise hoist the sizes out far enough, and we -- promise to be super-duper-careful about not having variant -- scalar allocations. || (boundInKernel size && notScalar space) = do tell $ M.singleton (patElemName patElem) (user, size, space) pure Nothing where expandableSize (Var v) = v `nameIn` bound_outside || v `nameIn` bound_kernel expandableSize Constant {} = True boundInKernel (Var v) = v `nameIn` bound_kernel boundInKernel Constant {} = False extractStmAllocations user bound_outside bound_kernel stm = do e <- mapExpM (expMapper user) $ stmExp stm pure $ Just $ stm {stmExp = e} where expMapper user' = (identityMapper @GPUMem) { mapOnBody = const $ onBody user', mapOnOp = onOp user' } onBody user' body = do let (body', allocs) = extractBodyAllocations user' bound_outside bound_kernel body tell allocs pure body' onOp (_, user_ids) (Inner (SegOp op)) = Inner . SegOp <$> mapSegOpM (opMapper user'') op where user'' = (segLevel op, user_ids ++ [le64 (segFlat (segSpace op))]) onOp _ op = pure op opMapper user' = identitySegOpMapper { mapOnSegOpLambda = onLambda user', mapOnSegOpBody = onKernelBody user' } onKernelBody user' body = do let (body', allocs) = extractKernelBodyAllocations user' bound_outside bound_kernel body tell allocs pure body' onLambda user' lam = do body <- onBody user' $ lambdaBody lam pure lam {lambdaBody = body} genericExpandedInvariantAllocations :: (User -> Space -> (Shape, [Exp64])) -> Extraction -> ExpandM (Stms GPUMem, RebaseMap) genericExpandedInvariantAllocations getNumUsers invariant_allocs = do -- We expand the invariant allocations by adding an inner dimension -- equal to the number of kernel threads. (rebases, alloc_stms) <- runBuilder $ mapM expand $ M.toList invariant_allocs pure (alloc_stms, mconcat rebases) where expand (mem, (user, per_thread_size, space)) = do let num_users = fst $ getNumUsers user space allocpat = Pat [PatElem mem $ MemMem space] total_size <- letExp "total_size" <=< toExp . product $ pe64 per_thread_size : map pe64 (shapeDims num_users) letBind allocpat $ Op $ Alloc (Var total_size) space pure $ M.singleton mem $ newBase user space newBaseThread user space _old_shape = let (users_shape, user_ids) = getNumUsers user space dims = map pe64 (shapeDims users_shape) in ( flattenIndex dims user_ids, product dims ) newBase user@(SegThreadInBlock {}, _) space = newBaseThread user space newBase user@(SegThread {}, _) space = newBaseThread user space newBase user@(SegBlock {}, _) space = \_old_shape -> let (users_shape, user_ids) = getNumUsers user space dims = map pe64 (shapeDims users_shape) in ( flattenIndex dims user_ids, product dims ) expandedInvariantAllocations :: SubExp -> Count NumBlocks SubExp -> Count BlockSize SubExp -> Extraction -> ExpandM (Stms GPUMem, RebaseMap) expandedInvariantAllocations num_threads (Count num_tblocks) (Count tblock_size) = genericExpandedInvariantAllocations getNumUsers where getNumUsers (SegThread {}, [gtid]) _ = (Shape [num_threads], [gtid]) getNumUsers (SegThread {}, [gid, ltid]) _ = (Shape [num_tblocks, tblock_size], [gid, ltid]) getNumUsers (SegThreadInBlock {}, [gtid]) _ = (Shape [num_threads], [gtid]) getNumUsers (SegThreadInBlock {}, [_gid, ltid]) (Space "shared") = (Shape [tblock_size], [ltid]) getNumUsers (SegThreadInBlock {}, [gid, ltid]) (Space "device") = (Shape [num_tblocks, tblock_size], [gid, ltid]) getNumUsers (SegBlock {}, [gid]) _ = (Shape [num_tblocks], [gid]) getNumUsers user space = error $ "getNumUsers: unhandled " ++ show (user, space) expandedVariantAllocations :: SubExp -> SegSpace -> Stms GPUMem -> Extraction -> ExpandM (Stms GPUMem, RebaseMap) expandedVariantAllocations _ _ _ variant_allocs | null variant_allocs = pure (mempty, mempty) expandedVariantAllocations num_threads kspace kstms variant_allocs = do let sizes_to_blocks = removeCommonSizes variant_allocs variant_sizes = map fst sizes_to_blocks (slice_stms, offsets, size_sums) <- sliceKernelSizes num_threads variant_sizes kspace kstms -- Note the recursive call to expand allocations inside the newly -- produced kernels. slice_stms_tmp <- simplifyStms =<< explicitAllocationsInStms slice_stms slice_stms' <- transformStms slice_stms_tmp let variant_allocs' :: [(VName, (SubExp, SubExp, Space))] variant_allocs' = concat $ zipWith memInfo (map snd sizes_to_blocks) (zip offsets size_sums) memInfo blocks (offset, total_size) = [(mem, (Var offset, Var total_size, space)) | (mem, space) <- blocks] -- We expand the invariant allocations by adding an inner dimension -- equal to the sum of the sizes required by different threads. (alloc_stms, rebases) <- mapAndUnzipM expand variant_allocs' pure (slice_stms' <> stmsFromList alloc_stms, mconcat rebases) where expand (mem, (_offset, total_size, space)) = do let allocpat = Pat [PatElem mem $ MemMem space] pure ( Let allocpat (defAux ()) $ Op $ Alloc total_size space, M.singleton mem newBase ) num_threads' = pe64 num_threads gtid = le64 $ segFlat kspace -- For the variant allocations, we add an inner dimension, -- which is then offset by a thread-specific amount. newBase _old_shape = (gtid, num_threads') type Expansion = (Exp64, Exp64) -- | A map from memory block names to index function embeddings.. type RebaseMap = M.Map VName ([Exp64] -> Expansion) --- Modifying the index functions of code. newtype OffsetM a = OffsetM (BuilderT GPUMem (StateT VNameSource (Either String)) a) deriving ( Applicative, Functor, Monad, HasScope GPUMem, LocalScope GPUMem, MonadError String, MonadFreshNames ) instance MonadBuilder OffsetM where type Rep OffsetM = GPUMem mkExpDecM pat e = OffsetM $ mkExpDecM pat e mkBodyM stms res = OffsetM $ mkBodyM stms res mkLetNamesM pat e = OffsetM $ mkLetNamesM pat e addStms = OffsetM . addStms collectStms (OffsetM m) = OffsetM $ collectStms m runOffsetM :: (MonadFreshNames m) => Scope GPUMem -> OffsetM a -> m (Either String a) runOffsetM scope (OffsetM m) = modifyNameSource $ \src -> case runStateT (runBuilderT m scope) src of Left e -> (Left e, src) Right (x, src') -> (Right (fst x), src') lookupNewBase :: VName -> [Exp64] -> RebaseMap -> Maybe Expansion lookupNewBase name x offsets = ($ x) <$> M.lookup name offsets offsetMemoryInKernelBody :: RebaseMap -> KernelBody GPUMem -> OffsetM (KernelBody GPUMem) offsetMemoryInKernelBody offsets kbody = do stms' <- collectStms_ $ mapM_ (addStm <=< offsetMemoryInStm offsets) (kernelBodyStms kbody) pure kbody {kernelBodyStms = stms'} offsetMemoryInBody :: RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem) offsetMemoryInBody offsets (Body _ stms res) = do buildBody_ $ do mapM_ (addStm <=< offsetMemoryInStm offsets) stms pure res argsContext :: [SubExp] -> OffsetM [SubExp] argsContext = fmap concat . mapM resCtx where resCtx se = do v_t <- subExpMemInfo se case v_t of MemArray _ _ _ (ArrayIn mem lmad) -> do ctxs <- mapM (letSubExp "ctx" <=< toExp) (LMAD.existentialized lmad) pure $ Var mem : ctxs _ -> pure [] offsetMemoryInBodyReturnCtx :: RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem) offsetMemoryInBodyReturnCtx offsets (Body _ stms res) = do buildBody_ $ do mapM_ (addStm <=< offsetMemoryInStm offsets) stms ctx <- argsContext $ map resSubExp res pure $ res <> subExpsRes ctx lmadFrom :: LMAD.Shape num -> [num] -> LMAD.LMAD num lmadFrom shape xs = LMAD.LMAD (head xs) $ zipWith LMAD.LMADDim (drop 1 xs) shape -- | Append pattern elements corresponding to memory and index -- function components for every array bound in the pattern. addPatternContext :: Pat LetDecMem -> OffsetM (Pat LetDecMem) addPatternContext (Pat pes) = localScope (scopeOfPat (Pat pes)) $ do (pes_ctx, pes') <- mapAccumLM onType [] pes pure $ Pat $ pes' <> pes_ctx where onType acc (PatElem pe_v (MemArray pt pe_shape pe_u (ArrayIn pe_mem lmad))) = do space <- lookupMemSpace pe_mem pe_mem' <- newVName $ baseString pe_mem <> "_ext" let num_exts = length (LMAD.existentialized lmad) lmad_exts <- replicateM num_exts $ PatElem <$> newVName "ext" <*> pure (MemPrim int64) let pe_lmad' = lmadFrom (LMAD.shape lmad) $ map (le64 . patElemName) lmad_exts pure ( acc ++ PatElem pe_mem' (MemMem space) : lmad_exts, PatElem pe_v $ MemArray pt pe_shape pe_u $ ArrayIn pe_mem' pe_lmad' ) onType acc t = pure (acc, t) -- | Append pattern elements corresponding to memory and index -- function components for every array bound in the parameters. addParamsContext :: [Param FParamMem] -> OffsetM [Param FParamMem] addParamsContext ps = localScope (scopeOfFParams ps) $ do (ps_ctx, ps') <- mapAccumLM onType [] ps pure $ ps' <> ps_ctx where onType acc (Param attr v (MemArray pt shape u (ArrayIn mem lmad))) = do space <- lookupMemSpace mem mem' <- newVName $ baseString mem <> "_ext" let num_exts = length (LMAD.existentialized lmad) lmad_exts <- replicateM num_exts $ Param mempty <$> newVName "ext" <*> pure (MemPrim int64) let lmad' = lmadFrom (LMAD.shape lmad) $ map (le64 . paramName) lmad_exts pure ( acc ++ Param mempty mem' (MemMem space) : lmad_exts, Param attr v $ MemArray pt shape u $ ArrayIn mem' lmad' ) onType acc t = pure (acc, t) offsetBranch :: Pat LetDecMem -> [BranchTypeMem] -> OffsetM (Pat LetDecMem, [BranchTypeMem]) offsetBranch (Pat pes) ts = do ((pes_ctx, ts_ctx), (pes', ts')) <- bimap unzip unzip <$> mapAccumLM onType [] (zip pes ts) pure (Pat $ pes' <> pes_ctx, ts' <> ts_ctx) where onType acc ( PatElem pe_v (MemArray _ pe_shape pe_u (ArrayIn pe_mem pe_lmad)), MemArray pt shape u meminfo ) = do (space, lmad) <- case meminfo of ReturnsInBlock mem lmad -> do space <- lookupMemSpace mem pure (space, lmad) ReturnsNewBlock space _ lmad -> pure (space, lmad) pe_mem' <- newVName $ baseString pe_mem <> "_ext" let start = length ts + length acc num_exts = length (LMAD.existentialized lmad) ext (Free se) = Free <$> pe64 se ext (Ext i) = le64 (Ext i) lmad_exts <- replicateM num_exts $ PatElem <$> newVName "ext" <*> pure (MemPrim int64) let pe_lmad' = lmadFrom (LMAD.shape pe_lmad) $ map (le64 . patElemName) lmad_exts pure ( acc ++ (PatElem pe_mem' $ MemMem space, MemMem space) : map (,MemPrim int64) lmad_exts, ( PatElem pe_v $ MemArray pt pe_shape pe_u $ ArrayIn pe_mem' pe_lmad', MemArray pt shape u . ReturnsNewBlock space start . fmap ext $ LMAD.mkExistential (shapeDims shape) (1 + start) ) ) onType acc t = pure (acc, t) offsetMemoryInPat :: RebaseMap -> Pat LetDecMem -> [ExpReturns] -> Pat LetDecMem offsetMemoryInPat offsets (Pat pes) rets = do Pat $ zipWith onPE pes rets where onPE (PatElem name (MemArray pt shape u (ArrayIn mem _))) (MemArray _ _ _ info) | Just lmad <- getLMAD info = PatElem name . MemArray pt shape u . ArrayIn mem $ fmap (fmap unExt) lmad onPE pe _ = offsetMemoryInMemBound offsets <$> pe unExt (Ext i) = patElemName (pes !! i) unExt (Free v) = v getLMAD (Just (ReturnsNewBlock _ _ lmad)) = Just lmad getLMAD (Just (ReturnsInBlock _ lmad)) = Just lmad getLMAD _ = Nothing offsetMemoryInParam :: RebaseMap -> Param (MemBound u) -> Param (MemBound u) offsetMemoryInParam offsets = fmap $ offsetMemoryInMemBound offsets offsetMemoryInMemBound :: RebaseMap -> MemBound u -> MemBound u offsetMemoryInMemBound offsets (MemArray pt shape u (ArrayIn mem lmad)) | Just (o, p) <- lookupNewBase mem (LMAD.shape lmad) offsets = MemArray pt shape u $ ArrayIn mem $ LMAD.expand o p lmad offsetMemoryInMemBound _ info = info offsetMemoryInBodyReturns :: RebaseMap -> BodyReturns -> BodyReturns offsetMemoryInBodyReturns offsets (MemArray pt shape u (ReturnsInBlock mem lmad)) | Just lmad' <- isStaticLMAD lmad, Just (o, p) <- lookupNewBase mem (LMAD.shape lmad') offsets = MemArray pt shape u $ ReturnsInBlock mem $ LMAD.expand (Free <$> o) (fmap Free p) lmad offsetMemoryInBodyReturns _ br = br offsetMemoryInLambda :: RebaseMap -> Lambda GPUMem -> OffsetM (Lambda GPUMem) offsetMemoryInLambda offsets lam = do body <- inScopeOf lam $ offsetMemoryInBody offsets $ lambdaBody lam let params = map (offsetMemoryInParam offsets) $ lambdaParams lam pure $ lam {lambdaBody = body, lambdaParams = params} -- A loop may have memory parameters, and those memory blocks may -- be expanded. We assume (but do not check - FIXME) that if the -- initial value of a loop parameter is an expanded memory block, -- then so will the result be. offsetMemoryInLoopParams :: RebaseMap -> [(FParam GPUMem, SubExp)] -> (RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM a) -> OffsetM a offsetMemoryInLoopParams offsets merge f = do let (params, args) = unzip merge params' <- addParamsContext params args' <- (args <>) <$> argsContext args f offsets' $ zip params' args' where offsets' = extend offsets extend rm = foldl' onParamArg rm merge onParamArg rm (param, Var arg) | Just x <- M.lookup arg rm = M.insert (paramName param) x rm onParamArg rm _ = rm -- | Handles only the expressions where we do not change the number of -- results; meaning anything except Loop, Match, and nonscalar Apply. offsetMemoryInExp :: RebaseMap -> Exp GPUMem -> OffsetM (Exp GPUMem) offsetMemoryInExp offsets = mapExpM recurse where recurse = (identityMapper @GPUMem) { mapOnBody = \bscope -> localScope bscope . offsetMemoryInBody offsets, mapOnBranchType = pure . offsetMemoryInBodyReturns offsets, mapOnOp = onOp } onOp (Inner (SegOp op)) = Inner . SegOp <$> localScope (scopeOfSegSpace (segSpace op)) (mapSegOpM segOpMapper op) where segOpMapper = identitySegOpMapper { mapOnSegOpBody = offsetMemoryInKernelBody offsets, mapOnSegOpLambda = offsetMemoryInLambda offsets } onOp op = pure op offsetMemoryInStm :: RebaseMap -> Stm GPUMem -> OffsetM (Stm GPUMem) offsetMemoryInStm offsets (Let pat dec (Match cond cases defbody (MatchDec ts kind))) = do cases' <- forM cases $ \(Case vs body) -> Case vs <$> offsetMemoryInBodyReturnCtx offsets body defbody' <- offsetMemoryInBodyReturnCtx offsets defbody (pat', ts') <- offsetBranch pat ts pure $ Let pat' dec $ Match cond cases' defbody' $ MatchDec ts' kind offsetMemoryInStm offsets (Let pat dec (Loop merge form body)) = do loop' <- offsetMemoryInLoopParams offsets merge $ \offsets' merge' -> do body' <- localScope (scopeOfFParams (map fst merge') <> scopeOfLoopForm form) (offsetMemoryInBodyReturnCtx offsets' body) pure $ Loop merge' form body' pat' <- addPatternContext pat pure $ Let pat' dec loop' offsetMemoryInStm offsets (Let pat dec e) = do e' <- offsetMemoryInExp offsets e pat' <- offsetMemoryInPat offsets pat <$> ( maybe (throwError "offsetMemoryInStm: ill-typed") pure =<< expReturns e' ) scope <- askScope -- Try to recompute the index function. Fall back to creating rebase -- operations with the RebaseMap. rts <- maybe (throwError "offsetMemoryInStm: ill-typed") pure $ runReader (expReturns e') scope let pat'' = Pat $ zipWith pick (patElems pat') rts pure $ Let pat'' dec e' where pick (PatElem name (MemArray pt s u _ret)) (MemArray _ _ _ (Just (ReturnsInBlock m extlmad))) | Just lmad <- instantiateLMAD extlmad = PatElem name (MemArray pt s u (ArrayIn m lmad)) pick p _ = p instantiateLMAD :: ExtLMAD -> Maybe LMAD instantiateLMAD = traverse (traverse inst) where inst Ext {} = Nothing inst (Free x) = pure x ---- Slicing allocation sizes out of a kernel. unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU.GPU) unAllocGPUStms = unAllocStms False where unAllocBody (Body dec stms res) = Body dec <$> unAllocStms True stms <*> pure res unAllocKernelBody (KernelBody dec stms res) = KernelBody dec <$> unAllocStms True stms <*> pure res unAllocStms nested = mapM (unAllocStm nested) unAllocStm nested stm@(Let pat dec (Op Alloc {})) | nested = throwError $ "Cannot handle nested allocation: " <> prettyString stm | otherwise = Let <$> unAllocPat pat <*> pure dec <*> pure (BasicOp (SubExp $ Constant UnitValue)) unAllocStm _ (Let pat dec e) = Let <$> unAllocPat pat <*> pure dec <*> mapExpM unAlloc' e unAllocLambda (Lambda params ret body) = Lambda (map unParam params) ret <$> unAllocBody body unAllocPat (Pat pes) = Pat <$> mapM (rephrasePatElem (Right . unMem)) pes unAllocOp Alloc {} = Left "unAllocOp: unhandled Alloc" unAllocOp (Inner OtherOp {}) = Left "unAllocOp: unhandled OtherOp" unAllocOp (Inner GPUBody {}) = Left "unAllocOp: unhandled GPUBody" unAllocOp (Inner (SizeOp op)) = pure $ SizeOp op unAllocOp (Inner (SegOp op)) = SegOp <$> mapSegOpM mapper op where mapper = identitySegOpMapper { mapOnSegOpLambda = unAllocLambda, mapOnSegOpBody = unAllocKernelBody } unParam = fmap unMem unT = Right . unMem unAlloc' = Mapper { mapOnBody = const unAllocBody, mapOnRetType = unT, mapOnBranchType = unT, mapOnFParam = Right . unParam, mapOnLParam = Right . unParam, mapOnOp = unAllocOp, mapOnSubExp = Right, mapOnVName = Right } unMem :: MemInfo d u ret -> TypeBase (ShapeBase d) u unMem (MemPrim pt) = Prim pt unMem (MemArray pt shape u _) = Array pt shape u unMem (MemAcc acc ispace ts u) = Acc acc ispace ts u unMem MemMem {} = Prim Unit unAllocScope :: Scope GPUMem -> Scope GPU.GPU unAllocScope = M.map unInfo where unInfo (LetName dec) = LetName $ unMem dec unInfo (FParamName dec) = FParamName $ unMem dec unInfo (LParamName dec) = LParamName $ unMem dec unInfo (IndexName it) = IndexName it removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])] removeCommonSizes = M.toList . foldl' comb mempty . M.toList where comb m (mem, (_, size, space)) = M.insertWith (++) size [(mem, space)] m copyConsumed :: (MonadBuilder m, AliasableRep (Rep m)) => Stms (Rep m) -> m (Stms (Rep m)) copyConsumed stms = do let consumed = namesToList $ snd $ snd $ Alias.analyseStms mempty stms collectStms_ $ do consumed' <- mapM copy consumed let substs = M.fromList (zip consumed consumed') addStms $ substituteNames substs stms where copy v = letExp (baseString v <> "_copy") $ BasicOp $ Replicate mempty $ Var v -- Important for edge cases (#1838) that the Stms here still have the -- Allocs we are actually trying to get rid of. sliceKernelSizes :: SubExp -> [SubExp] -> SegSpace -> Stms GPUMem -> ExpandM (Stms GPU.GPU, [VName], [VName]) sliceKernelSizes num_threads sizes space kstms = do kstms' <- either throwError pure $ unAllocGPUStms kstms let num_sizes = length sizes i64s = replicate num_sizes $ Prim int64 kernels_scope <- asks unAllocScope (max_lam, _) <- flip runBuilderT kernels_scope $ do xs <- replicateM num_sizes $ newParam "x" (Prim int64) ys <- replicateM num_sizes $ newParam "y" (Prim int64) (zs, stms) <- localScope (scopeOfLParams $ xs ++ ys) $ collectStms $ forM (zip xs ys) $ \(x, y) -> fmap subExpRes . letSubExp "z" . BasicOp $ BinOp (SMax Int64) (Var $ paramName x) (Var $ paramName y) pure $ Lambda (xs ++ ys) i64s (mkBody stms zs) flat_gtid_lparam <- newParam "flat_gtid" (Prim (IntType Int64)) size_lam' <- localScope (scopeOfSegSpace space) . fmap fst . flip runBuilderT kernels_scope $ GPU.simplifyLambda <=< mkLambda [flat_gtid_lparam] $ do -- Even though this SegRed is one-dimensional, we need to -- provide indexes corresponding to the original potentially -- multi-dimensional construct. let (kspace_gtids, kspace_dims) = unzip $ unSegSpace space new_inds = unflattenIndex (map pe64 kspace_dims) (pe64 $ Var $ paramName flat_gtid_lparam) zipWithM_ letBindNames (map pure kspace_gtids) =<< mapM toExp new_inds mapM_ addStm =<< copyConsumed kstms' pure $ subExpsRes sizes ((maxes_per_thread, size_sums), slice_stms) <- flip runBuilderT kernels_scope $ do pat <- basicPat <$> replicateM num_sizes (newIdent "max_per_thread" $ Prim int64) w <- letSubExp "size_slice_w" =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) (segSpaceDims space) thread_space_iota <- letExp "thread_space_iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64 let red_op = SegBinOp Commutative max_lam (replicate num_sizes $ intConst Int64 0) mempty lvl <- segThread "segred" addStms =<< mapM renameStm =<< nonSegRed lvl pat w [red_op] size_lam' [thread_space_iota] size_sums <- forM (patNames pat) $ \threads_max -> letExp "size_sum" $ BasicOp $ BinOp (Mul Int64 OverflowUndef) (Var threads_max) num_threads pure (patNames pat, size_sums) pure (slice_stms, maxes_per_thread, size_sums) futhark-0.25.27/src/Futhark/Pass/ExplicitAllocations.hs000066400000000000000000001123541475065116200227570ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-redundant-constraints #-} -- | A generic transformation for adding memory allocations to a -- Futhark program. Specialised by specific representations in -- submodules. module Futhark.Pass.ExplicitAllocations ( explicitAllocationsGeneric, explicitAllocationsInStmsGeneric, ExpHint (..), defaultExpHints, askDefaultSpace, Allocable, AllocM, AllocEnv (..), SizeSubst (..), allocInStms, allocForArray, simplifiable, mkLetNamesB', mkLetNamesB'', -- * Module re-exports -- -- These are highly likely to be needed by any downstream -- users. module Control.Monad.Reader, module Futhark.MonadFreshNames, module Futhark.Pass, module Futhark.Tools, ) where import Control.Monad import Control.Monad.RWS.Strict import Control.Monad.Reader import Control.Monad.State import Control.Monad.Writer import Data.Bifunctor (first) import Data.Either (partitionEithers) import Data.List (foldl', transpose, zip4) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.Analysis.SymbolTable (IndexOp) import Futhark.Analysis.UsageTable qualified as UT import Futhark.IR.Mem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.IR.Prop.Aliases (AliasedOp) import Futhark.MonadFreshNames import Futhark.Optimise.Simplify.Engine (SimpleOps (..)) import Futhark.Optimise.Simplify.Engine qualified as Engine import Futhark.Optimise.Simplify.Rep (mkWiseBody) import Futhark.Pass import Futhark.Tools import Futhark.Util (maybeNth, splitAt3) type Allocable fromrep torep inner = ( PrettyRep fromrep, PrettyRep torep, Mem torep inner, LetDec torep ~ LetDecMem, FParamInfo fromrep ~ DeclType, LParamInfo fromrep ~ Type, BranchType fromrep ~ ExtType, RetType fromrep ~ DeclExtType, BodyDec fromrep ~ (), BodyDec torep ~ (), ExpDec torep ~ (), SizeSubst (inner torep), BuilderOps torep ) data AllocEnv fromrep torep = AllocEnv { -- | When allocating memory, put it in this memory space. -- This is primarily used to ensure that group-wide -- statements store their results in shared memory. allocSpace :: Space, -- | The set of names that are known to be constants at -- kernel compile time. envConsts :: S.Set VName, allocInOp :: Op fromrep -> AllocM fromrep torep (Op torep), envExpHints :: Exp torep -> AllocM fromrep torep [ExpHint] } -- | Monad for adding allocations to an entire program. newtype AllocM fromrep torep a = AllocM (BuilderT torep (ReaderT (AllocEnv fromrep torep) (State VNameSource)) a) deriving ( Applicative, Functor, Monad, MonadFreshNames, HasScope torep, LocalScope torep, MonadReader (AllocEnv fromrep torep) ) instance (Allocable fromrep torep inner) => MonadBuilder (AllocM fromrep torep) where type Rep (AllocM fromrep torep) = torep mkExpDecM _ _ = pure () mkLetNamesM names e = do def_space <- askDefaultSpace hints <- expHints e pat <- patWithAllocations def_space names e hints pure $ Let pat (defAux ()) e mkBodyM stms res = pure $ Body () stms res addStms = AllocM . addStms collectStms (AllocM m) = AllocM $ collectStms m expHints :: Exp torep -> AllocM fromrep torep [ExpHint] expHints e = do f <- asks envExpHints f e -- | The space in which we allocate memory if we have no other -- preferences or constraints. askDefaultSpace :: AllocM fromrep torep Space askDefaultSpace = asks allocSpace runAllocM :: (MonadFreshNames m) => Space -> (Op fromrep -> AllocM fromrep torep (Op torep)) -> (Exp torep -> AllocM fromrep torep [ExpHint]) -> AllocM fromrep torep a -> m a runAllocM space handleOp hints (AllocM m) = fmap fst $ modifyNameSource $ runState $ runReaderT (runBuilderT m mempty) env where env = AllocEnv { allocSpace = space, envConsts = mempty, allocInOp = handleOp, envExpHints = hints } elemSize :: (Num a) => Type -> a elemSize = primByteSize . elemType arraySizeInBytesExp :: Type -> PrimExp VName arraySizeInBytesExp t = untyped $ foldl' (*) (elemSize t) $ map pe64 (arrayDims t) arraySizeInBytes :: (MonadBuilder m) => Type -> m SubExp arraySizeInBytes = letSubExp "bytes" <=< toExp . arraySizeInBytesExp allocForArray' :: (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) => Type -> Space -> m VName allocForArray' t space = do size <- arraySizeInBytes t letExp "mem" $ Op $ Alloc size space -- | Allocate memory for a value of the given type. allocForArray :: (Allocable fromrep torep inner) => Type -> Space -> AllocM fromrep torep VName allocForArray t space = do allocForArray' t space -- | Repair an expression that cannot be assigned an index function. -- There is a simple remedy for this: normalise the input arrays and -- try again. repairExpression :: (Allocable fromrep torep inner) => Exp torep -> AllocM fromrep torep (Exp torep) repairExpression (BasicOp (Reshape k shape v)) = do v_mem <- fst <$> lookupArraySummary v space <- lookupMemSpace v_mem v' <- snd <$> ensureDirectArray (Just space) v pure $ BasicOp $ Reshape k shape v' repairExpression e = error $ "repairExpression:\n" <> prettyString e expReturns' :: (Allocable fromrep torep inner) => Exp torep -> AllocM fromrep torep ([ExpReturns], Exp torep) expReturns' e = do maybe_rts <- expReturns e case maybe_rts of Just rts -> pure (rts, e) Nothing -> do e' <- repairExpression e let bad = error . unlines $ [ "expReturns': impossible index transformation", prettyString e, prettyString e' ] rts <- fromMaybe bad <$> expReturns e' pure (rts, e') allocsForStm :: (Allocable fromrep torep inner) => [Ident] -> Exp torep -> AllocM fromrep torep (Stm torep) allocsForStm idents e = do def_space <- askDefaultSpace hints <- expHints e (rts, e') <- expReturns' e pes <- allocsForPat def_space idents rts hints dec <- mkExpDecM (Pat pes) e' pure $ Let (Pat pes) (defAux dec) e' patWithAllocations :: (MonadBuilder m, Mem (Rep m) inner) => Space -> [VName] -> Exp (Rep m) -> [ExpHint] -> m (Pat LetDecMem) patWithAllocations def_space names e hints = do ts' <- instantiateShapes' names <$> expExtType e let idents = zipWith Ident names ts' rts <- fromMaybe (error "patWithAllocations: ill-typed") <$> expReturns e Pat <$> allocsForPat def_space idents rts hints mkMissingIdents :: (MonadFreshNames m) => [Ident] -> [ExpReturns] -> m [Ident] mkMissingIdents idents rts = reverse <$> zipWithM f (reverse rts) (map Just (reverse idents) ++ repeat Nothing) where f _ (Just ident) = pure ident f (MemMem space) Nothing = newIdent "ext_mem" $ Mem space f _ Nothing = newIdent "ext" $ Prim int64 allocsForPat :: (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) => Space -> [Ident] -> [ExpReturns] -> [ExpHint] -> m [PatElem LetDecMem] allocsForPat def_space some_idents rts hints = do idents <- mkMissingIdents some_idents rts forM (zip3 idents rts hints) $ \(ident, rt, hint) -> do let ident_shape = arrayShape $ identType ident case rt of MemPrim _ -> do summary <- summaryForBindage def_space (identType ident) hint pure $ PatElem (identName ident) summary MemMem space -> pure $ PatElem (identName ident) $ MemMem space MemArray bt _ u (Just (ReturnsInBlock mem extlmad)) -> do let ixfn = instantiateExtLMAD idents extlmad pure . PatElem (identName ident) . MemArray bt ident_shape u $ ArrayIn mem ixfn MemArray _ extshape _ Nothing | Just _ <- knownShape extshape -> do summary <- summaryForBindage def_space (identType ident) hint pure $ PatElem (identName ident) summary MemArray bt _ u (Just (ReturnsNewBlock _ i extixfn)) -> do let ixfn = instantiateExtLMAD idents extixfn pure . PatElem (identName ident) . MemArray bt ident_shape u $ ArrayIn (getIdent idents i) ixfn MemAcc acc ispace ts u -> pure $ PatElem (identName ident) $ MemAcc acc ispace ts u _ -> error "Impossible case reached in allocsForPat!" where knownShape = mapM known . shapeDims known (Free v) = Just v known Ext {} = Nothing getIdent idents i = case maybeNth i idents of Just ident -> identName ident Nothing -> error $ "getIdent: Ext " <> show i <> " but pattern has " <> show (length idents) <> " elements: " <> prettyString idents instantiateExtLMAD idents = fmap $ fmap inst where inst (Free v) = v inst (Ext i) = getIdent idents i instantiateLMAD :: (Monad m) => ExtLMAD -> m LMAD instantiateLMAD = traverse $ traverse inst where inst Ext {} = error "instantiateLMAD: not yet" inst (Free x) = pure x summaryForBindage :: (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m)) => Space -> Type -> ExpHint -> m (MemBound NoUniqueness) summaryForBindage _ (Prim bt) _ = pure $ MemPrim bt summaryForBindage _ (Mem space) _ = pure $ MemMem space summaryForBindage _ (Acc acc ispace ts u) _ = pure $ MemAcc acc ispace ts u summaryForBindage def_space t@(Array pt shape u) NoHint = do m <- allocForArray' t def_space pure $ MemArray pt shape u $ ArrayIn m $ LMAD.iota 0 $ map pe64 $ arrayDims t summaryForBindage _ t@(Array pt _ _) (Hint lmad space) = do bytes <- letSubExp "bytes" <=< toExp . untyped $ primByteSize pt * (1 + LMAD.range lmad) m <- letExp "mem" $ Op $ Alloc bytes space pure $ MemArray pt (arrayShape t) NoUniqueness $ ArrayIn m lmad allocInFParams :: (Allocable fromrep torep inner) => [(FParam fromrep, Space)] -> ([FParam torep] -> AllocM fromrep torep a) -> AllocM fromrep torep a allocInFParams params m = do (valparams, (memparams, ctxparams)) <- runWriterT $ mapM (uncurry allocInFParam) params let params' = memparams <> ctxparams <> valparams summary = scopeOfFParams params' localScope summary $ m params' allocInFParam :: (Allocable fromrep torep inner) => FParam fromrep -> Space -> WriterT ([FParam torep], [FParam torep]) (AllocM fromrep torep) (FParam torep) allocInFParam param pspace = case paramDeclType param of Array pt shape u -> do let memname = baseString (paramName param) <> "_mem" lmad = LMAD.iota 0 $ map pe64 $ shapeDims shape mem <- lift $ newVName memname tell ([Param (paramAttrs param) mem $ MemMem pspace], []) pure param {paramDec = MemArray pt shape u $ ArrayIn mem lmad} Prim pt -> pure param {paramDec = MemPrim pt} Mem space -> pure param {paramDec = MemMem space} Acc acc ispace ts u -> pure param {paramDec = MemAcc acc ispace ts u} ensureRowMajorArray :: (Allocable fromrep torep inner) => Maybe Space -> VName -> AllocM fromrep torep (VName, VName) ensureRowMajorArray space_ok v = do (mem, _) <- lookupArraySummary v mem_space <- lookupMemSpace mem default_space <- askDefaultSpace let space = fromMaybe default_space space_ok if maybe True (== mem_space) space_ok then pure (mem, v) else allocLinearArray space (baseString v) v ensureArrayIn :: (Allocable fromrep torep inner) => Space -> SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp ensureArrayIn _ (Constant v) = error $ "ensureArrayIn: " ++ prettyString v ++ " cannot be an array." ensureArrayIn space (Var v) = do (mem', v') <- lift $ ensureRowMajorArray (Just space) v (_, lmad) <- lift $ lookupArraySummary v' ctx <- lift $ mapM (letSubExp "lmad_arg" <=< toExp) (LMAD.existentialized lmad) tell ([Var mem'], ctx) pure $ Var v' allocInLoopParams :: (Allocable fromrep torep inner) => [(FParam fromrep, SubExp)] -> ( [(FParam torep, SubExp)] -> ([SubExp] -> AllocM fromrep torep ([SubExp], [SubExp])) -> AllocM fromrep torep a ) -> AllocM fromrep torep a allocInLoopParams merge m = do ((valparams, valargs, handle_loop_subexps), (mem_params, ctx_params)) <- runWriterT $ unzip3 <$> mapM allocInLoopParam merge let mergeparams' = mem_params <> ctx_params <> valparams summary = scopeOfFParams mergeparams' mk_loop_res ses = do (ses', (memargs, ctxargs)) <- runWriterT $ zipWithM ($) handle_loop_subexps ses pure (memargs <> ctxargs, ses') (valctx_args, valargs') <- mk_loop_res valargs let merge' = zip (mem_params <> ctx_params <> valparams) (valctx_args <> valargs') localScope summary $ m merge' mk_loop_res where param_names = namesFromList $ map (paramName . fst) merge anyIsLoopParam names = names `namesIntersect` param_names scalarRes param_t v_mem_space v_lmad (Var res) = do -- Try really hard to avoid copying needlessly, but the result -- _must_ be in ScalarSpace and have the right index function. (res_mem, res_lmad) <- lift $ lookupArraySummary res res_mem_space <- lift $ lookupMemSpace res_mem (res_mem', res') <- if (res_mem_space, res_lmad) == (v_mem_space, v_lmad) then pure (res_mem, res) else lift $ arrayWithLMAD v_mem_space v_lmad (fromDecl param_t) res tell ([Var res_mem'], []) pure $ Var res' scalarRes _ _ _ se = pure se allocInLoopParam :: (Allocable fromrep torep inner) => (Param DeclType, SubExp) -> WriterT ([FParam torep], [FParam torep]) (AllocM fromrep torep) ( FParam torep, SubExp, SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp ) allocInLoopParam (mergeparam, Var v) | param_t@(Array pt shape u) <- paramDeclType mergeparam = do (v_mem, v_lmad) <- lift $ lookupArraySummary v v_mem_space <- lift $ lookupMemSpace v_mem -- Loop-invariant array parameters that are in scalar space -- are special - we do not wish to existentialise their index -- function at all (but the memory block is still existential). case v_mem_space of ScalarSpace {} -> if anyIsLoopParam (freeIn shape) then do -- Arrays with loop-variant shape cannot be in scalar -- space, so copy them elsewhere and try again. space <- lift askDefaultSpace (_, v') <- lift $ allocLinearArray space (baseString v) v allocInLoopParam (mergeparam, Var v') else do p <- newParam "mem_param" $ MemMem v_mem_space tell ([p], []) pure ( mergeparam {paramDec = MemArray pt shape u $ ArrayIn (paramName p) v_lmad}, Var v, scalarRes param_t v_mem_space v_lmad ) _ -> do (v_mem', v') <- lift $ ensureRowMajorArray Nothing v let lmad_ext = LMAD.existentialize 0 $ LMAD.iota 0 $ map pe64 $ shapeDims shape v_mem_space' <- lift $ lookupMemSpace v_mem' ctx_params <- replicateM (length (LMAD.existentialized lmad_ext)) $ newParam "ctx_param_ext" (MemPrim int64) param_lmad <- instantiateLMAD $ LMAD.substitute ( M.fromList . zip (fmap Ext [0 ..]) $ map (le64 . Free . paramName) ctx_params ) lmad_ext mem_param <- newParam "mem_param" $ MemMem v_mem_space' tell ([mem_param], ctx_params) pure ( mergeparam {paramDec = MemArray pt shape u $ ArrayIn (paramName mem_param) param_lmad}, Var v', ensureArrayIn v_mem_space' ) allocInLoopParam (mergeparam, se) = doDefault mergeparam se =<< lift askDefaultSpace doDefault mergeparam se space = do mergeparam' <- allocInFParam mergeparam space pure (mergeparam', se, linearFuncallArg (paramType mergeparam) space) arrayWithLMAD :: (MonadBuilder m, Op (Rep m) ~ MemOp inner (Rep m), LetDec (Rep m) ~ LetDecMem) => Space -> LMAD -> Type -> VName -> m (VName, VName) arrayWithLMAD space lmad v_t v = do let Array pt shape u = v_t mem <- allocForArray' v_t space v_copy <- newVName $ baseString v <> "_scalcopy" let pe = PatElem v_copy $ MemArray pt shape u $ ArrayIn mem lmad letBind (Pat [pe]) $ BasicOp $ Replicate mempty $ Var v pure (mem, v_copy) ensureDirectArray :: (Allocable fromrep torep inner) => Maybe Space -> VName -> AllocM fromrep torep (VName, VName) ensureDirectArray space_ok v = do (mem, lmad) <- lookupArraySummary v mem_space <- lookupMemSpace mem default_space <- askDefaultSpace if LMAD.isDirect lmad && maybe True (== mem_space) space_ok then pure (mem, v) else needCopy (fromMaybe default_space space_ok) where needCopy space = -- We need to do a new allocation, copy 'v', and make a new -- binding for the size of the memory block. allocLinearArray space (baseString v) v allocPermArray :: (Allocable fromrep torep inner) => Space -> [Int] -> String -> VName -> AllocM fromrep torep (VName, VName) allocPermArray space perm s v = do t <- lookupType v case t of Array pt shape u -> do mem <- allocForArray t space v' <- newVName $ s <> "_desired_form" let info = MemArray pt shape u . ArrayIn mem $ LMAD.permute (LMAD.iota 0 $ map pe64 $ arrayDims t) perm pat = Pat [PatElem v' info] addStm $ Let pat (defAux ()) $ BasicOp $ Manifest perm v pure (mem, v') _ -> error $ "allocPermArray: " ++ prettyString t ensurePermArray :: (Allocable fromrep torep inner) => Maybe Space -> [Int] -> VName -> AllocM fromrep torep (VName, VName) ensurePermArray space_ok perm v = do (mem, _) <- lookupArraySummary v mem_space <- lookupMemSpace mem default_space <- askDefaultSpace if maybe True (== mem_space) space_ok then pure (mem, v) else allocPermArray (fromMaybe default_space space_ok) perm (baseString v) v allocLinearArray :: (Allocable fromrep torep inner) => Space -> String -> VName -> AllocM fromrep torep (VName, VName) allocLinearArray space s v = do t <- lookupType v let perm = [0 .. arrayRank t - 1] allocPermArray space perm s v funcallArgs :: (Allocable fromrep torep inner) => [(SubExp, Diet)] -> AllocM fromrep torep [(SubExp, Diet)] funcallArgs args = do (valargs, (ctx_args, mem_and_size_args)) <- runWriterT $ forM args $ \(arg, d) -> do t <- lift $ subExpType arg space <- lift askDefaultSpace arg' <- linearFuncallArg t space arg pure (arg', d) pure $ map (,Observe) (ctx_args <> mem_and_size_args) <> valargs linearFuncallArg :: (Allocable fromrep torep inner) => Type -> Space -> SubExp -> WriterT ([SubExp], [SubExp]) (AllocM fromrep torep) SubExp linearFuncallArg Array {} space (Var v) = do (mem, arg') <- lift $ ensureDirectArray (Just space) v tell ([Var mem], []) pure $ Var arg' linearFuncallArg _ _ arg = pure arg shiftRetAls :: Int -> Int -> RetAls -> RetAls shiftRetAls a b (RetAls is js) = RetAls (map (+ a) is) (map (+ b) js) explicitAllocationsGeneric :: (Allocable fromrep torep inner) => Space -> (Op fromrep -> AllocM fromrep torep (Op torep)) -> (Exp torep -> AllocM fromrep torep [ExpHint]) -> Pass fromrep torep explicitAllocationsGeneric space handleOp hints = Pass "explicit allocations" "Transform program to explicit memory representation" $ intraproceduralTransformationWithConsts onStms allocInFun where onStms stms = runAllocM space handleOp hints $ collectStms_ $ allocInStms stms $ pure () allocInFun consts (FunDef entry attrs fname rettype params fbody) = runAllocM space handleOp hints . inScopeOf consts $ allocInFParams (map (,space) params) $ \params' -> do (fbody', mem_rets) <- allocInFunBody (map (const $ Just space) rettype) fbody let num_extra_params = length params' - length params num_extra_rets = length mem_rets rettype' = map (,RetAls mempty mempty) mem_rets ++ zip (memoryInDeclExtType space (length mem_rets) (map fst rettype)) (map (shiftRetAls num_extra_params num_extra_rets . snd) rettype) pure $ FunDef entry attrs fname rettype' params' fbody' explicitAllocationsInStmsGeneric :: ( MonadFreshNames m, HasScope torep m, Allocable fromrep torep inner ) => Space -> (Op fromrep -> AllocM fromrep torep (Op torep)) -> (Exp torep -> AllocM fromrep torep [ExpHint]) -> Stms fromrep -> m (Stms torep) explicitAllocationsInStmsGeneric space handleOp hints stms = do scope <- askScope runAllocM space handleOp hints $ localScope scope $ collectStms_ $ allocInStms stms $ pure () memoryInDeclExtType :: Space -> Int -> [DeclExtType] -> [FunReturns] memoryInDeclExtType space k dets = evalState (mapM addMem dets) 0 where addMem (Prim t) = pure $ MemPrim t addMem Mem {} = error "memoryInDeclExtType: too much memory" addMem (Array pt shape u) = do i <- get <* modify (+ 1) let shape' = fmap shift shape pure . MemArray pt shape' u . ReturnsNewBlock space i $ LMAD.iota 0 (map convert $ shapeDims shape') addMem (Acc acc ispace ts u) = pure $ MemAcc acc ispace ts u convert (Ext i) = le64 $ Ext i convert (Free v) = Free <$> pe64 v shift (Ext i) = Ext (i + k) shift (Free x) = Free x bodyReturnMemCtx :: (Allocable fromrep torep inner) => SubExpRes -> AllocM fromrep torep [(SubExpRes, MemInfo ExtSize u MemReturn)] bodyReturnMemCtx (SubExpRes _ Constant {}) = pure [] bodyReturnMemCtx (SubExpRes _ (Var v)) = do info <- lookupMemInfo v case info of MemPrim {} -> pure [] MemAcc {} -> pure [] MemMem {} -> pure [] -- should not happen MemArray _ _ _ (ArrayIn mem _) -> do mem_info <- lookupMemInfo mem case mem_info of MemMem space -> pure [(subExpRes $ Var mem, MemMem space)] _ -> error $ "bodyReturnMemCtx: not a memory block: " ++ prettyString mem allocInFunBody :: (Allocable fromrep torep inner) => [Maybe Space] -> Body fromrep -> AllocM fromrep torep (Body torep, [FunReturns]) allocInFunBody space_oks (Body _ stms res) = buildBody . allocInStms stms $ do res' <- zipWithM ensureDirect space_oks' res (mem_ctx_res, mem_ctx_rets) <- unzip . concat <$> mapM bodyReturnMemCtx res' pure (mem_ctx_res <> res', mem_ctx_rets) where num_vals = length space_oks space_oks' = replicate (length res - num_vals) Nothing ++ space_oks ensureDirect :: (Allocable fromrep torep inner) => Maybe Space -> SubExpRes -> AllocM fromrep torep SubExpRes ensureDirect space_ok (SubExpRes cs se) = do se_info <- subExpMemInfo se SubExpRes cs <$> case (se_info, se) of (MemArray {}, Var v) -> do (_, v') <- ensureDirectArray space_ok v pure $ Var v' _ -> pure se allocInStms :: (Allocable fromrep torep inner) => Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a allocInStms origstms m = allocInStms' $ stmsToList origstms where allocInStms' [] = m allocInStms' (stm : stms) = do allocstms <- collectStms_ $ auxing (stmAux stm) $ allocInStm stm addStms allocstms let stms_consts = foldMap stmConsts allocstms f env = env {envConsts = stms_consts <> envConsts env} local f $ allocInStms' stms allocInStm :: (Allocable fromrep torep inner) => Stm fromrep -> AllocM fromrep torep () allocInStm (Let (Pat pes) _ e) = addStm =<< allocsForStm (map patElemIdent pes) =<< allocInExp e allocInLambda :: (Allocable fromrep torep inner) => [LParam torep] -> Body fromrep -> AllocM fromrep torep (Lambda torep) allocInLambda params body = mkLambda params . allocInStms (bodyStms body) $ pure $ bodyResult body data MemReq = MemReq Space | NeedsNormalisation Space deriving (Eq, Show) combMemReqs :: MemReq -> MemReq -> MemReq combMemReqs x@NeedsNormalisation {} _ = x combMemReqs _ y@NeedsNormalisation {} = y combMemReqs x@(MemReq x_space) y@MemReq {} = if x == y then x else NeedsNormalisation x_space type MemReqType = MemInfo (Ext SubExp) NoUniqueness MemReq combMemReqTypes :: MemReqType -> MemReqType -> MemReqType combMemReqTypes (MemArray pt shape u x) (MemArray _ _ _ y) = MemArray pt shape u $ combMemReqs x y combMemReqTypes x _ = x contextRets :: MemReqType -> [MemInfo d u r] contextRets (MemArray _ shape _ (MemReq space)) = -- Memory + offset + stride*rank. [MemMem space, MemPrim int64] ++ replicate (shapeRank shape) (MemPrim int64) contextRets (MemArray _ shape _ (NeedsNormalisation space)) = -- Memory + offset + stride*rank. [MemMem space, MemPrim int64] ++ replicate (shapeRank shape) (MemPrim int64) contextRets _ = [] -- Add memory information to the body, but do not return memory/lmad -- information. Instead, return restrictions on what the index -- function should look like. We will then (crudely) unify these -- restrictions across all bodies. allocInMatchBody :: (Allocable fromrep torep inner) => [ExtType] -> Body fromrep -> AllocM fromrep torep (Body torep, [MemReqType]) allocInMatchBody rets (Body _ stms res) = buildBody . allocInStms stms $ do restrictions <- zipWithM restriction rets (map resSubExp res) pure (res, restrictions) where restriction t se = do v_info <- subExpMemInfo se case (t, v_info) of (Array pt shape u, MemArray _ _ _ (ArrayIn mem _)) -> do space <- lookupMemSpace mem pure $ MemArray pt shape u $ MemReq space (_, MemMem space) -> pure $ MemMem space (_, MemPrim pt) -> pure $ MemPrim pt (_, MemAcc acc ispace ts u) -> pure $ MemAcc acc ispace ts u _ -> error $ "allocInMatchBody: mismatch: " ++ show (t, v_info) mkBranchRet :: [MemReqType] -> [BranchTypeMem] mkBranchRet reqs = let (ctx_rets, res_rets) = foldl helper ([], []) $ zip reqs offsets in ctx_rets ++ res_rets where numCtxNeeded = length . contextRets offsets = scanl (+) 0 $ map numCtxNeeded reqs num_new_ctx = last offsets helper (ctx_rets_acc, res_rets_acc) (req, ctx_offset) = ( ctx_rets_acc ++ contextRets req, res_rets_acc ++ [inspect ctx_offset req] ) arrayInfo (NeedsNormalisation space) = space arrayInfo (MemReq space) = space inspect ctx_offset (MemArray pt shape u req) = let shape' = fmap (adjustExt num_new_ctx) shape space = arrayInfo req in MemArray pt shape' u . ReturnsNewBlock space ctx_offset $ convert <$> LMAD.mkExistential (shapeDims shape') (ctx_offset + 1) inspect _ (MemAcc acc ispace ts u) = MemAcc acc ispace ts u inspect _ (MemPrim pt) = MemPrim pt inspect _ (MemMem space) = MemMem space convert (Ext i) = le64 (Ext i) convert (Free v) = Free <$> pe64 v adjustExt :: Int -> Ext a -> Ext a adjustExt _ (Free v) = Free v adjustExt k (Ext i) = Ext (k + i) addCtxToMatchBody :: (Allocable fromrep torep inner) => [MemReqType] -> Body torep -> AllocM fromrep torep (Body torep) addCtxToMatchBody reqs body = buildBody_ $ do res <- zipWithM normaliseIfNeeded reqs =<< bodyBind body ctx <- concat <$> mapM resCtx res pure $ ctx ++ res where normaliseIfNeeded (MemArray _ shape _ (NeedsNormalisation space)) (SubExpRes cs (Var v)) = SubExpRes cs . Var . snd <$> ensurePermArray (Just space) [0 .. shapeRank shape - 1] v normaliseIfNeeded _ res = pure res resCtx (SubExpRes _ Constant {}) = pure [] resCtx (SubExpRes _ (Var v)) = do info <- lookupMemInfo v case info of MemPrim {} -> pure [] MemAcc {} -> pure [] MemMem {} -> pure [] -- should not happen MemArray _ _ _ (ArrayIn mem lmad) -> do lmad_exts <- mapM (letSubExp "lmad_ext" <=< toExp) $ LMAD.existentialized lmad pure $ subExpRes (Var mem) : subExpsRes lmad_exts -- Do a a simple form of invariance analysis to simplify a Match. It -- is unfortunate that we have to do it here, but functions such as -- scalarRes will look carefully at the index functions before the -- simplifier has a chance to run. In a perfect world we would -- simplify away those copies afterwards. XXX; this should be fixed by -- a more general copy-removal pass. See -- Futhark.Optimise.EntryPointMem for a very specialised version of -- the idea, but which could perhaps be generalised. simplifyMatch :: (Mem rep inner) => [Case (Body rep)] -> Body rep -> [BranchTypeMem] -> ( [Case (Body rep)], Body rep, [BranchTypeMem] ) simplifyMatch cases defbody ts = let case_reses = map (bodyResult . caseBody) cases defbody_res = bodyResult defbody (ctx_fixes, variant) = partitionEithers . map branchInvariant $ zip4 [0 ..] (transpose case_reses) defbody_res ts (cases_reses, defbody_reses, ts') = unzip3 variant in ( zipWith onCase cases (transpose cases_reses), onBody defbody defbody_reses, foldr (uncurry fixExt) ts' ctx_fixes ) where bound_in_branches = namesFromList . concatMap (patNames . stmPat) $ foldMap (bodyStms . caseBody) cases <> bodyStms defbody onCase c res = fmap (`onBody` res) c onBody body res = body {bodyResult = res} branchInvariant (i, case_reses, defres, t) -- If even one branch has a variant result, then we give up. | namesIntersect bound_in_branches $ freeIn $ defres : case_reses = Right (case_reses, defres, t) -- Do all branches return the same value? | all ((== resSubExp defres) . resSubExp) case_reses = Left (i, resSubExp defres) | otherwise = Right (case_reses, defres, t) allocInExp :: (Allocable fromrep torep inner) => Exp fromrep -> AllocM fromrep torep (Exp torep) allocInExp (Loop merge form (Body () bodystms bodyres)) = allocInLoopParams merge $ \merge' mk_loop_val -> do localScope (scopeOfLoopForm form) $ do body' <- buildBody_ . allocInStms bodystms $ do (valctx, valres') <- mk_loop_val $ map resSubExp bodyres pure $ subExpsRes valctx <> zipWith SubExpRes (map resCerts bodyres) valres' pure $ Loop merge' form body' allocInExp (Apply fname args rettype loc) = do args' <- funcallArgs args space <- askDefaultSpace -- We assume that every array is going to be in its own memory. let num_extra_args = length args' - length args rettype' = mems space ++ zip (memoryInDeclExtType space num_arrays (map fst rettype)) (map (shiftRetAls num_extra_args num_arrays . snd) rettype) pure $ Apply fname args' rettype' loc where mems space = replicate num_arrays (MemMem space, RetAls mempty mempty) num_arrays = length $ filter ((> 0) . arrayRank . declExtTypeOf . fst) rettype allocInExp (Match ses cases defbody (MatchDec rets ifsort)) = do (defbody', def_reqs) <- allocInMatchBody rets defbody (cases', cases_reqs) <- mapAndUnzipM onCase cases let reqs = zipWith (foldl combMemReqTypes) def_reqs (transpose cases_reqs) defbody'' <- addCtxToMatchBody reqs defbody' cases'' <- mapM (traverse $ addCtxToMatchBody reqs) cases' let (cases''', defbody''', rets') = simplifyMatch cases'' defbody'' $ mkBranchRet reqs pure $ Match ses cases''' defbody''' $ MatchDec rets' ifsort where onCase (Case vs body) = first (Case vs) <$> allocInMatchBody rets body allocInExp (WithAcc inputs bodylam) = WithAcc <$> mapM onInput inputs <*> onLambda bodylam where onLambda lam = do params <- forM (lambdaParams lam) $ \(Param attrs pv t) -> case t of Prim Unit -> pure $ Param attrs pv $ MemPrim Unit Acc acc ispace ts u -> pure $ Param attrs pv $ MemAcc acc ispace ts u _ -> error $ "Unexpected WithAcc lambda param: " ++ prettyString (Param attrs pv t) allocInLambda params (lambdaBody lam) onInput (shape, arrs, op) = (shape,arrs,) <$> traverse (onOp shape arrs) op onOp accshape arrs (lam, nes) = do let num_vs = length (lambdaReturnType lam) num_is = shapeRank accshape (i_params, x_params, y_params) = splitAt3 num_is num_vs $ lambdaParams lam i_params' = map (\(Param attrs v _) -> Param attrs v $ MemPrim int64) i_params is = map (DimFix . Var . paramName) i_params' x_params' <- zipWithM (onXParam is) x_params arrs y_params' <- zipWithM (onYParam is) y_params arrs lam' <- allocInLambda (i_params' <> x_params' <> y_params') (lambdaBody lam) pure (lam', nes) mkP attrs p pt shape u mem lmad is = Param attrs p . MemArray pt shape u . ArrayIn mem . LMAD.slice lmad $ fmap pe64 $ Slice $ is ++ map sliceDim (shapeDims shape) onXParam _ (Param attrs p (Prim t)) _ = pure $ Param attrs p (MemPrim t) onXParam is (Param attrs p (Array pt shape u)) arr = do (mem, lmad) <- lookupArraySummary arr pure $ mkP attrs p pt shape u mem lmad is onXParam _ p _ = error $ "Cannot handle MkAcc param: " ++ prettyString p onYParam _ (Param attrs p (Prim t)) _ = pure $ Param attrs p $ MemPrim t onYParam is (Param attrs p (Array pt shape u)) arr = do arr_t <- lookupType arr space <- askDefaultSpace mem <- allocForArray arr_t space let base_dims = map pe64 $ arrayDims arr_t lmad = LMAD.iota 0 base_dims pure $ mkP attrs p pt shape u mem lmad is onYParam _ p _ = error $ "Cannot handle MkAcc param: " ++ prettyString p allocInExp e = mapExpM alloc e where alloc = identityMapper { mapOnBody = error "Unhandled Body in ExplicitAllocations", mapOnRetType = error "Unhandled RetType in ExplicitAllocations", mapOnBranchType = error "Unhandled BranchType in ExplicitAllocations", mapOnFParam = error "Unhandled FParam in ExplicitAllocations", mapOnLParam = error "Unhandled LParam in ExplicitAllocations", mapOnOp = \op -> do handle <- asks allocInOp handle op } class SizeSubst op where opIsConst :: op -> Bool opIsConst = const False instance SizeSubst (NoOp rep) instance (SizeSubst (op rep)) => SizeSubst (MemOp op rep) where opIsConst (Inner op) = opIsConst op opIsConst _ = False stmConsts :: (SizeSubst (Op rep)) => Stm rep -> S.Set VName stmConsts (Let pat _ (Op op)) | opIsConst op = S.fromList $ patNames pat stmConsts _ = mempty mkLetNamesB' :: ( LetDec (Rep m) ~ LetDecMem, Mem (Rep m) inner, MonadBuilder m, ExpDec (Rep m) ~ () ) => Space -> ExpDec (Rep m) -> [VName] -> Exp (Rep m) -> m (Stm (Rep m)) mkLetNamesB' space dec names e = do pat <- patWithAllocations space names e nohints pure $ Let pat (defAux dec) e where nohints = map (const NoHint) names mkLetNamesB'' :: ( Mem rep inner, LetDec rep ~ LetDecMem, OpReturns inner, ExpDec rep ~ (), Rep m ~ Engine.Wise rep, HasScope (Engine.Wise rep) m, MonadBuilder m, AliasedOp inner, RephraseOp (MemOp inner), Engine.CanBeWise inner, ASTConstraints (inner (Engine.Wise rep)) ) => Space -> [VName] -> Exp (Engine.Wise rep) -> m (Stm (Engine.Wise rep)) mkLetNamesB'' space names e = do pat <- patWithAllocations space names e nohints let pat' = Engine.addWisdomToPat pat e dec = Engine.mkWiseExpDec pat' () e pure $ Let pat' (defAux dec) e where nohints = map (const NoHint) names simplifyMemOp :: (Engine.SimplifiableRep rep) => ( inner (Engine.Wise rep) -> Engine.SimpleM rep (inner (Engine.Wise rep), Stms (Engine.Wise rep)) ) -> MemOp inner (Engine.Wise rep) -> Engine.SimpleM rep (MemOp inner (Engine.Wise rep), Stms (Engine.Wise rep)) simplifyMemOp _ (Alloc size space) = (,) <$> (Alloc <$> Engine.simplify size <*> pure space) <*> pure mempty simplifyMemOp onInner (Inner k) = do (k', hoisted) <- onInner k pure (Inner k', hoisted) simplifiable :: ( Engine.SimplifiableRep rep, LetDec rep ~ LetDecMem, ExpDec rep ~ (), BodyDec rep ~ (), Mem (Engine.Wise rep) inner, Engine.CanBeWise inner, RephraseOp inner, IsOp inner, OpReturns inner, AliasedOp inner, IndexOp (inner (Engine.Wise rep)) ) => (inner (Engine.Wise rep) -> UT.UsageTable) -> ( inner (Engine.Wise rep) -> Engine.SimpleM rep (inner (Engine.Wise rep), Stms (Engine.Wise rep)) ) -> SimpleOps rep simplifiable innerUsage simplifyInnerOp = SimpleOps mkExpDecS' mkBodyS' protectOp opUsage (simplifyMemOp simplifyInnerOp) where mkExpDecS' _ pat e = pure $ Engine.mkWiseExpDec pat () e mkBodyS' _ stms res = pure $ mkWiseBody () stms res protectOp taken pat (Alloc size space) = Just $ do tbody <- resultBodyM [size] fbody <- resultBodyM [intConst Int64 0] size' <- letSubExp "hoisted_alloc_size" $ Match [taken] [Case [Just $ BoolValue True] tbody] fbody $ MatchDec [MemPrim int64] MatchFallback letBind pat $ Op $ Alloc size' space protectOp _ _ _ = Nothing opUsage (Alloc (Var size) _) = UT.sizeUsage size opUsage (Alloc _ _) = mempty opUsage (Inner inner) = innerUsage inner data ExpHint = NoHint | Hint LMAD Space defaultExpHints :: (ASTRep rep, HasScope rep m) => Exp rep -> m [ExpHint] defaultExpHints e = map (const NoHint) <$> expExtType e futhark-0.25.27/src/Futhark/Pass/ExplicitAllocations/000077500000000000000000000000001475065116200224155ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Pass/ExplicitAllocations/GPU.hs000066400000000000000000000156221475065116200234120ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | Facilities for converting a 'GPU' program to 'GPUMem'. module Futhark.Pass.ExplicitAllocations.GPU ( explicitAllocations, explicitAllocationsInStms, ) where import Control.Monad import Data.Set qualified as S import Futhark.IR.GPU import Futhark.IR.GPUMem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.Pass.ExplicitAllocations import Futhark.Pass.ExplicitAllocations.SegOp instance SizeSubst (HostOp rep op) where opIsConst (SizeOp GetSize {}) = True opIsConst (SizeOp GetSizeMax {}) = True opIsConst _ = False allocAtLevel :: SegLevel -> AllocM GPU GPUMem a -> AllocM GPU GPUMem a allocAtLevel lvl = local $ \env -> env { allocSpace = space, allocInOp = handleHostOp (Just lvl) } where space = case lvl of SegBlock {} -> Space "shared" SegThread {} -> Space "device" SegThreadInBlock {} -> Space "device" handleSegOp :: Maybe SegLevel -> SegOp SegLevel GPU -> AllocM GPU GPUMem (SegOp SegLevel GPUMem) handleSegOp outer_lvl op = do num_threads <- case (outer_lvl, segLevel op) of -- This implies we are in the intragroup parallelism situation. -- Just allocate for a single group; memory expansion will -- handle the rest later. (Just (SegBlock _ (Just grid)), _) -> pure $ unCount $ gridBlockSize grid _ -> letSubExp "num_threads" =<< case maybe_grid of Just grid -> pure . BasicOp $ BinOp (Mul Int64 OverflowUndef) (unCount (gridNumBlocks grid)) (unCount (gridBlockSize grid)) Nothing -> foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) (segSpaceDims $ segSpace op) allocAtLevel (segLevel op) $ mapSegOpM (mapper num_threads) op where maybe_grid = case (outer_lvl, segLevel op) of (Just (SegThread _ (Just grid)), _) -> Just grid (Just (SegBlock _ (Just grid)), _) -> Just grid (_, SegThread _ (Just grid)) -> Just grid (_, SegBlock _ (Just grid)) -> Just grid _ -> Nothing scope = scopeOfSegSpace $ segSpace op mapper num_threads = identitySegOpMapper { mapOnSegOpBody = localScope scope . local f . allocInKernelBody, mapOnSegOpLambda = local inThread . allocInBinOpLambda num_threads (segSpace op) } f = case segLevel op of SegThread {} -> inThread SegThreadInBlock {} -> inThread SegBlock {} -> inGroup inThread env = env {envExpHints = inThreadExpHints} inGroup env = env {envExpHints = inGroupExpHints} handleHostOp :: Maybe SegLevel -> HostOp SOAC GPU -> AllocM GPU GPUMem (MemOp (HostOp NoOp) GPUMem) handleHostOp _ (SizeOp op) = pure $ Inner $ SizeOp op handleHostOp _ (OtherOp op) = error $ "Cannot allocate memory in SOAC: " ++ prettyString op handleHostOp outer_lvl (SegOp op) = Inner . SegOp <$> handleSegOp outer_lvl op handleHostOp _ (GPUBody ts (Body _ stms res)) = fmap (Inner . GPUBody ts) . buildBody_ . allocInStms stms $ pure res kernelExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint] kernelExpHints (BasicOp (Manifest perm v)) = do dims <- arrayDims <$> lookupType v let perm_inv = rearrangeInverse perm dims' = rearrangeShape perm dims lmad = LMAD.permute (LMAD.iota 0 $ map pe64 dims') perm_inv pure [Hint lmad $ Space "device"] kernelExpHints (Op (Inner (SegOp (SegMap lvl@(SegThread _ _) space ts body)))) = zipWithM (mapResultHint lvl space) ts $ kernelBodyResult body kernelExpHints (Op (Inner (SegOp (SegRed lvl@(SegThread _ _) space reds ts body)))) = (map (const NoHint) red_res <>) <$> zipWithM (mapResultHint lvl space) (drop num_reds ts) map_res where num_reds = segBinOpResults reds (red_res, map_res) = splitAt num_reds $ kernelBodyResult body kernelExpHints e = defaultExpHints e mapResultHint :: SegLevel -> SegSpace -> Type -> KernelResult -> AllocM GPU GPUMem ExpHint mapResultHint _lvl space = hint where -- Heuristic: do not rearrange for returned arrays that are -- sufficiently small. coalesceReturnOfShape _ [] = False coalesceReturnOfShape bs [Constant (IntValue (Int64Value d))] = bs * d > 4 coalesceReturnOfShape _ _ = True hint t Returns {} | coalesceReturnOfShape (primByteSize (elemType t)) $ arrayDims t = do let space_dims = segSpaceDims space pure $ Hint (innermost space_dims (arrayDims t)) $ Space "device" hint _ _ = pure NoHint innermost :: [SubExp] -> [SubExp] -> LMAD innermost space_dims t_dims = let r = length t_dims dims = space_dims ++ t_dims perm = [length space_dims .. length space_dims + r - 1] ++ [0 .. length space_dims - 1] perm_inv = rearrangeInverse perm dims_perm = rearrangeShape perm dims lmad_base = LMAD.iota 0 $ map pe64 dims_perm lmad_rearranged = LMAD.permute lmad_base perm_inv in lmad_rearranged semiStatic :: S.Set VName -> SubExp -> Bool semiStatic _ Constant {} = True semiStatic consts (Var v) = v `S.member` consts inGroupExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint] inGroupExpHints (Op (Inner (SegOp (SegMap _ space ts body)))) | any private $ kernelBodyResult body = do consts <- asks envConsts pure $ do (t, r) <- zip ts $ kernelBodyResult body pure $ if private r && all (semiStatic consts) (arrayDims t) then let seg_dims = map pe64 $ segSpaceDims space dims = seg_dims ++ map pe64 (arrayDims t) nilSlice d = DimSlice 0 d 0 in Hint ( LMAD.slice (LMAD.iota 0 dims) $ fullSliceNum dims (map nilSlice seg_dims) ) $ ScalarSpace (arrayDims t) $ elemType t else NoHint where private (Returns ResultPrivate _ _) = True private _ = False inGroupExpHints e = defaultExpHints e inThreadExpHints :: Exp GPUMem -> AllocM GPU GPUMem [ExpHint] inThreadExpHints e = do consts <- asks envConsts mapM (maybePrivate consts) =<< expExtType e where maybePrivate consts t | Just (Array pt shape _) <- hasStaticShape t, all (semiStatic consts) $ shapeDims shape = do let lmad = LMAD.iota 0 $ map pe64 $ shapeDims shape pure $ Hint lmad $ ScalarSpace (shapeDims shape) pt | otherwise = pure NoHint -- | The pass from 'GPU' to 'GPUMem'. explicitAllocations :: Pass GPU GPUMem explicitAllocations = explicitAllocationsGeneric (Space "device") (handleHostOp Nothing) kernelExpHints -- | Convert some 'GPU' stms to 'GPUMem'. explicitAllocationsInStms :: (MonadFreshNames m, HasScope GPUMem m) => Stms GPU -> m (Stms GPUMem) explicitAllocationsInStms = explicitAllocationsInStmsGeneric (Space "device") (handleHostOp Nothing) kernelExpHints futhark-0.25.27/src/Futhark/Pass/ExplicitAllocations/MC.hs000066400000000000000000000022601475065116200232500ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | Converting 'MC' programs to 'MCMem'. module Futhark.Pass.ExplicitAllocations.MC (explicitAllocations) where import Futhark.IR.MC import Futhark.IR.MCMem import Futhark.Pass.ExplicitAllocations import Futhark.Pass.ExplicitAllocations.SegOp instance SizeSubst (MCOp rep op) handleSegOp :: SegOp () MC -> AllocM MC MCMem (SegOp () MCMem) handleSegOp op = do let num_threads = intConst Int64 256 -- FIXME mapSegOpM (mapper num_threads) op where scope = scopeOfSegSpace $ segSpace op mapper num_threads = identitySegOpMapper { mapOnSegOpBody = localScope scope . allocInKernelBody, mapOnSegOpLambda = allocInBinOpLambda num_threads (segSpace op) } handleMCOp :: Op MC -> AllocM MC MCMem (Op MCMem) handleMCOp (ParOp par_op op) = Inner <$> (ParOp <$> traverse handleSegOp par_op <*> handleSegOp op) handleMCOp (OtherOp soac) = error $ "Cannot allocate memory in SOAC: " ++ prettyString soac -- | The pass from 'MC' to 'MCMem'. explicitAllocations :: Pass MC MCMem explicitAllocations = explicitAllocationsGeneric DefaultSpace handleMCOp defaultExpHints futhark-0.25.27/src/Futhark/Pass/ExplicitAllocations/SegOp.hs000066400000000000000000000062221475065116200237700ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module Futhark.Pass.ExplicitAllocations.SegOp ( allocInKernelBody, allocInBinOpLambda, ) where import Control.Monad import Futhark.IR.GPUMem import Futhark.IR.Mem.LMAD qualified as LMAD import Futhark.Pass.ExplicitAllocations instance SizeSubst (SegOp lvl rep) allocInKernelBody :: (Allocable fromrep torep inner) => KernelBody fromrep -> AllocM fromrep torep (KernelBody torep) allocInKernelBody (KernelBody () stms res) = uncurry (flip (KernelBody ())) <$> collectStms (allocInStms stms (pure res)) allocInLambda :: (Allocable fromrep torep inner) => [LParam torep] -> Body fromrep -> AllocM fromrep torep (Lambda torep) allocInLambda params body = mkLambda params . allocInStms (bodyStms body) $ pure $ bodyResult body allocInBinOpParams :: (Allocable fromrep torep inner) => SubExp -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> [LParam fromrep] -> [LParam fromrep] -> AllocM fromrep torep ([LParam torep], [LParam torep]) allocInBinOpParams num_threads my_id other_id xs ys = unzip <$> zipWithM alloc xs ys where alloc x y = case paramType x of Array pt shape u -> do let name = maybe "num_threads" baseString (subExpVar num_threads) twice_num_threads <- letSubExp ("twice_" <> name) . BasicOp $ BinOp (Mul Int64 OverflowUndef) num_threads (intConst Int64 2) let t = paramType x `arrayOfRow` twice_num_threads mem <- allocForArray t =<< askDefaultSpace -- XXX: this iota lmad is a bit inefficient; leading to -- uncoalesced access. let base_dims = map pe64 $ arrayDims t lmad_base = LMAD.iota 0 base_dims lmad_x = LMAD.slice lmad_base $ fullSliceNum base_dims [DimFix my_id] lmad_y = LMAD.slice lmad_base $ fullSliceNum base_dims [DimFix other_id] pure ( x {paramDec = MemArray pt shape u $ ArrayIn mem lmad_x}, y {paramDec = MemArray pt shape u $ ArrayIn mem lmad_y} ) Prim bt -> pure ( x {paramDec = MemPrim bt}, y {paramDec = MemPrim bt} ) Mem space -> pure ( x {paramDec = MemMem space}, y {paramDec = MemMem space} ) -- This next case will never happen. Acc acc ispace ts u -> pure ( x {paramDec = MemAcc acc ispace ts u}, y {paramDec = MemAcc acc ispace ts u} ) allocInBinOpLambda :: (Allocable fromrep torep inner) => SubExp -> SegSpace -> Lambda fromrep -> AllocM fromrep torep (Lambda torep) allocInBinOpLambda num_threads (SegSpace flat _) lam = do let (acc_params, arr_params) = splitAt (length (lambdaParams lam) `div` 2) $ lambdaParams lam index_x = TPrimExp $ LeafExp flat int64 index_y = index_x + pe64 num_threads (acc_params', arr_params') <- allocInBinOpParams num_threads index_x index_y acc_params arr_params allocInLambda (acc_params' ++ arr_params') (lambdaBody lam) futhark-0.25.27/src/Futhark/Pass/ExplicitAllocations/Seq.hs000066400000000000000000000005551475065116200235060ustar00rootroot00000000000000module Futhark.Pass.ExplicitAllocations.Seq ( explicitAllocations, simplifiable, ) where import Futhark.IR.Seq import Futhark.IR.SeqMem import Futhark.Pass import Futhark.Pass.ExplicitAllocations explicitAllocations :: Pass Seq SeqMem explicitAllocations = explicitAllocationsGeneric DefaultSpace (const $ pure $ Inner NoOp) defaultExpHints futhark-0.25.27/src/Futhark/Pass/ExtractKernels.hs000066400000000000000000001020221475065116200217320ustar00rootroot00000000000000{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} -- | Kernel extraction. -- -- In the following, I will use the term "width" to denote the amount -- of immediate parallelism in a map - that is, the outer size of the -- array(s) being used as input. -- -- = Basic Idea -- -- If we have: -- -- @ -- map -- map(f) -- stms_a... -- map(g) -- @ -- -- Then we want to distribute to: -- -- @ -- map -- map(f) -- map -- stms_a -- map -- map(g) -- @ -- -- But for now only if -- -- (0) it can be done without creating irregular arrays. -- Specifically, the size of the arrays created by @map(f)@, by -- @map(g)@ and whatever is created by @stms_a@ that is also used -- in @map(g)@, must be invariant to the outermost loop. -- -- (1) the maps are _balanced_. That is, the functions @f@ and @g@ -- must do the same amount of work for every iteration. -- -- The advantage is that the map-nests containing @map(f)@ and -- @map(g)@ can now be trivially flattened at no cost, thus exposing -- more parallelism. Note that the @stms_a@ map constitutes array -- expansion, which requires additional storage. -- -- = Distributing Sequential Loops -- -- As a starting point, sequential loops are treated like scalar -- expressions. That is, not distributed. However, sometimes it can -- be worthwhile to distribute if they contain a map: -- -- @ -- map -- loop -- map -- map -- @ -- -- If we distribute the loop and interchange the outer map into the -- loop, we get this: -- -- @ -- loop -- map -- map -- map -- map -- @ -- -- Now more parallelism may be available. -- -- = Unbalanced Maps -- -- Unbalanced maps will as a rule be sequentialised, but sometimes, -- there is another way. Assume we find this: -- -- @ -- map -- map(f) -- map(g) -- map -- @ -- -- Presume that @map(f)@ is unbalanced. By the simple rule above, we -- would then fully sequentialise it, resulting in this: -- -- @ -- map -- loop -- map -- map -- @ -- -- == Balancing by Loop Interchange -- -- The above is not ideal, as we cannot flatten the @map-loop@ nest, -- and we are thus limited in the amount of parallelism available. -- -- But assume now that the width of @map(g)@ is invariant to the outer -- loop. Then if possible, we can interchange @map(f)@ and @map(g)@, -- sequentialise @map(f)@ and distribute, interchanging the outer -- parallel loop into the sequential loop: -- -- @ -- loop(f) -- map -- map(g) -- map -- map -- @ -- -- After flattening the two nests we can obtain more parallelism. -- -- When distributing a map, we also need to distribute everything that -- the map depends on - possibly as its own map. When distributing a -- set of scalar bindings, we will need to know which of the binding -- results are used afterwards. Hence, we will need to compute usage -- information. -- -- = Redomap -- -- Redomap can be handled much like map. Distributed loops are -- distributed as maps, with the parameters corresponding to the -- neutral elements added to their bodies. The remaining loop will -- remain a redomap. Example: -- -- @ -- redomap(op, -- fn (v) => -- map(f) -- map(g), -- e,a) -- @ -- -- distributes to -- -- @ -- let b = map(fn v => -- let acc = e -- map(f), -- a) -- redomap(op, -- fn (v,dist) => -- map(g), -- e,a,b) -- @ -- -- Note that there may be further kernel extraction opportunities -- inside the @map(f)@. The downside of this approach is that the -- intermediate array (@b@ above) must be written to main memory. An -- often better approach is to just turn the entire @redomap@ into a -- single kernel. module Futhark.Pass.ExtractKernels (extractKernels) where import Control.Monad import Control.Monad.RWS.Strict import Control.Monad.Reader import Data.Bifunctor (first) import Data.Maybe import Futhark.IR.GPU import Futhark.IR.SOACS import Futhark.IR.SOACS.Simplify (simplifyStms) import Futhark.MonadFreshNames import Futhark.Pass import Futhark.Pass.ExtractKernels.BlockedKernel import Futhark.Pass.ExtractKernels.DistributeNests import Futhark.Pass.ExtractKernels.Distribution import Futhark.Pass.ExtractKernels.ISRWIM import Futhark.Pass.ExtractKernels.Intrablock import Futhark.Pass.ExtractKernels.StreamKernel import Futhark.Pass.ExtractKernels.ToGPU import Futhark.Tools import Futhark.Transform.FirstOrderTransform qualified as FOT import Futhark.Transform.Rename import Futhark.Util.Log import Prelude hiding (log) -- | Transform a program using SOACs to a program using explicit -- kernels, using the kernel extraction transformation. extractKernels :: Pass SOACS GPU extractKernels = Pass { passName = "extract kernels", passDescription = "Perform kernel extraction", passFunction = transformProg } transformProg :: Prog SOACS -> PassM (Prog GPU) transformProg prog = do consts' <- runDistribM $ transformStms mempty $ stmsToList $ progConsts prog funs' <- mapM (transformFunDef $ scopeOf consts') $ progFuns prog pure $ prog { progConsts = consts', progFuns = funs' } -- In order to generate more stable threshold names, we keep track of -- the numbers used for thresholds separately from the ordinary name -- source, data State = State { stateNameSource :: VNameSource, stateThresholdCounter :: Int } newtype DistribM a = DistribM (RWS (Scope GPU) Log State a) deriving ( Functor, Applicative, Monad, HasScope GPU, LocalScope GPU, MonadState State, MonadLogger ) instance MonadFreshNames DistribM where getNameSource = gets stateNameSource putNameSource src = modify $ \s -> s {stateNameSource = src} runDistribM :: (MonadLogger m, MonadFreshNames m) => DistribM a -> m a runDistribM (DistribM m) = do (x, msgs) <- modifyNameSource $ \src -> let (x, s, msgs) = runRWS m mempty (State src 0) in ((x, msgs), stateNameSource s) addLog msgs pure x transformFunDef :: (MonadFreshNames m, MonadLogger m) => Scope GPU -> FunDef SOACS -> m (FunDef GPU) transformFunDef scope (FunDef entry attrs name rettype params body) = runDistribM $ do body' <- localScope (scope <> scopeOfFParams params) $ transformBody mempty body pure $ FunDef entry attrs name rettype params body' type GPUStms = Stms GPU transformBody :: KernelPath -> Body SOACS -> DistribM (Body GPU) transformBody path body = do stms <- transformStms path $ stmsToList $ bodyStms body pure $ mkBody stms $ bodyResult body transformStms :: KernelPath -> [Stm SOACS] -> DistribM GPUStms transformStms _ [] = pure mempty transformStms path (stm : stms) = sequentialisedUnbalancedStm stm >>= \case Nothing -> do stm' <- transformStm path stm inScopeOf stm' $ (stm' <>) <$> transformStms path stms Just stms' -> transformStms path $ stmsToList stms' <> stms unbalancedLambda :: Lambda SOACS -> Bool unbalancedLambda orig_lam = unbalancedBody (namesFromList $ map paramName $ lambdaParams orig_lam) $ lambdaBody orig_lam where subExpBound (Var i) bound = i `nameIn` bound subExpBound (Constant _) _ = False unbalancedBody bound body = any (unbalancedStm (bound <> boundInBody body) . stmExp) $ bodyStms body -- XXX - our notion of balancing is probably still too naive. unbalancedStm bound (Op (Stream w _ _ _)) = w `subExpBound` bound unbalancedStm bound (Op (Screma w _ _)) = w `subExpBound` bound unbalancedStm _ Op {} = False unbalancedStm _ Loop {} = False unbalancedStm bound (WithAcc _ lam) = unbalancedBody bound (lambdaBody lam) unbalancedStm bound (Match ses cases defbody _) = any (`subExpBound` bound) ses && ( any (unbalancedBody bound . caseBody) cases || unbalancedBody bound defbody ) unbalancedStm _ (BasicOp _) = False unbalancedStm _ Apply {} = False sequentialisedUnbalancedStm :: Stm SOACS -> DistribM (Maybe (Stms SOACS)) sequentialisedUnbalancedStm (Let pat _ (Op soac@(Screma _ _ form))) | Just (_, lam2) <- isRedomapSOAC form, unbalancedLambda lam2, lambdaContainsParallelism lam2 = do types <- asksScope scopeForSOACs Just . snd <$> runBuilderT (FOT.transformSOAC pat soac) types sequentialisedUnbalancedStm _ = pure Nothing cmpSizeLe :: String -> SizeClass -> [SubExp] -> DistribM ((SubExp, Name), Stms GPU) cmpSizeLe desc size_class to_what = do x <- gets stateThresholdCounter modify $ \s -> s {stateThresholdCounter = x + 1} let size_key = nameFromString $ desc ++ "_" ++ show x runBuilder $ do to_what' <- letSubExp "comparatee" =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) to_what cmp_res <- letSubExp desc $ Op $ SizeOp $ CmpSizeLe size_key size_class to_what' pure (cmp_res, size_key) kernelAlternatives :: (MonadFreshNames m, HasScope GPU m) => Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU) kernelAlternatives pat default_body [] = runBuilder_ $ do ses <- bodyBind default_body forM_ (zip (patNames pat) ses) $ \(name, SubExpRes cs se) -> certifying cs $ letBindNames [name] $ BasicOp $ SubExp se kernelAlternatives pat default_body ((cond, alt) : alts) = runBuilder_ $ do alts_pat <- fmap Pat . forM (patElems pat) $ \pe -> do name <- newVName $ baseString $ patElemName pe pure pe {patElemName = name} alt_stms <- kernelAlternatives alts_pat default_body alts let alt_body = mkBody alt_stms $ varsRes $ patNames alts_pat letBind pat . Match [cond] [Case [Just $ BoolValue True] alt] alt_body $ MatchDec (staticShapes (patTypes pat)) MatchEquiv transformLambda :: KernelPath -> Lambda SOACS -> DistribM (Lambda GPU) transformLambda path (Lambda params ret body) = Lambda params ret <$> localScope (scopeOfLParams params) (transformBody path body) versionScanRed :: KernelPath -> Pat Type -> StmAux () -> SubExp -> Lambda SOACS -> DistribM (Stms GPU) -> DistribM (Body GPU) -> ([(Name, Bool)] -> DistribM (Body GPU)) -> DistribM (Stms GPU) versionScanRed path pat aux w map_lam paralleliseOuter outerParallelBody innerParallelBody = if not (lambdaContainsParallelism map_lam) || ("sequential_inner" `inAttrs` stmAuxAttrs aux) then paralleliseOuter else do ((outer_suff, outer_suff_key), suff_stms) <- sufficientParallelism "suff_outer_screma" [w] path Nothing outer_stms <- outerParallelBody inner_stms <- innerParallelBody ((outer_suff_key, False) : path) (suff_stms <>) <$> kernelAlternatives pat inner_stms [(outer_suff, outer_stms)] transformStm :: KernelPath -> Stm SOACS -> DistribM GPUStms transformStm _ stm | "sequential" `inAttrs` stmAuxAttrs (stmAux stm) = runBuilder_ $ FOT.transformStmRecursively stm transformStm path (Let pat aux (Op soac)) | "sequential_outer" `inAttrs` stmAuxAttrs aux = transformStms path . stmsToList . fmap (certify (stmAuxCerts aux)) =<< runBuilder_ (FOT.transformSOAC pat soac) transformStm path (Let pat aux (Match c cases defbody rt)) = do cases' <- mapM (traverse $ transformBody path) cases defbody' <- transformBody path defbody pure $ oneStm $ Let pat aux $ Match c cases' defbody' rt transformStm path (Let pat aux (WithAcc inputs lam)) = oneStm . Let pat aux <$> (WithAcc (map transformInput inputs) <$> transformLambda path lam) where transformInput (shape, arrs, op) = (shape, arrs, fmap (first soacsLambdaToGPU) op) transformStm path (Let pat aux (Loop merge form body)) = localScope (scopeOfLoopForm form <> scopeOfFParams params) $ oneStm . Let pat aux . Loop merge form <$> transformBody path body where params = map fst merge transformStm path (Let pat aux (Op (Screma w arrs form))) | Just lam <- isMapSOAC form = onMap path $ MapLoop pat aux w lam arrs transformStm path (Let pat aux@(StmAux cs _ _) (Op (Screma w arrs form))) | Just scans <- isScanSOAC form, Scan scan_lam nes <- singleScan scans, Just do_iswim <- iswim pat w scan_lam $ zip nes arrs = do types <- asksScope scopeForSOACs transformStms path . stmsToList . snd =<< runBuilderT (certifying cs do_iswim) types | Just (scans, map_lam) <- isScanomapSOAC form = do let paralleliseOuter = runBuilder_ $ do scan_ops <- forM scans $ \(Scan scan_lam nes) -> do (scan_lam', nes', shape) <- determineReduceOp scan_lam nes let scan_lam'' = soacsLambdaToGPU scan_lam' pure $ SegBinOp Noncommutative scan_lam'' nes' shape let map_lam_sequential = soacsLambdaToGPU map_lam lvl <- segThreadCapped [w] "segscan" $ NoRecommendation SegNoVirt addStms . fmap (certify cs) =<< segScan lvl pat mempty w scan_ops map_lam_sequential arrs [] [] outerParallelBody = renameBody =<< (mkBody <$> paralleliseOuter <*> pure (varsRes (patNames pat))) paralleliseInner path' = do (mapstm, scanstm) <- scanomapToMapAndScan pat (w, scans, map_lam, arrs) types <- asksScope scopeForSOACs transformStms path' . stmsToList <=< (`runBuilderT_` types) $ addStms =<< simplifyStms (stmsFromList [certify cs mapstm, certify cs scanstm]) innerParallelBody path' = renameBody =<< (mkBody <$> paralleliseInner path' <*> pure (varsRes (patNames pat))) versionScanRed path pat aux w map_lam paralleliseOuter outerParallelBody innerParallelBody transformStm path (Let res_pat aux (Op (Screma w arrs form))) | Just [Reduce comm red_fun nes] <- isReduceSOAC form, let comm' | commutativeLambda red_fun = Commutative | otherwise = comm, Just do_irwim <- irwim res_pat w comm' red_fun $ zip nes arrs = do types <- asksScope scopeForSOACs stms <- fst <$> runBuilderT (simplifyStms =<< collectStms_ (auxing aux do_irwim)) types transformStms path $ stmsToList stms transformStm path (Let pat aux@(StmAux cs _ _) (Op (Screma w arrs form))) | Just (reds, map_lam) <- isRedomapSOAC form = do let paralleliseOuter = runBuilder_ $ do red_ops <- forM reds $ \(Reduce comm red_lam nes) -> do (red_lam', nes', shape) <- determineReduceOp red_lam nes let comm' | commutativeLambda red_lam' = Commutative | otherwise = comm red_lam'' = soacsLambdaToGPU red_lam' pure $ SegBinOp comm' red_lam'' nes' shape let map_lam_sequential = soacsLambdaToGPU map_lam lvl <- segThreadCapped [w] "segred" $ NoRecommendation SegNoVirt addStms . fmap (certify cs) =<< nonSegRed lvl pat w red_ops map_lam_sequential arrs outerParallelBody = renameBody =<< (mkBody <$> paralleliseOuter <*> pure (varsRes (patNames pat))) paralleliseInner path' = do (mapstm, redstm) <- redomapToMapAndReduce pat (w, reds, map_lam, arrs) types <- asksScope scopeForSOACs transformStms path' . stmsToList <=< (`runBuilderT_` types) $ addStms =<< simplifyStms (stmsFromList [certify cs mapstm, certify cs redstm]) innerParallelBody path' = renameBody =<< (mkBody <$> paralleliseInner path' <*> pure (varsRes (patNames pat))) versionScanRed path pat aux w map_lam paralleliseOuter outerParallelBody innerParallelBody transformStm path (Let pat (StmAux cs _ _) (Op (Screma w arrs form))) = do -- This screma is too complicated for us to immediately do -- anything, so split it up and try again. scope <- asksScope scopeForSOACs transformStms path . map (certify cs) . stmsToList . snd =<< runBuilderT (dissectScrema pat w form arrs) scope transformStm path (Let pat _ (Op (Stream w arrs nes fold_fun))) = do -- Remove the stream and leave the body parallel. It will be -- distributed. types <- asksScope scopeForSOACs transformStms path . stmsToList . snd =<< runBuilderT (sequentialStreamWholeArray pat w nes fold_fun arrs) types -- -- When we are scattering into a multidimensional array, we want to -- fully parallelise, such that we do not have threads writing -- potentially large rows. We do this by fissioning the scatter into a -- map part and a scatter part, where the former is flattened as -- usual, and the latter has a thread per primitive element to be -- written. -- -- TODO: this could be slightly smarter. If we are dealing with a -- horizontally fused Scatter that targets both single- and -- multi-dimensional arrays, we could handle the former in the map -- stage. This would save us from having to store all the intermediate -- results to memory. Troels suspects such cases are very rare, but -- they may appear some day. transformStm path (Let pat aux (Op (Scatter w arrs as lam))) | not $ all primType $ lambdaReturnType lam = do -- Produce map stage. map_pat <- fmap Pat $ forM (lambdaReturnType lam) $ \t -> PatElem <$> newVName "scatter_tmp" <*> pure (t `arrayOfRow` w) map_stms <- onMap path $ MapLoop map_pat aux w lam arrs -- Now do the scatters. runBuilder_ $ do addStms map_stms zipWithM_ doScatter (patElems pat) $ groupScatterResults as $ patNames map_pat where -- Generate code for a scatter where each thread writes only a scalar. doScatter res_pe (scatter_space, arr, is_vs) = do kernel_i <- newVName "write_i" arr_t <- lookupType arr val_t <- stripArray (shapeRank scatter_space) <$> lookupType arr val_is <- replicateM (arrayRank val_t) (newVName "val_i") (kret, kstms) <- collectStms $ do is_vs' <- forM is_vs $ \(is, v) -> do v' <- letSubExp (baseString v <> "_elem") $ BasicOp $ Index v $ Slice $ map (DimFix . Var) $ kernel_i : val_is is' <- forM is $ \i' -> letSubExp (baseString i' <> "_i") $ BasicOp $ Index i' $ Slice [DimFix $ Var kernel_i] pure (Slice $ map DimFix $ is' <> map Var val_is, v') pure $ WriteReturns mempty arr is_vs' (kernel, stms) <- mapKernel segThreadCapped ((kernel_i, w) : zip val_is (arrayDims val_t)) mempty [arr_t] (KernelBody () kstms [kret]) addStms stms letBind (Pat [res_pe]) $ Op $ SegOp kernel -- transformStm _ (Let pat (StmAux cs _ _) (Op (Scatter w ivs as lam))) = runBuilder_ $ do let lam' = soacsLambdaToGPU lam write_i <- newVName "write_i" let krets = do (_a_w, a, is_vs) <- groupScatterResults as $ bodyResult $ lambdaBody lam' let res_cs = foldMap (foldMap resCerts . fst) is_vs <> foldMap (resCerts . snd) is_vs is_vs' = [(Slice $ map (DimFix . resSubExp) is, resSubExp v) | (is, v) <- is_vs] pure $ WriteReturns res_cs a is_vs' body = KernelBody () (bodyStms $ lambdaBody lam') krets inputs = do (p, p_a) <- zip (lambdaParams lam') ivs pure $ KernelInput (paramName p) (paramType p) p_a [Var write_i] (kernel, stms) <- mapKernel segThreadCapped [(write_i, w)] inputs (patTypes pat) body certifying cs $ do addStms stms letBind pat $ Op $ SegOp kernel transformStm _ (Let orig_pat (StmAux cs _ _) (Op (Hist w imgs ops bucket_fun))) = do let bfun' = soacsLambdaToGPU bucket_fun -- It is important not to launch unnecessarily many threads for -- histograms, because it may mean we unnecessarily need to reduce -- subhistograms as well. runBuilder_ $ do lvl <- segThreadCapped [w] "seghist" $ NoRecommendation SegNoVirt addStms =<< histKernel onLambda lvl orig_pat [] [] cs w ops bfun' imgs where onLambda = pure . soacsLambdaToGPU transformStm _ stm = runBuilder_ $ FOT.transformStmRecursively stm sufficientParallelism :: String -> [SubExp] -> KernelPath -> Maybe Int64 -> DistribM ((SubExp, Name), Stms GPU) sufficientParallelism desc ws path def = cmpSizeLe desc (SizeThreshold path def) ws -- | Intra-group parallelism is worthwhile if the lambda contains more -- than one instance of non-map nested parallelism, or any nested -- parallelism inside a loop. worthIntrablock :: Lambda SOACS -> Bool worthIntrablock lam = bodyInterest (lambdaBody lam) > 1 where bodyInterest body = sum $ interest <$> bodyStms body interest stm | "sequential" `inAttrs` attrs = 0 :: Int | Op (Screma w _ form) <- stmExp stm, Just lam' <- isMapSOAC form = mapLike w lam' | Op (Scatter w _ _ lam') <- stmExp stm = mapLike w lam' | Loop _ _ body <- stmExp stm = bodyInterest body * 10 | Match _ cases defbody _ <- stmExp stm = foldl max (bodyInterest defbody) (map (bodyInterest . caseBody) cases) | Op (Screma w _ (ScremaForm lam' _ _)) <- stmExp stm = zeroIfTooSmall w + bodyInterest (lambdaBody lam') | Op (Stream _ _ _ lam') <- stmExp stm = bodyInterest $ lambdaBody lam' | otherwise = 0 where attrs = stmAuxAttrs $ stmAux stm sequential_inner = "sequential_inner" `inAttrs` attrs zeroIfTooSmall (Constant (IntValue x)) | intToInt64 x < 32 = 0 zeroIfTooSmall _ = 1 mapLike w lam' = if sequential_inner then 0 else max (zeroIfTooSmall w) (bodyInterest (lambdaBody lam')) -- | A lambda is worth sequentialising if it contains enough nested -- parallelism of an interesting kind. worthSequentialising :: Lambda SOACS -> Bool worthSequentialising lam = bodyInterest (0 :: Int) (lambdaBody lam) > 1 where bodyInterest depth body = sum $ interest depth <$> bodyStms body interest depth stm | "sequential" `inAttrs` attrs = 0 :: Int | Op (Screma _ _ form@(ScremaForm lam' _ _)) <- stmExp stm, isJust $ isMapSOAC form = if sequential_inner then 0 else bodyInterest (depth + 1) (lambdaBody lam') | Op Scatter {} <- stmExp stm = 0 -- Basically a map. | Loop _ ForLoop {} body <- stmExp stm = bodyInterest (depth + 1) body * 10 | WithAcc _ withacc_lam <- stmExp stm = bodyInterest (depth + 1) (lambdaBody withacc_lam) | Op (Screma _ _ form@(ScremaForm lam' _ _)) <- stmExp stm = 1 + bodyInterest (depth + 1) (lambdaBody lam') + -- Give this a bigger score if it's a redomap just inside -- the the outer lambda, as these are often tileable and -- thus benefit more from sequentialisation. case (isRedomapSOAC form, depth) of (Just _, 0) -> 1 _ -> 0 | otherwise = 0 where attrs = stmAuxAttrs $ stmAux stm sequential_inner = "sequential_inner" `inAttrs` attrs onTopLevelStms :: KernelPath -> Stms SOACS -> DistNestT GPU DistribM GPUStms onTopLevelStms path stms = liftInner $ transformStms path $ stmsToList stms onMap :: KernelPath -> MapLoop -> DistribM GPUStms onMap path (MapLoop pat aux w lam arrs) = do types <- askScope let loopnest = MapNesting pat aux w $ zip (lambdaParams lam) arrs env path' = DistEnv { distNest = singleNesting (Nesting mempty loopnest), distScope = scopeOfPat pat <> scopeForGPU (scopeOf lam) <> types, distOnInnerMap = onInnerMap path', distOnTopLevelStms = onTopLevelStms path', distSegLevel = segThreadCapped, distOnSOACSStms = pure . oneStm . soacsStmToGPU, distOnSOACSLambda = pure . soacsLambdaToGPU } exploitInnerParallelism path' = runDistNestT (env path') $ distributeMapBodyStms acc (bodyStms $ lambdaBody lam) let exploitOuterParallelism path' = do let lam' = soacsLambdaToGPU lam runDistNestT (env path') $ distribute $ addStmsToAcc (bodyStms $ lambdaBody lam') acc onMap' (newKernel loopnest) path exploitOuterParallelism exploitInnerParallelism pat lam where acc = DistAcc { distTargets = singleTarget (pat, bodyResult $ lambdaBody lam), distStms = mempty } onlyExploitIntra :: Attrs -> Bool onlyExploitIntra attrs = AttrComp "incremental_flattening" ["only_intra"] `inAttrs` attrs mayExploitOuter :: Attrs -> Bool mayExploitOuter attrs = not $ AttrComp "incremental_flattening" ["no_outer"] `inAttrs` attrs || AttrComp "incremental_flattening" ["only_inner"] `inAttrs` attrs mayExploitIntra :: Attrs -> Bool mayExploitIntra attrs = not $ AttrComp "incremental_flattening" ["no_intra"] `inAttrs` attrs || AttrComp "incremental_flattening" ["only_inner"] `inAttrs` attrs -- The minimum amount of inner parallelism we require (by default) in -- intra-group versions. Less than this is usually pointless on a GPU -- (but we allow tuning to change it). intraMinInnerPar :: Int64 intraMinInnerPar = 32 -- One NVIDIA warp onMap' :: KernelNest -> KernelPath -> (KernelPath -> DistribM (Stms GPU)) -> (KernelPath -> DistribM (Stms GPU)) -> Pat Type -> Lambda SOACS -> DistribM (Stms GPU) onMap' loopnest path mk_seq_stms mk_par_stms pat lam = do -- Some of the control flow here looks a bit convoluted because we -- are trying to avoid generating unneeded threshold parameters, -- which means we need to do all the pruning checks up front. types <- askScope let only_intra = onlyExploitIntra (stmAuxAttrs aux) may_intra = worthIntrablock lam && mayExploitIntra attrs intra <- if only_intra || may_intra then flip runReaderT types $ intrablockParallelise loopnest lam else pure Nothing case intra of _ | "sequential_inner" `inAttrs` attrs -> do seq_body <- renameBody =<< mkBody <$> mk_seq_stms path <*> pure res kernelAlternatives pat seq_body [] -- Nothing | not only_intra, Just m <- mkSeqAlts -> do (outer_suff, outer_suff_key, outer_suff_stms, seq_body) <- m par_body <- renameBody =<< mkBody <$> mk_par_stms ((outer_suff_key, False) : path) <*> pure res (outer_suff_stms <>) <$> kernelAlternatives pat par_body [(outer_suff, seq_body)] -- | otherwise -> do par_body <- renameBody =<< mkBody <$> mk_par_stms path <*> pure res kernelAlternatives pat par_body [] -- Just intra'@(_, _, log, intra_prelude, intra_stms) | only_intra -> do addLog log group_par_body <- renameBody $ mkBody intra_stms res (intra_prelude <>) <$> kernelAlternatives pat group_par_body [] -- | otherwise -> do addLog log case mkSeqAlts of Nothing -> do (group_par_body, intra_ok, intra_suff_key, intra_suff_stms) <- checkSuffIntraPar path intra' par_body <- renameBody =<< mkBody <$> mk_par_stms ((intra_suff_key, False) : path) <*> pure res (intra_suff_stms <>) <$> kernelAlternatives pat par_body [(intra_ok, group_par_body)] Just m -> do (outer_suff, outer_suff_key, outer_suff_stms, seq_body) <- m (group_par_body, intra_ok, intra_suff_key, intra_suff_stms) <- checkSuffIntraPar ((outer_suff_key, False) : path) intra' par_body <- renameBody =<< mkBody <$> mk_par_stms ( [ (outer_suff_key, False), (intra_suff_key, False) ] ++ path ) <*> pure res ((outer_suff_stms <> intra_suff_stms) <>) <$> kernelAlternatives pat par_body [(outer_suff, seq_body), (intra_ok, group_par_body)] where nest_ws = kernelNestWidths loopnest res = varsRes $ patNames pat aux = loopNestingAux $ innermostKernelNesting loopnest attrs = stmAuxAttrs aux mkSeqAlts | worthSequentialising lam, mayExploitOuter attrs = Just $ do ((outer_suff, outer_suff_key), outer_suff_stms) <- checkSuffOuterPar seq_body <- renameBody =<< mkBody <$> mk_seq_stms ((outer_suff_key, True) : path) <*> pure res pure (outer_suff, outer_suff_key, outer_suff_stms, seq_body) | otherwise = Nothing checkSuffOuterPar = sufficientParallelism "suff_outer_par" nest_ws path Nothing checkSuffIntraPar path' ((_intra_min_par, intra_avail_par), tblock_size, _, intra_prelude, intra_stms) = do -- We must check that all intra-group parallelism fits in a group. ((intra_ok, intra_suff_key), intra_suff_stms) <- do ((intra_suff, suff_key), check_suff_stms) <- sufficientParallelism "suff_intra_par" [intra_avail_par] path' (Just intraMinInnerPar) runBuilder $ do addStms intra_prelude max_tblock_size <- letSubExp "max_tblock_size" $ Op $ SizeOp $ GetSizeMax SizeThreadBlock fits <- letSubExp "fits" $ BasicOp $ CmpOp (CmpSle Int64) tblock_size max_tblock_size addStms check_suff_stms intra_ok <- letSubExp "intra_suff_and_fits" $ BasicOp $ BinOp LogAnd fits intra_suff pure (intra_ok, suff_key) group_par_body <- renameBody $ mkBody intra_stms res pure (group_par_body, intra_ok, intra_suff_key, intra_suff_stms) removeUnusedMapResults :: Pat Type -> [SubExpRes] -> Lambda rep -> Maybe ([Int], Pat Type, Lambda rep) removeUnusedMapResults (Pat pes) res lam = do let (pes', body_res) = unzip $ filter (used . fst) $ zip pes $ bodyResult (lambdaBody lam) perm <- map (Var . patElemName) pes' `isPermutationOf` map resSubExp res pure (perm, Pat pes', lam {lambdaBody = (lambdaBody lam) {bodyResult = body_res}}) where used pe = patElemName pe `nameIn` freeIn res onInnerMap :: KernelPath -> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU) onInnerMap path maploop@(MapLoop pat aux w lam arrs) acc | unbalancedLambda lam, lambdaContainsParallelism lam = addStmToAcc (mapLoopStm maploop) acc | otherwise = distributeSingleStm acc (mapLoopStm maploop) >>= \case Just (post_kernels, res, nest, acc') | Just (perm, pat', lam') <- removeUnusedMapResults pat res lam -> do addPostStms post_kernels multiVersion perm nest acc' pat' lam' _ -> distributeMap maploop acc where discardTargets acc' = -- FIXME: work around bogus targets. acc' {distTargets = singleTarget (mempty, mempty)} -- GHC 9.2 loops without the type annotation. generate :: [Int] -> KernelNest -> Pat Type -> Lambda SOACS -> DistEnv GPU DistribM -> Scope GPU -> DistribM (Stms GPU) generate perm nest pat' lam' dist_env extra_scope = localScope extra_scope $ do let maploop' = MapLoop pat' aux w lam' arrs exploitInnerParallelism path' = do let dist_env' = dist_env { distOnTopLevelStms = onTopLevelStms path', distOnInnerMap = onInnerMap path' } runDistNestT dist_env' . inNesting nest . localScope extra_scope $ discardTargets <$> distributeMap maploop' acc {distStms = mempty} -- Normally the permutation is for the output pattern, but -- we can't really change that, so we change the result -- order instead. let lam_res' = rearrangeShape (rearrangeInverse perm) $ bodyResult $ lambdaBody lam' lam'' = lam' {lambdaBody = (lambdaBody lam') {bodyResult = lam_res'}} map_nesting = MapNesting pat' aux w $ zip (lambdaParams lam') arrs nest' = pushInnerKernelNesting (pat', lam_res') map_nesting nest -- XXX: we do not construct a new KernelPath when -- sequentialising. This is only OK as long as further -- versioning does not take place down that branch (it currently -- does not). (sequentialised_kernel, nestw_stms) <- localScope extra_scope $ do let sequentialised_lam = soacsLambdaToGPU lam'' constructKernel segThreadCapped nest' $ lambdaBody sequentialised_lam let outer_pat = loopNestingPat $ fst nest (nestw_stms <>) <$> onMap' nest' path (const $ pure $ oneStm sequentialised_kernel) exploitInnerParallelism outer_pat lam'' multiVersion perm nest acc' pat' lam' = do -- The kernel can be distributed by itself, so now we can -- decide whether to just sequentialise, or exploit inner -- parallelism. dist_env <- ask let extra_scope = targetsScope $ distTargets acc' stms <- liftInner $ generate perm nest pat' lam' dist_env extra_scope postStm stms pure acc' futhark-0.25.27/src/Futhark/Pass/ExtractKernels/000077500000000000000000000000001475065116200214015ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Pass/ExtractKernels/BlockedKernel.hs000066400000000000000000000173351475065116200244520ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.Pass.ExtractKernels.BlockedKernel ( DistRep, MkSegLevel, ThreadRecommendation (..), segRed, nonSegRed, segScan, segHist, segMap, mapKernel, KernelInput (..), readKernelInput, mkSegSpace, dummyDim, ) where import Control.Monad import Futhark.Analysis.PrimExp import Futhark.IR import Futhark.IR.Aliases (AliasableRep) import Futhark.IR.GPU.Op (SegVirt (..)) import Futhark.IR.SegOp import Futhark.MonadFreshNames import Futhark.Tools import Futhark.Transform.Rename import Prelude hiding (quot) -- | Constraints pertinent to performing distribution/flattening. type DistRep rep = ( Buildable rep, HasSegOp rep, BuilderOps rep, LetDec rep ~ Type, ExpDec rep ~ (), BodyDec rep ~ (), AliasableRep rep ) data ThreadRecommendation = ManyThreads | NoRecommendation SegVirt type MkSegLevel rep m = [SubExp] -> String -> ThreadRecommendation -> BuilderT rep m (SegOpLevel rep) mkSegSpace :: (MonadFreshNames m) => [(VName, SubExp)] -> m SegSpace mkSegSpace dims = SegSpace <$> newVName "phys_tid" <*> pure dims prepareRedOrScan :: (MonadBuilder m, DistRep (Rep m)) => Certs -> SubExp -> Lambda (Rep m) -> [VName] -> [(VName, SubExp)] -> [KernelInput] -> m (SegSpace, KernelBody (Rep m)) prepareRedOrScan cs w map_lam arrs ispace inps = do gtid <- newVName "gtid" space <- mkSegSpace $ ispace ++ [(gtid, w)] kbody <- fmap (uncurry (flip (KernelBody ()))) $ runBuilder $ localScope (scopeOfSegSpace space) $ do mapM_ readKernelInput inps certifying cs . mapM_ readKernelInput $ do (p, arr) <- zip (lambdaParams map_lam) arrs pure $ KernelInput (paramName p) (paramType p) arr [Var gtid] res <- bodyBind (lambdaBody map_lam) forM res $ \(SubExpRes res_cs se) -> pure $ Returns ResultMaySimplify res_cs se pure (space, kbody) segRed :: (MonadFreshNames m, DistRep rep, HasScope rep m) => SegOpLevel rep -> Pat (LetDec rep) -> Certs -> SubExp -> -- segment size [SegBinOp rep] -> Lambda rep -> [VName] -> [(VName, SubExp)] -> -- ispace = pair of (gtid, size) for the maps on "top" of this reduction [KernelInput] -> -- inps = inputs that can be looked up by using the gtids from ispace m (Stms rep) segRed lvl pat cs w ops map_lam arrs ispace inps = runBuilder_ $ do (kspace, kbody) <- prepareRedOrScan cs w map_lam arrs ispace inps letBind pat $ Op $ segOp $ SegRed lvl kspace ops (lambdaReturnType map_lam) kbody segScan :: (MonadFreshNames m, DistRep rep, HasScope rep m) => SegOpLevel rep -> Pat (LetDec rep) -> Certs -> SubExp -> -- segment size [SegBinOp rep] -> Lambda rep -> [VName] -> [(VName, SubExp)] -> -- ispace = pair of (gtid, size) for the maps on "top" of this scan [KernelInput] -> -- inps = inputs that can be looked up by using the gtids from ispace m (Stms rep) segScan lvl pat cs w ops map_lam arrs ispace inps = runBuilder_ $ do (kspace, kbody) <- prepareRedOrScan cs w map_lam arrs ispace inps letBind pat $ Op $ segOp $ SegScan lvl kspace ops (lambdaReturnType map_lam) kbody segMap :: (MonadFreshNames m, DistRep rep, HasScope rep m) => SegOpLevel rep -> Pat (LetDec rep) -> SubExp -> -- segment size Lambda rep -> [VName] -> [(VName, SubExp)] -> -- ispace = pair of (gtid, size) for the maps on "top" of this map [KernelInput] -> -- inps = inputs that can be looked up by using the gtids from ispace m (Stms rep) segMap lvl pat w map_lam arrs ispace inps = runBuilder_ $ do (kspace, kbody) <- prepareRedOrScan mempty w map_lam arrs ispace inps letBind pat $ Op $ segOp $ SegMap lvl kspace (lambdaReturnType map_lam) kbody dummyDim :: (MonadBuilder m) => Pat Type -> m (Pat Type, [(VName, SubExp)], m ()) dummyDim pat = do -- We add a unit-size segment on top to ensure that the result -- of the SegRed is an array, which we then immediately index. -- This is useful in the case that the value is used on the -- device afterwards, as this may save an expensive -- host-device copy (scalars are kept on the host, but arrays -- may be on the device). let addDummyDim t = t `arrayOfRow` intConst Int64 1 pat' <- fmap addDummyDim <$> renamePat pat dummy <- newVName "dummy" let ispace = [(dummy, intConst Int64 1)] pure ( pat', ispace, forM_ (zip (patNames pat') (patNames pat)) $ \(from, to) -> do from_t <- lookupType from letBindNames [to] . BasicOp $ case from_t of Acc {} -> SubExp $ Var from _ -> Index from $ fullSlice from_t [DimFix $ intConst Int64 0] ) nonSegRed :: (MonadFreshNames m, DistRep rep, HasScope rep m) => SegOpLevel rep -> Pat Type -> SubExp -> [SegBinOp rep] -> Lambda rep -> [VName] -> m (Stms rep) nonSegRed lvl pat w ops map_lam arrs = runBuilder_ $ do (pat', ispace, read_dummy) <- dummyDim pat addStms =<< segRed lvl pat' mempty w ops map_lam arrs ispace [] read_dummy segHist :: (DistRep rep, MonadFreshNames m, HasScope rep m) => SegOpLevel rep -> Pat Type -> SubExp -> -- | Segment indexes and sizes. [(VName, SubExp)] -> [KernelInput] -> [HistOp rep] -> Lambda rep -> [VName] -> m (Stms rep) segHist lvl pat arr_w ispace inps ops lam arrs = runBuilder_ $ do gtid <- newVName "gtid" space <- mkSegSpace $ ispace ++ [(gtid, arr_w)] kbody <- fmap (uncurry (flip $ KernelBody ())) $ runBuilder $ localScope (scopeOfSegSpace space) $ do mapM_ readKernelInput inps forM_ (zip (lambdaParams lam) arrs) $ \(p, arr) -> do arr_t <- lookupType arr letBindNames [paramName p] $ BasicOp $ Index arr $ fullSlice arr_t [DimFix $ Var gtid] res <- bodyBind (lambdaBody lam) forM res $ \(SubExpRes cs se) -> pure $ Returns ResultMaySimplify cs se letBind pat $ Op $ segOp $ SegHist lvl space ops (lambdaReturnType lam) kbody mapKernelSkeleton :: (DistRep rep, HasScope rep m, MonadFreshNames m) => [(VName, SubExp)] -> [KernelInput] -> m (SegSpace, Stms rep) mapKernelSkeleton ispace inputs = do read_input_stms <- runBuilder_ $ mapM readKernelInput inputs space <- mkSegSpace ispace pure (space, read_input_stms) mapKernel :: (DistRep rep, HasScope rep m, MonadFreshNames m) => MkSegLevel rep m -> [(VName, SubExp)] -> [KernelInput] -> [Type] -> KernelBody rep -> m (SegOp (SegOpLevel rep) rep, Stms rep) mapKernel mk_lvl ispace inputs rts (KernelBody () kstms krets) = runBuilderT' $ do (space, read_input_stms) <- mapKernelSkeleton ispace inputs let kbody' = KernelBody () (read_input_stms <> kstms) krets -- If the kernel creates arrays (meaning it will require memory -- expansion), we want to truncate the amount of threads. -- Otherwise, have at it! This is a bit of a hack - in principle, -- we should make this decision later, when we have a clearer idea -- of what is happening inside the kernel. let r = if all primType rts then ManyThreads else NoRecommendation SegVirt lvl <- mk_lvl (map snd ispace) "segmap" r pure $ SegMap lvl space rts kbody' data KernelInput = KernelInput { kernelInputName :: VName, kernelInputType :: Type, kernelInputArray :: VName, kernelInputIndices :: [SubExp] } deriving (Show) readKernelInput :: (DistRep (Rep m), MonadBuilder m) => KernelInput -> m () readKernelInput inp = do let pe = PatElem (kernelInputName inp) $ kernelInputType inp letBind (Pat [pe]) . BasicOp $ case kernelInputType inp of Acc {} -> SubExp $ Var $ kernelInputArray inp _ -> Index (kernelInputArray inp) . Slice $ map DimFix (kernelInputIndices inp) ++ map sliceDim (arrayDims (kernelInputType inp)) futhark-0.25.27/src/Futhark/Pass/ExtractKernels/DistributeNests.hs000066400000000000000000001250001475065116200250660ustar00rootroot00000000000000{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} {-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-} module Futhark.Pass.ExtractKernels.DistributeNests ( MapLoop (..), mapLoopStm, bodyContainsParallelism, lambdaContainsParallelism, determineReduceOp, histKernel, DistEnv (..), DistAcc (..), runDistNestT, DistNestT, liftInner, distributeMap, distribute, distributeSingleStm, distributeMapBodyStms, addStmsToAcc, addStmToAcc, permutationAndMissing, addPostStms, postStm, inNesting, ) where import Control.Arrow (first) import Control.Monad import Control.Monad.RWS.Strict import Control.Monad.Reader import Control.Monad.Trans.Maybe import Control.Monad.Writer.Strict import Data.List (find, partition, tails) import Data.List.NonEmpty (NonEmpty (..)) import Data.Map qualified as M import Data.Maybe import Futhark.IR import Futhark.IR.GPU.Op (SegVirt (..)) import Futhark.IR.SOACS (SOACS) import Futhark.IR.SOACS qualified as SOACS import Futhark.IR.SOACS.SOAC hiding (HistOp, histDest) import Futhark.IR.SOACS.Simplify (simpleSOACS, simplifyStms) import Futhark.IR.SegOp import Futhark.MonadFreshNames import Futhark.Pass.ExtractKernels.BlockedKernel import Futhark.Pass.ExtractKernels.Distribution import Futhark.Pass.ExtractKernels.ISRWIM import Futhark.Pass.ExtractKernels.Interchange import Futhark.Tools import Futhark.Transform.CopyPropagate import Futhark.Transform.FirstOrderTransform qualified as FOT import Futhark.Transform.Rename import Futhark.Util.Log scopeForSOACs :: (SameScope rep SOACS) => Scope rep -> Scope SOACS scopeForSOACs = castScope data MapLoop = MapLoop (Pat Type) (StmAux ()) SubExp (Lambda SOACS) [VName] mapLoopStm :: MapLoop -> Stm SOACS mapLoopStm (MapLoop pat aux w lam arrs) = Let pat aux $ Op $ Screma w arrs $ mapSOAC lam data DistEnv rep m = DistEnv { distNest :: Nestings, distScope :: Scope rep, distOnTopLevelStms :: Stms SOACS -> DistNestT rep m (Stms rep), distOnInnerMap :: MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep), distOnSOACSStms :: Stm SOACS -> Builder rep (Stms rep), distOnSOACSLambda :: Lambda SOACS -> Builder rep (Lambda rep), distSegLevel :: MkSegLevel rep m } data DistAcc rep = DistAcc { distTargets :: Targets, distStms :: Stms rep } data DistRes rep = DistRes { accPostStms :: PostStms rep, accLog :: Log } instance Semigroup (DistRes rep) where DistRes ks1 log1 <> DistRes ks2 log2 = DistRes (ks1 <> ks2) (log1 <> log2) instance Monoid (DistRes rep) where mempty = DistRes mempty mempty newtype PostStms rep = PostStms {unPostStms :: Stms rep} instance Semigroup (PostStms rep) where PostStms xs <> PostStms ys = PostStms $ ys <> xs instance Monoid (PostStms rep) where mempty = PostStms mempty typeEnvFromDistAcc :: (DistRep rep) => DistAcc rep -> Scope rep typeEnvFromDistAcc = scopeOfPat . fst . outerTarget . distTargets addStmsToAcc :: Stms rep -> DistAcc rep -> DistAcc rep addStmsToAcc stms acc = acc {distStms = stms <> distStms acc} addStmToAcc :: (MonadFreshNames m, DistRep rep) => Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep) addStmToAcc stm acc = do onSoacs <- asks distOnSOACSStms (stm', _) <- runBuilder $ onSoacs stm pure acc {distStms = stm' <> distStms acc} soacsLambda :: (MonadFreshNames m, DistRep rep) => Lambda SOACS -> DistNestT rep m (Lambda rep) soacsLambda lam = do onLambda <- asks distOnSOACSLambda fst <$> runBuilder (onLambda lam) newtype DistNestT rep m a = DistNestT (ReaderT (DistEnv rep m) (WriterT (DistRes rep) m) a) deriving ( Functor, Applicative, Monad, MonadReader (DistEnv rep m), MonadWriter (DistRes rep) ) liftInner :: (LocalScope rep m, DistRep rep) => m a -> DistNestT rep m a liftInner m = do outer_scope <- askScope DistNestT $ lift $ lift $ do inner_scope <- askScope localScope (outer_scope `M.difference` inner_scope) m instance (MonadFreshNames m) => MonadFreshNames (DistNestT rep m) where getNameSource = DistNestT $ lift getNameSource putNameSource = DistNestT . lift . putNameSource instance (Monad m, ASTRep rep) => HasScope rep (DistNestT rep m) where askScope = asks distScope instance (Monad m, ASTRep rep) => LocalScope rep (DistNestT rep m) where localScope types = local $ \env -> env {distScope = types <> distScope env} instance (Monad m) => MonadLogger (DistNestT rep m) where addLog msgs = tell mempty {accLog = msgs} runDistNestT :: (MonadLogger m, DistRep rep) => DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep) runDistNestT env (DistNestT m) = do (acc, res) <- runWriterT $ runReaderT m env addLog $ accLog res -- There may be a few final targets remaining - these correspond to -- arrays that are identity mapped, and must have statements -- inserted here. pure $ unPostStms (accPostStms res) <> identityStms (outerTarget $ distTargets acc) where outermost = nestingLoop $ case distNest env of (nest, []) -> nest (_, nest : _) -> nest params_to_arrs = map (first paramName) $ loopNestingParamsAndArrs outermost identityStms (rem_pat, res) = stmsFromList $ zipWith identityStm (patElems rem_pat) res identityStm pe (SubExpRes cs (Var v)) | Just arr <- lookup v params_to_arrs = certify cs . Let (Pat [pe]) (defAux ()) . BasicOp $ Replicate mempty (Var arr) identityStm pe (SubExpRes cs se) = certify cs . Let (Pat [pe]) (defAux ()) . BasicOp $ Replicate (Shape [loopNestingWidth outermost]) se addPostStms :: (Monad m) => PostStms rep -> DistNestT rep m () addPostStms ks = tell $ mempty {accPostStms = ks} postStm :: (Monad m) => Stms rep -> DistNestT rep m () postStm stms = addPostStms $ PostStms stms withStm :: (Monad m, DistRep rep) => Stm SOACS -> DistNestT rep m a -> DistNestT rep m a withStm stm = local $ \env -> env { distScope = castScope (scopeOf stm) <> distScope env, distNest = letBindInInnerNesting provided $ distNest env } where provided = namesFromList $ patNames $ stmPat stm leavingNesting :: (MonadFreshNames m, DistRep rep) => DistAcc rep -> DistNestT rep m (DistAcc rep) leavingNesting acc = case popInnerTarget $ distTargets acc of Nothing -> error "The kernel targets list is unexpectedly small" Just ((pat, res), newtargets) | not $ null $ distStms acc -> do -- Any statements left over correspond to something that -- could not be distributed because it would cause irregular -- arrays. These must be reconstructed into a a Map SOAC -- that will be sequentialised. XXX: life would be better if -- we were able to distribute irregular parallelism. (Nesting _ inner, _) <- asks distNest let MapNesting _ aux w params_and_arrs = inner body = Body () (distStms acc) res used_in_body = freeIn body (used_params, used_arrs) = unzip $ filter ((`nameIn` used_in_body) . paramName . fst) params_and_arrs lam' = Lambda { lambdaParams = used_params, lambdaBody = body, lambdaReturnType = map rowType $ patTypes pat } stms <- runBuilder_ . auxing aux . FOT.transformSOAC pat $ Screma w used_arrs $ mapSOAC lam' pure $ acc {distTargets = newtargets, distStms = stms} | otherwise -> do -- Any results left over correspond to a Replicate or a Copy in -- the parent nesting, depending on whether the argument is a -- parameter of the innermost nesting. (Nesting _ inner_nesting, _) <- asks distNest let w = loopNestingWidth inner_nesting aux = loopNestingAux inner_nesting inps = loopNestingParamsAndArrs inner_nesting remnantStm pe (SubExpRes cs (Var v)) | Just (_, arr) <- find ((== v) . paramName . fst) inps = certify cs . Let (Pat [pe]) aux . BasicOp $ Replicate mempty (Var arr) remnantStm pe (SubExpRes cs se) = certify cs . Let (Pat [pe]) aux . BasicOp $ Replicate (Shape [w]) se stms = stmsFromList $ zipWith remnantStm (patElems pat) res pure $ acc {distTargets = newtargets, distStms = stms} mapNesting :: (MonadFreshNames m, DistRep rep) => Pat Type -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> DistNestT rep m (DistAcc rep) -> DistNestT rep m (DistAcc rep) mapNesting pat aux w lam arrs m = local extend $ leavingNesting =<< m where nest = Nesting mempty $ MapNesting pat aux w $ zip (lambdaParams lam) arrs extend env = env { distNest = pushInnerNesting nest $ distNest env, distScope = castScope (scopeOf lam) <> distScope env } inNesting :: (Monad m, DistRep rep) => KernelNest -> DistNestT rep m a -> DistNestT rep m a inNesting (outer, nests) = local $ \env -> env { distNest = (inner, nests'), distScope = foldMap scopeOfLoopNesting (outer : nests) <> distScope env } where (inner, nests') = case reverse nests of [] -> (asNesting outer, []) (inner' : ns) -> (asNesting inner', map asNesting $ outer : reverse ns) asNesting = Nesting mempty bodyContainsParallelism :: Body SOACS -> Bool bodyContainsParallelism = any isParallelStm . bodyStms where isParallelStm stm = isMap (stmExp stm) && not ("sequential" `inAttrs` stmAuxAttrs (stmAux stm)) isMap BasicOp {} = False isMap Apply {} = False isMap Match {} = False isMap (Loop _ ForLoop {} body) = bodyContainsParallelism body isMap (Loop _ WhileLoop {} _) = False isMap (WithAcc _ lam) = bodyContainsParallelism $ lambdaBody lam isMap Op {} = True lambdaContainsParallelism :: Lambda SOACS -> Bool lambdaContainsParallelism = bodyContainsParallelism . lambdaBody distributeMapBodyStms :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep) distributeMapBodyStms orig_acc = distribute <=< onStms orig_acc . stmsToList where onStms acc [] = pure acc onStms acc (Let pat (StmAux cs _ _) (Op (Stream w arrs accs lam)) : stms) = do types <- asksScope scopeForSOACs stream_stms <- snd <$> runBuilderT (sequentialStreamWholeArray pat w accs lam arrs) types stream_stms' <- runReaderT (copyPropagateInStms simpleSOACS types stream_stms) types onStms acc $ stmsToList (fmap (certify cs) stream_stms') ++ stms onStms acc (stm : stms) = -- It is important that stm is in scope if 'maybeDistributeStm' -- wants to distribute, even if this causes the slightly silly -- situation that stm is in scope of itself. withStm stm $ maybeDistributeStm stm =<< onStms acc stms onInnerMap :: (Monad m) => MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep) onInnerMap loop acc = do f <- asks distOnInnerMap f loop acc onTopLevelStms :: (Monad m) => Stms SOACS -> DistNestT rep m () onTopLevelStms stms = do f <- asks distOnTopLevelStms postStm =<< f stms maybeDistributeStm :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep) maybeDistributeStm stm acc | "sequential" `inAttrs` stmAuxAttrs (stmAux stm) = addStmToAcc stm acc maybeDistributeStm (Let pat aux (Op soac)) acc | "sequential_outer" `inAttrs` stmAuxAttrs aux = distributeMapBodyStms acc . fmap (certify (stmAuxCerts aux)) =<< runBuilder_ (FOT.transformSOAC pat soac) maybeDistributeStm stm@(Let pat _ (Op (Screma w arrs form))) acc | Just lam <- isMapSOAC form = -- Only distribute inside the map if we can distribute everything -- following the map. distributeIfPossible acc >>= \case Nothing -> addStmToAcc stm acc Just acc' -> distribute =<< onInnerMap (MapLoop pat (stmAux stm) w lam arrs) acc' maybeDistributeStm stm@(Let pat aux (Loop merge form@ForLoop {} body)) acc | all (`notNameIn` freeIn (patTypes pat)) (patNames pat), bodyContainsParallelism body = distributeSingleStm acc stm >>= \case Just (kernels, res, nest, acc') | -- XXX: We cannot distribute if this loop depends on -- certificates bound within the loop nest (well, we could, -- but interchange would not be valid). This is not a -- fundamental restriction, but an artifact of our -- certificate representation, which we should probably -- rethink. not $ (freeIn form <> freeIn aux) `namesIntersect` boundInKernelNest nest, Just (perm, pat_unused) <- permutationAndMissing pat res -> -- We need to pretend pat_unused was used anyway, by adding -- it to the kernel nest. localScope (typeEnvFromDistAcc acc') $ do addPostStms kernels nest' <- expandKernelNest pat_unused nest types <- asksScope scopeForSOACs -- Simplification is key to hoisting out statements that -- were variant to the loop, but invariant to the outer maps -- (which are now innermost). stms <- (`runReaderT` types) $ simplifyStms =<< interchangeLoops nest' (SeqLoop perm pat merge form body) onTopLevelStms stms pure acc' _ -> addStmToAcc stm acc maybeDistributeStm stm@(Let pat _ (Match cond cases defbody ret)) acc | all (`notNameIn` freeIn pat) (patNames pat), any bodyContainsParallelism (defbody : map caseBody cases) || not (all primType (matchReturns ret)) = distributeSingleStm acc stm >>= \case Just (kernels, res, nest, acc') | not $ (freeIn cond <> freeIn ret) `namesIntersect` boundInKernelNest nest, Just (perm, pat_unused) <- permutationAndMissing pat res -> -- We need to pretend pat_unused was used anyway, by adding -- it to the kernel nest. localScope (typeEnvFromDistAcc acc') $ do nest' <- expandKernelNest pat_unused nest addPostStms kernels types <- asksScope scopeForSOACs let branch = Branch perm pat cond cases defbody ret stms <- (`runReaderT` types) $ simplifyStms . oneStm =<< interchangeBranch nest' branch onTopLevelStms stms pure acc' _ -> addStmToAcc stm acc maybeDistributeStm stm@(Let pat _ (WithAcc inputs lam)) acc | lambdaContainsParallelism lam = distributeSingleStm acc stm >>= \case Just (kernels, res, nest, acc') | not $ freeIn (drop num_accs (lambdaReturnType lam)) `namesIntersect` boundInKernelNest nest, Just (perm, pat_unused) <- permutationAndMissing pat res -> -- We need to pretend pat_unused was used anyway, by adding -- it to the kernel nest. localScope (typeEnvFromDistAcc acc') $ do nest' <- expandKernelNest pat_unused nest types <- asksScope scopeForSOACs addPostStms kernels let withacc = WithAccStm perm pat inputs lam stms <- (`runReaderT` types) $ simplifyStms . oneStm =<< interchangeWithAcc nest' withacc onTopLevelStms stms pure acc' _ -> addStmToAcc stm acc where num_accs = length inputs maybeDistributeStm (Let pat aux (Op (Screma w arrs form))) acc | Just [Reduce comm lam nes] <- isReduceSOAC form, Just m <- irwim pat w comm lam $ zip nes arrs = do types <- asksScope scopeForSOACs (_, stms) <- runBuilderT (auxing aux m) types distributeMapBodyStms acc stms -- Parallelise segmented scatters. maybeDistributeStm stm@(Let pat (StmAux cs _ _) (Op (Scatter w ivs as lam))) acc = distributeSingleStm acc stm >>= \case Just (kernels, res, nest, acc') | Just (perm, pat_unused) <- permutationAndMissing pat res -> localScope (typeEnvFromDistAcc acc') $ do nest' <- expandKernelNest pat_unused nest lam' <- soacsLambda lam addPostStms kernels postStm =<< segmentedScatterKernel nest' perm pat cs w lam' ivs as pure acc' _ -> addStmToAcc stm acc -- Parallelise segmented Hist. maybeDistributeStm stm@(Let pat (StmAux cs _ _) (Op (Hist w as ops lam))) acc = distributeSingleStm acc stm >>= \case Just (kernels, res, nest, acc') | Just (perm, pat_unused) <- permutationAndMissing pat res -> localScope (typeEnvFromDistAcc acc') $ do lam' <- soacsLambda lam nest' <- expandKernelNest pat_unused nest addPostStms kernels postStm =<< segmentedHistKernel nest' perm cs w ops lam' as pure acc' _ -> addStmToAcc stm acc -- Parallelise Index slices if the result is going to be returned -- directly from the kernel. This is because we would otherwise have -- to sequentialise writing the result, which may be costly. maybeDistributeStm stm@(Let (Pat [pe]) aux (BasicOp (Index arr slice))) acc | not $ null $ sliceDims slice, Var (patElemName pe) `elem` map resSubExp (snd (innerTarget (distTargets acc))) = distributeSingleStm acc stm >>= \case Just (kernels, _res, nest, acc') -> localScope (typeEnvFromDistAcc acc') $ do addPostStms kernels postStm =<< segmentedGatherKernel nest (stmAuxCerts aux) arr slice pure acc' _ -> addStmToAcc stm acc -- If the scan can be distributed by itself, we will turn it into a -- segmented scan. -- -- If the scan cannot be distributed by itself, it will be -- sequentialised in the default case for this function. maybeDistributeStm stm@(Let pat (StmAux cs _ _) (Op (Screma w arrs form))) acc | Just (scans, map_lam) <- isScanomapSOAC form, Scan lam nes <- singleScan scans = distributeSingleStm acc stm >>= \case Just (kernels, res, nest, acc') | Just (perm, pat_unused) <- permutationAndMissing pat res -> -- We need to pretend pat_unused was used anyway, by adding -- it to the kernel nest. localScope (typeEnvFromDistAcc acc') $ do nest' <- expandKernelNest pat_unused nest map_lam' <- soacsLambda map_lam localScope (typeEnvFromDistAcc acc') $ segmentedScanomapKernel nest' perm cs w lam map_lam' nes arrs >>= kernelOrNot mempty stm acc kernels acc' _ -> addStmToAcc stm acc -- If the map function of the reduction contains parallelism we split -- it, so that the parallelism can be exploited. maybeDistributeStm (Let pat aux (Op (Screma w arrs form))) acc | Just (reds, map_lam) <- isRedomapSOAC form, lambdaContainsParallelism map_lam = do (mapstm, redstm) <- redomapToMapAndReduce pat (w, reds, map_lam, arrs) distributeMapBodyStms acc $ oneStm mapstm {stmAux = aux} <> oneStm redstm -- if the reduction can be distributed by itself, we will turn it into a -- segmented reduce. -- -- If the reduction cannot be distributed by itself, it will be -- sequentialised in the default case for this function. maybeDistributeStm stm@(Let pat (StmAux cs _ _) (Op (Screma w arrs form))) acc | Just (reds, map_lam) <- isRedomapSOAC form, Reduce comm lam nes <- singleReduce reds = distributeSingleStm acc stm >>= \case Just (kernels, res, nest, acc') | Just (perm, pat_unused) <- permutationAndMissing pat res -> -- We need to pretend pat_unused was used anyway, by adding -- it to the kernel nest. localScope (typeEnvFromDistAcc acc') $ do nest' <- expandKernelNest pat_unused nest lam' <- soacsLambda lam map_lam' <- soacsLambda map_lam let comm' | commutativeLambda lam = Commutative | otherwise = comm regularSegmentedRedomapKernel nest' perm cs w comm' lam' map_lam' nes arrs >>= kernelOrNot mempty stm acc kernels acc' _ -> addStmToAcc stm acc maybeDistributeStm (Let pat (StmAux cs _ _) (Op (Screma w arrs form))) acc = do -- This Screma is too complicated for us to immediately do -- anything, so split it up and try again. scope <- asksScope scopeForSOACs distributeMapBodyStms acc . fmap (certify cs) . snd =<< runBuilderT (dissectScrema pat w form arrs) scope maybeDistributeStm stm@(Let _ aux (BasicOp (Replicate shape (Var stm_arr)))) acc = do distributeSingleUnaryStm acc stm stm_arr $ \nest outerpat arr -> if shape == mempty then pure $ oneStm $ Let outerpat aux $ BasicOp $ Replicate mempty $ Var arr else runBuilder_ $ auxing aux $ do arr_t <- lookupType arr let arr_r = arrayRank arr_t nest_r = length (snd nest) + 1 res_r = arr_r + shapeRank shape -- Move the to-be-replicated dimensions outermost. arr_tr <- letExp (baseString arr <> "_tr") . BasicOp $ Rearrange ([nest_r .. arr_r - 1] ++ [0 .. nest_r - 1]) arr -- Replicate the now-outermost dimensions appropriately. arr_tr_rep <- letExp (baseString arr <> "_tr_rep") . BasicOp $ Replicate shape (Var arr_tr) -- Move the replicated dimensions back where they belong. letBind outerpat . BasicOp $ Rearrange ([res_r - nest_r .. res_r - 1] ++ [0 .. res_r - nest_r - 1]) arr_tr_rep maybeDistributeStm stm@(Let _ aux (BasicOp (Replicate shape v))) acc = do distributeSingleStm acc stm >>= \case Just (kernels, _, nest, acc') | boundInKernelNest nest == mempty -> do addPostStms kernels let outerpat = loopNestingPat $ fst nest nest_shape = Shape $ kernelNestWidths nest localScope (typeEnvFromDistAcc acc') $ do postStm <=< runBuilder_ . auxing aux . letBind outerpat $ BasicOp (Replicate (nest_shape <> shape) v) pure acc' _ -> addStmToAcc stm acc -- Opaques are applied to the full array, because otherwise they can -- drastically inhibit parallelisation in some cases. maybeDistributeStm stm@(Let (Pat [pe]) aux (BasicOp (Opaque _ (Var stm_arr)))) acc | not $ primType $ typeOf pe = distributeSingleUnaryStm acc stm stm_arr $ \_ outerpat arr -> pure $ oneStm $ Let outerpat aux $ BasicOp $ Replicate mempty $ Var arr maybeDistributeStm stm@(Let _ aux (BasicOp (Rearrange perm stm_arr))) acc = distributeSingleUnaryStm acc stm stm_arr $ \nest outerpat arr -> do let r = length (snd nest) + 1 perm' = [0 .. r - 1] ++ map (+ r) perm -- We need to add a copy, because the original map nest -- will have produced an array without aliases, and so must we. arr' <- newVName $ baseString arr arr_t <- lookupType arr pure $ stmsFromList [ Let (Pat [PatElem arr' arr_t]) aux $ BasicOp $ Replicate mempty $ Var arr, Let outerpat aux $ BasicOp $ Rearrange perm' arr' ] maybeDistributeStm stm@(Let _ aux (BasicOp (Reshape k reshape stm_arr))) acc = distributeSingleUnaryStm acc stm stm_arr $ \nest outerpat arr -> do let reshape' = Shape (kernelNestWidths nest) <> reshape pure $ oneStm $ Let outerpat aux $ BasicOp $ Reshape k reshape' arr maybeDistributeStm stm@(Let pat aux (BasicOp (Update _ arr slice (Var v)))) acc | not $ null $ sliceDims slice = distributeSingleStm acc stm >>= \case Just (kernels, res, nest, acc') | map resSubExp res == map Var (patNames $ stmPat stm), Just (perm, pat_unused) <- permutationAndMissing pat res -> do addPostStms kernels localScope (typeEnvFromDistAcc acc') $ do nest' <- expandKernelNest pat_unused nest postStm =<< segmentedUpdateKernel nest' perm (stmAuxCerts aux) arr slice v pure acc' _ -> addStmToAcc stm acc maybeDistributeStm stm@(Let _ aux (BasicOp (Concat d (x :| xs) w))) acc = distributeSingleStm acc stm >>= \case Just (kernels, _, nest, acc') -> localScope (typeEnvFromDistAcc acc') $ segmentedConcat nest >>= kernelOrNot (stmAuxCerts aux) stm acc kernels acc' _ -> addStmToAcc stm acc where segmentedConcat nest = isSegmentedOp nest [0] mempty mempty [] (x : xs) $ \pat _ _ _ (x' : xs') -> let d' = d + length (snd nest) + 1 in addStm $ Let pat aux $ BasicOp $ Concat d' (x' :| xs') w maybeDistributeStm stm acc = addStmToAcc stm acc distributeSingleUnaryStm :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => DistAcc rep -> Stm SOACS -> VName -> (KernelNest -> Pat Type -> VName -> DistNestT rep m (Stms rep)) -> DistNestT rep m (DistAcc rep) distributeSingleUnaryStm acc stm stm_arr f = distributeSingleStm acc stm >>= \case Just (kernels, res, nest, acc') | map resSubExp res == map Var (patNames $ stmPat stm), (outer, _) <- nest, [(_, arr)] <- loopNestingParamsAndArrs outer, boundInKernelNest nest `namesIntersection` freeIn stm == oneName stm_arr, perfectlyMapped arr nest -> do addPostStms kernels let outerpat = loopNestingPat $ fst nest localScope (typeEnvFromDistAcc acc') $ do postStm =<< f nest outerpat arr pure acc' _ -> addStmToAcc stm acc where perfectlyMapped arr (outer, nest) | [(p, arr')] <- loopNestingParamsAndArrs outer, arr == arr' = case nest of [] -> paramName p == stm_arr x : xs -> perfectlyMapped (paramName p) (x, xs) | otherwise = False distribute :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => DistAcc rep -> DistNestT rep m (DistAcc rep) distribute acc = fromMaybe acc <$> distributeIfPossible acc mkSegLevel :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => DistNestT rep m (MkSegLevel rep (DistNestT rep m)) mkSegLevel = do mk_lvl <- asks distSegLevel pure $ \w desc r -> do (lvl, stms) <- lift $ liftInner $ runBuilderT' $ mk_lvl w desc r addStms stms pure lvl distributeIfPossible :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => DistAcc rep -> DistNestT rep m (Maybe (DistAcc rep)) distributeIfPossible acc = do nest <- asks distNest mk_lvl <- mkSegLevel tryDistribute mk_lvl nest (distTargets acc) (distStms acc) >>= \case Nothing -> pure Nothing Just (targets, kernel) -> do postStm kernel pure $ Just DistAcc { distTargets = targets, distStms = mempty } distributeSingleStm :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => DistAcc rep -> Stm SOACS -> DistNestT rep m ( Maybe ( PostStms rep, Result, KernelNest, DistAcc rep ) ) distributeSingleStm acc stm = do nest <- asks distNest mk_lvl <- mkSegLevel tryDistribute mk_lvl nest (distTargets acc) (distStms acc) >>= \case Nothing -> pure Nothing Just (targets, distributed_stms) -> tryDistributeStm nest targets stm >>= \case Nothing -> pure Nothing Just (res, targets', new_kernel_nest) -> pure $ Just ( PostStms distributed_stms, res, new_kernel_nest, DistAcc { distTargets = targets', distStms = mempty } ) segmentedScatterKernel :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => KernelNest -> [Int] -> Pat Type -> Certs -> SubExp -> Lambda rep -> [VName] -> [(Shape, Int, VName)] -> DistNestT rep m (Stms rep) segmentedScatterKernel nest perm scatter_pat cs scatter_w lam ivs dests = do -- We replicate some of the checking done by 'isSegmentedOp', but -- things are different because a scatter is not a reduction or -- scan. -- -- First, pretend that the scatter is also part of the nesting. The -- KernelNest we produce here is technically not sensible, but it's -- good enough for flatKernel to work. let nesting = MapNesting scatter_pat (StmAux cs mempty ()) scatter_w $ zip (lambdaParams lam) ivs nest' = pushInnerKernelNesting (scatter_pat, bodyResult $ lambdaBody lam) nesting nest (ispace, kernel_inps) <- flatKernel nest' let (as_ws, as_ns, as) = unzip3 dests indexes = zipWith (*) as_ns $ map length as_ws -- The input/output arrays ('as') _must_ correspond to some kernel -- input, or else the original nested scatter would have been -- ill-typed. Find them. as_inps <- mapM (findInput kernel_inps) as mk_lvl <- mkSegLevel let (is, vs) = splitAt (sum indexes) $ bodyResult $ lambdaBody lam (is', k_body_stms) <- runBuilder $ do addStms $ bodyStms $ lambdaBody lam pure is let grouped = groupScatterResults (zip3 as_ws as_ns as_inps) (is' ++ vs) (_, dest_arrs, _) = unzip3 grouped dest_arrs_ts <- mapM (lookupType . kernelInputArray) dest_arrs let k_body = KernelBody () k_body_stms (zipWith (inPlaceReturn ispace) dest_arrs_ts grouped) -- Remove unused kernel inputs, since some of these might -- reference the array we are scattering into. kernel_inps' = filter ((`nameIn` freeIn k_body) . kernelInputName) kernel_inps (k, k_stms) <- mapKernel mk_lvl ispace kernel_inps' dest_arrs_ts k_body traverse renameStm <=< runBuilder_ $ do addStms k_stms let pat = Pat . rearrangeShape perm $ patElems $ loopNestingPat $ fst nest letBind pat $ Op $ segOp k where findInput kernel_inps a = maybe bad pure $ find ((== a) . kernelInputName) kernel_inps bad = error "Ill-typed nested scatter encountered." inPlaceReturn ispace arr_t (_, inp, is_vs) = WriteReturns ( foldMap (foldMap resCerts . fst) is_vs <> foldMap (resCerts . snd) is_vs ) (kernelInputArray inp) [ (fullSlice arr_t $ map DimFix $ map Var (init gtids) ++ map resSubExp is, resSubExp v) | (is, v) <- is_vs ] where (gtids, _ws) = unzip ispace segmentedUpdateKernel :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => KernelNest -> [Int] -> Certs -> VName -> Slice SubExp -> VName -> DistNestT rep m (Stms rep) segmentedUpdateKernel nest perm cs arr slice v = do (base_ispace, kernel_inps) <- flatKernel nest let slice_dims = sliceDims slice slice_gtids <- replicateM (length slice_dims) (newVName "gtid_slice") let ispace = base_ispace ++ zip slice_gtids slice_dims ((dest_t, res), kstms) <- runBuilder $ do -- Compute indexes into full array. v' <- certifying cs . letSubExp "v" . BasicOp . Index v $ Slice (map (DimFix . Var) slice_gtids) slice_is <- traverse (toSubExp "index") $ fixSlice (fmap pe64 slice) $ map (pe64 . Var) slice_gtids let write_is = map (Var . fst) base_ispace ++ slice_is arr' = maybe (error "incorrectly typed Update") kernelInputArray $ find ((== arr) . kernelInputName) kernel_inps arr_t <- lookupType arr' pure ( arr_t, WriteReturns mempty arr' [(Slice $ map DimFix write_is, v')] ) -- Remove unused kernel inputs, since some of these might -- reference the array we are scattering into. let kernel_inps' = filter ((`nameIn` (freeIn kstms <> freeIn res)) . kernelInputName) kernel_inps mk_lvl <- mkSegLevel (k, prestms) <- mapKernel mk_lvl ispace kernel_inps' [dest_t] $ KernelBody () kstms [res] traverse renameStm <=< runBuilder_ $ do addStms prestms let pat = Pat . rearrangeShape perm $ patElems $ loopNestingPat $ fst nest letBind pat $ Op $ segOp k segmentedGatherKernel :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => KernelNest -> Certs -> VName -> Slice SubExp -> DistNestT rep m (Stms rep) segmentedGatherKernel nest cs arr slice = do let slice_dims = sliceDims slice slice_gtids <- replicateM (length slice_dims) (newVName "gtid_slice") (base_ispace, kernel_inps) <- flatKernel nest let ispace = base_ispace ++ zip slice_gtids slice_dims ((res_t, res), kstms) <- runBuilder $ do -- Compute indexes into full array. slice'' <- subExpSlice . sliceSlice (primExpSlice slice) $ primExpSlice $ Slice $ map (DimFix . Var) slice_gtids v' <- certifying cs $ letSubExp "v" $ BasicOp $ Index arr slice'' v_t <- subExpType v' pure (v_t, Returns ResultMaySimplify mempty v') mk_lvl <- mkSegLevel (k, prestms) <- mapKernel mk_lvl ispace kernel_inps [res_t] $ KernelBody () kstms [res] traverse renameStm <=< runBuilder_ $ do addStms prestms let pat = Pat $ patElems $ loopNestingPat $ fst nest letBind pat $ Op $ segOp k segmentedHistKernel :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => KernelNest -> [Int] -> Certs -> SubExp -> [SOACS.HistOp SOACS] -> Lambda rep -> [VName] -> DistNestT rep m (Stms rep) segmentedHistKernel nest perm cs hist_w ops lam arrs = do -- We replicate some of the checking done by 'isSegmentedOp', but -- things are different because a Hist is not a reduction or -- scan. (ispace, inputs) <- flatKernel nest let orig_pat = Pat . rearrangeShape perm $ patElems $ loopNestingPat $ fst nest -- The input/output arrays _must_ correspond to some kernel input, -- or else the original nested Hist would have been ill-typed. -- Find them. ops' <- forM ops $ \(SOACS.HistOp num_bins rf dests nes op) -> SOACS.HistOp num_bins rf <$> mapM (fmap kernelInputArray . findInput inputs) dests <*> pure nes <*> pure op mk_lvl <- asks distSegLevel onLambda <- asks distOnSOACSLambda let onLambda' = fmap fst . runBuilder . onLambda liftInner $ runBuilderT'_ $ do -- It is important not to launch unnecessarily many threads for -- histograms, because it may mean we unnecessarily need to reduce -- subhistograms as well. lvl <- mk_lvl (hist_w : map snd ispace) "seghist" $ NoRecommendation SegNoVirt addStms =<< histKernel onLambda' lvl orig_pat ispace inputs cs hist_w ops' lam arrs where findInput kernel_inps a = maybe bad pure $ find ((== a) . kernelInputName) kernel_inps bad = error "Ill-typed nested Hist encountered." histKernel :: (MonadBuilder m, DistRep (Rep m)) => (Lambda SOACS -> m (Lambda (Rep m))) -> SegOpLevel (Rep m) -> Pat Type -> [(VName, SubExp)] -> [KernelInput] -> Certs -> SubExp -> [SOACS.HistOp SOACS] -> Lambda (Rep m) -> [VName] -> m (Stms (Rep m)) histKernel onLambda lvl orig_pat ispace inputs cs hist_w ops lam arrs = runBuilderT'_ $ do ops' <- forM ops $ \(SOACS.HistOp dest_shape rf dests nes op) -> do (op', nes', shape) <- determineReduceOp op nes op'' <- lift $ onLambda op' pure $ HistOp dest_shape rf dests nes' shape op'' let isDest = flip elem $ concatMap histDest ops' inputs' = filter (not . isDest . kernelInputArray) inputs certifying cs $ addStms =<< traverse renameStm =<< segHist lvl orig_pat hist_w ispace inputs' ops' lam arrs determineReduceOp :: (MonadBuilder m) => Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape) determineReduceOp lam nes = -- FIXME? We are assuming that the accumulator is a replicate, and -- we fish out its value in a gross way. case mapM subExpVar nes of Just ne_vs' -> do let (shape, lam') = isVectorMap lam nes' <- forM ne_vs' $ \ne_v -> do ne_v_t <- lookupType ne_v letSubExp "hist_ne" $ BasicOp $ Index ne_v $ fullSlice ne_v_t $ replicate (shapeRank shape) $ DimFix $ intConst Int64 0 pure (lam', nes', shape) Nothing -> pure (lam, nes, mempty) isVectorMap :: Lambda SOACS -> (Shape, Lambda SOACS) isVectorMap lam | [Let (Pat pes) _ (Op (Screma w arrs form))] <- stmsToList $ bodyStms $ lambdaBody lam, map resSubExp (bodyResult (lambdaBody lam)) == map (Var . patElemName) pes, Just map_lam <- isMapSOAC form, arrs == map paramName (lambdaParams lam) = let (shape, lam') = isVectorMap map_lam in (Shape [w] <> shape, lam') | otherwise = (mempty, lam) segmentedScanomapKernel :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => KernelNest -> [Int] -> Certs -> SubExp -> Lambda SOACS -> Lambda rep -> [SubExp] -> [VName] -> DistNestT rep m (Maybe (Stms rep)) segmentedScanomapKernel nest perm cs segment_size lam map_lam nes arrs = do mk_lvl <- asks distSegLevel onLambda <- asks distOnSOACSLambda let onLambda' = fmap fst . runBuilder . onLambda isSegmentedOp nest perm (freeIn lam) (freeIn map_lam) nes [] $ \pat ispace inps nes' _ -> do (lam', nes'', shape) <- determineReduceOp lam nes' lam'' <- onLambda' lam' let scan_op = SegBinOp Noncommutative lam'' nes'' shape lvl <- mk_lvl (segment_size : map snd ispace) "segscan" $ NoRecommendation SegNoVirt addStms =<< traverse renameStm =<< segScan lvl pat cs segment_size [scan_op] map_lam arrs ispace inps regularSegmentedRedomapKernel :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => KernelNest -> [Int] -> Certs -> SubExp -> Commutativity -> Lambda rep -> Lambda rep -> [SubExp] -> [VName] -> DistNestT rep m (Maybe (Stms rep)) regularSegmentedRedomapKernel nest perm cs segment_size comm lam map_lam nes arrs = do mk_lvl <- asks distSegLevel isSegmentedOp nest perm (freeIn lam) (freeIn map_lam) nes [] $ \pat ispace inps nes' _ -> do let red_op = SegBinOp comm lam nes' mempty lvl <- mk_lvl (segment_size : map snd ispace) "segred" $ NoRecommendation SegNoVirt addStms =<< traverse renameStm =<< segRed lvl pat cs segment_size [red_op] map_lam arrs ispace inps isSegmentedOp :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => KernelNest -> [Int] -> Names -> Names -> [SubExp] -> [VName] -> ( Pat Type -> [(VName, SubExp)] -> [KernelInput] -> [SubExp] -> [VName] -> BuilderT rep m () ) -> DistNestT rep m (Maybe (Stms rep)) isSegmentedOp nest perm free_in_op _free_in_fold_op nes arrs m = runMaybeT $ do -- We must verify that array inputs to the operation are inputs to -- the outermost loop nesting or free in the loop nest. Nothing -- free in the op may be bound by the nest. Furthermore, the -- neutral elements must be free in the loop nest. -- -- We must summarise any names from free_in_op that are bound in the -- nest, and describe how to obtain them given segment indices. let bound_by_nest = boundInKernelNest nest (ispace, kernel_inps) <- flatKernel nest when (free_in_op `namesIntersect` bound_by_nest) $ fail "Non-fold lambda uses nest-bound parameters." let indices = map fst ispace prepareNe (Var v) | v `nameIn` bound_by_nest = fail "Neutral element bound in nest" prepareNe ne = pure ne prepareArr arr = case find ((== arr) . kernelInputName) kernel_inps of Just inp | kernelInputIndices inp == map Var indices -> pure $ pure $ kernelInputArray inp Nothing | arr `notNameIn` bound_by_nest -> -- This input is something that is free inside -- the loop nesting. We will have to replicate -- it. pure $ letExp (baseString arr ++ "_repd") (BasicOp $ Replicate (Shape $ map snd ispace) $ Var arr) _ -> fail "Input not free, perfectly mapped, or outermost." nes' <- mapM prepareNe nes mk_arrs <- mapM prepareArr arrs lift $ liftInner $ runBuilderT'_ $ do nested_arrs <- sequence mk_arrs let pat = Pat . rearrangeShape perm $ patElems $ loopNestingPat $ fst nest m pat ispace kernel_inps nes' nested_arrs permutationAndMissing :: Pat Type -> Result -> Maybe ([Int], [PatElem Type]) permutationAndMissing (Pat pes) res = do let (_used, unused) = partition ((`nameIn` freeIn res) . patElemName) pes res' = map resSubExp res res_expanded = res' ++ map (Var . patElemName) unused perm <- map (Var . patElemName) pes `isPermutationOf` res_expanded pure (perm, unused) -- Add extra pattern elements to every kernel nesting level. expandKernelNest :: (MonadFreshNames m) => [PatElem Type] -> KernelNest -> m KernelNest expandKernelNest pes (outer_nest, inner_nests) = do let outer_size = loopNestingWidth outer_nest : map loopNestingWidth inner_nests inner_sizes = tails $ map loopNestingWidth inner_nests outer_nest' <- expandWith outer_nest outer_size inner_nests' <- zipWithM expandWith inner_nests inner_sizes pure (outer_nest', inner_nests') where expandWith nest dims = do pes' <- mapM (expandPatElemWith dims) pes pure nest { loopNestingPat = Pat $ patElems (loopNestingPat nest) <> pes' } expandPatElemWith dims pe = do name <- newVName $ baseString $ patElemName pe pure pe { patElemName = name, patElemDec = patElemType pe `arrayOfShape` Shape dims } kernelOrNot :: (MonadFreshNames m, DistRep rep) => Certs -> Stm SOACS -> DistAcc rep -> PostStms rep -> DistAcc rep -> Maybe (Stms rep) -> DistNestT rep m (DistAcc rep) kernelOrNot cs stm acc _ _ Nothing = addStmToAcc (certify cs stm) acc kernelOrNot cs _ _ kernels acc' (Just stms) = do addPostStms kernels postStm $ fmap (certify cs) stms pure acc' distributeMap :: (MonadFreshNames m, LocalScope rep m, DistRep rep) => MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep) distributeMap (MapLoop pat aux w lam arrs) acc = distribute =<< mapNesting pat aux w lam arrs (distribute =<< distributeMapBodyStms acc' lam_stms) where acc' = DistAcc { distTargets = pushInnerTarget (pat, bodyResult $ lambdaBody lam) $ distTargets acc, distStms = mempty } lam_stms = bodyStms $ lambdaBody lam futhark-0.25.27/src/Futhark/Pass/ExtractKernels/Distribution.hs000066400000000000000000000445301475065116200244220ustar00rootroot00000000000000{-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeFamilies #-} module Futhark.Pass.ExtractKernels.Distribution ( Target, Targets, ppTargets, singleTarget, outerTarget, innerTarget, pushInnerTarget, popInnerTarget, targetsScope, LoopNesting (..), ppLoopNesting, scopeOfLoopNesting, Nesting (..), Nestings, ppNestings, letBindInInnerNesting, singleNesting, pushInnerNesting, KernelNest, ppKernelNest, newKernel, innermostKernelNesting, pushKernelNesting, pushInnerKernelNesting, scopeOfKernelNest, kernelNestLoops, kernelNestWidths, boundInKernelNest, boundInKernelNests, flatKernel, constructKernel, tryDistribute, tryDistributeStm, ) where import Control.Monad import Control.Monad.RWS.Strict import Control.Monad.Trans.Maybe import Data.Bifunctor (second) import Data.Foldable import Data.List (elemIndex, sortOn) import Data.Map.Strict qualified as M import Data.Maybe import Futhark.IR import Futhark.IR.SegOp import Futhark.MonadFreshNames import Futhark.Pass.ExtractKernels.BlockedKernel ( DistRep, KernelInput (..), MkSegLevel, mapKernel, readKernelInput, ) import Futhark.Tools import Futhark.Transform.Rename import Futhark.Util import Futhark.Util.Log type Target = (Pat Type, Result) -- | First pair element is the very innermost ("current") target. In -- the list, the outermost target comes first. Invariant: Every -- element of a pattern must be present as the result of the -- immediately enclosing target. This is ensured by 'pushInnerTarget' -- by removing unused pattern elements. data Targets = Targets { _innerTarget :: Target, _outerTargets :: [Target] } ppTargets :: Targets -> String ppTargets (Targets target targets) = unlines $ map ppTarget $ targets ++ [target] where ppTarget (pat, res) = prettyString pat ++ " <- " ++ prettyString res singleTarget :: Target -> Targets singleTarget = flip Targets [] outerTarget :: Targets -> Target outerTarget (Targets inner_target []) = inner_target outerTarget (Targets _ (outer_target : _)) = outer_target innerTarget :: Targets -> Target innerTarget (Targets inner_target _) = inner_target pushOuterTarget :: Target -> Targets -> Targets pushOuterTarget target (Targets inner_target targets) = Targets inner_target (target : targets) pushInnerTarget :: Target -> Targets -> Targets pushInnerTarget (pat, res) (Targets inner_target targets) = Targets (pat', res') (targets ++ [inner_target]) where (pes', res') = unzip $ filter (used . fst) $ zip (patElems pat) res pat' = Pat pes' inner_used = freeIn $ snd inner_target used pe = patElemName pe `nameIn` inner_used popInnerTarget :: Targets -> Maybe (Target, Targets) popInnerTarget (Targets t ts) = case reverse ts of x : xs -> Just (t, Targets x $ reverse xs) [] -> Nothing targetScope :: (DistRep rep) => Target -> Scope rep targetScope = scopeOfPat . fst targetsScope :: (DistRep rep) => Targets -> Scope rep targetsScope (Targets t ts) = mconcat $ map targetScope $ t : ts data LoopNesting = MapNesting { loopNestingPat :: Pat Type, loopNestingAux :: StmAux (), loopNestingWidth :: SubExp, loopNestingParamsAndArrs :: [(Param Type, VName)] } deriving (Show) scopeOfLoopNesting :: (LParamInfo rep ~ Type) => LoopNesting -> Scope rep scopeOfLoopNesting = scopeOfLParams . map fst . loopNestingParamsAndArrs ppLoopNesting :: LoopNesting -> String ppLoopNesting (MapNesting _ _ _ params_and_arrs) = prettyString (map fst params_and_arrs) ++ " <- " ++ prettyString (map snd params_and_arrs) loopNestingParams :: LoopNesting -> [Param Type] loopNestingParams = map fst . loopNestingParamsAndArrs instance FreeIn LoopNesting where freeIn' (MapNesting pat aux w params_and_arrs) = freeIn' pat <> freeIn' aux <> freeIn' w <> freeIn' params_and_arrs data Nesting = Nesting { nestingLetBound :: Names, nestingLoop :: LoopNesting } deriving (Show) letBindInNesting :: Names -> Nesting -> Nesting letBindInNesting newnames (Nesting oldnames loop) = Nesting (oldnames <> newnames) loop -- ^ First pair element is the very innermost ("current") nest. In -- the list, the outermost nest comes first. type Nestings = (Nesting, [Nesting]) ppNestings :: Nestings -> String ppNestings (nesting, nestings) = unlines $ map ppNesting $ nestings ++ [nesting] where ppNesting (Nesting _ loop) = ppLoopNesting loop singleNesting :: Nesting -> Nestings singleNesting = (,[]) pushInnerNesting :: Nesting -> Nestings -> Nestings pushInnerNesting nesting (inner_nesting, nestings) = (nesting, nestings ++ [inner_nesting]) -- | Both parameters and let-bound. boundInNesting :: Nesting -> Names boundInNesting nesting = namesFromList (map paramName (loopNestingParams loop)) <> nestingLetBound nesting where loop = nestingLoop nesting letBindInInnerNesting :: Names -> Nestings -> Nestings letBindInInnerNesting names (nest, nestings) = (letBindInNesting names nest, nestings) -- | Note: first element is *outermost* nesting. This is different -- from the similar types elsewhere! type KernelNest = (LoopNesting, [LoopNesting]) ppKernelNest :: KernelNest -> String ppKernelNest (nesting, nestings) = unlines $ map ppLoopNesting $ nesting : nestings -- | Retrieve the innermost kernel nesting. innermostKernelNesting :: KernelNest -> LoopNesting innermostKernelNesting (nest, nests) = fromMaybe nest $ maybeHead $ reverse nests -- | Add new outermost nesting, pushing the current outermost to the -- list, also taking care to swap patterns if necessary. pushKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest pushKernelNesting target newnest (nest, nests) = ( fixNestingPatOrder newnest target (loopNestingPat nest), nest : nests ) -- | Add new innermost nesting, pushing the current outermost to the -- list. It is important that the 'Target' has the right order -- (non-permuted compared to what is expected by the outer nests). pushInnerKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest pushInnerKernelNesting target newnest (nest, nests) = (nest, nests ++ [fixNestingPatOrder newnest target (loopNestingPat innermost)]) where innermost = case reverse nests of [] -> nest n : _ -> n fixNestingPatOrder :: LoopNesting -> Target -> Pat Type -> LoopNesting fixNestingPatOrder nest (_, res) inner_pat = nest {loopNestingPat = basicPat pat'} where pat = loopNestingPat nest pat' = map fst fixed_target fixed_target = sortOn posInInnerPat $ zip (patIdents pat) res posInInnerPat (_, SubExpRes _ (Var v)) = fromMaybe 0 $ elemIndex v $ patNames inner_pat posInInnerPat _ = 0 newKernel :: LoopNesting -> KernelNest newKernel nest = (nest, []) kernelNestLoops :: KernelNest -> [LoopNesting] kernelNestLoops (loop, loops) = loop : loops scopeOfKernelNest :: (LParamInfo rep ~ Type) => KernelNest -> Scope rep scopeOfKernelNest = foldMap scopeOfLoopNesting . kernelNestLoops boundInKernelNest :: KernelNest -> Names boundInKernelNest = mconcat . boundInKernelNests boundInKernelNests :: KernelNest -> [Names] boundInKernelNests = map (namesFromList . map (paramName . fst) . loopNestingParamsAndArrs) . kernelNestLoops kernelNestWidths :: KernelNest -> [SubExp] kernelNestWidths = map loopNestingWidth . kernelNestLoops constructKernel :: (DistRep rep, MonadFreshNames m, LocalScope rep m) => MkSegLevel rep m -> KernelNest -> Body rep -> m (Stm rep, Stms rep) constructKernel mk_lvl kernel_nest inner_body = runBuilderT' $ do (ispace, inps) <- flatKernel kernel_nest let aux = loopNestingAux first_nest ispace_scope = M.fromList $ map ((,IndexName Int64) . fst) ispace pat = loopNestingPat first_nest rts = map (stripArray (length ispace)) $ patTypes pat inner_body' <- fmap (uncurry (flip (KernelBody ()))) $ runBuilder . localScope ispace_scope $ do mapM_ readKernelInput $ filter inputIsUsed inps res <- bodyBind inner_body forM res $ \(SubExpRes cs se) -> pure $ Returns ResultMaySimplify cs se (segop, aux_stms) <- lift $ mapKernel mk_lvl ispace [] rts inner_body' addStms aux_stms pure $ Let pat aux $ Op $ segOp segop where first_nest = fst kernel_nest inputIsUsed input = kernelInputName input `nameIn` freeIn inner_body -- | Flatten a kernel nesting to: -- -- (1) The index space. -- -- (2) The kernel inputs - note that some of these may be unused. flatKernel :: (MonadFreshNames m) => KernelNest -> m ([(VName, SubExp)], [KernelInput]) flatKernel (MapNesting _ _ nesting_w params_and_arrs, []) = do i <- newVName "gtid" let inps = [ KernelInput pname ptype arr [Var i] | (Param _ pname ptype, arr) <- params_and_arrs ] pure ([(i, nesting_w)], inps) flatKernel (MapNesting _ _ nesting_w params_and_arrs, nest : nests) = do i <- newVName "gtid" (ispace, inps) <- flatKernel (nest, nests) let inps' = map fixupInput inps isParam inp = snd <$> find ((== kernelInputArray inp) . paramName . fst) params_and_arrs fixupInput inp | Just arr <- isParam inp = inp { kernelInputArray = arr, kernelInputIndices = Var i : kernelInputIndices inp } | otherwise = inp pure ((i, nesting_w) : ispace, extra_inps i <> inps') where extra_inps i = [ KernelInput pname ptype arr [Var i] | (Param _ pname ptype, arr) <- params_and_arrs ] -- | Description of distribution to do. data DistributionBody = DistributionBody { distributionTarget :: Targets, distributionFreeInBody :: Names, distributionIdentityMap :: M.Map VName Ident, -- | Also related to avoiding identity mapping. distributionExpandTarget :: Target -> Target } distributionInnerPat :: DistributionBody -> Pat Type distributionInnerPat = fst . innerTarget . distributionTarget distributionBodyFromStms :: (ASTRep rep) => Targets -> Stms rep -> (DistributionBody, Result) distributionBodyFromStms (Targets (inner_pat, inner_res) targets) stms = let bound_by_stms = namesFromList $ M.keys $ scopeOf stms (inner_pat', inner_res', inner_identity_map, inner_expand_target) = removeIdentityMappingGeneral bound_by_stms inner_pat inner_res free = (foldMap freeIn stms <> freeIn (map resCerts inner_res)) `namesSubtract` bound_by_stms in ( DistributionBody { distributionTarget = Targets (inner_pat', inner_res') targets, distributionFreeInBody = free, distributionIdentityMap = inner_identity_map, distributionExpandTarget = inner_expand_target }, inner_res' ) distributionBodyFromStm :: (ASTRep rep) => Targets -> Stm rep -> (DistributionBody, Result) distributionBodyFromStm targets stm = distributionBodyFromStms targets $ oneStm stm createKernelNest :: forall rep m. (MonadFreshNames m, HasScope rep m) => Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest)) createKernelNest (inner_nest, nests) distrib_body = do let Targets target targets = distributionTarget distrib_body unless (length nests == length targets) $ error $ "Nests and targets do not match!\n" ++ "nests: " ++ ppNestings (inner_nest, nests) ++ "\ntargets:" ++ ppTargets (Targets target targets) runMaybeT $ fmap prepare $ recurse $ zip nests targets where prepare (x, _, z) = (z, x) bound_in_nest = mconcat $ map boundInNesting $ inner_nest : nests distributableType = (== mempty) . namesIntersection bound_in_nest . freeIn . arrayDims distributeAtNesting :: Nesting -> Pat Type -> (LoopNesting -> KernelNest, Names) -> M.Map VName Ident -> [Ident] -> (Target -> Targets) -> MaybeT m (KernelNest, Names, Targets) distributeAtNesting (Nesting nest_let_bound nest) pat (add_to_kernel, free_in_kernel) identity_map inner_returned_arrs addTarget = do let nest'@(MapNesting _ aux w params_and_arrs) = removeUnusedNestingParts free_in_kernel nest (params, arrs) = unzip params_and_arrs param_names = namesFromList $ map paramName params free_in_kernel' = (freeIn nest' <> free_in_kernel) `namesSubtract` param_names required_from_nest = free_in_kernel' `namesIntersection` nest_let_bound required_from_nest_idents <- forM (namesToList required_from_nest) $ \name -> do t <- lift $ lookupType name pure $ Ident name t (free_params, free_arrs, bind_in_target) <- fmap unzip3 $ forM (inner_returned_arrs ++ required_from_nest_idents) $ \(Ident pname ptype) -> case M.lookup pname identity_map of Nothing -> do arr <- newIdent (baseString pname ++ "_r") $ arrayOfRow ptype w pure ( Param mempty pname ptype, arr, True ) Just arr -> pure ( Param mempty pname ptype, arr, False ) let free_arrs_pat = basicPat $ map snd $ filter fst $ zip bind_in_target free_arrs free_params_pat = map snd $ filter fst $ zip bind_in_target free_params (actual_params, actual_arrs) = ( params ++ free_params, arrs ++ map identName free_arrs ) actual_param_names = namesFromList $ map paramName actual_params nest'' = removeUnusedNestingParts free_in_kernel $ MapNesting pat aux w $ zip actual_params actual_arrs free_in_kernel'' = (freeIn nest'' <> free_in_kernel) `namesSubtract` actual_param_names unless ( all (distributableType . paramType) $ loopNestingParams nest'' ) $ fail "Would induce irregular array" pure ( add_to_kernel nest'', free_in_kernel'', addTarget (free_arrs_pat, varsRes $ map paramName free_params_pat) ) recurse :: [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets) recurse [] = distributeAtNesting inner_nest (distributionInnerPat distrib_body) ( newKernel, distributionFreeInBody distrib_body `namesIntersection` bound_in_nest ) (distributionIdentityMap distrib_body) [] $ singleTarget . distributionExpandTarget distrib_body recurse ((nest, (pat, res)) : nests') = do (kernel@(outer, _), kernel_free, kernel_targets) <- recurse nests' let (pat', res', identity_map, expand_target) = removeIdentityMappingFromNesting (namesFromList $ patNames $ loopNestingPat outer) pat res distributeAtNesting nest pat' ( \k -> pushKernelNesting (pat', res') k kernel, kernel_free ) identity_map (patIdents $ fst $ outerTarget kernel_targets) ((`pushOuterTarget` kernel_targets) . expand_target) removeUnusedNestingParts :: Names -> LoopNesting -> LoopNesting removeUnusedNestingParts used (MapNesting pat aux w params_and_arrs) = MapNesting pat aux w $ zip used_params used_arrs where (params, arrs) = unzip params_and_arrs (used_params, used_arrs) = unzip $ filter ((`nameIn` used) . paramName . fst) $ zip params arrs removeIdentityMappingGeneral :: Names -> Pat Type -> Result -> ( Pat Type, Result, M.Map VName Ident, Target -> Target ) removeIdentityMappingGeneral bound pat res = let (identities, not_identities) = mapEither isIdentity $ zip (patElems pat) res (not_identity_patElems, not_identity_res) = unzip not_identities (identity_patElems, identity_res) = unzip identities expandTarget (tpat, tres) = ( Pat $ patElems tpat ++ identity_patElems, tres ++ map (uncurry SubExpRes . second Var) identity_res ) identity_map = M.fromList $ zip (map snd identity_res) $ map patElemIdent identity_patElems in ( Pat not_identity_patElems, not_identity_res, identity_map, expandTarget ) where isIdentity (patElem, SubExpRes _ (Var v)) | v `notNameIn` bound = Left (patElem, (mempty, v)) isIdentity x = Right x removeIdentityMappingFromNesting :: Names -> Pat Type -> Result -> ( Pat Type, Result, M.Map VName Ident, Target -> Target ) removeIdentityMappingFromNesting bound_in_nesting pat res = let (pat', res', identity_map, expand_target) = removeIdentityMappingGeneral bound_in_nesting pat res in (pat', res', identity_map, expand_target) tryDistribute :: ( DistRep rep, MonadFreshNames m, LocalScope rep m, MonadLogger m ) => MkSegLevel rep m -> Nestings -> Targets -> Stms rep -> m (Maybe (Targets, Stms rep)) tryDistribute _ _ targets stms | null stms = -- No point in distributing an empty kernel. pure $ Just (targets, mempty) tryDistribute mk_lvl nest targets stms = createKernelNest nest dist_body >>= \case Just (targets', distributed) -> do (kernel_stm, w_stms) <- localScope (targetsScope targets') $ constructKernel mk_lvl distributed $ mkBody stms inner_body_res distributed' <- renameStm kernel_stm logMsg $ "distributing\n" ++ unlines (map prettyString $ stmsToList stms) ++ prettyString (snd $ innerTarget targets) ++ "\nas\n" ++ prettyString distributed' ++ "\ndue to targets\n" ++ ppTargets targets ++ "\nand with new targets\n" ++ ppTargets targets' pure $ Just (targets', w_stms <> oneStm distributed') Nothing -> pure Nothing where (dist_body, inner_body_res) = distributionBodyFromStms targets stms tryDistributeStm :: (MonadFreshNames m, HasScope t m, ASTRep rep) => Nestings -> Targets -> Stm rep -> m (Maybe (Result, Targets, KernelNest)) tryDistributeStm nest targets stm = fmap addRes <$> createKernelNest nest dist_body where (dist_body, res) = distributionBodyFromStm targets stm addRes (targets', kernel_nest) = (res, targets', kernel_nest) futhark-0.25.27/src/Futhark/Pass/ExtractKernels/ISRWIM.hs000066400000000000000000000150321475065116200227500ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Interchanging scans with inner maps. module Futhark.Pass.ExtractKernels.ISRWIM ( iswim, irwim, rwimPossible, ) where import Control.Arrow (first) import Control.Monad import Futhark.IR.SOACS import Futhark.MonadFreshNames import Futhark.Tools -- | Interchange Scan With Inner Map. Tries to turn a @scan(map)@ into a -- @map(scan) iswim :: (MonadBuilder m, Rep m ~ SOACS) => Pat Type -> SubExp -> Lambda SOACS -> [(SubExp, VName)] -> Maybe (m ()) iswim res_pat w scan_fun scan_input | Just (map_pat, map_cs, map_w, map_fun) <- rwimPossible scan_fun = Just $ do let (accs, arrs) = unzip scan_input arrs' <- transposedArrays arrs accs' <- mapM (letExp "acc" . BasicOp . SubExp) accs let map_arrs' = accs' ++ arrs' (scan_acc_params, scan_elem_params) = splitAt (length arrs) $ lambdaParams scan_fun map_params = map removeParamOuterDim scan_acc_params ++ map (setParamOuterDimTo w) scan_elem_params map_rettype = map (setOuterDimTo w) $ lambdaReturnType scan_fun scan_params = lambdaParams map_fun scan_body = lambdaBody map_fun scan_rettype = lambdaReturnType map_fun scan_fun' = Lambda scan_params scan_rettype scan_body scan_input' = map (first Var) $ uncurry zip $ splitAt (length arrs') $ map paramName map_params (nes', scan_arrs) = unzip scan_input' scan_soac <- scanSOAC [Scan scan_fun' nes'] let map_body = mkBody ( oneStm $ Let (setPatOuterDimTo w map_pat) (defAux ()) $ Op $ Screma w scan_arrs scan_soac ) $ varsRes $ patNames map_pat map_fun' = Lambda map_params map_rettype map_body res_pat' <- fmap basicPat $ mapM (newIdent' (<> "_transposed") . transposeIdentType) $ patIdents res_pat addStm $ Let res_pat' (StmAux map_cs mempty ()) $ Op $ Screma map_w map_arrs' (mapSOAC map_fun') forM_ (zip (patIdents res_pat) (patIdents res_pat')) $ \(to, from) -> do let perm = [1, 0] ++ [2 .. arrayRank (identType from) - 1] addStm $ Let (basicPat [to]) (defAux ()) $ BasicOp $ Rearrange perm $ identName from | otherwise = Nothing -- | Interchange Reduce With Inner Map. Tries to turn a @reduce(map)@ into a -- @map(reduce) irwim :: (MonadBuilder m, Rep m ~ SOACS) => Pat Type -> SubExp -> Commutativity -> Lambda SOACS -> [(SubExp, VName)] -> Maybe (m ()) irwim res_pat w comm red_fun red_input | Just (map_pat, map_cs, map_w, map_fun) <- rwimPossible red_fun = Just $ do let (accs, arrs) = unzip red_input arrs' <- transposedArrays arrs -- FIXME? Can we reasonably assume that the accumulator is a -- replicate? We also assume that it is non-empty. let indexAcc (Var v) = do v_t <- lookupType v letSubExp "acc" $ BasicOp $ Index v $ fullSlice v_t [DimFix $ intConst Int64 0] indexAcc Constant {} = error "irwim: array accumulator is a constant." accs' <- mapM indexAcc accs let (_red_acc_params, red_elem_params) = splitAt (length arrs) $ lambdaParams red_fun map_rettype = map rowType $ lambdaReturnType red_fun map_params = map (setParamOuterDimTo w) red_elem_params red_params = lambdaParams map_fun red_body = lambdaBody map_fun red_rettype = lambdaReturnType map_fun red_fun' = Lambda red_params red_rettype red_body red_input' = zip accs' $ map paramName map_params red_pat = stripPatOuterDim map_pat map_body <- case irwim red_pat w comm red_fun' red_input' of Nothing -> do reduce_soac <- reduceSOAC [Reduce comm red_fun' $ map fst red_input'] pure $ mkBody ( oneStm $ Let red_pat (defAux ()) $ Op $ Screma w (map snd red_input') reduce_soac ) $ varsRes $ patNames map_pat Just m -> localScope (scopeOfLParams map_params) $ do map_body_stms <- collectStms_ m pure $ mkBody map_body_stms $ varsRes $ patNames map_pat let map_fun' = Lambda map_params map_rettype map_body addStm $ Let res_pat (StmAux map_cs mempty ()) $ Op $ Screma map_w arrs' $ mapSOAC map_fun' | otherwise = Nothing -- | Does this reduce operator contain an inner map, and if so, what -- does that map look like? rwimPossible :: Lambda SOACS -> Maybe (Pat Type, Certs, SubExp, Lambda SOACS) rwimPossible fun | Body _ stms res <- lambdaBody fun, [stm] <- stmsToList stms, -- Body has a single binding map_pat <- stmPat stm, map Var (patNames map_pat) == map resSubExp res, -- Returned verbatim Op (Screma map_w map_arrs form) <- stmExp stm, Just map_fun <- isMapSOAC form, map paramName (lambdaParams fun) == map_arrs = Just (map_pat, stmCerts stm, map_w, map_fun) | otherwise = Nothing transposedArrays :: (MonadBuilder m) => [VName] -> m [VName] transposedArrays arrs = forM arrs $ \arr -> do t <- lookupType arr let perm = [1, 0] ++ [2 .. arrayRank t - 1] letExp (baseString arr) $ BasicOp $ Rearrange perm arr removeParamOuterDim :: LParam SOACS -> LParam SOACS removeParamOuterDim param = let t = rowType $ paramType param in param {paramDec = t} setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS setParamOuterDimTo w param = let t = setOuterDimTo w $ paramType param in param {paramDec = t} setIdentOuterDimTo :: SubExp -> Ident -> Ident setIdentOuterDimTo w ident = let t = setOuterDimTo w $ identType ident in ident {identType = t} setOuterDimTo :: SubExp -> Type -> Type setOuterDimTo w t = arrayOfRow (rowType t) w setPatOuterDimTo :: SubExp -> Pat Type -> Pat Type setPatOuterDimTo w pat = basicPat $ map (setIdentOuterDimTo w) $ patIdents pat transposeIdentType :: Ident -> Ident transposeIdentType ident = ident {identType = transposeType $ identType ident} stripIdentOuterDim :: Ident -> Ident stripIdentOuterDim ident = ident {identType = rowType $ identType ident} stripPatOuterDim :: Pat Type -> Pat Type stripPatOuterDim pat = basicPat $ map stripIdentOuterDim $ patIdents pat futhark-0.25.27/src/Futhark/Pass/ExtractKernels/Interchange.hs000066400000000000000000000306321475065116200241700ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | It is well known that fully parallel loops can always be -- interchanged inwards with a sequential loop. This module -- implements that transformation. -- -- This is also where we implement loop-switching (for branches), -- which is semantically similar to interchange. module Futhark.Pass.ExtractKernels.Interchange ( SeqLoop (..), interchangeLoops, Branch (..), interchangeBranch, WithAccStm (..), interchangeWithAcc, ) where import Control.Monad import Data.List (find) import Data.Maybe import Futhark.IR.SOACS import Futhark.MonadFreshNames import Futhark.Pass.ExtractKernels.Distribution ( KernelNest, LoopNesting (..), kernelNestLoops, scopeOfKernelNest, ) import Futhark.Tools import Futhark.Transform.Rename import Futhark.Util (splitFromEnd) -- | An encoding of a sequential do-loop with no existential context, -- alongside its result pattern. data SeqLoop = SeqLoop [Int] (Pat Type) [(FParam SOACS, SubExp)] LoopForm (Body SOACS) loopPerm :: SeqLoop -> [Int] loopPerm (SeqLoop perm _ _ _ _) = perm seqLoopStm :: SeqLoop -> Stm SOACS seqLoopStm (SeqLoop _ pat merge form body) = Let pat (defAux ()) $ Loop merge form body interchangeLoop :: (MonadBuilder m, Rep m ~ SOACS) => (VName -> Maybe VName) -> SeqLoop -> LoopNesting -> m SeqLoop interchangeLoop isMapParameter (SeqLoop perm loop_pat merge form body) (MapNesting pat aux w params_and_arrs) = do merge_expanded <- localScope (scopeOfLParams $ map fst params_and_arrs) $ mapM expand merge let loop_pat_expanded = Pat $ map expandPatElem $ patElems loop_pat new_params = [Param attrs pname $ fromDecl ptype | (Param attrs pname ptype, _) <- merge] new_arrs = map (paramName . fst) merge_expanded rettype = map rowType $ patTypes loop_pat_expanded -- If the map consumes something that is bound outside the loop -- (i.e. is not a merge parameter), we have to copy() it. As a -- small simplification, we just remove the parameter outright if -- it is not used anymore. This might happen if the parameter was -- used just as the inital value of a merge parameter. ((params', arrs'), pre_copy_stms) <- runBuilder $ localScope (scopeOfLParams new_params) $ unzip . catMaybes <$> mapM copyOrRemoveParam params_and_arrs let lam = Lambda (params' <> new_params) rettype body map_stm = Let loop_pat_expanded aux $ Op $ Screma w (arrs' <> new_arrs) (mapSOAC lam) res = varsRes $ patNames loop_pat_expanded pat' = Pat $ rearrangeShape perm $ patElems pat pure $ SeqLoop perm pat' merge_expanded form $ mkBody (pre_copy_stms <> oneStm map_stm) res where free_in_body = freeIn body copyOrRemoveParam (param, arr) | paramName param `notNameIn` free_in_body = pure Nothing | otherwise = pure $ Just (param, arr) expandedInit _ (Var v) | Just arr <- isMapParameter v = pure $ Var arr expandedInit param_name se = letSubExp (param_name <> "_expanded_init") $ BasicOp $ Replicate (Shape [w]) se expand (merge_param, merge_init) = do expanded_param <- newParam (param_name <> "_expanded") $ -- FIXME: Unique here is a hack to make sure the copy from -- makeCopyInitial is not prematurely simplified away. -- It'd be better to fix this somewhere else... arrayOf (paramDeclType merge_param) (Shape [w]) Unique expanded_init <- expandedInit param_name merge_init pure (expanded_param, expanded_init) where param_name = baseString $ paramName merge_param expandPatElem (PatElem name t) = PatElem name $ arrayOfRow t w -- We need to copy some initial arguments because otherwise the result -- of the loop might alias the input (if the number of iterations is -- 0), which is a problem if the result is consumed. maybeCopyInitial :: (MonadBuilder m) => (VName -> Bool) -> SeqLoop -> m SeqLoop maybeCopyInitial isMapInput (SeqLoop perm loop_pat merge form body) = SeqLoop perm loop_pat <$> mapM f merge <*> pure form <*> pure body where f (p, Var arg) | isMapInput arg, Array {} <- paramType p = (p,) <$> letSubExp (baseString (paramName p) <> "_inter_copy") (BasicOp $ Replicate mempty $ Var arg) f (p, arg) = pure (p, arg) manifestMaps :: [LoopNesting] -> [VName] -> Stms SOACS -> ([VName], Stms SOACS) manifestMaps [] res stms = (res, stms) manifestMaps (n : ns) res stms = let (res', stms') = manifestMaps ns res stms (params, arrs) = unzip $ loopNestingParamsAndArrs n lam = Lambda params (map rowType $ patTypes (loopNestingPat n)) (mkBody stms' $ varsRes res') in ( patNames $ loopNestingPat n, oneStm $ Let (loopNestingPat n) (loopNestingAux n) $ Op $ Screma (loopNestingWidth n) arrs (mapSOAC lam) ) -- | Given a (parallel) map nesting and an inner sequential loop, move -- the maps inside the sequential loop. The result is several -- statements - one of these will be the loop, which will then contain -- statements with @map@ expressions. interchangeLoops :: (MonadFreshNames m, HasScope SOACS m) => KernelNest -> SeqLoop -> m (Stms SOACS) interchangeLoops full_nest = recurse (kernelNestLoops full_nest) where recurse nest loop | (ns, [n]) <- splitFromEnd 1 nest = do let isMapParameter v = snd <$> find ((== v) . paramName . fst) (loopNestingParamsAndArrs n) isMapInput v = v `elem` map snd (loopNestingParamsAndArrs n) (loop', stms) <- runBuilder . localScope (scopeOfKernelNest full_nest) $ maybeCopyInitial isMapInput =<< interchangeLoop isMapParameter loop n -- Only safe to continue interchanging if we didn't need to add -- any new statements; otherwise we manifest the remaining nests -- as Maps and hand them back to the flattener. if null stms then recurse ns loop' else let loop_stm = seqLoopStm loop' names = rearrangeShape (loopPerm loop') (patNames (stmPat loop_stm)) in pure $ snd $ manifestMaps ns names $ stms <> oneStm loop_stm | otherwise = pure $ oneStm $ seqLoopStm loop -- | An encoding of a branch with alongside its result pattern. data Branch = Branch [Int] (Pat Type) [SubExp] [Case (Body SOACS)] (Body SOACS) (MatchDec (BranchType SOACS)) branchStm :: Branch -> Stm SOACS branchStm (Branch _ pat cond cases defbody ret) = Let pat (defAux ()) $ Match cond cases defbody ret interchangeBranch1 :: (MonadFreshNames m, HasScope SOACS m) => Branch -> LoopNesting -> m Branch interchangeBranch1 (Branch perm branch_pat cond cases defbody (MatchDec ret if_sort)) (MapNesting pat aux w params_and_arrs) = do let ret' = map (`arrayOfRow` Free w) ret pat' = Pat $ rearrangeShape perm $ patElems pat (params, arrs) = unzip params_and_arrs lam_ret = rearrangeShape perm $ map rowType $ patTypes pat branch_pat' = Pat $ map (fmap (`arrayOfRow` w)) $ patElems branch_pat mkBranch branch = (renameBody =<<) $ runBodyBuilder $ do let lam = Lambda params lam_ret branch addStm $ Let branch_pat' aux $ Op $ Screma w arrs $ mapSOAC lam pure $ varsRes $ patNames branch_pat' cases' <- mapM (traverse mkBranch) cases defbody' <- mkBranch defbody pure . Branch [0 .. patSize pat - 1] pat' cond cases' defbody' $ MatchDec ret' if_sort -- | Given a (parallel) map nesting and an inner branch, move the maps -- inside the branch. The result is the resulting branch expression, -- which will then contain statements with @map@ expressions. interchangeBranch :: (MonadFreshNames m, HasScope SOACS m) => KernelNest -> Branch -> m (Stm SOACS) interchangeBranch nest loop = branchStm <$> foldM interchangeBranch1 loop (reverse $ kernelNestLoops nest) -- | An encoding of a WithAcc with alongside its result pattern. data WithAccStm = WithAccStm [Int] (Pat Type) [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] (Lambda SOACS) withAccStm :: WithAccStm -> Stm SOACS withAccStm (WithAccStm _ pat inputs lam) = Let pat (defAux ()) $ WithAcc inputs lam interchangeWithAcc1 :: (MonadFreshNames m, LocalScope SOACS m) => WithAccStm -> LoopNesting -> m WithAccStm interchangeWithAcc1 (WithAccStm perm _withacc_pat inputs acc_lam) (MapNesting map_pat map_aux w params_and_arrs) = do inputs' <- mapM onInput inputs lam_params' <- newAccLamParams $ lambdaParams acc_lam iota_p <- newParam "iota_p" $ Prim int64 acc_lam' <- trLam (Var (paramName iota_p)) <=< runLambdaBuilder lam_params' $ do let acc_params = drop (length inputs) lam_params' orig_acc_params = drop (length inputs) $ lambdaParams acc_lam iota_w <- letExp "acc_inter_iota" . BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64 let (params, arrs) = unzip params_and_arrs maplam_ret = lambdaReturnType acc_lam maplam = Lambda (iota_p : orig_acc_params ++ params) maplam_ret (lambdaBody acc_lam) auxing map_aux . fmap subExpsRes . letTupExp' "withacc_inter" $ Op $ Screma w (iota_w : map paramName acc_params ++ arrs) (mapSOAC maplam) let pat = Pat $ rearrangeShape perm $ patElems map_pat pure $ WithAccStm perm pat inputs' acc_lam' where newAccLamParams ps = do let (cert_ps, acc_ps) = splitAt (length ps `div` 2) ps -- Should not rename the certificates. acc_ps' <- forM acc_ps $ \(Param attrs v t) -> Param attrs <$> newVName (baseString v) <*> pure t pure $ cert_ps <> acc_ps' num_accs = length inputs acc_certs = map paramName $ take num_accs $ lambdaParams acc_lam onArr v = pure . maybe v snd $ find ((== v) . paramName . fst) params_and_arrs onInput (shape, arrs, op) = (Shape [w] <> shape,,) <$> mapM onArr arrs <*> traverse onOp op onOp (op_lam, nes) = do -- We need to add an additional index parameter because we are -- extending the index space of the accumulator. idx_p <- newParam "idx" $ Prim int64 pure (op_lam {lambdaParams = idx_p : lambdaParams op_lam}, nes) trType :: TypeBase shape u -> TypeBase shape u trType (Acc acc ispace ts u) | acc `elem` acc_certs = Acc acc (Shape [w] <> ispace) ts u trType t = t trParam :: Param (TypeBase shape u) -> Param (TypeBase shape u) trParam = fmap trType trLam i (Lambda params ret body) = localScope (scopeOfLParams params) $ Lambda (map trParam params) (map trType ret) <$> trBody i body trBody i (Body dec stms res) = inScopeOf stms $ Body dec <$> traverse (trStm i) stms <*> pure res trStm i (Let pat aux e) = Let (fmap trType pat) aux <$> trExp i e trSOAC i = mapSOACM mapper where mapper = identitySOACMapper {mapOnSOACLambda = trLam i} trExp i (WithAcc acc_inputs lam) = WithAcc acc_inputs <$> trLam i lam trExp i (BasicOp (UpdateAcc safety acc is ses)) = do acc_t <- lookupType acc pure $ case acc_t of Acc cert _ _ _ | cert `elem` acc_certs -> BasicOp $ UpdateAcc safety acc (i : is) ses _ -> BasicOp $ UpdateAcc safety acc is ses trExp i e = mapExpM mapper e where mapper = identityMapper { mapOnBody = \scope -> localScope scope . trBody i, mapOnRetType = pure . trType, mapOnBranchType = pure . trType, mapOnFParam = pure . trParam, mapOnLParam = pure . trParam, mapOnOp = trSOAC i } -- | Given a (parallel) map nesting and an inner withacc, move the -- maps inside the branch. The result is the resulting withacc -- expression, which will then contain statements with @map@ -- expressions. interchangeWithAcc :: (MonadFreshNames m, LocalScope SOACS m) => KernelNest -> WithAccStm -> m (Stm SOACS) interchangeWithAcc nest withacc = withAccStm <$> foldM interchangeWithAcc1 withacc (reverse $ kernelNestLoops nest) futhark-0.25.27/src/Futhark/Pass/ExtractKernels/Intrablock.hs000066400000000000000000000311041475065116200240240ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Extract limited nested parallelism for execution inside -- individual kernel threadblocks. module Futhark.Pass.ExtractKernels.Intrablock (intrablockParallelise) where import Control.Monad import Control.Monad.RWS import Control.Monad.Trans.Maybe import Data.Map.Strict qualified as M import Data.Set qualified as S import Futhark.Analysis.PrimExp.Convert import Futhark.IR.GPU hiding (HistOp) import Futhark.IR.GPU.Op qualified as GPU import Futhark.IR.SOACS import Futhark.MonadFreshNames import Futhark.Pass.ExtractKernels.BlockedKernel import Futhark.Pass.ExtractKernels.DistributeNests import Futhark.Pass.ExtractKernels.Distribution import Futhark.Pass.ExtractKernels.ToGPU import Futhark.Tools import Futhark.Transform.FirstOrderTransform qualified as FOT import Futhark.Util.Log import Prelude hiding (log) -- | Convert the statements inside a map nest to kernel statements, -- attempting to parallelise any remaining (top-level) parallel -- statements. Anything that is not a map, scan or reduction will -- simply be sequentialised. This includes sequential loops that -- contain maps, scans or reduction. In the future, we could probably -- do something more clever. Make sure that the amount of parallelism -- to be exploited does not exceed the group size. Further, as a hack -- we also consider the size of all intermediate arrays as -- "parallelism to be exploited" to avoid exploding shared memory. -- -- We distinguish between "minimum group size" and "maximum -- exploitable parallelism". intrablockParallelise :: (MonadFreshNames m, LocalScope GPU m) => KernelNest -> Lambda SOACS -> m ( Maybe ( (SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU ) ) intrablockParallelise knest lam = runMaybeT $ do (ispace, inps) <- lift $ flatKernel knest (num_tblocks, w_stms) <- lift $ runBuilder $ letSubExp "intra_num_tblocks" =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) (map snd ispace) let body = lambdaBody lam tblock_size <- newVName "computed_tblock_size" (wss_min, wss_avail, log, kbody) <- lift . localScope (scopeOfLParams $ lambdaParams lam) $ intrablockParalleliseBody body outside_scope <- lift askScope -- outside_scope may also contain the inputs, even though those are -- not actually available outside the kernel. let available v = v `M.member` outside_scope && v `notElem` map kernelInputName inps unless (all available $ namesToList $ freeIn (wss_min ++ wss_avail)) $ fail "Irregular parallelism" ((intra_avail_par, kspace, read_input_stms), prelude_stms) <- lift $ runBuilder $ do let foldBinOp' _ [] = eSubExp $ intConst Int64 1 foldBinOp' bop (x : xs) = foldBinOp bop x xs ws_min <- mapM (letSubExp "one_intra_par_min" <=< foldBinOp' (Mul Int64 OverflowUndef)) $ filter (not . null) wss_min ws_avail <- mapM (letSubExp "one_intra_par_avail" <=< foldBinOp' (Mul Int64 OverflowUndef)) $ filter (not . null) wss_avail -- The amount of parallelism available *in the worst case* is -- equal to the smallest parallel loop, or *at least* 1. intra_avail_par <- letSubExp "intra_avail_par" =<< foldBinOp' (SMin Int64) ws_avail -- The group size is either the maximum of the minimum parallelism -- exploited, or the desired parallelism (bounded by the max group -- size) in case there is no minimum. letBindNames [tblock_size] =<< if null ws_min then eBinOp (SMin Int64) (eSubExp =<< letSubExp "max_tblock_size" (Op $ SizeOp $ GetSizeMax SizeThreadBlock)) (eSubExp intra_avail_par) else foldBinOp' (SMax Int64) ws_min let inputIsUsed input = kernelInputName input `nameIn` freeIn body used_inps = filter inputIsUsed inps addStms w_stms read_input_stms <- runBuilder_ $ mapM readGroupKernelInput used_inps space <- SegSpace <$> newVName "phys_tblock_id" <*> pure ispace pure (intra_avail_par, space, read_input_stms) let kbody' = kbody {kernelBodyStms = read_input_stms <> kernelBodyStms kbody} let nested_pat = loopNestingPat first_nest rts = map (length ispace `stripArray`) $ patTypes nested_pat grid = KernelGrid (Count num_tblocks) (Count $ Var tblock_size) lvl = SegBlock SegNoVirt (Just grid) kstm = Let nested_pat aux $ Op $ SegOp $ SegMap lvl kspace rts kbody' let intra_min_par = intra_avail_par pure ( (intra_min_par, intra_avail_par), Var tblock_size, log, prelude_stms, oneStm kstm ) where first_nest = fst knest aux = loopNestingAux first_nest readGroupKernelInput :: (DistRep (Rep m), MonadBuilder m) => KernelInput -> m () readGroupKernelInput inp | Array {} <- kernelInputType inp = do v <- newVName $ baseString $ kernelInputName inp readKernelInput inp {kernelInputName = v} letBindNames [kernelInputName inp] $ BasicOp $ Replicate mempty $ Var v | otherwise = readKernelInput inp data IntraAcc = IntraAcc { accMinPar :: S.Set [SubExp], accAvailPar :: S.Set [SubExp], accLog :: Log } instance Semigroup IntraAcc where IntraAcc min_x avail_x log_x <> IntraAcc min_y avail_y log_y = IntraAcc (min_x <> min_y) (avail_x <> avail_y) (log_x <> log_y) instance Monoid IntraAcc where mempty = IntraAcc mempty mempty mempty type IntrablockM = BuilderT GPU (RWS () IntraAcc VNameSource) instance MonadLogger IntrablockM where addLog log = tell mempty {accLog = log} runIntrablockM :: (MonadFreshNames m, HasScope GPU m) => IntrablockM () -> m (IntraAcc, Stms GPU) runIntrablockM m = do scope <- castScope <$> askScope modifyNameSource $ \src -> let (((), kstms), src', acc) = runRWS (runBuilderT m scope) () src in ((acc, kstms), src') parallelMin :: [SubExp] -> IntrablockM () parallelMin ws = tell mempty { accMinPar = S.singleton ws, accAvailPar = S.singleton ws } intrablockBody :: Body SOACS -> IntrablockM (Body GPU) intrablockBody body = do stms <- collectStms_ $ intrablockStms $ bodyStms body pure $ mkBody stms $ bodyResult body intrablockLambda :: Lambda SOACS -> IntrablockM (Lambda GPU) intrablockLambda lam = mkLambda (lambdaParams lam) $ bodyBind =<< intrablockBody (lambdaBody lam) intrablockWithAccInput :: WithAccInput SOACS -> IntrablockM (WithAccInput GPU) intrablockWithAccInput (shape, arrs, Nothing) = pure (shape, arrs, Nothing) intrablockWithAccInput (shape, arrs, Just (lam, nes)) = do lam' <- intrablockLambda lam pure (shape, arrs, Just (lam', nes)) intrablockStm :: Stm SOACS -> IntrablockM () intrablockStm stm@(Let pat aux e) = do scope <- askScope let lvl = SegThreadInBlock SegNoVirt case e of Loop merge form loopbody -> localScope (scopeOfLoopForm form <> scopeOfFParams (map fst merge)) $ do loopbody' <- intrablockBody loopbody certifying (stmAuxCerts aux) . letBind pat $ Loop merge form loopbody' Match cond cases defbody ifdec -> do cases' <- mapM (traverse intrablockBody) cases defbody' <- intrablockBody defbody certifying (stmAuxCerts aux) . letBind pat $ Match cond cases' defbody' ifdec WithAcc inputs lam -> do inputs' <- mapM intrablockWithAccInput inputs lam' <- intrablockLambda lam certifying (stmAuxCerts aux) . letBind pat $ WithAcc inputs' lam' Op soac | "sequential_outer" `inAttrs` stmAuxAttrs aux -> intrablockStms . fmap (certify (stmAuxCerts aux)) =<< runBuilder_ (FOT.transformSOAC pat soac) Op (Screma w arrs form) | Just lam <- isMapSOAC form -> do let loopnest = MapNesting pat aux w $ zip (lambdaParams lam) arrs env = DistEnv { distNest = singleNesting $ Nesting mempty loopnest, distScope = scopeOfPat pat <> scopeForGPU (scopeOf lam) <> scope, distOnInnerMap = distributeMap, distOnTopLevelStms = liftInner . collectStms_ . intrablockStms, distSegLevel = \minw _ _ -> do lift $ parallelMin minw pure lvl, distOnSOACSStms = pure . oneStm . soacsStmToGPU, distOnSOACSLambda = pure . soacsLambdaToGPU } acc = DistAcc { distTargets = singleTarget (pat, bodyResult $ lambdaBody lam), distStms = mempty } addStms =<< runDistNestT env (distributeMapBodyStms acc (bodyStms $ lambdaBody lam)) Op (Screma w arrs form) | Just (scans, mapfun) <- isScanomapSOAC form, -- FIXME: Futhark.CodeGen.ImpGen.GPU.Block.compileGroupOp -- cannot handle multiple scan operators yet. Scan scanfun nes <- singleScan scans -> do let scanfun' = soacsLambdaToGPU scanfun mapfun' = soacsLambdaToGPU mapfun certifying (stmAuxCerts aux) $ addStms =<< segScan lvl pat mempty w [SegBinOp Noncommutative scanfun' nes mempty] mapfun' arrs [] [] parallelMin [w] Op (Screma w arrs form) | Just (reds, map_lam) <- isRedomapSOAC form -> do let onReduce (Reduce comm red_lam nes) = SegBinOp comm (soacsLambdaToGPU red_lam) nes mempty reds' = map onReduce reds map_lam' = soacsLambdaToGPU map_lam certifying (stmAuxCerts aux) $ addStms =<< segRed lvl pat mempty w reds' map_lam' arrs [] [] parallelMin [w] Op (Screma w arrs form) -> -- This screma is too complicated for us to immediately do -- anything, so split it up and try again. mapM_ intrablockStm . fmap (certify (stmAuxCerts aux)) . snd =<< runBuilderT (dissectScrema pat w form arrs) (scopeForSOACs scope) Op (Hist w arrs ops bucket_fun) -> do ops' <- forM ops $ \(HistOp num_bins rf dests nes op) -> do (op', nes', shape) <- determineReduceOp op nes let op'' = soacsLambdaToGPU op' pure $ GPU.HistOp num_bins rf dests nes' shape op'' let bucket_fun' = soacsLambdaToGPU bucket_fun certifying (stmAuxCerts aux) $ addStms =<< segHist lvl pat w [] [] ops' bucket_fun' arrs parallelMin [w] Op (Stream w arrs accs lam) | chunk_size_param : _ <- lambdaParams lam -> do types <- asksScope castScope ((), stream_stms) <- runBuilderT (sequentialStreamWholeArray pat w accs lam arrs) types let replace (Var v) | v == paramName chunk_size_param = w replace se = se replaceSets (IntraAcc x y log) = IntraAcc (S.map (map replace) x) (S.map (map replace) y) log censor replaceSets $ intrablockStms stream_stms Op (Scatter w ivs dests lam) -> do write_i <- newVName "write_i" space <- mkSegSpace [(write_i, w)] let lam' = soacsLambdaToGPU lam krets = do (_a_w, a, is_vs) <- groupScatterResults dests $ bodyResult $ lambdaBody lam' let cs = foldMap (foldMap resCerts . fst) is_vs <> foldMap (resCerts . snd) is_vs is_vs' = [(Slice $ map (DimFix . resSubExp) is, resSubExp v) | (is, v) <- is_vs] pure $ WriteReturns cs a is_vs' inputs = do (p, p_a) <- zip (lambdaParams lam') ivs pure $ KernelInput (paramName p) (paramType p) p_a [Var write_i] kstms <- runBuilder_ $ localScope (scopeOfSegSpace space) $ do mapM_ readKernelInput inputs addStms $ bodyStms $ lambdaBody lam' certifying (stmAuxCerts aux) $ do let body = KernelBody () kstms krets letBind pat $ Op $ SegOp $ SegMap lvl space (patTypes pat) body parallelMin [w] _ -> addStm $ soacsStmToGPU stm intrablockStms :: Stms SOACS -> IntrablockM () intrablockStms = mapM_ intrablockStm intrablockParalleliseBody :: (MonadFreshNames m, HasScope GPU m) => Body SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU) intrablockParalleliseBody body = do (IntraAcc min_ws avail_ws log, kstms) <- runIntrablockM $ intrablockStms $ bodyStms body pure ( S.toList min_ws, S.toList avail_ws, log, KernelBody () kstms $ map ret $ bodyResult body ) where ret (SubExpRes cs se) = Returns ResultMaySimplify cs se futhark-0.25.27/src/Futhark/Pass/ExtractKernels/StreamKernel.hs000066400000000000000000000042271475065116200243360ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.Pass.ExtractKernels.StreamKernel ( segThreadCapped, ) where import Control.Monad import Data.List () import Futhark.Analysis.PrimExp import Futhark.IR import Futhark.IR.GPU hiding ( BasicOp, Body, Exp, FParam, FunDef, LParam, Lambda, Pat, PatElem, Prog, RetType, Stm, ) import Futhark.MonadFreshNames import Futhark.Pass.ExtractKernels.BlockedKernel import Futhark.Pass.ExtractKernels.ToGPU import Futhark.Tools import Prelude hiding (quot) data KernelSize = KernelSize { -- | Int64 kernelElementsPerThread :: SubExp, -- | Int32 kernelNumThreads :: SubExp } deriving (Eq, Ord, Show) numberOfBlocks :: (MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) => String -> SubExp -> SubExp -> m (SubExp, SubExp) numberOfBlocks desc w tblock_size = do max_num_tblocks_key <- nameFromString . prettyString <$> newVName (desc ++ "_num_tblocks") num_tblocks <- letSubExp "num_tblocks" $ Op $ SizeOp $ CalcNumBlocks w max_num_tblocks_key tblock_size num_threads <- letSubExp "num_threads" $ BasicOp $ BinOp (Mul Int64 OverflowUndef) num_tblocks tblock_size pure (num_tblocks, num_threads) -- | Like 'segThread', but cap the thread count to the input size. -- This is more efficient for small kernels, e.g. summing a small -- array. segThreadCapped :: (MonadFreshNames m) => MkSegLevel GPU m segThreadCapped ws desc r = do w <- letSubExp "nest_size" =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) ws tblock_size <- getSize (desc ++ "_tblock_size") SizeThreadBlock case r of ManyThreads -> do usable_groups <- letSubExp "segmap_usable_groups" =<< eBinOp (SDivUp Int64 Unsafe) (eSubExp w) (eSubExp =<< asIntS Int64 tblock_size) let grid = KernelGrid (Count usable_groups) (Count tblock_size) pure $ SegThread SegNoVirt (Just grid) NoRecommendation v -> do (num_tblocks, _) <- numberOfBlocks desc w tblock_size let grid = KernelGrid (Count num_tblocks) (Count tblock_size) pure $ SegThread v (Just grid) futhark-0.25.27/src/Futhark/Pass/ExtractKernels/ToGPU.hs000066400000000000000000000041131475065116200226720ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} module Futhark.Pass.ExtractKernels.ToGPU ( getSize, segThread, soacsLambdaToGPU, soacsStmToGPU, scopeForGPU, scopeForSOACs, injectSOACS, ) where import Control.Monad.Identity import Data.List () import Futhark.IR import Futhark.IR.GPU import Futhark.IR.SOACS (SOACS) import Futhark.IR.SOACS.SOAC qualified as SOAC import Futhark.Tools getSize :: (MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) => String -> SizeClass -> m SubExp getSize desc size_class = do size_key <- nameFromString . prettyString <$> newVName desc letSubExp desc $ Op $ SizeOp $ GetSize size_key size_class segThread :: (MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) => String -> m SegLevel segThread desc = SegThread SegVirt <$> (Just <$> kernelGrid) where kernelGrid = KernelGrid <$> (Count <$> getSize (desc ++ "_num_tblocks") SizeGrid) <*> (Count <$> getSize (desc ++ "_tblock_size") SizeThreadBlock) injectSOACS :: ( Monad m, SameScope from to, ExpDec from ~ ExpDec to, BodyDec from ~ BodyDec to, RetType from ~ RetType to, BranchType from ~ BranchType to, Op from ~ SOAC from ) => (SOAC to -> Op to) -> Rephraser m from to injectSOACS f = Rephraser { rephraseExpDec = pure, rephraseBodyDec = pure, rephraseLetBoundDec = pure, rephraseFParamDec = pure, rephraseLParamDec = pure, rephraseOp = fmap f . onSOAC, rephraseRetType = pure, rephraseBranchType = pure } where onSOAC = SOAC.mapSOACM mapper mapper = SOAC.SOACMapper { SOAC.mapOnSOACSubExp = pure, SOAC.mapOnSOACVName = pure, SOAC.mapOnSOACLambda = rephraseLambda $ injectSOACS f } soacsStmToGPU :: Stm SOACS -> Stm GPU soacsStmToGPU = runIdentity . rephraseStm (injectSOACS OtherOp) soacsLambdaToGPU :: Lambda SOACS -> Lambda GPU soacsLambdaToGPU = runIdentity . rephraseLambda (injectSOACS OtherOp) scopeForSOACs :: Scope GPU -> Scope SOACS scopeForSOACs = castScope scopeForGPU :: Scope SOACS -> Scope GPU scopeForGPU = castScope futhark-0.25.27/src/Futhark/Pass/ExtractMulticore.hs000066400000000000000000000255711475065116200223070ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Extraction of parallelism from a SOACs program. This generates -- parallel constructs aimed at CPU execution, which in particular may -- involve ad-hoc irregular nested parallelism. module Futhark.Pass.ExtractMulticore (extractMulticore) where import Control.Monad import Control.Monad.Identity import Control.Monad.Reader import Control.Monad.State import Data.Bitraversable import Futhark.IR import Futhark.IR.MC import Futhark.IR.MC qualified as MC import Futhark.IR.SOACS hiding ( Body, Exp, LParam, Lambda, Pat, Stm, ) import Futhark.IR.SOACS qualified as SOACS import Futhark.Pass import Futhark.Pass.ExtractKernels.DistributeNests import Futhark.Pass.ExtractKernels.ToGPU (injectSOACS) import Futhark.Tools import Futhark.Transform.Rename (Rename, renameSomething) import Futhark.Util.Log newtype ExtractM a = ExtractM (ReaderT (Scope MC) (State VNameSource) a) deriving ( Functor, Applicative, Monad, HasScope MC, LocalScope MC, MonadFreshNames ) -- XXX: throwing away the log here... instance MonadLogger ExtractM where addLog _ = pure () indexArray :: VName -> LParam SOACS -> VName -> Stm MC indexArray i (Param _ p t) arr = Let (Pat [PatElem p t]) (defAux ()) . BasicOp $ case t of Acc {} -> SubExp $ Var arr _ -> Index arr $ Slice $ DimFix (Var i) : map sliceDim (arrayDims t) mapLambdaToBody :: (Body SOACS -> ExtractM (Body MC)) -> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC) mapLambdaToBody onBody i lam arrs = do let indexings = zipWith (indexArray i) (lambdaParams lam) arrs Body () stms res <- inScopeOf indexings $ onBody $ lambdaBody lam pure $ Body () (stmsFromList indexings <> stms) res mapLambdaToKernelBody :: (Body SOACS -> ExtractM (Body MC)) -> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC) mapLambdaToKernelBody onBody i lam arrs = do Body () stms res <- mapLambdaToBody onBody i lam arrs let ret (SubExpRes cs se) = Returns ResultMaySimplify cs se pure $ KernelBody () stms $ map ret res reduceToSegBinOp :: Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC) reduceToSegBinOp (Reduce comm lam nes) = do ((lam', nes', shape), stms) <- runBuilder $ determineReduceOp lam nes lam'' <- transformLambda lam' let comm' | commutativeLambda lam' = Commutative | otherwise = comm pure (stms, SegBinOp comm' lam'' nes' shape) scanToSegBinOp :: Scan SOACS -> ExtractM (Stms MC, SegBinOp MC) scanToSegBinOp (Scan lam nes) = do ((lam', nes', shape), stms) <- runBuilder $ determineReduceOp lam nes lam'' <- transformLambda lam' pure (stms, SegBinOp Noncommutative lam'' nes' shape) histToSegBinOp :: SOACS.HistOp SOACS -> ExtractM (Stms MC, MC.HistOp MC) histToSegBinOp (SOACS.HistOp num_bins rf dests nes op) = do ((op', nes', shape), stms) <- runBuilder $ determineReduceOp op nes op'' <- transformLambda op' pure (stms, MC.HistOp num_bins rf dests nes' shape op'') mkSegSpace :: (MonadFreshNames m) => SubExp -> m (VName, SegSpace) mkSegSpace w = do flat <- newVName "flat_tid" gtid <- newVName "gtid" let space = SegSpace flat [(gtid, w)] pure (gtid, space) transformStm :: Stm SOACS -> ExtractM (Stms MC) transformStm (Let pat aux (BasicOp op)) = pure $ oneStm $ Let pat aux $ BasicOp op transformStm (Let pat aux (Apply f args ret info)) = pure $ oneStm $ Let pat aux $ Apply f args ret info transformStm (Let pat aux (Loop merge form body)) = do body' <- localScope (scopeOfFParams (map fst merge) <> scopeOfLoopForm form) $ transformBody body pure $ oneStm $ Let pat aux $ Loop merge form body' transformStm (Let pat aux (Match ses cases defbody ret)) = oneStm . Let pat aux <$> (Match ses <$> mapM transformCase cases <*> transformBody defbody <*> pure ret) where transformCase (Case vs body) = Case vs <$> transformBody body transformStm (Let pat aux (WithAcc inputs lam)) = oneStm . Let pat aux <$> (WithAcc <$> mapM transformInput inputs <*> transformLambda lam) where transformInput (shape, arrs, op) = (shape,arrs,) <$> traverse (bitraverse transformLambda pure) op transformStm (Let pat aux (Op op)) = fmap (certify (stmAuxCerts aux)) <$> transformSOAC pat (stmAuxAttrs aux) op transformLambda :: Lambda SOACS -> ExtractM (Lambda MC) transformLambda (Lambda params ret body) = Lambda params ret <$> localScope (scopeOfLParams params) (transformBody body) transformStms :: Stms SOACS -> ExtractM (Stms MC) transformStms stms = case stmsHead stms of Nothing -> pure mempty Just (stm, stms') -> do stm_stms <- transformStm stm inScopeOf stm_stms $ (stm_stms <>) <$> transformStms stms' transformBody :: Body SOACS -> ExtractM (Body MC) transformBody (Body () stms res) = Body () <$> transformStms stms <*> pure res sequentialiseBody :: Body SOACS -> ExtractM (Body MC) sequentialiseBody = pure . runIdentity . rephraseBody toMC where toMC = injectSOACS OtherOp transformFunDef :: FunDef SOACS -> ExtractM (FunDef MC) transformFunDef (FunDef entry attrs name rettype params body) = do body' <- localScope (scopeOfFParams params) $ transformBody body pure $ FunDef entry attrs name rettype params body' -- Code generation for each parallel basic block is parameterised over -- how we handle parallelism in the body (whether it's sequentialised -- by keeping it as SOACs, or turned into SegOps). data NeedsRename = DoRename | DoNotRename renameIfNeeded :: (Rename a) => NeedsRename -> a -> ExtractM a renameIfNeeded DoRename = renameSomething renameIfNeeded DoNotRename = pure transformMap :: NeedsRename -> (Body SOACS -> ExtractM (Body MC)) -> SubExp -> Lambda SOACS -> [VName] -> ExtractM (SegOp () MC) transformMap rename onBody w map_lam arrs = do (gtid, space) <- mkSegSpace w kbody <- mapLambdaToKernelBody onBody gtid map_lam arrs renameIfNeeded rename $ SegMap () space (lambdaReturnType map_lam) kbody transformRedomap :: NeedsRename -> (Body SOACS -> ExtractM (Body MC)) -> SubExp -> [Reduce SOACS] -> Lambda SOACS -> [VName] -> ExtractM ([Stms MC], SegOp () MC) transformRedomap rename onBody w reds map_lam arrs = do (gtid, space) <- mkSegSpace w kbody <- mapLambdaToKernelBody onBody gtid map_lam arrs (reds_stms, reds') <- mapAndUnzipM reduceToSegBinOp reds op' <- renameIfNeeded rename $ SegRed () space reds' (lambdaReturnType map_lam) kbody pure (reds_stms, op') transformHist :: NeedsRename -> (Body SOACS -> ExtractM (Body MC)) -> SubExp -> [SOACS.HistOp SOACS] -> Lambda SOACS -> [VName] -> ExtractM ([Stms MC], SegOp () MC) transformHist rename onBody w hists map_lam arrs = do (gtid, space) <- mkSegSpace w kbody <- mapLambdaToKernelBody onBody gtid map_lam arrs (hists_stms, hists') <- mapAndUnzipM histToSegBinOp hists op' <- renameIfNeeded rename $ SegHist () space hists' (lambdaReturnType map_lam) kbody pure (hists_stms, op') transformSOAC :: Pat Type -> Attrs -> SOAC SOACS -> ExtractM (Stms MC) transformSOAC _ _ JVP {} = error "transformSOAC: unhandled JVP" transformSOAC _ _ VJP {} = error "transformSOAC: unhandled VJP" transformSOAC pat _ (Screma w arrs form) | Just lam <- isMapSOAC form = do seq_op <- transformMap DoNotRename sequentialiseBody w lam arrs if lambdaContainsParallelism lam then do par_op <- transformMap DoRename transformBody w lam arrs pure $ oneStm (Let pat (defAux ()) $ Op $ ParOp (Just par_op) seq_op) else pure $ oneStm (Let pat (defAux ()) $ Op $ ParOp Nothing seq_op) | Just (reds, map_lam) <- isRedomapSOAC form = do (seq_reds_stms, seq_op) <- transformRedomap DoNotRename sequentialiseBody w reds map_lam arrs if lambdaContainsParallelism map_lam then do (par_reds_stms, par_op) <- transformRedomap DoRename transformBody w reds map_lam arrs pure $ mconcat (seq_reds_stms <> par_reds_stms) <> oneStm (Let pat (defAux ()) $ Op $ ParOp (Just par_op) seq_op) else pure $ mconcat seq_reds_stms <> oneStm (Let pat (defAux ()) $ Op $ ParOp Nothing seq_op) | Just (scans, map_lam) <- isScanomapSOAC form = do (gtid, space) <- mkSegSpace w kbody <- mapLambdaToKernelBody transformBody gtid map_lam arrs (scans_stms, scans') <- mapAndUnzipM scanToSegBinOp scans pure $ mconcat scans_stms <> oneStm ( Let pat (defAux ()) $ Op $ ParOp Nothing $ SegScan () space scans' (lambdaReturnType map_lam) kbody ) | otherwise = do -- This screma is too complicated for us to immediately do -- anything, so split it up and try again. scope <- castScope <$> askScope transformStms =<< runBuilderT_ (dissectScrema pat w form arrs) scope transformSOAC pat _ (Scatter w ivs dests lam) = do (gtid, space) <- mkSegSpace w Body () kstms res <- mapLambdaToBody transformBody gtid lam ivs (rets, kres) <- fmap unzip $ forM (groupScatterResults dests res) $ \(_a_w, a, is_vs) -> do a_t <- lookupType a let cs = foldMap (foldMap resCerts . fst) is_vs <> foldMap (resCerts . snd) is_vs is_vs' = [(fullSlice a_t $ map (DimFix . resSubExp) is, resSubExp v) | (is, v) <- is_vs] pure (a_t, WriteReturns cs a is_vs') pure . oneStm . Let pat (defAux ()) . Op . ParOp Nothing $ SegMap () space rets (KernelBody () kstms kres) transformSOAC pat _ (Hist w arrs hists map_lam) = do (seq_hist_stms, seq_op) <- transformHist DoNotRename sequentialiseBody w hists map_lam arrs if lambdaContainsParallelism map_lam then do (par_hist_stms, par_op) <- transformHist DoRename transformBody w hists map_lam arrs pure $ mconcat (seq_hist_stms <> par_hist_stms) <> oneStm (Let pat (defAux ()) $ Op $ ParOp (Just par_op) seq_op) else pure $ mconcat seq_hist_stms <> oneStm (Let pat (defAux ()) $ Op $ ParOp Nothing seq_op) transformSOAC pat _ (Stream w arrs nes lam) = do -- Just remove the stream and transform the resulting stms. soacs_scope <- castScope <$> askScope stream_stms <- flip runBuilderT_ soacs_scope $ sequentialStreamWholeArray pat w nes lam arrs transformStms stream_stms transformProg :: Prog SOACS -> PassM (Prog MC) transformProg prog = modifyNameSource $ runState (runReaderT m mempty) where ExtractM m = do consts' <- transformStms $ progConsts prog funs' <- inScopeOf consts' $ mapM transformFunDef $ progFuns prog pure $ prog { progConsts = consts', progFuns = funs' } -- | Transform a program using SOACs to a program in the 'MC' -- representation, using some amount of flattening. extractMulticore :: Pass SOACS MC extractMulticore = Pass { passName = "extract multicore parallelism", passDescription = "Extract multicore parallelism", passFunction = transformProg } futhark-0.25.27/src/Futhark/Pass/FirstOrderTransform.hs000066400000000000000000000015301475065116200227550ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | Transform any SOACs to @for@-loops. -- -- Example: -- -- @ -- let ys = map (\x -> x + 2) xs -- @ -- -- becomes something like: -- -- @ -- let out = scratch n i32 -- let ys = -- loop (ys' = out) for i < n do -- let x = xs[i] -- let y = x + 2 -- let ys'[i] = y -- in ys' -- @ module Futhark.Pass.FirstOrderTransform (firstOrderTransform) where import Futhark.IR.SOACS (SOACS, scopeOf) import Futhark.Pass import Futhark.Transform.FirstOrderTransform (FirstOrderRep, transformConsts, transformFunDef) -- | The first-order transformation pass. firstOrderTransform :: (FirstOrderRep rep) => Pass SOACS rep firstOrderTransform = Pass "first order transform" "Transform all SOACs to for-loops." $ intraproceduralTransformationWithConsts transformConsts (transformFunDef . scopeOf) futhark-0.25.27/src/Futhark/Pass/LiftAllocations.hs000066400000000000000000000121161475065116200220670ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | This pass attempts to lift allocations and asserts as far towards -- the top in their body as possible. This helps memory short -- circuiting do a better job, as it is sensitive to statement -- ordering. It does not try to hoist allocations outside across body -- boundaries. module Futhark.Pass.LiftAllocations ( liftAllocationsSeqMem, liftAllocationsGPUMem, liftAllocationsMCMem, ) where import Control.Monad.Reader import Data.Sequence (Seq (..)) import Futhark.Analysis.Alias (aliasAnalysis) import Futhark.IR.Aliases import Futhark.IR.GPUMem import Futhark.IR.MCMem import Futhark.IR.SeqMem import Futhark.Pass (Pass (..)) liftInProg :: (AliasableRep rep, Mem rep inner, ASTConstraints (inner (Aliases rep))) => (inner (Aliases rep) -> LiftM (inner (Aliases rep)) (inner (Aliases rep))) -> Prog rep -> Prog rep liftInProg onOp prog = prog { progFuns = removeFunDefAliases . onFun <$> progFuns (aliasAnalysis prog) } where onFun f = f {funDefBody = onBody (funDefBody f)} onBody body = runReader (liftAllocationsInBody body) (Env onOp) liftAllocationsSeqMem :: Pass SeqMem SeqMem liftAllocationsSeqMem = Pass "lift allocations" "lift allocations" $ pure . liftInProg pure liftAllocationsGPUMem :: Pass GPUMem GPUMem liftAllocationsGPUMem = Pass "lift allocations gpu" "lift allocations gpu" $ pure . liftInProg liftAllocationsInHostOp liftAllocationsMCMem :: Pass MCMem MCMem liftAllocationsMCMem = Pass "lift allocations mc" "lift allocations mc" $ pure . liftInProg liftAllocationsInMCOp newtype Env inner = Env {onInner :: inner -> LiftM inner inner} type LiftM inner a = Reader (Env inner) a liftAllocationsInBody :: (Mem rep inner, Aliased rep) => Body rep -> LiftM (inner rep) (Body rep) liftAllocationsInBody body = do stms <- liftAllocationsInStms (bodyStms body) mempty mempty mempty pure $ body {bodyStms = stms} liftInsideStm :: (Mem rep inner, Aliased rep) => Stm rep -> LiftM (inner rep) (Stm rep) liftInsideStm stm@(Let _ _ (Op (Inner inner))) = do on_inner <- asks onInner inner' <- on_inner inner pure $ stm {stmExp = Op $ Inner inner'} liftInsideStm stm@(Let _ _ (Match cond_ses cases body dec)) = do cases' <- mapM (\(Case p b) -> Case p <$> liftAllocationsInBody b) cases body' <- liftAllocationsInBody body pure stm {stmExp = Match cond_ses cases' body' dec} liftInsideStm stm@(Let _ _ (Loop params form body)) = do body' <- liftAllocationsInBody body pure stm {stmExp = Loop params form body'} liftInsideStm stm = pure stm liftAllocationsInStms :: (Mem rep inner, Aliased rep) => -- | The input stms Stms rep -> -- | The lifted allocations and associated statements Stms rep -> -- | The other statements processed so far Stms rep -> -- | (Names we need to lift, consumed names) (Names, Names) -> LiftM (inner rep) (Stms rep) liftAllocationsInStms Empty lifted acc _ = pure $ lifted <> acc liftAllocationsInStms (stms :|> stm) lifted acc (to_lift, consumed) = do stm' <- liftInsideStm stm case stmExp stm' of BasicOp Assert {} -> liftStm stm' Op Alloc {} -> liftStm stm' _ -> do let pat_names = namesFromList $ patNames $ stmPat stm' if (pat_names `namesIntersect` to_lift) || namesIntersect consumed (freeIn stm) then liftStm stm' else dontLiftStm stm' where liftStm stm' = liftAllocationsInStms stms (stm' :<| lifted) acc (to_lift', consumed') where to_lift' = freeIn stm' <> (to_lift `namesSubtract` namesFromList (patNames (stmPat stm'))) consumed' = consumed <> consumedInStm stm' dontLiftStm stm' = liftAllocationsInStms stms lifted (stm' :<| acc) (to_lift, consumed) liftAllocationsInSegOp :: (Mem rep inner, Aliased rep) => SegOp lvl rep -> LiftM (inner rep) (SegOp lvl rep) liftAllocationsInSegOp (SegMap lvl sp tps body) = do stms <- liftAllocationsInStms (kernelBodyStms body) mempty mempty mempty pure $ SegMap lvl sp tps $ body {kernelBodyStms = stms} liftAllocationsInSegOp (SegRed lvl sp binops tps body) = do stms <- liftAllocationsInStms (kernelBodyStms body) mempty mempty mempty pure $ SegRed lvl sp binops tps $ body {kernelBodyStms = stms} liftAllocationsInSegOp (SegScan lvl sp binops tps body) = do stms <- liftAllocationsInStms (kernelBodyStms body) mempty mempty mempty pure $ SegScan lvl sp binops tps $ body {kernelBodyStms = stms} liftAllocationsInSegOp (SegHist lvl sp histops tps body) = do stms <- liftAllocationsInStms (kernelBodyStms body) mempty mempty mempty pure $ SegHist lvl sp histops tps $ body {kernelBodyStms = stms} liftAllocationsInHostOp :: HostOp NoOp (Aliases GPUMem) -> LiftM (HostOp NoOp (Aliases GPUMem)) (HostOp NoOp (Aliases GPUMem)) liftAllocationsInHostOp (SegOp op) = SegOp <$> liftAllocationsInSegOp op liftAllocationsInHostOp op = pure op liftAllocationsInMCOp :: MCOp NoOp (Aliases MCMem) -> LiftM (MCOp NoOp (Aliases MCMem)) (MCOp NoOp (Aliases MCMem)) liftAllocationsInMCOp (ParOp par op) = ParOp <$> traverse liftAllocationsInSegOp par <*> liftAllocationsInSegOp op liftAllocationsInMCOp op = pure op futhark-0.25.27/src/Futhark/Pass/LowerAllocations.hs000066400000000000000000000120121475065116200222540ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | This pass attempts to lower allocations as far towards the bottom of their -- body as possible. module Futhark.Pass.LowerAllocations ( lowerAllocationsSeqMem, lowerAllocationsGPUMem, lowerAllocationsMCMem, ) where import Control.Monad.Reader import Data.Function ((&)) import Data.Map qualified as M import Data.Sequence (Seq (..)) import Data.Sequence qualified as Seq import Futhark.IR.GPUMem import Futhark.IR.MCMem import Futhark.IR.SeqMem import Futhark.Pass (Pass (..)) lowerInProg :: (Mem rep inner, LetDec rep ~ LetDecMem) => (inner rep -> LowerM (inner rep) (inner rep)) -> Prog rep -> Prog rep lowerInProg onOp prog = prog {progFuns = fmap onFun (progFuns prog)} where onFun f = f {funDefBody = onBody (funDefBody f)} onBody body = runReader (lowerAllocationsInBody body) (Env onOp) lowerAllocationsSeqMem :: Pass SeqMem SeqMem lowerAllocationsSeqMem = Pass "lower allocations" "lower allocations" $ pure . lowerInProg pure lowerAllocationsGPUMem :: Pass GPUMem GPUMem lowerAllocationsGPUMem = Pass "lower allocations gpu" "lower allocations gpu" $ pure . lowerInProg lowerAllocationsInHostOp lowerAllocationsMCMem :: Pass MCMem MCMem lowerAllocationsMCMem = Pass "lower allocations mc" "lower allocations mc" $ pure . lowerInProg lowerAllocationsInMCOp newtype Env inner = Env {onInner :: inner -> LowerM inner inner} type LowerM inner a = Reader (Env inner) a lowerAllocationsInBody :: (Mem rep inner, LetDec rep ~ LetDecMem) => Body rep -> LowerM (inner rep) (Body rep) lowerAllocationsInBody body = do stms <- lowerAllocationsInStms (bodyStms body) mempty mempty pure $ body {bodyStms = stms} lowerAllocationsInStms :: (Mem rep inner, LetDec rep ~ LetDecMem) => -- | The input stms Stms rep -> -- | The allocations currently being lowered M.Map VName (Stm rep) -> -- | The other statements processed so far Stms rep -> LowerM (inner rep) (Stms rep) lowerAllocationsInStms Empty allocs acc = pure $ acc <> Seq.fromList (M.elems allocs) lowerAllocationsInStms (stm@(Let (Pat [PatElem vname _]) _ (Op (Alloc _ _))) :<| stms) allocs acc = lowerAllocationsInStms stms (M.insert vname stm allocs) acc lowerAllocationsInStms (stm0@(Let _ _ (Op (Inner inner))) :<| stms) alloc0 acc0 = do on_inner <- asks onInner inner' <- on_inner inner let stm = stm0 {stmExp = Op $ Inner inner'} (alloc, acc) = insertLoweredAllocs (freeIn stm0) alloc0 acc0 lowerAllocationsInStms stms alloc (acc :|> stm) lowerAllocationsInStms (stm@(Let _ _ (Match cond_ses cases body dec)) :<| stms) alloc acc = do cases' <- mapM (\(Case pat b) -> Case pat <$> lowerAllocationsInBody b) cases body' <- lowerAllocationsInBody body let stm' = stm {stmExp = Match cond_ses cases' body' dec} (alloc', acc') = insertLoweredAllocs (freeIn stm) alloc acc lowerAllocationsInStms stms alloc' (acc' :|> stm') lowerAllocationsInStms (stm@(Let _ _ (Loop params form body)) :<| stms) alloc acc = do body' <- lowerAllocationsInBody body let stm' = stm {stmExp = Loop params form body'} (alloc', acc') = insertLoweredAllocs (freeIn stm) alloc acc lowerAllocationsInStms stms alloc' (acc' :|> stm') lowerAllocationsInStms (stm :<| stms) alloc acc = do let (alloc', acc') = insertLoweredAllocs (freeIn stm) alloc acc lowerAllocationsInStms stms alloc' (acc' :|> stm) insertLoweredAllocs :: Names -> M.Map VName (Stm rep) -> Stms rep -> (M.Map VName (Stm rep), Stms rep) insertLoweredAllocs frees alloc acc = frees `namesIntersection` namesFromList (M.keys alloc) & namesToList & foldl ( \(alloc', acc') name -> ( M.delete name alloc', acc' :|> alloc' M.! name ) ) (alloc, acc) lowerAllocationsInSegOp :: (Mem rep inner, LetDec rep ~ LetDecMem) => SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep) lowerAllocationsInSegOp (SegMap lvl sp tps body) = do stms <- lowerAllocationsInStms (kernelBodyStms body) mempty mempty pure $ SegMap lvl sp tps $ body {kernelBodyStms = stms} lowerAllocationsInSegOp (SegRed lvl sp binops tps body) = do stms <- lowerAllocationsInStms (kernelBodyStms body) mempty mempty pure $ SegRed lvl sp binops tps $ body {kernelBodyStms = stms} lowerAllocationsInSegOp (SegScan lvl sp binops tps body) = do stms <- lowerAllocationsInStms (kernelBodyStms body) mempty mempty pure $ SegScan lvl sp binops tps $ body {kernelBodyStms = stms} lowerAllocationsInSegOp (SegHist lvl sp histops tps body) = do stms <- lowerAllocationsInStms (kernelBodyStms body) mempty mempty pure $ SegHist lvl sp histops tps $ body {kernelBodyStms = stms} lowerAllocationsInHostOp :: HostOp NoOp GPUMem -> LowerM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem) lowerAllocationsInHostOp (SegOp op) = SegOp <$> lowerAllocationsInSegOp op lowerAllocationsInHostOp op = pure op lowerAllocationsInMCOp :: MCOp NoOp MCMem -> LowerM (MCOp NoOp MCMem) (MCOp NoOp MCMem) lowerAllocationsInMCOp (ParOp par op) = ParOp <$> traverse lowerAllocationsInSegOp par <*> lowerAllocationsInSegOp op lowerAllocationsInMCOp op = pure op futhark-0.25.27/src/Futhark/Pass/Simplify.hs000066400000000000000000000023661475065116200206020ustar00rootroot00000000000000module Futhark.Pass.Simplify ( simplify, simplifySOACS, simplifySeq, simplifyMC, simplifyGPU, simplifyGPUMem, simplifySeqMem, simplifyMCMem, ) where import Futhark.IR.GPU.Simplify qualified as GPU import Futhark.IR.GPUMem qualified as GPUMem import Futhark.IR.MC qualified as MC import Futhark.IR.MCMem qualified as MCMem import Futhark.IR.SOACS.Simplify qualified as SOACS import Futhark.IR.Seq qualified as Seq import Futhark.IR.SeqMem qualified as SeqMem import Futhark.IR.Syntax import Futhark.Pass simplify :: (Prog rep -> PassM (Prog rep)) -> Pass rep rep simplify = Pass "simplify" "Perform simple enabling optimisations." simplifySOACS :: Pass SOACS.SOACS SOACS.SOACS simplifySOACS = simplify SOACS.simplifySOACS simplifyGPU :: Pass GPU.GPU GPU.GPU simplifyGPU = simplify GPU.simplifyGPU simplifySeq :: Pass Seq.Seq Seq.Seq simplifySeq = simplify Seq.simplifyProg simplifyMC :: Pass MC.MC MC.MC simplifyMC = simplify MC.simplifyProg simplifyGPUMem :: Pass GPUMem.GPUMem GPUMem.GPUMem simplifyGPUMem = simplify GPUMem.simplifyProg simplifySeqMem :: Pass SeqMem.SeqMem SeqMem.SeqMem simplifySeqMem = simplify SeqMem.simplifyProg simplifyMCMem :: Pass MCMem.MCMem MCMem.MCMem simplifyMCMem = simplify MCMem.simplifyProg futhark-0.25.27/src/Futhark/Passes.hs000066400000000000000000000136641475065116200173410ustar00rootroot00000000000000-- | Optimisation pipelines. module Futhark.Passes ( standardPipeline, seqPipeline, gpuPipeline, seqmemPipeline, gpumemPipeline, mcPipeline, mcmemPipeline, ) where import Control.Category ((>>>)) import Futhark.IR.GPU (GPU) import Futhark.IR.GPUMem (GPUMem) import Futhark.IR.MC (MC) import Futhark.IR.MCMem (MCMem) import Futhark.IR.SOACS (SOACS, usesAD) import Futhark.IR.Seq (Seq) import Futhark.IR.SeqMem (SeqMem) import Futhark.Optimise.ArrayLayout import Futhark.Optimise.ArrayShortCircuiting qualified as ArrayShortCircuiting import Futhark.Optimise.CSE import Futhark.Optimise.DoubleBuffer import Futhark.Optimise.EntryPointMem import Futhark.Optimise.Fusion import Futhark.Optimise.GenRedOpt import Futhark.Optimise.HistAccs import Futhark.Optimise.InliningDeadFun import Futhark.Optimise.MemoryBlockMerging qualified as MemoryBlockMerging import Futhark.Optimise.MergeGPUBodies import Futhark.Optimise.ReduceDeviceSyncs import Futhark.Optimise.Sink import Futhark.Optimise.TileLoops import Futhark.Optimise.Unstream import Futhark.Pass.AD import Futhark.Pass.ExpandAllocations import Futhark.Pass.ExplicitAllocations.GPU qualified as GPU import Futhark.Pass.ExplicitAllocations.MC qualified as MC import Futhark.Pass.ExplicitAllocations.Seq qualified as Seq import Futhark.Pass.ExtractKernels import Futhark.Pass.ExtractMulticore import Futhark.Pass.FirstOrderTransform import Futhark.Pass.LiftAllocations as LiftAllocations import Futhark.Pass.LowerAllocations as LowerAllocations import Futhark.Pass.Simplify import Futhark.Pipeline -- | A pipeline used by all current compilers. Performs inlining, -- fusion, and various forms of cleanup. This pipeline will be -- followed by another one that deals with parallelism and memory. standardPipeline :: Pipeline SOACS SOACS standardPipeline = passes [ simplifySOACS, inlineConservatively, simplifySOACS, inlineAggressively, simplifySOACS, performCSE True, simplifySOACS, fuseSOACs, performCSE True, simplifySOACS, removeDeadFunctions ] >>> condPipeline usesAD adPipeline -- | This is the pipeline that applies the AD transformation and -- subsequent interesting optimisations. adPipeline :: Pipeline SOACS SOACS adPipeline = passes [ applyAD, simplifySOACS, performCSE True, fuseSOACs, performCSE True, simplifySOACS ] -- | The pipeline used by the CUDA, HIP, and OpenCL backends, but before -- adding memory information. Includes 'standardPipeline'. gpuPipeline :: Pipeline SOACS GPU gpuPipeline = standardPipeline >>> onePass extractKernels >>> passes [ simplifyGPU, optimiseGenRed, simplifyGPU, tileLoops, simplifyGPU, histAccsGPU, unstreamGPU, performCSE True, simplifyGPU, sinkGPU, -- Sink reads before migrating them. reduceDeviceSyncs, simplifyGPU, -- Simplify and hoist storages. performCSE True, -- Eliminate duplicate storages. mergeGPUBodies, simplifyGPU, -- Cleanup merged GPUBody kernels. sinkGPU, -- Sink reads within GPUBody kernels. optimiseArrayLayoutGPU, -- Important to simplify after coalescing in order to fix up -- redundant manifests. simplifyGPU, performCSE True ] -- | The pipeline used by the sequential backends. Turns all -- parallelism into sequential loops. Includes 'standardPipeline'. seqPipeline :: Pipeline SOACS Seq seqPipeline = standardPipeline >>> onePass firstOrderTransform >>> passes [ simplifySeq ] -- | Run 'seqPipeline', then add memory information (and -- optimise it slightly). seqmemPipeline :: Pipeline SOACS SeqMem seqmemPipeline = seqPipeline >>> onePass Seq.explicitAllocations >>> passes [ performCSE False, simplifySeqMem, entryPointMemSeq, simplifySeqMem, LiftAllocations.liftAllocationsSeqMem, simplifySeqMem, ArrayShortCircuiting.optimiseSeqMem, simplifySeqMem, performCSE False, simplifySeqMem, LowerAllocations.lowerAllocationsSeqMem, simplifySeqMem ] -- | Run 'gpuPipeline', then add memory information (and optimise -- it a lot). gpumemPipeline :: Pipeline SOACS GPUMem gpumemPipeline = gpuPipeline >>> onePass GPU.explicitAllocations >>> passes [ simplifyGPUMem, performCSE False, simplifyGPUMem, entryPointMemGPU, doubleBufferGPU, simplifyGPUMem, performCSE False, LiftAllocations.liftAllocationsGPUMem, simplifyGPUMem, ArrayShortCircuiting.optimiseGPUMem, simplifyGPUMem, performCSE False, simplifyGPUMem, LowerAllocations.lowerAllocationsGPUMem, performCSE False, simplifyGPUMem, MemoryBlockMerging.optimise, simplifyGPUMem, expandAllocations, simplifyGPUMem ] -- | Run 'standardPipeline' and then convert to multicore -- representation (and do a bunch of optimisation). mcPipeline :: Pipeline SOACS MC mcPipeline = standardPipeline >>> onePass extractMulticore >>> passes [ simplifyMC, unstreamMC, performCSE True, simplifyMC, sinkMC, optimiseArrayLayoutMC, simplifyMC, performCSE True ] -- | Run 'mcPipeline' and then add memory information. mcmemPipeline :: Pipeline SOACS MCMem mcmemPipeline = mcPipeline >>> onePass MC.explicitAllocations >>> passes [ simplifyMCMem, performCSE False, simplifyMCMem, entryPointMemMC, doubleBufferMC, simplifyMCMem, performCSE False, LiftAllocations.liftAllocationsMCMem, simplifyMCMem, ArrayShortCircuiting.optimiseMCMem, simplifyMCMem, performCSE False, simplifyMCMem, LowerAllocations.lowerAllocationsMCMem, performCSE False, simplifyMCMem ] futhark-0.25.27/src/Futhark/Pipeline.hs000066400000000000000000000145551475065116200176500ustar00rootroot00000000000000-- | Definition of the core compiler driver building blocks. The -- spine of the compiler is the 'FutharkM' monad, although note that -- individual passes are pure functions, and do not use the 'FutharkM' -- monad (see "Futhark.Pass"). -- -- Running the compiler involves producing an initial IR program (see -- "Futhark.Compiler.Program"), running a 'Pipeline' to produce a -- final program (still in IR), then running an 'Action', which is -- usually a code generator. module Futhark.Pipeline ( Pipeline, PipelineConfig (..), Action (..), FutharkM, runFutharkM, Verbosity (..), module Futhark.Error, onePass, passes, condPipeline, runPipeline, ) where import Control.Category import Control.Exception (SomeException, catch, throwIO) import Control.Monad import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State import Control.Parallel import Data.Text qualified as T import Data.Text.IO qualified as T import Data.Time.Clock import Futhark.Analysis.Alias qualified as Alias import Futhark.Compiler.Config (Verbosity (..)) import Futhark.Error import Futhark.IR (PrettyRep, Prog) import Futhark.IR.TypeCheck import Futhark.MonadFreshNames import Futhark.Pass import Futhark.Util.Log import Futhark.Util.Pretty (prettyText) import System.IO import Text.Printf import Prelude hiding (id, (.)) newtype FutharkEnv = FutharkEnv {futharkVerbose :: Verbosity} data FutharkState = FutharkState { futharkPrevLog :: UTCTime, futharkNameSource :: VNameSource } -- | The main Futhark compiler driver monad - basically some state -- tracking on top if 'IO'. newtype FutharkM a = FutharkM (ExceptT CompilerError (StateT FutharkState (ReaderT FutharkEnv IO)) a) deriving ( Applicative, Functor, Monad, MonadError CompilerError, MonadState FutharkState, MonadReader FutharkEnv, MonadIO ) instance MonadFreshNames FutharkM where getNameSource = gets futharkNameSource putNameSource src = modify $ \s -> s {futharkNameSource = src} instance MonadLogger FutharkM where addLog = mapM_ perLine . T.lines . toText where perLine msg = do verb <- asks $ (>= Verbose) . futharkVerbose prev <- gets futharkPrevLog now <- liftIO getCurrentTime let delta :: Double delta = fromRational $ toRational (now `diffUTCTime` prev) prefix = printf "[ +%7.3f] " delta modify $ \s -> s {futharkPrevLog = now} when verb $ liftIO $ T.hPutStrLn stderr $ T.pack prefix <> msg runFutharkM' :: FutharkM a -> FutharkState -> FutharkEnv -> IO (Either CompilerError a, FutharkState) runFutharkM' (FutharkM m) s = runReaderT (runStateT (runExceptT m) s) -- | Run a 'FutharkM' action. runFutharkM :: FutharkM a -> Verbosity -> IO (Either CompilerError a) runFutharkM m verbose = do s <- FutharkState <$> getCurrentTime <*> pure blankNameSource fst <$> runFutharkM' m s (FutharkEnv verbose) catchIO :: FutharkM a -> (SomeException -> FutharkM a) -> FutharkM a catchIO m f = FutharkM $ do s <- get env <- ask (x, s') <- liftIO $ runFutharkM' m s env `catch` \e -> runFutharkM' (f e) s env put s' case x of Left e -> throwError e Right x' -> pure x' -- | A compilation always ends with some kind of action. data Action rep = Action { actionName :: String, actionDescription :: String, actionProcedure :: Prog rep -> FutharkM () } -- | Configuration object for running a compiler pipeline. data PipelineConfig = PipelineConfig { pipelineVerbose :: Bool, pipelineValidate :: Bool } -- | A compiler pipeline is conceptually a function from programs to -- programs, where the actual representation may change. Pipelines -- can be composed using their 'Category' instance. newtype Pipeline fromrep torep = Pipeline { unPipeline :: forall a. PipelineConfig -> Prog fromrep -> FutharkM ((Prog torep -> FutharkM a) -> FutharkM a) } instance Category Pipeline where id = Pipeline $ \_ prog -> pure $ \c -> c prog p2 . p1 = Pipeline perform where perform cfg prog = do rc <- unPipeline p1 cfg prog rc $ unPipeline p2 cfg -- | Run the pipeline on the given program. runPipeline :: Pipeline fromrep torep -> PipelineConfig -> Prog fromrep -> FutharkM (Prog torep) runPipeline p cfg prog = do rc <- unPipeline p cfg prog rc pure -- | Construct a pipeline from a single compiler pass. onePass :: (Checkable torep) => Pass fromrep torep -> Pipeline fromrep torep onePass pass = Pipeline perform where perform cfg prog = do when (pipelineVerbose cfg) . logMsg $ "Running pass: " <> T.pack (passName pass) prog' <- runPass pass prog -- Spark validation in a separate task and speculatively execute -- next pass. If the next pass throws an exception, we better -- be ready to catch it and check if it might be because the -- program was actually ill-typed. let check = if pipelineValidate cfg then validate prog' else Right () par check $ pure $ \c -> (errorOnError check pure =<< c prog') `catchIO` errorOnError check (liftIO . throwIO) validate prog = let prog' = Alias.aliasAnalysis prog in case checkProg prog' of Left err -> Left (prog', err) Right () -> Right () errorOnError (Left (prog, err)) _ _ = validationError pass prog $ show err errorOnError _ c x = c x -- | Conditionally run pipeline if predicate is true. condPipeline :: (Prog rep -> Bool) -> Pipeline rep rep -> Pipeline rep rep condPipeline cond (Pipeline f) = Pipeline $ \cfg prog -> if cond prog then f cfg prog else pure $ \c -> c prog -- | Create a pipeline from a list of passes. passes :: (Checkable rep) => [Pass rep rep] -> Pipeline rep rep passes = foldl (>>>) id . map onePass validationError :: (PrettyRep rep) => Pass fromrep torep -> Prog rep -> String -> FutharkM a validationError pass prog err = throwError $ InternalError msg (prettyText prog) CompilerBug where msg = "Type error after pass '" <> T.pack (passName pass) <> "':\n" <> T.pack err runPass :: Pass fromrep torep -> Prog fromrep -> FutharkM (Prog torep) runPass pass prog = do (prog', logged) <- runPassM (passFunction pass prog) verb <- asks $ (>= VeryVerbose) . futharkVerbose when verb $ addLog logged pure prog' futhark-0.25.27/src/Futhark/Pkg/000077500000000000000000000000001475065116200162565ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Pkg/Info.hs000066400000000000000000000300171475065116200175060ustar00rootroot00000000000000-- | Obtaining information about packages over THE INTERNET! module Futhark.Pkg.Info ( -- * Package info PkgInfo (..), lookupPkgRev, pkgInfo, PkgRevInfo (..), GetManifest (getManifest), GetFiles (getFiles), CacheDir (..), -- * Package registry PkgRegistry, MonadPkgRegistry (..), lookupPackage, lookupPackageRev, lookupNewestRev, ) where import Control.Monad (unless, void) import Control.Monad.IO.Class import Data.ByteString qualified as BS import Data.IORef import Data.List qualified as L import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T import Data.Text.Encoding qualified as T import Data.Text.IO qualified as T import Data.Time (UTCTime, defaultTimeLocale, formatTime) import Data.Time.Format.ISO8601 (iso8601ParseM) import Data.Time.LocalTime (zonedTimeToUTC) import Futhark.Pkg.Types import Futhark.Util (directoryContents, showText, zEncodeText) import Futhark.Util.Log import System.Directory (doesDirectoryExist) import System.Exit import System.FilePath (makeRelative, ()) import System.Process.ByteString (readProcessWithExitCode) -- | The manifest is stored as a monadic action, because we want to -- fetch them on-demand. It would be a waste to fetch it information -- for every version of every package if we only actually need a small -- subset of them. newtype GetManifest m = GetManifest {getManifest :: m PkgManifest} instance Show (GetManifest m) where show _ = "#" instance Eq (GetManifest m) where _ == _ = True -- | Get the absolute path to a package directory on disk, as well as -- /relative/ paths to files that should be installed from this -- package. Composing the package directory and one of these paths -- refers to a local file (pointing into the cache) and is valid at -- least until the next cache operation. newtype GetFiles m = GetFiles {getFiles :: m (FilePath, [FilePath])} instance Show (GetFiles m) where show _ = "#" instance Eq (GetFiles m) where _ == _ = True -- | Information about a version of a single package. The version -- number is stored separately. data PkgRevInfo m = PkgRevInfo { pkgGetFiles :: GetFiles m, -- | The commit ID can be used for verification ("freezing"), by -- storing what it was at the time this version was last selected. pkgRevCommit :: T.Text, pkgRevGetManifest :: GetManifest m, -- | Timestamp for when the revision was made (rarely used). pkgRevTime :: UTCTime } deriving (Eq, Show) -- | Create memoisation around a 'GetManifest' action to ensure that -- multiple inspections of the same revisions will not result in -- potentially expensive IO operations. memoiseGetManifest :: (MonadIO m) => GetManifest m -> m (GetManifest m) memoiseGetManifest (GetManifest m) = do ref <- liftIO $ newIORef Nothing pure $ GetManifest $ do v <- liftIO $ readIORef ref case v of Just v' -> pure v' Nothing -> do v' <- m liftIO $ writeIORef ref $ Just v' pure v' -- | Information about a package. The name of the package is stored -- separately. data PkgInfo m = PkgInfo { pkgVersions :: M.Map SemVer (PkgRevInfo m), -- | Look up information about a specific -- commit, or HEAD in case of Nothing. pkgLookupCommit :: Maybe T.Text -> m (PkgRevInfo m) } -- | Lookup information about a given version of a package. lookupPkgRev :: SemVer -> PkgInfo m -> Maybe (PkgRevInfo m) lookupPkgRev v = M.lookup v . pkgVersions majorRevOfPkg :: PkgPath -> (T.Text, [Word]) majorRevOfPkg p = case T.splitOn "@" p of [p', v] | [(v', "")] <- reads $ T.unpack v -> (p', [v']) _ -> (p, [0, 1]) gitCmd :: (MonadIO m, MonadLogger m, MonadFail m) => [String] -> m BS.ByteString gitCmd opts = do logMsg $ "Running command: " <> T.unwords ("git" : map T.pack opts) (code, out, err) <- liftIO $ readProcessWithExitCode "git" opts mempty unless (err == mempty) $ logMsg $ T.decodeUtf8 err case code of ExitFailure 127 -> fail $ "'" <> unwords ("git" : opts) <> "' failed (program not found?)." ExitFailure _ -> fail $ "'" <> unwords ("git" : opts) <> "' failed." ExitSuccess -> pure out gitCmd_ :: (MonadIO m, MonadLogger m, MonadFail m) => [String] -> m () gitCmd_ = void . gitCmd gitCmdLines :: (MonadIO m, MonadLogger m, MonadFail m) => [String] -> m [T.Text] gitCmdLines = fmap (T.lines . T.decodeUtf8) . gitCmd -- | A temporary directory in which we store Git checkouts while -- running. This is to avoid constantly re-cloning. Will be deleted -- when @futhark pkg@ terminates. In principle we could keep this -- around for longer, but then we would have to 'git pull' now and -- then also. Note that the cache is stateful - we are going to use -- @git checkout@ to move around the history. It is generally not -- safe to have multiple operations running concurrently. newtype CacheDir = CacheDir FilePath ensureGit :: (MonadIO m, MonadLogger m, MonadFail m) => CacheDir -> T.Text -> m FilePath ensureGit (CacheDir cachedir) url = do exists <- liftIO $ doesDirectoryExist gitdir unless exists $ gitCmd_ ["-C", cachedir, "clone", "https://" <> T.unpack url, url'] pure gitdir where url' = T.unpack $ zEncodeText url gitdir = cachedir url' -- A git reference (tag, commit, HEAD, etc). type Ref = String versionRef :: SemVer -> Ref versionRef v = T.unpack $ "v" <> prettySemVer v revInfo :: (MonadIO m, MonadLogger m, MonadFail m) => FilePath -> PkgPath -> Ref -> m (PkgRevInfo m) revInfo gitdir path ref = do gitCmd_ ["-C", gitdir, "rev-parse", ref, "--"] [sha] <- gitCmdLines ["-C", gitdir, "rev-list", "-n1", ref] [time] <- gitCmdLines ["-C", gitdir, "show", "-s", "--format=%cI", ref] utc <- -- Git sometimes produces timestamps with Z time zone, which are -- not valid ZonedTimes. if 'Z' `T.elem` time then iso8601ParseM (T.unpack time) else zonedTimeToUTC <$> iso8601ParseM (T.unpack time) gm <- memoiseGetManifest getManifest' pure $ PkgRevInfo { pkgGetFiles = getFiles gm, pkgRevCommit = sha, pkgRevGetManifest = gm, pkgRevTime = utc } where noPkgDir pdir = fail $ T.unpack path <> "-" <> ref <> " does not contain a directory " <> pdir noPkgPath = fail $ "futhark.pkg for " <> T.unpack path <> "-" <> ref <> " does not define a package path." getFiles gm = GetFiles $ do gitCmd_ ["-C", gitdir, "checkout", ref, "--"] pdir <- maybe noPkgPath pure . pkgDir =<< getManifest gm let pdir_abs = gitdir pdir exists <- liftIO $ doesDirectoryExist pdir_abs unless exists $ noPkgDir pdir fs <- liftIO $ directoryContents pdir_abs pure (pdir_abs, map (makeRelative pdir_abs) fs) getManifest' = GetManifest $ do gitCmd_ ["-C", gitdir, "checkout", ref, "--"] let f = gitdir futharkPkg s <- liftIO $ T.readFile f let msg = "When reading package manifest for " <> T.unpack path <> " " <> ref <> ":\n" case parsePkgManifest f s of Left e -> fail $ msg <> errorBundlePretty e Right pm -> pure pm -- | Retrieve information about a package based on its package path. -- This uses Semantic Import Versioning when interacting with -- repositories. For example, a package @github.com/user/repo@ will -- match version 0.* or 1.* tags only, a package -- @github.com/user/repo/v2@ will match 2.* tags, and so forth.. pkgInfo :: (MonadIO m, MonadLogger m, MonadFail m) => CacheDir -> PkgPath -> m (PkgInfo m) pkgInfo cachedir path = do gitdir <- ensureGit cachedir url versions <- mapMaybe isVersionRef <$> gitCmdLines ["-C", gitdir, "tag"] versions' <- M.fromList . zip versions <$> mapM (revInfo gitdir path . versionRef) versions pure $ PkgInfo versions' $ lookupCommit gitdir where (url, path_versions) = majorRevOfPkg path isVersionRef l | "v" `T.isPrefixOf` l, Right v <- parseVersion $ T.drop 1 l, _svMajor v `elem` path_versions = Just v | otherwise = Nothing lookupCommit gitdir = revInfo gitdir path . maybe "HEAD" T.unpack -- | A package registry is a mapping from package paths to information -- about the package. It is unlikely that any given registry is -- global; rather small registries are constructed on-demand based on -- the package paths referenced by the user, and may also be combined -- monoidically. In essence, the PkgRegistry is just a cache. newtype PkgRegistry m = PkgRegistry (M.Map PkgPath (PkgInfo m)) instance Semigroup (PkgRegistry m) where PkgRegistry x <> PkgRegistry y = PkgRegistry $ x <> y instance Monoid (PkgRegistry m) where mempty = PkgRegistry mempty lookupKnownPackage :: PkgPath -> PkgRegistry m -> Maybe (PkgInfo m) lookupKnownPackage p (PkgRegistry m) = M.lookup p m -- | Monads that support a stateful package registry. These are also -- required to be instances of 'MonadIO' because most package registry -- operations involve network operations. class (MonadIO m, MonadLogger m, MonadFail m) => MonadPkgRegistry m where getPkgRegistry :: m (PkgRegistry m) putPkgRegistry :: PkgRegistry m -> m () modifyPkgRegistry :: (PkgRegistry m -> PkgRegistry m) -> m () modifyPkgRegistry f = putPkgRegistry . f =<< getPkgRegistry -- | Given a package path, look up information about that package. lookupPackage :: (MonadPkgRegistry m) => CacheDir -> PkgPath -> m (PkgInfo m) lookupPackage cachedir p = do r@(PkgRegistry m) <- getPkgRegistry case lookupKnownPackage p r of Just info -> pure info Nothing -> do pinfo <- pkgInfo cachedir p putPkgRegistry $ PkgRegistry $ M.insert p pinfo m pure pinfo lookupPackageCommit :: (MonadPkgRegistry m) => CacheDir -> PkgPath -> Maybe T.Text -> m (SemVer, PkgRevInfo m) lookupPackageCommit cachedir p ref = do pinfo <- lookupPackage cachedir p rev_info <- pkgLookupCommit pinfo ref let timestamp = T.pack $ formatTime defaultTimeLocale "%Y%m%d%H%M%S" $ pkgRevTime rev_info v = commitVersion timestamp $ pkgRevCommit rev_info pinfo' = pinfo {pkgVersions = M.insert v rev_info $ pkgVersions pinfo} modifyPkgRegistry $ \(PkgRegistry m) -> PkgRegistry $ M.insert p pinfo' m pure (v, rev_info) -- | Look up information about a specific version of a package. lookupPackageRev :: (MonadPkgRegistry m) => CacheDir -> PkgPath -> SemVer -> m (PkgRevInfo m) lookupPackageRev cachedir p v | Just commit <- isCommitVersion v = snd <$> lookupPackageCommit cachedir p (Just commit) | otherwise = do pinfo <- lookupPackage cachedir p case lookupPkgRev v pinfo of Nothing -> let versions = case M.keys $ pkgVersions pinfo of [] -> "Package " <> p <> " has no versions. Invalid package path?" ks -> "Known versions: " <> T.concat (L.intersperse ", " $ map prettySemVer ks) major | (_, vs) <- majorRevOfPkg p, _svMajor v `notElem` vs = "\nFor major version " <> showText (_svMajor v) <> ", use package path " <> p <> "@" <> showText (_svMajor v) | otherwise = mempty in fail $ T.unpack $ "package " <> p <> " does not have a version " <> prettySemVer v <> ".\n" <> versions <> major Just v' -> pure v' -- | Find the newest version of a package. lookupNewestRev :: (MonadPkgRegistry m) => CacheDir -> PkgPath -> m SemVer lookupNewestRev cachedir p = do pinfo <- lookupPackage cachedir p case M.keys $ pkgVersions pinfo of [] -> do logMsg $ "Package " <> p <> " has no released versions. Using HEAD." fst <$> lookupPackageCommit cachedir p Nothing v : vs -> pure $ L.foldl' max v vs futhark-0.25.27/src/Futhark/Pkg/Solve.hs000066400000000000000000000101641475065116200177040ustar00rootroot00000000000000-- | Dependency solver -- -- This is a relatively simple problem due to the choice of the -- Minimum Package Version algorithm. In fact, the only failure mode -- is referencing an unknown package or revision. module Futhark.Pkg.Solve ( solveDeps, solveDepsPure, PkgRevDepInfo, ) where import Control.Monad import Control.Monad.Free.Church import Control.Monad.State import Data.Map qualified as M import Data.Set qualified as S import Data.Text qualified as T import Futhark.Pkg.Info import Futhark.Pkg.Types import Prelude data PkgOp a = OpGetDeps PkgPath SemVer (Maybe T.Text) (PkgRevDeps -> a) instance Functor PkgOp where fmap f (OpGetDeps p v h c) = OpGetDeps p v h (f . c) -- | A rough build list is like a build list, but may contain packages -- that are not reachable from the root. Also contains the -- dependencies of each package. newtype RoughBuildList = RoughBuildList (M.Map PkgPath (SemVer, [PkgPath])) deriving (Show) emptyRoughBuildList :: RoughBuildList emptyRoughBuildList = RoughBuildList mempty depRoots :: PkgRevDeps -> S.Set PkgPath depRoots (PkgRevDeps m) = S.fromList $ M.keys m -- Construct a 'BuildList' from a 'RoughBuildList'. This involves -- pruning all packages that cannot be reached from the root. buildList :: S.Set PkgPath -> RoughBuildList -> BuildList buildList roots (RoughBuildList pkgs) = BuildList $ execState (mapM_ addPkg roots) mempty where addPkg p = case M.lookup p pkgs of Nothing -> pure () Just (v, deps) -> do listed <- gets $ M.member p modify $ M.insert p v unless listed $ mapM_ addPkg deps type SolveM = StateT RoughBuildList (F PkgOp) getDeps :: PkgPath -> SemVer -> Maybe T.Text -> SolveM PkgRevDeps getDeps p v h = lift $ liftF $ OpGetDeps p v h id -- | Given a list of immediate dependency minimum version constraints, -- find dependency versions that fit, including transitive -- dependencies. doSolveDeps :: PkgRevDeps -> SolveM () doSolveDeps (PkgRevDeps deps) = mapM_ add $ M.toList deps where add (p, (v, maybe_h)) = do RoughBuildList l <- get case M.lookup p l of -- Already satisfied? Just (cur_v, _) | v <= cur_v -> pure () -- No; add 'p' and its dependencies. _ -> do PkgRevDeps p_deps <- getDeps p v maybe_h put $ RoughBuildList $ M.insert p (v, M.keys p_deps) l mapM_ add $ M.toList p_deps -- | Run the solver, producing both a package registry containing -- a cache of the lookups performed, as well as a build list. solveDeps :: (MonadPkgRegistry m) => CacheDir -> PkgRevDeps -> m BuildList solveDeps cachedir deps = buildList (depRoots deps) <$> runF (execStateT (doSolveDeps deps) emptyRoughBuildList) pure step where step (OpGetDeps p v h c) = do pinfo <- lookupPackageRev cachedir p v checkHash p v pinfo h d <- fmap pkgRevDeps . getManifest $ pkgRevGetManifest pinfo c d checkHash _ _ _ Nothing = pure () checkHash p v pinfo (Just h) | h == pkgRevCommit pinfo = pure () | otherwise = fail $ T.unpack $ "Package " <> p <> " " <> prettySemVer v <> " has commit hash " <> pkgRevCommit pinfo <> ", but expected " <> h <> " from package manifest." -- | A mapping of package revisions to the dependencies of that -- package. Can be considered a 'PkgRegistry' without the option of -- obtaining more information from the Internet. Probably useful only -- for testing the solver. type PkgRevDepInfo = M.Map (PkgPath, SemVer) PkgRevDeps -- | Perform package resolution with only pre-known information. This -- is useful for testing. solveDepsPure :: PkgRevDepInfo -> PkgRevDeps -> Either T.Text BuildList solveDepsPure r deps = buildList (depRoots deps) <$> runF (execStateT (doSolveDeps deps) emptyRoughBuildList) Right step where step (OpGetDeps p v _ c) = do let errmsg = "Unknown package/version: " <> p <> "-" <> prettySemVer v d <- maybe (Left errmsg) Right $ M.lookup (p, v) r c d futhark-0.25.27/src/Futhark/Pkg/Types.hs000066400000000000000000000241311475065116200177170ustar00rootroot00000000000000-- | Types (and a few other simple definitions) for futhark-pkg. module Futhark.Pkg.Types ( PkgPath, pkgPathFilePath, PkgRevDeps (..), module Data.Versions, -- * Versions commitVersion, isCommitVersion, parseVersion, -- * Package manifests PkgManifest (..), newPkgManifest, pkgRevDeps, pkgDir, addRequiredToManifest, removeRequiredFromManifest, prettyPkgManifest, Comment, Commented (..), Required (..), futharkPkg, -- * Parsing package manifests parsePkgManifest, parsePkgManifestFromFile, errorBundlePretty, -- * Build list BuildList (..), prettyBuildList, ) where import Control.Applicative import Control.Monad import Data.Either import Data.Foldable import Data.List (sortOn) import Data.List.NonEmpty qualified as NE import Data.Map qualified as M import Data.Maybe import Data.Text qualified as T import Data.Text.IO qualified as T import Data.Traversable import Data.Versions (Chunk (Alphanum), Release (Release), SemVer (..), prettySemVer) import Data.Void import System.FilePath import System.FilePath.Posix qualified as Posix import Text.Megaparsec hiding (many, some) import Text.Megaparsec.Char import Prelude -- | A package path is a unique identifier for a package, for example -- @github.com/user/foo@. type PkgPath = T.Text -- | Turn a package path (which always uses forward slashes) into a -- file path in the local file system (which might use different -- slashes). pkgPathFilePath :: PkgPath -> FilePath pkgPathFilePath = joinPath . Posix.splitPath . T.unpack -- | Versions of the form (0,0,0)-timestamp+hash are treated -- specially, as a reference to the commit identified uniquely with -- @hash@ (typically the Git commit ID). This function detects such -- versions. isCommitVersion :: SemVer -> Maybe T.Text isCommitVersion (SemVer 0 0 0 (Just (Release (_ NE.:| []))) (Just s)) = Just s isCommitVersion _ = Nothing -- | @commitVersion timestamp commit@ constructs a commit version. commitVersion :: T.Text -> T.Text -> SemVer commitVersion time commit = SemVer 0 0 0 (Just $ Release $ NE.singleton $ Alphanum time) (Just commit) -- | Unfortunately, Data.Versions has a buggy semver parser that -- collapses consecutive zeroes in the metadata field. So, we define -- our own parser here. It's a little simpler too, since we don't -- need full semver. parseVersion :: T.Text -> Either (ParseErrorBundle T.Text Void) SemVer parseVersion = parse (semver' <* eof) "Semantic Version" semver' :: Parsec Void T.Text SemVer semver' = SemVer <$> majorP <*> minorP <*> patchP <*> preRel <*> metaData where majorP = digitsP <* char '.' minorP = majorP patchP = digitsP digitsP = read <$> ((T.unpack <$> string "0") <|> some digitChar) preRel = fmap (Release . NE.singleton) <$> optional preRel' preRel' = char '-' *> (Alphanum . T.pack <$> some digitChar) metaData = optional metaData' metaData' = char '+' *> (T.pack <$> some alphaNumChar) -- | The dependencies of a (revision of a) package is a mapping from -- package paths to minimum versions (and an optional hash pinning). newtype PkgRevDeps = PkgRevDeps (M.Map PkgPath (SemVer, Maybe T.Text)) deriving (Show) instance Semigroup PkgRevDeps where PkgRevDeps x <> PkgRevDeps y = PkgRevDeps $ x <> y instance Monoid PkgRevDeps where mempty = PkgRevDeps mempty --- Package manifest -- | A line comment. type Comment = T.Text -- | Wraps a value with an annotation of preceding line comments. -- This is important to our goal of being able to programmatically -- modify the @futhark.pkg@ file while keeping comments intact. data Commented a = Commented { comments :: [Comment], commented :: a } deriving (Show, Eq) instance Functor Commented where fmap = fmapDefault instance Foldable Commented where foldMap = foldMapDefault instance Traversable Commented where traverse f (Commented cs x) = Commented cs <$> f x -- | An entry in the @required@ section of a @futhark.pkg@ file. data Required = Required { -- | Name of the required package. requiredPkg :: PkgPath, -- | The minimum revision. requiredPkgRev :: SemVer, -- | An optional hash indicating what -- this revision looked like the last -- time we saw it. Used for integrity -- checking. requiredHash :: Maybe T.Text } deriving (Show, Eq) -- | The name of the file containing the futhark-pkg manifest. futharkPkg :: FilePath futharkPkg = "futhark.pkg" -- | A structure corresponding to a @futhark.pkg@ file, including -- comments. It is an invariant that duplicate required packages do -- not occcur (the parser will verify this). data PkgManifest = PkgManifest { -- | The name of the package. manifestPkgPath :: Commented (Maybe PkgPath), manifestRequire :: Commented [Either Comment Required], manifestEndComments :: [Comment] } deriving (Show, Eq) -- | Possibly given a package path, construct an otherwise-empty manifest file. newPkgManifest :: Maybe PkgPath -> PkgManifest newPkgManifest p = PkgManifest (Commented mempty p) (Commented mempty mempty) mempty -- | Prettyprint a package manifest such that it can be written to a -- @futhark.pkg@ file. prettyPkgManifest :: PkgManifest -> T.Text prettyPkgManifest (PkgManifest name required endcs) = T.unlines $ concat [ prettyComments name, maybe [] (pure . ("package " <>) . (<> "\n")) $ commented name, prettyComments required, ["require {"], map ((" " <>) . prettyRequired) $ commented required, ["}"], map prettyComment endcs ] where prettyComments = map prettyComment . comments prettyComment = ("--" <>) prettyRequired (Left c) = prettyComment c prettyRequired (Right (Required p r h)) = T.unwords $ catMaybes [ Just p, Just $ prettySemVer r, ("#" <>) <$> h ] -- | The required packages listed in a package manifest. pkgRevDeps :: PkgManifest -> PkgRevDeps pkgRevDeps = PkgRevDeps . M.fromList . mapMaybe onR . commented . manifestRequire where onR (Right r) = Just (requiredPkg r, (requiredPkgRev r, requiredHash r)) onR (Left _) = Nothing -- | Where in the corresponding repository archive we can expect to -- find the package files. pkgDir :: PkgManifest -> Maybe Posix.FilePath pkgDir = fmap ( Posix.addTrailingPathSeparator . ("lib" Posix.) . T.unpack ) . commented . manifestPkgPath -- | Add new required package to the package manifest. If the package -- was already present, return the old version. addRequiredToManifest :: Required -> PkgManifest -> (PkgManifest, Maybe Required) addRequiredToManifest new_r pm = let (old, requires') = mapAccumL add Nothing $ commented $ manifestRequire pm in ( if isJust old then pm {manifestRequire = requires' <$ manifestRequire pm} else pm {manifestRequire = (++ [Right new_r]) <$> manifestRequire pm}, old ) where add acc (Left c) = (acc, Left c) add acc (Right r) | requiredPkg r == requiredPkg new_r = (Just r, Right new_r) | otherwise = (acc, Right r) -- | Check if the manifest specifies a required package with the given -- package path. requiredInManifest :: PkgPath -> PkgManifest -> Maybe Required requiredInManifest p = find ((== p) . requiredPkg) . rights . commented . manifestRequire -- | Remove a required package from the manifest. Returns 'Nothing' -- if the package was not found in the manifest, and otherwise the new -- manifest and the 'Required' that was present. removeRequiredFromManifest :: PkgPath -> PkgManifest -> Maybe (PkgManifest, Required) removeRequiredFromManifest p pm = do r <- requiredInManifest p pm pure ( pm {manifestRequire = filter (not . matches) <$> manifestRequire pm}, r ) where matches = either (const False) ((== p) . requiredPkg) --- Parsing futhark.pkg. type Parser = Parsec Void T.Text pPkgManifest :: Parser PkgManifest pPkgManifest = do c1 <- pComments p <- optional $ lexstr "package" *> pPkgPath space c2 <- pComments required <- ( lexstr "require" *> braces (many $ (Left <$> pComment) <|> (Right <$> pRequired)) ) <|> pure [] c3 <- pComments eof pure $ PkgManifest (Commented c1 p) (Commented c2 required) c3 where lexeme :: Parser a -> Parser a lexeme p = p <* space lexeme' p = p <* spaceNoEol lexstr :: T.Text -> Parser () lexstr = void . try . lexeme . string braces :: Parser a -> Parser a braces p = lexstr "{" *> p <* lexstr "}" spaceNoEol = many $ oneOf (" \t" :: String) pPkgPath = T.pack <$> some (alphaNumChar <|> oneOf ("@-/.:" :: String)) "package path" pRequired = space *> ( Required <$> lexeme' pPkgPath <*> lexeme' semver' <*> optional (lexeme' pHash) ) <* space "package requirement" pHash = char '#' *> (T.pack <$> some alphaNumChar) pComment = lexeme $ T.pack <$> (string "--" >> anySingle `manyTill` (void eol <|> eof)) pComments :: Parser [Comment] pComments = catMaybes <$> many (comment <|> blankLine) where comment = Just <$> pComment blankLine = some spaceChar >> pure Nothing -- | Parse a pretty as a 'PkgManifest'. The 'FilePath' is used for any error messages. parsePkgManifest :: FilePath -> T.Text -> Either (ParseErrorBundle T.Text Void) PkgManifest parsePkgManifest = parse pPkgManifest -- | Read contents of file and pass it to 'parsePkgManifest'. parsePkgManifestFromFile :: FilePath -> IO PkgManifest parsePkgManifestFromFile f = do s <- T.readFile f case parsePkgManifest f s of Left err -> fail $ errorBundlePretty err Right m -> pure m -- | A mapping from package paths to their chosen revisions. This is -- the result of the version solver. newtype BuildList = BuildList {unBuildList :: M.Map PkgPath SemVer} deriving (Eq, Show) -- | Prettyprint a build list; one package per line and -- newline-terminated. prettyBuildList :: BuildList -> T.Text prettyBuildList (BuildList m) = T.unlines $ map f $ sortOn fst $ M.toList m where f (p, v) = T.unwords [p, "=>", prettySemVer v] futhark-0.25.27/src/Futhark/Profile.hs000066400000000000000000000046371475065116200175030ustar00rootroot00000000000000-- | Profiling information emitted by a running Futhark program. module Futhark.Profile ( ProfilingEvent (..), ProfilingReport (..), profilingReportFromText, decodeProfilingReport, ) where import Data.Aeson qualified as JSON import Data.Aeson.Key qualified as JSON import Data.Aeson.KeyMap qualified as JSON import Data.Bifunctor import Data.ByteString.Builder (toLazyByteString) import Data.ByteString.Lazy.Char8 qualified as LBS import Data.Map qualified as M import Data.Text qualified as T import Data.Text.Encoding (encodeUtf8Builder) -- | A thing that has occurred during execution. data ProfilingEvent = ProfilingEvent { -- | Short, single line. eventName :: T.Text, -- | In microseconds. eventDuration :: Double, -- | Long, may be multiple lines. eventDescription :: T.Text } deriving (Eq, Ord, Show) instance JSON.ToJSON ProfilingEvent where toJSON (ProfilingEvent name duration description) = JSON.object [ ("name", JSON.toJSON name), ("duration", JSON.toJSON duration), ("description", JSON.toJSON description) ] instance JSON.FromJSON ProfilingEvent where parseJSON = JSON.withObject "event" $ \o -> ProfilingEvent <$> o JSON..: "name" <*> o JSON..: "duration" <*> o JSON..: "description" -- | A profiling report contains all profiling information for a -- single benchmark (meaning a single invocation on an entry point on -- a specific dataset). data ProfilingReport = ProfilingReport { profilingEvents :: [ProfilingEvent], -- | Mapping memory spaces to bytes. profilingMemory :: M.Map T.Text Integer } deriving (Eq, Ord, Show) instance JSON.ToJSON ProfilingReport where toJSON (ProfilingReport events memory) = JSON.object [ ("events", JSON.toJSON events), ("memory", JSON.object $ map (bimap JSON.fromText JSON.toJSON) $ M.toList memory) ] instance JSON.FromJSON ProfilingReport where parseJSON = JSON.withObject "profiling-info" $ \o -> ProfilingReport <$> o JSON..: "events" <*> (JSON.toMapText <$> o JSON..: "memory") -- | Read a profiling report from a bytestring containing JSON. decodeProfilingReport :: LBS.ByteString -> Maybe ProfilingReport decodeProfilingReport = JSON.decode -- | Read a profiling report from a text containing JSON. profilingReportFromText :: T.Text -> Maybe ProfilingReport profilingReportFromText = JSON.decode . toLazyByteString . encodeUtf8Builder futhark-0.25.27/src/Futhark/Script.hs000066400000000000000000000526631475065116200173510ustar00rootroot00000000000000-- | FutharkScript is a (tiny) subset of Futhark used to write small -- expressions that are evaluated by server executables. The @futhark -- literate@ command is the main user. module Futhark.Script ( -- * Server ScriptServer, withScriptServer, withScriptServer', -- * Expressions, values, and types Func (..), Exp (..), parseExp, parseExpFromText, varsInExp, ScriptValueType (..), ScriptValue (..), scriptValueType, serverVarsInValue, ValOrVar (..), ExpValue, -- * Evaluation EvalBuiltin, scriptBuiltin, evalExp, getExpValue, evalExpToGround, valueToExp, freeValue, ) where import Control.Monad import Control.Monad.Except (MonadError (..)) import Control.Monad.IO.Class (MonadIO, liftIO) import Data.Bifunctor (bimap) import Data.ByteString qualified as BS import Data.ByteString.Lazy qualified as LBS import Data.Char import Data.Foldable (toList) import Data.Functor import Data.IORef import Data.List (intersperse) import Data.Map qualified as M import Data.Set qualified as S import Data.Text qualified as T import Data.Traversable import Data.Vector.Storable qualified as SVec import Data.Void import Data.Word (Word8) import Futhark.Data.Parser qualified as V import Futhark.Server import Futhark.Server.Values (getValue, putValue) import Futhark.Test.Values qualified as V import Futhark.Util (nubOrd) import Futhark.Util.Pretty hiding (line, sep, space, ()) import Language.Futhark.Core (Name, nameFromText, nameToText) import Language.Futhark.Tuple (areTupleFields) import System.FilePath (()) import Text.Megaparsec import Text.Megaparsec.Char (space) import Text.Megaparsec.Char.Lexer (charLiteral) type TypeMap = M.Map TypeName (Maybe [(Name, TypeName)]) typeMap :: (MonadIO m) => Server -> m TypeMap typeMap server = do liftIO $ either (pure mempty) onTypes =<< cmdTypes server where onTypes types = M.fromList . zip types <$> mapM onType types onType t = either (const Nothing) (Just . map onField) <$> cmdFields server t onField = bimap nameFromText (T.drop 1) . T.breakOn " " isRecord :: TypeName -> TypeMap -> Maybe [(Name, TypeName)] isRecord t m = join $ M.lookup t m isTuple :: TypeName -> TypeMap -> Maybe [TypeName] isTuple t m = areTupleFields . M.fromList =<< isRecord t m -- | Like a 'Server', but keeps a bit more state to make FutharkScript -- more convenient. data ScriptServer = ScriptServer { scriptServer :: Server, scriptCounter :: IORef Int, scriptTypes :: TypeMap } -- | Run an action with a 'ScriptServer' produced by an existing -- 'Server', without shutting it down at the end. withScriptServer' :: (MonadIO m) => Server -> (ScriptServer -> m a) -> m a withScriptServer' server f = do counter <- liftIO $ newIORef 0 types <- typeMap server f $ ScriptServer server counter types -- | Start a server, execute an action, then shut down the server. -- Similar to 'withServer'. withScriptServer :: ServerCfg -> (ScriptServer -> IO a) -> IO a withScriptServer cfg f = withServer cfg $ flip withScriptServer' f -- | A function called in a 'Call' expression can be either a Futhark -- function or a builtin function. data Func = FuncFut EntryName | FuncBuiltin T.Text deriving (Show) -- | A FutharkScript expression. This is a simple AST that might not -- correspond exactly to what the user wrote (e.g. no parentheses or -- source locations). This is fine for small expressions, which is -- all this is meant for. data Exp = Call Func [Exp] | Const V.Value | Tuple [Exp] | Record [(T.Text, Exp)] | StringLit T.Text | Let [VarName] Exp Exp | -- | Server-side variable, *not* Futhark variable (these are -- handled in 'Call'). ServerVar TypeName VarName deriving (Show) instance Pretty Func where pretty (FuncFut f) = pretty f pretty (FuncBuiltin f) = "$" <> pretty f instance Pretty Exp where pretty = pprPrec (0 :: Int) where pprPrec _ (ServerVar _ v) = "$" <> pretty v pprPrec _ (Const v) = stack $ map pretty $ T.lines $ V.valueText v pprPrec i (Let pat e1 e2) = parensIf (i > 0) $ "let" <+> pat' <+> equals <+> pretty e1 <+> "in" <+> pretty e2 where pat' = case pat of [x] -> pretty x _ -> parens $ align $ commasep $ map pretty pat pprPrec _ (Call v []) = pretty v pprPrec i (Call v args) = parensIf (i > 0) $ pretty v <+> hsep (map (align . pprPrec 1) args) pprPrec _ (Tuple vs) = parens $ commasep $ map (align . pretty) vs pprPrec _ (StringLit s) = pretty $ show s pprPrec _ (Record m) = braces $ align $ commasep $ map field m where field (k, v) = align (pretty k <> equals <> pretty v) type Parser = Parsec Void T.Text lexeme :: Parser () -> Parser a -> Parser a lexeme sep p = p <* sep inParens :: Parser () -> Parser a -> Parser a inParens sep = between (lexeme sep "(") (lexeme sep ")") inBraces :: Parser () -> Parser a -> Parser a inBraces sep = between (lexeme sep "{") (lexeme sep "}") -- | Parse a FutharkScript expression, given a whitespace parser. parseExp :: Parsec Void T.Text () -> Parsec Void T.Text Exp parseExp sep = choice [ lexeme sep "let" $> Let <*> pPat <* lexeme sep "=" <*> parseExp sep <* lexeme sep "in" <*> parseExp sep, try $ Call <$> parseFunc <*> many pAtom, pAtom ] "expression" where pField = (,) <$> pVarName <*> (pEquals *> parseExp sep) pEquals = lexeme sep "=" pComma = lexeme sep "," mkTuple [v] = v mkTuple vs = Tuple vs pAtom = choice [ try $ inParens sep (mkTuple <$> (parseExp sep `sepEndBy` pComma)), inParens sep $ parseExp sep, inBraces sep (Record <$> (pField `sepEndBy` pComma)), StringLit . T.pack <$> lexeme sep ("\"" *> manyTill charLiteral "\""), Const <$> V.parseValue sep, Call <$> parseFunc <*> pure [] ] pPat = choice [ inParens sep $ pVarName `sepEndBy` pComma, pure <$> pVarName ] parseFunc = choice [ FuncBuiltin <$> ("$" *> pVarName), FuncFut <$> pVarName ] reserved = ["let", "in"] pVarName = lexeme sep . try $ do v <- fmap T.pack $ (:) <$> satisfy isAlpha <*> many (satisfy constituent) guard $ v `notElem` reserved pure v where constituent c = isAlphaNum c || c == '\'' || c == '_' -- | Parse a FutharkScript expression with normal whitespace handling. parseExpFromText :: FilePath -> T.Text -> Either T.Text Exp parseExpFromText f s = either (Left . T.pack . errorBundlePretty) Right $ parse (parseExp space <* eof) f s readVar :: (MonadError T.Text m, MonadIO m) => Server -> VarName -> m V.Value readVar server v = either throwError pure =<< liftIO (getValue server v) writeVar :: (MonadError T.Text m, MonadIO m) => Server -> VarName -> V.Value -> m () writeVar server v val = cmdMaybe $ liftIO (putValue server v val) -- | A ScriptValue is either a base value or a partially applied -- function. We don't have real first-class functions in -- FutharkScript, but we sort of have closures. data ScriptValue v = SValue TypeName v | -- | Ins, then outs. Yes, this is the opposite of more or less -- everywhere else. SFun EntryName [TypeName] [TypeName] [ScriptValue v] deriving (Show) instance Functor ScriptValue where fmap = fmapDefault instance Foldable ScriptValue where foldMap = foldMapDefault instance Traversable ScriptValue where traverse f (SValue t v) = SValue t <$> f v traverse f (SFun fname ins outs vs) = SFun fname ins outs <$> traverse (traverse f) vs -- | The type of a 'ScriptValue' - either a value type or a function type. data ScriptValueType = STValue TypeName | -- | Ins, then outs. STFun [TypeName] [TypeName] deriving (Eq, Show) instance Pretty ScriptValueType where pretty (STValue t) = pretty t pretty (STFun ins outs) = hsep $ intersperse "->" (map pretty ins ++ [outs']) where outs' = case outs of [out] -> pretty out _ -> parens $ commasep $ map pretty outs -- | A Haskell-level value or a variable on the server. data ValOrVar = VVal V.Value | VVar VarName deriving (Show) -- | The intermediate values produced by an expression - in -- particular, these may not be on the server. type ExpValue = V.Compound (ScriptValue ValOrVar) -- | The type of a 'ScriptValue'. scriptValueType :: ScriptValue v -> ScriptValueType scriptValueType (SValue t _) = STValue t scriptValueType (SFun _ ins outs _) = STFun ins outs -- | The set of server-side variables in the value. serverVarsInValue :: ExpValue -> S.Set VarName serverVarsInValue = S.fromList . concatMap isVar . toList where isVar (SValue _ (VVar x)) = [x] isVar (SValue _ (VVal _)) = [] isVar (SFun _ _ _ closure) = concatMap isVar $ toList closure -- | Convert a value into a corresponding expression. valueToExp :: ExpValue -> Exp valueToExp (V.ValueAtom (SValue t (VVar v))) = ServerVar t v valueToExp (V.ValueAtom (SValue _ (VVal v))) = Const v valueToExp (V.ValueAtom (SFun fname _ _ closure)) = Call (FuncFut fname) $ map (valueToExp . V.ValueAtom) closure valueToExp (V.ValueRecord fs) = Record $ M.toList $ M.map valueToExp fs valueToExp (V.ValueTuple fs) = Tuple $ map valueToExp fs -- Decompose a type name into a rank and an element type. parseTypeName :: TypeName -> Maybe (Int, V.PrimType) parseTypeName s | Just pt <- lookup s m = Just (0, pt) | "[]" `T.isPrefixOf` s = do (d, pt) <- parseTypeName (T.drop 2 s) pure (d + 1, pt) | otherwise = Nothing where prims = [minBound .. maxBound] primtexts = map (V.valueTypeText . V.ValueType []) prims m = zip primtexts prims coerceValue :: TypeName -> V.Value -> Maybe V.Value coerceValue t v = do (_, pt) <- parseTypeName t case v of V.I8Value shape vs -> coerceInts pt shape $ map toInteger $ SVec.toList vs V.I16Value shape vs -> coerceInts pt shape $ map toInteger $ SVec.toList vs V.I32Value shape vs -> coerceInts pt shape $ map toInteger $ SVec.toList vs V.I64Value shape vs -> coerceInts pt shape $ map toInteger $ SVec.toList vs _ -> Nothing where coerceInts V.I8 shape = Just . V.I8Value shape . SVec.fromList . map fromInteger coerceInts V.I16 shape = Just . V.I16Value shape . SVec.fromList . map fromInteger coerceInts V.I32 shape = Just . V.I32Value shape . SVec.fromList . map fromInteger coerceInts V.I64 shape = Just . V.I64Value shape . SVec.fromList . map fromInteger coerceInts V.F32 shape = Just . V.F32Value shape . SVec.fromList . map fromInteger coerceInts V.F64 shape = Just . V.F64Value shape . SVec.fromList . map fromInteger coerceInts _ _ = const Nothing -- | How to evaluate a builtin function. type EvalBuiltin m = T.Text -> [V.CompoundValue] -> m V.CompoundValue loadData :: (MonadIO m, MonadError T.Text m) => FilePath -> m (V.Compound V.Value) loadData datafile = do contents <- liftIO $ LBS.readFile datafile let maybe_vs = V.readValues contents case maybe_vs of Nothing -> throwError $ "Failed to read data file " <> T.pack datafile Just [v] -> pure $ V.ValueAtom v Just vs -> pure $ V.ValueTuple $ map V.ValueAtom vs pathArg :: (MonadError T.Text f) => FilePath -> T.Text -> [V.Compound V.Value] -> f FilePath pathArg dir cmd vs = case vs of [V.ValueAtom v] | Just path <- V.getValue v -> pure $ dir map (chr . fromIntegral) (path :: [Word8]) _ -> throwError $ "$" <> cmd <> " does not accept arguments of types: " <> T.intercalate ", " (map (prettyText . fmap V.valueType) vs) -- | Handles the following builtin functions: @loaddata@, @loadbytes@. -- Fails for everything else. The 'FilePath' indicates the directory -- that files should be read relative to. scriptBuiltin :: (MonadIO m, MonadError T.Text m) => FilePath -> EvalBuiltin m scriptBuiltin dir "loaddata" vs = do loadData =<< pathArg dir "loaddata" vs scriptBuiltin dir "loadbytes" vs = do fmap (V.ValueAtom . V.putValue1) . liftIO . BS.readFile =<< pathArg dir "loadbytes" vs scriptBuiltin _ f _ = throwError $ "Unknown builtin function $" <> prettyText f -- | Symbol table used for local variable lookups during expression evaluation. type VTable = M.Map VarName ExpValue -- | Evaluate a FutharkScript expression relative to some running server. evalExp :: forall m. (MonadError T.Text m, MonadIO m) => EvalBuiltin m -> ScriptServer -> Exp -> m ExpValue evalExp builtin sserver top_level_e = do vars <- liftIO $ newIORef [] let ( ScriptServer { scriptServer = server, scriptCounter = counter, scriptTypes = types } ) = sserver newVar base = liftIO $ do x <- readIORef counter modifyIORef counter (+ 1) let v = base <> prettyText x modifyIORef vars (v :) pure v mkRecord t vs = do v <- newVar "record" cmdMaybe $ cmdNew server v t vs pure v getField from (f, _) = do to <- newVar "field" cmdMaybe $ cmdProject server to from $ nameToText f pure to toVal :: ValOrVar -> m V.Value toVal (VVal v) = pure v toVal (VVar v) = readVar server v toVar :: ValOrVar -> m VarName toVar (VVar v) = pure v toVar (VVal val) = do v <- newVar "const" writeVar server v val pure v scriptValueToValOrVar (SFun f _ _ _) = throwError $ "Function " <> f <> " not fully applied." scriptValueToValOrVar (SValue _ v) = pure v scriptValueToVal :: ScriptValue ValOrVar -> m V.Value scriptValueToVal = toVal <=< scriptValueToValOrVar scriptValueToVar :: ScriptValue ValOrVar -> m VarName scriptValueToVar = toVar <=< scriptValueToValOrVar interValToVal :: ExpValue -> m V.CompoundValue interValToVal = traverse scriptValueToVal -- Apart from type checking, this function also converts -- FutharkScript tuples/records to Futhark-level tuples/records, -- as well as maps between different names for the same -- tuple/record. -- -- We also implicitly convert the types of constants. interValToVar :: m VarName -> TypeName -> ExpValue -> m VarName interValToVar _ t (V.ValueAtom v) | STValue t == scriptValueType v = scriptValueToVar v interValToVar bad t (V.ValueTuple vs) | Just ts <- isTuple t types, length vs == length ts = mkRecord t =<< zipWithM (interValToVar bad) ts vs interValToVar bad t (V.ValueRecord vs) | Just fs <- isRecord t types, Just vs' <- mapM ((`M.lookup` vs) . nameToText . fst) fs = mkRecord t =<< zipWithM (interValToVar bad) (map snd fs) vs' interValToVar _ t (V.ValueAtom (SValue vt (VVar v))) | Just t_fs <- isRecord t types, Just vt_fs <- isRecord vt types, vt_fs == t_fs = mkRecord t =<< mapM (getField v) vt_fs interValToVar _ t (V.ValueAtom (SValue _ (VVal v))) | Just v' <- coerceValue t v = scriptValueToVar $ SValue t $ VVal v' interValToVar bad _ _ = bad valToInterVal :: V.CompoundValue -> ExpValue valToInterVal = fmap $ \v -> SValue (V.valueTypeTextNoDims (V.valueType v)) $ VVal v letMatch :: [VarName] -> ExpValue -> m VTable letMatch vs val | vals <- V.unCompound val, length vs == length vals = pure $ M.fromList (zip vs vals) | otherwise = throwError $ "Pat: " <> prettyTextOneLine vs <> "\nDoes not match value of type: " <> prettyTextOneLine (fmap scriptValueType val) evalExp' :: VTable -> Exp -> m ExpValue evalExp' _ (ServerVar t v) = pure $ V.ValueAtom $ SValue t $ VVar v evalExp' vtable (Call (FuncBuiltin name) es) = do v <- builtin name =<< mapM (interValToVal <=< evalExp' vtable) es pure $ valToInterVal v evalExp' vtable (Call (FuncFut name) es) | Just e <- M.lookup name vtable = do unless (null es) $ throwError $ "Locally bound name cannot be invoked as a function: " <> prettyText name pure e evalExp' vtable (Call (FuncFut name) es) = do in_types <- fmap (map inputType) $ cmdEither $ cmdInputs server name out_types <- fmap (map outputType) $ cmdEither $ cmdOutputs server name es' <- mapM (evalExp' vtable) es let es_types = map (fmap scriptValueType) es' let cannotApply = throwError $ "Function \"" <> name <> "\" expects " <> prettyText (length in_types) <> " argument(s) of types:\n" <> T.intercalate "\n" (map prettyTextOneLine in_types) <> "\nBut applied to " <> prettyText (length es_types) <> " argument(s) of types:\n" <> T.intercalate "\n" (map prettyTextOneLine es_types) tryApply args = do arg_types <- zipWithM (interValToVar cannotApply) in_types args if length in_types == length arg_types then do outs <- replicateM (length out_types) $ newVar "out" void $ cmdEither $ cmdCall server name outs arg_types pure $ V.mkCompound $ map V.ValueAtom $ zipWith SValue out_types $ map VVar outs else pure . V.ValueAtom . SFun name in_types out_types $ zipWith SValue in_types $ map VVar arg_types -- Careful to not require saturated application, but do still -- check for over-saturation. when (length es_types > length in_types) cannotApply -- Allow automatic uncurrying if applicable. case es' of [V.ValueTuple es''] | length es'' == length in_types -> tryApply es'' _ -> tryApply es' evalExp' _ (StringLit s) = case V.putValue s of Just s' -> pure $ V.ValueAtom $ SValue (V.valueTypeTextNoDims (V.valueType s')) $ VVal s' Nothing -> error $ "Unable to write value " ++ prettyString s evalExp' _ (Const val) = pure $ V.ValueAtom $ SValue (V.valueTypeTextNoDims (V.valueType val)) $ VVal val evalExp' vtable (Tuple es) = V.ValueTuple <$> mapM (evalExp' vtable) es evalExp' vtable e@(Record m) = do when (length (nubOrd (map fst m)) /= length (map fst m)) $ throwError $ "Record " <> prettyText e <> " has duplicate fields." V.ValueRecord <$> traverse (evalExp' vtable) (M.fromList m) evalExp' vtable (Let pat e1 e2) = do v <- evalExp' vtable e1 pat_vtable <- letMatch pat v evalExp' (pat_vtable <> vtable) e2 let freeNonresultVars v = do let v_vars = serverVarsInValue v to_free <- liftIO $ filter (`S.notMember` v_vars) <$> readIORef vars cmdMaybe $ cmdFree server to_free pure v freeVarsOnError e = do -- We are intentionally ignoring any errors produced by -- cmdFree, because we already have another error to -- propagate. Also, not all of the variables that we put in -- 'vars' might actually exist server-side, if we failed in a -- Call. void $ liftIO $ cmdFree server =<< readIORef vars throwError e (freeNonresultVars =<< evalExp' mempty top_level_e) `catchError` freeVarsOnError -- | Read actual values from the server. Fails for values that have -- no well-defined external representation. getExpValue :: (MonadError T.Text m, MonadIO m) => ScriptServer -> ExpValue -> m V.CompoundValue getExpValue server e = traverse toGround =<< traverse (traverse onLeaf) e where onLeaf (VVar v) = readVar (scriptServer server) v onLeaf (VVal v) = pure v toGround (SFun fname _ _ _) = throwError $ "Function " <> fname <> " not fully applied." toGround (SValue _ v) = pure v -- | Like 'evalExp', but requires all values to be non-functional. If -- the value has a bad type, return that type instead. Other -- evaluation problems (e.g. type failures) raise errors. evalExpToGround :: (MonadError T.Text m, MonadIO m) => EvalBuiltin m -> ScriptServer -> Exp -> m (Either (V.Compound ScriptValueType) V.CompoundValue) evalExpToGround builtin server e = do v <- evalExp builtin server e -- This assumes that the only error that can occur during -- getExpValue is trying to read an opaque. (Right <$> getExpValue server v) `catchError` const (pure $ Left $ fmap scriptValueType v) -- | The set of Futhark variables that are referenced by the -- expression - these will have to be entry points in the Futhark -- program. varsInExp :: Exp -> S.Set EntryName varsInExp ServerVar {} = mempty varsInExp (Call (FuncFut v) es) = S.insert v $ foldMap varsInExp es varsInExp (Call (FuncBuiltin _) es) = foldMap varsInExp es varsInExp (Tuple es) = foldMap varsInExp es varsInExp (Record fs) = foldMap (foldMap varsInExp) fs varsInExp Const {} = mempty varsInExp StringLit {} = mempty varsInExp (Let pat e1 e2) = varsInExp e1 <> S.filter (`notElem` pat) (varsInExp e2) -- | Release all the server-side variables in the value. Yes, -- FutharkScript has manual memory management... freeValue :: (MonadError T.Text m, MonadIO m) => ScriptServer -> ExpValue -> m () freeValue server = cmdMaybe . cmdFree (scriptServer server) . S.toList . serverVarsInValue futhark-0.25.27/src/Futhark/Test.hs000066400000000000000000000426621475065116200170220ustar00rootroot00000000000000-- | Facilities for reading Futhark test programs. A Futhark test -- program is an ordinary Futhark program where an initial comment -- block specifies input- and output-sets. module Futhark.Test ( module Futhark.Test.Spec, valuesFromByteString, FutharkExe (..), getValues, getValuesBS, valuesAsVars, V.compareValues, checkResult, testRunReferenceOutput, getExpectedResult, compileProgram, readResults, ensureReferenceOutput, determineTuning, determineCache, binaryName, futharkServerCfg, V.Mismatch, V.Value, V.valueText, ) where import Codec.Compression.GZip import Control.Applicative import Control.Exception (catch) import Control.Exception.Base qualified as E import Control.Monad import Control.Monad.Except (MonadError (..), runExceptT) import Control.Monad.IO.Class (MonadIO, liftIO) import Data.Binary qualified as Bin import Data.ByteString qualified as SBS import Data.ByteString.Lazy qualified as BS import Data.Char import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T import Data.Text.Encoding qualified as T import Data.Text.IO qualified as T import Futhark.Script qualified as Script import Futhark.Server import Futhark.Server.Values import Futhark.Test.Spec import Futhark.Test.Values qualified as V import Futhark.Util (isEnvVarAtLeast, pmapIO, showText) import Futhark.Util.Pretty (prettyText, prettyTextOneLine) import System.Directory import System.Exit import System.FilePath import System.IO (IOMode (..), hClose, hFileSize, withFile) import System.IO.Error import System.IO.Temp import System.Process.ByteString (readProcessWithExitCode) import Prelude -- | Try to parse a several values from a byte string. The 'String' -- parameter is used for error messages. valuesFromByteString :: String -> BS.ByteString -> Either String [V.Value] valuesFromByteString srcname = maybe (Left $ "Cannot parse values from '" ++ srcname ++ "'") Right . V.readValues -- | The @futhark@ executable we are using. This is merely a wrapper -- around the underlying file path, because we will be using a lot of -- different file paths here, and it is easy to mix them up. newtype FutharkExe = FutharkExe FilePath deriving (Eq, Ord, Show) -- | Get the actual core Futhark values corresponding to a 'Values' -- specification. The first 'FilePath' is the path of the @futhark@ -- executable, and the second is the directory which file paths are -- read relative to. getValues :: (MonadFail m, MonadIO m) => FutharkExe -> FilePath -> Values -> m [V.Value] getValues _ _ (Values vs) = pure vs getValues futhark dir v = do s <- getValuesBS futhark dir v case valuesFromByteString (fileName v) s of Left e -> fail e Right vs -> pure vs where fileName Values {} = "" fileName GenValues {} = "" fileName ScriptValues {} = "" fileName (InFile f) = f fileName (ScriptFile f) = f readAndDecompress :: FilePath -> IO (Either DecompressError BS.ByteString) readAndDecompress file = E.try $ do s <- BS.readFile file E.evaluate $ decompress s -- | Extract a text representation of some 'Values'. In the IO monad -- because this might involve reading from a file. There is no -- guarantee that the resulting byte string yields a readable value. getValuesBS :: (MonadFail m, MonadIO m) => FutharkExe -> FilePath -> Values -> m BS.ByteString getValuesBS _ _ (Values vs) = pure $ BS.fromStrict $ T.encodeUtf8 $ T.unlines $ map V.valueText vs getValuesBS _ dir (InFile file) = case takeExtension file of ".gz" -> liftIO $ do s <- readAndDecompress file' case s of Left e -> fail $ show file ++ ": " ++ show e Right s' -> pure s' _ -> liftIO $ BS.readFile file' where file' = dir file getValuesBS futhark dir (GenValues gens) = mconcat <$> mapM (getGenBS futhark dir) gens getValuesBS _ _ (ScriptValues e) = fail $ "Cannot get values from FutharkScript expression: " <> T.unpack (prettyTextOneLine e) getValuesBS _ _ (ScriptFile f) = fail $ "Cannot get values from FutharkScript file: " <> f valueAsVar :: (MonadError T.Text m, MonadIO m) => Server -> VarName -> V.Value -> m () valueAsVar server v val = cmdMaybe $ putValue server v val -- Frees the expression on error. scriptValueAsVars :: (MonadError T.Text m, MonadIO m) => Server -> [(VarName, TypeName)] -> Script.ExpValue -> m () scriptValueAsVars server names_and_types val | vals <- V.unCompound val, length names_and_types == length vals, Just loads <- zipWithM f names_and_types vals = sequence_ loads where f (v, t0) (V.ValueAtom (Script.SValue t1 sval)) | t0 == t1 = Just $ case sval of Script.VVar oldname -> cmdMaybe $ cmdRename server oldname v Script.VVal sval' -> valueAsVar server v sval' f _ _ = Nothing scriptValueAsVars server names_and_types val = do cmdMaybe $ cmdFree server $ S.toList $ Script.serverVarsInValue val throwError $ "Expected value of type: " <> prettyTextOneLine (V.mkCompound (map (V.ValueAtom . snd) names_and_types)) <> "\nBut got value of type: " <> prettyTextOneLine (fmap Script.scriptValueType val) <> notes where notes = mconcat $ mapMaybe note names_and_types note (_, t) | "(" `T.isPrefixOf` t = Just $ "\nNote: expected type " <> prettyText t <> " is an opaque tuple that cannot be constructed\n" <> "in FutharkScript. Consider using type annotations to give it a proper name." | "{" `T.isPrefixOf` t = Just $ "\nNote: expected type " <> prettyText t <> " is an opaque record that cannot be constructed\n" <> "in FutharkScript. Consider using type annotations to give it a proper name." | otherwise = Nothing -- | Make the provided 'Values' available as server-side variables. -- This may involve arbitrary server-side computation. Error -- detection... dubious. valuesAsVars :: (MonadError T.Text m, MonadIO m) => Server -> [(VarName, TypeName)] -> FutharkExe -> FilePath -> Values -> m () valuesAsVars server names_and_types _ dir (InFile file) | takeExtension file == ".gz" = do s <- liftIO $ readAndDecompress $ dir file case s of Left e -> throwError $ showText file <> ": " <> showText e Right s' -> cmdMaybe . withSystemTempFile "futhark-input" $ \tmpf tmpf_h -> do BS.hPutStr tmpf_h s' hClose tmpf_h cmdRestore server tmpf names_and_types | otherwise = cmdMaybe $ cmdRestore server (dir file) names_and_types valuesAsVars server names_and_types futhark dir (GenValues gens) = do unless (length gens == length names_and_types) $ throwError "Mismatch between number of expected and generated values." gen_fs <- mapM (getGenFile futhark dir) gens forM_ (zip gen_fs names_and_types) $ \(file, (v, t)) -> cmdMaybe $ cmdRestore server (dir file) [(v, t)] valuesAsVars server names_and_types _ _ (Values vs) = do let types = map snd names_and_types vs_types = map (V.valueTypeTextNoDims . V.valueType) vs unless (types == vs_types) . throwError . T.unlines $ [ "Expected input of types: " <> T.unwords (map prettyTextOneLine types), "Provided input of types: " <> T.unwords (map prettyTextOneLine vs_types) ] cmdMaybe . withSystemTempFile "futhark-input" $ \tmpf tmpf_h -> do mapM_ (BS.hPutStr tmpf_h . Bin.encode) vs hClose tmpf_h cmdRestore server tmpf names_and_types valuesAsVars server names_and_types _ dir (ScriptValues e) = Script.withScriptServer' server $ \server' -> do e_v <- Script.evalExp (Script.scriptBuiltin dir) server' e scriptValueAsVars server names_and_types e_v valuesAsVars server names_and_types futhark dir (ScriptFile f) = do e <- either throwError pure . Script.parseExpFromText f =<< liftIO (T.readFile (dir f)) valuesAsVars server names_and_types futhark dir (ScriptValues e) -- | There is a risk of race conditions when multiple programs have -- identical 'GenValues'. In such cases, multiple threads in 'futhark -- test' might attempt to create the same file (or read from it, while -- something else is constructing it). This leads to a mess. To -- avoid this, we create a temporary file, and only when it is -- complete do we move it into place. It would be better if we could -- use file locking, but that does not work on some file systems. The -- approach here seems robust enough for now, but certainly it could -- be made even better. The race condition that remains should mostly -- result in duplicate work, not crashes or data corruption. getGenFile :: (MonadIO m) => FutharkExe -> FilePath -> GenValue -> m FilePath getGenFile futhark dir gen = do liftIO $ createDirectoryIfMissing True $ dir "data" exists_and_proper_size <- liftIO $ withFile (dir file) ReadMode (fmap (== genFileSize gen) . hFileSize) `catch` \ex -> if isDoesNotExistError ex then pure False else E.throw ex unless exists_and_proper_size $ liftIO $ do s <- genValues futhark [gen] withTempFile (dir "data") (genFileName gen) $ \tmpfile h -> do hClose h -- We will be writing and reading this ourselves. SBS.writeFile tmpfile s renameFile tmpfile $ dir file pure file where file = "data" genFileName gen getGenBS :: (MonadIO m) => FutharkExe -> FilePath -> GenValue -> m BS.ByteString getGenBS futhark dir gen = liftIO . BS.readFile . (dir ) =<< getGenFile futhark dir gen genValues :: FutharkExe -> [GenValue] -> IO SBS.ByteString genValues (FutharkExe futhark) gens = do (code, stdout, stderr) <- readProcessWithExitCode futhark ("dataset" : map T.unpack args) mempty case code of ExitSuccess -> pure stdout ExitFailure e -> fail $ "'futhark dataset' failed with exit code " ++ show e ++ " and stderr:\n" ++ map (chr . fromIntegral) (SBS.unpack stderr) where args = "-b" : concatMap argForGen gens argForGen g = ["-g", genValueType g] genFileName :: GenValue -> FilePath genFileName gen = T.unpack (genValueType gen) <> ".in" -- | Compute the expected size of the file. We use this to check -- whether an existing file is broken/truncated. genFileSize :: GenValue -> Integer genFileSize = genSize where header_size = 1 + 1 + 1 + 4 -- 'b' genSize (GenValue (V.ValueType ds t)) = toInteger $ header_size + length ds * 8 + product ds * V.primTypeBytes t genSize (GenPrim v) = toInteger $ header_size + product (V.valueShape v) * V.primTypeBytes (V.valueElemType v) -- | When/if generating a reference output file for this run, what -- should it be called? Includes the "data/" folder. testRunReferenceOutput :: FilePath -> T.Text -> TestRun -> FilePath testRunReferenceOutput prog entry tr = "data" takeBaseName prog <> ":" <> T.unpack entry <> "-" <> map clean (T.unpack (runDescription tr)) <.> "out" where clean '/' = '_' -- Would this ever happen? clean ' ' = '_' clean c = c -- | Get the values corresponding to an expected result, if any. getExpectedResult :: (MonadFail m, MonadIO m) => FutharkExe -> FilePath -> T.Text -> TestRun -> m (ExpectedResult [V.Value]) getExpectedResult futhark prog entry tr = case runExpectedResult tr of (Succeeds (Just (SuccessValues vals))) -> Succeeds . Just <$> getValues futhark (takeDirectory prog) vals Succeeds (Just SuccessGenerateValues) -> getExpectedResult futhark prog entry tr' where tr' = tr { runExpectedResult = Succeeds . Just . SuccessValues . InFile $ testRunReferenceOutput prog entry tr } Succeeds Nothing -> pure $ Succeeds Nothing RunTimeFailure err -> pure $ RunTimeFailure err -- | The name we use for compiled programs. binaryName :: FilePath -> FilePath binaryName = dropExtension -- | @compileProgram extra_options futhark backend program@ compiles -- @program@ with the command @futhark backend extra-options...@, and -- returns stdout and stderr of the compiler. Throws an IO exception -- containing stderr if compilation fails. compileProgram :: (MonadIO m, MonadError T.Text m) => [String] -> FutharkExe -> String -> FilePath -> m (SBS.ByteString, SBS.ByteString) compileProgram extra_options (FutharkExe futhark) backend program = do (futcode, stdout, stderr) <- liftIO $ readProcessWithExitCode futhark (backend : options) "" case futcode of ExitFailure 127 -> throwError $ progNotFound $ T.pack futhark ExitFailure _ -> throwError $ T.decodeUtf8 stderr ExitSuccess -> pure () pure (stdout, stderr) where binOutputf = binaryName program options = [program, "-o", binOutputf] ++ extra_options progNotFound s = s <> ": command not found" -- | Read the given variables from a running server. readResults :: (MonadIO m, MonadError T.Text m) => Server -> [VarName] -> m [V.Value] readResults server = mapM (either throwError pure <=< liftIO . getValue server) -- | Call an entry point. Returns server variables storing the result. callEntry :: (MonadIO m, MonadError T.Text m) => FutharkExe -> Server -> FilePath -> EntryName -> Values -> m [VarName] callEntry futhark server prog entry input = do output_types <- cmdEither $ cmdOutputs server entry input_types <- cmdEither $ cmdInputs server entry let outs = ["out" <> showText i | i <- [0 .. length output_types - 1]] ins = ["in" <> showText i | i <- [0 .. length input_types - 1]] ins_and_types = zip ins (map inputType input_types) valuesAsVars server ins_and_types futhark dir input _ <- cmdEither $ cmdCall server entry outs ins cmdMaybe $ cmdFree server ins pure outs where dir = takeDirectory prog -- | Ensure that any reference output files exist, or create them (by -- compiling the program with the reference compiler and running it on -- the input) if necessary. ensureReferenceOutput :: (MonadIO m, MonadError T.Text m) => Maybe Int -> FutharkExe -> String -> FilePath -> [InputOutputs] -> m () ensureReferenceOutput concurrency futhark compiler prog ios = do missing <- filterM isReferenceMissing $ concatMap entryAndRuns ios unless (null missing) $ do void $ compileProgram ["--server"] futhark compiler prog res <- liftIO . flip (pmapIO concurrency) missing $ \(entry, tr) -> withServer server_cfg $ \server -> runExceptT $ do outs <- callEntry futhark server prog entry $ runInput tr let f = file entry tr liftIO $ createDirectoryIfMissing True $ takeDirectory f cmdMaybe $ cmdStore server f outs cmdMaybe $ cmdFree server outs either throwError (const (pure ())) (sequence_ res) where server_cfg = futharkServerCfg ("." dropExtension prog) [] file entry tr = takeDirectory prog testRunReferenceOutput prog entry tr entryAndRuns (InputOutputs entry rts) = map (entry,) rts isReferenceMissing (entry, tr) | Succeeds (Just SuccessGenerateValues) <- runExpectedResult tr = liftIO $ ((<) <$> getModificationTime (file entry tr) <*> getModificationTime prog) `catch` (\e -> if isDoesNotExistError e then pure True else E.throw e) | otherwise = pure False -- | Determine the @--tuning@ options to pass to the program. The first -- argument is the extension of the tuning file, or 'Nothing' if none -- should be used. determineTuning :: (MonadIO m) => Maybe FilePath -> FilePath -> m ([String], String) determineTuning Nothing _ = pure ([], mempty) determineTuning (Just ext) program = do exists <- liftIO $ doesFileExist (program <.> ext) if exists then pure ( ["--tuning", program <.> ext], " (using " <> takeFileName (program <.> ext) <> ")" ) else pure ([], " (no tuning file)") -- | Determine the @--cache-file@ options to pass to the program. The -- first argument is the extension of the cache file, or 'Nothing' if -- none should be used. determineCache :: Maybe FilePath -> FilePath -> [String] determineCache Nothing _ = [] determineCache (Just ext) program = ["--cache-file", program <.> ext] -- | Check that the result is as expected, and write files and throw -- an error if not. checkResult :: (MonadError T.Text m, MonadIO m) => FilePath -> [V.Value] -> [V.Value] -> m () checkResult program expected_vs actual_vs = case V.compareSeveralValues (V.Tolerance 0.002) actual_vs expected_vs of mismatch : mismatches -> do let actualf = program <.> "actual" expectedf = program <.> "expected" liftIO $ BS.writeFile actualf $ mconcat $ map Bin.encode actual_vs liftIO $ BS.writeFile expectedf $ mconcat $ map Bin.encode expected_vs throwError $ T.pack actualf <> " and " <> T.pack expectedf <> " do not match:\n" <> showText mismatch <> if null mismatches then mempty else "\n...and " <> prettyText (length mismatches) <> " other mismatches." [] -> pure () -- | Create a Futhark server configuration suitable for use when -- testing/benchmarking Futhark programs. futharkServerCfg :: FilePath -> [String] -> ServerCfg futharkServerCfg prog opts = (newServerCfg prog opts) { cfgDebug = isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 } futhark-0.25.27/src/Futhark/Test/000077500000000000000000000000001475065116200164545ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Test/Spec.hs000066400000000000000000000362421475065116200177110ustar00rootroot00000000000000-- | Definition and parsing of a test specification. module Futhark.Test.Spec ( testSpecFromProgram, testSpecFromProgramOrDie, testSpecsFromPaths, testSpecsFromPathsOrDie, testSpecFromFile, testSpecFromFileOrDie, ProgramTest (..), StructureTest (..), StructurePipeline (..), WarningTest (..), TestAction (..), ExpectedError (..), InputOutputs (..), TestRun (..), ExpectedResult (..), Success (..), Values (..), GenValue (..), genValueType, ) where import Control.Applicative import Control.Exception (catch) import Control.Monad import Data.Char import Data.Functor import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Data.Text qualified as T import Data.Text.IO qualified as T import Data.Void import Futhark.Analysis.Metrics.Type import Futhark.Data.Parser import Futhark.Data.Parser qualified as V import Futhark.Script qualified as Script import Futhark.Test.Values qualified as V import Futhark.Util (directoryContents, nubOrd, showText) import Futhark.Util.Pretty (prettyTextOneLine) import System.Exit import System.FilePath import System.IO import System.IO.Error import Text.Megaparsec hiding (many, some) import Text.Megaparsec.Char import Text.Megaparsec.Char.Lexer (charLiteral) import Text.Regex.TDFA import Prelude -- | Description of a test to be carried out on a Futhark program. -- The Futhark program is stored separately. data ProgramTest = ProgramTest { testDescription :: T.Text, testTags :: [T.Text], testAction :: TestAction } deriving (Show) -- | How to test a program. data TestAction = CompileTimeFailure ExpectedError | RunCases [InputOutputs] [StructureTest] [WarningTest] deriving (Show) -- | Input and output pairs for some entry point(s). data InputOutputs = InputOutputs { iosEntryPoint :: T.Text, iosTestRuns :: [TestRun] } deriving (Show) -- | The error expected for a negative test. data ExpectedError = AnyError | ThisError T.Text Regex instance Show ExpectedError where show AnyError = "AnyError" show (ThisError r _) = "ThisError " ++ show r -- | How a program can be transformed. data StructurePipeline = GpuPipeline | MCPipeline | SOACSPipeline | SeqMemPipeline | GpuMemPipeline | MCMemPipeline | NoPipeline deriving (Show) -- | A structure test specifies a compilation pipeline, as well as -- metrics for the program coming out the other end. data StructureTest = StructureTest StructurePipeline AstMetrics deriving (Show) -- | A warning test requires that a warning matching the regular -- expression is produced. The program must also compile succesfully. data WarningTest = ExpectedWarning T.Text Regex instance Show WarningTest where show (ExpectedWarning r _) = "ExpectedWarning " ++ T.unpack r -- | A condition for execution, input, and expected result. data TestRun = TestRun { runTags :: [T.Text], runInput :: Values, runExpectedResult :: ExpectedResult Success, runIndex :: Int, runDescription :: T.Text } deriving (Show) -- | Several values - either literally, or by reference to a file, or -- to be generated on demand. All paths are relative to test program. data Values = Values [V.Value] | InFile FilePath | GenValues [GenValue] | ScriptValues Script.Exp | ScriptFile FilePath deriving (Show) -- | How to generate a single random value. data GenValue = -- | Generate a value of the given rank and primitive -- type. Scalars are considered 0-ary arrays. GenValue V.ValueType | -- | A fixed non-randomised primitive value. GenPrim V.Value deriving (Show) -- | A prettyprinted representation of type of value produced by a -- 'GenValue'. genValueType :: GenValue -> T.Text genValueType (GenValue (V.ValueType ds t)) = foldMap (\d -> "[" <> showText d <> "]") ds <> V.primTypeText t genValueType (GenPrim v) = V.valueText v -- | How a test case is expected to terminate. data ExpectedResult values = -- | Execution suceeds, with or without -- expected result values. Succeeds (Maybe values) | -- | Execution fails with this error. RunTimeFailure ExpectedError deriving (Show) -- | The result expected from a succesful execution. data Success = -- | These values are expected. SuccessValues Values | -- | Compute expected values from executing a known-good -- reference implementation. SuccessGenerateValues deriving (Show) type Parser = Parsec Void T.Text lexeme :: Parser () -> Parser a -> Parser a lexeme sep p = p <* sep -- Like 'lexeme', but does not consume trailing linebreaks. lexeme' :: Parser a -> Parser a lexeme' p = p <* hspace -- Like 'lexstr', but does not consume trailing linebreaks. lexstr' :: T.Text -> Parser () lexstr' = void . try . lexeme' . string inBraces :: Parser () -> Parser a -> Parser a inBraces sep = between (lexeme sep "{") (lexeme sep "}") parseNatural :: Parser () -> Parser Int parseNatural sep = lexeme sep $ L.foldl' addDigit 0 . map num <$> some digitChar where addDigit acc x = acc * 10 + x num c = ord c - ord '0' restOfLine :: Parser T.Text restOfLine = do l <- restOfLine_ if T.null l then void eol else void eol <|> eof pure l restOfLine_ :: Parser T.Text restOfLine_ = takeWhileP Nothing (/= '\n') parseDescription :: Parser () -> Parser T.Text parseDescription sep = T.unlines <$> pDescLine `manyTill` pDescriptionSeparator where pDescLine = restOfLine <* sep pDescriptionSeparator = void $ "==" *> sep lTagName :: Parser () -> Parser T.Text lTagName sep = lexeme sep $ takeWhile1P (Just "tag-constituent character") tagConstituent parseTags :: Parser () -> Parser [T.Text] parseTags sep = choice [ lexeme' "tags" *> inBraces sep (many (lTagName sep)), pure [] ] tagConstituent :: Char -> Bool tagConstituent c = isAlphaNum c || c == '_' || c == '-' parseAction :: Parser () -> Parser TestAction parseAction sep = choice [ CompileTimeFailure <$> (lexstr' "error:" *> parseExpectedError sep), RunCases <$> parseInputOutputs sep <*> many (parseExpectedStructure sep) <*> many (parseWarning sep) ] parseInputOutputs :: Parser () -> Parser [InputOutputs] parseInputOutputs sep = do entrys <- parseEntryPoints sep cases <- parseRunCases sep pure $ if null cases then [] else map (`InputOutputs` cases) entrys parseEntryPoints :: Parser () -> Parser [T.Text] parseEntryPoints sep = (lexeme' "entry:" *> many entry <* sep) <|> pure ["main"] where constituent c = not (isSpace c) && c /= '}' entry = lexeme' $ takeWhile1P Nothing constituent parseRunTags :: Parser () -> Parser [T.Text] parseRunTags sep = many . try . lexeme' $ do s <- lTagName sep guard $ s `notElem` ["input", "structure", "warning"] pure s parseStringLiteral :: Parser () -> Parser T.Text parseStringLiteral sep = lexeme sep . fmap T.pack $ char '"' >> manyTill charLiteral (char '"') parseRunCases :: Parser () -> Parser [TestRun] parseRunCases sep = parseRunCases' (0 :: Int) where parseRunCases' i = (:) <$> parseRunCase i <*> parseRunCases' (i + 1) <|> pure [] parseRunCase i = do name <- optional $ parseStringLiteral sep tags <- parseRunTags sep void $ lexeme sep "input" input <- if "random" `elem` tags then parseRandomValues sep else if "script" `elem` tags then parseScriptValues sep else parseValues sep expr <- parseExpectedResult sep pure $ TestRun tags input expr i $ fromMaybe (desc i input) name -- If the file is gzipped, we strip the 'gz' extension from -- the dataset name. This makes it more convenient to rename -- from 'foo.in' to 'foo.in.gz', as the reported dataset name -- does not change (which would make comparisons to historical -- data harder). desc _ (InFile path) | takeExtension path == ".gz" = T.pack $ dropExtension path | otherwise = T.pack path desc i (Values vs) = -- Turn linebreaks into space. "#" <> showText i <> " (\"" <> T.unwords (T.lines vs') <> "\")" where vs' = case T.unwords $ map V.valueText vs of s | T.length s > 50 -> T.take 50 s <> "..." | otherwise -> s desc _ (GenValues gens) = T.unwords $ map genValueType gens desc _ (ScriptValues e) = prettyTextOneLine e desc _ (ScriptFile path) = T.pack path parseExpectedResult :: Parser () -> Parser (ExpectedResult Success) parseExpectedResult sep = choice [ lexeme sep "auto" *> lexeme sep "output" $> Succeeds (Just SuccessGenerateValues), Succeeds . Just . SuccessValues <$> (lexeme sep "output" *> parseValues sep), RunTimeFailure <$> (lexeme sep "error:" *> parseExpectedError sep), pure (Succeeds Nothing) ] parseExpectedError :: Parser () -> Parser ExpectedError parseExpectedError sep = lexeme sep $ do s <- T.strip <$> restOfLine_ <* sep if T.null s then pure AnyError else -- blankCompOpt creates a regular expression that treats -- newlines like ordinary characters, which is what we want. ThisError s <$> makeRegexOptsM blankCompOpt defaultExecOpt (T.unpack s) parseScriptValues :: Parser () -> Parser Values parseScriptValues sep = choice [ ScriptValues <$> inBraces sep (Script.parseExp sep), ScriptFile . T.unpack <$> (lexeme sep "@" *> lexeme sep nextWord) ] where nextWord = takeWhileP Nothing $ not . isSpace parseRandomValues :: Parser () -> Parser Values parseRandomValues sep = GenValues <$> inBraces sep (many (parseGenValue sep)) parseGenValue :: Parser () -> Parser GenValue parseGenValue sep = choice [ GenValue <$> lexeme sep parseType, GenPrim <$> lexeme sep V.parsePrimValue ] parseValues :: Parser () -> Parser Values parseValues sep = choice [ Values <$> inBraces sep (many $ parseValue sep), InFile . T.unpack <$> (lexeme sep "@" *> lexeme sep nextWord) ] where nextWord = takeWhileP Nothing $ not . isSpace parseWarning :: Parser () -> Parser WarningTest parseWarning sep = lexeme sep "warning:" >> parseExpectedWarning where parseExpectedWarning = lexeme sep $ do s <- T.strip <$> restOfLine_ ExpectedWarning s <$> makeRegexOptsM blankCompOpt defaultExecOpt (T.unpack s) parseExpectedStructure :: Parser () -> Parser StructureTest parseExpectedStructure sep = lexeme sep "structure" *> (StructureTest <$> optimisePipeline sep <*> parseMetrics sep) optimisePipeline :: Parser () -> Parser StructurePipeline optimisePipeline sep = choice [ lexeme sep "gpu-mem" $> GpuMemPipeline, lexeme sep "gpu" $> GpuPipeline, lexeme sep "mc-mem" $> MCMemPipeline, lexeme sep "mc" $> MCPipeline, lexeme sep "seq-mem" $> SeqMemPipeline, lexeme sep "internalised" $> NoPipeline, pure SOACSPipeline ] parseMetrics :: Parser () -> Parser AstMetrics parseMetrics sep = inBraces sep . fmap (AstMetrics . M.fromList) . many $ (,) <$> lexeme sep (takeWhile1P Nothing constituent) <*> parseNatural sep where constituent c = isAlpha c || c == '/' testSpec :: Parser () -> Parser ProgramTest testSpec sep = ProgramTest <$> parseDescription sep <*> parseTags sep <*> parseAction sep couldNotRead :: IOError -> IO (Either String a) couldNotRead = pure . Left . show pProgramTest :: Parser ProgramTest pProgramTest = do void $ many pNonTestLine maybe_spec <- optional ("--" *> sep *> testSpec sep) <* pEndOfTestBlock <* many pNonTestLine case maybe_spec of Just spec | RunCases old_cases structures warnings <- testAction spec -> do cases <- many $ pInputOutputs <* many pNonTestLine pure spec {testAction = RunCases (old_cases ++ concat cases) structures warnings} | otherwise -> many pNonTestLine *> notFollowedBy "-- ==" *> pure spec "no more test blocks, since first test block specifies type error." Nothing -> eof $> noTest where sep = void $ hspace *> optional (try $ eol *> "--" *> sep) noTest = ProgramTest mempty mempty (RunCases mempty mempty mempty) pEndOfTestBlock = (void eol <|> eof) *> notFollowedBy "--" pNonTestLine = void $ notFollowedBy "-- ==" *> restOfLine pInputOutputs = "--" *> sep *> parseDescription sep *> parseInputOutputs sep <* pEndOfTestBlock validate :: FilePath -> ProgramTest -> Either String ProgramTest validate path pt = do case testAction pt of CompileTimeFailure {} -> pure pt RunCases ios _ _ -> do mapM_ (noDups . map runDescription . iosTestRuns) ios Right pt where noDups xs = let xs' = nubOrd xs in -- Works because \\ only removes first instance. case xs L.\\ xs' of [] -> Right () x : _ -> Left $ path <> ": multiple datasets with name " <> show (T.unpack x) -- | Read the test specification from the given Futhark program. testSpecFromProgram :: FilePath -> IO (Either String ProgramTest) testSpecFromProgram path = ( either (Left . errorBundlePretty) (validate path) . parse pProgramTest path <$> T.readFile path ) `catch` couldNotRead -- | Like 'testSpecFromProgram', but exits the process on error. testSpecFromProgramOrDie :: FilePath -> IO ProgramTest testSpecFromProgramOrDie prog = do spec_or_err <- testSpecFromProgram prog case spec_or_err of Left err -> do hPutStrLn stderr err exitFailure Right spec -> pure spec testPrograms :: FilePath -> IO [FilePath] testPrograms dir = filter isFut <$> directoryContents dir where isFut = (== ".fut") . takeExtension -- | Read test specifications from the given path, which can be a file -- or directory containing @.fut@ files and further directories. testSpecsFromPath :: FilePath -> IO (Either String [(FilePath, ProgramTest)]) testSpecsFromPath path = do programs_or_err <- (Right <$> testPrograms path) `catch` couldNotRead case programs_or_err of Left err -> pure $ Left err Right programs -> do specs_or_errs <- mapM testSpecFromProgram programs pure $ zip programs <$> sequence specs_or_errs -- | Read test specifications from the given paths, which can be a -- files or directories containing @.fut@ files and further -- directories. testSpecsFromPaths :: [FilePath] -> IO (Either String [(FilePath, ProgramTest)]) testSpecsFromPaths = fmap (fmap concat . sequence) . mapM testSpecsFromPath -- | Like 'testSpecsFromPaths', but kills the process on errors. testSpecsFromPathsOrDie :: [FilePath] -> IO [(FilePath, ProgramTest)] testSpecsFromPathsOrDie dirs = do specs_or_err <- testSpecsFromPaths dirs case specs_or_err of Left err -> do hPutStrLn stderr err exitFailure Right specs -> pure specs -- | Read a test specification from a file. Expects only a single -- block, and no comment prefixes. testSpecFromFile :: FilePath -> IO (Either String ProgramTest) testSpecFromFile path = ( either (Left . errorBundlePretty) Right . parse (testSpec space) path <$> T.readFile path ) `catch` couldNotRead -- | Like 'testSpecFromFile', but kills the process on errors. testSpecFromFileOrDie :: FilePath -> IO ProgramTest testSpecFromFileOrDie dirs = do spec_or_err <- testSpecFromFile dirs case spec_or_err of Left err -> do hPutStrLn stderr err exitFailure Right spec -> pure spec futhark-0.25.27/src/Futhark/Test/Values.hs000066400000000000000000000042101475065116200202440ustar00rootroot00000000000000{-# LANGUAGE Strict #-} {-# OPTIONS_GHC -fno-warn-orphans #-} -- | This module provides an efficient value representation as well as -- parsing and comparison functions. module Futhark.Test.Values ( module Futhark.Data, module Futhark.Data.Compare, module Futhark.Data.Reader, Compound (..), CompoundValue, mkCompound, unCompound, ) where import Data.Map qualified as M import Data.Text qualified as T import Data.Traversable import Futhark.Data import Futhark.Data.Compare import Futhark.Data.Reader import Futhark.Util.Pretty instance Pretty Value where pretty = pretty . valueText instance Pretty ValueType where pretty = pretty . valueTypeText -- | The structure of a compound value, parameterised over the actual -- values. For most cases you probably want 'CompoundValue'. data Compound v = ValueRecord (M.Map T.Text (Compound v)) | -- | Must not be single value. ValueTuple [Compound v] | ValueAtom v deriving (Eq, Ord, Show) instance Functor Compound where fmap = fmapDefault instance Foldable Compound where foldMap = foldMapDefault instance Traversable Compound where traverse f (ValueAtom v) = ValueAtom <$> f v traverse f (ValueTuple vs) = ValueTuple <$> traverse (traverse f) vs traverse f (ValueRecord m) = ValueRecord <$> traverse (traverse f) m instance (Pretty v) => Pretty (Compound v) where pretty (ValueAtom v) = pretty v pretty (ValueTuple vs) = parens $ commasep $ map pretty vs pretty (ValueRecord m) = braces $ commasep $ map field $ M.toList m where field (k, v) = pretty k <> equals <> pretty v -- | Create a tuple for a non-unit list, and otherwise a 'ValueAtom' mkCompound :: [Compound v] -> Compound v mkCompound [v] = v mkCompound vs = ValueTuple vs -- | If the value is a tuple, extract the components, otherwise return -- a singleton list of the value. unCompound :: Compound v -> [Compound v] unCompound (ValueTuple vs) = vs unCompound v = [v] -- | Like a 'Value', but also grouped in compound ways that are not -- supported by raw values. You cannot parse or read these in -- standard ways, and they cannot be elements of arrays. type CompoundValue = Compound Value futhark-0.25.27/src/Futhark/Tools.hs000066400000000000000000000135101475065116200171710ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | An unstructured grab-bag of various tools and inspection -- functions that didn't really fit anywhere else. module Futhark.Tools ( module Futhark.Construct, redomapToMapAndReduce, scanomapToMapAndScan, dissectScrema, sequentialStreamWholeArray, partitionChunkedFoldParameters, -- * Primitive expressions module Futhark.Analysis.PrimExp.Convert, ) where import Control.Monad import Futhark.Analysis.PrimExp.Convert import Futhark.Construct import Futhark.IR import Futhark.IR.SOACS.SOAC import Futhark.Util -- | Turns a binding of a @redomap@ into two seperate bindings, a -- @map@ binding and a @reduce@ binding (returned in that order). -- -- Reuses the original pattern for the @reduce@, and creates a new -- pattern with new 'Ident's for the result of the @map@. redomapToMapAndReduce :: ( MonadFreshNames m, Buildable rep, ExpDec rep ~ (), Op rep ~ SOAC rep ) => Pat (LetDec rep) -> ( SubExp, [Reduce rep], Lambda rep, [VName] ) -> m (Stm rep, Stm rep) redomapToMapAndReduce (Pat pes) (w, reds, map_lam, arrs) = do (map_pat, red_pat, red_arrs) <- splitScanOrRedomap pes w map_lam $ map redNeutral reds let map_stm = mkLet map_pat $ Op $ Screma w arrs (mapSOAC map_lam) red_stm <- Let red_pat (defAux ()) . Op <$> (Screma w red_arrs <$> reduceSOAC reds) pure (map_stm, red_stm) scanomapToMapAndScan :: ( MonadFreshNames m, Buildable rep, ExpDec rep ~ (), Op rep ~ SOAC rep ) => Pat (LetDec rep) -> ( SubExp, [Scan rep], Lambda rep, [VName] ) -> m (Stm rep, Stm rep) scanomapToMapAndScan (Pat pes) (w, scans, map_lam, arrs) = do (map_pat, scan_pat, scan_arrs) <- splitScanOrRedomap pes w map_lam $ map scanNeutral scans let map_stm = mkLet map_pat $ Op $ Screma w arrs (mapSOAC map_lam) scan_stm <- Let scan_pat (defAux ()) . Op <$> (Screma w scan_arrs <$> scanSOAC scans) pure (map_stm, scan_stm) splitScanOrRedomap :: (Typed dec, MonadFreshNames m) => [PatElem dec] -> SubExp -> Lambda rep -> [[SubExp]] -> m ([Ident], Pat dec, [VName]) splitScanOrRedomap pes w map_lam nes = do let (acc_pes, arr_pes) = splitAt (length $ concat nes) pes (acc_ts, _arr_ts) = splitAt (length (concat nes)) $ lambdaReturnType map_lam map_accpat <- zipWithM accMapPatElem acc_pes acc_ts map_arrpat <- mapM arrMapPatElem arr_pes let map_pat = map_accpat ++ map_arrpat pure (map_pat, Pat acc_pes, map identName map_accpat) where accMapPatElem pe acc_t = newIdent (baseString (patElemName pe) ++ "_map_acc") $ acc_t `arrayOfRow` w arrMapPatElem = pure . patElemIdent -- | Turn a Screma into a Scanomap (possibly with mapout parts) and a -- Redomap. This is used to handle Scremas that are so complicated -- that we cannot directly generate efficient parallel code for them. -- In essense, what happens is the opposite of horisontal fusion. dissectScrema :: ( MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m) ) => Pat (LetDec (Rep m)) -> SubExp -> ScremaForm (Rep m) -> [VName] -> m () dissectScrema pat w (ScremaForm map_lam scans reds) arrs = do let num_reds = redResults reds num_scans = scanResults scans (scan_res, red_res, map_res) = splitAt3 num_scans num_reds $ patNames pat to_red <- replicateM num_reds $ newVName "to_red" let scanomap = scanomapSOAC scans map_lam letBindNames (scan_res <> to_red <> map_res) $ Op (Screma w arrs scanomap) reduce <- reduceSOAC reds letBindNames red_res $ Op $ Screma w to_red reduce -- | Turn a stream SOAC into statements that apply the stream lambda -- to the entire input. sequentialStreamWholeArray :: (MonadBuilder m, Buildable (Rep m)) => Pat (LetDec (Rep m)) -> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m () sequentialStreamWholeArray pat w nes lam arrs = do -- We just set the chunksize to w and inline the lambda body. There -- is no difference between parallel and sequential streams here. let (chunk_size_param, fold_params, arr_params) = partitionChunkedFoldParameters (length nes) $ lambdaParams lam -- The chunk size is the full size of the array. letBindNames [paramName chunk_size_param] $ BasicOp $ SubExp w -- The accumulator parameters are initialised to the neutral element. forM_ (zip fold_params nes) $ \(p, ne) -> letBindNames [paramName p] $ BasicOp $ SubExp ne -- Finally, the array parameters are set to the arrays (but reshaped -- to make the types work out; this will be simplified rapidly). forM_ (zip arr_params arrs) $ \(p, arr) -> letBindNames [paramName p] . BasicOp $ Reshape ReshapeCoerce (arrayShape $ paramType p) arr -- Then we just inline the lambda body. mapM_ addStm $ bodyStms $ lambdaBody lam -- The number of results in the body matches exactly the size (and -- order) of 'pat', so we bind them up here, again with a reshape to -- make the types work out. forM_ (zip (patElems pat) $ bodyResult $ lambdaBody lam) $ \(pe, SubExpRes cs se) -> certifying cs $ case (arrayDims $ patElemType pe, se) of (dims, Var v) | not $ null dims -> letBindNames [patElemName pe] $ BasicOp $ Reshape ReshapeCoerce (Shape dims) v _ -> letBindNames [patElemName pe] $ BasicOp $ SubExp se -- | Split the parameters of a stream reduction lambda into the chunk -- size parameter, the accumulator parameters, and the input chunk -- parameters. The integer argument is how many accumulators are -- used. partitionChunkedFoldParameters :: Int -> [Param dec] -> (Param dec, [Param dec], [Param dec]) partitionChunkedFoldParameters _ [] = error "partitionChunkedFoldParameters: lambda takes no parameters" partitionChunkedFoldParameters num_accs (chunk_param : params) = let (acc_params, arr_params) = splitAt num_accs params in (chunk_param, acc_params, arr_params) futhark-0.25.27/src/Futhark/Transform/000077500000000000000000000000001475065116200175105ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Transform/CopyPropagate.hs000066400000000000000000000022461475065116200226250ustar00rootroot00000000000000-- | Perform copy propagation. This is done by invoking the -- simplifier with no rules, so hoisting and dead-code elimination may -- also take place. module Futhark.Transform.CopyPropagate ( copyPropagateInProg, copyPropagateInStms, copyPropagateInFun, ) where import Futhark.Analysis.SymbolTable qualified as ST import Futhark.IR import Futhark.MonadFreshNames import Futhark.Optimise.Simplify import Futhark.Optimise.Simplify.Rep (Wise) import Futhark.Pass -- | Run copy propagation on an entire program. copyPropagateInProg :: (SimplifiableRep rep) => SimpleOps rep -> Prog rep -> PassM (Prog rep) copyPropagateInProg simpl = simplifyProg simpl mempty neverHoist -- | Run copy propagation on some statements. copyPropagateInStms :: (MonadFreshNames m, SimplifiableRep rep) => SimpleOps rep -> Scope rep -> Stms rep -> m (Stms rep) copyPropagateInStms simpl = simplifyStms simpl mempty neverHoist -- | Run copy propagation on a function. copyPropagateInFun :: (MonadFreshNames m, SimplifiableRep rep) => SimpleOps rep -> ST.SymbolTable (Wise rep) -> FunDef rep -> m (FunDef rep) copyPropagateInFun simpl = simplifyFun simpl mempty neverHoist futhark-0.25.27/src/Futhark/Transform/FirstOrderTransform.hs000066400000000000000000000414711475065116200240320ustar00rootroot00000000000000{-# LANGUAGE TypeFamilies #-} -- | The code generator cannot handle the array combinators (@map@ and -- friends), so this module was written to transform them into the -- equivalent do-loops. The transformation is currently rather naive, -- and - it's certainly worth considering when we can express such -- transformations in-place. module Futhark.Transform.FirstOrderTransform ( transformFunDef, transformConsts, FirstOrderRep, Transformer, transformStmRecursively, transformLambda, transformSOAC, ) where import Control.Monad import Control.Monad.State import Data.List (find, zip4) import Data.Map.Strict qualified as M import Futhark.Analysis.Alias qualified as Alias import Futhark.IR qualified as AST import Futhark.IR.Prop.Aliases import Futhark.IR.SOACS import Futhark.MonadFreshNames import Futhark.Tools import Futhark.Util (chunks, splitAt3) -- | The constraints that must hold for a rep in order to be the -- target of first-order transformation. type FirstOrderRep rep = ( Buildable rep, BuilderOps rep, LetDec SOACS ~ LetDec rep, LParamInfo SOACS ~ LParamInfo rep, Alias.AliasableRep rep ) -- | First-order-transform a single function, with the given scope -- provided by top-level constants. transformFunDef :: (MonadFreshNames m, FirstOrderRep torep) => Scope torep -> FunDef SOACS -> m (AST.FunDef torep) transformFunDef consts_scope (FunDef entry attrs fname rettype params body) = do (body', _) <- modifyNameSource $ runState $ runBuilderT m consts_scope pure $ FunDef entry attrs fname rettype params body' where m = localScope (scopeOfFParams params) $ transformBody body -- | First-order-transform these top-level constants. transformConsts :: (MonadFreshNames m, FirstOrderRep torep) => Stms SOACS -> m (AST.Stms torep) transformConsts stms = fmap snd $ modifyNameSource $ runState $ runBuilderT m mempty where m = mapM_ transformStmRecursively stms -- | The constraints that a monad must uphold in order to be used for -- first-order transformation. type Transformer m = ( MonadBuilder m, LocalScope (Rep m) m, Buildable (Rep m), BuilderOps (Rep m), LParamInfo SOACS ~ LParamInfo (Rep m), Alias.AliasableRep (Rep m) ) transformBody :: (Transformer m, LetDec (Rep m) ~ LetDec SOACS) => Body SOACS -> m (AST.Body (Rep m)) transformBody (Body () stms res) = buildBody_ $ do mapM_ transformStmRecursively stms pure res -- | First transform any nested t'Body' or t'Lambda' elements, then -- apply 'transformSOAC' if the expression is a SOAC. transformStmRecursively :: (Transformer m, LetDec (Rep m) ~ LetDec SOACS) => Stm SOACS -> m () transformStmRecursively (Let pat aux (Op soac)) = auxing aux $ transformSOAC pat =<< mapSOACM soacTransform soac where soacTransform = identitySOACMapper {mapOnSOACLambda = transformLambda} transformStmRecursively (Let pat aux e) = auxing aux $ letBind pat =<< mapExpM transform e where transform = identityMapper { mapOnBody = \scope -> localScope scope . transformBody, mapOnRetType = pure, mapOnBranchType = pure, mapOnFParam = pure, mapOnLParam = pure, mapOnOp = error "Unhandled Op in first order transform" } -- Produce scratch "arrays" for the Map and Scan outputs of Screma. -- "Arrays" is in quotes because some of those may be accumulators. resultArray :: (Transformer m) => [VName] -> [Type] -> m [VName] resultArray arrs ts = do arrs_ts <- mapM lookupType arrs let oneArray t@Acc {} | Just (v, _) <- find ((== t) . snd) (zip arrs arrs_ts) = pure v oneArray t = letExp "result" =<< eBlank t mapM oneArray ts -- | Transform a single 'SOAC' into a do-loop. The body of the lambda -- is untouched, and may or may not contain further 'SOAC's depending -- on the given rep. transformSOAC :: (Transformer m) => Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m () transformSOAC _ JVP {} = error "transformSOAC: unhandled JVP" transformSOAC _ VJP {} = error "transformSOAC: unhandled VJP" transformSOAC pat (Screma w arrs form@(ScremaForm map_lam scans reds)) = do -- See Note [Translation of Screma]. -- -- Start by combining all the reduction and scan parts into a single -- operator let Reduce _ red_lam red_nes = singleReduce reds Scan scan_lam scan_nes = singleScan scans (scan_arr_ts, _red_ts, map_arr_ts) = splitAt3 (length scan_nes) (length red_nes) $ scremaType w form scan_arrs <- resultArray [] scan_arr_ts map_arrs <- resultArray arrs map_arr_ts scanacc_params <- mapM (newParam "scanacc" . flip toDecl Nonunique) $ lambdaReturnType scan_lam scanout_params <- mapM (newParam "scanout" . flip toDecl Unique) scan_arr_ts redout_params <- mapM (newParam "redout" . flip toDecl Nonunique) $ lambdaReturnType red_lam mapout_params <- mapM (newParam "mapout" . flip toDecl Unique) map_arr_ts arr_ts <- mapM lookupType arrs let paramForAcc (Acc c _ _ _) = find (f . paramType) mapout_params where f (Acc c2 _ _ _) = c == c2 f _ = False paramForAcc _ = Nothing let merge = concat [ zip scanacc_params scan_nes, zip scanout_params $ map Var scan_arrs, zip redout_params red_nes, zip mapout_params $ map Var map_arrs ] i <- newVName "i" let loopform = ForLoop i Int64 w lam_cons = consumedByLambda $ Alias.analyseLambda mempty map_lam loop_body <- runBodyBuilder . localScope (scopeOfFParams (map fst merge) <> scopeOfLoopForm loopform) $ do -- Bind the parameters to the lambda. forM_ (zip3 (lambdaParams map_lam) arrs arr_ts) $ \(p, arr, arr_t) -> case paramForAcc arr_t of Just acc_out_p -> letBindNames [paramName p] . BasicOp $ SubExp $ Var $ paramName acc_out_p Nothing | paramName p `nameIn` lam_cons -> do p' <- letExp (baseString (paramName p)) . BasicOp $ Index arr $ fullSlice arr_t [DimFix $ Var i] letBindNames [paramName p] $ BasicOp $ Replicate mempty $ Var p' | otherwise -> letBindNames [paramName p] . BasicOp . Index arr $ fullSlice arr_t [DimFix $ Var i] -- Insert the statements of the lambda. We have taken care to -- ensure that the parameters are bound at this point. mapM_ addStm $ bodyStms $ lambdaBody map_lam -- Split into scan results, reduce results, and map results. let (scan_res, red_res, map_res) = splitAt3 (length scan_nes) (length red_nes) $ bodyResult $ lambdaBody map_lam scan_res' <- eLambda scan_lam $ map (pure . BasicOp . SubExp) $ map (Var . paramName) scanacc_params ++ map resSubExp scan_res red_res' <- eLambda red_lam $ map (pure . BasicOp . SubExp) $ map (Var . paramName) redout_params ++ map resSubExp red_res -- Write the scan accumulator to the scan result arrays. scan_outarrs <- certifying (foldMap resCerts scan_res) $ letwith (map paramName scanout_params) (Var i) $ map resSubExp scan_res' -- Write the map results to the map result arrays. map_outarrs <- certifying (foldMap resCerts map_res) $ letwith (map paramName mapout_params) (Var i) $ map resSubExp map_res pure . concat $ [ scan_res', varsRes scan_outarrs, red_res', varsRes map_outarrs ] -- We need to discard the final scan accumulators, as they are not -- bound in the original pattern. names <- (++ patNames pat) <$> replicateM (length scanacc_params) (newVName "discard") letBindNames names $ Loop merge loopform loop_body transformSOAC pat (Stream w arrs nes lam) = do -- Create a loop that repeatedly applies the lambda body to a -- chunksize of 1. Hopefully this will lead to this outer loop -- being the only one, as all the innermost one can be simplified -- array (as they will have one iteration each). let (chunk_size_param, fold_params, chunk_params) = partitionChunkedFoldParameters (length nes) $ lambdaParams lam mapout_merge <- forM (drop (length nes) $ lambdaReturnType lam) $ \t -> let t' = t `setOuterSize` w scratch = BasicOp $ Scratch (elemType t') (arrayDims t') in (,) <$> newParam "stream_mapout" (toDecl t' Unique) <*> letSubExp "stream_mapout_scratch" scratch -- We need to copy the neutral elements because they may be consumed -- in the body of the Stream. let copyIfArray se = do se_t <- subExpType se case (se_t, se) of (Array {}, Var v) -> letSubExp (baseString v) $ BasicOp $ Replicate mempty se _ -> pure se nes' <- mapM copyIfArray nes let onType t = t `toDecl` Unique merge = zip (map (fmap onType) fold_params) nes' ++ mapout_merge merge_params = map fst merge mapout_params = map fst mapout_merge i <- newVName "i" let loop_form = ForLoop i Int64 w letBindNames [paramName chunk_size_param] . BasicOp . SubExp $ intConst Int64 1 loop_body <- runBodyBuilder $ localScope (scopeOfLoopForm loop_form <> scopeOfFParams merge_params) $ do let slice = [DimSlice (Var i) (Var (paramName chunk_size_param)) (intConst Int64 1)] forM_ (zip chunk_params arrs) $ \(p, arr) -> letBindNames [paramName p] . BasicOp . Index arr $ fullSlice (paramType p) slice (res, mapout_res) <- splitAt (length nes) <$> bodyBind (lambdaBody lam) res' <- mapM (copyIfArray . resSubExp) res mapout_res' <- forM (zip mapout_params mapout_res) $ \(p, SubExpRes cs se) -> certifying cs . letSubExp "mapout_res" . BasicOp $ Update Unsafe (paramName p) (fullSlice (paramType p) slice) se pure $ subExpsRes $ res' ++ mapout_res' letBind pat $ Loop merge loop_form loop_body transformSOAC pat (Scatter len ivs as lam) = do iter <- newVName "write_iter" let (as_ws, as_ns, as_vs) = unzip3 as ts <- mapM lookupType as_vs asOuts <- mapM (newIdent "write_out") ts -- Scatter is in-place, so we use the input array as the output array. let merge = loopMerge asOuts $ map Var as_vs loopBody <- runBodyBuilder $ localScope (M.insert iter (IndexName Int64) $ scopeOfFParams $ map fst merge) $ do ivs' <- forM ivs $ \iv -> do iv_t <- lookupType iv letSubExp "write_iv" $ BasicOp $ Index iv $ fullSlice iv_t [DimFix $ Var iter] ivs'' <- bindLambda lam (map (BasicOp . SubExp) ivs') let indexes = groupScatterResults (zip3 as_ws as_ns $ map identName asOuts) ivs'' ress <- forM indexes $ \(_, arr, indexes') -> do arr_t <- lookupType arr let saveInArray arr' (indexCur, SubExpRes value_cs valueCur) = certifying (foldMap resCerts indexCur <> value_cs) . letExp "write_out" $ BasicOp $ Update Safe arr' (fullSlice arr_t $ map (DimFix . resSubExp) indexCur) valueCur foldM saveInArray arr indexes' pure $ varsRes ress letBind pat $ Loop merge (ForLoop iter Int64 len) loopBody transformSOAC pat (Hist len imgs ops bucket_fun) = do iter <- newVName "iter" -- Bind arguments to parameters for the merge-variables. hists_ts <- mapM lookupType $ concatMap histDest ops hists_out <- mapM (newIdent "dests") hists_ts let merge = loopMerge hists_out $ concatMap (map Var . histDest) ops -- Bind lambda-bodies for operators. let iter_scope = M.insert iter (IndexName Int64) $ scopeOfFParams $ map fst merge loopBody <- runBodyBuilder . localScope iter_scope $ do -- Bind images to parameters of bucket function. imgs' <- forM imgs $ \img -> do img_t <- lookupType img letSubExp "pixel" $ BasicOp $ Index img $ fullSlice img_t [DimFix $ Var iter] imgs'' <- map resSubExp <$> bindLambda bucket_fun (map (BasicOp . SubExp) imgs') -- Split out values from bucket function. let lens = sum $ map (shapeRank . histShape) ops ops_inds = chunks (map (shapeRank . histShape) ops) (take lens imgs'') vals = chunks (map (length . lambdaReturnType . histOp) ops) $ drop lens imgs'' hists_out' = chunks (map (length . lambdaReturnType . histOp) ops) $ map identName hists_out hists_out'' <- forM (zip4 hists_out' ops ops_inds vals) $ \(hist, op, idxs, val) -> do -- Check whether the indexes are in-bound. If they are not, we -- return the histograms unchanged. let outside_bounds_branch = buildBody_ $ pure $ varsRes hist oob = case hist of [] -> eSubExp $ constant True arr : _ -> eOutOfBounds arr $ map eSubExp idxs letTupExp "new_histo" <=< eIf oob outside_bounds_branch $ buildBody_ $ do -- Read values from histogram. h_val <- forM hist $ \arr -> do arr_t <- lookupType arr letSubExp "read_hist" $ BasicOp $ Index arr $ fullSlice arr_t $ map DimFix idxs -- Apply operator. h_val' <- bindLambda (histOp op) $ map (BasicOp . SubExp) $ h_val ++ val -- Write values back to histograms. hist' <- forM (zip hist h_val') $ \(arr, SubExpRes cs v) -> do arr_t <- lookupType arr certifying cs . letInPlace "hist_out" arr (fullSlice arr_t $ map DimFix idxs) $ BasicOp $ SubExp v pure $ varsRes hist' pure $ varsRes $ concat hists_out'' -- Wrap up the above into a for-loop. letBind pat $ Loop merge (ForLoop iter Int64 len) loopBody -- | Recursively first-order-transform a lambda. transformLambda :: ( MonadFreshNames m, Buildable rep, BuilderOps rep, LocalScope somerep m, SameScope somerep rep, LetDec rep ~ LetDec SOACS, Alias.AliasableRep rep ) => Lambda SOACS -> m (AST.Lambda rep) transformLambda (Lambda params rettype body) = do body' <- fmap fst . runBuilder $ localScope (scopeOfLParams params) $ transformBody body pure $ Lambda params rettype body' letwith :: (Transformer m) => [VName] -> SubExp -> [SubExp] -> m [VName] letwith ks i vs = do let update k v = do k_t <- lookupType k case k_t of Acc {} -> letExp "lw_acc" $ BasicOp $ SubExp v _ -> letInPlace "lw_dest" k (fullSlice k_t [DimFix i]) $ BasicOp $ SubExp v zipWithM update ks vs bindLambda :: (Transformer m) => AST.Lambda (Rep m) -> [AST.Exp (Rep m)] -> m Result bindLambda (Lambda params _ body) args = do forM_ (zip params args) $ \(param, arg) -> if primType $ paramType param then letBindNames [paramName param] arg else letBindNames [paramName param] =<< eCopy (pure arg) bodyBind body loopMerge :: [Ident] -> [SubExp] -> [(Param DeclType, SubExp)] loopMerge vars = loopMerge' $ map (,Unique) vars loopMerge' :: [(Ident, Uniqueness)] -> [SubExp] -> [(Param DeclType, SubExp)] loopMerge' vars vals = [ (Param mempty pname $ toDecl ptype u, val) | ((Ident pname ptype, u), val) <- zip vars vals ] -- Note [Translation of Screma] -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -- -- Screma is the most general SOAC. It is translated by constructing -- a loop that contains several groups of parameters, in this order: -- -- (0) Scan accumulator, initialised with neutral element. -- (1) Scan results, initialised with Scratch. -- (2) Reduce results (also functioning as accumulators), -- initialised with neutral element. -- (3) Map results, mostly initialised with Scratch. -- -- However, category (3) is a little more tricky in the case where one -- of the results is an Acc. In that case, the result is not an -- array, but another Acc. Any Acc result of a Map must correspond to -- an Acc that is an input to the map, and the result is initialised -- to be that input. This requires a 1:1 relationship between Acc -- inputs and Acc outputs, which the type checker should enforce. -- There is no guarantee that the map results appear in any particular -- order (e.g. accumulator results before non-accumulator results), so -- we need to do a little sleuthing to establish the relationship. -- -- Inside the loop, the non-Acc parameters to map_lam become for-in -- parameters. Acc parameters refer to the loop parameters for the -- corresponding Map result instead. -- -- Intuitively, a Screma(w, -- (scan_op, scan_ne), -- (red_op, red_ne), -- map_fn, -- {acc_input, arr_input}) -- -- then becomes -- -- loop (scan_acc, scan_arr, red_acc, map_acc, map_arr) = -- for i < w, x in arr_input do -- let (a,b,map_acc',d) = map_fn(map_acc, x) -- let scan_acc' = scan_op(scan_acc, a) -- let scan_arr[i] = scan_acc' -- let red_acc' = red_op(red_acc, b) -- let map_arr[i] = d -- in (scan_acc', scan_arr', red_acc', map_acc', map_arr) futhark-0.25.27/src/Futhark/Transform/Rename.hs000066400000000000000000000255541475065116200212660ustar00rootroot00000000000000{-# LANGUAGE UndecidableInstances #-} -- | This module provides facilities for transforming Futhark programs -- such that names are unique, via the 'renameProg' function. module Futhark.Transform.Rename ( -- * Renaming programs renameProg, -- * Renaming parts of a program. -- -- These all require execution in a 'MonadFreshNames' environment. renameExp, renameStm, renameBody, renameLambda, renamePat, renameSomething, renameBound, renameStmsWith, -- * Renaming annotations RenameM, substituteRename, renamingStms, Rename (..), Renameable, ) where import Control.Monad.Reader import Control.Monad.State import Data.Bitraversable import Data.Map.Strict qualified as M import Data.Maybe import Futhark.FreshNames hiding (newName) import Futhark.IR.Prop.Names import Futhark.IR.Prop.Pat import Futhark.IR.Syntax import Futhark.IR.Traversals import Futhark.MonadFreshNames (MonadFreshNames (..), modifyNameSource, newName) import Futhark.Transform.Substitute runRenamer :: RenameM a -> VNameSource -> (a, VNameSource) runRenamer (RenameM m) src = runReader (runStateT m src) env where env = RenameEnv M.empty -- | Rename variables such that each is unique. The semantics of the -- program are unaffected, under the assumption that the program was -- correct to begin with. In particular, the renaming may make an -- invalid program valid. renameProg :: (Renameable rep, MonadFreshNames m) => Prog rep -> m (Prog rep) renameProg prog = modifyNameSource $ runRenamer $ renamingStms (progConsts prog) $ \consts -> do funs <- mapM rename (progFuns prog) pure prog {progConsts = consts, progFuns = funs} -- | Rename bound variables such that each is unique. The semantics -- of the expression is unaffected, under the assumption that the -- expression was correct to begin with. Any free variables are left -- untouched. renameExp :: (Renameable rep, MonadFreshNames m) => Exp rep -> m (Exp rep) renameExp = modifyNameSource . runRenamer . rename -- | Rename bound variables such that each is unique. The semantics -- of the binding is unaffected, under the assumption that the -- binding was correct to begin with. Any free variables are left -- untouched, as are the names in the pattern of the binding. renameStm :: (Renameable rep, MonadFreshNames m) => Stm rep -> m (Stm rep) renameStm binding = do e <- renameExp $ stmExp binding pure binding {stmExp = e} -- | Rename bound variables such that each is unique. The semantics -- of the body is unaffected, under the assumption that the body was -- correct to begin with. Any free variables are left untouched. renameBody :: (Renameable rep, MonadFreshNames m) => Body rep -> m (Body rep) renameBody = modifyNameSource . runRenamer . rename -- | Rename bound variables such that each is unique. The semantics -- of the lambda is unaffected, under the assumption that the body was -- correct to begin with. Any free variables are left untouched. -- Note in particular that the parameters of the lambda are renamed. renameLambda :: (Renameable rep, MonadFreshNames m) => Lambda rep -> m (Lambda rep) renameLambda = modifyNameSource . runRenamer . rename -- | Produce an equivalent pattern but with each pattern element given -- a new name. renamePat :: (Rename dec, MonadFreshNames m) => Pat dec -> m (Pat dec) renamePat = modifyNameSource . runRenamer . rename' where rename' pat = renameBound (patNames pat) $ rename pat -- | Rename the bound variables in something (does not affect free variables). renameSomething :: (Rename a, MonadFreshNames m) => a -> m a renameSomething = modifyNameSource . runRenamer . rename -- | Rename statements, then rename something within the scope of -- those statements. renameStmsWith :: (MonadFreshNames m, Renameable rep, Rename a) => Stms rep -> a -> m (Stms rep, a) renameStmsWith stms a = modifyNameSource . runRenamer $ renamingStms stms $ \stms' -> (stms',) <$> rename a newtype RenameEnv = RenameEnv {envNameMap :: M.Map VName VName} -- | The monad in which renaming is performed. newtype RenameM a = RenameM (StateT VNameSource (Reader RenameEnv) a) deriving ( Functor, Applicative, Monad, MonadFreshNames, MonadReader RenameEnv ) -- | Produce a map of the substitutions that should be performed by -- the renamer. renamerSubstitutions :: RenameM Substitutions renamerSubstitutions = asks envNameMap -- | Perform a renaming using the 'Substitute' instance. This only -- works if the argument does not itself perform any name binding, but -- it can save on boilerplate for simple types. substituteRename :: (Substitute a) => a -> RenameM a substituteRename x = do substs <- renamerSubstitutions pure $ substituteNames substs x -- | Members of class 'Rename' can be uniquely renamed. class Rename a where -- | Rename the given value such that it does not contain shadowing, -- and has incorporated any substitutions present in the 'RenameM' -- environment. rename :: a -> RenameM a instance Rename VName where rename name = asks (fromMaybe name . M.lookup name . envNameMap) instance (Rename a) => Rename [a] where rename = mapM rename instance (Rename a, Rename b) => Rename (a, b) where rename (a, b) = (,) <$> rename a <*> rename b instance (Rename a, Rename b, Rename c) => Rename (a, b, c) where rename (a, b, c) = do a' <- rename a b' <- rename b c' <- rename c pure (a', b', c') instance (Rename a) => Rename (Maybe a) where rename = maybe (pure Nothing) (fmap Just . rename) instance Rename Bool where rename = pure instance Rename Ident where rename (Ident name tp) = do name' <- rename name tp' <- rename tp pure $ Ident name' tp' -- | Rename variables in binding position. The provided VNames are -- associated with new, fresh names in the renaming environment. renameBound :: [VName] -> RenameM a -> RenameM a renameBound vars body = do vars' <- mapM newName vars -- This works because map union prefers elements from left -- operand. local (renameBound' vars') body where renameBound' vars' env = env { envNameMap = M.fromList (zip vars vars') `M.union` envNameMap env } -- | Rename some statements, then execute an action with the name -- substitutions induced by the statements active. renamingStms :: (Renameable rep) => Stms rep -> (Stms rep -> RenameM a) -> RenameM a renamingStms stms m = descend mempty stms where descend stms' rem_stms = case stmsHead rem_stms of Nothing -> m stms' Just (stm, rem_stms') -> renameBound (patNames $ stmPat stm) $ do stm' <- rename stm descend (stms' <> oneStm stm') rem_stms' instance (Renameable rep) => Rename (FunDef rep) where rename (FunDef entry attrs fname ret params body) = renameBound (map paramName params) $ do params' <- mapM rename params body' <- rename body ret' <- mapM (bitraverse rename pure) ret pure $ FunDef entry attrs fname ret' params' body' instance Rename SubExp where rename (Var v) = Var <$> rename v rename (Constant v) = pure $ Constant v instance (Rename dec) => Rename (Param dec) where rename (Param attrs name dec) = Param <$> rename attrs <*> rename name <*> rename dec instance (Rename dec) => Rename (Pat dec) where rename (Pat xs) = Pat <$> rename xs instance (Rename dec) => Rename (PatElem dec) where rename (PatElem ident dec) = PatElem <$> rename ident <*> rename dec instance Rename Certs where rename (Certs cs) = Certs <$> rename cs instance Rename Attrs where rename = pure instance (Rename dec) => Rename (StmAux dec) where rename (StmAux cs attrs dec) = StmAux <$> rename cs <*> rename attrs <*> rename dec instance Rename SubExpRes where rename (SubExpRes cs se) = SubExpRes <$> rename cs <*> rename se instance (Renameable rep) => Rename (Body rep) where rename (Body dec stms res) = do dec' <- rename dec renamingStms stms $ \stms' -> Body dec' stms' <$> rename res instance (Renameable rep) => Rename (Stm rep) where rename (Let pat dec e) = Let <$> rename pat <*> rename dec <*> rename e instance (Renameable rep) => Rename (Exp rep) where rename (WithAcc inputs lam) = WithAcc <$> rename inputs <*> rename lam rename (Loop merge form loopbody) = do let (params, args) = unzip merge args' <- mapM rename args case form of -- It is important that 'i' is renamed before the loop_vars, as -- 'i' may be used in the annotations for loop_vars (e.g. index -- functions). ForLoop i it boundexp -> renameBound [i] $ do boundexp' <- rename boundexp renameBound (map paramName params) $ do params' <- mapM rename params i' <- rename i loopbody' <- rename loopbody pure $ Loop (zip params' args') (ForLoop i' it boundexp') loopbody' WhileLoop cond -> renameBound (map paramName params) $ do params' <- mapM rename params loopbody' <- rename loopbody cond' <- rename cond pure $ Loop (zip params' args') (WhileLoop cond') loopbody' rename e = mapExpM mapper e where mapper = Mapper { mapOnBody = const rename, mapOnSubExp = rename, mapOnVName = rename, mapOnRetType = rename, mapOnBranchType = rename, mapOnFParam = rename, mapOnLParam = rename, mapOnOp = rename } instance Rename PrimType where rename = pure instance (Rename shape) => Rename (TypeBase shape u) where rename (Array et size u) = Array <$> rename et <*> rename size <*> pure u rename (Prim t) = pure $ Prim t rename (Mem space) = pure $ Mem space rename (Acc acc ispace ts u) = Acc <$> rename acc <*> rename ispace <*> rename ts <*> pure u instance (Renameable rep) => Rename (Lambda rep) where rename (Lambda params ret body) = renameBound (map paramName params) $ Lambda <$> mapM rename params <*> mapM rename ret <*> rename body instance Rename Names where rename = fmap namesFromList . mapM rename . namesToList instance Rename Rank where rename = pure instance (Rename d) => Rename (ShapeBase d) where rename (Shape l) = Shape <$> mapM rename l instance Rename ExtSize where rename (Free se) = Free <$> rename se rename (Ext x) = pure $ Ext x instance Rename () where rename = pure instance Rename (NoOp rep) where rename NoOp = pure NoOp instance (Rename d) => Rename (DimIndex d) where rename (DimFix i) = DimFix <$> rename i rename (DimSlice i n s) = DimSlice <$> rename i <*> rename n <*> rename s -- | Representations in which all decorations are renameable. type Renameable rep = ( Rename (LetDec rep), Rename (ExpDec rep), Rename (BodyDec rep), Rename (FParamInfo rep), Rename (LParamInfo rep), Rename (RetType rep), Rename (BranchType rep), Rename (Op rep) ) futhark-0.25.27/src/Futhark/Transform/Substitute.hs000066400000000000000000000164361475065116200222310ustar00rootroot00000000000000{-# LANGUAGE UndecidableInstances #-} -- | -- -- This module contains facilities for replacing variable names in -- syntactic constructs. module Futhark.Transform.Substitute ( Substitutions, Substitute (..), Substitutable, ) where import Control.Monad.Identity import Data.Map.Strict qualified as M import Futhark.Analysis.PrimExp import Futhark.IR.Prop.Names import Futhark.IR.Prop.Scope import Futhark.IR.Syntax import Futhark.IR.Traversals -- | The substitutions to be made are given by a mapping from names to -- names. type Substitutions = M.Map VName VName -- | A type that is an instance of this class supports substitution of -- any names contained within. class Substitute a where -- | @substituteNames m e@ replaces the variable names in @e@ with -- new names, based on the mapping in @m@. It is assumed that all -- names in @e@ are unique, i.e. there is no shadowing. substituteNames :: M.Map VName VName -> a -> a instance (Substitute a) => Substitute [a] where substituteNames substs = map $ substituteNames substs instance (Substitute (Stm rep)) => Substitute (Stms rep) where substituteNames substs = fmap $ substituteNames substs instance (Substitute a, Substitute b) => Substitute (a, b) where substituteNames substs (x, y) = (substituteNames substs x, substituteNames substs y) instance (Substitute a, Substitute b, Substitute c) => Substitute (a, b, c) where substituteNames substs (x, y, z) = ( substituteNames substs x, substituteNames substs y, substituteNames substs z ) instance (Substitute a, Substitute b, Substitute c, Substitute d) => Substitute (a, b, c, d) where substituteNames substs (x, y, z, u) = ( substituteNames substs x, substituteNames substs y, substituteNames substs z, substituteNames substs u ) instance (Substitute a) => Substitute (Maybe a) where substituteNames substs = fmap $ substituteNames substs instance Substitute Bool where substituteNames = const id instance Substitute VName where substituteNames substs k = M.findWithDefault k k substs instance Substitute SubExp where substituteNames substs (Var v) = Var $ substituteNames substs v substituteNames _ (Constant v) = Constant v instance (Substitutable rep) => Substitute (Exp rep) where substituteNames substs = mapExp $ replace substs instance (Substitute dec) => Substitute (PatElem dec) where substituteNames substs (PatElem ident dec) = PatElem (substituteNames substs ident) (substituteNames substs dec) instance Substitute Attrs where substituteNames _ attrs = attrs instance (Substitute dec) => Substitute (StmAux dec) where substituteNames substs (StmAux cs attrs dec) = StmAux (substituteNames substs cs) (substituteNames substs attrs) (substituteNames substs dec) instance (Substitute dec) => Substitute (Param dec) where substituteNames substs (Param attrs name dec) = Param (substituteNames substs attrs) (substituteNames substs name) (substituteNames substs dec) instance Substitute SubExpRes where substituteNames substs (SubExpRes cs se) = SubExpRes (substituteNames substs cs) (substituteNames substs se) instance (Substitute dec) => Substitute (Pat dec) where substituteNames substs (Pat xs) = Pat (substituteNames substs xs) instance Substitute Certs where substituteNames substs (Certs cs) = Certs $ substituteNames substs cs instance (Substitutable rep) => Substitute (Stm rep) where substituteNames substs (Let pat annot e) = Let (substituteNames substs pat) (substituteNames substs annot) (substituteNames substs e) instance (Substitutable rep) => Substitute (Body rep) where substituteNames substs (Body dec stms res) = Body (substituteNames substs dec) (substituteNames substs stms) (substituteNames substs res) replace :: (Substitutable rep) => M.Map VName VName -> Mapper rep rep Identity replace substs = Mapper { mapOnVName = pure . substituteNames substs, mapOnSubExp = pure . substituteNames substs, mapOnBody = const $ pure . substituteNames substs, mapOnRetType = pure . substituteNames substs, mapOnBranchType = pure . substituteNames substs, mapOnFParam = pure . substituteNames substs, mapOnLParam = pure . substituteNames substs, mapOnOp = pure . substituteNames substs } instance Substitute Rank where substituteNames _ = id instance Substitute () where substituteNames _ = id instance Substitute (NoOp rep) where substituteNames _ = id instance (Substitute d) => Substitute (ShapeBase d) where substituteNames substs (Shape es) = Shape $ map (substituteNames substs) es instance (Substitute d) => Substitute (Ext d) where substituteNames substs (Free x) = Free $ substituteNames substs x substituteNames _ (Ext x) = Ext x instance Substitute Names where substituteNames = mapNames . substituteNames instance Substitute PrimType where substituteNames _ t = t instance (Substitute shape) => Substitute (TypeBase shape u) where substituteNames _ (Prim et) = Prim et substituteNames substs (Acc acc ispace ts u) = Acc (substituteNames substs acc) (substituteNames substs ispace) (substituteNames substs ts) u substituteNames substs (Array et sz u) = Array (substituteNames substs et) (substituteNames substs sz) u substituteNames _ (Mem space) = Mem space instance (Substitutable rep) => Substitute (Lambda rep) where substituteNames substs (Lambda params rettype body) = Lambda (substituteNames substs params) (map (substituteNames substs) rettype) (substituteNames substs body) instance Substitute Ident where substituteNames substs v = v { identName = substituteNames substs $ identName v, identType = substituteNames substs $ identType v } instance (Substitute d) => Substitute (DimIndex d) where substituteNames substs = fmap $ substituteNames substs instance (Substitute d) => Substitute (Slice d) where substituteNames substs = fmap $ substituteNames substs instance (Substitute d) => Substitute (FlatDimIndex d) where substituteNames substs = fmap $ substituteNames substs instance (Substitute d) => Substitute (FlatSlice d) where substituteNames substs = fmap $ substituteNames substs instance (Substitute v) => Substitute (PrimExp v) where substituteNames substs = fmap $ substituteNames substs instance (Substitute v) => Substitute (TPrimExp t v) where substituteNames substs = TPrimExp . fmap (substituteNames substs) . untyped instance (Substitutable rep) => Substitute (NameInfo rep) where substituteNames subst (LetName dec) = LetName $ substituteNames subst dec substituteNames subst (FParamName dec) = FParamName $ substituteNames subst dec substituteNames subst (LParamName dec) = LParamName $ substituteNames subst dec substituteNames _ (IndexName it) = IndexName it instance Substitute FV where substituteNames subst = fvNames . substituteNames subst . freeIn -- | Representations in which all annotations support name -- substitution. type Substitutable rep = ( RepTypes rep, Substitute (ExpDec rep), Substitute (BodyDec rep), Substitute (LetDec rep), Substitute (FParamInfo rep), Substitute (LParamInfo rep), Substitute (RetType rep), Substitute (BranchType rep), Substitute (Op rep) ) futhark-0.25.27/src/Futhark/Util.hs000066400000000000000000000376541475065116200170250ustar00rootroot00000000000000-- | Non-Futhark-specific utilities. If you find yourself writing -- general functions on generic data structures, consider putting them -- here. -- -- Sometimes it is also preferable to copy a small function rather -- than introducing a large dependency. In this case, make sure to -- note where you got it from (and make sure that the license is -- compatible). module Futhark.Util ( nubOrd, nubByOrd, mapAccumLM, maxinum, mininum, chunk, chunks, chunkLike, dropAt, takeLast, dropLast, mapEither, partitionMaybe, maybeNth, maybeHead, lookupWithIndex, splitFromEnd, splitAt3, focusNth, focusMaybe, hashText, showText, unixEnvironment, isEnvVarAtLeast, startupTime, fancyTerminal, hFancyTerminal, runProgramWithExitCode, directoryContents, fromPOSIX, toPOSIX, trim, pmapIO, interactWithFileSafely, convFloat, UserText, EncodedText, zEncodeText, atMostChars, invertMap, cartesian, traverseFold, fixPoint, concatMapM, topologicalSort, debugTraceM, ) where import Control.Concurrent import Control.Exception import Control.Monad import Control.Monad.State import Crypto.Hash.MD5 as MD5 import Data.Bifunctor import Data.ByteString qualified as BS import Data.ByteString.Base16 qualified as Base16 import Data.Char import Data.Either import Data.Foldable (fold, toList) import Data.Function ((&)) import Data.IntMap qualified as IM import Data.List qualified as L import Data.List.NonEmpty qualified as NE import Data.Map qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T import Data.Text.Encoding qualified as T import Data.Text.Encoding.Error qualified as T import Data.Time.Clock (UTCTime, getCurrentTime) import Data.Tuple (swap) import Debug.Trace import Numeric import System.Directory.Tree qualified as Dir import System.Environment import System.Exit import System.FilePath qualified as Native import System.FilePath.Posix qualified as Posix import System.IO (Handle, hIsTerminalDevice, stdout) import System.IO.Error (isDoesNotExistError) import System.IO.Unsafe import System.Process.ByteString import Text.Read (readMaybe) -- | Like @nub@, but without the quadratic runtime. nubOrd :: (Ord a) => [a] -> [a] nubOrd = nubByOrd compare -- | Like @nubBy@, but without the quadratic runtime. nubByOrd :: (a -> a -> Ordering) -> [a] -> [a] nubByOrd cmp = map NE.head . NE.groupBy eq . L.sortBy cmp where eq x y = cmp x y == EQ -- | Like 'Data.Traversable.mapAccumL', but monadic and generalised to -- any 'Traversable'. mapAccumLM :: (Monad m, Traversable t) => (acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y) mapAccumLM op initial l = do (l', acc) <- runStateT (traverse f l) initial pure (acc, l') where f x = do acc <- get (acc', y) <- lift $ op acc x put acc' pure y -- | @chunk n a@ splits @a@ into @n@-size-chunks. If the length of -- @a@ is not divisible by @n@, the last chunk will have fewer than -- @n@ elements (but it will never be empty). chunk :: Int -> [a] -> [[a]] chunk _ [] = [] chunk n xs = let (bef, aft) = splitAt n xs in bef : chunk n aft -- | @chunks ns a@ splits @a@ into chunks determined by the elements -- of @ns@. It must hold that @sum ns == length a@, or the resulting -- list may contain too few chunks, or not all elements of @a@. chunks :: [Int] -> [a] -> [[a]] chunks [] _ = [] chunks (n : ns) xs = let (bef, aft) = splitAt n xs in bef : chunks ns aft -- | @chunkLike xss ys@ chunks the elements of @ys@ to match the -- elements of @xss@. The sum of the lengths of the sublists of @xss@ -- must match the length of @ys@. chunkLike :: [[a]] -> [b] -> [[b]] chunkLike as = chunks (map length as) -- | Like 'maximum', but returns zero for an empty list. maxinum :: (Num a, Ord a, Foldable f) => f a -> a maxinum = L.foldl' max 0 -- | Like 'minimum', but returns zero for an empty list. mininum :: (Num a, Ord a, Foldable f) => f a -> a mininum xs = L.foldl' min (maxinum xs) xs -- | @dropAt i n@ drops @n@ elements starting at element @i@. dropAt :: Int -> Int -> [a] -> [a] dropAt i n xs = take i xs ++ drop (i + n) xs -- | @takeLast n l@ takes the last @n@ elements of @l@. takeLast :: Int -> [a] -> [a] takeLast n = reverse . take n . reverse -- | @dropLast n l@ drops the last @n@ elements of @l@. dropLast :: Int -> [a] -> [a] dropLast n = reverse . drop n . reverse -- | A combination of 'map' and 'partitionEithers'. mapEither :: (a -> Either b c) -> [a] -> ([b], [c]) mapEither f l = partitionEithers $ map f l -- | A combination of 'Data.List.partition' and 'mapMaybe' partitionMaybe :: (a -> Maybe b) -> [a] -> ([b], [a]) partitionMaybe f = helper ([], []) where helper (acc1, acc2) [] = (reverse acc1, reverse acc2) helper (acc1, acc2) (x : xs) = case f x of Just x' -> helper (x' : acc1, acc2) xs Nothing -> helper (acc1, x : acc2) xs -- | Return the list element at the given index, if the index is valid. maybeNth :: (Integral int) => int -> [a] -> Maybe a maybeNth i l | i >= 0, v : _ <- L.genericDrop i l = Just v | otherwise = Nothing -- | Return the first element of the list, if it exists. maybeHead :: [a] -> Maybe a maybeHead [] = Nothing maybeHead (x : _) = Just x -- | Lookup a value, returning also the index at which it appears. lookupWithIndex :: (Eq a) => a -> [(a, b)] -> Maybe (Int, b) lookupWithIndex needle haystack = lookup needle $ zip (map fst haystack) (zip [0 ..] (map snd haystack)) -- | Like 'splitAt', but from the end. splitFromEnd :: Int -> [a] -> ([a], [a]) splitFromEnd i l = splitAt (length l - i) l -- | Like 'splitAt', but produces three lists. splitAt3 :: Int -> Int -> [a] -> ([a], [a], [a]) splitAt3 n m l = let (xs, l') = splitAt n l (ys, zs) = splitAt m l' in (xs, ys, zs) -- | Return the list element at the given index, if the index is -- valid, along with the elements before and after. focusNth :: (Integral int) => int -> [a] -> Maybe ([a], a, [a]) focusNth i xs | (bef, x : aft) <- L.genericSplitAt i xs = Just (bef, x, aft) | otherwise = Nothing -- | Return the first list element that satisifes a predicate, along with the -- elements before and after. focusMaybe :: (a -> Maybe b) -> [a] -> Maybe ([a], b, [a]) focusMaybe f xs = do idx <- L.findIndex (isJust . f) xs (before, focus, after) <- focusNth idx xs res <- f focus pure (before, res, after) -- | Compute a hash of a text that is stable across OS versions. -- Returns the hash as a text as well, ready for human consumption. hashText :: T.Text -> T.Text hashText = T.decodeUtf8With T.lenientDecode . Base16.encode . MD5.hash . T.encodeUtf8 -- | Like 'show', but produces text. showText :: (Show a) => a -> T.Text showText = T.pack . show {-# NOINLINE unixEnvironment #-} -- | The Unix environment when the Futhark compiler started. unixEnvironment :: [(String, String)] unixEnvironment = unsafePerformIO getEnvironment -- | True if the environment variable, viewed as an integer, has at -- least this numeric value. Returns False if variable is unset or -- not numeric. isEnvVarAtLeast :: String -> Int -> Bool isEnvVarAtLeast s x = case readMaybe =<< lookup s unixEnvironment of Just y -> y >= x _ -> False {-# NOINLINE startupTime #-} -- | The time at which the process started - or more accurately, the -- first time this binding was forced. startupTime :: UTCTime startupTime = unsafePerformIO getCurrentTime {-# NOINLINE fancyTerminal #-} -- | Are we running in a terminal capable of fancy commands and -- visualisation? fancyTerminal :: Bool fancyTerminal = unsafePerformIO $ hFancyTerminal stdout -- | Is this handle connected to a terminal capable of fancy commands -- and visualisation? hFancyTerminal :: Handle -> IO Bool hFancyTerminal h = do isTTY <- hIsTerminalDevice h isDumb <- (Just "dumb" ==) <$> lookupEnv "TERM" pure $ isTTY && not isDumb -- | Like 'readProcessWithExitCode', but also wraps exceptions when -- the indicated binary cannot be launched, or some other exception is -- thrown. Also does shenanigans to handle improperly encoded outputs. runProgramWithExitCode :: FilePath -> [String] -> BS.ByteString -> IO (Either IOException (ExitCode, String, String)) runProgramWithExitCode exe args inp = (Right . postprocess <$> readProcessWithExitCode exe args inp) `catch` \e -> pure (Left e) where decode = T.unpack . T.decodeUtf8With T.lenientDecode postprocess (code, out, err) = (code, decode out, decode err) -- | Every non-directory file contained in a directory tree. directoryContents :: FilePath -> IO [FilePath] directoryContents dir = do _ Dir.:/ tree <- Dir.readDirectoryWith pure dir case Dir.failures tree of Dir.Failed _ err : _ -> throw err _ -> pure $ mapMaybe isFile $ Dir.flattenDir tree where isFile (Dir.File _ path) = Just path isFile _ = Nothing -- | Turn a POSIX filepath into a filepath for the native system. toPOSIX :: Native.FilePath -> Posix.FilePath toPOSIX = Posix.joinPath . Native.splitDirectories -- | Some bad operating systems do not use forward slash as -- directory separator - this is where we convert Futhark includes -- (which always use forward slash) to native paths. fromPOSIX :: Posix.FilePath -> Native.FilePath fromPOSIX = Native.joinPath . Posix.splitDirectories -- | Remove leading and trailing whitespace from a string. Not an -- efficient implementation! trim :: String -> String trim = reverse . dropWhile isSpace . reverse . dropWhile isSpace -- | Run various 'IO' actions concurrently, possibly with a bound on -- the number of threads. The list must be finite. The ordering of -- the result list is not deterministic - add your own sorting if -- needed. If any of the actions throw an exception, then that -- exception is propagated to this function. pmapIO :: Maybe Int -> (a -> IO b) -> [a] -> IO [b] pmapIO concurrency f elems = do tasks <- newMVar elems results <- newEmptyMVar num_threads <- maybe getNumCapabilities pure concurrency replicateM_ num_threads $ forkIO $ worker tasks results replicateM (length elems) $ getResult results where worker tasks results = do task <- modifyMVar tasks getTask case task of Nothing -> pure () Just x -> do y <- (Right <$> f x) `catch` (pure . Left) putMVar results y worker tasks results getTask [] = pure ([], Nothing) getTask (task : tasks) = pure (tasks, Just task) getResult results = do res <- takeMVar results case res of Left err -> throw (err :: SomeException) Right v -> pure v -- | Do some operation on a file, returning 'Nothing' if the file does -- not exist, and 'Left' if some other error occurs. interactWithFileSafely :: IO a -> IO (Maybe (Either String a)) interactWithFileSafely m = (Just . Right <$> m) `catch` couldNotRead where couldNotRead e | isDoesNotExistError e = pure Nothing | otherwise = pure $ Just $ Left $ show e -- | Convert between different floating-point types, preserving -- infinities and NaNs. convFloat :: (RealFloat from, RealFloat to) => from -> to convFloat v | isInfinite v, v > 0 = 1 / 0 | isInfinite v, v < 0 = -1 / 0 | isNaN v = 0 / 0 | otherwise = fromRational $ toRational v -- Z-encoding from https://ghc.haskell.org/trac/ghc/wiki/Commentary/Compiler/SymbolNames -- -- Slightly simplified as we do not need it to deal with tuples and -- the like. -- -- (c) The University of Glasgow, 1997-2006 -- | As the user typed it. type UserString = String -- | Encoded form. type EncodedString = String -- | As 'zEncodeText', but for strings. zEncodeString :: UserString -> EncodedString zEncodeString "" = "" zEncodeString (c : cs) = encodeDigitChar c ++ concatMap encodeChar cs -- | As the user typed it. type UserText = T.Text -- | Encoded form. type EncodedText = T.Text -- | Z-encode a text using a slightly simplified variant of GHC -- Z-encoding. The encoded string is a valid identifier in most -- programming languages. zEncodeText :: UserText -> EncodedText zEncodeText = T.pack . zEncodeString . T.unpack unencodedChar :: Char -> Bool -- True for chars that don't need encoding unencodedChar 'Z' = False unencodedChar 'z' = False unencodedChar '_' = True unencodedChar c = isAsciiLower c || isAsciiUpper c || isDigit c -- If a digit is at the start of a symbol then we need to encode it. -- Otherwise names like 9pH-0.1 give linker errors. encodeDigitChar :: Char -> EncodedString encodeDigitChar c | isDigit c = encodeAsUnicodeCharar c | otherwise = encodeChar c encodeChar :: Char -> EncodedString encodeChar c | unencodedChar c = [c] -- Common case first -- Constructors encodeChar '(' = "ZL" -- Needed for things like (,), and (->) encodeChar ')' = "ZR" -- For symmetry with ( encodeChar '[' = "ZM" encodeChar ']' = "ZN" encodeChar ':' = "ZC" encodeChar 'Z' = "ZZ" -- Variables encodeChar 'z' = "zz" encodeChar '&' = "za" encodeChar '|' = "zb" encodeChar '^' = "zc" encodeChar '$' = "zd" encodeChar '=' = "ze" encodeChar '>' = "zg" encodeChar '#' = "zh" encodeChar '.' = "zi" encodeChar '<' = "zl" encodeChar '-' = "zm" encodeChar '!' = "zn" encodeChar '+' = "zp" encodeChar '\'' = "zq" encodeChar '\\' = "zr" encodeChar '/' = "zs" encodeChar '*' = "zt" encodeChar '_' = "zu" encodeChar '%' = "zv" encodeChar c = encodeAsUnicodeCharar c encodeAsUnicodeCharar :: Char -> EncodedString encodeAsUnicodeCharar c = 'z' : if maybe False isDigit $ maybeHead hex_str then hex_str else '0' : hex_str where hex_str = showHex (ord c) "U" -- | Truncate to at most this many characters, making the last three -- characters "..." if truncation is necessary. atMostChars :: Int -> T.Text -> T.Text atMostChars n s | T.length s > n = T.take (n - 3) s <> "..." | otherwise = s -- | Invert a map, handling duplicate values (now keys) by -- constructing a set of corresponding values. invertMap :: (Ord v, Ord k) => M.Map k v -> M.Map v (S.Set k) invertMap m = foldr (uncurry (M.insertWith (<>)) . swap . first S.singleton) mempty (M.toList m) -- | Compute the cartesian product of two foldable collections, using the given -- combinator function. cartesian :: (Monoid m, Foldable t) => (a -> a -> m) -> t a -> t a -> m cartesian f xs ys = [(x, y) | x <- toList xs, y <- toList ys] & foldMap (uncurry f) -- | Applicatively fold a traversable. traverseFold :: (Monoid m, Traversable t, Applicative f) => (a -> f m) -> t a -> f m traverseFold f = fmap fold . traverse f -- | Perform fixpoint iteration. fixPoint :: (Eq a) => (a -> a) -> a -> a fixPoint f x = let x' = f x in if x' == x then x else fixPoint f x' -- | Like 'concatMap', but monoidal and monadic. concatMapM :: (Monad m, Monoid b) => (a -> m b) -> [a] -> m b concatMapM f xs = mconcat <$> mapM f xs -- | Topological sorting of an array with an adjancency function, if -- there is a cycle, it causes an error. @dep a b@ means @a -> b@, -- and the returned array guarantee that for i < j: -- -- @not ( dep (ret !! j) (ret !! i) )@. topologicalSort :: (a -> a -> Bool) -> [a] -> [a] topologicalSort dep nodes = fst $ execState (mapM_ (sorting . snd) nodes_idx) (mempty, mempty) where nodes_idx = zip nodes [0 ..] depends_of a (b, i) = if a `dep` b then Just i else Nothing -- Using an IntMap Bool -- when reading a lookup: -- \* Nothing : never explored -- \* Just True : being explored -- \* Just False : explored sorting i = do status <- gets $ IM.lookup i . snd when (status == Just True) $ error "topological sorting has encountered a cycle" unless (status == Just False) $ do let node = nodes !! i modify $ second $ IM.insert i True mapM_ sorting $ mapMaybe (depends_of node) nodes_idx modify $ bimap (node :) (IM.insert i False) -- | 'traceM', but only if @FUTHARK_COMPILER_DEBUGGING@ is set to to -- the appropriate level. debugTraceM :: (Monad m) => Int -> String -> m () debugTraceM level | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" level = traceM | otherwise = const $ pure () futhark-0.25.27/src/Futhark/Util/000077500000000000000000000000001475065116200164525ustar00rootroot00000000000000futhark-0.25.27/src/Futhark/Util/CMath.hs000066400000000000000000000105551475065116200200100ustar00rootroot00000000000000-- | Bindings to the C math library. -- -- Follows the naming scheme of the C functions when feasible. module Futhark.Util.CMath ( roundFloat, ceilFloat, floorFloat, roundDouble, ceilDouble, floorDouble, nextafterf, nextafter, lgamma, lgammaf, tgamma, tgammaf, erf, erff, erfc, erfcf, cbrt, cbrtf, hypot, hypotf, ldexp, ldexpf, copysign, copysignf, ) where import Foreign.C.Types (CInt (..)) foreign import ccall "nearbyint" c_nearbyint :: Double -> Double foreign import ccall "nearbyintf" c_nearbyintf :: Float -> Float foreign import ccall "ceil" c_ceil :: Double -> Double foreign import ccall "ceilf" c_ceilf :: Float -> Float foreign import ccall "floor" c_floor :: Double -> Double foreign import ccall "floorf" c_floorf :: Float -> Float -- | Round a single-precision floating point number correctly. roundFloat :: Float -> Float roundFloat = c_nearbyintf -- | Round a single-precision floating point number upwards correctly. ceilFloat :: Float -> Float ceilFloat = c_ceilf -- | Round a single-precision floating point number downwards correctly. floorFloat :: Float -> Float floorFloat = c_floorf -- | Round a double-precision floating point number correctly. roundDouble :: Double -> Double roundDouble = c_nearbyint -- | Round a double-precision floating point number upwards correctly. ceilDouble :: Double -> Double ceilDouble = c_ceil -- | Round a double-precision floating point number downwards correctly. floorDouble :: Double -> Double floorDouble = c_floor foreign import ccall "nextafter" c_nextafter :: Double -> Double -> Double foreign import ccall "nextafterf" c_nextafterf :: Float -> Float -> Float -- | The next representable single-precision floating-point value in -- the given direction. nextafterf :: Float -> Float -> Float nextafterf = c_nextafterf -- | The next representable double-precision floating-point value in -- the given direction. nextafter :: Double -> Double -> Double nextafter = c_nextafter foreign import ccall "lgamma" c_lgamma :: Double -> Double foreign import ccall "lgammaf" c_lgammaf :: Float -> Float foreign import ccall "tgamma" c_tgamma :: Double -> Double foreign import ccall "tgammaf" c_tgammaf :: Float -> Float -- | The system-level @lgamma()@ function. lgamma :: Double -> Double lgamma = c_lgamma -- | The system-level @lgammaf()@ function. lgammaf :: Float -> Float lgammaf = c_lgammaf -- | The system-level @tgamma()@ function. tgamma :: Double -> Double tgamma = c_tgamma -- | The system-level @tgammaf()@ function. tgammaf :: Float -> Float tgammaf = c_tgammaf foreign import ccall "hypot" c_hypot :: Double -> Double -> Double foreign import ccall "hypotf" c_hypotf :: Float -> Float -> Float -- | The system-level @hypot@ function. hypot :: Double -> Double -> Double hypot = c_hypot -- | The system-level @hypotf@ function. hypotf :: Float -> Float -> Float hypotf = c_hypotf foreign import ccall "erf" c_erf :: Double -> Double foreign import ccall "erff" c_erff :: Float -> Float foreign import ccall "erfc" c_erfc :: Double -> Double foreign import ccall "erfcf" c_erfcf :: Float -> Float -- | The system-level @erf()@ function. erf :: Double -> Double erf = c_erf -- | The system-level @erff()@ function. erff :: Float -> Float erff = c_erff -- | The system-level @erfc()@ function. erfc :: Double -> Double erfc = c_erfc -- | The system-level @erfcf()@ function. erfcf :: Float -> Float erfcf = c_erfcf foreign import ccall "cbrt" c_cbrt :: Double -> Double foreign import ccall "cbrtf" c_cbrtf :: Float -> Float -- | The system-level @cbrt@ function. cbrt :: Double -> Double cbrt = c_cbrt -- | The system-level @cbrtf@ function. cbrtf :: Float -> Float cbrtf = c_cbrtf foreign import ccall "ldexp" c_ldexp :: Double -> CInt -> Double foreign import ccall "ldexpf" c_ldexpf :: Float -> CInt -> Float -- | The system-level @ldexp@ function. ldexp :: Double -> CInt -> Double ldexp = c_ldexp -- | The system-level @ldexpf@ function. ldexpf :: Float -> CInt -> Float ldexpf = c_ldexpf foreign import ccall "copysign" c_copysign :: Double -> Double -> Double foreign import ccall "copysignf" c_copysignf :: Float -> Float -> Float -- | The system-level @copysign@ function. copysign :: Double -> Double -> Double copysign = c_copysign -- | The system-level @copysignf@ function. copysignf :: Float -> Float -> Float copysignf = c_copysignf futhark-0.25.27/src/Futhark/Util/IntegralExp.hs000066400000000000000000000046401475065116200212340ustar00rootroot00000000000000-- | It is occasionally useful to define generic functions that can -- not only compute their result as an integer, but also as a symbolic -- expression in the form of an AST. -- -- There are some Haskell hacks for this - it is for example not hard -- to define an instance of 'Num' that constructs an AST. However, -- this falls down for some other interesting classes, like -- 'Integral', which requires both the problematic method -- 'fromInteger', and also that the type is an instance of 'Enum'. -- -- We can always just define hobbled instances that call 'error' for -- those methods that are impractical, but this is ugly. -- -- Hence, this module defines similes to standard Haskell numeric -- typeclasses that have been modified to make generic functions -- slightly easier to write. module Futhark.Util.IntegralExp ( IntegralExp (..), Wrapped (..), ) where import Data.Int import Prelude -- | A twist on the 'Integral' type class that is more friendly to -- symbolic representations. class (Num e) => IntegralExp e where quot :: e -> e -> e rem :: e -> e -> e div :: e -> e -> e mod :: e -> e -> e sgn :: e -> Maybe Int pow :: e -> e -> e -- | Like 'Futhark.Util.IntegralExp.div', but rounds towards -- positive infinity. divUp :: e -> e -> e divUp x y = (x + y - 1) `Futhark.Util.IntegralExp.div` y nextMul :: e -> e -> e nextMul x y = x `divUp` y * y -- | This wrapper allows you to use a type that is an instance of the -- true class whenever the simile class is required. newtype Wrapped a = Wrapped {wrappedValue :: a} deriving (Eq, Ord, Show) instance (Enum a) => Enum (Wrapped a) where toEnum a = Wrapped $ toEnum a fromEnum (Wrapped a) = fromEnum a liftOp :: (a -> a) -> Wrapped a -> Wrapped a liftOp op (Wrapped x) = Wrapped $ op x liftOp2 :: (a -> a -> a) -> Wrapped a -> Wrapped a -> Wrapped a liftOp2 op (Wrapped x) (Wrapped y) = Wrapped $ x `op` y instance (Num a) => Num (Wrapped a) where (+) = liftOp2 (Prelude.+) (-) = liftOp2 (Prelude.-) (*) = liftOp2 (Prelude.*) abs = liftOp Prelude.abs signum = liftOp Prelude.signum fromInteger = Wrapped . Prelude.fromInteger negate = liftOp Prelude.negate instance (Integral a) => IntegralExp (Wrapped a) where quot = liftOp2 Prelude.quot rem = liftOp2 Prelude.rem div = liftOp2 Prelude.div mod = liftOp2 Prelude.mod sgn = Just . fromIntegral . signum . toInteger . wrappedValue pow = liftOp2 (Prelude.^) futhark-0.25.27/src/Futhark/Util/Loc.hs000066400000000000000000000002001475065116200175130ustar00rootroot00000000000000-- | A Safe Haskell-trusted re-export of the @srcloc@ package. module Futhark.Util.Loc (module Data.Loc) where import Data.Loc futhark-0.25.27/src/Futhark/Util/Log.hs000066400000000000000000000031571475065116200175350ustar00rootroot00000000000000-- | Opaque type for an operations log that provides fast O(1) -- appends. module Futhark.Util.Log ( Log, toText, ToLog (..), MonadLogger (..), ) where import Control.Monad.RWS.Lazy qualified import Control.Monad.RWS.Strict qualified import Control.Monad.Writer import Data.DList qualified as DL import Data.Text qualified as T import Data.Text.IO qualified as T import System.IO (stderr) -- | An efficiently catenable sequence of log entries. newtype Log = Log {unLog :: DL.DList T.Text} instance Semigroup Log where Log l1 <> Log l2 = Log $ l1 <> l2 instance Monoid Log where mempty = Log mempty -- | Transform a log into pretty. Every log entry becomes its own line -- (or possibly more, in case of multi-line entries). toText :: Log -> T.Text toText = T.intercalate "\n" . DL.toList . unLog -- | Typeclass for things that can be turned into a single-entry log. class ToLog a where toLog :: a -> Log instance ToLog String where toLog = Log . DL.singleton . T.pack instance ToLog T.Text where toLog = Log . DL.singleton -- | Typeclass for monads that support logging. class (Applicative m, Monad m) => MonadLogger m where -- | Add one log entry. logMsg :: (ToLog a) => a -> m () logMsg = addLog . toLog -- | Append an entire log. addLog :: Log -> m () instance (Monad m) => MonadLogger (WriterT Log m) where addLog = tell instance (Monad m) => MonadLogger (Control.Monad.RWS.Lazy.RWST r Log s m) where addLog = tell instance (Monad m) => MonadLogger (Control.Monad.RWS.Strict.RWST r Log s m) where addLog = tell instance MonadLogger IO where addLog = mapM_ (T.hPutStrLn stderr) . unLog futhark-0.25.27/src/Futhark/Util/Options.hs000066400000000000000000000072041475065116200204440ustar00rootroot00000000000000-- | Common code for parsing command line options based on getopt. module Futhark.Util.Options ( FunOptDescr, mainWithOptions, commonOptions, optionsError, module System.Console.GetOpt, ) where import Control.Monad.IO.Class import Data.List (sortBy) import Data.Text.IO qualified as T import Futhark.Version import System.Console.GetOpt import System.Environment (getProgName) import System.Exit import System.IO -- | A command line option that either purely updates a configuration, -- or performs an IO action (and stops). type FunOptDescr cfg = OptDescr (Either (IO ()) (cfg -> cfg)) -- | Generate a main action that parses the given command line options -- (while always adding 'commonOptions'). mainWithOptions :: cfg -> [FunOptDescr cfg] -> String -> ([String] -> cfg -> Maybe (IO ())) -> String -> [String] -> IO () mainWithOptions emptyConfig commandLineOptions usage f prog args = case getOpt' Permute commandLineOptions' args of (opts, nonopts, [], []) -> case applyOpts opts of Right config | Just m <- f nonopts config -> m | otherwise -> invalid nonopts [] [] Left m -> m (_, nonopts, unrecs, errs) -> invalid nonopts unrecs errs where applyOpts opts = do fs <- sequence opts pure $ foldl (.) id (reverse fs) emptyConfig invalid nonopts unrecs errs = do help <- helpStr prog usage commandLineOptions' badOptions help nonopts errs unrecs commandLineOptions' = commonOptions prog usage commandLineOptions ++ commandLineOptions helpStr :: String -> String -> [OptDescr a] -> IO String helpStr prog usage opts = do let header = unlines ["Usage: " ++ prog ++ " " ++ usage, "Options:"] pure $ usageInfo header $ sortBy cmp opts where -- Sort first by long option, then by short name, then by description. Hopefully -- everything has a long option. cmp (Option _ (a : _) _ _) (Option _ (b : _) _ _) = compare a b cmp (Option (a : _) _ _ _) (Option (b : _) _ _ _) = compare a b cmp (Option _ _ _ a) (Option _ _ _ b) = compare a b badOptions :: String -> [String] -> [String] -> [String] -> IO () badOptions usage nonopts errs unrecs = do mapM_ (errput . ("Junk argument: " ++)) nonopts mapM_ (errput . ("Unrecognised argument: " ++)) unrecs hPutStr stderr $ concat errs ++ usage exitWith $ ExitFailure 1 -- | Short-hand for 'liftIO . hPutStrLn stderr' errput :: (MonadIO m) => String -> m () errput = liftIO . hPutStrLn stderr -- | Common definitions for @-v@ and @-h@, given the list of all other -- options. commonOptions :: String -> String -> [FunOptDescr cfg] -> [FunOptDescr cfg] commonOptions prog usage options = [ Option "V" ["version"] ( NoArg $ Left $ do header exitSuccess ) "Print version information and exit.", Option "h" ["help"] ( NoArg $ Left $ do header putStrLn "" putStrLn =<< helpStr prog usage (commonOptions prog usage [] ++ options) exitSuccess ) "Print help and exit." ] where header = do T.putStrLn $ "Futhark " <> versionString T.putStrLn "Copyright (C) DIKU, University of Copenhagen, released under the ISC license." T.putStrLn "This is free software: you are free to change and redistribute it." T.putStrLn "There is NO WARRANTY, to the extent permitted by law." -- | Terminate the program with this error message (but don't report -- it as an ICE, as happens with 'error'). optionsError :: String -> IO () optionsError s = do prog <- getProgName hPutStrLn stderr $ prog <> ": " <> s exitWith $ ExitFailure 2 futhark-0.25.27/src/Futhark/Util/Pretty.hs000066400000000000000000000140361475065116200203010ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} -- | A re-export of the prettyprinting library, along with some -- convenience functions. module Futhark.Util.Pretty ( -- * Rendering to texts prettyTuple, prettyTupleLines, prettyString, prettyStringOneLine, prettyText, prettyTextOneLine, docText, docTextForHandle, docString, -- * Rendering to terminal putDoc, hPutDoc, putDocLn, hPutDocLn, -- * Building blocks module Prettyprinter, module Prettyprinter.Symbols.Ascii, module Prettyprinter.Render.Terminal, apply, oneLine, annot, nestedBlock, textwrap, shorten, commastack, commasep, semistack, stack, parensIf, ppTuple', ppTupleLines', -- * Operators (), ) where import Data.Text (Text) import Data.Text qualified as T import Numeric.Half import Prettyprinter import Prettyprinter.Render.Terminal (AnsiStyle, Color (..), bgColor, bgColorDull, bold, color, colorDull, italicized, underlined) import Prettyprinter.Render.Terminal qualified import Prettyprinter.Render.Text qualified import Prettyprinter.Symbols.Ascii import System.IO (Handle, hIsTerminalDevice, hPutStrLn, stdout) -- | Print a doc with styling to the given file; stripping colors if -- the file does not seem to support such things. hPutDoc :: Handle -> Doc AnsiStyle -> IO () hPutDoc h d = do colours <- hIsTerminalDevice h if colours then Prettyprinter.Render.Terminal.renderIO h (layouter d) else Prettyprinter.Render.Text.renderIO h (layouter d) where layouter = removeTrailingWhitespace . layoutSmart defaultLayoutOptions {layoutPageWidth = Unbounded} -- | Like 'hPutDoc', but with a final newline. hPutDocLn :: Handle -> Doc AnsiStyle -> IO () hPutDocLn h d = do hPutDoc h d hPutStrLn h "" -- | Like 'hPutDoc', but to stdout. putDoc :: Doc AnsiStyle -> IO () putDoc = hPutDoc stdout -- | Like 'putDoc', but with a final newline. putDocLn :: Doc AnsiStyle -> IO () putDocLn d = do putDoc d putStrLn "" -- | Produce text suitable for printing on the given handle. This -- mostly means stripping any control characters if the handle is not -- a terminal. docTextForHandle :: Handle -> Doc AnsiStyle -> IO T.Text docTextForHandle h d = do colours <- hIsTerminalDevice h let sds = removeTrailingWhitespace $ layoutSmart defaultLayoutOptions d pure $ if colours then Prettyprinter.Render.Terminal.renderStrict sds else Prettyprinter.Render.Text.renderStrict sds -- | Prettyprint a value to a 'String', appropriately wrapped. prettyString :: (Pretty a) => a -> String prettyString = T.unpack . prettyText -- | Prettyprint a value to a 'String' on a single line. prettyStringOneLine :: (Pretty a) => a -> String prettyStringOneLine = T.unpack . prettyTextOneLine -- | Prettyprint a value to a 'Text', appropriately wrapped. prettyText :: (Pretty a) => a -> Text prettyText = docText . pretty -- | Convert a 'Doc' to text. This ignores any annotations (i.e. it -- will be non-coloured output). docText :: Doc a -> T.Text docText = Prettyprinter.Render.Text.renderStrict . layouter where layouter = removeTrailingWhitespace . layoutSmart defaultLayoutOptions {layoutPageWidth = Unbounded} -- | Convert a 'Doc' to a 'String', through 'docText'. Intended for -- debugging. docString :: Doc a -> String docString = T.unpack . docText -- | Prettyprint a value to a 'Text' on a single line. prettyTextOneLine :: (Pretty a) => a -> Text prettyTextOneLine = Prettyprinter.Render.Text.renderStrict . layoutSmart oneLineLayout . group . pretty where oneLineLayout = defaultLayoutOptions {layoutPageWidth = Unbounded} ppTuple' :: [Doc a] -> Doc a ppTuple' ets = braces $ commasep $ map align ets ppTupleLines' :: [Doc a] -> Doc a ppTupleLines' ets = braces $ commastack $ map align ets -- | Prettyprint a list enclosed in curly braces. prettyTuple :: (Pretty a) => [a] -> Text prettyTuple = docText . ppTuple' . map pretty -- | Like 'prettyTuple', but put a linebreak after every element. prettyTupleLines :: (Pretty a) => [a] -> Text prettyTupleLines = docText . ppTupleLines' . map pretty -- | The document @'apply' ds@ separates @ds@ with commas and encloses them with -- parentheses. apply :: [Doc a] -> Doc a apply = parens . align . commasep . map align -- | Make sure that the given document is printed on just a single line. oneLine :: Doc a -> Doc a oneLine = group -- | Splits the string into words and permits line breaks between all -- of them. textwrap :: T.Text -> Doc a textwrap = fillSep . map pretty . T.words -- | Stack and prepend a list of 'Doc's to another 'Doc', separated by -- a linebreak. If the list is empty, the second 'Doc' will be -- returned without a preceding linebreak. annot :: [Doc a] -> Doc a -> Doc a annot [] s = s annot l s = vsep (l ++ [s]) -- | Surround the given document with enclosers and add linebreaks and -- indents. nestedBlock :: Doc a -> Doc a -> Doc a -> Doc a nestedBlock pre post body = vsep [pre, indent 2 body, post] -- | Prettyprint on a single line up to at most some appropriate -- number of characters, with trailing ... if necessary. Used for -- error messages. shorten :: Doc a -> Doc b shorten a | T.length s > 70 = pretty (T.take 70 s) <> "..." | otherwise = pretty s where s = Prettyprinter.Render.Text.renderStrict $ layoutCompact a -- | Like 'commasep', but a newline after every comma. commastack :: [Doc a] -> Doc a commastack = align . vsep . punctuate comma -- | Separate with semicolons and newlines. semistack :: [Doc a] -> Doc a semistack = align . vsep . punctuate semi -- | Separate with commas. commasep :: [Doc a] -> Doc a commasep = hsep . punctuate comma -- | Separate with linebreaks. stack :: [Doc a] -> Doc a stack = align . mconcat . punctuate line -- | The document @'parensIf' p d@ encloses the document @d@ in parenthesis if -- @p@ is @True@, and otherwise yields just @d@. parensIf :: Bool -> Doc a -> Doc a parensIf True doc = parens doc parensIf False doc = doc instance Pretty Half where pretty = viaShow () :: Doc a -> Doc a -> Doc a a b = a <> line <> b futhark-0.25.27/src/Futhark/Util/ProgressBar.hs000066400000000000000000000034431475065116200212430ustar00rootroot00000000000000-- | Facilities for generating and otherwise handling pretty-based progress bars. module Futhark.Util.ProgressBar ( progressBar, ProgressBar (..), progressSpinner, ) where import Data.Text qualified as T -- | Information about a progress bar to render. The "progress space" -- spans from 0 and up to the `progressBarBound`, but can be -- visualised in any number of steps. data ProgressBar = ProgressBar { -- | Number of steps in the visualisation. progressBarSteps :: Int, -- | The logical upper bound. progressBarBound :: Double, -- | The current position in the progress bar, relative to the -- upper bound. progressBarElapsed :: Double } -- | Render the progress bar. progressBar :: ProgressBar -> T.Text progressBar (ProgressBar steps bound elapsed) = "|" <> T.pack (map cell [1 .. steps]) <> "| " where step_size :: Double step_size = bound / fromIntegral steps chars = " ▏▎▍▍▌▋▊▉█" num_chars = T.length chars char i | i >= 0 && i < num_chars = T.index chars i | otherwise = ' ' cell :: Int -> Char cell i | i' * step_size <= elapsed = char 9 | otherwise = char (floor (((elapsed - (i' - 1) * step_size) * fromIntegral num_chars) / step_size)) where i' = fromIntegral i -- | Render a spinner - a kind of progress bar where there is no upper -- bound because we don't know how long it'll take. You certainly -- know these from THE INTERNET. The non-negative integer is how many -- "steps" have been taken. The spinner looks best if this is -- incremented by one for every call. progressSpinner :: Int -> T.Text progressSpinner spin_idx = T.singleton $ T.index spin_load (spin_idx `rem` n) where spin_load = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏" n = T.length spin_load futhark-0.25.27/src/Futhark/Util/Table.hs000066400000000000000000000041501475065116200200350ustar00rootroot00000000000000-- | Basic table building for prettier futhark-test output. module Futhark.Util.Table ( hPutTable, mkEntry, Entry, AnsiStyle, Color (..), color, ) where import Data.List (intersperse, transpose) import Futhark.Util (maxinum) import Futhark.Util.Pretty hiding (sep, width) import System.IO (Handle) data RowTemplate = RowTemplate [Int] Int deriving (Show) -- | A table entry. Consists of the content as well as how it should -- be styled.. data Entry = Entry {entryText :: String, _entryStyle :: AnsiStyle} -- | Makes a table entry. mkEntry :: String -> AnsiStyle -> Entry mkEntry = Entry buildRowTemplate :: [[Entry]] -> Int -> RowTemplate buildRowTemplate rows = RowTemplate widths where widths = map (maxinum . map (length . entryText)) . transpose $ rows buildRow :: RowTemplate -> [Entry] -> Doc AnsiStyle buildRow (RowTemplate widths pad) entries = cells <> hardline where bar = "\x2502" cells = mconcat (zipWith buildCell entries widths) <> bar buildCell (Entry entry sgr) width = let padding = width - length entry + pad in bar <> " " <> annotate sgr (pretty entry) <> mconcat (replicate padding " ") buildSep :: Char -> Char -> Char -> RowTemplate -> Doc AnsiStyle buildSep lCorner rCorner sep (RowTemplate widths pad) = corners . concatMap cellFloor $ widths where cellFloor width = replicate (width + pad + 1) '\x2500' <> [sep] corners [] = "" corners s = pretty lCorner <> pretty (init s) <> pretty rCorner -- | Produce a table from a list of entries and a padding amount that -- determines padding from the right side of the widest entry in each column. hPutTable :: Handle -> [[Entry]] -> Int -> IO () hPutTable h rows pad = hPutDoc h $ buildTop template <> sepRows <> buildBottom template <> hardline where sepRows = mconcat $ intersperse (buildFloor template) builtRows builtRows = map (buildRow template) rows template = buildRowTemplate rows pad buildTop rt = buildSep '\x250C' '\x2510' '\x252C' rt <> hardline buildFloor rt = buildSep '\x251C' '\x2524' '\x253C' rt <> hardline buildBottom = buildSep '\x2514' '\x2518' '\x2534' futhark-0.25.27/src/Futhark/Version.hs000066400000000000000000000036041475065116200175210ustar00rootroot00000000000000{-# LANGUAGE CPP #-} {-# LANGUAGE TemplateHaskell #-} -- | This module exports version information about the Futhark -- compiler. module Futhark.Version ( version, versionString, ) where import Data.ByteString.Char8 qualified as BS import Data.FileEmbed import Data.Text qualified as T import Data.Version import Futhark.Util (showText, trim) import GitHash import Paths_futhark qualified {-# NOINLINE version #-} -- | The version of Futhark that we are using. This is equivalent to -- the version defined in the .cabal file. version :: Version version = Paths_futhark.version {-# NOINLINE versionString #-} -- | The version of Futhark that we are using, in human-readable form. versionString :: T.Text versionString = T.pack (showVersion version) <> unreleased <> ".\n" <> gitversion $$tGitInfoCwdTry <> ghcversion where unreleased = if last (versionBranch version) == 0 then " (prerelease - include info below when reporting bugs)" else mempty gitversion (Left _) = case commitIdFromFile of Nothing -> "" Just commit -> "git: " <> T.pack commit <> "\n" gitversion (Right gi) = mconcat [ "git: ", branch, T.pack (take 7 $ giHash gi), " (", T.pack (giCommitDate gi), ")", dirty, "\n" ] where branch | giBranch gi == "master" = "" | otherwise = T.pack (giBranch gi) <> " @ " dirty = if giDirty gi then " [modified]" else "" ghcversion = "Compiled with GHC " <> showText a <> "." <> showText b <> "." <> showText c <> ".\n" where a, b, c :: Int a = __GLASGOW_HASKELL__ `div` 100 b = __GLASGOW_HASKELL__ `mod` 100 c = __GLASGOW_HASKELL_PATCHLEVEL1__ commitIdFromFile :: Maybe String commitIdFromFile = trim . BS.unpack <$> $(embedFileIfExists "./commit-id") futhark-0.25.27/src/Language/000077500000000000000000000000001475065116200156545ustar00rootroot00000000000000futhark-0.25.27/src/Language/Futhark.hs000066400000000000000000000005541475065116200176200ustar00rootroot00000000000000-- | Re-export the external Futhark modules for convenience. module Language.Futhark ( module Language.Futhark.Syntax, module Language.Futhark.Prop, module Language.Futhark.FreeVars, module Language.Futhark.Pretty, ) where import Language.Futhark.FreeVars import Language.Futhark.Pretty import Language.Futhark.Prop import Language.Futhark.Syntax futhark-0.25.27/src/Language/Futhark/000077500000000000000000000000001475065116200172605ustar00rootroot00000000000000futhark-0.25.27/src/Language/Futhark/Core.hs000066400000000000000000000141141475065116200205050ustar00rootroot00000000000000-- | This module contains very basic definitions for Futhark - so basic, -- that they can be shared between the internal and external -- representation. module Language.Futhark.Core ( Uniqueness (..), NoUniqueness (..), -- * Location utilities SrcLoc, Loc, Located (..), noLoc, L (..), unLoc, srclocOf, locStr, locStrRel, locText, locTextRel, prettyStacktrace, -- * Name handling Name, nameToString, nameFromString, nameToText, nameFromText, VName (..), baseTag, baseName, baseString, baseText, quote, -- * Number re-export Int8, Int16, Int32, Int64, Word8, Word16, Word32, Word64, Half, ) where import Control.Category import Data.Int (Int16, Int32, Int64, Int8) import Data.String import Data.Text qualified as T import Data.Word (Word16, Word32, Word64, Word8) import Futhark.Util (showText) import Futhark.Util.Loc import Futhark.Util.Pretty import Numeric.Half import Prelude hiding (id, (.)) -- | The uniqueness attribute of a type. This essentially indicates -- whether or not in-place modifications are acceptable. With respect -- to ordering, 'Unique' is greater than 'Nonunique'. data Uniqueness = -- | May have references outside current function. Nonunique | -- | No references outside current function. Unique deriving (Eq, Ord, Show) instance Semigroup Uniqueness where (<>) = min instance Monoid Uniqueness where mempty = Unique instance Pretty Uniqueness where pretty Unique = "*" pretty Nonunique = mempty -- | A fancier name for @()@ - encodes no uniqueness information. -- Also has a different prettyprinting instance. data NoUniqueness = NoUniqueness deriving (Eq, Ord, Show) instance Semigroup NoUniqueness where NoUniqueness <> NoUniqueness = NoUniqueness instance Monoid NoUniqueness where mempty = NoUniqueness instance Pretty NoUniqueness where pretty _ = mempty -- | The abstract (not really) type representing names in the Futhark -- compiler. 'String's, being lists of characters, are very slow, -- while 'T.Text's are based on byte-arrays. newtype Name = Name T.Text deriving (Show, Eq, Ord, IsString, Semigroup) instance Pretty Name where pretty = pretty . nameToString -- | Convert a name to the corresponding list of characters. nameToString :: Name -> String nameToString (Name t) = T.unpack t -- | Convert a list of characters to the corresponding name. nameFromString :: String -> Name nameFromString = Name . T.pack -- | Convert a name to the corresponding 'T.Text'. nameToText :: Name -> T.Text nameToText (Name t) = t -- | Convert a 'T.Text' to the corresponding name. nameFromText :: T.Text -> Name nameFromText = Name -- | A human-readable location string, of the form -- @filename:lineno:columnno@. This follows the GNU coding standards -- for error messages: -- https://www.gnu.org/prep/standards/html_node/Errors.html -- -- This function assumes that both start and end position is in the -- same file (it is not clear what the alternative would even mean). locStr :: (Located a) => a -> String locStr a = case locOf a of NoLoc -> "unknown location" Loc (Pos file line1 col1 _) (Pos _ line2 col2 _) -- Do not show line2 if it is identical to line1. | line1 == line2 -> first_part ++ "-" ++ show col2 | otherwise -> first_part ++ "-" ++ show line2 ++ ":" ++ show col2 where first_part = file ++ ":" ++ show line1 ++ ":" ++ show col1 -- | Like 'locStr', but @locStrRel prev now@ prints the location @now@ -- with the file name left out if the same as @prev@. This is useful -- when printing messages that are all in the context of some -- initially printed location (e.g. the first mention contains the -- file name; the rest just line and column name). locStrRel :: (Located a, Located b) => a -> b -> String locStrRel a b = case (locOf a, locOf b) of (Loc (Pos a_file _ _ _) _, Loc (Pos b_file line1 col1 _) (Pos _ line2 col2 _)) | a_file == b_file, line1 == line2 -> first_part ++ "-" ++ show col2 | a_file == b_file -> first_part ++ "-" ++ show line2 ++ ":" ++ show col2 where first_part = show line1 ++ ":" ++ show col1 _ -> locStr b -- | 'locStr', but for text. locText :: (Located a) => a -> T.Text locText = T.pack . locStr -- | 'locStrRel', but for text. locTextRel :: (Located a, Located b) => a -> b -> T.Text locTextRel a b = T.pack $ locStrRel a b -- | Given a list of strings representing entries in the stack trace -- and the index of the frame to highlight, produce a final -- newline-terminated string for showing to the user. This string -- should also be preceded by a newline. The most recent stack frame -- must come first in the list. prettyStacktrace :: Int -> [T.Text] -> T.Text prettyStacktrace cur = T.unlines . zipWith f [(0 :: Int) ..] where -- Formatting hack: assume no stack is deeper than 100 -- elements. Since Futhark does not support recursion, going -- beyond that would require a truly perverse program. f i x = (if cur == i then "-> " else " ") <> "#" <> showText i <> (if i > 9 then "" else " ") <> " " <> x -- | A name tagged with some integer. Only the integer is used in -- comparisons, no matter the type of @vn@. data VName = VName !Name !Int deriving (Show) -- | Return the tag contained in the 'VName'. baseTag :: VName -> Int baseTag (VName _ tag) = tag -- | Return the name contained in the 'VName'. baseName :: VName -> Name baseName (VName vn _) = vn -- | Return the base 'Name' converted to a string. baseString :: VName -> String baseString = nameToString . baseName -- | Return the base 'Name' converted to a text. baseText :: VName -> T.Text baseText = nameToText . baseName instance Eq VName where VName _ x == VName _ y = x == y instance Ord VName where VName _ x `compare` VName _ y = x `compare` y -- | Enclose a string in the prefered quotes used in error messages. -- These are picked to not collide with characters permitted in -- identifiers. quote :: T.Text -> T.Text quote s = "\"" <> s <> "\"" futhark-0.25.27/src/Language/Futhark/FreeVars.hs000066400000000000000000000120141475065116200213270ustar00rootroot00000000000000-- | Facilities for computing free term variables in various syntactic -- constructs. module Language.Futhark.FreeVars ( freeInExp, freeInPat, freeInType, freeWithout, FV, fvVars, ) where import Data.Set qualified as S import Language.Futhark.Prop import Language.Futhark.Syntax -- | A set of names. newtype FV = FV {unFV :: S.Set VName} deriving (Show) -- | The set of names in an 'FV'. fvVars :: FV -> S.Set VName fvVars = unFV instance Semigroup FV where FV x <> FV y = FV $ x <> y instance Monoid FV where mempty = FV mempty -- | Set subtraction. Do not consider those variables as free. freeWithout :: FV -> S.Set VName -> FV freeWithout (FV x) y = FV $ x `S.difference` y -- | As 'freeWithout', but for lists. freeWithoutL :: FV -> [VName] -> FV freeWithoutL fv y = fv `freeWithout` S.fromList y ident :: Ident t -> FV ident = FV . S.singleton . identName -- | Compute the set of free variables of an expression. freeInExp :: ExpBase Info VName -> FV freeInExp expr = case expr of Literal {} -> mempty IntLit {} -> mempty FloatLit {} -> mempty StringLit {} -> mempty Hole {} -> mempty Parens e _ -> freeInExp e QualParens _ e _ -> freeInExp e TupLit es _ -> foldMap freeInExp es RecordLit fs _ -> foldMap freeInExpField fs where freeInExpField (RecordFieldExplicit _ e _) = freeInExp e freeInExpField (RecordFieldImplicit (L _ vn) t _) = ident $ Ident vn t mempty ArrayVal {} -> mempty ArrayLit es t _ -> foldMap freeInExp es <> freeInType (unInfo t) AppExp (Range e me incl _) _ -> freeInExp e <> foldMap freeInExp me <> foldMap freeInExp incl Var qn _ _ -> FV $ S.singleton $ qualLeaf qn Ascript e _ _ -> freeInExp e Coerce e _ (Info t) _ -> freeInExp e <> freeInType t AppExp (LetPat let_sizes pat e1 e2 _) _ -> freeInExp e1 <> ( (freeInPat pat <> freeInExp e2) `freeWithoutL` (patNames pat <> map sizeName let_sizes) ) AppExp (LetFun vn (tparams, pats, _, _, e1) e2 _) _ -> ( (freeInExp e1 <> foldMap freeInPat pats) `freeWithoutL` (foldMap patNames pats <> map typeParamName tparams) ) <> (freeInExp e2 `freeWithout` S.singleton vn) AppExp (If e1 e2 e3 _) _ -> freeInExp e1 <> freeInExp e2 <> freeInExp e3 AppExp (Apply f args _) _ -> freeInExp f <> foldMap (freeInExp . snd) args Negate e _ -> freeInExp e Not e _ -> freeInExp e Lambda pats e0 _ (Info (RetType dims t)) _ -> (foldMap freeInPat pats <> freeInExp e0 <> freeInType t) `freeWithoutL` (foldMap patNames pats <> dims) OpSection {} -> mempty OpSectionLeft _ _ e _ _ _ -> freeInExp e OpSectionRight _ _ e _ _ _ -> freeInExp e ProjectSection {} -> mempty IndexSection idxs _ _ -> foldMap freeInDimIndex idxs AppExp (Loop sparams pat e1 form e3 _) _ -> let (e2fv, e2ident) = formVars form in freeInExp (loopInitExp e1) <> ( (e2fv <> freeInExp e3) `freeWithoutL` (sparams <> patNames pat <> e2ident) ) where formVars (For v e2) = (freeInExp e2, [identName v]) formVars (ForIn p e2) = (freeInExp e2, patNames p) formVars (While e2) = (freeInExp e2, mempty) AppExp (BinOp (qn, _) _ (e1, _) (e2, _) _) _ -> FV (S.singleton (qualLeaf qn)) <> freeInExp e1 <> freeInExp e2 Project _ e _ _ -> freeInExp e AppExp (LetWith id1 id2 idxs e1 e2 _) _ -> ident id2 <> foldMap freeInDimIndex idxs <> freeInExp e1 <> (freeInExp e2 `freeWithout` S.singleton (identName id1)) AppExp (Index e idxs _) _ -> freeInExp e <> foldMap freeInDimIndex idxs Update e1 idxs e2 _ -> freeInExp e1 <> foldMap freeInDimIndex idxs <> freeInExp e2 RecordUpdate e1 _ e2 _ _ -> freeInExp e1 <> freeInExp e2 Assert e1 e2 _ _ -> freeInExp e1 <> freeInExp e2 Constr _ es _ _ -> foldMap freeInExp es Attr _ e _ -> freeInExp e AppExp (Match e cs _) _ -> freeInExp e <> foldMap caseFV cs where caseFV (CasePat p eCase _) = (freeInPat p <> freeInExp eCase) `freeWithoutL` patNames p freeInDimIndex :: DimIndexBase Info VName -> FV freeInDimIndex (DimFix e) = freeInExp e freeInDimIndex (DimSlice me1 me2 me3) = foldMap (foldMap freeInExp) [me1, me2, me3] -- | Free variables in pattern (including types of the bound identifiers). freeInPat :: Pat (TypeBase Size u) -> FV freeInPat = foldMap freeInType -- | Free variables in the type (meaning those that are used in size expression). freeInType :: TypeBase Size u -> FV freeInType t = case t of Array _ s a -> freeInType (Scalar a) <> foldMap freeInExp (shapeDims s) Scalar (Record fs) -> foldMap freeInType fs Scalar Prim {} -> mempty Scalar (Sum cs) -> foldMap (foldMap freeInType) cs Scalar (Arrow _ v _ t1 (RetType dims t2)) -> FV . S.filter (\k -> notV v k && notElem k dims) $ unFV (freeInType t1 <> freeInType t2) Scalar (TypeVar _ _ targs) -> foldMap typeArgDims targs where typeArgDims (TypeArgDim d) = freeInExp d typeArgDims (TypeArgType at) = freeInType at notV Unnamed = const True notV (Named v) = (/= v) futhark-0.25.27/src/Language/Futhark/Interpreter.hs000066400000000000000000002372611475065116200221320ustar00rootroot00000000000000-- | An interpreter operating on type-checked source Futhark terms. -- Relatively slow. module Language.Futhark.Interpreter ( Ctx (..), Env, InterpreterError, prettyInterpreterError, initialCtx, interpretExp, interpretDec, interpretImport, interpretFunction, ctxWithImports, ExtOp (..), BreakReason (..), StackFrame (..), typeCheckerEnv, -- * Values Value, fromTuple, isEmptyArray, prettyEmptyArray, prettyValue, valueText, ) where import Control.Monad import Control.Monad.Free.Church import Control.Monad.Identity import Control.Monad.Reader import Control.Monad.State import Control.Monad.Trans.Maybe import Data.Array import Data.Bifunctor import Data.Bitraversable import Data.List ( find, foldl', genericLength, genericTake, transpose, ) import Data.List qualified as L import Data.List.NonEmpty qualified as NE import Data.Map qualified as M import Data.Maybe import Data.Monoid hiding (Sum) import Data.Ord import Data.Set qualified as S import Data.Text qualified as T import Futhark.Data qualified as V import Futhark.Util (chunk, maybeHead) import Futhark.Util.Loc import Futhark.Util.Pretty hiding (apply) import Language.Futhark hiding (Shape, matchDims) import Language.Futhark qualified as F import Language.Futhark.Interpreter.AD qualified as AD import Language.Futhark.Interpreter.Values hiding (Value) import Language.Futhark.Interpreter.Values qualified import Language.Futhark.Primitive (floatValue, intValue) import Language.Futhark.Primitive qualified as P import Language.Futhark.Semantic qualified as T import Prelude hiding (break, mod) data StackFrame = StackFrame { stackFrameLoc :: Loc, stackFrameCtx :: Ctx } instance Located StackFrame where locOf = stackFrameLoc -- | What is the reason for this break point? data BreakReason = -- | An explicit breakpoint in the program. BreakPoint | -- | A BreakNaN data ExtOp a = ExtOpTrace T.Text (Doc ()) a | ExtOpBreak Loc BreakReason (NE.NonEmpty StackFrame) a | ExtOpError InterpreterError instance Functor ExtOp where fmap f (ExtOpTrace w s x) = ExtOpTrace w s $ f x fmap f (ExtOpBreak w why backtrace x) = ExtOpBreak w why backtrace $ f x fmap _ (ExtOpError err) = ExtOpError err type Stack = [StackFrame] type Exts = M.Map VName Value -- | The monad in which evaluation takes place. newtype EvalM a = EvalM ( ReaderT (Stack, M.Map ImportName Env) (StateT Exts (F ExtOp)) a ) deriving ( Monad, Applicative, Functor, MonadFree ExtOp, MonadReader (Stack, M.Map ImportName Env), MonadState Exts ) runEvalM :: M.Map ImportName Env -> EvalM a -> F ExtOp a runEvalM imports (EvalM m) = evalStateT (runReaderT m (mempty, imports)) mempty stacking :: SrcLoc -> Env -> EvalM a -> EvalM a stacking loc env = local $ \(ss, imports) -> if isNoLoc loc then (ss, imports) else let s = StackFrame (locOf loc) (Ctx env imports) in (s : ss, imports) where isNoLoc :: SrcLoc -> Bool isNoLoc = (== NoLoc) . locOf stacktrace :: EvalM [Loc] stacktrace = asks $ map stackFrameLoc . fst lookupImport :: ImportName -> EvalM (Maybe Env) lookupImport f = asks $ M.lookup f . snd putExtSize :: VName -> Value -> EvalM () putExtSize v x = modify $ M.insert v x getExts :: EvalM Exts getExts = get -- | Disregard any existential sizes computed during this action. -- This is used so that existentials computed during one iteration of -- a loop or a function call are not remembered the next time around. localExts :: EvalM a -> EvalM a localExts m = do s <- get x <- m put s pure x extEnv :: EvalM Env extEnv = valEnv . M.map f <$> getExts where f v = ( Nothing, v ) valueStructType :: ValueType -> StructType valueStructType = first $ flip sizeFromInteger mempty . toInteger -- | An expression along with an environment in which to evaluate that -- expression. Used to represent non-interpreted size expressions, -- which may still be in reference to some environment. data SizeClosure = SizeClosure Env Size deriving (Show) instance Pretty SizeClosure where pretty (SizeClosure _ e) = pretty e instance Pretty (F.Shape SizeClosure) where pretty = mconcat . map (braces . pretty) . shapeDims -- | A type where the sizes are unevaluated expressions. type EvalType = TypeBase SizeClosure NoUniqueness structToEval :: Env -> StructType -> EvalType structToEval env = first (SizeClosure env) evalToStruct :: EvalType -> StructType evalToStruct = first (\(SizeClosure _ e) -> e) resolveTypeParams :: [VName] -> StructType -> EvalType -> ([(VName, ([VName], EvalType))], [(VName, SizeClosure)]) resolveTypeParams names orig_t1 orig_t2 = execState (match mempty orig_t1 orig_t2) mempty where addType v t = modify $ first $ L.insertBy (comparing fst) (v, t) addDim v e = modify $ second $ L.insertBy (comparing fst) (v, e) match bound (Scalar (TypeVar _ tn _)) t | qualLeaf tn `elem` names = addType (qualLeaf tn) (bound, t) match bound (Scalar (Record poly_fields)) (Scalar (Record fields)) = sequence_ . M.elems $ M.intersectionWith (match bound) poly_fields fields match bound (Scalar (Sum poly_fields)) (Scalar (Sum fields)) = sequence_ . mconcat . M.elems $ M.intersectionWith (zipWith $ match bound) poly_fields fields match bound (Scalar (Arrow _ p1 _ poly_t1 (RetType dims1 poly_t2))) (Scalar (Arrow _ p2 _ t1 (RetType dims2 t2))) = do let bound' = mapMaybe paramName [p1, p2] <> dims1 <> dims2 <> bound match bound' poly_t1 t1 match bound' (toStruct poly_t2) (toStruct t2) match bound poly_t t | d1 : _ <- shapeDims (arrayShape poly_t), d2 : _ <- shapeDims (arrayShape t) = do matchDims bound d1 d2 match bound (stripArray 1 poly_t) (stripArray 1 t) match bound t1 t2 | Just t1' <- isAccType t1, Just t2' <- isAccType t2 = match bound t1' t2' match _ _ _ = pure mempty matchDims bound e1 (SizeClosure env e2) | e1 == anySize || e2 == anySize = pure mempty | otherwise = matchExps bound env e1 e2 matchExps bound env (Var (QualName _ d1) _ _) e | d1 `elem` names, not $ any problematic $ fvVars $ freeInExp e = addDim d1 (SizeClosure env e) where problematic v = v `elem` bound || v `elem` names matchExps bound env e1 e2 | Just es <- similarExps e1 e2 = mapM_ (uncurry $ matchExps bound env) es matchExps _ _ _ _ = pure mempty evalWithExts :: Env -> Exp -> EvalM Value evalWithExts env e = do size_env <- extEnv eval (size_env <> env) e evalResolved :: ([(VName, ([VName], EvalType))], [(VName, SizeClosure)]) -> EvalM Env evalResolved (ts, ds) = do ts' <- mapM (traverse $ \(bound, t) -> first onDim <$> evalType (S.fromList bound) t) ts ds' <- mapM (traverse $ \(SizeClosure env e) -> asInt64 <$> evalWithExts env e) ds pure $ typeEnv (M.fromList ts') <> i64Env (M.fromList ds') where onDim (Left x) = sizeFromInteger (toInteger x) mempty onDim (Right (SizeClosure _ e)) = e -- FIXME resolveExistentials :: [VName] -> StructType -> ValueShape -> M.Map VName Int64 resolveExistentials names = match where match (Scalar (Record poly_fields)) (ShapeRecord fields) = mconcat $ M.elems $ M.intersectionWith match poly_fields fields match (Scalar (Sum poly_fields)) (ShapeSum fields) = mconcat $ map mconcat $ M.elems $ M.intersectionWith (zipWith match) poly_fields fields match poly_t (ShapeDim d2 rowshape) | d1 : _ <- shapeDims (arrayShape poly_t) = matchDims d1 d2 <> match (stripArray 1 poly_t) rowshape match _ _ = mempty matchDims (Var (QualName _ d1) _ _) d2 | d1 `elem` names = M.singleton d1 d2 matchDims _ _ = mempty checkShape :: Shape Int64 -> ValueShape -> Maybe ValueShape checkShape (ShapeDim d1 shape1) (ShapeDim d2 shape2) = do guard $ d1 == d2 ShapeDim d2 <$> checkShape shape1 shape2 checkShape (ShapeDim d1 shape1) ShapeLeaf = -- This case is for handling polymorphism, when a function doesn't -- know that the array it produced actually has more dimensions. ShapeDim d1 <$> checkShape shape1 ShapeLeaf checkShape (ShapeRecord shapes1) (ShapeRecord shapes2) = ShapeRecord <$> sequence (M.intersectionWith checkShape shapes1 shapes2) checkShape (ShapeRecord shapes1) ShapeLeaf = Just $ ShapeRecord shapes1 checkShape (ShapeSum shapes1) (ShapeSum shapes2) = ShapeSum <$> sequence (M.intersectionWith (zipWithM checkShape) shapes1 shapes2) checkShape (ShapeSum shapes1) ShapeLeaf = Just $ ShapeSum shapes1 checkShape _ shape2 = Just shape2 type Value = Language.Futhark.Interpreter.Values.Value EvalM asInteger :: Value -> Integer asInteger (ValuePrim (SignedValue v)) = P.valueIntegral v asInteger (ValuePrim (UnsignedValue v)) = toInteger (P.valueIntegral (P.doZExt v Int64) :: Word64) asInteger (ValueAD _ v) | P.IntValue v' <- AD.varPrimal v = P.valueIntegral v' asInteger v = error $ "Unexpectedly not an integer: " <> show v asInt :: Value -> Int asInt = fromIntegral . asInteger asSigned :: Value -> IntValue asSigned (ValuePrim (SignedValue v)) = v asSigned (ValueAD _ v) | P.IntValue v' <- AD.varPrimal v = v' asSigned v = error $ "Unexpectedly not a signed integer: " <> show v asInt64 :: Value -> Int64 asInt64 = fromIntegral . asInteger asBool :: Value -> Bool asBool (ValuePrim (BoolValue x)) = x asBool (ValueAD _ v) | P.BoolValue v' <- AD.varPrimal v = v' asBool v = error $ "Unexpectedly not a boolean: " <> show v lookupInEnv :: (Env -> M.Map VName x) -> QualName VName -> Env -> Maybe x lookupInEnv onEnv qv env = f env $ qualQuals qv where f m (q : qs) = case M.lookup q $ envTerm m of Just (TermModule (Module mod)) -> f mod qs _ -> Nothing f m [] = M.lookup (qualLeaf qv) $ onEnv m lookupVar :: QualName VName -> Env -> Maybe TermBinding lookupVar = lookupInEnv envTerm lookupType :: QualName VName -> Env -> Maybe (Env, T.TypeBinding) lookupType = lookupInEnv envType -- | A TermValue with a 'Nothing' type annotation is an intrinsic or -- an existential. data TermBinding = TermValue (Maybe T.BoundV) Value | -- | A polymorphic value that must be instantiated. The -- 'StructType' provided is un-evaluated, but parts of it can be -- evaluated using the provided 'Eval' function. TermPoly (Maybe T.BoundV) (EvalType -> EvalM Value) | TermModule Module instance Show TermBinding where show (TermValue bv v) = unwords ["TermValue", show bv, show v] show (TermPoly bv _) = unwords ["TermPoly", show bv] show (TermModule m) = unwords ["TermModule", show m] data Module = Module Env | ModuleFun (Module -> EvalM Module) instance Show Module where show (Module env) = "(" <> unwords ["Module", show env] <> ")" show (ModuleFun _) = "(ModuleFun _)" -- | The actual type- and value environment. data Env = Env { envTerm :: M.Map VName TermBinding, envType :: M.Map VName (Env, T.TypeBinding) } deriving (Show) instance Monoid Env where mempty = Env mempty mempty instance Semigroup Env where Env vm1 tm1 <> Env vm2 tm2 = Env (vm1 <> vm2) (tm1 <> tm2) -- | An error occurred during interpretation due to an error in the -- user program. Actual interpreter errors will be signaled with an -- IO exception ('error'). newtype InterpreterError = InterpreterError T.Text -- | Prettyprint the error for human consumption. prettyInterpreterError :: InterpreterError -> Doc AnsiStyle prettyInterpreterError (InterpreterError e) = pretty e valEnv :: M.Map VName (Maybe T.BoundV, Value) -> Env valEnv m = Env { envTerm = M.map (uncurry TermValue) m, envType = mempty } modEnv :: M.Map VName Module -> Env modEnv m = Env { envTerm = M.map TermModule m, envType = mempty } typeEnv :: M.Map VName StructType -> Env typeEnv m = Env { envTerm = mempty, envType = M.map tbind m } where tbind = (mempty,) . T.TypeAbbr Unlifted [] . RetType [] i64Env :: M.Map VName Int64 -> Env i64Env = valEnv . M.map f where f x = ( Just $ T.BoundV [] $ Scalar $ Prim $ Signed Int64, ValuePrim $ SignedValue $ Int64Value x ) instance Show InterpreterError where show (InterpreterError s) = T.unpack s bad :: SrcLoc -> Env -> T.Text -> EvalM a bad loc env s = stacking loc env $ do ss <- map (locText . srclocOf) <$> stacktrace liftF . ExtOpError . InterpreterError $ "Error at\n" <> prettyStacktrace 0 ss <> s trace :: T.Text -> Value -> EvalM () trace w v = do liftF $ ExtOpTrace w (prettyValue v) () typeCheckerEnv :: Env -> T.Env typeCheckerEnv env = -- FIXME: some shadowing issues are probably not right here. let valMap (TermValue (Just t) _) = Just t valMap _ = Nothing vtable = M.mapMaybe valMap $ envTerm env nameMap k | k `M.member` vtable = Just ((T.Term, baseName k), qualName k) | otherwise = Nothing in mempty { T.envNameMap = M.fromList $ mapMaybe nameMap $ M.keys $ envTerm env, T.envVtable = vtable } break :: Env -> Loc -> EvalM () break env loc = do imports <- asks snd backtrace <- asks ((StackFrame loc (Ctx env imports) NE.:|) . fst) liftF $ ExtOpBreak loc BreakPoint backtrace () fromArray :: Value -> (ValueShape, [Value]) fromArray (ValueArray shape as) = (shape, elems as) fromArray v = error $ "Expected array value, but found: " <> show v apply :: SrcLoc -> Env -> Value -> Value -> EvalM Value apply loc env (ValueFun f) v = stacking loc env (f v) apply _ _ f _ = error $ "Cannot apply non-function: " <> show f apply2 :: SrcLoc -> Env -> Value -> Value -> Value -> EvalM Value apply2 loc env f x y = stacking loc env $ do f' <- apply noLoc mempty f x apply noLoc mempty f' y matchPat :: Env -> Pat (TypeBase Size u) -> Value -> EvalM Env matchPat env p v = do m <- runMaybeT $ patternMatch env p v case m of Nothing -> error $ "matchPat: missing case for " <> prettyString (toStruct <$> p) ++ " and " <> show v Just env' -> pure env' patternMatch :: Env -> Pat (TypeBase Size u) -> Value -> MaybeT EvalM Env patternMatch env (PatAttr _ p _) val = patternMatch env p val patternMatch env (Id v (Info t) _) val = lift $ pure $ valEnv (M.singleton v (Just $ T.BoundV [] $ toStruct t, val)) <> env patternMatch env Wildcard {} _ = lift $ pure env patternMatch env (TuplePat ps _) (ValueRecord vs) = foldM (\env' (p, v) -> patternMatch env' p v) env $ zip ps (map snd $ sortFields vs) patternMatch env (RecordPat ps _) (ValueRecord vs) = foldM (\env' (p, v) -> patternMatch env' p v) env $ M.intersectionWith (,) (M.fromList $ map (first unLoc) ps) vs patternMatch env (PatParens p _) v = patternMatch env p v patternMatch env (PatAscription p _ _) v = patternMatch env p v patternMatch env (PatLit l t _) v = do l' <- case l of PatLitInt x -> lift $ eval env $ IntLit x (toStruct <$> t) mempty PatLitFloat x -> lift $ eval env $ FloatLit x (toStruct <$> t) mempty PatLitPrim lv -> pure $ ValuePrim lv if v == l' then pure env else mzero patternMatch env (PatConstr n _ ps _) (ValueSum _ n' vs) | n == n' = foldM (\env' (p, v) -> patternMatch env' p v) env $ zip ps vs patternMatch _ _ _ = mzero data Indexing = IndexingFix Int64 | IndexingSlice (Maybe Int64) (Maybe Int64) (Maybe Int64) instance Pretty Indexing where pretty (IndexingFix i) = pretty i pretty (IndexingSlice i j (Just s)) = maybe mempty pretty i <> ":" <> maybe mempty pretty j <> ":" <> pretty s pretty (IndexingSlice i (Just j) s) = maybe mempty pretty i <> ":" <> pretty j <> maybe mempty ((":" <>) . pretty) s pretty (IndexingSlice i Nothing Nothing) = maybe mempty pretty i <> ":" indexesFor :: Maybe Int64 -> Maybe Int64 -> Maybe Int64 -> Int64 -> Maybe [Int] indexesFor start end stride n | (start', end', stride') <- slice, end' == start' || signum' (end' - start') == signum' stride', stride' /= 0, is <- [start', start' + stride' .. end' - signum stride'], all inBounds is = Just $ map fromIntegral is | otherwise = Nothing where inBounds i = i >= 0 && i < n slice = case (start, end, stride) of (Just start', _, _) -> let end' = fromMaybe n end in (start', end', fromMaybe 1 stride) (Nothing, Just end', _) -> let start' = 0 in (start', end', fromMaybe 1 stride) (Nothing, Nothing, Just stride') -> ( if stride' > 0 then 0 else n - 1, if stride' > 0 then n else -1, stride' ) (Nothing, Nothing, Nothing) -> (0, n, 1) -- | 'signum', but with 0 as 1. signum' :: (Eq p, Num p) => p -> p signum' 0 = 1 signum' x = signum x indexShape :: [Indexing] -> ValueShape -> ValueShape indexShape (IndexingFix {} : is) (ShapeDim _ shape) = indexShape is shape indexShape (IndexingSlice start end stride : is) (ShapeDim d shape) = ShapeDim n $ indexShape is shape where n = maybe 0 genericLength $ indexesFor start end stride d indexShape _ shape = shape indexArray :: [Indexing] -> Value -> Maybe Value indexArray (IndexingFix i : is) (ValueArray _ arr) | i >= 0, i < n = indexArray is $ arr ! fromIntegral i | otherwise = Nothing where n = arrayLength arr indexArray (IndexingSlice start end stride : is) (ValueArray (ShapeDim _ rowshape) arr) = do js <- indexesFor start end stride $ arrayLength arr toArray' (indexShape is rowshape) <$> mapM (indexArray is . (arr !)) js indexArray _ v = Just v writeArray :: [Indexing] -> Value -> Value -> Maybe Value writeArray slice x y = runIdentity $ updateArray (\_ y' -> pure y') slice x y updateArray :: (Monad m) => (Value -> Value -> m Value) -> [Indexing] -> Value -> Value -> m (Maybe Value) updateArray f (IndexingFix i : is) (ValueArray shape arr) v | i >= 0, i < n = do v' <- updateArray f is (arr ! i') v pure $ do v'' <- v' Just $ ValueArray shape $ arr // [(i', v'')] | otherwise = pure Nothing where n = arrayLength arr i' = fromIntegral i updateArray f (IndexingSlice start end stride : is) (ValueArray shape arr) (ValueArray _ v) | Just arr_is <- indexesFor start end stride $ arrayLength arr, length arr_is == arrayLength v = do let update (Just arr') (i, v') = do x <- updateArray f is (arr ! i) v' pure $ do x' <- x Just $ arr' // [(i, x')] update Nothing _ = pure Nothing fmap (fmap (ValueArray shape)) $ foldM update (Just arr) $ zip arr_is $ elems v | otherwise = pure Nothing updateArray f _ x y = Just <$> f x y evalDimIndex :: Env -> DimIndex -> EvalM Indexing evalDimIndex env (DimFix x) = IndexingFix . asInt64 <$> eval env x evalDimIndex env (DimSlice start end stride) = IndexingSlice <$> traverse (fmap asInt64 . eval env) start <*> traverse (fmap asInt64 . eval env) end <*> traverse (fmap asInt64 . eval env) stride evalIndex :: SrcLoc -> Env -> [Indexing] -> Value -> EvalM Value evalIndex loc env is arr = do let oob = bad loc env $ "Index [" <> T.intercalate ", " (map prettyText is) <> "] out of bounds for array of shape " <> prettyText (valueShape arr) <> "." maybe oob pure $ indexArray is arr -- | Expand type based on information that was not available at -- type-checking time (the structure of abstract types). expandType :: (Pretty u) => Env -> TypeBase Size u -> TypeBase SizeClosure u expandType _ (Scalar (Prim pt)) = Scalar $ Prim pt expandType env (Scalar (Record fs)) = Scalar $ Record $ fmap (expandType env) fs expandType env (Scalar (Arrow u p d t1 (RetType dims t2))) = Scalar $ Arrow u p d (expandType env t1) (RetType dims (expandType env t2)) expandType env t@(Array u shape _) = let et = stripArray (shapeRank shape) t et' = expandType env et shape' = fmap (SizeClosure env) shape in second (const u) (arrayOf shape' $ toStruct et') expandType env (Scalar (TypeVar u tn args)) = case lookupType tn env of Just (tn_env, T.TypeAbbr _ ps (RetType ext t')) -> let (substs, types) = mconcat $ zipWith matchPtoA ps args onDim (SizeClosure _ (Var v _ _)) | Just e <- M.lookup (qualLeaf v) substs = SizeClosure env e -- The next case can occur when a type with existential size -- has been hidden by a module ascription, -- e.g. tests/modules/sizeparams4.fut. onDim (SizeClosure _ e) | any (`elem` ext) $ fvVars $ freeInExp e = SizeClosure mempty anySize onDim d = d in bimap onDim (const u) $ expandType (Env mempty types <> tn_env) t' Nothing -> -- This case only happens for built-in abstract types, -- e.g. accumulators. Scalar (TypeVar u tn $ map expandArg args) where matchPtoA (TypeParamDim p _) (TypeArgDim e) = (M.singleton p e, mempty) matchPtoA (TypeParamType l p _) (TypeArgType t') = let t'' = evalToStruct $ expandType env t' -- FIXME, we are throwing away the closure here. in (mempty, M.singleton p (mempty, T.TypeAbbr l [] $ RetType [] t'')) matchPtoA _ _ = mempty expandArg (TypeArgDim s) = TypeArgDim $ SizeClosure env s expandArg (TypeArgType t) = TypeArgType $ expandType env t expandType env (Scalar (Sum cs)) = Scalar $ Sum $ (fmap . fmap) (expandType env) cs -- | Evaluate all possible sizes, except those that contain free -- variables in the set of names. evalType :: S.Set VName -> EvalType -> EvalM (TypeBase (Either Int64 SizeClosure) NoUniqueness) evalType outer_bound t = do let evalDim bound _ (SizeClosure env e) | canBeEvaluated bound e = Left . asInt64 <$> evalWithExts env e evalDim _ _ e = pure $ Right e traverseDims evalDim t where canBeEvaluated bound e = let free = fvVars $ freeInExp e in not $ any (`S.member` bound) free || any (`S.member` outer_bound) free -- | Evaluate all sizes, and it better work. This implies it must be a -- size-dependent function type, or one that has existentials. evalTypeFully :: EvalType -> EvalM ValueType evalTypeFully t = do let evalDim (SizeClosure env e) = asInt64 <$> evalWithExts env e bitraverse evalDim pure t evalTermVar :: Env -> QualName VName -> StructType -> EvalM Value evalTermVar env qv t = case lookupVar qv env of Just (TermPoly _ v) -> v $ expandType env t Just (TermValue _ v) -> pure v x -> do ss <- map (locText . srclocOf) <$> stacktrace error $ prettyString qv <> " is not bound to a value.\n" <> T.unpack (prettyStacktrace 0 ss) <> "Bound to\n" <> show x typeValueShape :: Env -> StructType -> EvalM ValueShape typeValueShape env t = typeShape <$> evalTypeFully (expandType env t) -- Sometimes type instantiation is not quite enough - then we connect -- up the missing sizes here. In particular used for eta-expanded -- entry points. linkMissingSizes :: [VName] -> Pat (TypeBase Size u) -> Value -> Env -> Env linkMissingSizes [] _ _ env = env linkMissingSizes missing_sizes p v env = env <> i64Env (resolveExistentials missing_sizes p_t (valueShape v)) where p_t = evalToStruct $ expandType env $ patternStructType p evalFunction :: Env -> [VName] -> [Pat ParamType] -> Exp -> ResType -> EvalM Value -- We treat zero-parameter lambdas as simply an expression to -- evaluate immediately. Note that this is *not* the same as a lambda -- that takes an empty tuple '()' as argument! Zero-parameter lambdas -- can never occur in a well-formed Futhark program, but they are -- convenient in the interpreter. evalFunction env missing_sizes [] body rettype = -- Eta-expand the rest to make any sizes visible. etaExpand [] env rettype where etaExpand vs env' (Scalar (Arrow _ _ _ p_t (RetType _ rt))) = do pure . ValueFun $ \v -> do let p = Wildcard (Info p_t) noLoc env'' <- linkMissingSizes missing_sizes p v <$> matchPat env' p v etaExpand (v : vs) env'' rt etaExpand vs env' _ = do f <- localExts $ eval env' body foldM (apply noLoc mempty) f $ reverse vs evalFunction env missing_sizes (p : ps) body rettype = pure . ValueFun $ \v -> do env' <- linkMissingSizes missing_sizes p v <$> matchPat env p v evalFunction env' missing_sizes ps body rettype evalFunctionBinding :: Env -> [TypeParam] -> [Pat ParamType] -> ResRetType -> Exp -> EvalM TermBinding evalFunctionBinding env tparams ps ret fbody = do let ftype = funType ps ret retext = case ps of [] -> retDims ret _ -> [] -- Distinguish polymorphic and non-polymorphic bindings here. if null tparams then fmap (TermValue (Just $ T.BoundV [] ftype)) . returned env (retType ret) retext =<< evalFunction env [] ps fbody (retType ret) else pure . TermPoly (Just $ T.BoundV [] ftype) $ \ftype' -> do let resolved = resolveTypeParams (map typeParamName tparams) ftype ftype' tparam_env <- evalResolved resolved let env' = tparam_env <> env -- In some cases (abstract lifted types) there may be -- missing sizes that were not fixed by the type -- instantiation. These will have to be set by looking -- at the actual function arguments. missing_sizes = filter (`M.notMember` envTerm env') $ map typeParamName (filter isSizeParam tparams) returned env (retType ret) retext =<< evalFunction env' missing_sizes ps fbody (retType ret) evalArg :: Env -> Exp -> Maybe VName -> EvalM Value evalArg env e ext = do v <- eval env e case ext of Just ext' -> putExtSize ext' v _ -> pure () pure v returned :: Env -> TypeBase Size u -> [VName] -> Value -> EvalM Value returned _ _ [] v = pure v returned env ret retext v = do mapM_ (uncurry putExtSize . second (ValuePrim . SignedValue . Int64Value)) . M.toList $ resolveExistentials retext (evalToStruct $ expandType env $ toStruct ret) $ valueShape v pure v evalAppExp :: Env -> AppExp -> EvalM Value evalAppExp env (Range start maybe_second end loc) = do start' <- asInteger <$> eval env start maybe_second' <- traverse (fmap asInteger . eval env) maybe_second end' <- traverse (fmap asInteger . eval env) end let (end_adj, step, ok) = case (end', maybe_second') of (DownToExclusive end'', Nothing) -> (end'' + 1, -1, start' >= end'') (DownToExclusive end'', Just second') -> (end'' + 1, second' - start', start' >= end'' && second' < start') (ToInclusive end'', Nothing) -> (end'', 1, start' <= end'') (ToInclusive end'', Just second') | second' > start' -> (end'', second' - start', start' <= end'') | otherwise -> (end'', second' - start', start' >= end'' && second' /= start') (UpToExclusive x, Nothing) -> (x - 1, 1, start' <= x) (UpToExclusive x, Just second') -> (x - 1, second' - start', start' <= x && second' > start') if ok then pure $ toArray' ShapeLeaf $ map toInt [start', start' + step .. end_adj] else bad loc env $ badRange start' maybe_second' end' where toInt = case typeOf start of Scalar (Prim (Signed t')) -> ValuePrim . SignedValue . intValue t' Scalar (Prim (Unsigned t')) -> ValuePrim . UnsignedValue . intValue t' t -> error $ "Nonsensical range type: " ++ show t badRange start' maybe_second' end' = "Range " <> prettyText start' <> ( case maybe_second' of Nothing -> "" Just second' -> ".." <> prettyText second' ) <> ( case end' of DownToExclusive x -> "..>" <> prettyText x ToInclusive x -> "..." <> prettyText x UpToExclusive x -> "..<" <> prettyText x ) <> " is invalid." evalAppExp env (LetPat sizes p e body _) = do v <- eval env e env' <- matchPat env p v let p_t = evalToStruct $ expandType env $ patternStructType p v_s = valueShape v env'' = env' <> i64Env (resolveExistentials (map sizeName sizes) p_t v_s) eval env'' body evalAppExp env (LetFun f (tparams, ps, _, Info ret, fbody) body _) = do binding <- evalFunctionBinding env tparams ps ret fbody eval (env {envTerm = M.insert f binding $ envTerm env}) body evalAppExp env (BinOp (op, _) op_t (x, Info xext) (y, Info yext) loc) | baseString (qualLeaf op) == "&&" = do x' <- asBool <$> eval env x if x' then eval env y else pure $ ValuePrim $ BoolValue False | baseString (qualLeaf op) == "||" = do x' <- asBool <$> eval env x if x' then pure $ ValuePrim $ BoolValue True else eval env y | otherwise = do x' <- evalArg env x xext y' <- evalArg env y yext op' <- eval env $ Var op op_t loc apply2 loc env op' x' y' evalAppExp env (If cond e1 e2 _) = do cond' <- asBool <$> eval env cond if cond' then eval env e1 else eval env e2 evalAppExp env (Apply f args loc) = do -- It is important that 'arguments' are evaluated in reverse order -- in order to bring any sizes into scope that may be used in the -- type of the functions. args' <- reverse <$> mapM evalArg' (reverse $ NE.toList args) f' <- eval env f foldM (apply loc env) f' args' where evalArg' (Info ext, x) = evalArg env x ext evalAppExp env (Index e is loc) = do is' <- mapM (evalDimIndex env) is arr <- eval env e evalIndex loc env is' arr evalAppExp env (LetWith dest src is v body loc) = do let Ident src_vn (Info src_t) _ = src dest' <- maybe oob pure =<< writeArray <$> mapM (evalDimIndex env) is <*> evalTermVar env (qualName src_vn) (toStruct src_t) <*> eval env v let t = T.BoundV [] $ toStruct $ unInfo $ identType dest eval (valEnv (M.singleton (identName dest) (Just t, dest')) <> env) body where oob = bad loc env "Update out of bounds" evalAppExp env (Loop sparams pat loopinit form body _) = do init_v <- eval env $ loopInitExp loopinit case form of For iv bound -> do bound' <- asSigned <$> eval env bound forLoop (identName iv) bound' (zero bound') init_v ForIn in_pat in_e -> do (_, in_vs) <- fromArray <$> eval env in_e foldM (forInLoop in_pat) init_v in_vs While cond -> whileLoop cond init_v where withLoopParams v = let sparams' = resolveExistentials sparams (patternStructType pat) (valueShape v) in matchPat (i64Env sparams' <> env) pat v inc = (`P.doAdd` Int64Value 1) zero = (`P.doMul` Int64Value 0) evalBody env' = localExts $ eval env' body forLoopEnv iv i = valEnv ( M.singleton iv ( Just $ T.BoundV [] $ Scalar $ Prim $ Signed Int64, ValuePrim (SignedValue i) ) ) forLoop iv bound i v | i >= bound = pure v | otherwise = do env' <- withLoopParams v forLoop iv bound (inc i) =<< evalBody (forLoopEnv iv i <> env') whileLoop cond v = do env' <- withLoopParams v continue <- asBool <$> eval env' cond if continue then whileLoop cond =<< evalBody env' else pure v forInLoop in_pat v in_v = do env' <- withLoopParams v env'' <- matchPat env' in_pat in_v evalBody env'' evalAppExp env (Match e cs _) = do v <- eval env e match v (NE.toList cs) where match _ [] = error "Pattern match failure." match v (c : cs') = do c' <- evalCase v env c case c' of Just v' -> pure v' Nothing -> match v cs' eval :: Env -> Exp -> EvalM Value eval _ (Literal v _) = pure $ ValuePrim v eval env (Hole (Info t) loc) = bad loc env $ "Hole of type: " <> prettyTextOneLine t eval env (Parens e _) = eval env e eval env (QualParens (qv, _) e loc) = do m <- evalModuleVar env qv case m of ModuleFun {} -> error $ "Local open of module function at " ++ locStr loc Module m' -> eval (m' <> env) e eval env (TupLit vs _) = toTuple <$> mapM (eval env) vs eval env (RecordLit fields _) = ValueRecord . M.fromList <$> mapM evalField fields where evalField (RecordFieldExplicit (L _ k) e _) = do v <- eval env e pure (k, v) evalField (RecordFieldImplicit (L _ k) t loc) = do v <- eval env $ Var (qualName k) t loc pure (baseName k, v) eval _ (StringLit vs _) = pure $ toArray' ShapeLeaf $ map (ValuePrim . UnsignedValue . Int8Value . fromIntegral) vs eval env (ArrayLit [] (Info t) _) = do t' <- typeValueShape env $ toStruct t pure $ toArray t' [] eval env (ArrayLit (v : vs) _ _) = do v' <- eval env v vs' <- mapM (eval env) vs pure $ toArray' (valueShape v') (v' : vs') eval _ (ArrayVal vs _ _) = -- Probably will not ever be used. pure $ toArray' ShapeLeaf $ map ValuePrim vs eval env (AppExp e (Info (AppRes t retext))) = do v <- evalAppExp env e returned env (toStruct t) retext v eval env (Var qv (Info t) _) = evalTermVar env qv (toStruct t) eval env (Ascript e _ _) = eval env e eval env (Coerce e te (Info t) loc) = do v <- eval env e t' <- evalTypeFully $ expandType env $ toStruct t case checkShape (typeShape t') (valueShape v) of Just _ -> pure v Nothing -> bad loc env . docText $ "Value `" <> prettyValue v <> "` of shape `" <> pretty (valueShape v) <> "` cannot match shape of type `" <> pretty te <> "` (`" <> pretty t' <> "`)" eval _ (IntLit v (Info t) _) = case t of Scalar (Prim (Signed it)) -> pure $ ValuePrim $ SignedValue $ intValue it v Scalar (Prim (Unsigned it)) -> pure $ ValuePrim $ UnsignedValue $ intValue it v Scalar (Prim (FloatType ft)) -> pure $ ValuePrim $ FloatValue $ floatValue ft v _ -> error $ "eval: nonsensical type for integer literal: " <> prettyString t eval _ (FloatLit v (Info t) _) = case t of Scalar (Prim (FloatType ft)) -> pure $ ValuePrim $ FloatValue $ floatValue ft v _ -> error $ "eval: nonsensical type for float literal: " <> prettyString t eval env (Negate e loc) = do ev <- eval env e apply loc env intrinsicsNeg ev eval env (Not e loc) = apply loc env intrinsicsNot =<< eval env e eval env (Update src is v loc) = maybe oob pure =<< writeArray <$> mapM (evalDimIndex env) is <*> eval env src <*> eval env v where oob = bad loc env "Bad update" eval env (RecordUpdate src all_fs v _ _) = update <$> eval env src <*> pure all_fs <*> eval env v where update _ [] v' = v' update (ValueRecord src') (f : fs) v' | Just f_v <- M.lookup f src' = ValueRecord $ M.insert f (update f_v fs v') src' update _ _ _ = error "eval RecordUpdate: invalid value." -- We treat zero-parameter lambdas as simply an expression to -- evaluate immediately. Note that this is *not* the same as a lambda -- that takes an empty tuple '()' as argument! Zero-parameter lambdas -- can never occur in a well-formed Futhark program, but they are -- convenient in the interpreter. eval env (Lambda ps body _ (Info (RetType _ rt)) _) = evalFunction env [] ps body rt eval env (OpSection qv (Info t) _) = evalTermVar env qv $ toStruct t eval env (OpSectionLeft qv _ e (Info (_, _, argext), _) (Info (RetType _ t), _) loc) = do v <- evalArg env e argext f <- evalTermVar env qv (toStruct t) apply loc env f v eval env (OpSectionRight qv _ e (Info _, Info (_, _, argext)) (Info (RetType _ t)) loc) = do y <- evalArg env e argext pure $ ValueFun $ \x -> do f <- evalTermVar env qv $ toStruct t apply2 loc env f x y eval env (IndexSection is _ loc) = do is' <- mapM (evalDimIndex env) is pure $ ValueFun $ evalIndex loc env is' eval _ (ProjectSection ks _ _) = pure $ ValueFun $ flip (foldM walk) ks where walk (ValueRecord fs) f | Just v' <- M.lookup f fs = pure v' walk _ _ = error "Value does not have expected field." eval env (Project f e _ _) = do v <- eval env e case v of ValueRecord fs | Just v' <- M.lookup f fs -> pure v' _ -> error "Value does not have expected field." eval env (Assert what e (Info s) loc) = do cond <- asBool <$> eval env what unless cond $ bad loc env s eval env e eval env (Constr c es (Info t) _) = do vs <- mapM (eval env) es shape <- typeValueShape env $ toStruct t pure $ ValueSum shape c vs eval env (Attr (AttrAtom (AtomName "break") _) e loc) = do break env (locOf loc) eval env e eval env (Attr (AttrAtom (AtomName "trace") _) e loc) = do v <- eval env e trace (locText (locOf loc)) v pure v eval env (Attr (AttrComp "trace" [AttrAtom (AtomName tag) _] _) e _) = do v <- eval env e trace (nameToText tag) v pure v eval env (Attr _ e _) = eval env e evalCase :: Value -> Env -> CaseBase Info VName -> EvalM (Maybe Value) evalCase v env (CasePat p cExp _) = runMaybeT $ do env' <- patternMatch env p v lift $ eval env' cExp -- We hackily do multiple substitutions in modules, because otherwise -- we would lose in cases where the parameter substitutions are [a->x, -- b->x] when we reverse. (See issue #1250.) reverseSubstitutions :: M.Map VName VName -> M.Map VName [VName] reverseSubstitutions = M.fromListWith (<>) . map (second pure . uncurry (flip (,))) . M.toList substituteInModule :: M.Map VName VName -> Module -> Module substituteInModule substs = onModule where rev_substs = reverseSubstitutions substs replace v = fromMaybe [v] $ M.lookup v rev_substs replaceQ v = maybe v qualName $ maybeHead =<< M.lookup (qualLeaf v) rev_substs replaceM f m = M.fromList $ do (k, v) <- M.toList m k' <- replace k pure (k', f v) onEnv (Env terms types) = Env (replaceM onTerm terms) (replaceM (bimap onEnv onType) types) onModule (Module env) = Module $ onEnv env onModule (ModuleFun f) = ModuleFun $ \m -> onModule <$> f (substituteInModule substs m) onTerm (TermValue t v) = TermValue t v onTerm (TermPoly t v) = TermPoly t v onTerm (TermModule m) = TermModule $ onModule m onType (T.TypeAbbr l ps t) = T.TypeAbbr l ps $ first onDim t onDim (Var v typ loc) = Var (replaceQ v) typ loc onDim (IntLit x t loc) = IntLit x t loc onDim _ = error "Arbitrary expression not supported yet" evalModuleVar :: Env -> QualName VName -> EvalM Module evalModuleVar env qv = case lookupVar qv env of Just (TermModule m) -> pure m _ -> error $ prettyString qv <> " is not bound to a module." -- We also return a new Env here, because we want the definitions -- inside any constructed modules to also be in scope at the top -- level. This is because types may contain un-qualified references to -- definitions in modules, and sometimes those definitions may not -- actually *have* any qualified name! See tests/modules/sizes7.fut. -- This occurs solely because of evalType. evalModExp :: Env -> ModExp -> EvalM (Env, Module) evalModExp _ (ModImport _ (Info f) _) = do f' <- lookupImport f known <- asks snd case f' of Nothing -> error $ unlines [ "Unknown interpreter import: " ++ show f, "Known: " ++ show (M.keys known) ] Just m -> pure (mempty, Module m) evalModExp env (ModDecs ds _) = do Env terms types <- foldM evalDec env ds -- Remove everything that was present in the original Env. let env' = Env (terms `M.difference` envTerm env) (types `M.difference` envType env) pure (env', Module env') evalModExp env (ModVar qv _) = (mempty,) <$> evalModuleVar env qv evalModExp env (ModAscript me _ (Info substs) _) = bimap substituteInEnv (substituteInModule substs) <$> evalModExp env me where substituteInEnv env' = let Module env'' = substituteInModule substs (Module env') in env'' evalModExp env (ModParens me _) = evalModExp env me evalModExp env (ModLambda p ret e loc) = pure ( mempty, ModuleFun $ \am -> do let env' = env {envTerm = M.insert (modParamName p) (TermModule am) $ envTerm env} fmap snd . evalModExp env' $ case ret of Nothing -> e Just (se, rsubsts) -> ModAscript e se rsubsts loc ) evalModExp env (ModApply f e (Info psubst) (Info rsubst) _) = do (f_env, f') <- evalModExp env f (e_env, e') <- evalModExp env e case f' of ModuleFun f'' -> do res_mod <- substituteInModule rsubst <$> f'' (substituteInModule psubst e') let res_env = case res_mod of Module x -> x _ -> mempty pure (f_env <> e_env <> res_env, res_mod) _ -> error "Expected ModuleFun." evalDec :: Env -> Dec -> EvalM Env evalDec env (ValDec (ValBind _ v _ (Info ret) tparams ps fbody _ _ _)) = localExts $ do binding <- evalFunctionBinding env tparams ps ret fbody sizes <- extEnv pure $ env {envTerm = M.insert v binding $ envTerm env} <> sizes evalDec env (OpenDec me _) = do (me_env, me') <- evalModExp env me case me' of Module me'' -> pure $ me'' <> me_env <> env _ -> error "Expected Module" evalDec env (ImportDec name name' loc) = evalDec env $ LocalDec (OpenDec (ModImport name name' loc) loc) loc evalDec env (LocalDec d _) = evalDec env d evalDec env ModTypeDec {} = pure env evalDec env (TypeDec (TypeBind v l ps _ (Info (RetType dims t)) _ _)) = do let abbr = (env, T.TypeAbbr l ps . RetType dims $ evalToStruct $ expandType env t) pure env {envType = M.insert v abbr $ envType env} evalDec env (ModDec (ModBind v ps ret body _ loc)) = do (mod_env, mod) <- evalModExp env $ wrapInLambda ps pure $ modEnv (M.singleton v mod) <> mod_env <> env where wrapInLambda [] = case ret of Just (se, substs) -> ModAscript body se substs loc Nothing -> body wrapInLambda [p] = ModLambda p ret body loc wrapInLambda (p : ps') = ModLambda p Nothing (wrapInLambda ps') loc -- | The interpreter context. All evaluation takes place with respect -- to a context, and it can be extended with more definitions, which -- is how the REPL works. data Ctx = Ctx { ctxEnv :: Env, ctxImports :: M.Map ImportName Env } nanValue :: PrimValue -> Bool nanValue (FloatValue v) = case v of Float16Value x -> isNaN x Float32Value x -> isNaN x Float64Value x -> isNaN x nanValue _ = False breakOnNaN :: [PrimValue] -> PrimValue -> EvalM () breakOnNaN inputs result | not (any nanValue inputs) && nanValue result = do backtrace <- asks fst case NE.nonEmpty backtrace of Nothing -> pure () Just backtrace' -> let loc = stackFrameLoc $ NE.head backtrace' in liftF $ ExtOpBreak loc BreakNaN backtrace' () breakOnNaN _ _ = pure () -- | The initial environment contains definitions of the various intrinsic functions. initialCtx :: Ctx initialCtx = Ctx ( Env ( M.insert (VName (nameFromText "intrinsics") 0) (TermModule (Module $ Env terms types)) terms ) types ) mempty where terms = M.mapMaybeWithKey (const . def . baseText) intrinsics types = M.mapMaybeWithKey (const . tdef . baseName) intrinsics sintOp f = [ (getS, putS, P.doBinOp (f Int8), adBinOp $ AD.OpBin (f Int8)), (getS, putS, P.doBinOp (f Int16), adBinOp $ AD.OpBin (f Int16)), (getS, putS, P.doBinOp (f Int32), adBinOp $ AD.OpBin (f Int32)), (getS, putS, P.doBinOp (f Int64), adBinOp $ AD.OpBin (f Int64)) ] uintOp f = [ (getU, putU, P.doBinOp (f Int8), adBinOp $ AD.OpBin (f Int8)), (getU, putU, P.doBinOp (f Int16), adBinOp $ AD.OpBin (f Int16)), (getU, putU, P.doBinOp (f Int32), adBinOp $ AD.OpBin (f Int32)), (getU, putU, P.doBinOp (f Int64), adBinOp $ AD.OpBin (f Int64)) ] intOp f = sintOp f ++ uintOp f floatOp f = [ (getF, putF, P.doBinOp (f Float16), adBinOp $ AD.OpBin (f Float16)), (getF, putF, P.doBinOp (f Float32), adBinOp $ AD.OpBin (f Float32)), (getF, putF, P.doBinOp (f Float64), adBinOp $ AD.OpBin (f Float64)) ] arithOp f g = Just $ bopDef $ intOp f ++ floatOp g flipCmps = map (\(f, g, h, o) -> (f, g, flip h, flip o)) sintCmp f = [ (getS, Just . BoolValue, P.doCmpOp (f Int8), adBinOp $ AD.OpCmp (f Int8)), (getS, Just . BoolValue, P.doCmpOp (f Int16), adBinOp $ AD.OpCmp (f Int16)), (getS, Just . BoolValue, P.doCmpOp (f Int32), adBinOp $ AD.OpCmp (f Int32)), (getS, Just . BoolValue, P.doCmpOp (f Int64), adBinOp $ AD.OpCmp (f Int64)) ] uintCmp f = [ (getU, Just . BoolValue, P.doCmpOp (f Int8), adBinOp $ AD.OpCmp (f Int8)), (getU, Just . BoolValue, P.doCmpOp (f Int16), adBinOp $ AD.OpCmp (f Int16)), (getU, Just . BoolValue, P.doCmpOp (f Int32), adBinOp $ AD.OpCmp (f Int32)), (getU, Just . BoolValue, P.doCmpOp (f Int64), adBinOp $ AD.OpCmp (f Int64)) ] floatCmp f = [ (getF, Just . BoolValue, P.doCmpOp (f Float16), adBinOp $ AD.OpCmp (f Float16)), (getF, Just . BoolValue, P.doCmpOp (f Float32), adBinOp $ AD.OpCmp (f Float32)), (getF, Just . BoolValue, P.doCmpOp (f Float64), adBinOp $ AD.OpCmp (f Float64)) ] boolCmp f = [(getB, Just . BoolValue, P.doCmpOp f, adBinOp $ AD.OpCmp f)] getV (SignedValue x) = Just $ P.IntValue x getV (UnsignedValue x) = Just $ P.IntValue x getV (FloatValue x) = Just $ P.FloatValue x getV (BoolValue x) = Just $ P.BoolValue x putV (P.IntValue x) = SignedValue x putV (P.FloatValue x) = FloatValue x putV (P.BoolValue x) = BoolValue x putV P.UnitValue = BoolValue True getS (SignedValue x) = Just $ P.IntValue x getS _ = Nothing putS (P.IntValue x) = Just $ SignedValue x putS _ = Nothing getU (UnsignedValue x) = Just $ P.IntValue x getU _ = Nothing putU (P.IntValue x) = Just $ UnsignedValue x putU _ = Nothing getF (FloatValue x) = Just $ P.FloatValue x getF _ = Nothing putF (P.FloatValue x) = Just $ FloatValue x putF _ = Nothing getB (BoolValue x) = Just $ P.BoolValue x getB _ = Nothing putB (P.BoolValue x) = Just $ BoolValue x putB _ = Nothing getAD (ValuePrim v) = AD.Constant <$> getV v getAD (ValueAD d v) = Just $ AD.Variable d v getAD _ = Nothing putAD (AD.Variable d s) = ValueAD d s putAD (AD.Constant v) = ValuePrim $ putV v adToPrim v = putV $ AD.primitive v adBinOp op x y = AD.doOp op [x, y] adUnOp op x = AD.doOp op [x] fun1 f = TermValue Nothing $ ValueFun $ \x -> f x fun2 f = TermValue Nothing . ValueFun $ \x -> pure . ValueFun $ \y -> f x y fun3 f = TermValue Nothing . ValueFun $ \x -> pure . ValueFun $ \y -> pure . ValueFun $ \z -> f x y z fun5 f = TermValue Nothing . ValueFun $ \x -> pure . ValueFun $ \y -> pure . ValueFun $ \z -> pure . ValueFun $ \a -> pure . ValueFun $ \b -> f x y z a b fun6 f = TermValue Nothing . ValueFun $ \x -> pure . ValueFun $ \y -> pure . ValueFun $ \z -> pure . ValueFun $ \a -> pure . ValueFun $ \b -> pure . ValueFun $ \c -> f x y z a b c fun7 f = TermValue Nothing . ValueFun $ \x -> pure . ValueFun $ \y -> pure . ValueFun $ \z -> pure . ValueFun $ \a -> pure . ValueFun $ \b -> pure . ValueFun $ \c -> pure . ValueFun $ \d -> f x y z a b c d fun8 f = TermValue Nothing . ValueFun $ \x -> pure . ValueFun $ \y -> pure . ValueFun $ \z -> pure . ValueFun $ \a -> pure . ValueFun $ \b -> pure . ValueFun $ \c -> pure . ValueFun $ \d -> pure . ValueFun $ \e -> f x y z a b c d e fun10 f = TermValue Nothing . ValueFun $ \x -> pure . ValueFun $ \y -> pure . ValueFun $ \z -> pure . ValueFun $ \a -> pure . ValueFun $ \b -> pure . ValueFun $ \c -> pure . ValueFun $ \d -> pure . ValueFun $ \e -> pure . ValueFun $ \g -> pure . ValueFun $ \h -> f x y z a b c d e g h bopDef fs = fun2 $ \x y -> case (x, y) of (ValuePrim x', ValuePrim y') | Just z <- msum $ map (`bopDef'` (x', y')) fs -> do breakOnNaN [x', y'] z pure $ ValuePrim z _ | Just x' <- getAD x, Just y' <- getAD y, Just z <- msum $ map (`bopDefAD'` (x', y')) fs -> do breakOnNaN [adToPrim x', adToPrim y'] $ adToPrim z pure $ putAD z _ -> bad noLoc mempty . docText $ "Cannot apply operator to arguments" <+> dquotes (prettyValue x) <+> "and" <+> dquotes (prettyValue y) <> "." where bopDef' (valf, retf, op, _) (x, y) = do x' <- valf x y' <- valf y retf =<< op x' y' bopDefAD' (_, _, _, dop) (x, y) = dop x y unopDef fs = fun1 $ \x -> case x of (ValuePrim x') | Just r <- msum $ map (`unopDef'` x') fs -> do breakOnNaN [x'] r pure $ ValuePrim r _ | Just x' <- getAD x, Just r <- msum $ map (`unopDefAD'` x') fs -> do breakOnNaN [adToPrim x'] $ adToPrim r pure $ putAD r _ -> bad noLoc mempty . docText $ "Cannot apply function to argument" <+> dquotes (prettyValue x) <> "." where unopDef' (valf, retf, op, _) x = do x' <- valf x retf =<< op x' unopDefAD' (_, _, _, dop) = dop tbopDef op f = fun1 $ \v -> case fromTuple v of Just [ValuePrim x, ValuePrim y] | Just x' <- getV x, Just y' <- getV y, Just z <- putV <$> f x' y' -> do breakOnNaN [x, y] z pure $ ValuePrim z Just [x, y] | Just x' <- getAD x, Just y' <- getAD y, Just z <- AD.doOp op [x', y'] -> do breakOnNaN [adToPrim x', adToPrim y'] $ adToPrim z pure $ putAD z _ -> bad noLoc mempty . docText $ "Cannot apply operator to argument" <+> dquotes (prettyValue v) <> "." def :: T.Text -> Maybe TermBinding def "!" = Just $ unopDef [ (getS, putS, P.doUnOp $ P.Complement Int8, adUnOp $ AD.OpUn $ P.Complement Int8), (getS, putS, P.doUnOp $ P.Complement Int16, adUnOp $ AD.OpUn $ P.Complement Int16), (getS, putS, P.doUnOp $ P.Complement Int32, adUnOp $ AD.OpUn $ P.Complement Int32), (getS, putS, P.doUnOp $ P.Complement Int64, adUnOp $ AD.OpUn $ P.Complement Int64), (getU, putU, P.doUnOp $ P.Complement Int8, adUnOp $ AD.OpUn $ P.Complement Int8), (getU, putU, P.doUnOp $ P.Complement Int16, adUnOp $ AD.OpUn $ P.Complement Int16), (getU, putU, P.doUnOp $ P.Complement Int32, adUnOp $ AD.OpUn $ P.Complement Int32), (getU, putU, P.doUnOp $ P.Complement Int64, adUnOp $ AD.OpUn $ P.Complement Int64), (getB, putB, P.doUnOp $ P.Neg P.Bool, adUnOp $ AD.OpUn $ P.Neg P.Bool) ] def "neg" = Just $ unopDef [ (getS, putS, P.doUnOp $ P.Neg $ P.IntType Int8, adUnOp $ AD.OpUn $ P.Neg $ P.IntType Int8), (getS, putS, P.doUnOp $ P.Neg $ P.IntType Int16, adUnOp $ AD.OpUn $ P.Neg $ P.IntType Int16), (getS, putS, P.doUnOp $ P.Neg $ P.IntType Int32, adUnOp $ AD.OpUn $ P.Neg $ P.IntType Int32), (getS, putS, P.doUnOp $ P.Neg $ P.IntType Int64, adUnOp $ AD.OpUn $ P.Neg $ P.IntType Int64), (getU, putU, P.doUnOp $ P.Neg $ P.IntType Int8, adUnOp $ AD.OpUn $ P.Neg $ P.IntType Int8), (getU, putU, P.doUnOp $ P.Neg $ P.IntType Int16, adUnOp $ AD.OpUn $ P.Neg $ P.IntType Int16), (getU, putU, P.doUnOp $ P.Neg $ P.IntType Int32, adUnOp $ AD.OpUn $ P.Neg $ P.IntType Int32), (getU, putU, P.doUnOp $ P.Neg $ P.IntType Int64, adUnOp $ AD.OpUn $ P.Neg $ P.IntType Int64), (getF, putF, P.doUnOp $ P.Neg $ P.FloatType Float16, adUnOp $ AD.OpUn $ P.Neg $ P.FloatType Float16), (getF, putF, P.doUnOp $ P.Neg $ P.FloatType Float32, adUnOp $ AD.OpUn $ P.Neg $ P.FloatType Float32), (getF, putF, P.doUnOp $ P.Neg $ P.FloatType Float64, adUnOp $ AD.OpUn $ P.Neg $ P.FloatType Float64), (getB, putB, P.doUnOp $ P.Neg P.Bool, adUnOp $ AD.OpUn $ P.Neg P.Bool) ] def "+" = arithOp (`P.Add` P.OverflowWrap) P.FAdd def "-" = arithOp (`P.Sub` P.OverflowWrap) P.FSub def "*" = arithOp (`P.Mul` P.OverflowWrap) P.FMul def "**" = arithOp P.Pow P.FPow def "/" = Just $ bopDef $ sintOp (`P.SDiv` P.Unsafe) ++ uintOp (`P.UDiv` P.Unsafe) ++ floatOp P.FDiv def "%" = Just $ bopDef $ sintOp (`P.SMod` P.Unsafe) ++ uintOp (`P.UMod` P.Unsafe) ++ floatOp P.FMod def "//" = Just $ bopDef $ sintOp (`P.SQuot` P.Unsafe) ++ uintOp (`P.UDiv` P.Unsafe) def "%%" = Just $ bopDef $ sintOp (`P.SRem` P.Unsafe) ++ uintOp (`P.UMod` P.Unsafe) def "^" = Just $ bopDef $ intOp P.Xor def "&" = Just $ bopDef $ intOp P.And def "|" = Just $ bopDef $ intOp P.Or def ">>" = Just $ bopDef $ sintOp P.AShr ++ uintOp P.LShr def "<<" = Just $ bopDef $ intOp P.Shl def ">>>" = Just $ bopDef $ sintOp P.LShr ++ uintOp P.LShr def "==" = Just $ fun2 $ \xs ys -> pure $ ValuePrim $ BoolValue $ xs == ys def "!=" = Just $ fun2 $ \xs ys -> pure $ ValuePrim $ BoolValue $ xs /= ys -- The short-circuiting is handled directly in 'eval'; these cases -- are only used when partially applying and such. def "&&" = Just $ fun2 $ \x y -> pure $ ValuePrim $ BoolValue $ asBool x && asBool y def "||" = Just $ fun2 $ \x y -> pure $ ValuePrim $ BoolValue $ asBool x || asBool y def "<" = Just $ bopDef $ sintCmp P.CmpSlt ++ uintCmp P.CmpUlt ++ floatCmp P.FCmpLt ++ boolCmp P.CmpLlt def ">" = Just $ bopDef $ flipCmps $ sintCmp P.CmpSlt ++ uintCmp P.CmpUlt ++ floatCmp P.FCmpLt ++ boolCmp P.CmpLlt def "<=" = Just $ bopDef $ sintCmp P.CmpSle ++ uintCmp P.CmpUle ++ floatCmp P.FCmpLe ++ boolCmp P.CmpLle def ">=" = Just $ bopDef $ flipCmps $ sintCmp P.CmpSle ++ uintCmp P.CmpUle ++ floatCmp P.FCmpLe ++ boolCmp P.CmpLle def s | Just bop <- find ((s ==) . prettyText) P.allBinOps = Just $ tbopDef (AD.OpBin bop) $ P.doBinOp bop | Just unop <- find ((s ==) . prettyText) P.allCmpOps = Just $ tbopDef (AD.OpCmp unop) $ \x y -> P.BoolValue <$> P.doCmpOp unop x y | Just cop <- find ((s ==) . prettyText) P.allConvOps = Just $ unopDef [(getV, Just . putV, P.doConvOp cop, adUnOp $ AD.OpConv cop)] | Just unop <- find ((s ==) . prettyText) P.allUnOps = Just $ unopDef [(getV, Just . putV, P.doUnOp unop, adUnOp $ AD.OpUn unop)] | Just (pts, _, f) <- M.lookup s P.primFuns = case length pts of 1 -> Just $ unopDef [(getV, Just . putV, f . pure, adUnOp $ AD.OpFn s)] _ -> Just $ fun1 $ \x -> do let getV' (ValuePrim v) = Just v getV' _ = Nothing case mapM getV' =<< fromTuple x of Just vs | Just res <- fmap putV . f =<< mapM getV vs -> do breakOnNaN vs res pure $ ValuePrim res _ -> error $ "Cannot apply " <> prettyString s ++ " to " <> show x | "sign_" `T.isPrefixOf` s = Just $ fun1 $ \x -> case x of (ValuePrim (UnsignedValue x')) -> pure $ ValuePrim $ SignedValue x' _ -> error $ "Cannot sign: " <> show x | "unsign_" `T.isPrefixOf` s = Just $ fun1 $ \x -> case x of (ValuePrim (SignedValue x')) -> pure $ ValuePrim $ UnsignedValue x' _ -> error $ "Cannot unsign: " <> show x def "map" = Just $ TermPoly Nothing $ \t -> do t' <- evalTypeFully t pure $ ValueFun $ \f -> pure . ValueFun $ \xs -> case unfoldFunType t' of ([_, _], ret_t) | rowshape <- typeShape $ stripArray 1 ret_t -> toArray' rowshape <$> mapM (apply noLoc mempty f) (snd $ fromArray xs) _ -> error $ "Invalid arguments to map intrinsic:\n" ++ unlines [prettyString t, show f, show xs] def s | "reduce" `T.isPrefixOf` s = Just $ fun3 $ \f ne xs -> foldM (apply2 noLoc mempty f) ne $ snd $ fromArray xs def "scan" = Just $ fun3 $ \f ne xs -> do let next (out, acc) x = do x' <- apply2 noLoc mempty f acc x pure (x' : out, x') toArray' (valueShape ne) . reverse . fst <$> foldM next ([], ne) (snd $ fromArray xs) def "scatter" = Just $ fun3 $ \arr is vs -> case arr of ValueArray shape arr' -> pure $ ValueArray shape $ foldl' update arr' $ zip (map asInt $ snd $ fromArray is) (snd $ fromArray vs) _ -> error $ "scatter expects array, but got: " <> show arr where update arr' (i, v) = if i >= 0 && i < arrayLength arr' then arr' // [(i, v)] else arr' def "scatter_2d" = Just $ fun3 $ \arr is vs -> case arr of ValueArray _ _ -> pure $ foldl' update arr $ zip (map fromTuple $ snd $ fromArray is) (snd $ fromArray vs) _ -> error $ "scatter_2d expects array, but got: " <> show arr where update :: Value -> (Maybe [Value], Value) -> Value update arr (Just idxs@[_, _], v) = fromMaybe arr $ writeArray (map (IndexingFix . asInt64) idxs) arr v update _ _ = error "scatter_2d expects 2-dimensional indices" def "scatter_3d" = Just $ fun3 $ \arr is vs -> case arr of ValueArray _ _ -> pure $ foldl' update arr $ zip (map fromTuple $ snd $ fromArray is) (snd $ fromArray vs) _ -> error $ "scatter_3d expects array, but got: " <> show arr where update :: Value -> (Maybe [Value], Value) -> Value update arr (Just idxs@[_, _, _], v) = fromMaybe arr $ writeArray (map (IndexingFix . asInt64) idxs) arr v update _ _ = error "scatter_3d expects 3-dimensional indices" def "hist_1d" = Just . fun6 $ \_ arr fun _ is vs -> foldM (update fun) arr (zip (map asInt64 $ snd $ fromArray is) (snd $ fromArray vs)) where op = apply2 mempty mempty update fun arr (i, v) = fromMaybe arr <$> updateArray (op fun) [IndexingFix i] arr v def "hist_2d" = Just . fun6 $ \_ arr fun _ is vs -> foldM (update fun) arr (zip (map fromTuple $ snd $ fromArray is) (snd $ fromArray vs)) where op = apply2 mempty mempty update fun arr (Just idxs@[_, _], v) = fromMaybe arr <$> updateArray (op fun) (map (IndexingFix . asInt64) idxs) arr v update _ _ _ = error "hist_2d: bad index value" def "hist_3d" = Just . fun6 $ \_ arr fun _ is vs -> foldM (update fun) arr (zip (map fromTuple $ snd $ fromArray is) (snd $ fromArray vs)) where op = apply2 mempty mempty update fun arr (Just idxs@[_, _, _], v) = fromMaybe arr <$> updateArray (op fun) (map (IndexingFix . asInt64) idxs) arr v update _ _ _ = error "hist_2d: bad index value" def "partition" = Just $ fun3 $ \k f xs -> do let (ShapeDim _ rowshape, xs') = fromArray xs next outs x = do i <- asInt <$> apply noLoc mempty f x pure $ insertAt i x outs pack parts = toTuple [ toArray' rowshape $ concat parts, toArray' rowshape $ map (ValuePrim . SignedValue . Int64Value . genericLength) parts ] pack . map reverse <$> foldM next (replicate (asInt k) []) xs' where insertAt 0 x (l : ls) = (x : l) : ls insertAt i x (l : ls) = l : insertAt (i - 1) x ls insertAt _ _ ls = ls def "scatter_stream" = Just $ fun3 $ \dest f vs -> case (dest, vs) of ( ValueArray dest_shape dest_arr, ValueArray _ vs_arr ) -> do let acc = ValueAcc dest_shape (\_ x -> pure x) dest_arr acc' <- foldM (apply2 noLoc mempty f) acc vs_arr case acc' of ValueAcc _ _ dest_arr' -> pure $ ValueArray dest_shape dest_arr' _ -> error $ "scatter_stream produced: " <> show acc' _ -> error $ "scatter_stream expects array, but got: " <> prettyString (show vs, show vs) def "hist_stream" = Just $ fun5 $ \dest op _ne f vs -> case (dest, vs) of ( ValueArray dest_shape dest_arr, ValueArray _ vs_arr ) -> do let acc = ValueAcc dest_shape (apply2 noLoc mempty op) dest_arr acc' <- foldM (apply2 noLoc mempty f) acc vs_arr case acc' of ValueAcc _ _ dest_arr' -> pure $ ValueArray dest_shape dest_arr' _ -> error $ "hist_stream produced: " <> show acc' _ -> error $ "hist_stream expects array, but got: " <> prettyString (show dest, show vs) def "acc_write" = Just $ fun3 $ \acc i v -> case (acc, i) of ( ValueAcc shape op acc_arr, ValuePrim (SignedValue (Int64Value i')) ) -> write acc v shape op acc_arr i' ( ValueAcc shape op acc_arr, adv@(ValueAD {}) ) | Just (SignedValue (Int64Value i')) <- putV . AD.primitive <$> getAD adv -> write acc v shape op acc_arr i' _ -> error $ "acc_write invalid arguments: " <> prettyString (show acc, show i, show v) where write acc v shape op acc_arr i' = if i' >= 0 && i' < arrayLength acc_arr then do let x = acc_arr ! fromIntegral i' res <- op x v pure $ ValueAcc shape op $ acc_arr // [(fromIntegral i', res)] else pure acc -- def "flat_index_2d" = Just . fun6 $ \arr offset n1 s1 n2 s2 -> do let offset' = asInt64 offset n1' = asInt64 n1 n2' = asInt64 n2 s1' = asInt64 s1 s2' = asInt64 s2 shapeFromDims = foldr ShapeDim ShapeLeaf mk1 = fmap (toArray (shapeFromDims [n1', n2'])) . sequence mk2 = fmap (toArray $ shapeFromDims [n2']) . sequence iota x = [0 .. x - 1] f i j = indexArray [IndexingFix $ offset' + i * s1' + j * s2'] arr case mk1 [mk2 [f i j | j <- iota n2'] | i <- iota n1'] of Just arr' -> pure arr' Nothing -> bad mempty mempty $ "Index out of bounds: " <> prettyText [((n1', s1'), (n2', s2'))] -- def "flat_update_2d" = Just . fun5 $ \arr offset s1 s2 v -> do let offset' = asInt64 offset s1' = asInt64 s1 s2' = asInt64 s2 case valueShape v of ShapeDim n1 (ShapeDim n2 _) -> do let iota x = [0 .. x - 1] f arr' (i, j) = writeArray [IndexingFix $ offset' + i * s1' + j * s2'] arr' =<< indexArray [IndexingFix i, IndexingFix j] v case foldM f arr [(i, j) | i <- iota n1, j <- iota n2] of Just arr' -> pure arr' Nothing -> bad mempty mempty $ "Index out of bounds: " <> prettyText [((n1, s1'), (n2, s2'))] s -> error $ "flat_update_2d: invalid arg shape: " ++ show s -- def "flat_index_3d" = Just . fun8 $ \arr offset n1 s1 n2 s2 n3 s3 -> do let offset' = asInt64 offset n1' = asInt64 n1 n2' = asInt64 n2 n3' = asInt64 n3 s1' = asInt64 s1 s2' = asInt64 s2 s3' = asInt64 s3 shapeFromDims = foldr ShapeDim ShapeLeaf mk1 = fmap (toArray (shapeFromDims [n1', n2', n3'])) . sequence mk2 = fmap (toArray $ shapeFromDims [n2', n3']) . sequence mk3 = fmap (toArray $ shapeFromDims [n3']) . sequence iota x = [0 .. x - 1] f i j l = indexArray [IndexingFix $ offset' + i * s1' + j * s2' + l * s3'] arr case mk1 [mk2 [mk3 [f i j l | l <- iota n3'] | j <- iota n2'] | i <- iota n1'] of Just arr' -> pure arr' Nothing -> bad mempty mempty $ "Index out of bounds: " <> prettyText [((n1', s1'), (n2', s2'), (n3', s3'))] -- def "flat_update_3d" = Just . fun6 $ \arr offset s1 s2 s3 v -> do let offset' = asInt64 offset s1' = asInt64 s1 s2' = asInt64 s2 s3' = asInt64 s3 case valueShape v of ShapeDim n1 (ShapeDim n2 (ShapeDim n3 _)) -> do let iota x = [0 .. x - 1] f arr' (i, j, l) = writeArray [IndexingFix $ offset' + i * s1' + j * s2' + l * s3'] arr' =<< indexArray [IndexingFix i, IndexingFix j, IndexingFix l] v case foldM f arr [(i, j, l) | i <- iota n1, j <- iota n2, l <- iota n3] of Just arr' -> pure arr' Nothing -> bad mempty mempty $ "Index out of bounds: " <> prettyText [((n1, s1'), (n2, s2'), (n3, s3'))] s -> error $ "flat_update_3d: invalid arg shape: " ++ show s -- def "flat_index_4d" = Just . fun10 $ \arr offset n1 s1 n2 s2 n3 s3 n4 s4 -> do let offset' = asInt64 offset n1' = asInt64 n1 n2' = asInt64 n2 n3' = asInt64 n3 n4' = asInt64 n4 s1' = asInt64 s1 s2' = asInt64 s2 s3' = asInt64 s3 s4' = asInt64 s4 shapeFromDims = foldr ShapeDim ShapeLeaf mk1 = fmap (toArray (shapeFromDims [n1', n2', n3', n4'])) . sequence mk2 = fmap (toArray $ shapeFromDims [n2', n3', n4']) . sequence mk3 = fmap (toArray $ shapeFromDims [n3', n4']) . sequence mk4 = fmap (toArray $ shapeFromDims [n4']) . sequence iota x = [0 .. x - 1] f i j l m = indexArray [IndexingFix $ offset' + i * s1' + j * s2' + l * s3' + m * s4'] arr case mk1 [mk2 [mk3 [mk4 [f i j l m | m <- iota n4'] | l <- iota n3'] | j <- iota n2'] | i <- iota n1'] of Just arr' -> pure arr' Nothing -> bad mempty mempty $ "Index out of bounds: " <> prettyText [(((n1', s1'), (n2', s2')), ((n3', s3'), (n4', s4')))] -- def "flat_update_4d" = Just . fun7 $ \arr offset s1 s2 s3 s4 v -> do let offset' = asInt64 offset s1' = asInt64 s1 s2' = asInt64 s2 s3' = asInt64 s3 s4' = asInt64 s4 case valueShape v of ShapeDim n1 (ShapeDim n2 (ShapeDim n3 (ShapeDim n4 _))) -> do let iota x = [0 .. x - 1] f arr' (i, j, l, m) = writeArray [IndexingFix $ offset' + i * s1' + j * s2' + l * s3' + m * s4'] arr' =<< indexArray [IndexingFix i, IndexingFix j, IndexingFix l, IndexingFix m] v case foldM f arr [(i, j, l, m) | i <- iota n1, j <- iota n2, l <- iota n3, m <- iota n4] of Just arr' -> pure arr' Nothing -> bad mempty mempty $ "Index out of bounds: " <> prettyText [(((n1, s1'), (n2, s2')), ((n3, s3'), (n4, s4')))] s -> error $ "flat_update_4d: invalid arg shape: " ++ show s -- def "unzip" = Just $ fun1 $ \x -> do let ShapeDim _ (ShapeRecord fs) = valueShape x Just [xs_shape, ys_shape] = areTupleFields fs listPair (xs, ys) = [toArray' xs_shape xs, toArray' ys_shape ys] pure $ toTuple $ listPair $ unzip $ map (fromPair . fromTuple) $ snd $ fromArray x where fromPair (Just [x, y]) = (x, y) fromPair _ = error "Not a pair" def "zip" = Just $ fun2 $ \xs ys -> do let ShapeDim _ xs_rowshape = valueShape xs ShapeDim _ ys_rowshape = valueShape ys pure $ toArray' (ShapeRecord (tupleFields [xs_rowshape, ys_rowshape])) $ map toTuple $ transpose [snd $ fromArray xs, snd $ fromArray ys] def "concat" = Just $ fun2 $ \xs ys -> do let (ShapeDim _ rowshape, xs') = fromArray xs (_, ys') = fromArray ys pure $ toArray' rowshape $ xs' ++ ys' def "transpose" = Just $ fun1 $ \xs -> do let (ShapeDim n (ShapeDim m shape), xs') = fromArray xs pure $ toArray (ShapeDim m (ShapeDim n shape)) $ map (toArray (ShapeDim n shape)) $ -- Slight hack to work around empty dimensions. genericTake m $ transpose (map (snd . fromArray) xs') ++ repeat [] def "flatten" = Just $ fun1 $ \xs -> do let (ShapeDim n (ShapeDim m shape), xs') = fromArray xs pure $ toArray (ShapeDim (n * m) shape) $ concatMap (snd . fromArray) xs' def "unflatten" = Just $ fun3 $ \n m xs -> do let (ShapeDim xs_size innershape, xs') = fromArray xs rowshape = ShapeDim (asInt64 m) innershape shape = ShapeDim (asInt64 n) rowshape if asInt64 n * asInt64 m /= xs_size || asInt64 n < 0 || asInt64 m < 0 then bad mempty mempty $ "Cannot unflatten array of shape [" <> prettyText xs_size <> "] to array of shape [" <> prettyText (asInt64 n) <> "][" <> prettyText (asInt64 m) <> "]" else pure $ toArray shape $ map (toArray rowshape) $ chunk (asInt m) xs' def "manifest" = Just $ fun1 pure def "vjp2" = Just $ -- TODO: This could be much better. Currently, it is very inefficient -- Perhaps creating VJPValues could be abstracted into a function -- exposed by the AD module? fun3 $ \f v s -> do -- Get the depth depth <- length <$> stacktrace -- Augment the values let v' = fromMaybe (error $ "vjp: invalid values " ++ show v) $ modifyValueM (\i lv -> ValueAD depth . AD.VJP . AD.VJPValue . AD.TapeID i <$> getAD lv) v -- Turn the seeds into a list of ADValues let s' = fromMaybe (error $ "vjp: invalid seeds " ++ show s) $ mapM getAD $ fst $ valueAccum (\a b -> (b : a, b)) [] s -- Run the function, and turn its outputs into a list of Values o <- apply noLoc mempty f v' let o' = fst $ valueAccum (\a b -> (b : a, b)) [] o -- For each output.. let m = flip map (zip o' s') $ \(on, sn) -> case on of -- If it is a VJP variable of the correct depth, run -- deriveTapqe on it- and its corresponding seed (ValueAD d (AD.VJP (AD.VJPValue t))) | d == depth -> (putAD $ AD.tapePrimal t, AD.deriveTape t sn) -- Otherwise, its partial derivatives are all 0 _ -> (on, M.empty) -- Add together every derivative let drvs = M.map (Just . putAD) $ M.unionsWith add $ map snd m -- Extract the output values, and the partial derivatives let ov = modifyValue (\i _ -> fst $ m !! (length m - 1 - i)) o let od = fromMaybe (error "vjp: differentiation failed") $ modifyValueM (\i vo -> M.findWithDefault (ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD vo) i drvs) v -- Return a tuple of the output values, and partial derivatives pure $ toTuple [ov, od] where modifyValue f v = snd $ valueAccum (\a b -> (a + 1, f a b)) 0 v modifyValueM f v = snd <$> valueAccumLM ( \a b -> do b' <- f a b pure (a + 1, b') ) 0 v -- TODO: Perhaps this could be fully abstracted by AD? -- Making addFor private would be nice.. add x y = fromMaybe (error "jvp: illtyped add") $ AD.doOp (AD.OpBin $ AD.addFor $ P.primValueType $ AD.primitive x) [x, y] def "jvp2" = Just $ -- TODO: This could be much better. Currently, it is very inefficient -- Perhaps creating JVPValues could be abstracted into a function -- exposed by the AD module? fun3 $ \f v s -> do -- Get the depth depth <- length <$> stacktrace -- Turn the seeds into a list of ADValues let s' = expectJust ("jvp: invalid seeds " ++ show s) $ mapM getAD $ fst $ valueAccum (\a b -> (b : a, b)) [] s -- Augment the values let v' = expectJust ("jvp: invalid values " ++ show v) $ modifyValueM ( \i lv -> do lv' <- getAD lv pure $ ValueAD depth . AD.JVP . AD.JVPValue lv' $ s' !! (length s' - 1 - i) ) v -- Run the function, and turn its outputs into a list of Values o <- apply noLoc mempty f v' let o' = fst $ valueAccum (\a b -> (b : a, b)) [] o -- For each output.. let m = expectJust "jvp: differentiation failed" $ mapM ( \on -> case on of -- If it is a JVP variable of the correct depth, return its primal and derivative (ValueAD d (AD.JVP (AD.JVPValue pv dv))) | d == depth -> Just (putAD pv, putAD dv) -- Otherwise, its partial derivatives are all 0 _ -> (on,) . ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD on ) o' -- Extract the output values, and the partial derivatives let ov = modifyValue (\i _ -> fst $ m !! (length m - 1 - i)) o od = modifyValue (\i _ -> snd $ m !! (length m - 1 - i)) o -- Return a tuple of the output values, and partial derivatives pure $ toTuple [ov, od] where modifyValue f v = snd $ valueAccum (\a b -> (a + 1, f a b)) 0 v modifyValueM f v = snd <$> valueAccumLM ( \a b -> do b' <- f a b pure (a + 1, b') ) 0 v expectJust _ (Just v) = v expectJust s Nothing = error s def "acc" = Nothing def s | nameFromText s `M.member` namesToPrimTypes = Nothing def s = error $ "Missing intrinsic: " ++ T.unpack s tdef :: Name -> Maybe (Env, T.TypeBinding) tdef s = do t <- s `M.lookup` namesToPrimTypes pure (mempty, T.TypeAbbr Unlifted [] $ RetType [] $ Scalar $ Prim t) intrinsicVal :: Name -> Value intrinsicVal name = case M.lookup (intrinsicVar name) $ envTerm $ ctxEnv initialCtx of Just (TermValue _ v) -> v _ -> error $ "intrinsicVal: " <> prettyString name intrinsicsNeg :: Value intrinsicsNeg = intrinsicVal "neg" intrinsicsNot :: Value intrinsicsNot = intrinsicVal "!" interpretExp :: Ctx -> Exp -> F ExtOp Value interpretExp ctx e = runEvalM (ctxImports ctx) $ eval (ctxEnv ctx) e interpretDecs :: Ctx -> [Dec] -> F ExtOp Env interpretDecs ctx decs = runEvalM (ctxImports ctx) $ do env <- foldM evalDec (ctxEnv ctx) decs -- We need to extract any new existential sizes and add them as -- ordinary bindings to the context, or we will not be able to -- look up their values later. sizes <- extEnv pure $ env <> sizes interpretDec :: Ctx -> Dec -> F ExtOp Ctx interpretDec ctx d = do env <- interpretDecs ctx [d] pure ctx {ctxEnv = env} interpretImport :: Ctx -> (ImportName, Prog) -> F ExtOp Ctx interpretImport ctx (fp, prog) = do env <- interpretDecs ctx $ progDecs prog pure ctx {ctxImports = M.insert fp env $ ctxImports ctx} -- | Produce a context, based on the one passed in, where all of -- the provided imports have been @open@ened in order. ctxWithImports :: [Env] -> Ctx -> Ctx ctxWithImports envs ctx = ctx {ctxEnv = mconcat (reverse envs) <> ctxEnv ctx} valueType :: V.Value -> ValueType valueType v = let V.ValueType shape pt = V.valueType v in arrayOf (F.Shape (map fromIntegral shape)) (Scalar (Prim (toPrim pt))) where toPrim V.I8 = Signed Int8 toPrim V.I16 = Signed Int16 toPrim V.I32 = Signed Int32 toPrim V.I64 = Signed Int64 toPrim V.U8 = Unsigned Int8 toPrim V.U16 = Unsigned Int16 toPrim V.U32 = Unsigned Int32 toPrim V.U64 = Unsigned Int64 toPrim V.Bool = Bool toPrim V.F16 = FloatType Float16 toPrim V.F32 = FloatType Float32 toPrim V.F64 = FloatType Float64 checkEntryArgs :: VName -> [V.Value] -> StructType -> Either T.Text () checkEntryArgs entry args entry_t | args_ts == map toStruct param_ts = pure () | otherwise = Left . docText $ expected "Got input of types" indent 2 (stack (map pretty args_ts)) where (param_ts, _) = unfoldFunType entry_t args_ts = map (valueStructType . valueType) args expected | null param_ts = "Entry point " <> dquotes (prettyName entry) <> " is not a function." | otherwise = "Entry point " <> dquotes (prettyName entry) <> " expects input of type(s)" indent 2 (stack (map pretty param_ts)) -- | Execute the named function on the given arguments; may fail -- horribly if these are ill-typed. interpretFunction :: Ctx -> VName -> [V.Value] -> Either T.Text (F ExtOp Value) interpretFunction ctx fname vs = do let env = ctxEnv ctx (ft, mkf) <- case lookupVar (qualName fname) env of Just (TermValue (Just (T.BoundV _ t)) v) -> do ft <- updateType (map valueType vs) t pure (ft, pure v) Just (TermPoly (Just (T.BoundV _ t)) v) -> do ft <- updateType (map valueType vs) t pure (ft, v (structToEval env ft)) _ -> Left $ "Unknown function `" <> nameToText (toName fname) <> "`." let vs' = map fromDataValue vs checkEntryArgs fname vs ft Right $ runEvalM (ctxImports ctx) $ do f <- mkf foldM (apply noLoc mempty) f vs' where updateType (vt : vts) (Scalar (Arrow als pn d pt (RetType dims rt))) = do checkInput vt pt Scalar . Arrow als pn d (valueStructType vt) . RetType dims . toRes Nonunique <$> updateType vts (toStruct rt) updateType _ t = Right t checkInput :: ValueType -> StructType -> Either T.Text () checkInput (Scalar (Prim vt)) (Scalar (Prim pt)) | vt /= pt = badPrim vt pt checkInput (Array _ _ (Prim vt)) (Array _ _ (Prim pt)) | vt /= pt = badPrim vt pt checkInput vArr@(Array _ (F.Shape vd) _) pArr@(Array _ (F.Shape pd) _) | length vd /= length pd = badDim vArr pArr | not . and $ zipWith sameShape vd pd = badDim vArr pArr where sameShape :: Int64 -> Size -> Bool sameShape shape0 (IntLit shape1 _ _) = fromIntegral shape0 == shape1 sameShape _ _ = True checkInput _ _ = Right () badPrim vt pt = Left . docText $ "Invalid argument type." "Expected:" <+> align (pretty pt) "Got: " <+> align (pretty vt) badDim vd pd = Left . docText $ "Invalid argument dimensions." "Expected:" <+> align (pretty pd) "Got: " <+> align (pretty vd) futhark-0.25.27/src/Language/Futhark/Interpreter/000077500000000000000000000000001475065116200215635ustar00rootroot00000000000000futhark-0.25.27/src/Language/Futhark/Interpreter/AD.hs000066400000000000000000000262451475065116200224140ustar00rootroot00000000000000module Language.Futhark.Interpreter.AD ( Op (..), ADVariable (..), ADValue (..), Tape (..), VJPValue (..), JVPValue (..), doOp, addFor, tapePrimal, primitive, varPrimal, deriveTape, ) where import Control.Monad (foldM, zipWithM) import Data.Either (isRight) import Data.List (find, foldl') import Data.Map qualified as M import Data.Maybe (fromMaybe) import Data.Text qualified as T import Futhark.AD.Derivatives (pdBinOp, pdBuiltin, pdUnOp) import Futhark.Analysis.PrimExp (PrimExp (..)) import Language.Futhark.Core (VName (..), nameFromString, nameFromText) import Language.Futhark.Primitive -- Mathematical operations subject to AD. data Op = OpBin BinOp | OpCmp CmpOp | OpUn UnOp | OpFn T.Text | OpConv ConvOp deriving (Show) -- Checks if an operation matches the types of its operands opTypeMatch :: Op -> [PrimType] -> Bool opTypeMatch (OpBin op) p = all (\x -> binOpType op == x) p opTypeMatch (OpCmp op) p = all (\x -> cmpOpType op == x) p opTypeMatch (OpUn op) p = all (\x -> unOpType op == x) p opTypeMatch (OpConv op) p = all (\x -> fst (convOpType op) == x) p opTypeMatch (OpFn fn) p = case M.lookup fn primFuns of Just (t, _, _) -> and $ zipWith (==) t p Nothing -> error "opTypeMatch" -- It is assumed that the function exists -- Gets the return type of an operation opReturnType :: Op -> PrimType opReturnType (OpBin op) = binOpType op opReturnType (OpCmp op) = cmpOpType op opReturnType (OpUn op) = unOpType op opReturnType (OpConv op) = snd $ convOpType op opReturnType (OpFn fn) = case M.lookup fn primFuns of Just (_, t, _) -> t Nothing -> error "opReturnType" -- It is assumed that the function exists -- Returns the operation which performs addition (or an -- equivalent operation) on the given type addFor :: PrimType -> BinOp addFor (IntType t) = Add t OverflowWrap addFor (FloatType t) = FAdd t addFor Bool = LogOr addFor t = error $ "addFor: " ++ show t -- Returns the function which performs multiplication -- (or an equivalent operation) on the given type mulFor :: PrimType -> BinOp mulFor (IntType t) = Mul t OverflowWrap mulFor (FloatType t) = FMul t mulFor Bool = LogAnd mulFor t = error $ "mulFor: " ++ show t -- Types and utility functions-- -- When taking the partial derivative of a function, we -- must differentiate between the values which are kept -- constant, and those which are not data ADValue = Variable Int ADVariable | Constant PrimValue deriving (Show) -- When performing automatic differentiation, each derived -- variable must be augmented with additional data. This -- value holds the primitive value of the variable, as well -- as its data data ADVariable = VJP VJPValue | JVP JVPValue deriving (Show) depth :: ADValue -> Int depth (Variable d _) = d depth (Constant _) = 0 primal :: ADValue -> ADValue primal (Variable _ (VJP (VJPValue t))) = tapePrimal t primal (Variable _ (JVP (JVPValue v _))) = primal v primal (Constant v) = Constant v primitive :: ADValue -> PrimValue primitive (Variable _ v) = varPrimal v primitive (Constant v) = v varPrimal :: ADVariable -> PrimValue varPrimal (VJP (VJPValue t)) = primitive $ tapePrimal t varPrimal (JVP (JVPValue v _)) = primitive $ primal v -- Evaluates a PrimExp using doOp evalPrimExp :: M.Map VName ADValue -> PrimExp VName -> Maybe ADValue evalPrimExp m (LeafExp n _) = M.lookup n m evalPrimExp _ (ValueExp pv) = Just $ Constant pv evalPrimExp m (BinOpExp op x y) = do x' <- evalPrimExp m x y' <- evalPrimExp m y doOp (OpBin op) [x', y'] evalPrimExp m (CmpOpExp op x y) = do x' <- evalPrimExp m x y' <- evalPrimExp m y doOp (OpCmp op) [x', y'] evalPrimExp m (UnOpExp op x) = do x' <- evalPrimExp m x doOp (OpUn op) [x'] evalPrimExp m (ConvOpExp op x) = do x' <- evalPrimExp m x doOp (OpConv op) [x'] evalPrimExp m (FunExp fn p _) = do p' <- mapM (evalPrimExp m) p doOp (OpFn fn) p' -- Returns a list of PrimExps calculating the partial -- derivative of each operands of a given operation lookupPDs :: Op -> [PrimExp VName] -> Maybe [PrimExp VName] lookupPDs (OpBin op) [x, y] = Just $ do let (a, b) = pdBinOp op x y [a, b] lookupPDs (OpUn op) [x] = Just [pdUnOp op x] lookupPDs (OpFn fn) p = pdBuiltin (nameFromText fn) p lookupPDs _ _ = Nothing -- Shared AD logic-- -- This function performs a mathematical operation on a -- list of operands, performing automatic differentiation -- if one or more operands is a Variable (of depth > 0) doOp :: Op -> [ADValue] -> Maybe ADValue doOp op o | not $ opTypeMatch op (map primValueType pv) = -- This function may be called with arguments of invalid types, -- because it is used as part of an overloaded operator. Nothing | otherwise = do let dep = case op of OpCmp _ -> 0 -- AD is not well-defined for comparason operations -- There are no derivatives for those written in -- PrimExp (check lookupPDs) _ -> maximum (map depth o) if dep == 0 then constCase else nonconstCase dep where pv = map primitive o divideDepths :: Int -> ADValue -> Either ADValue ADVariable divideDepths _ v@(Constant {}) = Left v divideDepths d v@(Variable d' v') = if d' < d then Left v else Right v' -- TODO: There may be a more graceful way of -- doing this extractVJP :: Either ADValue ADVariable -> Either ADValue VJPValue extractVJP (Right (VJP v)) = Right v extractVJP (Left v) = Left v extractVJP _ = -- This will never be called when the maximum depth layer is JVP error "extractVJP" -- TODO: There may be a more graceful way of -- doing this extractJVP :: Either ADValue ADVariable -> Either ADValue JVPValue extractJVP (Right (JVP v)) = Right v extractJVP (Left v) = Left v extractJVP _ = -- This will never be called when the maximum depth layer is VJP error "extractJVP" -- In this case, every operand is a constant, and the -- mathematical operation can be applied as it would be -- otherwise constCase = Constant <$> case (op, pv) of (OpBin op', [x, y]) -> doBinOp op' x y (OpCmp op', [x, y]) -> BoolValue <$> doCmpOp op' x y (OpUn op', [x]) -> doUnOp op' x (OpConv op', [x]) -> doConvOp op' x (OpFn fn, _) -> do (_, _, f) <- M.lookup fn primFuns f pv _ -> error "doOp: opTypeMatch" nonconstCase dep = do -- In this case, some values are variables. We therefore -- have to perform the necessary steps for AD -- First, we calculate the value for the previous depth let oprev = map primal o vprev <- doOp op oprev -- Then we separate the values of the maximum depth from -- those of a lower depth let o' = map (divideDepths dep) o -- Then we find out what type of AD is being performed case find isRight o' of -- Finally, we perform the necessary steps for the given -- type of AD Just (Right (VJP {})) -> Just . Variable dep . VJP . VJPValue $ vjpHandleOp op (map extractVJP o') vprev Just (Right (JVP {})) -> Variable dep . JVP . JVPValue vprev <$> jvpHandleFn op (map extractJVP o') _ -> -- Since the maximum depth is non-zero, there must be at -- least one variable of depth > 0 error "find isRight" calculatePDs :: Op -> [ADValue] -> [ADValue] calculatePDs op p = -- Create a unique VName for each operand let n = map (\i -> VName (nameFromString $ "x" ++ show i) i) [1 .. length p] -- Put the operands in the environment m = M.fromList $ zip n p -- Look up, and calculate the partial derivative -- of the operation with respect to each operand pde = fromMaybe (error "lookupPDs failed") $ lookupPDs op $ map (`LeafExp` opReturnType op) n in map (fromMaybe (error "evalPrimExp failed") . evalPrimExp m) pde -- VJP / Reverse mode automatic differentiation-- -- In reverse mode AD, the entire computation -- leading up to a variable must be saved -- This is represented as a Tape newtype VJPValue = VJPValue Tape deriving (Show) -- | Represents a computation tree, as well as every intermediate -- value in its evaluation. TODO: make this a graph. data Tape = -- | This represents a variable. Each variable is given a unique ID, -- and has an initial value TapeID Int ADValue | -- | This represents a constant. TapeConst ADValue | -- | This represents the application of a mathematical operation. -- Each parameter is given by its Tape, and the return value of -- the operation is saved TapeOp Op [Tape] ADValue deriving (Show) -- | Returns the primal value of a Tape. tapePrimal :: Tape -> ADValue tapePrimal (TapeID _ v) = v tapePrimal (TapeConst v) = v tapePrimal (TapeOp _ _ v) = v -- This updates Tape of a VJPValue with a new operation, -- treating all operands of a lower depth as constants vjpHandleOp :: Op -> [Either ADValue VJPValue] -> ADValue -> Tape vjpHandleOp op p v = do TapeOp op (map toTape p) v where toTape (Left v') = TapeConst v' toTape (Right (VJPValue t)) = t -- | This calculates every partial derivative of a 'Tape'. The result -- is a map of the partial derivatives, each key corresponding to the -- ID of a free variable (see TapeID). deriveTape :: Tape -> ADValue -> M.Map Int ADValue deriveTape (TapeID i _) s = M.fromList [(i, s)] deriveTape (TapeConst _) _ = M.empty deriveTape (TapeOp op p _) s = -- Calculate the new sensitivities let s'' = case op of OpConv op' -> -- In case of type conversion, simply convert the sensitivity [ fromMaybe (error "deriveTape: doOp failed") $ doOp (OpConv $ flipConvOp op') [s] ] _ -> map (mul s) $ calculatePDs op $ map tapePrimal p -- Propagate the new sensitivities pd = zipWith deriveTape p s'' in -- Add up the results foldl' (M.unionWith add) M.empty pd where add x y = fromMaybe (error "deriveTape: add failed") $ doOp (OpBin $ addFor $ opReturnType op) [x, y] mul x y = fromMaybe (error "deriveTape: mul failed") $ doOp (OpBin $ mulFor $ opReturnType op) [x, y] -- JVP / Forward mode automatic differentiation-- -- | In JVP, the derivative of the variable must be saved. This is -- represented as a second value. data JVPValue = JVPValue ADValue ADValue deriving (Show) -- | This calculates the derivative part of the JVPValue resulting -- from the application of a mathematical operation on one or more -- JVPValues. jvpHandleFn :: Op -> [Either ADValue JVPValue] -> Maybe ADValue jvpHandleFn op p = do case op of OpConv _ -> -- In case of type conversion, simply convert -- the old derivative doOp op [derivative $ head p] _ -> do -- Calculate the new derivative using the chain -- rule let pds = calculatePDs op $ map primal' p vs <- zipWithM mul pds $ map derivative p foldM add (Constant $ blankPrimValue $ opReturnType op) vs where primal' (Left v) = v primal' (Right (JVPValue v _)) = v derivative (Left v) = Constant $ blankPrimValue $ primValueType $ primitive v derivative (Right (JVPValue _ d)) = d add x y = doOp (OpBin $ addFor $ opReturnType op) [x, y] mul x y = doOp (OpBin $ mulFor $ opReturnType op) [x, y] futhark-0.25.27/src/Language/Futhark/Interpreter/Values.hs000066400000000000000000000264031475065116200233630ustar00rootroot00000000000000-- | The value representation used in the interpreter. -- -- Kept simple and free of unnecessary operational details (in -- particular, no references to the interpreter monad). module Language.Futhark.Interpreter.Values ( -- * Shapes Shape (..), ValueShape, typeShape, structTypeShape, -- * Values Value (..), valueShape, prettyValue, valueText, valueAccum, valueAccumLM, fromTuple, arrayLength, isEmptyArray, prettyEmptyArray, toArray, toArray', toTuple, -- * Conversion fromDataValue, ) where import Data.Array import Data.Bifunctor (Bifunctor (second)) import Data.List (genericLength) import Data.Map qualified as M import Data.Maybe import Data.Monoid hiding (Sum) import Data.Text qualified as T import Data.Vector.Storable qualified as SVec import Futhark.Data qualified as V import Futhark.Util (chunk, mapAccumLM) import Futhark.Util.Pretty import Language.Futhark hiding (Shape, matchDims) import Language.Futhark.Interpreter.AD qualified as AD import Language.Futhark.Primitive qualified as P import Prelude hiding (break, mod) prettyRecord :: (a -> Doc ann) -> M.Map Name a -> Doc ann prettyRecord p m | Just vs <- areTupleFields m = parens $ align $ vsep $ punctuate comma $ map p vs | otherwise = braces $ align $ vsep $ punctuate comma $ map field $ M.toList m where field (k, v) = pretty k <+> equals <+> p v -- | A shape is a tree to accomodate the case of records. It is -- parameterised over the representation of dimensions. data Shape d = ShapeDim d (Shape d) | ShapeLeaf | ShapeRecord (M.Map Name (Shape d)) | ShapeSum (M.Map Name [Shape d]) deriving (Eq, Show, Functor, Foldable, Traversable) -- | The shape of an array. type ValueShape = Shape Int64 instance (Pretty d) => Pretty (Shape d) where pretty ShapeLeaf = mempty pretty (ShapeDim d s) = brackets (pretty d) <> pretty s pretty (ShapeRecord m) = prettyRecord pretty m pretty (ShapeSum cs) = mconcat (punctuate " | " cs') where ppConstr (name, fs) = sep $ ("#" <> pretty name) : map pretty fs cs' = map ppConstr $ M.toList cs emptyShape :: ValueShape -> Bool emptyShape (ShapeDim d s) = d == 0 || emptyShape s emptyShape _ = False typeShape :: TypeBase d u -> Shape d typeShape (Array _ shape et) = foldr ShapeDim (typeShape (Scalar et)) $ shapeDims shape typeShape (Scalar (Record fs)) = ShapeRecord $ M.map typeShape fs typeShape (Scalar (Sum cs)) = ShapeSum $ M.map (map typeShape) cs typeShape t | Just t' <- isAccType t = typeShape t' | otherwise = ShapeLeaf structTypeShape :: StructType -> Shape (Maybe Int64) structTypeShape = fmap dim . typeShape where dim (IntLit x _ _) = Just $ fromIntegral x dim _ = Nothing -- | A fully evaluated Futhark value. data Value m = ValuePrim !PrimValue | ValueArray ValueShape !(Array Int (Value m)) | -- Stores the full shape. ValueRecord (M.Map Name (Value m)) | ValueFun (Value m -> m (Value m)) | -- Stores the full shape. ValueSum ValueShape Name [Value m] | -- The shape, the update function, and the array. ValueAcc ValueShape (Value m -> Value m -> m (Value m)) !(Array Int (Value m)) | -- A primitive value with added information used in automatic differentiation ValueAD Int AD.ADVariable instance Show (Value m) where show (ValuePrim v) = "ValuePrim " <> show v <> "" show (ValueArray shape vs) = unwords ["ValueArray", "(" <> show shape <> ")", "(" <> show vs <> ")"] show (ValueRecord fs) = "ValueRecord " <> "(" <> show fs <> ")" show (ValueSum shape c vs) = unwords ["ValueSum", "(" <> show shape <> ")", show c, "(" <> show vs <> ")"] show ValueFun {} = "ValueFun _" show ValueAcc {} = "ValueAcc _" show (ValueAD d v) = unwords ["ValueAD", show d, show v] instance Eq (Value m) where ValuePrim (SignedValue x) == ValuePrim (SignedValue y) = P.doCmpEq (P.IntValue x) (P.IntValue y) ValuePrim (UnsignedValue x) == ValuePrim (UnsignedValue y) = P.doCmpEq (P.IntValue x) (P.IntValue y) ValuePrim (FloatValue x) == ValuePrim (FloatValue y) = P.doCmpEq (P.FloatValue x) (P.FloatValue y) ValuePrim (BoolValue x) == ValuePrim (BoolValue y) = P.doCmpEq (P.BoolValue x) (P.BoolValue y) ValueArray _ x == ValueArray _ y = x == y ValueRecord x == ValueRecord y = x == y ValueSum _ n1 vs1 == ValueSum _ n2 vs2 = n1 == n2 && vs1 == vs2 ValueAcc _ _ x == ValueAcc _ _ y = x == y _ == _ = False prettyValueWith :: (PrimValue -> Doc a) -> Value m -> Doc a prettyValueWith pprPrim = pprPrec 0 where pprPrec _ (ValuePrim v) = pprPrim v pprPrec _ (ValueArray _ a) = let elements = elems a -- [Value] separator = case elements of ValueArray _ _ : _ -> vsep _ -> hsep in brackets $ align $ separator $ punctuate comma $ map pprElem elements pprPrec _ (ValueRecord m) = prettyRecord (pprPrec 0) m pprPrec _ ValueFun {} = "#" pprPrec _ ValueAcc {} = "#" pprPrec p (ValueSum _ n vs) = parensIf (p > (0 :: Int)) $ "#" <> sep (pretty n : map (pprPrec 1) vs) pprPrec _ (ValueAD _ v) = pprPrim $ putV $ AD.varPrimal v pprElem v@ValueArray {} = pprPrec 0 v pprElem v = group $ pprPrec 0 v putV (P.IntValue x) = SignedValue x putV (P.FloatValue x) = FloatValue x putV (P.BoolValue x) = BoolValue x putV P.UnitValue = BoolValue True -- | Prettyprint value. prettyValue :: Value m -> Doc a prettyValue = prettyValueWith pprPrim where pprPrim (UnsignedValue (Int8Value v)) = pretty (fromIntegral v :: Word8) pprPrim (UnsignedValue (Int16Value v)) = pretty (fromIntegral v :: Word16) pprPrim (UnsignedValue (Int32Value v)) = pretty (fromIntegral v :: Word32) pprPrim (UnsignedValue (Int64Value v)) = pretty (fromIntegral v :: Word64) pprPrim (SignedValue (Int8Value v)) = pretty v pprPrim (SignedValue (Int16Value v)) = pretty v pprPrim (SignedValue (Int32Value v)) = pretty v pprPrim (SignedValue (Int64Value v)) = pretty v pprPrim (BoolValue True) = "true" pprPrim (BoolValue False) = "false" pprPrim (FloatValue (Float16Value v)) = pprFloat "f16." v pprPrim (FloatValue (Float32Value v)) = pprFloat "f32." v pprPrim (FloatValue (Float64Value v)) = pprFloat "f64." v pprFloat t v | isInfinite v, v >= 0 = t <> "inf" | isInfinite v, v < 0 = "-" <> t <> "inf" | isNaN v = t <> "nan" | otherwise = pretty $ show v -- | The value in the textual format. valueText :: Value m -> T.Text valueText = docText . prettyValueWith pretty valueShape :: Value m -> ValueShape valueShape (ValueArray shape _) = shape valueShape (ValueAcc shape _ _) = shape valueShape (ValueRecord fs) = ShapeRecord $ M.map valueShape fs valueShape (ValueSum shape _ _) = shape valueShape _ = ShapeLeaf -- TODO: Perhaps there is some clever way to reuse the code between -- valueAccum and valueAccumLM valueAccum :: (a -> Value m -> (a, Value m)) -> a -> Value m -> (a, Value m) valueAccum f i v@(ValuePrim {}) = f i v valueAccum f i v@(ValueAD {}) = f i v valueAccum f i (ValueRecord m) = second ValueRecord $ M.mapAccum (valueAccum f) i m valueAccum f i (ValueArray s a) = do -- TODO: This could probably be better -- Transform into a map let m = M.fromList $ assocs a -- Accumulate over the map let (i', m') = M.mapAccum (valueAccum f) i m -- Transform back into an array and return let a' = array (bounds a) (M.toList m') (i', ValueArray s a') valueAccum _ _ v = error $ "valueAccum not implemented for " ++ show v valueAccumLM :: (Monad f) => (a -> Value m -> f (a, Value m)) -> a -> Value m -> f (a, Value m) valueAccumLM f i v@(ValuePrim {}) = f i v valueAccumLM f i v@(ValueAD {}) = f i v valueAccumLM f i (ValueRecord m) = do (a, b) <- mapAccumLM (valueAccumLM f) i m pure (a, ValueRecord b) valueAccumLM f i (ValueArray s a) = do -- TODO: This could probably be better -- Transform into a map let m = M.fromList $ assocs a -- Accumulate over the map (i', m') <- mapAccumLM (valueAccumLM f) i m -- Transform back into an array and return let a' = array (bounds a) (M.toList m') pure (i', ValueArray s a') valueAccumLM _ _ v = error $ "valueAccum not implemented for " ++ show v -- | Does the value correspond to an empty array? isEmptyArray :: Value m -> Bool isEmptyArray = emptyShape . valueShape -- | String representation of an empty array with the provided element -- type. This is pretty ad-hoc - don't expect good results unless the -- element type is a primitive. prettyEmptyArray :: TypeBase () () -> Value m -> T.Text prettyEmptyArray t v = "empty(" <> dims (valueShape v) <> prettyText t' <> ")" where t' = stripArray (arrayRank t) t dims (ShapeDim n rowshape) = "[" <> prettyText n <> "]" <> dims rowshape dims _ = "" toArray :: ValueShape -> [Value m] -> Value m toArray shape vs = ValueArray shape (listArray (0, length vs - 1) vs) toArray' :: ValueShape -> [Value m] -> Value m toArray' rowshape vs = ValueArray shape (listArray (0, length vs - 1) vs) where shape = ShapeDim (genericLength vs) rowshape arrayLength :: (Integral int) => Array Int (Value m) -> int arrayLength = fromIntegral . (+ 1) . snd . bounds toTuple :: [Value m] -> Value m toTuple = ValueRecord . M.fromList . zip tupleFieldNames fromTuple :: Value m -> Maybe [Value m] fromTuple (ValueRecord m) = areTupleFields m fromTuple _ = Nothing fromDataShape :: V.Vector Int -> ValueShape fromDataShape = foldr (ShapeDim . fromIntegral) ShapeLeaf . SVec.toList fromDataValueWith :: (SVec.Storable a) => (a -> PrimValue) -> SVec.Vector Int -> SVec.Vector a -> Value m fromDataValueWith f shape vector | SVec.null shape = ValuePrim $ f $ SVec.head vector | SVec.null vector = toArray (fromDataShape shape) $ replicate (SVec.head shape) (fromDataValueWith f shape' vector) | otherwise = toArray (fromDataShape shape) . map (fromDataValueWith f shape' . SVec.fromList) $ chunk (SVec.product shape') (SVec.toList vector) where shape' = SVec.tail shape -- | Convert a Futhark value in the externally observable data format -- to an interpreter value. fromDataValue :: V.Value -> Value m fromDataValue (V.I8Value shape vector) = fromDataValueWith (SignedValue . Int8Value) shape vector fromDataValue (V.I16Value shape vector) = fromDataValueWith (SignedValue . Int16Value) shape vector fromDataValue (V.I32Value shape vector) = fromDataValueWith (SignedValue . Int32Value) shape vector fromDataValue (V.I64Value shape vector) = fromDataValueWith (SignedValue . Int64Value) shape vector fromDataValue (V.U8Value shape vector) = fromDataValueWith (UnsignedValue . Int8Value . fromIntegral) shape vector fromDataValue (V.U16Value shape vector) = fromDataValueWith (UnsignedValue . Int16Value . fromIntegral) shape vector fromDataValue (V.U32Value shape vector) = fromDataValueWith (UnsignedValue . Int32Value . fromIntegral) shape vector fromDataValue (V.U64Value shape vector) = fromDataValueWith (UnsignedValue . Int64Value . fromIntegral) shape vector fromDataValue (V.F16Value shape vector) = fromDataValueWith (FloatValue . Float16Value) shape vector fromDataValue (V.F32Value shape vector) = fromDataValueWith (FloatValue . Float32Value) shape vector fromDataValue (V.F64Value shape vector) = fromDataValueWith (FloatValue . Float64Value) shape vector fromDataValue (V.BoolValue shape vector) = fromDataValueWith BoolValue shape vector futhark-0.25.27/src/Language/Futhark/Parser.hs000066400000000000000000000036321475065116200210540ustar00rootroot00000000000000-- | Interface to the Futhark parser. module Language.Futhark.Parser ( parseFuthark, parseFutharkWithComments, parseExp, parseModExp, parseType, parseDecOrExp, SyntaxError (..), Comment (..), ) where import Data.Text qualified as T import Language.Futhark.Parser.Parser import Language.Futhark.Prop import Language.Futhark.Syntax -- | Parse an entire Futhark program from the given 'T.Text', using -- the 'FilePath' as the source name for error messages. parseFuthark :: FilePath -> T.Text -> Either SyntaxError UncheckedProg parseFuthark = parse prog -- | Parse an entire Futhark program from the given 'T.Text', using -- the 'FilePath' as the source name for error messages. Also returns -- the comments encountered. parseFutharkWithComments :: FilePath -> T.Text -> Either SyntaxError (UncheckedProg, [Comment]) parseFutharkWithComments = parseWithComments prog -- | Parse an Futhark expression from the given 'String', using the -- 'FilePath' as the source name for error messages. parseExp :: FilePath -> T.Text -> Either SyntaxError UncheckedExp parseExp = parse expression -- | Parse a Futhark module expression from the given 'String', using the -- 'FilePath' as the source name for error messages. parseModExp :: FilePath -> T.Text -> Either SyntaxError (ModExpBase NoInfo Name) parseModExp = parse modExpression -- | Parse an Futhark type from the given 'String', using the -- 'FilePath' as the source name for error messages. parseType :: FilePath -> T.Text -> Either SyntaxError UncheckedTypeExp parseType = parse futharkType -- | Parse either an expression or a declaration; favouring -- declarations in case of ambiguity. parseDecOrExp :: FilePath -> T.Text -> Either SyntaxError (Either UncheckedDec UncheckedExp) parseDecOrExp file input = case parse declaration file input of Left {} -> Right <$> parseExp file input Right d -> Right $ Left d futhark-0.25.27/src/Language/Futhark/Parser/000077500000000000000000000000001475065116200205145ustar00rootroot00000000000000futhark-0.25.27/src/Language/Futhark/Parser/Lexer.x000066400000000000000000000224511475065116200217700ustar00rootroot00000000000000{ {-# OPTIONS_GHC -w #-} -- | The Futhark lexer. Takes a string, produces a list of tokens with position information. module Language.Futhark.Parser.Lexer ( Token(..) , getToken , scanTokensText ) where import Data.Bifunctor (second) import qualified Data.ByteString.Lazy as BS import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Text.Read as T import Data.Char (chr, ord, toLower) import Data.Int (Int8, Int16, Int32, Int64) import Data.Word (Word8) import Data.Loc (Loc (..), L(..), Pos(..)) import Data.Function (fix) import Language.Futhark.Core (Int8, Int16, Int32, Int64, Word8, Word16, Word32, Word64, Name, nameFromText, nameToText, nameFromString) import Language.Futhark.Prop (leadingOperator) import Language.Futhark.Syntax (BinOp(..)) import Language.Futhark.Parser.Lexer.Wrapper import Language.Futhark.Parser.Lexer.Tokens } @charlit = ($printable#['\\]|\\($printable|[0-9]+)) @stringcharlit = ($printable#[\"\\]|\\($printable|[0-9]+)) @hexlit = 0[xX][0-9a-fA-F][0-9a-fA-F_]* @declit = [0-9][0-9_]* @binlit = 0[bB][01][01_]* @romlit = 0[rR][IVXLCDM][IVXLCDM_]* @reallit = (([0-9][0-9_]*("."[0-9][0-9_]*)?))([eE][\+\-]?[0-9]+)? @hexreallit = 0[xX][0-9a-fA-F][0-9a-fA-F_]*"."[0-9a-fA-F][0-9a-fA-F_]*([pP][\+\-]?[0-9_]+) @field = [a-zA-Z0-9] [a-zA-Z0-9_]* @constituent = [a-zA-Z0-9_'] @identifier = [a-zA-Z] @constituent* | "_" [a-zA-Z0-9] @constituent* @qualidentifier = (@identifier ".")+ @identifier $opchar = [\+\-\*\/\%\=\!\>\<\|\&\^\.] @binop = ($opchar # \.) $opchar* @qualbinop = (@identifier ".")+ @binop @space = [\ \t\f\v] @doc = "-- |".*(\n@space*"--".*)* tokens :- $white+ ; @doc { tokenS $ DOC . T.intercalate "\n" . map (T.drop 3 . T.stripStart) . T.split (== '\n') . ("--"<>) . T.drop 4 } "--".* { tokenS COMMENT } "=" { tokenC EQU } "(" { tokenC LPAR } ")" { tokenC RPAR } [a-zA-Z0-9_'] ^ "[" { tokenC INDEXING } \) ^ "[" { tokenC INDEXING } \] ^ "[" { tokenC INDEXING } "[" { tokenC LBRACKET } "]" { tokenC RBRACKET } "{" { tokenC LCURLY } "}" { tokenC RCURLY } "," { tokenC COMMA } "_" { tokenC UNDERSCORE } "?" { tokenC QUESTION_MARK } "->" { tokenC RIGHT_ARROW } ":" { tokenC COLON } ":>" { tokenC COLON_GT } "\" { tokenC BACKSLASH } "~" { tokenC TILDE } "'" { tokenC APOSTROPHE } "'^" { tokenC APOSTROPHE_THEN_HAT } "'~" { tokenC APOSTROPHE_THEN_TILDE } "`" { tokenC BACKTICK } "#[" { tokenC HASH_LBRACKET } "..<" { tokenC TWO_DOTS_LT } "..>" { tokenC TWO_DOTS_GT } "..." { tokenC THREE_DOTS } ".." { tokenC TWO_DOTS } "." { tokenC DOT } "!" { tokenC BANG } "$" { tokenC DOLLAR } "???" { tokenC HOLE } @declit i8 { decToken I8LIT . BS.dropEnd 2 } @binlit i8 { binToken I8LIT . BS.drop 2 . BS.dropEnd 2 } @hexlit i8 { hexToken I8LIT . BS.drop 2 . BS.dropEnd 2 } @romlit i8 { romToken I8LIT . BS.drop 2 . BS.dropEnd 2 } @declit i16 { decToken I16LIT . BS.dropEnd 3 } @binlit i16 { binToken I16LIT . BS.drop 2 . BS.dropEnd 3 } @hexlit i16 { hexToken I16LIT . BS.drop 2 . BS.dropEnd 3 } @romlit i16 { romToken I16LIT . BS.drop 2 . BS.dropEnd 3 } @declit i32 { decToken I32LIT . BS.dropEnd 3 } @binlit i32 { binToken I32LIT . BS.drop 2 . BS.dropEnd 3 } @hexlit i32 { hexToken I32LIT . BS.drop 2 . BS.dropEnd 3 } @romlit i32 { romToken I32LIT . BS.drop 2 . BS.dropEnd 3 } @declit i64 { decToken I64LIT . BS.dropEnd 3 } @binlit i64 { binToken I64LIT . BS.drop 2 . BS.dropEnd 3 } @hexlit i64 { hexToken I64LIT . BS.drop 2 . BS.dropEnd 3 } @romlit i64 { romToken I64LIT . BS.drop 2 . BS.dropEnd 3 } @declit u8 { decToken U8LIT . BS.dropEnd 2 } @binlit u8 { binToken U8LIT . BS.drop 2 . BS.dropEnd 2 } @hexlit u8 { hexToken U8LIT . BS.drop 2 . BS.dropEnd 2 } @romlit u8 { romToken U8LIT . BS.drop 2 . BS.dropEnd 2 } @declit u16 { decToken U16LIT . BS.dropEnd 3 } @binlit u16 { binToken U16LIT . BS.drop 2 . BS.dropEnd 3 } @hexlit u16 { hexToken U16LIT . BS.drop 2 . BS.dropEnd 3 } @romlit u16 { romToken U16LIT . BS.drop 2 . BS.dropEnd 3 } @declit u32 { decToken U32LIT . BS.dropEnd 3 } @binlit u32 { binToken U32LIT . BS.drop 2 . BS.dropEnd 3 } @hexlit u32 { hexToken U32LIT . BS.drop 2 . BS.dropEnd 3 } @romlit u32 { romToken U32LIT . BS.drop 2 . BS.dropEnd 3 } @declit u64 { decToken U64LIT . BS.dropEnd 3 } @binlit u64 { binToken U64LIT . BS.drop 2 . BS.dropEnd 3 } @hexlit u64 { hexToken U64LIT . BS.drop 2 . BS.dropEnd 3 } @romlit u64 { romToken U64LIT . BS.drop 2 . BS.dropEnd 3 } @declit { \s -> decToken (NATLIT (nameFromBS s)) s } @binlit { binToken INTLIT . BS.drop 2 } @hexlit { hexToken INTLIT . BS.drop 2 } @romlit { romToken INTLIT . BS.drop 2 } [\n[^\.]] ^ @reallit f16 { tokenS $ F16LIT . tryRead "f16" . suffZero . T.filter (/= '_') . T.takeWhile (/='f') } [\n[^\.]] ^ @reallit f32 { tokenS $ F32LIT . tryRead "f32" . suffZero . T.filter (/= '_') . T.takeWhile (/='f') } [\n[^\.]] ^ @reallit f64 { tokenS $ F64LIT . tryRead "f64" . suffZero . T.filter (/= '_') . T.takeWhile (/='f') } [\n[^\.]] ^ @reallit { tokenS $ FLOATLIT . tryRead "f64" . suffZero . T.filter (/= '_') } @hexreallit f16 { tokenS $ F16LIT . readHexRealLit . T.filter (/= '_') . T.dropEnd 3 } @hexreallit f32 { tokenS $ F32LIT . readHexRealLit . T.filter (/= '_') . T.dropEnd 3 } @hexreallit f64 { tokenS $ F64LIT . readHexRealLit . T.filter (/= '_') . T.dropEnd 3 } @hexreallit { tokenS $ FLOATLIT . readHexRealLit . T.filter (/= '_') } "'" @charlit "'" { tokenS $ CHARLIT . tryRead "char" } \" @stringcharlit* \" { tokenS $ STRINGLIT . T.pack . tryRead "string" } "true" { tokenC TRUE } "false" { tokenC FALSE } "if" { tokenC IF } "then" { tokenC THEN } "else" { tokenC ELSE } "def" { tokenC DEF } "let" { tokenC LET } "loop" { tokenC LOOP } "in" { tokenC IN } "val" { tokenC VAL } "for" { tokenC FOR } "do" { tokenC DO } "with" { tokenC WITH } "local" { tokenC LOCAL } "open" { tokenC OPEN } "include" { tokenC INCLUDE } "import" { tokenC IMPORT } "type" { tokenC TYPE } "entry" { tokenC ENTRY } "module" { tokenC MODULE } "while" { tokenC WHILE } "assert" { tokenC ASSERT } "match" { tokenC MATCH } "case" { tokenC CASE } @identifier { tokenS $ ID . nameFromText } "#" @identifier { tokenS $ CONSTRUCTOR . nameFromText . T.drop 1 } @binop { tokenS $ symbol [] . nameFromText } @qualbinop { tokenS $ uncurry symbol . mkQualId } . { tokenS ERROR } { nameFromBS :: BS.ByteString -> Name nameFromBS = nameFromString . map (chr . fromIntegral) . BS.unpack getToken :: AlexInput -> Either LexerError (AlexInput, (Pos, Pos, Token)) getToken state@(pos,c,s,n) = case alexScan state 0 of AlexEOF -> Right (state, (pos, pos, EOF)) AlexError (pos,_,_,_) -> Left $ LexerError (Loc pos pos) "Invalid lexical syntax." AlexSkip state' _len -> getToken state' AlexToken state'@(pos',_,_,n') _ action -> do let x = action (BS.take (n'-n) s) x `seq` Right (state', (pos, pos', x)) scanTokens :: Pos -> BS.ByteString -> Either LexerError [L Token] scanTokens pos str = fmap reverse $ loop [] $ initialLexerState pos str where loop toks s = do (s', tok) <- getToken s case tok of (start, end, EOF) -> pure toks (start, end, t) -> loop (L (Loc start end) t:toks) s' -- | Given a starting position, produce tokens from the given text (or -- a lexer error). Returns the final position. scanTokensText :: Pos -> T.Text -> Either LexerError [L Token] scanTokensText pos = scanTokens pos . BS.fromStrict . T.encodeUtf8 } futhark-0.25.27/src/Language/Futhark/Parser/Lexer/000077500000000000000000000000001475065116200215735ustar00rootroot00000000000000futhark-0.25.27/src/Language/Futhark/Parser/Lexer/Tokens.hs000066400000000000000000000125621475065116200234000ustar00rootroot00000000000000{-# LANGUAGE Strict #-} -- | Definition of the tokens used in the lexer. -- -- Also defines other useful building blocks for constructing tokens. module Language.Futhark.Parser.Lexer.Tokens ( Token (..), fromRoman, symbol, mkQualId, tokenC, tokenS, suffZero, tryRead, decToken, binToken, hexToken, romToken, readHexRealLit, ) where import Data.ByteString.Lazy qualified as BS import Data.Either import Data.List (find) import Data.Text qualified as T import Data.Text.Encoding qualified as T import Data.Text.Read qualified as T import Language.Futhark.Core ( Int16, Int32, Int64, Int8, Name, Word16, Word32, Word64, Word8, ) import Language.Futhark.Prop (leadingOperator) import Language.Futhark.Syntax (BinOp, nameFromText, nameToText) import Numeric.Half import Prelude hiding (exponent) -- | A lexical token. It does not itself contain position -- information, so in practice the parser will consume tokens tagged -- with a source position. data Token = ID Name | COMMENT T.Text | INDEXING -- A left bracket immediately following an identifier. | SYMBOL BinOp [Name] Name | CONSTRUCTOR Name | NATLIT Name Integer | INTLIT Integer | STRINGLIT T.Text | I8LIT Int8 | I16LIT Int16 | I32LIT Int32 | I64LIT Int64 | U8LIT Word8 | U16LIT Word16 | U32LIT Word32 | U64LIT Word64 | FLOATLIT Double | F16LIT Half | F32LIT Float | F64LIT Double | CHARLIT Char | COLON | COLON_GT | BACKSLASH | APOSTROPHE | APOSTROPHE_THEN_HAT | APOSTROPHE_THEN_TILDE | BACKTICK | HASH_LBRACKET | DOT | TWO_DOTS | TWO_DOTS_LT | TWO_DOTS_GT | THREE_DOTS | LPAR | RPAR | LBRACKET | RBRACKET | LCURLY | RCURLY | COMMA | UNDERSCORE | RIGHT_ARROW | QUESTION_MARK | EQU | ASTERISK | NEGATE | BANG | DOLLAR | LTH | HAT | TILDE | PIPE | IF | THEN | ELSE | DEF | LET | LOOP | IN | FOR | DO | WITH | ASSERT | TRUE | FALSE | WHILE | INCLUDE | IMPORT | ENTRY | TYPE | MODULE | VAL | OPEN | LOCAL | MATCH | CASE | DOC T.Text | EOF | HOLE | ERROR T.Text deriving (Show, Eq, Ord) mkQualId :: T.Text -> ([Name], Name) mkQualId s = case reverse $ T.splitOn "." s of [] -> error "mkQualId: no components" k : qs -> (map nameFromText (reverse qs), nameFromText k) -- | Suffix a zero if the last character is dot. suffZero :: T.Text -> T.Text suffZero s = if T.last s == '.' then s <> "0" else s tryRead :: (Read a) => String -> T.Text -> a tryRead desc s = case reads s' of [(x, "")] -> x _ -> error $ "Invalid " ++ desc ++ " literal: `" ++ T.unpack s ++ "'." where s' = T.unpack s {-# INLINE tokenC #-} tokenC :: a -> BS.ByteString -> a tokenC v _ = v {-# INLINE decToken #-} decToken :: (Integral a) => (a -> Token) -> BS.ByteString -> Token decToken f = f . BS.foldl' digit 0 where digit x c = if c >= 48 && c <= 57 then x * 10 + fromIntegral (c - 48) else x {-# INLINE binToken #-} binToken :: (Integral a) => (a -> Token) -> BS.ByteString -> Token binToken f = f . BS.foldl' digit 0 where digit x c = if c >= 48 && c <= 49 then x * 2 + fromIntegral (c - 48) else x {-# INLINE hexToken #-} hexToken :: (Integral a) => (a -> Token) -> BS.ByteString -> Token hexToken f = f . BS.foldl' digit 0 where digit x c | c >= 48 && c <= 57 = x * 16 + fromIntegral (c - 48) | c >= 65 && c <= 70 = x * 16 + fromIntegral (10 + c - 65) | c >= 97 && c <= 102 = x * 16 + fromIntegral (10 + c - 97) | otherwise = x {-# INLINE romToken #-} romToken :: (Integral a) => (a -> Token) -> BS.ByteString -> Token romToken f = tokenS $ f . fromRoman {-# INLINE tokenS #-} tokenS :: (T.Text -> a) -> BS.ByteString -> a tokenS f = f . T.decodeUtf8 . BS.toStrict symbol :: [Name] -> Name -> Token symbol [] q | nameToText q == "*" = ASTERISK | nameToText q == "-" = NEGATE | nameToText q == "<" = LTH | nameToText q == "^" = HAT | nameToText q == "|" = PIPE | otherwise = SYMBOL (leadingOperator q) [] q symbol qs q = SYMBOL (leadingOperator q) qs q romanNumerals :: (Integral a) => [(T.Text, a)] romanNumerals = reverse [ ("I", 1), ("IV", 4), ("V", 5), ("IX", 9), ("X", 10), ("XL", 40), ("L", 50), ("XC", 90), ("C", 100), ("CD", 400), ("D", 500), ("CM", 900), ("M", 1000) ] fromRoman :: (Integral a) => T.Text -> a fromRoman s = case find ((`T.isPrefixOf` s) . fst) romanNumerals of Nothing -> 0 Just (d, n) -> n + fromRoman (T.drop (T.length d) s) readHexRealLit :: (RealFloat a) => T.Text -> a readHexRealLit s = let num = T.drop 2 s in -- extract number into integer, fractional and (optional) exponent let comps = T.split (`elem` ['.', 'p', 'P']) num in case comps of [i, f, p] -> let runTextReader r = fromInteger . fst . fromRight (error "internal error") . r intPart = runTextReader T.hexadecimal i fracPart = runTextReader T.hexadecimal f exponent = runTextReader (T.signed T.decimal) p fracLen = fromIntegral $ T.length f fracVal = fracPart / (16.0 ** fracLen) totalVal = (intPart + fracVal) * (2.0 ** exponent) in totalVal _ -> error "bad hex real literal" futhark-0.25.27/src/Language/Futhark/Parser/Lexer/Wrapper.hs000066400000000000000000000035631475065116200235560ustar00rootroot00000000000000{-# OPTIONS_GHC -funbox-strict-fields #-} -- | Utility definitions used by the lexer. None of the default Alex -- "wrappers" are precisely what we need. The code here is highly -- minimalistic. Lexers should not be complicated! module Language.Futhark.Parser.Lexer.Wrapper ( initialLexerState, AlexInput, alexInputPrevChar, LexerError (..), alexGetByte, alexGetPos, ) where import Data.ByteString.Internal qualified as BS (w2c) import Data.ByteString.Lazy qualified as BS import Data.Int (Int64) import Data.Loc (Loc, Pos (..)) import Data.Text qualified as T import Data.Word (Word8) type Byte = Word8 -- | The input type. Contains: -- -- 1. current position -- -- 2. previous char -- -- 3. current input string -- -- 4. bytes consumed so far type AlexInput = ( Pos, -- current position, Char, -- previous char BS.ByteString, -- current input string Int64 -- bytes consumed so far ) alexInputPrevChar :: AlexInput -> Char alexInputPrevChar (_, prev, _, _) = prev {-# INLINE alexGetByte #-} alexGetByte :: AlexInput -> Maybe (Byte, AlexInput) alexGetByte (p, _, cs, n) = case BS.uncons cs of Nothing -> Nothing Just (b, cs') -> let c = BS.w2c b p' = alexMove p c n' = n + 1 in p' `seq` cs' `seq` n' `seq` Just (b, (p', c, cs', n')) alexGetPos :: AlexInput -> Pos alexGetPos (pos, _, _, _) = pos tabSize :: Int tabSize = 8 {-# INLINE alexMove #-} alexMove :: Pos -> Char -> Pos alexMove (Pos !f !l !c !a) '\t' = Pos f l (c + tabSize - ((c - 1) `mod` tabSize)) (a + 1) alexMove (Pos !f !l _ !a) '\n' = Pos f (l + 1) 1 (a + 1) alexMove (Pos !f !l !c !a) _ = Pos f l (c + 1) (a + 1) initialLexerState :: Pos -> BS.ByteString -> AlexInput initialLexerState start_pos input = (start_pos, '\n', input, 0) data LexerError = LexerError Loc T.Text instance Show LexerError where show (LexerError _ s) = T.unpack s futhark-0.25.27/src/Language/Futhark/Parser/Monad.hs000066400000000000000000000206021475065116200221060ustar00rootroot00000000000000-- | Utility functions and definitions used in the Happy-generated -- parser. They are defined here because the @.y@ file is opaque to -- linters and other tools. In particular, we cannot enable warnings -- for that file, because Happy-generated code is very dirty by GHC's -- standards. module Language.Futhark.Parser.Monad ( ParserMonad, ParserState, Comment (..), parse, parseWithComments, lexer, mustBeEmpty, arrayFromList, binOp, binOpName, mustBe, primNegate, applyExp, arrayLitExp, addDocSpec, addAttrSpec, addDoc, addAttr, twoDotsRange, SyntaxError (..), emptyArrayError, parseError, parseErrorAt, backOneCol, -- * Reexports L, Token, ) where import Control.Monad import Control.Monad.Except (ExceptT, MonadError (..), runExceptT) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.State import Data.Array hiding (index) import Data.ByteString.Lazy qualified as BS import Data.List.NonEmpty qualified as NE import Data.Monoid import Data.Text qualified as T import Data.Text.Encoding qualified as T import Futhark.Util.Loc import Futhark.Util.Pretty hiding (line, line') import Language.Futhark.Parser.Lexer import Language.Futhark.Parser.Lexer.Wrapper (AlexInput, LexerError (..), initialLexerState) import Language.Futhark.Pretty () import Language.Futhark.Prop import Language.Futhark.Syntax import Prelude hiding (mod) addDoc :: DocComment -> UncheckedDec -> UncheckedDec addDoc doc (ValDec val) = ValDec (val {valBindDoc = Just doc}) addDoc doc (TypeDec tp) = TypeDec (tp {typeDoc = Just doc}) addDoc doc (ModTypeDec sig) = ModTypeDec (sig {modTypeDoc = Just doc}) addDoc doc (ModDec mod) = ModDec (mod {modDoc = Just doc}) addDoc doc (LocalDec dec loc) = LocalDec (addDoc doc dec) loc addDoc _ dec = dec addDocSpec :: DocComment -> SpecBase NoInfo Name -> SpecBase NoInfo Name addDocSpec doc (TypeAbbrSpec tpsig) = TypeAbbrSpec (tpsig {typeDoc = Just doc}) addDocSpec doc (ValSpec name ps t NoInfo _ loc) = ValSpec name ps t NoInfo (Just doc) loc addDocSpec doc (TypeSpec l name ps _ loc) = TypeSpec l name ps (Just doc) loc addDocSpec doc (ModSpec name se _ loc) = ModSpec name se (Just doc) loc addDocSpec _ spec@IncludeSpec {} = spec addAttr :: AttrInfo Name -> UncheckedDec -> UncheckedDec addAttr attr (ValDec val) = ValDec $ val {valBindAttrs = attr : valBindAttrs val} addAttr _ dec = dec -- We will extend this function once we actually start tracking these. addAttrSpec :: AttrInfo Name -> UncheckedSpec -> UncheckedSpec addAttrSpec _attr dec = dec mustBe :: L Token -> T.Text -> ParserMonad () mustBe (L _ (ID got)) expected | nameToText got == expected = pure () mustBe (L loc _) expected = parseErrorAt loc . Just $ "Only the keyword '" <> expected <> "' may appear here." mustBeEmpty :: (Located loc) => loc -> ValueType -> ParserMonad () mustBeEmpty _ (Array _ (Shape dims) _) | 0 `elem` dims = pure () mustBeEmpty loc t = parseErrorAt loc $ Just $ prettyText t <> " is not an empty array." -- | A comment consists of its starting and end position, as well as -- its text. The contents include the comment start marker. data Comment = Comment {commentLoc :: Loc, commentText :: T.Text} deriving (Eq, Ord, Show) instance Located Comment where locOf = commentLoc data ParserState = ParserState { _parserFile :: FilePath, parserInput :: T.Text, -- | Note: reverse order. parserComments :: [Comment], parserLexerState :: AlexInput } type ParserMonad = ExceptT SyntaxError (State ParserState) arrayFromList :: [a] -> Array Int a arrayFromList l = listArray (0, length l - 1) l arrayLitExp :: [UncheckedExp] -> SrcLoc -> UncheckedExp arrayLitExp es loc | Just (v : vs) <- mapM isLiteral es, all ((primValueType v ==) . primValueType) vs = ArrayVal (v : vs) (primValueType v) loc | otherwise = ArrayLit es NoInfo loc where isLiteral (Literal v _) = Just v isLiteral _ = Nothing applyExp :: NE.NonEmpty UncheckedExp -> ParserMonad UncheckedExp applyExp all_es@((Constr n [] _ loc1) NE.:| es) = pure $ Constr n es NoInfo (srcspan loc1 (NE.last all_es)) applyExp es = foldM op (NE.head es) (NE.tail es) where op (AppExp (Index e is floc) _) (ArrayLit xs _ xloc) = parseErrorAt (srcspan floc xloc) . Just . docText $ "Incorrect syntax for multi-dimensional indexing." "Use" <+> align (pretty index) where index = AppExp (Index e (is ++ map DimFix xs) xloc) NoInfo op f x = pure $ mkApplyUT f x binOpName :: L Token -> (QualName Name, Loc) binOpName (L loc (SYMBOL _ qs op)) = (QualName qs op, loc) binOpName t = error $ "binOpName: unexpected " ++ show t binOp :: UncheckedExp -> L Token -> UncheckedExp -> UncheckedExp binOp x (L loc (SYMBOL _ qs op)) y = AppExp (BinOp (QualName qs op, srclocOf loc) NoInfo (x, NoInfo) (y, NoInfo) (srcspan x y)) NoInfo binOp _ t _ = error $ "binOp: unexpected " ++ show t putComment :: Comment -> ParserMonad () putComment c = lift $ modify $ \env -> env {parserComments = c : parserComments env} intNegate :: IntValue -> IntValue intNegate (Int8Value v) = Int8Value (-v) intNegate (Int16Value v) = Int16Value (-v) intNegate (Int32Value v) = Int32Value (-v) intNegate (Int64Value v) = Int64Value (-v) floatNegate :: FloatValue -> FloatValue floatNegate (Float16Value v) = Float16Value (-v) floatNegate (Float32Value v) = Float32Value (-v) floatNegate (Float64Value v) = Float64Value (-v) primNegate :: PrimValue -> PrimValue primNegate (FloatValue v) = FloatValue $ floatNegate v primNegate (SignedValue v) = SignedValue $ intNegate v primNegate (UnsignedValue v) = UnsignedValue $ intNegate v primNegate (BoolValue v) = BoolValue $ not v lexer :: (L Token -> ParserMonad a) -> ParserMonad a lexer cont = do ls <- lift $ gets parserLexerState case getToken ls of Left e -> throwError $ lexerErrToParseErr e Right (ls', (start, end, tok)) -> do let loc = Loc start end lift $ modify $ \s -> s {parserLexerState = ls'} case tok of COMMENT text -> do putComment $ Comment loc text lexer cont _ -> cont $ L loc tok parseError :: (L Token, [String]) -> ParserMonad a parseError (L loc EOF, expected) = parseErrorAt (locOf loc) . Just . T.unlines $ [ "Unexpected end of file.", "Expected one of the following: " <> T.unwords (map T.pack expected) ] parseError (L loc DOC {}, _) = parseErrorAt (locOf loc) $ Just "Documentation comments ('-- |') are only permitted when preceding declarations." parseError (L loc (ERROR "\""), _) = parseErrorAt (locOf loc) $ Just "Unclosed string literal." parseError (L loc _, expected) = do input <- lift $ gets parserInput let ~(Loc (Pos _ _ _ beg) (Pos _ _ _ end)) = locOf loc tok_src = T.take (end - beg) $ T.drop beg input parseErrorAt loc . Just . T.unlines $ [ "Unexpected token: '" <> tok_src <> "'", "Expected one of the following: " <> T.unwords (map T.pack expected) ] parseErrorAt :: (Located loc) => loc -> Maybe T.Text -> ParserMonad a parseErrorAt loc Nothing = throwError $ SyntaxError (locOf loc) "Syntax error." parseErrorAt loc (Just s) = throwError $ SyntaxError (locOf loc) s emptyArrayError :: Loc -> ParserMonad a emptyArrayError loc = parseErrorAt loc $ Just "write empty arrays as 'empty(t)', for element type 't'.\n" twoDotsRange :: Loc -> ParserMonad a twoDotsRange loc = parseErrorAt loc $ Just "use '...' for ranges, not '..'.\n" -- | Move the end position back one column. backOneCol :: Loc -> Loc backOneCol (Loc start (Pos f l c o)) = Loc start $ Pos f l (c - 1) (o - 1) backOneCol NoLoc = NoLoc --- Now for the parser interface. -- | A syntax error. data SyntaxError = SyntaxError {syntaxErrorLoc :: Loc, syntaxErrorMsg :: T.Text} lexerErrToParseErr :: LexerError -> SyntaxError lexerErrToParseErr (LexerError loc msg) = SyntaxError loc msg parseWithComments :: ParserMonad a -> FilePath -> T.Text -> Either SyntaxError (a, [Comment]) parseWithComments p file program = onRes $ runState (runExceptT p) env where env = ParserState file program [] (initialLexerState start $ BS.fromStrict . T.encodeUtf8 $ program) start = Pos file 1 1 0 onRes (Left err, _) = Left err onRes (Right x, s) = Right (x, reverse $ parserComments s) parse :: ParserMonad a -> FilePath -> T.Text -> Either SyntaxError a parse p file program = fst <$> parseWithComments p file program futhark-0.25.27/src/Language/Futhark/Parser/Parser.y000066400000000000000000001077541475065116200221600ustar00rootroot00000000000000{ -- | Futhark parser written with Happy. module Language.Futhark.Parser.Parser ( prog , expression , declaration , modExpression , futharkType , parse , parseWithComments , SyntaxError(..) , Comment(..) ) where import Data.Bifunctor (second) import Control.Monad import Control.Monad.Trans import Control.Monad.Except import Control.Monad.Reader import Control.Monad.Trans.State import Data.Array import qualified Data.ByteString as BS import qualified Data.Text as T import qualified Data.Text.Encoding as T import Data.Char (ord) import Data.Maybe (fromMaybe, fromJust) import Data.List (genericLength) import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M import Data.Monoid import Language.Futhark.Syntax hiding (ID) import Language.Futhark.Prop import Language.Futhark.Pretty import Language.Futhark.Parser.Lexer (Token(..)) import Futhark.Util.Pretty import Futhark.Util.Loc import Language.Futhark.Parser.Monad } %name prog Prog %name futharkType TypeExp %name expression Exp %name modExpression ModExp %name declaration Dec %tokentype { L Token } %error { parseError } %errorhandlertype explist %monad { ParserMonad } %lexer { lexer } { L _ EOF } %token if { L $$ IF } then { L $$ THEN } else { L $$ ELSE } let { L $$ LET } def { L $$ DEF } loop { L $$ LOOP } in { L $$ IN } match { L $$ MATCH } case { L $$ CASE } id { L _ (ID _) } '...[' { L _ INDEXING } constructor { L _ (CONSTRUCTOR _) } natlit { L _ (NATLIT _ _) } intlit { L _ (INTLIT _) } i8lit { L _ (I8LIT _) } i16lit { L _ (I16LIT _) } i32lit { L _ (I32LIT _) } i64lit { L _ (I64LIT _) } u8lit { L _ (U8LIT _) } u16lit { L _ (U16LIT _) } u32lit { L _ (U32LIT _) } u64lit { L _ (U64LIT _) } floatlit { L _ (FLOATLIT _) } f16lit { L _ (F16LIT _) } f32lit { L _ (F32LIT _) } f64lit { L _ (F64LIT _) } stringlit { L _ (STRINGLIT _) } charlit { L _ (CHARLIT _) } '.' { L $$ DOT } '..' { L $$ TWO_DOTS } '...' { L $$ THREE_DOTS } '..<' { L $$ TWO_DOTS_LT } '..>' { L $$ TWO_DOTS_GT } '=' { L $$ EQU } '*' { L $$ ASTERISK } '-' { L $$ NEGATE } '!' { L $$ BANG } '<' { L $$ LTH } '^' { L $$ HAT } '~' { L $$ TILDE } '|' { L $$ PIPE } '+...' { L _ (SYMBOL Plus _ _) } '-...' { L _ (SYMBOL Minus _ _) } '*...' { L _ (SYMBOL Times _ _) } '/...' { L _ (SYMBOL Divide _ _) } '%...' { L _ (SYMBOL Mod _ _) } '//...' { L _ (SYMBOL Quot _ _) } '%%...' { L _ (SYMBOL Rem _ _) } '==...' { L _ (SYMBOL Equal _ _) } '!=...' { L _ (SYMBOL NotEqual _ _) } '<...' { L _ (SYMBOL Less _ _) } '>...' { L _ (SYMBOL Greater _ _) } '<=...' { L _ (SYMBOL Leq _ _) } '>=...' { L _ (SYMBOL Geq _ _) } '**...' { L _ (SYMBOL Pow _ _) } '<<...' { L _ (SYMBOL ShiftL _ _) } '>>...' { L _ (SYMBOL ShiftR _ _) } '|>...' { L _ (SYMBOL PipeRight _ _) } '<|...' { L _ (SYMBOL PipeLeft _ _) } '|...' { L _ (SYMBOL Bor _ _) } '&...' { L _ (SYMBOL Band _ _) } '^...' { L _ (SYMBOL Xor _ _) } '||...' { L _ (SYMBOL LogOr _ _) } '&&...' { L _ (SYMBOL LogAnd _ _) } '!...' { L _ (SYMBOL Bang _ _) } '=...' { L _ (SYMBOL Equ _ _) } '(' { L $$ LPAR } ')' { L $$ RPAR } '{' { L $$ LCURLY } '}' { L $$ RCURLY } '[' { L $$ LBRACKET } ']' { L $$ RBRACKET } '#[' { L $$ HASH_LBRACKET } ',' { L $$ COMMA } '_' { L $$ UNDERSCORE } '\\' { L $$ BACKSLASH } '\'' { L $$ APOSTROPHE } '\'^' { L $$ APOSTROPHE_THEN_HAT } '\'~' { L $$ APOSTROPHE_THEN_TILDE } '`' { L $$ BACKTICK } entry { L $$ ENTRY } '->' { L $$ RIGHT_ARROW } ':' { L $$ COLON } ':>' { L $$ COLON_GT } '?' { L $$ QUESTION_MARK } for { L $$ FOR } do { L $$ DO } with { L $$ WITH } assert { L $$ ASSERT } true { L $$ TRUE } false { L $$ FALSE } while { L $$ WHILE } include { L $$ INCLUDE } import { L $$ IMPORT } type { L $$ TYPE } module { L $$ MODULE } val { L $$ VAL } open { L $$ OPEN } local { L $$ LOCAL } doc { L _ (DOC _) } hole { L $$ HOLE } %left bottom %left ifprec letprec caseprec typeprec enumprec sumprec %left ',' case id constructor '(' '{' %right ':' ':>' %right '...' '..<' '..>' '..' %left '`' %right '->' %left with %left '=' %left '|>...' %right '<|...' %left '||...' %left '&&...' %left '<=...' '>=...' '>...' '<' '<...' '==...' '!=...' '!...' '=...' %left '&...' '^...' '^' '|...' '|' %left '<<...' '>>...' %left '+...' '-...' '-' %left '*...' '*' '/...' '%...' '//...' '%%...' %left '**...' %left juxtprec %left '[' '...[' indexprec %left top %% -- The main parser. Doc :: { DocComment } : doc { let L loc (DOC s) = $1 in DocComment s (srclocOf loc) } -- Four cases to avoid ambiguities. Prog :: { UncheckedProg } -- File begins with a file comment, followed by a Dec with a comment. : Doc Doc Dec_ Decs { Prog (Just $1) (addDoc $2 $3 : $4) } -- File begins with a file comment, followed by a Dec with no comment. | Doc Dec_ Decs { Prog (Just $1) ($2 : $3) } -- File begins with a dec with no comment. | Dec_ Decs { Prog Nothing ($1 : $2) } -- File is empty. | { Prog Nothing [] } ; Dec :: { UncheckedDec } : Dec_ { $1 } | Doc Dec_ { addDoc $1 $2 } Decs :: { [UncheckedDec] } : Decs_ { reverse $1 } Decs_ :: { [UncheckedDec] } : { [] } | Decs_ Dec { $2 : $1 } Dec_ :: { UncheckedDec } : Val { ValDec $1 } | TypeAbbr { TypeDec $1 } | ModTypeBind { ModTypeDec $1 } | ModBind { ModDec $1 } | open ModExp { OpenDec $2 (srclocOf $1) } | import stringlit { let L _ (STRINGLIT s) = $2 in ImportDec (T.unpack s) NoInfo (srcspan $1 $>) } | local Dec { LocalDec $2 (srcspan $1 $>) } | '#[' AttrInfo ']' Dec_ { addAttr $2 $4 } ; ModTypeExp :: { UncheckedModTypeExp } : QualName { let (v, loc) = $1 in ModTypeVar v NoInfo (srclocOf loc) } | '{' Specs '}' { ModTypeSpecs $2 (srcspan $1 $>) } | ModTypeExp with TypeRef { ModTypeWith $1 $3 (srcspan $1 $>) } | '(' ModTypeExp ')' { ModTypeParens $2 (srcspan $1 $>) } | '(' id ':' ModTypeExp ')' '->' ModTypeExp { let L _ (ID name) = $2 in ModTypeArrow (Just name) $4 $7 (srcspan $1 $>) } | ModTypeExp '->' ModTypeExp { ModTypeArrow Nothing $1 $3 (srcspan $1 $>) } TypeRef :: { TypeRefBase NoInfo Name } : QualName TypeParams '=' TypeExpTerm { TypeRef (fst $1) $2 $4 (srcspan (snd $1) $>) } ModTypeBind :: { ModTypeBindBase NoInfo Name } : module type id '=' ModTypeExp { let L _ (ID name) = $3 in ModTypeBind name $5 Nothing (srcspan $1 $>) } ModExp :: { UncheckedModExp } : ModExp ':' ModTypeExp { ModAscript $1 $3 NoInfo (srcspan $1 $>) } | '\\' ModParam maybeAscription(SimpleModTypeExp) '->' ModExp { ModLambda $2 (fmap (,NoInfo) $3) $5 (srcspan $1 $>) } | import stringlit { let L _ (STRINGLIT s) = $2 in ModImport (T.unpack s) NoInfo (srcspan $1 $>) } | ModExpApply { $1 } | ModExpAtom { $1 } ModExpApply :: { UncheckedModExp } : ModExpAtom ModExpAtom %prec juxtprec { ModApply $1 $2 NoInfo NoInfo (srcspan $1 $>) } | ModExpApply ModExpAtom %prec juxtprec { ModApply $1 $2 NoInfo NoInfo (srcspan $1 $>) } ModExpAtom :: { UncheckedModExp } : '(' ModExp ')' { ModParens $2 (srcspan $1 $>) } | QualName { let (v, loc) = $1 in ModVar v (srclocOf loc) } | '{' Decs '}' { ModDecs $2 (srcspan $1 $>) } SimpleModTypeExp :: { UncheckedModTypeExp } : QualName { let (v, loc) = $1 in ModTypeVar v NoInfo (srclocOf loc) } | '(' ModTypeExp ')' { $2 } ModBind :: { ModBindBase NoInfo Name } : module id ModParams maybeAscription(ModTypeExp) '=' ModExp { let L floc (ID fname) = $2; in ModBind fname $3 (fmap (,NoInfo) $4) $6 Nothing (srcspan $1 $>) } ModParam :: { ModParamBase NoInfo Name } : '(' id ':' ModTypeExp ')' { let L _ (ID name) = $2 in ModParam name $4 NoInfo (srcspan $1 $>) } ModParams :: { [ModParamBase NoInfo Name] } : ModParam ModParams { $1 : $2 } | { [] } Liftedness :: { Liftedness } : { Unlifted } | '~' { SizeLifted } | '^' { Lifted } Spec :: { SpecBase NoInfo Name } : val id TypeParams ':' TypeExp { let L loc (ID name) = $2 in ValSpec name $3 $5 NoInfo Nothing (srcspan $1 $>) } | val BindingBinOp TypeParams ':' TypeExp { ValSpec $2 $3 $5 NoInfo Nothing (srcspan $1 $>) } | TypeAbbr { TypeAbbrSpec $1 } | type Liftedness id TypeParams { let L _ (ID name) = $3 in TypeSpec $2 name $4 Nothing (srcspan $1 $>) } | module id ':' ModTypeExp { let L _ (ID name) = $2 in ModSpec name $4 Nothing (srcspan $1 $>) } | include ModTypeExp { IncludeSpec $2 (srcspan $1 $>) } | Doc Spec { addDocSpec $1 $2 } | '#[' AttrInfo ']' Spec { addAttrSpec $2 $4 } Specs :: { [SpecBase NoInfo Name] } : Specs_ { reverse $1 } Specs_ :: { [SpecBase NoInfo Name] } : Specs_ Spec { $2 : $1 } | { [] } SizeBinder :: { SizeBinder Name } : '[' id ']' { let L _ (ID name) = $2 in SizeBinder name (srcspan $1 $>) } | '...[' id ']' { let L _ (ID name) = $2 in SizeBinder name (srcspan $1 $>) } SizeBinders1 :: { [SizeBinder Name] } : SizeBinder SizeBinders1 { $1 : $2 } | SizeBinder { [$1] } TypeTypeParam :: { TypeParamBase Name } : '\'' id { let L _ (ID name) = $2 in TypeParamType Unlifted name (srcspan $1 $>) } | '\'~' id { let L _ (ID name) = $2 in TypeParamType SizeLifted name (srcspan $1 $>) } | '\'^' id { let L _ (ID name) = $2 in TypeParamType Lifted name (srcspan $1 $>) } TypeParam :: { TypeParamBase Name } : '[' id ']' { let L _ (ID name) = $2 in TypeParamDim name (srcspan $1 $>) } | '...[' id ']' { let L _ (ID name) = $2 in TypeParamDim name (srcspan $1 $>) } | TypeTypeParam { $1 } TypeParams :: { [TypeParamBase Name] } : TypeParam TypeParams { $1 : $2 } | { [] } -- Due to an ambiguity between in-place updates ("let x[i] ...") and -- local functions with size parameters, the latter need a special -- nonterminal. LocalFunTypeParams :: { [TypeParamBase Name] } : '[' id ']' TypeParams { let L _ (ID name) = $2 in TypeParamDim name (srcspan $1 $>) : $4 } | TypeTypeParam TypeParams { $1 : $2 } | { [] } -- Note that this production does not include Minus, but does include -- operator sections. BinOp :: { (QualName Name, Loc) } : '+...' { binOpName $1 } | '-...' { binOpName $1 } | '*...' { binOpName $1 } | '*' { (qualName (nameFromString "*"), $1) } | '/...' { binOpName $1 } | '%...' { binOpName $1 } | '//...' { binOpName $1 } | '%%...' { binOpName $1 } | '==...' { binOpName $1 } | '!=...' { binOpName $1 } | '<...' { binOpName $1 } | '<=...' { binOpName $1 } | '>...' { binOpName $1 } | '>=...' { binOpName $1 } | '&&...' { binOpName $1 } | '||...' { binOpName $1 } | '**...' { binOpName $1 } | '^...' { binOpName $1 } | '^' { (qualName (nameFromString "^"), $1) } | '&...' { binOpName $1 } | '|...' { binOpName $1 } | '|' { (qualName (nameFromString "|"), $1) } | '>>...' { binOpName $1 } | '<<...' { binOpName $1 } | '<|...' { binOpName $1 } | '|>...' { binOpName $1 } | '<' { (qualName (nameFromString "<"), $1) } | '!...' { binOpName $1 } | '=...' { binOpName $1 } | '`' QualName '`' { $2 } BindingBinOp :: { Name } : BinOp {% let (QualName qs name, loc) = $1 in do unless (null qs) $ parseErrorAt loc $ Just "Cannot use a qualified name in binding position." pure name } | '-' { nameFromString "-" } | '!' {% parseErrorAt $1 $ Just $ "'!' is a prefix operator and cannot be used as infix operator." } BindingId :: { (Name, Loc) } : id { let L loc (ID name) = $1 in (name, loc) } | '(' BindingBinOp ')' { ($2, $1) } Val :: { ValBindBase NoInfo Name } Val : def BindingId TypeParams FunParams maybeAscription(TypeExp) '=' Exp { let (name, _) = $2 in ValBind Nothing name $5 NoInfo $3 $4 $7 Nothing mempty (srcspan $1 $>) } | entry BindingId TypeParams FunParams maybeAscription(TypeExp) '=' Exp { let (name, loc) = $2 in ValBind (Just NoInfo) name $5 NoInfo $3 $4 $7 Nothing mempty (srcspan $1 $>) } | def FunParam BindingBinOp FunParam maybeAscription(TypeExp) '=' Exp { ValBind Nothing $3 $5 NoInfo [] [$2,$4] $7 Nothing mempty (srcspan $1 $>) } -- The next two for backwards compatibility. | let BindingId TypeParams FunParams maybeAscription(TypeExp) '=' Exp { let (name, _) = $2 in ValBind Nothing name $5 NoInfo $3 $4 $7 Nothing mempty (srcspan $1 $>) } | let FunParam BindingBinOp FunParam maybeAscription(TypeExp) '=' Exp { ValBind Nothing $3 $5 NoInfo [] [$2,$4] $7 Nothing mempty (srcspan $1 $>) } -- Some error cases | def '(' Pat ',' Pats1 ')' '=' Exp {% parseErrorAt (srcspan $2 $6) $ Just $ T.unlines ["Cannot bind patterns at top level.", "Bind a single name instead."] } | let '(' Pat ',' Pats1 ')' '=' Exp {% parseErrorAt (srcspan $2 $6) $ Just $ T.unlines ["Cannot bind patterns at top level.", "Bind a single name instead."] } TypeAbbr :: { TypeBindBase NoInfo Name } TypeAbbr : type Liftedness id TypeParams '=' TypeExp { let L _ (ID name) = $3 in TypeBind name $2 $4 $6 NoInfo Nothing (srcspan $1 $>) } TypeExp :: { UncheckedTypeExp } : '(' id ':' TypeExp ')' '->' TypeExp { let L _ (ID v) = $2 in TEArrow (Just v) $4 $7 (srcspan $1 $>) } | TypeExpTerm '->' TypeExp { TEArrow Nothing $1 $3 (srcspan $1 $>) } | '?' TypeExpDims '.' TypeExp { TEDim $2 $4 (srcspan $1 $>) } | TypeExpTerm %prec typeprec { $1 } TypeExpDims :: { [Name] } : '[' id ']' { let L _ (ID v) = $2 in [v] } | '[' id ']' TypeExpDims { let L _ (ID v) = $2 in v : $4 } | '...[' id ']' { let L _ (ID v) = $2 in [v] } | '...[' id ']' TypeExpDims { let L _ (ID v) = $2 in v : $4 } TypeExpTerm :: { UncheckedTypeExp } : '*' TypeExpTerm { TEUnique $2 (srcspan $1 $>) } | TypeExpApply %prec typeprec { $1 } | SumClauses %prec sumprec { let (cs, loc) = $1 in TESum cs (srclocOf loc) } SumClauses :: { ([(Name, [UncheckedTypeExp])], Loc) } : SumClauses '|' SumClause %prec sumprec { let (cs, loc1) = $1; (c, ts, loc2) = $3 in (cs++[(c, ts)], locOf (srcspan loc1 loc2)) } | SumClause %prec sumprec { let (n, ts, loc) = $1 in ([(n, ts)], loc) } SumPayload :: { [UncheckedTypeExp] } : %prec bottom { [] } | TypeExpAtom SumPayload { $1 : $2 } SumClause :: { (Name, [UncheckedTypeExp], Loc) } : Constr SumPayload { (fst $1, $2, locOf (srcspan (snd $1) $>)) } TypeExpApply :: { UncheckedTypeExp } : TypeExpApply TypeArg { TEApply $1 $2 (srcspan $1 $>) } | TypeExpAtom { $1 } TypeExpAtom :: { UncheckedTypeExp } : '(' TypeExp ')' { TEParens $2 (srcspan $1 $>) } | '(' ')' { TETuple [] (srcspan $1 $>) } | '(' TypeExp ',' TupleTypes ')' { TETuple ($2:$4) (srcspan $1 $>) } | '{' FieldTypes '}' { TERecord $2 (srcspan $1 $>) } | SizeExp TypeExpTerm { TEArray $1 $2 (srcspan $1 $>) } | QualName { TEVar (fst $1) (srclocOf (snd $1)) } Constr :: { (Name, Loc) } : constructor { let L _ (CONSTRUCTOR c) = $1 in (c, locOf $1) } TypeArg :: { TypeArgExp UncheckedExp Name } : SizeExp %prec top { TypeArgExpSize $1 } | TypeExpAtom { TypeArgExpType $1 } FieldType :: { (L Name, UncheckedTypeExp) } FieldType : FieldId ':' TypeExp { ($1, $3) } FieldTypes :: { [(L Name, UncheckedTypeExp)] } FieldTypes : { [] } | FieldType { [$1] } | FieldType ',' FieldTypes { $1 : $3 } TupleTypes :: { [UncheckedTypeExp] } : TypeExp { [$1] } | TypeExp ',' { [$1] } | TypeExp ',' TupleTypes { $1 : $3 } SizeExp :: { SizeExp UncheckedExp } : '[' Exp ']' { SizeExp $2 (srcspan $1 $>) } | '[' ']' { SizeExpAny (srcspan $1 $>) } | '...[' Exp ']' { SizeExp $2 (srcspan $1 $>) } | '...[' ']' { SizeExpAny (srcspan $1 $>) } FunParam :: { PatBase NoInfo Name ParamType } FunParam : ParamPat { fmap (toParam Observe) $1 } FunParams1 :: { (PatBase NoInfo Name ParamType, [PatBase NoInfo Name ParamType]) } FunParams1 : FunParam { ($1, []) } | FunParam FunParams1 { ($1, fst $2 : snd $2) } FunParams :: { [PatBase NoInfo Name ParamType ] } FunParams : { [] } | FunParam FunParams { $1 : $2 } QualName :: { (QualName Name, Loc) } : id { let L vloc (ID v) = $1 in (QualName [] v, vloc) } | QualName '.' id { let {L ploc (ID f) = $3; (QualName qs v,vloc) = $1;} in (QualName (qs++[v]) f, locOf (srcspan ploc vloc)) } -- Expressions are divided into several layers. The first distinction -- (between Exp and Exp2) is to factor out ascription, which we do not -- permit inside array slices (there is an ambiguity with -- array slices). Exp :: { UncheckedExp } : Exp ':' TypeExp { Ascript $1 $3 (srcspan $1 $>) } | Exp ':>' TypeExp { Coerce $1 $3 NoInfo (srcspan $1 $>) } | Exp2 %prec ':' { $1 } Exp2 :: { UncheckedExp } : IfExp { $1 } | LoopExp { $1 } | LetExp %prec letprec { $1 } | MatchExp { $1 } | assert Atom Atom { Assert $2 $3 NoInfo (srcspan $1 $>) } | '#[' AttrInfo ']' Exp %prec bottom { Attr $2 $4 (srcspan $1 $>) } | BinOpExp { $1 } | RangeExp { $1 } | Exp2 '..' Atom {% twoDotsRange $2 } | Atom '..' Exp2 {% twoDotsRange $2 } | '-' Exp2 %prec juxtprec { Negate $2 (srcspan $1 $>) } | '!' Exp2 %prec juxtprec { Not $2 (srcspan $1 $>) } | Exp2 with '[' DimIndices ']' '=' Exp2 { Update $1 $4 $7 (srcspan $1 $>) } | Exp2 with '...[' DimIndices ']' '=' Exp2 { Update $1 $4 $7 (srcspan $1 $>) } | Exp2 with FieldAccesses_ '=' Exp2 { RecordUpdate $1 (map unLoc $3) $5 NoInfo (srcspan $1 $>) } | ApplyList {% applyExp $1 } ApplyList :: { NE.NonEmpty UncheckedExp } : Atom ApplyList %prec juxtprec { NE.cons $1 $2 } | LastArg { NE.singleton $1 } LastArg :: { UncheckedExp } : '\\' FunParams1 maybeAscription(TypeExpTerm) '->' Exp %prec letprec { Lambda (fst $2 : snd $2) $5 $3 NoInfo (srcspan $1 $>) } | Atom %prec juxtprec { $1 } Atom :: { UncheckedExp } Atom : PrimLit { Literal (fst $1) (srclocOf (snd $1)) } | Constr { Constr (fst $1) [] NoInfo (srclocOf (snd $1)) } | charlit { let L loc (CHARLIT x) = $1 in IntLit (toInteger (ord x)) NoInfo (srclocOf loc) } | intlit { let L loc (INTLIT x) = $1 in IntLit x NoInfo (srclocOf loc) } | natlit { let L loc (NATLIT _ x) = $1 in IntLit x NoInfo (srclocOf loc) } | floatlit { let L loc (FLOATLIT x) = $1 in FloatLit x NoInfo (srclocOf loc) } | stringlit { let L loc (STRINGLIT s) = $1 in StringLit (BS.unpack (T.encodeUtf8 s)) (srclocOf loc) } | hole { Hole NoInfo (srclocOf $1) } | '(' Exp ')' { Parens $2 (srcspan $1 $>) } | '(' Exp ',' Exps1 ')' { TupLit ($2 : $4) (srcspan $1 $>) } | '(' ')' { TupLit [] (srcspan $1 $>) } | '[' Exps1 ']' { arrayLitExp $2 (srcspan $1 $>) } | '[' ']' { arrayLitExp [] (srcspan $1 $>) } | id { let L loc (ID v) = $1 in Var (QualName [] v) NoInfo (srclocOf loc) } | Atom '.' id { let L ploc (ID f) = $3 in case $1 of Var (QualName qs v) NoInfo vloc -> Var (QualName (qs++[v]) f) NoInfo (srcspan vloc ploc) _ -> Project f $1 NoInfo (srcspan $1 ploc) } | Atom '.' natlit { let L ploc (NATLIT f _) = $3 in Project f $1 NoInfo (srcspan $1 ploc) } | Atom '.' '(' Exp ')' {% case $1 of Var qn NoInfo vloc -> pure (QualParens (qn, srclocOf vloc) $4 (srcspan vloc $>)) _ -> parseErrorAt $3 (Just "Can only locally open module names, not arbitrary expressions") } | Atom '...[' DimIndices ']' { AppExp (Index $1 $3 (srcspan $1 $>)) NoInfo } | '{' Fields '}' { RecordLit $2 (srcspan $1 $>) } | SectionExp { $1 } NumLit :: { (PrimValue, Loc) } : i8lit { let L loc (I8LIT num) = $1 in (SignedValue $ Int8Value num, loc) } | i16lit { let L loc (I16LIT num) = $1 in (SignedValue $ Int16Value num, loc) } | i32lit { let L loc (I32LIT num) = $1 in (SignedValue $ Int32Value num, loc) } | i64lit { let L loc (I64LIT num) = $1 in (SignedValue $ Int64Value num, loc) } | u8lit { let L loc (U8LIT num) = $1 in (UnsignedValue $ Int8Value $ fromIntegral num, loc) } | u16lit { let L loc (U16LIT num) = $1 in (UnsignedValue $ Int16Value $ fromIntegral num, loc) } | u32lit { let L loc (U32LIT num) = $1 in (UnsignedValue $ Int32Value $ fromIntegral num, loc) } | u64lit { let L loc (U64LIT num) = $1 in (UnsignedValue $ Int64Value $ fromIntegral num, loc) } | f16lit { let L loc (F16LIT num) = $1 in (FloatValue $ Float16Value num, loc) } | f32lit { let L loc (F32LIT num) = $1 in (FloatValue $ Float32Value num, loc) } | f64lit { let L loc (F64LIT num) = $1 in (FloatValue $ Float64Value num, loc) } PrimLit :: { (PrimValue, Loc) } : true { (BoolValue True, $1) } | false { (BoolValue False, $1) } | NumLit { $1 } Exps1 :: { [UncheckedExp] } : Exps1_ { reverse $1 } Exps1_ :: { [UncheckedExp] } : Exps1_ ',' Exp { $3 : $1 } | Exps1_ ',' { $1 } | Exp { [$1] } FieldAccesses :: { [L Name] } : '.' FieldId FieldAccesses { $2 : $3 } | { [] } FieldAccesses_ :: { [L Name] } : FieldId FieldAccesses { $1 : $2 } Field :: { FieldBase NoInfo Name } : FieldId '=' Exp { RecordFieldExplicit $1 $3 (srcspan $1 $>) } | id { let L loc (ID s) = $1 in RecordFieldImplicit (L loc s) NoInfo (srclocOf loc) } Fields :: { [FieldBase NoInfo Name] } : Field ',' Fields { $1 : $3 } | Field { [$1] } | { [] } LetExp :: { UncheckedExp } : let SizeBinders1 Pat '=' Exp LetBody { AppExp (LetPat $2 $3 $5 $6 (srcspan $1 $>)) NoInfo } | let Pat '=' Exp LetBody { AppExp (LetPat [] $2 $4 $5 (srcspan $1 $>)) NoInfo } | let id LocalFunTypeParams FunParams1 maybeAscription(TypeExp) '=' Exp LetBody { let L _ (ID name) = $2 in AppExp (LetFun name ($3, fst $4 : snd $4, $5, NoInfo, $7) $8 (srcspan $1 $>)) NoInfo} | let id '...[' DimIndices ']' '=' Exp LetBody { let L vloc (ID v) = $2; ident = Ident v NoInfo (srclocOf vloc) in AppExp (LetWith ident ident $4 $7 $8 (srcspan $1 $>)) NoInfo } LetBody :: { UncheckedExp } : in Exp %prec letprec { $2 } | LetExp %prec letprec { $1 } | def {% parseErrorAt $1 (Just "Unexpected \"def\" - missing \"in\"?") } | type {% parseErrorAt $1 (Just "Unexpected \"type\" - missing \"in\"?") } | module {% parseErrorAt $1 (Just "Unexpected \"module\" - missing \"in\"?") } BinOpExp :: { UncheckedExp } : Exp2 '+...' Exp2 { binOp $1 $2 $3 } | Exp2 '-...' Exp2 { binOp $1 $2 $3 } | Exp2 '-' Exp2 { binOp $1 (L $2 (SYMBOL Minus [] (nameFromString "-"))) $3 } | Exp2 '*...' Exp2 { binOp $1 $2 $3 } | Exp2 '*' Exp2 { binOp $1 (L $2 (SYMBOL Times [] (nameFromString "*"))) $3 } | Exp2 '/...' Exp2 { binOp $1 $2 $3 } | Exp2 '%...' Exp2 { binOp $1 $2 $3 } | Exp2 '//...' Exp2 { binOp $1 $2 $3 } | Exp2 '%%...' Exp2 { binOp $1 $2 $3 } | Exp2 '**...' Exp2 { binOp $1 $2 $3 } | Exp2 '>>...' Exp2 { binOp $1 $2 $3 } | Exp2 '<<...' Exp2 { binOp $1 $2 $3 } | Exp2 '&...' Exp2 { binOp $1 $2 $3 } | Exp2 '|...' Exp2 { binOp $1 $2 $3 } | Exp2 '|' Exp2 { binOp $1 (L $2 (SYMBOL Bor [] (nameFromString "|"))) $3 } | Exp2 '&&...' Exp2 { binOp $1 $2 $3 } | Exp2 '||...' Exp2 { binOp $1 $2 $3 } | Exp2 '^...' Exp2 { binOp $1 $2 $3 } | Exp2 '^' Exp2 { binOp $1 (L $2 (SYMBOL Xor [] (nameFromString "^"))) $3 } | Exp2 '==...' Exp2 { binOp $1 $2 $3 } | Exp2 '!=...' Exp2 { binOp $1 $2 $3 } | Exp2 '<...' Exp2 { binOp $1 $2 $3 } | Exp2 '<=...' Exp2 { binOp $1 $2 $3 } | Exp2 '>...' Exp2 { binOp $1 $2 $3 } | Exp2 '>=...' Exp2 { binOp $1 $2 $3 } | Exp2 '|>...' Exp2 { binOp $1 $2 $3 } | Exp2 '<|...' Exp2 { binOp $1 $2 $3 } | Exp2 '<' Exp2 { binOp $1 (L $2 (SYMBOL Less [] (nameFromString "<"))) $3 } | Exp2 '!...' Exp2 { binOp $1 $2 $3 } | Exp2 '=...' Exp2 { binOp $1 $2 $3 } | Exp2 '`' QualName '`' Exp2 { AppExp (BinOp (second srclocOf $3) NoInfo ($1, NoInfo) ($5, NoInfo) (srcspan $1 $>)) NoInfo } SectionExp :: { UncheckedExp } : '(' '-' ')' { OpSection (qualName (nameFromString "-")) NoInfo (srcspan $1 $>) } | '(' Exp2 '-' ')' { OpSectionLeft (qualName (nameFromString "-")) NoInfo $2 (NoInfo, NoInfo) (NoInfo, NoInfo) (srcspan $1 $>) } | '(' BinOp Exp2 ')' { OpSectionRight (fst $2) NoInfo $3 (NoInfo, NoInfo) NoInfo (srcspan $1 $>) } | '(' Exp2 BinOp ')' { OpSectionLeft (fst $3) NoInfo $2 (NoInfo, NoInfo) (NoInfo, NoInfo) (srcspan $1 $>) } | '(' BinOp ')' { OpSection (fst $2) NoInfo (srcspan $1 $>) } | '(' '.' FieldAccesses_ ')' { ProjectSection (map unLoc $3) NoInfo (srcspan $1 $>) } | '(' '.' '[' DimIndices ']' ')' { IndexSection $4 NoInfo (srcspan $1 $>) } RangeExp :: { UncheckedExp } : Exp2 '...' Exp2 { AppExp (Range $1 Nothing (ToInclusive $3) (srcspan $1 $>)) NoInfo } | Exp2 '..<' Exp2 { AppExp (Range $1 Nothing (UpToExclusive $3) (srcspan $1 $>)) NoInfo } | Exp2 '..>' Exp2 { AppExp (Range $1 Nothing (DownToExclusive $3) (srcspan $1 $>)) NoInfo } | Exp2 '..' Exp2 '...' Exp2 { AppExp (Range $1 (Just $3) (ToInclusive $5) (srcspan $1 $>)) NoInfo } | Exp2 '..' Exp2 '..<' Exp2 { AppExp (Range $1 (Just $3) (UpToExclusive $5) (srcspan $1 $>)) NoInfo } | Exp2 '..' Exp2 '..>' Exp2 { AppExp (Range $1 (Just $3) (DownToExclusive $5) (srcspan $1 $>)) NoInfo } IfExp :: { UncheckedExp } : if Exp then Exp else Exp %prec ifprec { AppExp (If $2 $4 $6 (srcspan $1 $>)) NoInfo } LoopExp :: { UncheckedExp } : loop Pat LoopForm do Exp %prec ifprec { AppExp (Loop [] (fmap (toParam Observe) $2) (LoopInitImplicit NoInfo) $3 $5 (srcspan $1 $>)) NoInfo } | loop Pat '=' Exp LoopForm do Exp %prec ifprec { AppExp (Loop [] (fmap (toParam Observe) $2) (LoopInitExplicit $4) $5 $7 (srcspan $1 $>)) NoInfo } MatchExp :: { UncheckedExp } : match Exp Cases { let loc = srcspan $1 (NE.toList $>) in AppExp (Match $2 $> loc) NoInfo } Cases :: { NE.NonEmpty (CaseBase NoInfo Name) } : Case %prec caseprec { NE.singleton $1 } | Case Cases { NE.cons $1 $2 } Case :: { CaseBase NoInfo Name } : case Pat '->' Exp { let loc = srcspan $1 $> in CasePat $2 $> loc } Pat :: { PatBase NoInfo Name StructType } : '#[' AttrInfo ']' Pat { PatAttr $2 $4 (srcspan $1 $>) } | InnerPat ':' TypeExp { PatAscription $1 $3 (srcspan $1 $>) } | InnerPat { $1 } | Constr ConstrFields { let (n, loc) = $1; loc' = srcspan loc $> in PatConstr n NoInfo $2 loc'} -- Parameter patterns are slightly restricted; see #2017. ParamPat :: { PatBase NoInfo Name StructType } : id { let L loc (ID name) = $1 in Id name NoInfo (srclocOf loc) } | '(' BindingBinOp ')' { Id $2 NoInfo (srcspan $1 $>) } | '_' { Wildcard NoInfo (srclocOf $1) } | '(' ')' { TuplePat [] (srcspan $1 $>) } | '(' Pat ')' { PatParens $2 (srcspan $1 $>) } | '(' Pat ',' Pats1 ')'{ TuplePat ($2:$4) (srcspan $1 $>) } | '{' CFieldPats '}' { RecordPat $2 (srcspan $1 $>) } | PatLiteralNoNeg { PatLit (fst $1) NoInfo (srclocOf (snd $1)) } | Constr { let (n, loc) = $1 in PatConstr n NoInfo [] (srclocOf loc) } Pats1 :: { [PatBase NoInfo Name StructType] } : Pat { [$1] } | Pat ',' { [$1] } | Pat ',' Pats1 { $1 : $3 } InnerPat :: { PatBase NoInfo Name StructType } : id { let L loc (ID name) = $1 in Id name NoInfo (srclocOf loc) } | '(' BindingBinOp ')' { Id $2 NoInfo (srcspan $1 $>) } | '_' { Wildcard NoInfo (srclocOf $1) } | '(' ')' { TuplePat [] (srcspan $1 $>) } | '(' Pat ')' { PatParens $2 (srcspan $1 $>) } | '(' Pat ',' Pats1 ')'{ TuplePat ($2:$4) (srcspan $1 $>) } | '{' CFieldPats '}' { RecordPat $2 (srcspan $1 $>) } | PatLiteral { PatLit (fst $1) NoInfo (srclocOf (snd $1)) } | Constr { let (n, loc) = $1 in PatConstr n NoInfo [] (srclocOf loc) } ConstrFields :: { [PatBase NoInfo Name StructType] } : InnerPat { [$1] } | ConstrFields InnerPat { $1 ++ [$2] } CFieldPat :: { (L Name, PatBase NoInfo Name StructType) } : FieldId '=' Pat { ($1, $3) } | FieldId ':' TypeExp { ($1, PatAscription (Id (unLoc $1) NoInfo (srclocOf $1)) $3 (srcspan $1 $>)) } | FieldId { ($1, Id (unLoc $1) NoInfo (srclocOf $1)) } CFieldPats :: { [(L Name, PatBase NoInfo Name StructType)] } : CFieldPats1 { $1 } | { [] } CFieldPats1 :: { [(L Name, PatBase NoInfo Name StructType)] } : CFieldPat ',' CFieldPats1 { $1 : $3 } | CFieldPat ',' { [$1] } | CFieldPat { [$1] } PatLiteralNoNeg :: { (PatLit, Loc) } : charlit { let L loc (CHARLIT x) = $1 in (PatLitInt (toInteger (ord x)), loc) } | PrimLit { (PatLitPrim (fst $1), snd $1) } | intlit { let L loc (INTLIT x) = $1 in (PatLitInt x, loc) } | natlit { let L loc (NATLIT _ x) = $1 in (PatLitInt x, loc) } | floatlit { let L loc (FLOATLIT x) = $1 in (PatLitFloat x, loc) } PatLiteral :: { (PatLit, Loc) } : PatLiteralNoNeg { $1 } | '-' NumLit %prec bottom { (PatLitPrim (primNegate (fst $2)), locOf (srcspan $1 (snd $2))) } | '-' intlit %prec bottom { let L loc (INTLIT x) = $2 in (PatLitInt (negate x), locOf (srcspan $1 $>)) } | '-' natlit %prec bottom { let L loc (NATLIT _ x) = $2 in (PatLitInt (negate x), locOf (srcspan $1 $>)) } | '-' floatlit { let L loc (FLOATLIT x) = $2 in (PatLitFloat (negate x), locOf (srcspan $1 $>)) } LoopForm :: { LoopFormBase NoInfo Name } LoopForm : for VarId '<' Exp { For $2 $4 } | for Pat in Exp { ForIn $2 $4 } | while Exp { While $2 } DimIndex :: { UncheckedDimIndex } : Exp2 { DimFix $1 } | Exp2 ':' Exp2 { DimSlice (Just $1) (Just $3) Nothing } | Exp2 ':' { DimSlice (Just $1) Nothing Nothing } | ':' Exp2 { DimSlice Nothing (Just $2) Nothing } | ':' { DimSlice Nothing Nothing Nothing } | Exp2 ':' Exp2 ':' Exp2 { DimSlice (Just $1) (Just $3) (Just $5) } | ':' Exp2 ':' Exp2 { DimSlice Nothing (Just $2) (Just $4) } | Exp2 ':' ':' Exp2 { DimSlice (Just $1) Nothing (Just $4) } | ':' ':' Exp2 { DimSlice Nothing Nothing (Just $3) } DimIndices :: { [UncheckedDimIndex] } : { [] } | DimIndex { [$1] } | DimIndex ',' DimIndices { $1 : $3 } VarId :: { IdentBase NoInfo Name StructType } VarId : id { let L loc (ID name) = $1 in Ident name NoInfo (srclocOf loc) } FieldId :: { L Name } : id { let L loc (ID name) = $1 in L loc name } | natlit { let L loc (NATLIT x _) = $1 in L loc x } maybeAscription(p) : ':' p { Just $2 } | { Nothing } AttrAtom :: { (AttrAtom Name, Loc) } : id { let L loc (ID s) = $1 in (AtomName s, loc) } | intlit { let L loc (INTLIT x) = $1 in (AtomInt x, loc) } | natlit { let L loc (NATLIT _ x) = $1 in (AtomInt x, loc) } AttrInfo :: { AttrInfo Name } : AttrAtom { let (x,y) = $1 in AttrAtom x (srclocOf y) } | id '(' Attrs ')' { let L _ (ID s) = $1 in AttrComp s $3 (srcspan $1 $>) } Attrs :: { [AttrInfo Name] } : { [] } | AttrInfo { [$1] } | AttrInfo ',' Attrs { $1 : $3 } futhark-0.25.27/src/Language/Futhark/Prelude.hs000066400000000000000000000014221475065116200212130ustar00rootroot00000000000000{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE TemplateHaskell #-} -- | The Futhark Prelude Library embedded embedded as strings read -- during compilation of the Futhark compiler. The advantage is that -- the prelude can be accessed without reading it from disk, thus -- saving users from include path headaches. module Language.Futhark.Prelude (prelude) where import Data.FileEmbed import Data.Text qualified as T import Data.Text.Encoding qualified as T import Futhark.Util (toPOSIX) import System.FilePath.Posix qualified as Posix -- | Prelude embedded as 'T.Text' values, one for every file. prelude :: [(Posix.FilePath, T.Text)] prelude = map fixup prelude_bs where prelude_bs = $(embedDir "prelude") fixup (path, s) = ("/prelude" Posix. toPOSIX path, T.decodeUtf8 s) futhark-0.25.27/src/Language/Futhark/Pretty.hs000066400000000000000000000534441475065116200211150ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} -- | Futhark prettyprinter. This module defines 'Pretty' instances -- for the AST defined in "Language.Futhark.Syntax". module Language.Futhark.Pretty ( prettyString, prettyTuple, leadingOperator, IsName (..), prettyNameString, Annot (..), ) where import Control.Monad import Data.Char (chr) import Data.Functor import Data.List (intersperse) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Monoid hiding (Sum) import Data.Ord import Data.Text qualified as T import Data.Word import Futhark.Util import Futhark.Util.Pretty import Language.Futhark.Prop import Language.Futhark.Syntax import Prelude -- | A class for types that are variable names in the Futhark source -- language. This is used instead of a mere 'Pretty' instance because -- in the compiler frontend we want to print VNames differently -- depending on whether the FUTHARK_COMPILER_DEBUGGING environment -- variable is set, yet in the backend we want to always print VNames -- with the tag. To avoid erroneously using the 'Pretty' instance for -- VNames, we in fact only define it inside the modules for the core -- language (as an orphan instance). class IsName v where prettyName :: v -> Doc a toName :: v -> Name -- | Depending on the environment variable FUTHARK_COMPILER_DEBUGGING, -- VNames are printed as either the name with an internal tag, or just -- the base name. instance IsName VName where prettyName | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 = \(VName vn i) -> pretty vn <> "_" <> pretty (show i) | otherwise = pretty . baseName toName = baseName instance IsName Name where prettyName = pretty toName = id -- | Prettyprint name as string. Only use this for debugging. prettyNameString :: (IsName v) => v -> String prettyNameString = T.unpack . docText . prettyName -- | Class for type constructors that represent annotations. Used in -- the prettyprinter to either print the original AST, or the computed -- decoration. class Annot f where -- | Extract value, if any. unAnnot :: f a -> Maybe a instance Annot NoInfo where unAnnot = const Nothing instance Annot Info where unAnnot = Just . unInfo instance Pretty PrimValue where pretty (UnsignedValue (Int8Value v)) = pretty (show (fromIntegral v :: Word8)) <> "u8" pretty (UnsignedValue (Int16Value v)) = pretty (show (fromIntegral v :: Word16)) <> "u16" pretty (UnsignedValue (Int32Value v)) = pretty (show (fromIntegral v :: Word32)) <> "u32" pretty (UnsignedValue (Int64Value v)) = pretty (show (fromIntegral v :: Word64)) <> "u64" pretty (SignedValue v) = pretty v pretty (BoolValue True) = "true" pretty (BoolValue False) = "false" pretty (FloatValue v) = pretty v instance (Pretty d) => Pretty (SizeExp d) where pretty SizeExpAny {} = brackets mempty pretty (SizeExp e _) = brackets $ pretty e instance Pretty (Shape Size) where pretty (Shape ds) = mconcat (map (brackets . pretty) ds) instance Pretty (Shape ()) where pretty (Shape ds) = mconcat $ replicate (length ds) "[]" instance Pretty (Shape Int64) where pretty (Shape ds) = mconcat (map (brackets . pretty) ds) instance Pretty (Shape Bool) where pretty (Shape ds) = mconcat (map (brackets . pretty) ds) prettyRetType :: (Pretty (Shape dim), Pretty u) => Int -> RetTypeBase dim u -> Doc a prettyRetType p (RetType [] t) = prettyType p t prettyRetType _ (RetType dims t) = "?" <> mconcat (map (brackets . prettyName) dims) <> "." <> pretty t instance (Pretty (Shape dim), Pretty u) => Pretty (RetTypeBase dim u) where pretty = prettyRetType 0 instance Pretty Diet where pretty Consume = "*" pretty Observe = "" prettyScalarType :: (Pretty (Shape dim), Pretty u) => Int -> ScalarTypeBase dim u -> Doc a prettyScalarType _ (Prim et) = pretty et prettyScalarType p (TypeVar u v targs) = parensIf (not (null targs) && p > 3) $ pretty u <> hsep (pretty v : map (prettyTypeArg 3) targs) prettyScalarType _ (Record fs) | Just ts <- areTupleFields fs = group $ parens $ align $ mconcat $ punctuate ("," <> line) $ map pretty ts | otherwise = group $ braces $ align $ mconcat $ punctuate ("," <> line) fs' where ppField (name, t) = pretty (nameToString name) <> colon <+> align (pretty t) fs' = map ppField $ M.toList fs prettyScalarType p (Arrow _ (Named v) d t1 t2) = parensIf (p > 1) $ parens (prettyName v <> colon <+> pretty d <> align (pretty t1)) <+> "->" <+> prettyRetType 1 t2 prettyScalarType p (Arrow _ Unnamed d t1 t2) = parensIf (p > 1) $ (pretty d <> prettyType 2 t1) <+> "->" <+> prettyRetType 1 t2 prettyScalarType p (Sum cs) = parensIf (p > 0) $ group (align (mconcat $ punctuate (" |" <> line) cs')) where ppConstr (name, fs) = sep $ ("#" <> pretty name) : map (prettyType 2) fs cs' = map ppConstr $ M.toList cs instance (Pretty (Shape dim), Pretty u) => Pretty (ScalarTypeBase dim u) where pretty = prettyScalarType 0 prettyType :: (Pretty (Shape dim), Pretty u) => Int -> TypeBase dim u -> Doc a prettyType _ (Array u shape at) = pretty u <> pretty shape <> align (prettyScalarType 1 at) prettyType p (Scalar t) = prettyScalarType p t instance (Pretty (Shape dim), Pretty u) => Pretty (TypeBase dim u) where pretty = prettyType 0 prettyTypeArg :: (Pretty (Shape dim)) => Int -> TypeArg dim -> Doc a prettyTypeArg _ (TypeArgDim d) = pretty $ Shape [d] prettyTypeArg p (TypeArgType t) = prettyType p t instance Pretty (TypeArg Size) where pretty = prettyTypeArg 0 instance (IsName vn, Pretty d) => Pretty (TypeExp d vn) where pretty (TEUnique t _) = "*" <> pretty t pretty (TEArray d at _) = pretty d <> pretty at pretty (TETuple ts _) = parens $ commasep $ map pretty ts pretty (TERecord fs _) = braces $ commasep $ map ppField fs where ppField (L _ name, t) = prettyName name <> colon <+> pretty t pretty (TEVar name _) = pretty name pretty (TEParens te _) = parens $ pretty te pretty (TEApply t arg _) = pretty t <+> pretty arg pretty (TEArrow (Just v) t1 t2 _) = parens v' <+> "->" <+> pretty t2 where v' = prettyName v <> colon <+> pretty t1 pretty (TEArrow Nothing t1 t2 _) = pretty t1 <+> "->" <+> pretty t2 pretty (TESum cs _) = align $ cat $ punctuate (" |" <> softline) $ map ppConstr cs where ppConstr (name, fs) = "#" <> pretty name <+> sep (map pretty fs) pretty (TEDim dims te _) = "?" <> mconcat (map (brackets . prettyName) dims) <> "." <> pretty te instance (Pretty d, IsName vn) => Pretty (TypeArgExp d vn) where pretty (TypeArgExpSize d) = pretty d pretty (TypeArgExpType t) = pretty t instance (IsName vn) => Pretty (QualName vn) where pretty (QualName names name) = mconcat $ punctuate "." $ map prettyName names ++ [prettyName name] instance (IsName vn) => Pretty (IdentBase f vn t) where pretty = prettyName . identName hasArrayLit :: ExpBase ty vn -> Bool hasArrayLit ArrayLit {} = True hasArrayLit (TupLit es2 _) = any hasArrayLit es2 hasArrayLit _ = False instance (Eq vn, IsName vn, Annot f) => Pretty (DimIndexBase f vn) where pretty (DimFix e) = pretty e pretty (DimSlice i j (Just s)) = maybe mempty pretty i <> ":" <> maybe mempty pretty j <> ":" <> pretty s pretty (DimSlice i (Just j) s) = maybe mempty pretty i <> ":" <> pretty j <> maybe mempty ((":" <>) . pretty) s pretty (DimSlice i Nothing Nothing) = maybe mempty pretty i <> ":" instance (IsName vn) => Pretty (SizeBinder vn) where pretty (SizeBinder v _) = brackets $ prettyName v letBody :: (Eq vn, IsName vn, Annot f) => ExpBase f vn -> Doc a letBody body@(AppExp LetPat {} _) = pretty body letBody body@(AppExp LetFun {} _) = pretty body letBody body = "in" <+> align (pretty body) prettyAppExp :: (Eq vn, IsName vn, Annot f) => Int -> AppExpBase f vn -> Doc a prettyAppExp p (BinOp (bop, _) _ (x, _) (y, _) _) = prettyBinOp p bop x y prettyAppExp _ (Match e cs _) = "match" <+> pretty e (stack . map pretty) (NE.toList cs) prettyAppExp _ (Loop sizeparams pat initexp form loopbody _) = "loop" <+> align ( hsep (map (brackets . prettyName) sizeparams ++ [pretty pat]) <+> equals <+> pretty initexp pretty form "do" ) indent 2 (pretty loopbody) prettyAppExp _ (Index e idxs _) = prettyExp 9 e <> brackets (commasep (map pretty idxs)) prettyAppExp p (LetPat sizes pat e body _) = parensIf (p /= -1) . align $ hsep ("let" : map pretty sizes ++ [align (pretty pat)]) <+> ( if linebreak then equals indent 2 (pretty e) else equals <+> align (pretty e) ) letBody body where linebreak = case e of AppExp {} -> True Coerce {} -> True Attr {} -> True ArrayLit {} -> False Lambda {} -> True _ -> hasArrayLit e prettyAppExp _ (LetFun fname (tparams, params, retdecl, rettype, e) body _) = "let" <+> hsep (prettyName fname : map pretty tparams ++ map pretty params) <> retdecl' <+> equals indent 2 (pretty e) letBody body where retdecl' = case (pretty <$> unAnnot rettype) `mplus` (pretty <$> retdecl) of Just rettype' -> colon <+> align rettype' Nothing -> mempty prettyAppExp _ (LetWith dest src idxs ve body _) | dest == src = "let" <+> pretty dest <> list (map pretty idxs) <+> equals <+> align (pretty ve) letBody body | otherwise = "let" <+> pretty dest <+> equals <+> pretty src <+> "with" <+> brackets (commasep (map pretty idxs)) <+> "=" <+> align (pretty ve) letBody body prettyAppExp p (Range start maybe_step end _) = parensIf (p /= -1) $ pretty start <> maybe mempty ((".." <>) . pretty) maybe_step <> case end of DownToExclusive end' -> "..>" <> pretty end' ToInclusive end' -> "..." <> pretty end' UpToExclusive end' -> "..<" <> pretty end' prettyAppExp _ (If c t f _) = "if" <+> pretty c "then" <+> align (pretty t) "else" <+> align (pretty f) prettyAppExp p (Apply f args _) = parensIf (p >= 10) $ prettyExp 0 f <+> hsep (map (prettyExp 10 . snd) $ NE.toList args) instance (Eq vn, IsName vn, Annot f) => Pretty (AppExpBase f vn) where pretty = prettyAppExp (-1) prettyInst :: (Annot f, Pretty t) => f t -> Doc a prettyInst t = case unAnnot t of Just t' | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 2 -> "@" <> parens (align $ pretty t') _ -> mempty prettyAttr :: (Pretty a) => a -> Doc ann prettyAttr attr = "#[" <> pretty attr <> "]" operatorName :: Name -> Bool operatorName = (`elem` opchars) . T.head . nameToText where opchars :: String opchars = "+-*/%=!><|&^." prettyExp :: (Eq vn, IsName vn, Annot f) => Int -> ExpBase f vn -> Doc a prettyExp _ (Var name t _) -- The first case occurs only for programs that have been normalised -- by the compiler. | operatorName (toName (qualLeaf name)) = parens $ pretty name <> prettyInst t | otherwise = pretty name <> prettyInst t prettyExp _ (Hole t _) = "???" <> prettyInst t prettyExp _ (Parens e _) = align $ parens $ pretty e prettyExp _ (QualParens (v, _) e _) = pretty v <> "." <> align (parens $ pretty e) prettyExp p (Ascript e t _) = parensIf (p /= -1) $ prettyExp 0 e <+> ":" <+> align (pretty t) prettyExp p (Coerce e t _ _) = parensIf (p /= -1) $ prettyExp 0 e <+> ":>" <+> align (pretty t) prettyExp _ (Literal v _) = pretty v prettyExp _ (IntLit v t _) = pretty v <> prettyInst t prettyExp _ (FloatLit v t _) = pretty v <> prettyInst t prettyExp _ (TupLit es _) | any hasArrayLit es = parens $ commastack $ map pretty es | otherwise = parens $ commasep $ map pretty es prettyExp _ (RecordLit fs _) | any fieldArray fs = braces $ commastack $ map pretty fs | otherwise = braces $ commasep $ map pretty fs where fieldArray (RecordFieldExplicit _ e _) = hasArrayLit e fieldArray RecordFieldImplicit {} = False prettyExp _ (ArrayVal vs _ _) = brackets (commasep $ map pretty vs) prettyExp _ (ArrayLit es t _) = brackets (commasep $ map pretty es) <> prettyInst t prettyExp _ (StringLit s _) = pretty $ show $ map (chr . fromIntegral) s prettyExp _ (Project k e _ _) = pretty e <> "." <> pretty k prettyExp _ (Negate e _) = "-" <> pretty e prettyExp _ (Not e _) = "!" <> pretty e prettyExp _ (Update src idxs ve _) = pretty src <+> "with" <+> brackets (commasep (map pretty idxs)) <+> "=" <+> align (pretty ve) prettyExp _ (RecordUpdate src fs ve _ _) = pretty src <+> "with" <+> mconcat (intersperse "." (map pretty fs)) <+> "=" <+> align (pretty ve) prettyExp _ (Assert e1 e2 _ _) = "assert" <+> prettyExp 10 e1 <+> prettyExp 10 e2 prettyExp p (Lambda params body rettype _ _) = parensIf (p /= -1) $ "\\" <> hsep (map pretty params) <> ppAscription rettype <+> "->" indent 2 (align (pretty body)) prettyExp _ (OpSection binop _ _) = parens $ pretty binop prettyExp _ (OpSectionLeft binop _ x _ _ _) = parens $ pretty x <+> ppBinOp binop prettyExp _ (OpSectionRight binop _ x _ _ _) = parens $ ppBinOp binop <+> pretty x prettyExp _ (ProjectSection fields _ _) = parens $ mconcat $ map p fields where p name = "." <> pretty name prettyExp _ (IndexSection idxs _ _) = parens $ "." <> brackets (commasep (map pretty idxs)) prettyExp p (Constr n cs t _) = parensIf (p >= 10) $ "#" <> pretty n <+> sep (map (prettyExp 10) cs) <> prettyInst t prettyExp _ (Attr attr e _) = prettyAttr attr prettyExp (-1) e prettyExp i (AppExp e res) | isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 2, Just (AppRes t ext) <- unAnnot res, not $ null ext = parens (prettyAppExp i e) "@" <> parens (pretty t <> "," <+> brackets (commasep $ map prettyName ext)) | otherwise = prettyAppExp i e instance (Eq vn, IsName vn, Annot f) => Pretty (ExpBase f vn) where pretty = prettyExp (-1) instance (IsName vn) => Pretty (AttrAtom vn) where pretty (AtomName v) = pretty v pretty (AtomInt x) = pretty x instance (IsName vn) => Pretty (AttrInfo vn) where pretty (AttrAtom attr _) = pretty attr pretty (AttrComp f attrs _) = pretty f <> parens (commasep $ map pretty attrs) instance (Eq vn, IsName vn, Annot f) => Pretty (FieldBase f vn) where pretty (RecordFieldExplicit (L _ name) e _) = pretty name <> equals <> pretty e pretty (RecordFieldImplicit (L _ name) _ _) = prettyName name instance (Eq vn, IsName vn, Annot f) => Pretty (CaseBase f vn) where pretty (CasePat p e _) = "case" <+> pretty p <+> "->" indent 2 (pretty e) instance (Eq vn, IsName vn, Annot f) => Pretty (LoopInitBase f vn) where pretty (LoopInitImplicit e) = maybe "_" pretty $ unAnnot e pretty (LoopInitExplicit e) = pretty e instance (Eq vn, IsName vn, Annot f) => Pretty (LoopFormBase f vn) where pretty (For i ubound) = "for" <+> pretty i <+> "<" <+> align (pretty ubound) pretty (ForIn x e) = "for" <+> pretty x <+> "in" <+> pretty e pretty (While cond) = "while" <+> pretty cond instance Pretty PatLit where pretty (PatLitInt x) = pretty x pretty (PatLitFloat f) = pretty f pretty (PatLitPrim v) = pretty v instance (Eq vn, IsName vn, Annot f, Pretty t) => Pretty (PatBase f vn t) where pretty (PatAscription p t _) = pretty p <> colon <+> align (pretty t) pretty (PatParens p _) = parens $ pretty p pretty (Id v t _) = case unAnnot t of Just t' -> parens $ prettyName v <> colon <+> align (pretty t') Nothing -> prettyName v pretty (TuplePat pats _) = parens $ commasep $ map pretty pats pretty (RecordPat fs _) = braces $ commasep $ map ppField fs where ppField (L _ name, t) = prettyName name <> equals <> pretty t pretty (Wildcard t _) = case unAnnot t of Just t' -> parens $ "_" <> colon <+> pretty t' Nothing -> "_" pretty (PatLit e _ _) = pretty e pretty (PatConstr n _ ps _) = "#" <> pretty n <+> sep (map pretty ps) pretty (PatAttr attr p _) = "#[" <> pretty attr <> "]" pretty p ppAscription :: (Pretty t) => Maybe t -> Doc a ppAscription Nothing = mempty ppAscription (Just t) = colon <> align (pretty t) instance (Eq vn, IsName vn, Annot f) => Pretty (ProgBase f vn) where pretty = stack . punctuate line . map pretty . progDecs instance (Eq vn, IsName vn, Annot f) => Pretty (DecBase f vn) where pretty (ValDec dec) = pretty dec pretty (TypeDec dec) = pretty dec pretty (ModTypeDec sig) = pretty sig pretty (ModDec sd) = pretty sd pretty (OpenDec x _) = "open" <+> pretty x pretty (LocalDec dec _) = "local" <+> pretty dec pretty (ImportDec x _ _) = "import" <+> pretty x prettyModExp :: (Eq vn, IsName vn, Annot f) => Int -> ModExpBase f vn -> Doc a prettyModExp _ (ModVar v _) = pretty v prettyModExp _ (ModParens e _) = align $ parens $ pretty e prettyModExp _ (ModImport v _ _) = "import" <+> pretty (show v) prettyModExp _ (ModDecs ds _) = nestedBlock "{" "}" $ stack $ punctuate line $ map pretty ds prettyModExp p (ModApply f a _ _ _) = parensIf (p >= 10) $ prettyModExp 0 f <+> prettyModExp 10 a prettyModExp p (ModAscript me se _ _) = parensIf (p /= -1) $ pretty me <> colon <+> pretty se prettyModExp p (ModLambda param maybe_sig body _) = parensIf (p /= -1) $ "\\" <> pretty param <> maybe_sig' <+> "->" indent 2 (pretty body) where maybe_sig' = case maybe_sig of Nothing -> mempty Just (sig, _) -> colon <+> pretty sig instance (Eq vn, IsName vn, Annot f) => Pretty (ModExpBase f vn) where pretty = prettyModExp (-1) instance Pretty Liftedness where pretty Unlifted = "" pretty SizeLifted = "~" pretty Lifted = "^" instance (Eq vn, IsName vn, Annot f) => Pretty (TypeBindBase f vn) where pretty (TypeBind name l params te rt _ _) = "type" <> pretty l <+> hsep (prettyName name : map pretty params) <+> equals <+> maybe (pretty te) pretty (unAnnot rt) instance (Eq vn, IsName vn) => Pretty (TypeParamBase vn) where pretty (TypeParamDim name _) = brackets $ prettyName name pretty (TypeParamType l name _) = "'" <> pretty l <> prettyName name instance (Eq vn, IsName vn, Annot f) => Pretty (ValBindBase f vn) where pretty (ValBind entry name retdecl rettype tparams args body _ attrs _) = mconcat (map ((<> line) . prettyAttr) attrs) <> fun <+> align ( sep ( prettyName name : map pretty tparams ++ map pretty args ++ retdecl' ++ ["="] ) ) indent 2 (pretty body) where fun | isJust entry = "entry" | otherwise = "def" retdecl' = case (pretty <$> unAnnot rettype) `mplus` (pretty <$> retdecl) of Just rettype' -> [colon <+> align rettype'] Nothing -> mempty instance (Eq vn, IsName vn, Annot f) => Pretty (SpecBase f vn) where pretty (TypeAbbrSpec tpsig) = pretty tpsig pretty (TypeSpec l name ps _ _) = "type" <> pretty l <+> hsep (prettyName name : map pretty ps) pretty (ValSpec name tparams vtype _ _ _) = "val" <+> hsep (prettyName name : map pretty tparams) <> colon <+> pretty vtype pretty (ModSpec name sig _ _) = "module" <+> prettyName name <> colon <+> pretty sig pretty (IncludeSpec e _) = "include" <+> pretty e instance (Eq vn, IsName vn, Annot f) => Pretty (ModTypeExpBase f vn) where pretty (ModTypeVar v _ _) = pretty v pretty (ModTypeParens e _) = parens $ pretty e pretty (ModTypeSpecs ss _) = nestedBlock "{" "}" (stack $ punctuate line $ map pretty ss) pretty (ModTypeWith s (TypeRef v ps td _) _) = pretty s <+> "with" <+> pretty v <+> hsep (map pretty ps) <> " =" <+> pretty td pretty (ModTypeArrow (Just v) e1 e2 _) = parens (prettyName v <> colon <+> pretty e1) <+> "->" <+> pretty e2 pretty (ModTypeArrow Nothing e1 e2 _) = pretty e1 <+> "->" <+> pretty e2 instance (Eq vn, IsName vn, Annot f) => Pretty (ModTypeBindBase f vn) where pretty (ModTypeBind name e _ _) = "module type" <+> prettyName name <+> equals <+> pretty e instance (Eq vn, IsName vn, Annot f) => Pretty (ModParamBase f vn) where pretty (ModParam pname psig _ _) = parens (prettyName pname <> colon <+> pretty psig) instance (Eq vn, IsName vn, Annot f) => Pretty (ModBindBase f vn) where pretty (ModBind name ps sig e _ _) = "module" <+> hsep (prettyName name : map pretty ps) <> sig' <> " =" <+> pretty e where sig' = case sig of Nothing -> mempty Just (s, _) -> " " <> colon <+> pretty s <> " " ppBinOp :: (IsName v) => QualName v -> Doc a ppBinOp bop = case leading of Backtick -> "`" <> pretty bop <> "`" _ -> pretty bop where leading = leadingOperator $ toName $ qualLeaf bop prettyBinOp :: (Eq vn, IsName vn, Annot f) => Int -> QualName vn -> ExpBase f vn -> ExpBase f vn -> Doc a prettyBinOp p bop x y = parensIf (p > symPrecedence) $ prettyExp symPrecedence x <+> bop' <+> prettyExp symRPrecedence y where bop' = case leading of Backtick -> "`" <> pretty bop <> "`" _ -> pretty bop leading = leadingOperator $ toName $ qualLeaf bop symPrecedence = precedence leading symRPrecedence = rprecedence leading precedence PipeRight = -1 precedence PipeLeft = -1 precedence LogAnd = 0 precedence LogOr = 0 precedence Band = 1 precedence Bor = 1 precedence Xor = 1 precedence Equal = 2 precedence NotEqual = 2 precedence Bang = 2 precedence Equ = 2 precedence Less = 2 precedence Leq = 2 precedence Greater = 2 precedence Geq = 2 precedence ShiftL = 3 precedence ShiftR = 3 precedence Plus = 4 precedence Minus = 4 precedence Times = 5 precedence Divide = 5 precedence Mod = 5 precedence Quot = 5 precedence Rem = 5 precedence Pow = 6 precedence Backtick = 9 rprecedence Minus = 10 rprecedence Divide = 10 rprecedence PipeLeft = -1 rprecedence op = precedence op + 1 futhark-0.25.27/src/Language/Futhark/Primitive.hs000066400000000000000000001755531475065116200216040ustar00rootroot00000000000000{-# LANGUAGE LambdaCase #-} -- | Definitions of primitive types, the values that inhabit these -- types, and operations on these values. A primitive value can also -- be called a scalar. -- -- This module diverges from the actual Futhark language in that it -- does not distinguish signed and unsigned types. Further, we allow -- a "unit" type that is only indirectly present in source Futhark in -- the form of empty tuples. module Language.Futhark.Primitive ( -- * Types IntType (..), allIntTypes, FloatType (..), allFloatTypes, PrimType (..), allPrimTypes, module Data.Int, module Data.Word, Half, -- * Values IntValue (..), intValue, intValueType, valueIntegral, FloatValue (..), floatValue, floatValueType, PrimValue (..), primValueType, blankPrimValue, onePrimValue, -- * Operations Overflow (..), Safety (..), UnOp (..), allUnOps, BinOp (..), allBinOps, ConvOp (..), allConvOps, CmpOp (..), allCmpOps, -- ** Unary Operations doUnOp, doComplement, doAbs, doFAbs, doSSignum, doUSignum, -- ** Binary Operations doBinOp, doAdd, doMul, doSDiv, doSMod, doPow, -- ** Conversion Operations doConvOp, doZExt, doSExt, doFPConv, doFPToUI, doFPToSI, doUIToFP, doSIToFP, intToInt64, intToWord64, flipConvOp, -- * Comparison Operations doCmpOp, doCmpEq, doCmpUlt, doCmpUle, doCmpSlt, doCmpSle, doFCmpLt, doFCmpLe, -- * Type Of binOpType, unOpType, cmpOpType, convOpType, -- * Primitive functions primFuns, condFun, isCondFun, -- * Utility zeroIsh, zeroIshInt, oneIsh, oneIshInt, negativeIsh, primBitSize, primByteSize, intByteSize, floatByteSize, commutativeBinOp, associativeBinOp, -- * Prettyprinting convOpFun, prettySigned, ) where import Control.Category import Data.Binary.Get qualified as G import Data.Binary.Put qualified as P import Data.Bits ( complement, countLeadingZeros, countTrailingZeros, popCount, shift, shiftR, xor, (.&.), (.|.), ) import Data.Fixed (mod') -- Weird location. import Data.Int (Int16, Int32, Int64, Int8) import Data.List qualified as L import Data.Map qualified as M import Data.Text qualified as T import Data.Word (Word16, Word32, Word64, Word8) import Foreign.C.Types (CUShort (..)) import Futhark.Util (convFloat) import Futhark.Util.CMath import Futhark.Util.Pretty import Numeric (log1p) import Numeric.Half import Prelude hiding (id, (.)) -- | An integer type, ordered by size. Note that signedness is not a -- property of the type, but a property of the operations performed on -- values of these types. data IntType = Int8 | Int16 | Int32 | Int64 deriving (Eq, Ord, Show, Enum, Bounded) instance Pretty IntType where pretty Int8 = "i8" pretty Int16 = "i16" pretty Int32 = "i32" pretty Int64 = "i64" -- | A list of all integer types. allIntTypes :: [IntType] allIntTypes = [minBound .. maxBound] -- | A floating point type. data FloatType = Float16 | Float32 | Float64 deriving (Eq, Ord, Show, Enum, Bounded) instance Pretty FloatType where pretty Float16 = "f16" pretty Float32 = "f32" pretty Float64 = "f64" -- | A list of all floating-point types. allFloatTypes :: [FloatType] allFloatTypes = [minBound .. maxBound] -- | Low-level primitive types. data PrimType = IntType IntType | FloatType FloatType | Bool | -- | An informationless type - An array of this type takes up no space. Unit deriving (Eq, Ord, Show) instance Enum PrimType where toEnum 0 = IntType Int8 toEnum 1 = IntType Int16 toEnum 2 = IntType Int32 toEnum 3 = IntType Int64 toEnum 4 = FloatType Float16 toEnum 5 = FloatType Float32 toEnum 6 = FloatType Float64 toEnum 7 = Bool toEnum _ = Unit fromEnum (IntType Int8) = 0 fromEnum (IntType Int16) = 1 fromEnum (IntType Int32) = 2 fromEnum (IntType Int64) = 3 fromEnum (FloatType Float16) = 4 fromEnum (FloatType Float32) = 5 fromEnum (FloatType Float64) = 6 fromEnum Bool = 7 fromEnum Unit = 8 instance Bounded PrimType where minBound = IntType Int8 maxBound = Unit instance Pretty PrimType where pretty (IntType t) = pretty t pretty (FloatType t) = pretty t pretty Bool = "bool" pretty Unit = "unit" -- | A list of all primitive types. allPrimTypes :: [PrimType] allPrimTypes = map IntType allIntTypes ++ map FloatType allFloatTypes ++ [Bool, Unit] -- | An integer value. data IntValue = Int8Value !Int8 | Int16Value !Int16 | Int32Value !Int32 | Int64Value !Int64 deriving (Eq, Ord, Show) instance Pretty IntValue where pretty (Int8Value v) = pretty $ show v ++ "i8" pretty (Int16Value v) = pretty $ show v ++ "i16" pretty (Int32Value v) = pretty $ show v ++ "i32" pretty (Int64Value v) = pretty $ show v ++ "i64" -- | Create an t'IntValue' from a type and an 'Integer'. intValue :: (Integral int) => IntType -> int -> IntValue intValue Int8 = Int8Value . fromIntegral intValue Int16 = Int16Value . fromIntegral intValue Int32 = Int32Value . fromIntegral intValue Int64 = Int64Value . fromIntegral -- | The type of an integer value. intValueType :: IntValue -> IntType intValueType Int8Value {} = Int8 intValueType Int16Value {} = Int16 intValueType Int32Value {} = Int32 intValueType Int64Value {} = Int64 -- | Convert an t'IntValue' to any 'Integral' type. valueIntegral :: (Integral int) => IntValue -> int valueIntegral (Int8Value v) = fromIntegral v valueIntegral (Int16Value v) = fromIntegral v valueIntegral (Int32Value v) = fromIntegral v valueIntegral (Int64Value v) = fromIntegral v -- | A floating-point value. data FloatValue = Float16Value !Half | Float32Value !Float | Float64Value !Double deriving (Show) instance Eq FloatValue where Float16Value x == Float16Value y = isNaN x && isNaN y || x == y Float32Value x == Float32Value y = isNaN x && isNaN y || x == y Float64Value x == Float64Value y = isNaN x && isNaN y || x == y _ == _ = False -- The derived Ord instance does not handle NaNs correctly. instance Ord FloatValue where Float16Value x <= Float16Value y = x <= y Float32Value x <= Float32Value y = x <= y Float64Value x <= Float64Value y = x <= y Float16Value _ <= Float32Value _ = True Float16Value _ <= Float64Value _ = True Float32Value _ <= Float16Value _ = False Float32Value _ <= Float64Value _ = True Float64Value _ <= Float16Value _ = False Float64Value _ <= Float32Value _ = False Float16Value x < Float16Value y = x < y Float32Value x < Float32Value y = x < y Float64Value x < Float64Value y = x < y Float16Value _ < Float32Value _ = True Float16Value _ < Float64Value _ = True Float32Value _ < Float16Value _ = False Float32Value _ < Float64Value _ = True Float64Value _ < Float16Value _ = False Float64Value _ < Float32Value _ = False (>) = flip (<) (>=) = flip (<=) instance Pretty FloatValue where pretty (Float16Value v) | isInfinite v, v >= 0 = "f16.inf" | isInfinite v, v < 0 = "-f16.inf" | isNaN v = "f16.nan" | otherwise = pretty $ show v ++ "f16" pretty (Float32Value v) | isInfinite v, v >= 0 = "f32.inf" | isInfinite v, v < 0 = "-f32.inf" | isNaN v = "f32.nan" | otherwise = pretty $ show v ++ "f32" pretty (Float64Value v) | isInfinite v, v >= 0 = "f64.inf" | isInfinite v, v < 0 = "-f64.inf" | isNaN v = "f64.nan" | otherwise = pretty $ show v ++ "f64" -- | Create a t'FloatValue' from a type and a 'Rational'. floatValue :: (Real num) => FloatType -> num -> FloatValue floatValue Float16 = Float16Value . fromRational . toRational floatValue Float32 = Float32Value . fromRational . toRational floatValue Float64 = Float64Value . fromRational . toRational -- | The type of a floating-point value. floatValueType :: FloatValue -> FloatType floatValueType Float16Value {} = Float16 floatValueType Float32Value {} = Float32 floatValueType Float64Value {} = Float64 -- | Non-array values. data PrimValue = IntValue !IntValue | FloatValue !FloatValue | BoolValue !Bool | -- | The only value of type 'Unit'. UnitValue deriving (Eq, Ord, Show) instance Pretty PrimValue where pretty (IntValue v) = pretty v pretty (BoolValue True) = "true" pretty (BoolValue False) = "false" pretty (FloatValue v) = pretty v pretty UnitValue = "()" -- | The type of a basic value. primValueType :: PrimValue -> PrimType primValueType (IntValue v) = IntType $ intValueType v primValueType (FloatValue v) = FloatType $ floatValueType v primValueType BoolValue {} = Bool primValueType UnitValue = Unit -- | A "blank" value of the given primitive type - this is zero, or -- whatever is close to it. Don't depend on this value, but use it -- for e.g. creating arrays to be populated by do-loops. blankPrimValue :: PrimType -> PrimValue blankPrimValue (IntType Int8) = IntValue $ Int8Value 0 blankPrimValue (IntType Int16) = IntValue $ Int16Value 0 blankPrimValue (IntType Int32) = IntValue $ Int32Value 0 blankPrimValue (IntType Int64) = IntValue $ Int64Value 0 blankPrimValue (FloatType Float16) = FloatValue $ Float16Value 0.0 blankPrimValue (FloatType Float32) = FloatValue $ Float32Value 0.0 blankPrimValue (FloatType Float64) = FloatValue $ Float64Value 0.0 blankPrimValue Bool = BoolValue False blankPrimValue Unit = UnitValue -- | A one value of the given primitive type - this is one -- whatever is close to it. onePrimValue :: PrimType -> PrimValue onePrimValue (IntType Int8) = IntValue $ Int8Value 1 onePrimValue (IntType Int16) = IntValue $ Int16Value 1 onePrimValue (IntType Int32) = IntValue $ Int32Value 1 onePrimValue (IntType Int64) = IntValue $ Int64Value 1 onePrimValue (FloatType Float16) = FloatValue $ Float16Value 1.0 onePrimValue (FloatType Float32) = FloatValue $ Float32Value 1.0 onePrimValue (FloatType Float64) = FloatValue $ Float64Value 1.0 onePrimValue Bool = BoolValue True onePrimValue Unit = UnitValue -- | Various unary operators. It is a bit ad-hoc what is a unary -- operator and what is a built-in function. Perhaps these should all -- go away eventually. data UnOp = -- | Flip sign. Logical negation for booleans. Neg PrimType | -- | E.g., @~(~1) = 1@. Complement IntType | -- | @abs(-2) = 2@. Abs IntType | -- | @fabs(-2.0) = 2.0@. FAbs FloatType | -- | Signed sign function: @ssignum(-2)@ = -1. SSignum IntType | -- | Unsigned sign function: @usignum(2)@ = 1. USignum IntType | -- | Floating-point sign function. FSignum FloatType deriving (Eq, Ord, Show) -- | What to do in case of arithmetic overflow. Futhark's semantics -- are that overflow does wraparound, but for generated code (like -- address arithmetic), it can be beneficial for overflow to be -- undefined behaviour, as it allows better optimisation of things -- such as GPU kernels. -- -- Note that all values of this type are considered equal for 'Eq' and -- 'Ord'. data Overflow = OverflowWrap | OverflowUndef deriving (Show) instance Eq Overflow where _ == _ = True instance Ord Overflow where _ `compare` _ = EQ -- | Whether something is safe or unsafe (mostly function calls, and -- in the context of whether operations are dynamically checked). -- When we inline an 'Unsafe' function, we remove all safety checks in -- its body. The 'Ord' instance picks 'Unsafe' as being less than -- 'Safe'. -- -- For operations like integer division, a safe division will not -- explode the computer in case of division by zero, but instead -- return some unspecified value. This always involves a run-time -- check, so generally the unsafe variant is what the compiler will -- insert, but guarded by an explicit assertion elsewhere. Safe -- operations are useful when the optimiser wants to move e.g. a -- division to a location where the divisor may be zero, but where the -- result will only be used when it is non-zero (so it doesn't matter -- what result is provided with a zero divisor, as long as the program -- keeps running). data Safety = Unsafe | Safe deriving (Eq, Ord, Show) -- | Binary operators. These correspond closely to the binary operators in -- LLVM. Most are parametrised by their expected input and output -- types. data BinOp = -- | Integer addition. Add IntType Overflow | -- | Floating-point addition. FAdd FloatType | -- | Integer subtraction. Sub IntType Overflow | -- | Floating-point subtraction. FSub FloatType | -- | Integer multiplication. Mul IntType Overflow | -- | Floating-point multiplication. FMul FloatType | -- | Unsigned integer division. Rounds towards -- negativity infinity. Note: this is different -- from LLVM. UDiv IntType Safety | -- | Unsigned integer division. Rounds towards positive -- infinity. UDivUp IntType Safety | -- | Signed integer division. Rounds towards -- negativity infinity. Note: this is different -- from LLVM. SDiv IntType Safety | -- | Signed integer division. Rounds towards positive -- infinity. SDivUp IntType Safety | -- | Floating-point division. FDiv FloatType | -- | Floating-point modulus. FMod FloatType | -- | Unsigned integer modulus; the countepart to 'UDiv'. UMod IntType Safety | -- | Signed integer modulus; the countepart to 'SDiv'. SMod IntType Safety | -- | Signed integer division. Rounds towards zero. This -- corresponds to the @sdiv@ instruction in LLVM and -- integer division in C. SQuot IntType Safety | -- | Signed integer division. Rounds towards zero. This -- corresponds to the @srem@ instruction in LLVM and -- integer modulo in C. SRem IntType Safety | -- | Returns the smallest of two signed integers. SMin IntType | -- | Returns the smallest of two unsigned integers. UMin IntType | -- | Returns the smallest of two floating-point numbers. FMin FloatType | -- | Returns the greatest of two signed integers. SMax IntType | -- | Returns the greatest of two unsigned integers. UMax IntType | -- | Returns the greatest of two floating-point numbers. FMax FloatType | -- | Left-shift. Shl IntType | -- | Logical right-shift, zero-extended. LShr IntType | -- | Arithmetic right-shift, sign-extended. AShr IntType | -- | Bitwise and. And IntType | -- | Bitwise or. Or IntType | -- | Bitwise exclusive-or. Xor IntType | -- | Integer exponentiation. Pow IntType | -- | Floating-point exponentiation. FPow FloatType | -- | Boolean and - not short-circuiting. LogAnd | -- | Boolean or - not short-circuiting. LogOr deriving (Eq, Ord, Show) -- | Comparison operators are like 'BinOp's, but they always return a -- boolean value. The somewhat ugly constructor names are straight -- out of LLVM. data CmpOp = -- | All types equality. CmpEq PrimType | -- | Unsigned less than. CmpUlt IntType | -- | Unsigned less than or equal. CmpUle IntType | -- | Signed less than. CmpSlt IntType | -- | Signed less than or equal. CmpSle IntType | -- Comparison operators for floating-point values. TODO: extend -- this to handle NaNs and such, like the LLVM fcmp instruction. -- | Floating-point less than. FCmpLt FloatType | -- | Floating-point less than or equal. FCmpLe FloatType | -- Boolean comparison. -- | Boolean less than. CmpLlt | -- | Boolean less than or equal. CmpLle deriving (Eq, Ord, Show) -- | Conversion operators try to generalise the @from t0 x to t1@ -- instructions from LLVM. data ConvOp = -- | Zero-extend the former integer type to the latter. -- If the new type is smaller, the result is a -- truncation. ZExt IntType IntType | -- | Sign-extend the former integer type to the latter. -- If the new type is smaller, the result is a -- truncation. SExt IntType IntType | -- | Convert value of the former floating-point type to -- the latter. If the new type is smaller, the result -- is a truncation. FPConv FloatType FloatType | -- | Convert a floating-point value to the nearest -- unsigned integer (rounding towards zero). FPToUI FloatType IntType | -- | Convert a floating-point value to the nearest -- signed integer (rounding towards zero). FPToSI FloatType IntType | -- | Convert an unsigned integer to a floating-point value. UIToFP IntType FloatType | -- | Convert a signed integer to a floating-point value. SIToFP IntType FloatType | -- | Convert an integer to a boolean value. Zero -- becomes false; anything else is true. IToB IntType | -- | Convert a boolean to an integer. True is converted -- to 1 and False to 0. BToI IntType | -- | Convert a float to a boolean value. Zero becomes false; -- | anything else is true. FToB FloatType | -- | Convert a boolean to a float. True is converted -- to 1 and False to 0. BToF FloatType deriving (Eq, Ord, Show) -- | A list of all unary operators for all types. allUnOps :: [UnOp] allUnOps = map Neg [minBound .. maxBound] ++ map Complement [minBound .. maxBound] ++ map Abs [minBound .. maxBound] ++ map FAbs [minBound .. maxBound] ++ map SSignum [minBound .. maxBound] ++ map USignum [minBound .. maxBound] ++ map FSignum [minBound .. maxBound] -- | A list of all binary operators for all types. allBinOps :: [BinOp] allBinOps = concat [ Add <$> allIntTypes <*> [OverflowWrap, OverflowUndef], map FAdd allFloatTypes, Sub <$> allIntTypes <*> [OverflowWrap, OverflowUndef], map FSub allFloatTypes, Mul <$> allIntTypes <*> [OverflowWrap, OverflowUndef], map FMul allFloatTypes, UDiv <$> allIntTypes <*> [Unsafe, Safe], UDivUp <$> allIntTypes <*> [Unsafe, Safe], SDiv <$> allIntTypes <*> [Unsafe, Safe], SDivUp <$> allIntTypes <*> [Unsafe, Safe], map FDiv allFloatTypes, map FMod allFloatTypes, UMod <$> allIntTypes <*> [Unsafe, Safe], SMod <$> allIntTypes <*> [Unsafe, Safe], SQuot <$> allIntTypes <*> [Unsafe, Safe], SRem <$> allIntTypes <*> [Unsafe, Safe], map SMin allIntTypes, map UMin allIntTypes, map FMin allFloatTypes, map SMax allIntTypes, map UMax allIntTypes, map FMax allFloatTypes, map Shl allIntTypes, map LShr allIntTypes, map AShr allIntTypes, map And allIntTypes, map Or allIntTypes, map Xor allIntTypes, map Pow allIntTypes, map FPow allFloatTypes, [LogAnd, LogOr] ] -- | A list of all comparison operators for all types. allCmpOps :: [CmpOp] allCmpOps = concat [ map CmpEq allPrimTypes, map CmpUlt allIntTypes, map CmpUle allIntTypes, map CmpSlt allIntTypes, map CmpSle allIntTypes, map FCmpLt allFloatTypes, map FCmpLe allFloatTypes, [CmpLlt, CmpLle] ] -- | A list of all conversion operators for all types. allConvOps :: [ConvOp] allConvOps = concat [ ZExt <$> allIntTypes <*> allIntTypes, SExt <$> allIntTypes <*> allIntTypes, FPConv <$> allFloatTypes <*> allFloatTypes, FPToUI <$> allFloatTypes <*> allIntTypes, FPToSI <$> allFloatTypes <*> allIntTypes, UIToFP <$> allIntTypes <*> allFloatTypes, SIToFP <$> allIntTypes <*> allFloatTypes, IToB <$> allIntTypes, BToI <$> allIntTypes, FToB <$> allFloatTypes, BToF <$> allFloatTypes ] -- | Apply an 'UnOp' to an operand. Returns 'Nothing' if the -- application is mistyped. doUnOp :: UnOp -> PrimValue -> Maybe PrimValue doUnOp (Neg _) (BoolValue b) = Just $ BoolValue $ not b doUnOp (Neg _) (FloatValue v) = Just $ FloatValue $ doFNeg v doUnOp (Neg _) (IntValue v) = Just $ IntValue $ doIntNeg v doUnOp Complement {} (IntValue v) = Just $ IntValue $ doComplement v doUnOp Abs {} (IntValue v) = Just $ IntValue $ doAbs v doUnOp FAbs {} (FloatValue v) = Just $ FloatValue $ doFAbs v doUnOp SSignum {} (IntValue v) = Just $ IntValue $ doSSignum v doUnOp USignum {} (IntValue v) = Just $ IntValue $ doUSignum v doUnOp FSignum {} (FloatValue v) = Just $ FloatValue $ doFSignum v doUnOp _ _ = Nothing doFNeg :: FloatValue -> FloatValue doFNeg (Float16Value x) = Float16Value $ negate x doFNeg (Float32Value x) = Float32Value $ negate x doFNeg (Float64Value x) = Float64Value $ negate x doIntNeg :: IntValue -> IntValue doIntNeg (Int8Value x) = Int8Value $ -x doIntNeg (Int16Value x) = Int16Value $ -x doIntNeg (Int32Value x) = Int32Value $ -x doIntNeg (Int64Value x) = Int64Value $ -x -- | E.g., @~(~1) = 1@. doComplement :: IntValue -> IntValue doComplement v = intValue (intValueType v) $ complement $ intToInt64 v -- | @abs(-2) = 2@. doAbs :: IntValue -> IntValue doAbs v = intValue (intValueType v) $ abs $ intToInt64 v -- | @abs(-2.0) = 2.0@. doFAbs :: FloatValue -> FloatValue doFAbs (Float16Value x) = Float16Value $ abs x doFAbs (Float32Value x) = Float32Value $ abs x doFAbs (Float64Value x) = Float64Value $ abs x -- | @ssignum(-2)@ = -1. doSSignum :: IntValue -> IntValue doSSignum v = intValue (intValueType v) $ signum $ intToInt64 v -- | @usignum(-2)@ = -1. doUSignum :: IntValue -> IntValue doUSignum v = intValue (intValueType v) $ signum $ intToWord64 v -- | @fsignum(-2.0)@ = -1.0. doFSignum :: FloatValue -> FloatValue doFSignum (Float16Value v) = Float16Value $ signum v doFSignum (Float32Value v) = Float32Value $ signum v doFSignum (Float64Value v) = Float64Value $ signum v -- | Apply a 'BinOp' to an operand. Returns 'Nothing' if the -- application is mistyped, or outside the domain (e.g. division by -- zero). doBinOp :: BinOp -> PrimValue -> PrimValue -> Maybe PrimValue doBinOp Add {} = doIntBinOp doAdd doBinOp FAdd {} = doFloatBinOp (+) (+) (+) doBinOp Sub {} = doIntBinOp doSub doBinOp FSub {} = doFloatBinOp (-) (-) (-) doBinOp Mul {} = doIntBinOp doMul doBinOp FMul {} = doFloatBinOp (*) (*) (*) doBinOp UDiv {} = doRiskyIntBinOp doUDiv doBinOp UDivUp {} = doRiskyIntBinOp doUDivUp doBinOp SDiv {} = doRiskyIntBinOp doSDiv doBinOp SDivUp {} = doRiskyIntBinOp doSDivUp doBinOp FDiv {} = doFloatBinOp (/) (/) (/) doBinOp FMod {} = doFloatBinOp mod' mod' mod' doBinOp UMod {} = doRiskyIntBinOp doUMod doBinOp SMod {} = doRiskyIntBinOp doSMod doBinOp SQuot {} = doRiskyIntBinOp doSQuot doBinOp SRem {} = doRiskyIntBinOp doSRem doBinOp SMin {} = doIntBinOp doSMin doBinOp UMin {} = doIntBinOp doUMin doBinOp FMin {} = doFloatBinOp fmin fmin fmin where fmin x y | isNaN x = y | isNaN y = x | otherwise = min x y doBinOp SMax {} = doIntBinOp doSMax doBinOp UMax {} = doIntBinOp doUMax doBinOp FMax {} = doFloatBinOp fmax fmax fmax where fmax x y | isNaN x = y | isNaN y = x | otherwise = max x y doBinOp Shl {} = doIntBinOp doShl doBinOp LShr {} = doIntBinOp doLShr doBinOp AShr {} = doIntBinOp doAShr doBinOp And {} = doIntBinOp doAnd doBinOp Or {} = doIntBinOp doOr doBinOp Xor {} = doIntBinOp doXor doBinOp Pow {} = doRiskyIntBinOp doPow doBinOp FPow {} = doFloatBinOp (**) (**) (**) doBinOp LogAnd {} = doBoolBinOp (&&) doBinOp LogOr {} = doBoolBinOp (||) doIntBinOp :: (IntValue -> IntValue -> IntValue) -> PrimValue -> PrimValue -> Maybe PrimValue doIntBinOp f (IntValue v1) (IntValue v2) = Just $ IntValue $ f v1 v2 doIntBinOp _ _ _ = Nothing doRiskyIntBinOp :: (IntValue -> IntValue -> Maybe IntValue) -> PrimValue -> PrimValue -> Maybe PrimValue doRiskyIntBinOp f (IntValue v1) (IntValue v2) = IntValue <$> f v1 v2 doRiskyIntBinOp _ _ _ = Nothing doFloatBinOp :: (Half -> Half -> Half) -> (Float -> Float -> Float) -> (Double -> Double -> Double) -> PrimValue -> PrimValue -> Maybe PrimValue doFloatBinOp f16 _ _ (FloatValue (Float16Value v1)) (FloatValue (Float16Value v2)) = Just $ FloatValue $ Float16Value $ f16 v1 v2 doFloatBinOp _ f32 _ (FloatValue (Float32Value v1)) (FloatValue (Float32Value v2)) = Just $ FloatValue $ Float32Value $ f32 v1 v2 doFloatBinOp _ _ f64 (FloatValue (Float64Value v1)) (FloatValue (Float64Value v2)) = Just $ FloatValue $ Float64Value $ f64 v1 v2 doFloatBinOp _ _ _ _ _ = Nothing doBoolBinOp :: (Bool -> Bool -> Bool) -> PrimValue -> PrimValue -> Maybe PrimValue doBoolBinOp f (BoolValue v1) (BoolValue v2) = Just $ BoolValue $ f v1 v2 doBoolBinOp _ _ _ = Nothing -- | Integer addition. doAdd :: IntValue -> IntValue -> IntValue doAdd v1 v2 = intValue (intValueType v1) $ intToInt64 v1 + intToInt64 v2 -- | Integer subtraction. doSub :: IntValue -> IntValue -> IntValue doSub v1 v2 = intValue (intValueType v1) $ intToInt64 v1 - intToInt64 v2 -- | Integer multiplication. doMul :: IntValue -> IntValue -> IntValue doMul v1 v2 = intValue (intValueType v1) $ intToInt64 v1 * intToInt64 v2 -- | Unsigned integer division. Rounds towards negativity infinity. -- Note: this is different from LLVM. doUDiv :: IntValue -> IntValue -> Maybe IntValue doUDiv v1 v2 | zeroIshInt v2 = Nothing | otherwise = Just . intValue (intValueType v1) $ intToWord64 v1 `div` intToWord64 v2 -- | Unsigned integer division. Rounds towards positive infinity. doUDivUp :: IntValue -> IntValue -> Maybe IntValue doUDivUp v1 v2 | zeroIshInt v2 = Nothing | otherwise = Just . intValue (intValueType v1) $ (intToWord64 v1 + intToWord64 v2 - 1) `div` intToWord64 v2 -- | Signed integer division. Rounds towards negativity infinity. -- Note: this is different from LLVM. doSDiv :: IntValue -> IntValue -> Maybe IntValue doSDiv v1 v2 | zeroIshInt v2 = Nothing | otherwise = Just $ intValue (intValueType v1) $ intToInt64 v1 `div` intToInt64 v2 -- | Signed integer division. Rounds towards positive infinity. doSDivUp :: IntValue -> IntValue -> Maybe IntValue doSDivUp v1 v2 | zeroIshInt v2 = Nothing | otherwise = Just . intValue (intValueType v1) $ (intToInt64 v1 + intToInt64 v2 - 1) `div` intToInt64 v2 -- | Unsigned integer modulus; the countepart to 'UDiv'. doUMod :: IntValue -> IntValue -> Maybe IntValue doUMod v1 v2 | zeroIshInt v2 = Nothing | otherwise = Just $ intValue (intValueType v1) $ intToWord64 v1 `mod` intToWord64 v2 -- | Signed integer modulus; the countepart to 'SDiv'. doSMod :: IntValue -> IntValue -> Maybe IntValue doSMod v1 v2 | zeroIshInt v2 = Nothing | otherwise = Just $ intValue (intValueType v1) $ intToInt64 v1 `mod` intToInt64 v2 -- | Signed integer division. Rounds towards zero. -- This corresponds to the @sdiv@ instruction in LLVM. doSQuot :: IntValue -> IntValue -> Maybe IntValue doSQuot v1 v2 | zeroIshInt v2 = Nothing | otherwise = Just $ intValue (intValueType v1) $ intToInt64 v1 `quot` intToInt64 v2 -- | Signed integer division. Rounds towards zero. -- This corresponds to the @srem@ instruction in LLVM. doSRem :: IntValue -> IntValue -> Maybe IntValue doSRem v1 v2 | zeroIshInt v2 = Nothing | otherwise = Just $ intValue (intValueType v1) $ intToInt64 v1 `rem` intToInt64 v2 -- | Minimum of two signed integers. doSMin :: IntValue -> IntValue -> IntValue doSMin v1 v2 = intValue (intValueType v1) $ intToInt64 v1 `min` intToInt64 v2 -- | Minimum of two unsigned integers. doUMin :: IntValue -> IntValue -> IntValue doUMin v1 v2 = intValue (intValueType v1) $ intToWord64 v1 `min` intToWord64 v2 -- | Maximum of two signed integers. doSMax :: IntValue -> IntValue -> IntValue doSMax v1 v2 = intValue (intValueType v1) $ intToInt64 v1 `max` intToInt64 v2 -- | Maximum of two unsigned integers. doUMax :: IntValue -> IntValue -> IntValue doUMax v1 v2 = intValue (intValueType v1) $ intToWord64 v1 `max` intToWord64 v2 -- | Left-shift. doShl :: IntValue -> IntValue -> IntValue doShl v1 v2 = intValue (intValueType v1) $ intToInt64 v1 `shift` intToInt v2 -- | Logical right-shift, zero-extended. doLShr :: IntValue -> IntValue -> IntValue doLShr v1 v2 = intValue (intValueType v1) $ intToWord64 v1 `shift` negate (intToInt v2) -- | Arithmetic right-shift, sign-extended. doAShr :: IntValue -> IntValue -> IntValue doAShr v1 v2 = intValue (intValueType v1) $ intToInt64 v1 `shift` negate (intToInt v2) -- | Bitwise and. doAnd :: IntValue -> IntValue -> IntValue doAnd v1 v2 = intValue (intValueType v1) $ intToWord64 v1 .&. intToWord64 v2 -- | Bitwise or. doOr :: IntValue -> IntValue -> IntValue doOr v1 v2 = intValue (intValueType v1) $ intToWord64 v1 .|. intToWord64 v2 -- | Bitwise exclusive-or. doXor :: IntValue -> IntValue -> IntValue doXor v1 v2 = intValue (intValueType v1) $ intToWord64 v1 `xor` intToWord64 v2 -- | Signed integer exponentatation. doPow :: IntValue -> IntValue -> Maybe IntValue doPow v1 v2 | negativeIshInt v2 = Nothing | otherwise = Just $ intValue (intValueType v1) $ intToInt64 v1 ^ intToInt64 v2 -- | Apply a 'ConvOp' to an operand. Returns 'Nothing' if the -- application is mistyped. doConvOp :: ConvOp -> PrimValue -> Maybe PrimValue doConvOp (ZExt _ to) (IntValue v) = Just $ IntValue $ doZExt v to doConvOp (SExt _ to) (IntValue v) = Just $ IntValue $ doSExt v to doConvOp (FPConv _ to) (FloatValue v) = Just $ FloatValue $ doFPConv v to doConvOp (FPToUI _ to) (FloatValue v) = Just $ IntValue $ doFPToUI v to doConvOp (FPToSI _ to) (FloatValue v) = Just $ IntValue $ doFPToSI v to doConvOp (UIToFP _ to) (IntValue v) = Just $ FloatValue $ doUIToFP v to doConvOp (SIToFP _ to) (IntValue v) = Just $ FloatValue $ doSIToFP v to doConvOp (IToB _) (IntValue v) = Just $ BoolValue $ intToInt64 v /= 0 doConvOp (BToI to) (BoolValue v) = Just $ IntValue $ intValue to $ if v then 1 else 0 :: Int doConvOp (FToB _) (FloatValue v) = Just $ BoolValue $ floatToDouble v /= 0 doConvOp (BToF to) (BoolValue v) = Just $ FloatValue $ floatValue to $ if v then 1 else 0 :: Double doConvOp _ _ = Nothing -- | Turn the conversion the other way around. Note that most -- conversions are lossy, so there is no guarantee the value will -- round-trip. flipConvOp :: ConvOp -> ConvOp flipConvOp (ZExt from to) = ZExt to from flipConvOp (SExt from to) = SExt to from flipConvOp (FPConv from to) = FPConv to from flipConvOp (FPToUI from to) = UIToFP to from flipConvOp (FPToSI from to) = SIToFP to from flipConvOp (UIToFP from to) = FPToSI to from flipConvOp (SIToFP from to) = FPToSI to from flipConvOp (IToB from) = BToI from flipConvOp (BToI to) = IToB to flipConvOp (FToB from) = BToF from flipConvOp (BToF to) = FToB to -- | Zero-extend the given integer value to the size of the given -- type. If the type is smaller than the given value, the result is a -- truncation. doZExt :: IntValue -> IntType -> IntValue doZExt (Int8Value x) t = intValue t $ toInteger (fromIntegral x :: Word8) doZExt (Int16Value x) t = intValue t $ toInteger (fromIntegral x :: Word16) doZExt (Int32Value x) t = intValue t $ toInteger (fromIntegral x :: Word32) doZExt (Int64Value x) t = intValue t $ toInteger (fromIntegral x :: Word64) -- | Sign-extend the given integer value to the size of the given -- type. If the type is smaller than the given value, the result is a -- truncation. doSExt :: IntValue -> IntType -> IntValue doSExt (Int8Value x) t = intValue t $ toInteger x doSExt (Int16Value x) t = intValue t $ toInteger x doSExt (Int32Value x) t = intValue t $ toInteger x doSExt (Int64Value x) t = intValue t $ toInteger x -- | Convert the former floating-point type to the latter. doFPConv :: FloatValue -> FloatType -> FloatValue doFPConv v Float16 = Float16Value $ floatToHalf v doFPConv v Float32 = Float32Value $ floatToFloat v doFPConv v Float64 = Float64Value $ floatToDouble v -- | Convert a floating-point value to the nearest -- unsigned integer (rounding towards zero). doFPToUI :: FloatValue -> IntType -> IntValue doFPToUI v t = intValue t (truncate $ floatToDouble v :: Word64) -- | Convert a floating-point value to the nearest -- signed integer (rounding towards zero). doFPToSI :: FloatValue -> IntType -> IntValue doFPToSI v t = intValue t (truncate $ floatToDouble v :: Word64) -- | Convert an unsigned integer to a floating-point value. doUIToFP :: IntValue -> FloatType -> FloatValue doUIToFP v t = floatValue t $ intToWord64 v -- | Convert a signed integer to a floating-point value. doSIToFP :: IntValue -> FloatType -> FloatValue doSIToFP v t = floatValue t $ intToInt64 v -- | Apply a 'CmpOp' to an operand. Returns 'Nothing' if the -- application is mistyped. doCmpOp :: CmpOp -> PrimValue -> PrimValue -> Maybe Bool doCmpOp CmpEq {} v1 v2 = Just $ doCmpEq v1 v2 doCmpOp CmpUlt {} (IntValue v1) (IntValue v2) = Just $ doCmpUlt v1 v2 doCmpOp CmpUle {} (IntValue v1) (IntValue v2) = Just $ doCmpUle v1 v2 doCmpOp CmpSlt {} (IntValue v1) (IntValue v2) = Just $ doCmpSlt v1 v2 doCmpOp CmpSle {} (IntValue v1) (IntValue v2) = Just $ doCmpSle v1 v2 doCmpOp FCmpLt {} (FloatValue v1) (FloatValue v2) = Just $ doFCmpLt v1 v2 doCmpOp FCmpLe {} (FloatValue v1) (FloatValue v2) = Just $ doFCmpLe v1 v2 doCmpOp CmpLlt {} (BoolValue v1) (BoolValue v2) = Just $ not v1 && v2 doCmpOp CmpLle {} (BoolValue v1) (BoolValue v2) = Just $ not (v1 && not v2) doCmpOp _ _ _ = Nothing -- | Compare any two primtive values for exact equality. doCmpEq :: PrimValue -> PrimValue -> Bool doCmpEq (FloatValue (Float32Value v1)) (FloatValue (Float32Value v2)) = v1 == v2 doCmpEq (FloatValue (Float64Value v1)) (FloatValue (Float64Value v2)) = v1 == v2 doCmpEq v1 v2 = v1 == v2 -- | Unsigned less than. doCmpUlt :: IntValue -> IntValue -> Bool doCmpUlt v1 v2 = intToWord64 v1 < intToWord64 v2 -- | Unsigned less than or equal. doCmpUle :: IntValue -> IntValue -> Bool doCmpUle v1 v2 = intToWord64 v1 <= intToWord64 v2 -- | Signed less than. doCmpSlt :: IntValue -> IntValue -> Bool doCmpSlt = (<) -- | Signed less than or equal. doCmpSle :: IntValue -> IntValue -> Bool doCmpSle = (<=) -- | Floating-point less than. doFCmpLt :: FloatValue -> FloatValue -> Bool doFCmpLt = (<) -- | Floating-point less than or equal. doFCmpLe :: FloatValue -> FloatValue -> Bool doFCmpLe = (<=) -- | Translate an t'IntValue' to 'Word64'. This is guaranteed to fit. intToWord64 :: IntValue -> Word64 intToWord64 (Int8Value v) = fromIntegral (fromIntegral v :: Word8) intToWord64 (Int16Value v) = fromIntegral (fromIntegral v :: Word16) intToWord64 (Int32Value v) = fromIntegral (fromIntegral v :: Word32) intToWord64 (Int64Value v) = fromIntegral (fromIntegral v :: Word64) -- | Translate an t'IntValue' to t'Int64'. This is guaranteed to fit. intToInt64 :: IntValue -> Int64 intToInt64 (Int8Value v) = fromIntegral v intToInt64 (Int16Value v) = fromIntegral v intToInt64 (Int32Value v) = fromIntegral v intToInt64 (Int64Value v) = fromIntegral v -- | Careful - there is no guarantee this will fit. intToInt :: IntValue -> Int intToInt = fromIntegral . intToInt64 floatToDouble :: FloatValue -> Double floatToDouble (Float16Value v) | isInfinite v, v > 0 = 1 / 0 | isInfinite v, v < 0 = -1 / 0 | isNaN v = 0 / 0 | otherwise = fromRational $ toRational v floatToDouble (Float32Value v) | isInfinite v, v > 0 = 1 / 0 | isInfinite v, v < 0 = -1 / 0 | isNaN v = 0 / 0 | otherwise = fromRational $ toRational v floatToDouble (Float64Value v) = v floatToFloat :: FloatValue -> Float floatToFloat (Float16Value v) | isInfinite v, v > 0 = 1 / 0 | isInfinite v, v < 0 = -1 / 0 | isNaN v = 0 / 0 | otherwise = fromRational $ toRational v floatToFloat (Float32Value v) = v floatToFloat (Float64Value v) | isInfinite v, v > 0 = 1 / 0 | isInfinite v, v < 0 = -1 / 0 | isNaN v = 0 / 0 | otherwise = fromRational $ toRational v floatToHalf :: FloatValue -> Half floatToHalf (Float16Value v) = v floatToHalf (Float32Value v) | isInfinite v, v > 0 = 1 / 0 | isInfinite v, v < 0 = -1 / 0 | isNaN v = 0 / 0 | otherwise = fromRational $ toRational v floatToHalf (Float64Value v) | isInfinite v, v > 0 = 1 / 0 | isInfinite v, v < 0 = -1 / 0 | isNaN v = 0 / 0 | otherwise = fromRational $ toRational v -- | The result type of a binary operator. binOpType :: BinOp -> PrimType binOpType (Add t _) = IntType t binOpType (Sub t _) = IntType t binOpType (Mul t _) = IntType t binOpType (SDiv t _) = IntType t binOpType (SDivUp t _) = IntType t binOpType (SMod t _) = IntType t binOpType (SQuot t _) = IntType t binOpType (SRem t _) = IntType t binOpType (UDiv t _) = IntType t binOpType (UDivUp t _) = IntType t binOpType (UMod t _) = IntType t binOpType (SMin t) = IntType t binOpType (UMin t) = IntType t binOpType (FMin t) = FloatType t binOpType (SMax t) = IntType t binOpType (UMax t) = IntType t binOpType (FMax t) = FloatType t binOpType (Shl t) = IntType t binOpType (LShr t) = IntType t binOpType (AShr t) = IntType t binOpType (And t) = IntType t binOpType (Or t) = IntType t binOpType (Xor t) = IntType t binOpType (Pow t) = IntType t binOpType (FPow t) = FloatType t binOpType LogAnd = Bool binOpType LogOr = Bool binOpType (FAdd t) = FloatType t binOpType (FSub t) = FloatType t binOpType (FMul t) = FloatType t binOpType (FDiv t) = FloatType t binOpType (FMod t) = FloatType t -- | The operand types of a comparison operator. cmpOpType :: CmpOp -> PrimType cmpOpType (CmpEq t) = t cmpOpType (CmpSlt t) = IntType t cmpOpType (CmpSle t) = IntType t cmpOpType (CmpUlt t) = IntType t cmpOpType (CmpUle t) = IntType t cmpOpType (FCmpLt t) = FloatType t cmpOpType (FCmpLe t) = FloatType t cmpOpType CmpLlt = Bool cmpOpType CmpLle = Bool -- | The operand and result type of a unary operator. unOpType :: UnOp -> PrimType unOpType (SSignum t) = IntType t unOpType (USignum t) = IntType t unOpType (Neg t) = t unOpType (Complement t) = IntType t unOpType (Abs t) = IntType t unOpType (FAbs t) = FloatType t unOpType (FSignum t) = FloatType t -- | The input and output types of a conversion operator. convOpType :: ConvOp -> (PrimType, PrimType) convOpType (ZExt from to) = (IntType from, IntType to) convOpType (SExt from to) = (IntType from, IntType to) convOpType (FPConv from to) = (FloatType from, FloatType to) convOpType (FPToUI from to) = (FloatType from, IntType to) convOpType (FPToSI from to) = (FloatType from, IntType to) convOpType (UIToFP from to) = (IntType from, FloatType to) convOpType (SIToFP from to) = (IntType from, FloatType to) convOpType (IToB from) = (IntType from, Bool) convOpType (BToI to) = (Bool, IntType to) convOpType (FToB from) = (FloatType from, Bool) convOpType (BToF to) = (Bool, FloatType to) halfToWord :: Half -> Word16 halfToWord (Half (CUShort x)) = x wordToHalf :: Word16 -> Half wordToHalf = Half . CUShort floatToWord :: Float -> Word32 floatToWord = G.runGet G.getWord32le . P.runPut . P.putFloatle wordToFloat :: Word32 -> Float wordToFloat = G.runGet G.getFloatle . P.runPut . P.putWord32le doubleToWord :: Double -> Word64 doubleToWord = G.runGet G.getWord64le . P.runPut . P.putDoublele wordToDouble :: Word64 -> Double wordToDouble = G.runGet G.getDoublele . P.runPut . P.putWord64le -- | @condFun t@ is the name of the ternary conditional function that -- accepts operands of type @[Bool, t, t]@, and returns either the -- first or second @t@ based on the truth value of the @Bool@. condFun :: PrimType -> T.Text condFun t = "cond_" <> prettyText t -- | Is this the name of a condition function as per 'condFun', and -- for which type? isCondFun :: T.Text -> Maybe PrimType isCondFun v = L.find (\t -> condFun t == v) allPrimTypes -- | A mapping from names of primitive functions to their parameter -- types, their result type, and a function for evaluating them. primFuns :: M.Map T.Text ( [PrimType], PrimType, [PrimValue] -> Maybe PrimValue ) primFuns = M.fromList $ [ f16 "sqrt16" sqrt, f32 "sqrt32" sqrt, f64 "sqrt64" sqrt, -- f16 "cbrt16" $ convFloat . cbrtf . convFloat, f32 "cbrt32" cbrtf, f64 "cbrt64" cbrt, -- f16 "log16" log, f32 "log32" log, f64 "log64" log, -- f16 "log10_16" (logBase 10), f32 "log10_32" (logBase 10), f64 "log10_64" (logBase 10), -- f16 "log1p_16" log1p, f32 "log1p_32" log1p, f64 "log1p_64" log1p, -- f16 "log2_16" (logBase 2), f32 "log2_32" (logBase 2), f64 "log2_64" (logBase 2), -- f16 "exp16" exp, f32 "exp32" exp, f64 "exp64" exp, -- f16 "sin16" sin, f32 "sin32" sin, f64 "sin64" sin, -- f16 "sinh16" sinh, f32 "sinh32" sinh, f64 "sinh64" sinh, -- f16 "cos16" cos, f32 "cos32" cos, f64 "cos64" cos, -- f16 "cosh16" cosh, f32 "cosh32" cosh, f64 "cosh64" cosh, -- f16 "tan16" tan, f32 "tan32" tan, f64 "tan64" tan, -- f16 "tanh16" tanh, f32 "tanh32" tanh, f64 "tanh64" tanh, -- f16 "asin16" asin, f32 "asin32" asin, f64 "asin64" asin, -- f16 "asinh16" asinh, f32 "asinh32" asinh, f64 "asinh64" asinh, -- f16 "acos16" acos, f32 "acos32" acos, f64 "acos64" acos, -- f16 "acosh16" acosh, f32 "acosh32" acosh, f64 "acosh64" acosh, -- f16 "atan16" atan, f32 "atan32" atan, f64 "atan64" atan, -- f16 "atanh16" atanh, f32 "atanh32" atanh, f64 "atanh64" atanh, -- f16 "round16" $ convFloat . roundFloat . convFloat, f32 "round32" roundFloat, f64 "round64" roundDouble, -- f16 "ceil16" $ convFloat . ceilFloat . convFloat, f32 "ceil32" ceilFloat, f64 "ceil64" ceilDouble, -- f16 "floor16" $ convFloat . floorFloat . convFloat, f32 "floor32" floorFloat, f64 "floor64" floorDouble, -- f16_2 "nextafter16" (\x y -> convFloat $ nextafterf (convFloat x) (convFloat y)), f32_2 "nextafter32" nextafterf, f64_2 "nextafter64" nextafter, -- ( "ldexp16", ( [FloatType Float16, IntType Int32], FloatType Float16, \case [FloatValue (Float16Value x), IntValue (Int32Value y)] -> Just $ FloatValue $ Float16Value $ x * (2 ** fromIntegral y) _ -> Nothing ) ), ( "ldexp32", ( [FloatType Float32, IntType Int32], FloatType Float32, \case [FloatValue (Float32Value x), IntValue (Int32Value y)] -> Just $ FloatValue $ Float32Value $ ldexpf x $ fromIntegral y _ -> Nothing ) ), ( "ldexp64", ( [FloatType Float64, IntType Int32], FloatType Float64, \case [FloatValue (Float64Value x), IntValue (Int32Value y)] -> Just $ FloatValue $ Float64Value $ ldexp x $ fromIntegral y _ -> Nothing ) ), -- f16 "gamma16" $ convFloat . tgammaf . convFloat, f32 "gamma32" tgammaf, f64 "gamma64" tgamma, -- f16 "lgamma16" $ convFloat . lgammaf . convFloat, f32 "lgamma32" lgammaf, f64 "lgamma64" lgamma, -- -- f16 "erf16" $ convFloat . erff . convFloat, f32 "erf32" erff, f64 "erf64" erf, -- f16 "erfc16" $ convFloat . erfcf . convFloat, f32 "erfc32" erfcf, f64 "erfc64" erfc, -- f16_2 "copysign16" $ \x y -> convFloat (copysign (convFloat x) (convFloat y)), f32_2 "copysign32" copysignf, f64_2 "copysign64" copysign, -- i8 "clz8" $ IntValue . Int32Value . fromIntegral . countLeadingZeros, i16 "clz16" $ IntValue . Int32Value . fromIntegral . countLeadingZeros, i32 "clz32" $ IntValue . Int32Value . fromIntegral . countLeadingZeros, i64 "clz64" $ IntValue . Int32Value . fromIntegral . countLeadingZeros, i8 "ctz8" $ IntValue . Int32Value . fromIntegral . countTrailingZeros, i16 "ctz16" $ IntValue . Int32Value . fromIntegral . countTrailingZeros, i32 "ctz32" $ IntValue . Int32Value . fromIntegral . countTrailingZeros, i64 "ctz64" $ IntValue . Int32Value . fromIntegral . countTrailingZeros, i8 "popc8" $ IntValue . Int32Value . fromIntegral . popCount, i16 "popc16" $ IntValue . Int32Value . fromIntegral . popCount, i32 "popc32" $ IntValue . Int32Value . fromIntegral . popCount, i64 "popc64" $ IntValue . Int32Value . fromIntegral . popCount, i8_3 "umad_hi8" umad_hi8, i16_3 "umad_hi16" umad_hi16, i32_3 "umad_hi32" umad_hi32, i64_3 "umad_hi64" umad_hi64, i8_2 "umul_hi8" umul_hi8, i16_2 "umul_hi16" umul_hi16, i32_2 "umul_hi32" umul_hi32, i64_2 "umul_hi64" umul_hi64, i8_3 "smad_hi8" smad_hi8, i16_3 "smad_hi16" smad_hi16, i32_3 "smad_hi32" smad_hi32, i64_3 "smad_hi64" smad_hi64, i8_2 "smul_hi8" smul_hi8, i16_2 "smul_hi16" smul_hi16, i32_2 "smul_hi32" smul_hi32, i64_2 "smul_hi64" smul_hi64, -- ( "atan2_16", ( [FloatType Float16, FloatType Float16], FloatType Float16, \case [FloatValue (Float16Value x), FloatValue (Float16Value y)] -> Just $ FloatValue $ Float16Value $ atan2 x y _ -> Nothing ) ), ( "atan2_32", ( [FloatType Float32, FloatType Float32], FloatType Float32, \case [FloatValue (Float32Value x), FloatValue (Float32Value y)] -> Just $ FloatValue $ Float32Value $ atan2 x y _ -> Nothing ) ), ( "atan2_64", ( [FloatType Float64, FloatType Float64], FloatType Float64, \case [FloatValue (Float64Value x), FloatValue (Float64Value y)] -> Just $ FloatValue $ Float64Value $ atan2 x y _ -> Nothing ) ), -- ( "hypot16", ( [FloatType Float16, FloatType Float16], FloatType Float16, \case [FloatValue (Float16Value x), FloatValue (Float16Value y)] -> Just $ FloatValue $ Float16Value $ convFloat $ hypotf (convFloat x) (convFloat y) _ -> Nothing ) ), ( "hypot32", ( [FloatType Float32, FloatType Float32], FloatType Float32, \case [FloatValue (Float32Value x), FloatValue (Float32Value y)] -> Just $ FloatValue $ Float32Value $ hypotf x y _ -> Nothing ) ), ( "hypot64", ( [FloatType Float64, FloatType Float64], FloatType Float64, \case [FloatValue (Float64Value x), FloatValue (Float64Value y)] -> Just $ FloatValue $ Float64Value $ hypot x y _ -> Nothing ) ), ( "isinf16", ( [FloatType Float16], Bool, \case [FloatValue (Float16Value x)] -> Just $ BoolValue $ isInfinite x _ -> Nothing ) ), ( "isinf32", ( [FloatType Float32], Bool, \case [FloatValue (Float32Value x)] -> Just $ BoolValue $ isInfinite x _ -> Nothing ) ), ( "isinf64", ( [FloatType Float64], Bool, \case [FloatValue (Float64Value x)] -> Just $ BoolValue $ isInfinite x _ -> Nothing ) ), ( "isnan16", ( [FloatType Float16], Bool, \case [FloatValue (Float16Value x)] -> Just $ BoolValue $ isNaN x _ -> Nothing ) ), ( "isnan32", ( [FloatType Float32], Bool, \case [FloatValue (Float32Value x)] -> Just $ BoolValue $ isNaN x _ -> Nothing ) ), ( "isnan64", ( [FloatType Float64], Bool, \case [FloatValue (Float64Value x)] -> Just $ BoolValue $ isNaN x _ -> Nothing ) ), ( "to_bits16", ( [FloatType Float16], IntType Int16, \case [FloatValue (Float16Value x)] -> Just $ IntValue $ Int16Value $ fromIntegral $ halfToWord x _ -> Nothing ) ), ( "to_bits32", ( [FloatType Float32], IntType Int32, \case [FloatValue (Float32Value x)] -> Just $ IntValue $ Int32Value $ fromIntegral $ floatToWord x _ -> Nothing ) ), ( "to_bits64", ( [FloatType Float64], IntType Int64, \case [FloatValue (Float64Value x)] -> Just $ IntValue $ Int64Value $ fromIntegral $ doubleToWord x _ -> Nothing ) ), ( "from_bits16", ( [IntType Int16], FloatType Float16, \case [IntValue (Int16Value x)] -> Just $ FloatValue $ Float16Value $ wordToHalf $ fromIntegral x _ -> Nothing ) ), ( "from_bits32", ( [IntType Int32], FloatType Float32, \case [IntValue (Int32Value x)] -> Just $ FloatValue $ Float32Value $ wordToFloat $ fromIntegral x _ -> Nothing ) ), ( "from_bits64", ( [IntType Int64], FloatType Float64, \case [IntValue (Int64Value x)] -> Just $ FloatValue $ Float64Value $ wordToDouble $ fromIntegral x _ -> Nothing ) ), f16_3 "lerp16" (\v0 v1 t -> v0 + (v1 - v0) * max 0 (min 1 t)), f32_3 "lerp32" (\v0 v1 t -> v0 + (v1 - v0) * max 0 (min 1 t)), f64_3 "lerp64" (\v0 v1 t -> v0 + (v1 - v0) * max 0 (min 1 t)), f16_3 "mad16" (\a b c -> a * b + c), f32_3 "mad32" (\a b c -> a * b + c), f64_3 "mad64" (\a b c -> a * b + c), f16_3 "fma16" (\a b c -> a * b + c), f32_3 "fma32" (\a b c -> a * b + c), f64_3 "fma64" (\a b c -> a * b + c) ] <> [ ( condFun t, ( [Bool, t, t], t, \case [BoolValue b, tv, fv] -> Just $ if b then tv else fv _ -> Nothing ) ) | t <- allPrimTypes ] where i8 s f = (s, ([IntType Int8], IntType Int32, i8PrimFun f)) i16 s f = (s, ([IntType Int16], IntType Int32, i16PrimFun f)) i32 s f = (s, ([IntType Int32], IntType Int32, i32PrimFun f)) i64 s f = (s, ([IntType Int64], IntType Int32, i64PrimFun f)) f16 s f = (s, ([FloatType Float16], FloatType Float16, f16PrimFun f)) f32 s f = (s, ([FloatType Float32], FloatType Float32, f32PrimFun f)) f64 s f = (s, ([FloatType Float64], FloatType Float64, f64PrimFun f)) t_2 t s f = (s, ([t, t], t, f)) t_3 t s f = (s, ([t, t, t], t, f)) f16_2 s f = t_2 (FloatType Float16) s (f16PrimFun2 f) f32_2 s f = t_2 (FloatType Float32) s (f32PrimFun2 f) f64_2 s f = t_2 (FloatType Float64) s (f64PrimFun2 f) f16_3 s f = t_3 (FloatType Float16) s (f16PrimFun3 f) f32_3 s f = t_3 (FloatType Float32) s (f32PrimFun3 f) f64_3 s f = t_3 (FloatType Float64) s (f64PrimFun3 f) i8_2 s f = t_2 (IntType Int8) s (i8PrimFun2 f) i16_2 s f = t_2 (IntType Int16) s (i16PrimFun2 f) i32_2 s f = t_2 (IntType Int32) s (i32PrimFun2 f) i64_2 s f = t_2 (IntType Int64) s (i64PrimFun2 f) i8_3 s f = t_3 (IntType Int8) s (i8PrimFun3 f) i16_3 s f = t_3 (IntType Int16) s (i16PrimFun3 f) i32_3 s f = t_3 (IntType Int32) s (i32PrimFun3 f) i64_3 s f = t_3 (IntType Int64) s (i64PrimFun3 f) i8PrimFun f [IntValue (Int8Value x)] = Just $ f x i8PrimFun _ _ = Nothing i16PrimFun f [IntValue (Int16Value x)] = Just $ f x i16PrimFun _ _ = Nothing i32PrimFun f [IntValue (Int32Value x)] = Just $ f x i32PrimFun _ _ = Nothing i64PrimFun f [IntValue (Int64Value x)] = Just $ f x i64PrimFun _ _ = Nothing f16PrimFun f [FloatValue (Float16Value x)] = Just $ FloatValue $ Float16Value $ f x f16PrimFun _ _ = Nothing f32PrimFun f [FloatValue (Float32Value x)] = Just $ FloatValue $ Float32Value $ f x f32PrimFun _ _ = Nothing f64PrimFun f [FloatValue (Float64Value x)] = Just $ FloatValue $ Float64Value $ f x f64PrimFun _ _ = Nothing f16PrimFun2 f [ FloatValue (Float16Value a), FloatValue (Float16Value b) ] = Just $ FloatValue $ Float16Value $ f a b f16PrimFun2 _ _ = Nothing f32PrimFun2 f [ FloatValue (Float32Value a), FloatValue (Float32Value b) ] = Just $ FloatValue $ Float32Value $ f a b f32PrimFun2 _ _ = Nothing f64PrimFun2 f [ FloatValue (Float64Value a), FloatValue (Float64Value b) ] = Just $ FloatValue $ Float64Value $ f a b f64PrimFun2 _ _ = Nothing f16PrimFun3 f [ FloatValue (Float16Value a), FloatValue (Float16Value b), FloatValue (Float16Value c) ] = Just $ FloatValue $ Float16Value $ f a b c f16PrimFun3 _ _ = Nothing f32PrimFun3 f [ FloatValue (Float32Value a), FloatValue (Float32Value b), FloatValue (Float32Value c) ] = Just $ FloatValue $ Float32Value $ f a b c f32PrimFun3 _ _ = Nothing f64PrimFun3 f [ FloatValue (Float64Value a), FloatValue (Float64Value b), FloatValue (Float64Value c) ] = Just $ FloatValue $ Float64Value $ f a b c f64PrimFun3 _ _ = Nothing i8PrimFun2 f [IntValue (Int8Value a), IntValue (Int8Value b)] = Just $ IntValue $ Int8Value $ f a b i8PrimFun2 _ _ = Nothing i16PrimFun2 f [IntValue (Int16Value a), IntValue (Int16Value b)] = Just $ IntValue $ Int16Value $ f a b i16PrimFun2 _ _ = Nothing i32PrimFun2 f [IntValue (Int32Value a), IntValue (Int32Value b)] = Just $ IntValue $ Int32Value $ f a b i32PrimFun2 _ _ = Nothing i64PrimFun2 f [IntValue (Int64Value a), IntValue (Int64Value b)] = Just $ IntValue $ Int64Value $ f a b i64PrimFun2 _ _ = Nothing i8PrimFun3 f [IntValue (Int8Value a), IntValue (Int8Value b), IntValue (Int8Value c)] = Just $ IntValue $ Int8Value $ f a b c i8PrimFun3 _ _ = Nothing i16PrimFun3 f [IntValue (Int16Value a), IntValue (Int16Value b), IntValue (Int16Value c)] = Just $ IntValue $ Int16Value $ f a b c i16PrimFun3 _ _ = Nothing i32PrimFun3 f [IntValue (Int32Value a), IntValue (Int32Value b), IntValue (Int32Value c)] = Just $ IntValue $ Int32Value $ f a b c i32PrimFun3 _ _ = Nothing i64PrimFun3 f [IntValue (Int64Value a), IntValue (Int64Value b), IntValue (Int64Value c)] = Just $ IntValue $ Int64Value $ f a b c i64PrimFun3 _ _ = Nothing -- | Is the given value kind of zero? zeroIsh :: PrimValue -> Bool zeroIsh (IntValue k) = zeroIshInt k zeroIsh (FloatValue (Float16Value k)) = k == 0 zeroIsh (FloatValue (Float32Value k)) = k == 0 zeroIsh (FloatValue (Float64Value k)) = k == 0 zeroIsh (BoolValue False) = True zeroIsh _ = False -- | Is the given value kind of one? oneIsh :: PrimValue -> Bool oneIsh (IntValue k) = oneIshInt k oneIsh (FloatValue (Float16Value k)) = k == 1 oneIsh (FloatValue (Float32Value k)) = k == 1 oneIsh (FloatValue (Float64Value k)) = k == 1 oneIsh (BoolValue True) = True oneIsh _ = False -- | Is the given value kind of negative? negativeIsh :: PrimValue -> Bool negativeIsh (IntValue k) = negativeIshInt k negativeIsh (FloatValue (Float16Value k)) = k < 0 negativeIsh (FloatValue (Float32Value k)) = k < 0 negativeIsh (FloatValue (Float64Value k)) = k < 0 negativeIsh (BoolValue _) = False negativeIsh UnitValue = False -- | Is the given integer value kind of zero? zeroIshInt :: IntValue -> Bool zeroIshInt (Int8Value k) = k == 0 zeroIshInt (Int16Value k) = k == 0 zeroIshInt (Int32Value k) = k == 0 zeroIshInt (Int64Value k) = k == 0 -- | Is the given integer value kind of one? oneIshInt :: IntValue -> Bool oneIshInt (Int8Value k) = k == 1 oneIshInt (Int16Value k) = k == 1 oneIshInt (Int32Value k) = k == 1 oneIshInt (Int64Value k) = k == 1 -- | Is the given integer value kind of negative? negativeIshInt :: IntValue -> Bool negativeIshInt (Int8Value k) = k < 0 negativeIshInt (Int16Value k) = k < 0 negativeIshInt (Int32Value k) = k < 0 negativeIshInt (Int64Value k) = k < 0 -- | The size of a value of a given integer type in eight-bit bytes. intByteSize :: (Num a) => IntType -> a intByteSize Int8 = 1 intByteSize Int16 = 2 intByteSize Int32 = 4 intByteSize Int64 = 8 -- | The size of a value of a given floating-point type in eight-bit bytes. floatByteSize :: (Num a) => FloatType -> a floatByteSize Float16 = 2 floatByteSize Float32 = 4 floatByteSize Float64 = 8 -- | The size of a value of a given primitive type in eight-bit bytes. -- -- Warning: note that this is 0 for 'Unit', but a 'Unit' takes up a -- byte in the binary data format. primByteSize :: (Num a) => PrimType -> a primByteSize (IntType t) = intByteSize t primByteSize (FloatType t) = floatByteSize t primByteSize Bool = 1 primByteSize Unit = 0 -- | The size of a value of a given primitive type in bits. primBitSize :: PrimType -> Int primBitSize = (* 8) . primByteSize -- | True if the given binary operator is commutative. commutativeBinOp :: BinOp -> Bool commutativeBinOp Add {} = True commutativeBinOp FAdd {} = True commutativeBinOp Mul {} = True commutativeBinOp FMul {} = True commutativeBinOp And {} = True commutativeBinOp Or {} = True commutativeBinOp Xor {} = True commutativeBinOp LogOr {} = True commutativeBinOp LogAnd {} = True commutativeBinOp SMax {} = True commutativeBinOp SMin {} = True commutativeBinOp UMax {} = True commutativeBinOp UMin {} = True commutativeBinOp FMax {} = True commutativeBinOp FMin {} = True commutativeBinOp _ = False -- | True if the given binary operator is associative. associativeBinOp :: BinOp -> Bool associativeBinOp Add {} = True associativeBinOp Mul {} = True associativeBinOp And {} = True associativeBinOp Or {} = True associativeBinOp Xor {} = True associativeBinOp LogOr {} = True associativeBinOp LogAnd {} = True associativeBinOp SMax {} = True associativeBinOp SMin {} = True associativeBinOp UMax {} = True associativeBinOp UMin {} = True associativeBinOp FMax {} = True associativeBinOp FMin {} = True associativeBinOp _ = False -- Prettyprinting instances instance Pretty BinOp where pretty (Add t OverflowWrap) = taggedI "add" t pretty (Add t OverflowUndef) = taggedI "add_nw" t pretty (Sub t OverflowWrap) = taggedI "sub" t pretty (Sub t OverflowUndef) = taggedI "sub_nw" t pretty (Mul t OverflowWrap) = taggedI "mul" t pretty (Mul t OverflowUndef) = taggedI "mul_nw" t pretty (FAdd t) = taggedF "fadd" t pretty (FSub t) = taggedF "fsub" t pretty (FMul t) = taggedF "fmul" t pretty (UDiv t Safe) = taggedI "udiv_safe" t pretty (UDiv t Unsafe) = taggedI "udiv" t pretty (UDivUp t Safe) = taggedI "udiv_up_safe" t pretty (UDivUp t Unsafe) = taggedI "udiv_up" t pretty (UMod t Safe) = taggedI "umod_safe" t pretty (UMod t Unsafe) = taggedI "umod" t pretty (SDiv t Safe) = taggedI "sdiv_safe" t pretty (SDiv t Unsafe) = taggedI "sdiv" t pretty (SDivUp t Safe) = taggedI "sdiv_up_safe" t pretty (SDivUp t Unsafe) = taggedI "sdiv_up" t pretty (SMod t Safe) = taggedI "smod_safe" t pretty (SMod t Unsafe) = taggedI "smod" t pretty (SQuot t Safe) = taggedI "squot_safe" t pretty (SQuot t Unsafe) = taggedI "squot" t pretty (SRem t Safe) = taggedI "srem_safe" t pretty (SRem t Unsafe) = taggedI "srem" t pretty (FDiv t) = taggedF "fdiv" t pretty (FMod t) = taggedF "fmod" t pretty (SMin t) = taggedI "smin" t pretty (UMin t) = taggedI "umin" t pretty (FMin t) = taggedF "fmin" t pretty (SMax t) = taggedI "smax" t pretty (UMax t) = taggedI "umax" t pretty (FMax t) = taggedF "fmax" t pretty (Shl t) = taggedI "shl" t pretty (LShr t) = taggedI "lshr" t pretty (AShr t) = taggedI "ashr" t pretty (And t) = taggedI "and" t pretty (Or t) = taggedI "or" t pretty (Xor t) = taggedI "xor" t pretty (Pow t) = taggedI "pow" t pretty (FPow t) = taggedF "fpow" t pretty LogAnd = "logand" pretty LogOr = "logor" instance Pretty CmpOp where pretty (CmpEq t) = "eq_" <> pretty t pretty (CmpUlt t) = taggedI "ult" t pretty (CmpUle t) = taggedI "ule" t pretty (CmpSlt t) = taggedI "slt" t pretty (CmpSle t) = taggedI "sle" t pretty (FCmpLt t) = taggedF "lt" t pretty (FCmpLe t) = taggedF "le" t pretty CmpLlt = "llt" pretty CmpLle = "lle" instance Pretty ConvOp where pretty op = convOp (convOpFun op) from to where (from, to) = convOpType op instance Pretty UnOp where pretty (Neg t) = "neg_" <> pretty t pretty (Abs t) = taggedI "abs" t pretty (FAbs t) = taggedF "fabs" t pretty (SSignum t) = taggedI "ssignum" t pretty (USignum t) = taggedI "usignum" t pretty (FSignum t) = taggedF "fsignum" t pretty (Complement t) = taggedI "complement" t -- | The human-readable name for a 'ConvOp'. This is used to expose -- the 'ConvOp' in the @intrinsics@ module of a Futhark program. convOpFun :: ConvOp -> String convOpFun ZExt {} = "zext" convOpFun SExt {} = "sext" convOpFun FPConv {} = "fpconv" convOpFun FPToUI {} = "fptoui" convOpFun FPToSI {} = "fptosi" convOpFun UIToFP {} = "uitofp" convOpFun SIToFP {} = "sitofp" convOpFun IToB {} = "itob" convOpFun BToI {} = "btoi" convOpFun FToB {} = "ftob" convOpFun BToF {} = "btof" taggedI :: String -> IntType -> Doc a taggedI s Int8 = pretty $ s ++ "8" taggedI s Int16 = pretty $ s ++ "16" taggedI s Int32 = pretty $ s ++ "32" taggedI s Int64 = pretty $ s ++ "64" taggedF :: String -> FloatType -> Doc a taggedF s Float16 = pretty $ s ++ "16" taggedF s Float32 = pretty $ s ++ "32" taggedF s Float64 = pretty $ s ++ "64" convOp :: (Pretty from, Pretty to) => String -> from -> to -> Doc a convOp s from to = pretty s <> "_" <> pretty from <> "_" <> pretty to -- | True if signed. Only makes a difference for integer types. prettySigned :: Bool -> PrimType -> T.Text prettySigned True (IntType it) = T.cons 'u' (T.drop 1 (prettyText it)) prettySigned _ t = prettyText t umul_hi8 :: Int8 -> Int8 -> Int8 umul_hi8 a b = let a' = fromIntegral (fromIntegral a :: Word8) :: Word64 b' = fromIntegral (fromIntegral b :: Word8) :: Word64 in fromIntegral (shiftR (a' * b') 8) umul_hi16 :: Int16 -> Int16 -> Int16 umul_hi16 a b = let a' = fromIntegral (fromIntegral a :: Word16) :: Word64 b' = fromIntegral (fromIntegral b :: Word16) :: Word64 in fromIntegral (shiftR (a' * b') 16) umul_hi32 :: Int32 -> Int32 -> Int32 umul_hi32 a b = let a' = fromIntegral (fromIntegral a :: Word32) :: Word64 b' = fromIntegral (fromIntegral b :: Word32) :: Word64 in fromIntegral (shiftR (a' * b') 32) umul_hi64 :: Int64 -> Int64 -> Int64 umul_hi64 a b = let a' = toInteger (fromIntegral a :: Word64) b' = toInteger (fromIntegral b :: Word64) in fromIntegral (shiftR (a' * b') 64) umad_hi8 :: Int8 -> Int8 -> Int8 -> Int8 umad_hi8 a b c = umul_hi8 a b + c umad_hi16 :: Int16 -> Int16 -> Int16 -> Int16 umad_hi16 a b c = umul_hi16 a b + c umad_hi32 :: Int32 -> Int32 -> Int32 -> Int32 umad_hi32 a b c = umul_hi32 a b + c umad_hi64 :: Int64 -> Int64 -> Int64 -> Int64 umad_hi64 a b c = umul_hi64 a b + c smul_hi8 :: Int8 -> Int8 -> Int8 smul_hi8 a b = let a' = fromIntegral a :: Int64 b' = fromIntegral b :: Int64 in fromIntegral (shiftR (a' * b') 8) smul_hi16 :: Int16 -> Int16 -> Int16 smul_hi16 a b = let a' = fromIntegral a :: Int64 b' = fromIntegral b :: Int64 in fromIntegral (shiftR (a' * b') 16) smul_hi32 :: Int32 -> Int32 -> Int32 smul_hi32 a b = let a' = fromIntegral a :: Int64 b' = fromIntegral b :: Int64 in fromIntegral (shiftR (a' * b') 32) smul_hi64 :: Int64 -> Int64 -> Int64 smul_hi64 a b = let a' = toInteger a b' = toInteger b in fromIntegral (shiftR (a' * b') 64) smad_hi8 :: Int8 -> Int8 -> Int8 -> Int8 smad_hi8 a b c = smul_hi8 a b + c smad_hi16 :: Int16 -> Int16 -> Int16 -> Int16 smad_hi16 a b c = smul_hi16 a b + c smad_hi32 :: Int32 -> Int32 -> Int32 -> Int32 smad_hi32 a b c = smul_hi32 a b + c smad_hi64 :: Int64 -> Int64 -> Int64 -> Int64 smad_hi64 a b c = smul_hi64 a b + c futhark-0.25.27/src/Language/Futhark/Primitive/000077500000000000000000000000001475065116200212305ustar00rootroot00000000000000futhark-0.25.27/src/Language/Futhark/Primitive/Parse.hs000066400000000000000000000061171475065116200226430ustar00rootroot00000000000000-- | Parsers for primitive values and types. module Language.Futhark.Primitive.Parse ( pPrimValue, pPrimType, pFloatType, pIntType, -- * Building blocks constituent, lexeme, keyword, whitespace, ) where import Data.Char (isAlphaNum) import Data.Functor import Data.Text qualified as T import Data.Void import Futhark.Util.Pretty import Language.Futhark.Primitive import Text.Megaparsec import Text.Megaparsec.Char import Text.Megaparsec.Char.Lexer qualified as L -- | Is this character a valid member of an identifier? constituent :: Char -> Bool constituent c = isAlphaNum c || (c `elem` ("_/'+-=!&^.<>*|%" :: String)) -- | Consume whitespace (including skipping line comments). whitespace :: Parsec Void T.Text () whitespace = L.space space1 (L.skipLineComment "--") empty -- | Consume whitespace after the provided parser, if it succeeds. lexeme :: Parsec Void T.Text a -> Parsec Void T.Text a lexeme = try . L.lexeme whitespace -- | @keyword k@ parses @k@, which must not be immediately followed by -- a 'constituent' character. This ensures that @iff@ is not seen as -- the @if@ keyword followed by @f@. Sometimes called the "maximum -- munch" rule. keyword :: T.Text -> Parsec Void T.Text () keyword s = lexeme $ chunk s *> notFollowedBy (satisfy constituent) -- | Parse an integer value. pIntValue :: Parsec Void T.Text IntValue pIntValue = try $ do x <- L.signed (pure ()) L.decimal t <- pIntType pure $ intValue t (x :: Integer) -- | Parse a floating-point value. pFloatValue :: Parsec Void T.Text FloatValue pFloatValue = choice [ pNum, keyword "f16.nan" $> Float16Value (0 / 0), keyword "f16.inf" $> Float16Value (1 / 0), keyword "-f16.inf" $> Float16Value (-1 / 0), keyword "f32.nan" $> Float32Value (0 / 0), keyword "f32.inf" $> Float32Value (1 / 0), keyword "-f32.inf" $> Float32Value (-1 / 0), keyword "f64.nan" $> Float64Value (0 / 0), keyword "f64.inf" $> Float64Value (1 / 0), keyword "-f64.inf" $> Float64Value (-1 / 0) ] where pNum = try $ do x <- L.signed (pure ()) L.float t <- pFloatType pure $ floatValue t (x :: Double) -- | Parse a boolean value. pBoolValue :: Parsec Void T.Text Bool pBoolValue = choice [ keyword "true" $> True, keyword "false" $> False ] -- | Defined in this module for convenience. pPrimValue :: Parsec Void T.Text PrimValue pPrimValue = choice [ FloatValue <$> pFloatValue, IntValue <$> pIntValue, BoolValue <$> pBoolValue, UnitValue <$ try (lexeme "(" *> lexeme ")") ] "primitive value" -- | Parse a floating-point type. pFloatType :: Parsec Void T.Text FloatType pFloatType = choice $ map p allFloatTypes where p t = keyword (prettyText t) $> t -- | Parse an integer type. pIntType :: Parsec Void T.Text IntType pIntType = choice $ map p allIntTypes where p t = keyword (prettyText t) $> t -- | Parse a primitive type. pPrimType :: Parsec Void T.Text PrimType pPrimType = choice [p Bool, p Unit, FloatType <$> pFloatType, IntType <$> pIntType] where p t = keyword (prettyText t) $> t futhark-0.25.27/src/Language/Futhark/Prop.hs000066400000000000000000001552711475065116200205470ustar00rootroot00000000000000-- | This module provides various simple ways to query and manipulate -- fundamental Futhark terms, such as types and values. The intent is to -- keep "Futhark.Language.Syntax" simple, and put whatever embellishments -- we need here. module Language.Futhark.Prop ( -- * Various Intrinsic (..), intrinsics, intrinsicVar, isBuiltin, isBuiltinLoc, maxIntrinsicTag, namesToPrimTypes, qualName, qualify, primValueType, leadingOperator, progImports, decImports, progModuleTypes, identifierReference, prettyStacktrace, progHoles, defaultEntryPoint, paramName, anySize, -- * Queries on expressions typeOf, valBindTypeScheme, valBindBound, funType, stripExp, subExps, similarExps, sameExp, -- * Queries on patterns and params patIdents, patNames, patternMap, patternType, patternStructType, patternParam, patternOrderZero, -- * Queries on types uniqueness, unique, diet, arrayRank, arrayShape, orderZero, unfoldFunType, foldFunType, typeVars, isAccType, -- * Operations on types peelArray, stripArray, arrayOf, arrayOfWithAliases, toStructural, toStruct, toRes, toParam, resToParam, paramToRes, toResRet, setUniqueness, noSizes, traverseDims, DimPos (..), tupleRecord, isTupleRecord, areTupleFields, tupleFields, tupleFieldNames, sortFields, sortConstrs, isTypeParam, isSizeParam, matchDims, -- * Un-typechecked ASTs UncheckedType, UncheckedTypeExp, UncheckedIdent, UncheckedDimIndex, UncheckedSlice, UncheckedExp, UncheckedModExp, UncheckedModTypeExp, UncheckedTypeParam, UncheckedPat, UncheckedValBind, UncheckedTypeBind, UncheckedModTypeBind, UncheckedModBind, UncheckedDec, UncheckedSpec, UncheckedProg, UncheckedCase, -- * Type-checked ASTs Ident, DimIndex, Slice, AppExp, Exp, Pat, ModExp, ModParam, ModTypeExp, ModBind, ModTypeBind, ValBind, Dec, Spec, Prog, TypeBind, StructTypeArg, ScalarType, TypeParam, Case, ) where import Control.Monad import Control.Monad.State import Data.Bifunctor import Data.Bitraversable (bitraverse) import Data.Char import Data.Foldable import Data.List (genericLength, isPrefixOf, sortOn) import Data.List.NonEmpty qualified as NE import Data.Loc (Loc (..), posFile) import Data.Map.Strict qualified as M import Data.Maybe import Data.Ord import Data.Set qualified as S import Data.Text qualified as T import Futhark.Util (maxinum) import Futhark.Util.Pretty import Language.Futhark.Primitive qualified as Primitive import Language.Futhark.Syntax import Language.Futhark.Traversals import Language.Futhark.Tuple import System.FilePath (takeDirectory) -- | The name of the default program entry point (@main@). defaultEntryPoint :: Name defaultEntryPoint = nameFromString "main" -- | Return the dimensionality of a type. For non-arrays, this is -- zero. For a one-dimensional array it is one, for a two-dimensional -- it is two, and so forth. arrayRank :: TypeBase d u -> Int arrayRank = shapeRank . arrayShape -- | Return the shape of a type - for non-arrays, this is 'mempty'. arrayShape :: TypeBase dim as -> Shape dim arrayShape (Array _ ds _) = ds arrayShape _ = mempty -- | Change the shape of a type to be just the rank. noSizes :: TypeBase Size as -> TypeBase () as noSizes = first $ const () -- | Where does this dimension occur? data DimPos = -- | Immediately in the argument to 'traverseDims'. PosImmediate | -- | In a function parameter type. PosParam | -- | In a function return type. PosReturn deriving (Eq, Ord, Show) -- | Perform a traversal (possibly including replacement) on sizes -- that are parameters in a function type, but also including the type -- immediately passed to the function. Also passes along a set of the -- parameter names inside the type that have come in scope at the -- occurrence of the dimension. traverseDims :: forall f fdim tdim als. (Applicative f) => (S.Set VName -> DimPos -> fdim -> f tdim) -> TypeBase fdim als -> f (TypeBase tdim als) traverseDims f = go mempty PosImmediate where go :: forall als'. S.Set VName -> DimPos -> TypeBase fdim als' -> f (TypeBase tdim als') go bound b t@Array {} = bitraverse (f bound b) pure t go bound b (Scalar (Record fields)) = Scalar . Record <$> traverse (go bound b) fields go bound b (Scalar (TypeVar as tn targs)) = Scalar <$> (TypeVar as tn <$> traverse (onTypeArg tn bound b) targs) go bound b (Scalar (Sum cs)) = Scalar . Sum <$> traverse (traverse (go bound b)) cs go _ _ (Scalar (Prim t)) = pure $ Scalar $ Prim t go bound _ (Scalar (Arrow als p u t1 (RetType dims t2))) = Scalar <$> (Arrow als p u <$> go bound' PosParam t1 <*> (RetType dims <$> go bound' PosReturn t2)) where bound' = S.fromList dims <> case p of Named p' -> S.insert p' bound Unnamed -> bound onTypeArg _ bound b (TypeArgDim d) = TypeArgDim <$> f bound b d onTypeArg tn bound b (TypeArgType t) = TypeArgType <$> go bound b' t where b' = if qualLeaf tn == fst intrinsicAcc then b else PosParam -- | Return the uniqueness of a type. uniqueness :: TypeBase shape Uniqueness -> Uniqueness uniqueness (Array u _ _) = u uniqueness (Scalar (TypeVar u _ _)) = u uniqueness (Scalar (Sum ts)) | any (any unique) ts = Unique uniqueness (Scalar (Record fs)) | any unique fs = Unique uniqueness _ = Nonunique -- | @unique t@ is 'True' if the type of the argument is unique. unique :: TypeBase shape Uniqueness -> Bool unique = (== Unique) . uniqueness -- | @diet t@ returns a description of how a function parameter of -- type @t@ consumes its argument. diet :: TypeBase shape Diet -> Diet diet (Scalar (Record ets)) = foldl max Observe $ fmap diet ets diet (Scalar (Prim _)) = Observe diet (Scalar (Arrow {})) = Observe diet (Array d _ _) = d diet (Scalar (TypeVar d _ _)) = d diet (Scalar (Sum cs)) = foldl max Observe $ foldMap (map diet) cs -- | Convert any type to one that has rank information, no alias -- information, and no embedded names. toStructural :: TypeBase dim as -> TypeBase () () toStructural = bimap (const ()) (const ()) -- | Remove uniquenss information from a type. toStruct :: TypeBase dim u -> TypeBase dim NoUniqueness toStruct = second (const NoUniqueness) -- | Uses 'Observe'. toParam :: Diet -> TypeBase Size u -> ParamType toParam d = fmap (const d) -- | Convert to 'ResType' toRes :: Uniqueness -> TypeBase Size u -> ResType toRes u = fmap (const u) -- | Convert to 'ResRetType' toResRet :: Uniqueness -> RetTypeBase Size u -> ResRetType toResRet u = second (const u) -- | Preserves relation between 'Diet' and 'Uniqueness'. resToParam :: ResType -> ParamType resToParam = second f where f Unique = Consume f Nonunique = Observe -- | Preserves relation between 'Diet' and 'Uniqueness'. paramToRes :: ParamType -> ResType paramToRes = second f where f Consume = Unique f Observe = Nonunique -- | @peelArray n t@ returns the type resulting from peeling the first -- @n@ array dimensions from @t@. Returns @Nothing@ if @t@ has less -- than @n@ dimensions. peelArray :: Int -> TypeBase dim u -> Maybe (TypeBase dim u) peelArray n (Array u shape t) | shapeRank shape == n = Just $ second (const u) (Scalar t) | otherwise = Array u <$> stripDims n shape <*> pure t peelArray _ _ = Nothing -- | @arrayOf u s t@ constructs an array type. The convenience -- compared to using the 'Array' constructor directly is that @t@ can -- itself be an array. If @t@ is an @n@-dimensional array, and @s@ is -- a list of length @n@, the resulting type is of an @n+m@ dimensions. arrayOf :: Shape dim -> TypeBase dim NoUniqueness -> TypeBase dim NoUniqueness arrayOf = arrayOfWithAliases mempty -- | Like 'arrayOf', but you can pass in uniqueness info of the -- resulting array. arrayOfWithAliases :: u -> Shape dim -> TypeBase dim u' -> TypeBase dim u arrayOfWithAliases u shape2 (Array _ shape1 et) = Array u (shape2 <> shape1) et arrayOfWithAliases u shape (Scalar t) = Array u shape (second (const mempty) t) -- | @stripArray n t@ removes the @n@ outermost layers of the array. -- Essentially, it is the type of indexing an array of type @t@ with -- @n@ indexes. stripArray :: Int -> TypeBase dim as -> TypeBase dim as stripArray n (Array u shape et) | Just shape' <- stripDims n shape = Array u shape' et | otherwise = second (const u) (Scalar et) stripArray _ t = t -- | Create a record type corresponding to a tuple with the given -- element types. tupleRecord :: [TypeBase dim as] -> ScalarTypeBase dim as tupleRecord = Record . M.fromList . zip tupleFieldNames -- | Does this type corespond to a tuple? If so, return the elements -- of that tuple. isTupleRecord :: TypeBase dim as -> Maybe [TypeBase dim as] isTupleRecord (Scalar (Record fs)) = areTupleFields fs isTupleRecord _ = Nothing -- | Sort the constructors of a sum type in some well-defined (but not -- otherwise significant) manner. sortConstrs :: M.Map Name a -> [(Name, a)] sortConstrs cs = sortOn fst $ M.toList cs -- | Is this a 'TypeParamType'? isTypeParam :: TypeParamBase vn -> Bool isTypeParam TypeParamType {} = True isTypeParam TypeParamDim {} = False -- | Is this a 'TypeParamDim'? isSizeParam :: TypeParamBase vn -> Bool isSizeParam = not . isTypeParam -- | The name, if any. paramName :: PName -> Maybe VName paramName (Named v) = Just v paramName Unnamed = Nothing -- | A special expression representing no known size. When present in -- a type, each instance represents a distinct size. The type checker -- should _never_ produce these - they are a (hopefully temporary) -- thing introduced by defunctorisation and monomorphisation. They -- represent a flaw in our implementation. When they occur in a -- return type, they can be replaced with freshly created existential -- sizes. When they occur in parameter types, they can be replaced -- with size parameters. anySize :: Size anySize = -- The definition here is weird to avoid seeing this as a free -- variable. StringLit [65, 78, 89] mempty -- | Match the dimensions of otherwise assumed-equal types. The -- combining function is also passed the names bound within the type -- (from named parameters or return types). matchDims :: forall as m d1 d2. (Monoid as, Monad m) => ([VName] -> d1 -> d2 -> m d1) -> TypeBase d1 as -> TypeBase d2 as -> m (TypeBase d1 as) matchDims onDims = matchDims' mempty where matchDims' :: forall u'. (Monoid u') => [VName] -> TypeBase d1 u' -> TypeBase d2 u' -> m (TypeBase d1 u') matchDims' bound t1 t2 = case (t1, t2) of (Array u1 shape1 et1, Array u2 shape2 et2) -> arrayOfWithAliases u1 <$> onShapes bound shape1 shape2 <*> matchDims' bound (second (const u2) (Scalar et1)) (second (const u2) (Scalar et2)) (Scalar (Record f1), Scalar (Record f2)) -> Scalar . Record <$> traverse (uncurry (matchDims' bound)) (M.intersectionWith (,) f1 f2) (Scalar (Sum cs1), Scalar (Sum cs2)) -> Scalar . Sum <$> traverse (traverse (uncurry (matchDims' bound))) (M.intersectionWith zip cs1 cs2) ( Scalar (Arrow als1 p1 d1 a1 (RetType dims1 b1)), Scalar (Arrow als2 p2 _d2 a2 (RetType dims2 b2)) ) -> let bound' = mapMaybe paramName [p1, p2] <> dims1 <> dims2 <> bound in Scalar <$> ( Arrow (als1 <> als2) p1 d1 <$> matchDims' bound' a1 a2 <*> (RetType dims1 <$> matchDims' bound' b1 b2) ) ( Scalar (TypeVar als1 v targs1), Scalar (TypeVar als2 _ targs2) ) -> Scalar . TypeVar (als1 <> als2) v <$> zipWithM (matchTypeArg bound) targs1 targs2 _ -> pure t1 matchTypeArg bound (TypeArgType t1) (TypeArgType t2) = TypeArgType <$> matchDims' bound t1 t2 matchTypeArg bound (TypeArgDim x) (TypeArgDim y) = TypeArgDim <$> onDims bound x y matchTypeArg _ a _ = pure a onShapes bound shape1 shape2 = Shape <$> zipWithM (onDims bound) (shapeDims shape1) (shapeDims shape2) -- | Set the uniqueness attribute of a type. If the type is a record -- or sum type, the uniqueness of its components will be modified. setUniqueness :: TypeBase dim u1 -> u2 -> TypeBase dim u2 setUniqueness t u = second (const u) t intValueType :: IntValue -> IntType intValueType Int8Value {} = Int8 intValueType Int16Value {} = Int16 intValueType Int32Value {} = Int32 intValueType Int64Value {} = Int64 floatValueType :: FloatValue -> FloatType floatValueType Float16Value {} = Float16 floatValueType Float32Value {} = Float32 floatValueType Float64Value {} = Float64 -- | The type of a basic value. primValueType :: PrimValue -> PrimType primValueType (SignedValue v) = Signed $ intValueType v primValueType (UnsignedValue v) = Unsigned $ intValueType v primValueType (FloatValue v) = FloatType $ floatValueType v primValueType BoolValue {} = Bool -- | The type of an Futhark term. The aliasing will refer to itself, if -- the term is a non-tuple-typed variable. typeOf :: ExpBase Info VName -> StructType typeOf (Literal val _) = Scalar $ Prim $ primValueType val typeOf (IntLit _ (Info t) _) = t typeOf (FloatLit _ (Info t) _) = t typeOf (Parens e _) = typeOf e typeOf (QualParens _ e _) = typeOf e typeOf (TupLit es _) = Scalar $ tupleRecord $ map typeOf es typeOf (RecordLit fs _) = Scalar $ Record $ M.fromList $ map record fs where record (RecordFieldExplicit (L _ name) e _) = (name, typeOf e) record (RecordFieldImplicit (L _ name) (Info t) _) = (baseName name, t) typeOf (ArrayLit _ (Info t) _) = t typeOf (ArrayVal vs t loc) = Array mempty (Shape [sizeFromInteger (genericLength vs) loc]) (Prim t) typeOf (StringLit vs loc) = Array mempty (Shape [sizeFromInteger (genericLength vs) loc]) (Prim (Unsigned Int8)) typeOf (Project _ _ (Info t) _) = t typeOf (Var _ (Info t) _) = t typeOf (Hole (Info t) _) = t typeOf (Ascript e _ _) = typeOf e typeOf (Coerce _ _ (Info t) _) = t typeOf (Negate e _) = typeOf e typeOf (Not e _) = typeOf e typeOf (Update e _ _ _) = typeOf e typeOf (RecordUpdate _ _ _ (Info t) _) = t typeOf (Assert _ e _ _) = typeOf e typeOf (Lambda params _ _ (Info t) _) = funType params t typeOf (OpSection _ (Info t) _) = t typeOf (OpSectionLeft _ _ _ (_, Info (pn, pt2)) (Info ret, _) _) = Scalar $ Arrow mempty pn (diet pt2) (toStruct pt2) ret typeOf (OpSectionRight _ _ _ (Info (pn, pt1), _) (Info ret) _) = Scalar $ Arrow mempty pn (diet pt1) (toStruct pt1) ret typeOf (ProjectSection _ (Info t) _) = t typeOf (IndexSection _ (Info t) _) = t typeOf (Constr _ _ (Info t) _) = t typeOf (Attr _ e _) = typeOf e typeOf (AppExp _ (Info res)) = appResType res -- | The type of a function with the given parameters and return type. funType :: [Pat ParamType] -> ResRetType -> StructType funType params ret = let RetType _ t = foldr (arrow . patternParam) ret params in toStruct t where arrow (xp, d, xt) yt = RetType [] $ Scalar $ Arrow Nonunique xp d xt yt -- | @foldFunType ts ret@ creates a function type ('Arrow') that takes -- @ts@ as parameters and returns @ret@. foldFunType :: [ParamType] -> ResRetType -> StructType foldFunType ps ret = let RetType _ t = foldr arrow ret ps in toStruct t where arrow t1 t2 = RetType [] $ Scalar $ Arrow Nonunique Unnamed (diet t1) (toStruct t1) t2 -- | Extract the parameter types and return type from a type. -- If the type is not an arrow type, the list of parameter types is empty. unfoldFunType :: TypeBase dim as -> ([TypeBase dim Diet], TypeBase dim NoUniqueness) unfoldFunType (Scalar (Arrow _ _ d t1 (RetType _ t2))) = let (ps, r) = unfoldFunType t2 in (second (const d) t1 : ps, r) unfoldFunType t = ([], toStruct t) -- | The type scheme of a value binding, comprising the type -- parameters and the actual type. valBindTypeScheme :: ValBindBase Info VName -> ([TypeParamBase VName], StructType) valBindTypeScheme vb = ( valBindTypeParams vb, funType (valBindParams vb) (unInfo (valBindRetType vb)) ) -- | The names that are brought into scope by this value binding (not -- including its own parameter names, but including any existential -- sizes). valBindBound :: ValBindBase Info VName -> [VName] valBindBound vb = valBindName vb : case valBindParams vb of [] -> retDims (unInfo (valBindRetType vb)) _ -> [] -- | The type names mentioned in a type. typeVars :: TypeBase dim as -> S.Set VName typeVars t = case t of Scalar Prim {} -> mempty Scalar (TypeVar _ tn targs) -> mconcat $ S.singleton (qualLeaf tn) : map typeArgFree targs Scalar (Arrow _ _ _ t1 (RetType _ t2)) -> typeVars t1 <> typeVars t2 Scalar (Record fields) -> foldMap typeVars fields Scalar (Sum cs) -> mconcat $ (foldMap . fmap) typeVars cs Array _ _ rt -> typeVars $ Scalar rt where typeArgFree (TypeArgType ta) = typeVars ta typeArgFree TypeArgDim {} = mempty -- | @orderZero t@ is 'True' if the argument type has order 0, i.e., it is not -- a function type, does not contain a function type as a subcomponent, and may -- not be instantiated with a function type. orderZero :: TypeBase dim as -> Bool orderZero Array {} = True orderZero (Scalar (Prim _)) = True orderZero (Scalar (Record fs)) = all orderZero $ M.elems fs orderZero (Scalar TypeVar {}) = True orderZero (Scalar Arrow {}) = False orderZero (Scalar (Sum cs)) = all (all orderZero) cs -- | @patternOrderZero pat@ is 'True' if all of the types in the given pattern -- have order 0. patternOrderZero :: Pat (TypeBase d u) -> Bool patternOrderZero = orderZero . patternType -- | The set of identifiers bound in a pattern. patIdents :: PatBase f vn t -> [IdentBase f vn t] patIdents (Id v t loc) = [Ident v t loc] patIdents (PatParens p _) = patIdents p patIdents (TuplePat pats _) = foldMap patIdents pats patIdents (RecordPat fs _) = foldMap (patIdents . snd) fs patIdents Wildcard {} = mempty patIdents (PatAscription p _ _) = patIdents p patIdents PatLit {} = mempty patIdents (PatConstr _ _ ps _) = foldMap patIdents ps patIdents (PatAttr _ p _) = patIdents p -- | The set of names bound in a pattern. patNames :: Pat t -> [VName] patNames = map fst . patternMap -- | Each name bound in a pattern alongside its type. patternMap :: Pat t -> [(VName, t)] patternMap = map f . patIdents where f (Ident v (Info t) _) = (v, t) -- | The type of values bound by the pattern. patternType :: Pat (TypeBase d u) -> TypeBase d u patternType (Wildcard (Info t) _) = t patternType (PatParens p _) = patternType p patternType (Id _ (Info t) _) = t patternType (TuplePat pats _) = Scalar $ tupleRecord $ map patternType pats patternType (RecordPat fs _) = Scalar $ Record $ patternType <$> M.fromList (map (first unLoc) fs) patternType (PatAscription p _ _) = patternType p patternType (PatLit _ (Info t) _) = t patternType (PatConstr _ (Info t) _ _) = t patternType (PatAttr _ p _) = patternType p -- | The type matched by the pattern, including shape declarations if present. patternStructType :: Pat (TypeBase Size u) -> StructType patternStructType = toStruct . patternType -- | When viewed as a function parameter, does this pattern correspond -- to a named parameter of some type? patternParam :: Pat ParamType -> (PName, Diet, StructType) patternParam (PatParens p _) = patternParam p patternParam (PatAttr _ p _) = patternParam p patternParam (PatAscription (Id v (Info t) _) _ _) = (Named v, diet t, toStruct t) patternParam (Id v (Info t) _) = (Named v, diet t, toStruct t) patternParam p = (Unnamed, diet p_t, toStruct p_t) where p_t = patternType p -- | Names of primitive types to types. This is only valid if no -- shadowing is going on, but useful for tools. namesToPrimTypes :: M.Map Name PrimType namesToPrimTypes = M.fromList [ (nameFromString $ prettyString t, t) | t <- Bool : map Signed [minBound .. maxBound] ++ map Unsigned [minBound .. maxBound] ++ map FloatType [minBound .. maxBound] ] -- | The nature of something predefined. For functions, these can -- either be monomorphic or overloaded. An overloaded builtin is a -- list valid types it can be instantiated with, to the parameter and -- result type, with 'Nothing' representing the overloaded parameter -- type. data Intrinsic = IntrinsicMonoFun [PrimType] PrimType | IntrinsicOverloadedFun [PrimType] [Maybe PrimType] (Maybe PrimType) | IntrinsicPolyFun [TypeParamBase VName] [ParamType] (RetTypeBase Size Uniqueness) | IntrinsicType Liftedness [TypeParamBase VName] StructType | IntrinsicEquality -- Special cased. intrinsicAcc :: (VName, Intrinsic) intrinsicAcc = ( acc_v, IntrinsicType SizeLifted [TypeParamType Unlifted t_v mempty] $ Scalar $ TypeVar mempty (qualName acc_v) [arg] ) where acc_v = VName "acc" 10 t_v = VName "t" 11 arg = TypeArgType $ Scalar (TypeVar mempty (qualName t_v) []) -- | If this type corresponds to the builtin "acc" type, return the -- type of the underlying array. isAccType :: TypeBase d u -> Maybe (TypeBase d NoUniqueness) isAccType (Scalar (TypeVar _ (QualName [] v) [TypeArgType t])) | v == fst intrinsicAcc = Just t isAccType _ = Nothing -- | Find the 'VName' corresponding to a builtin. Crashes if that -- name cannot be found. intrinsicVar :: Name -> VName intrinsicVar v = fromMaybe bad $ find ((v ==) . baseName) $ M.keys intrinsics where bad = error $ "findBuiltin: " <> nameToString v mkBinOp :: Name -> StructType -> Exp -> Exp -> Exp mkBinOp op t x y = AppExp ( BinOp (qualName (intrinsicVar op), mempty) (Info t) (x, Info Nothing) (y, Info Nothing) mempty ) (Info $ AppRes t []) mkAdd, mkMul :: Exp -> Exp -> Exp mkAdd = mkBinOp "+" $ Scalar $ Prim $ Signed Int64 mkMul = mkBinOp "*" $ Scalar $ Prim $ Signed Int64 -- | A map of all built-ins. intrinsics :: M.Map VName Intrinsic intrinsics = (M.fromList [intrinsicAcc] <>) $ M.fromList $ primOp ++ zipWith namify [intrinsicStart ..] ( [ ( "manifest", IntrinsicPolyFun [tp_a] [Scalar $ t_a mempty] $ RetType [] $ Scalar $ t_a mempty ), ( "flatten", IntrinsicPolyFun [tp_a, sp_n, sp_m] [Array Observe (shape [n, m]) $ t_a mempty] $ RetType [] $ Array Nonunique (Shape [size n `mkMul` size m]) (t_a mempty) ), ( "unflatten", IntrinsicPolyFun [tp_a, sp_n, sp_m] [ Scalar $ Prim $ Signed Int64, Scalar $ Prim $ Signed Int64, Array Observe (Shape [size n `mkMul` size m]) $ t_a mempty ] $ RetType [] $ Array Nonunique (shape [n, m]) (t_a mempty) ), ( "concat", IntrinsicPolyFun [tp_a, sp_n, sp_m] [ array_a Observe $ shape [n], array_a Observe $ shape [m] ] $ RetType [] $ array_a Unique $ Shape [size n `mkAdd` size m] ), ( "transpose", IntrinsicPolyFun [tp_a, sp_n, sp_m] [array_a Observe $ shape [n, m]] $ RetType [] $ array_a Nonunique $ shape [m, n] ), ( "scatter", IntrinsicPolyFun [tp_a, sp_n, sp_l] [ Array Consume (shape [n]) $ t_a mempty, Array Observe (shape [l]) (Prim $ Signed Int64), Array Observe (shape [l]) $ t_a mempty ] $ RetType [] $ Array Unique (shape [n]) (t_a mempty) ), ( "scatter_2d", IntrinsicPolyFun [tp_a, sp_n, sp_m, sp_l] [ array_a Consume $ shape [n, m], Array Observe (shape [l]) (tupInt64 2), Array Observe (shape [l]) $ t_a mempty ] $ RetType [] $ array_a Unique $ shape [n, m] ), ( "scatter_3d", IntrinsicPolyFun [tp_a, sp_n, sp_m, sp_k, sp_l] [ array_a Consume $ shape [n, m, k], Array Observe (shape [l]) (tupInt64 3), Array Observe (shape [l]) $ t_a mempty ] $ RetType [] $ array_a Unique $ shape [n, m, k] ), ( "zip", IntrinsicPolyFun [tp_a, tp_b, sp_n] [ array_a Observe (shape [n]), array_b Observe (shape [n]) ] $ RetType [] $ tuple_array Unique (Scalar $ t_a mempty) (Scalar $ t_b mempty) $ shape [n] ), ( "unzip", IntrinsicPolyFun [tp_a, tp_b, sp_n] [tuple_array Observe (Scalar $ t_a mempty) (Scalar $ t_b mempty) $ shape [n]] $ RetType [] . Scalar . Record . M.fromList $ zip tupleFieldNames [array_a Unique $ shape [n], array_b Unique $ shape [n]] ), ( "hist_1d", IntrinsicPolyFun [tp_a, sp_n, sp_m] [ Scalar $ Prim $ Signed Int64, array_a Consume $ shape [m], Scalar (t_a mempty) `arr` (Scalar (t_a mempty) `arr` Scalar (t_a Nonunique)), Scalar $ t_a Observe, Array Observe (shape [n]) (tupInt64 1), array_a Observe (shape [n]) ] $ RetType [] $ array_a Unique $ shape [m] ), ( "hist_2d", IntrinsicPolyFun [tp_a, sp_n, sp_m, sp_k] [ Scalar $ Prim $ Signed Int64, array_a Consume $ shape [m, k], Scalar (t_a mempty) `arr` (Scalar (t_a mempty) `arr` Scalar (t_a Nonunique)), Scalar $ t_a Observe, Array Observe (shape [n]) (tupInt64 2), array_a Observe (shape [n]) ] $ RetType [] $ array_a Unique $ shape [m, k] ), ( "hist_3d", IntrinsicPolyFun [tp_a, sp_n, sp_m, sp_k, sp_l] [ Scalar $ Prim $ Signed Int64, array_a Consume $ shape [m, k, l], Scalar (t_a mempty) `arr` (Scalar (t_a mempty) `arr` Scalar (t_a Nonunique)), Scalar $ t_a Observe, Array Observe (shape [n]) (tupInt64 3), array_a Observe (shape [n]) ] $ RetType [] $ array_a Unique $ shape [m, k, l] ), ( "map", IntrinsicPolyFun [tp_a, tp_b, sp_n] [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), array_a Observe $ shape [n] ] $ RetType [] $ array_b Unique $ shape [n] ), ( "reduce", IntrinsicPolyFun [tp_a, sp_n] [ Scalar (t_a mempty) `arr` (Scalar (t_a mempty) `arr` Scalar (t_a Nonunique)), Scalar $ t_a Observe, array_a Observe $ shape [n] ] $ RetType [] $ Scalar (t_a Unique) ), ( "reduce_comm", IntrinsicPolyFun [tp_a, sp_n] [ Scalar (t_a mempty) `arr` (Scalar (t_a mempty) `arr` Scalar (t_a Nonunique)), Scalar $ t_a Observe, array_a Observe $ shape [n] ] $ RetType [] (Scalar (t_a Unique)) ), ( "scan", IntrinsicPolyFun [tp_a, sp_n] [ Scalar (t_a mempty) `arr` (Scalar (t_a mempty) `arr` Scalar (t_a Nonunique)), Scalar $ t_a Observe, array_a Observe $ shape [n] ] $ RetType [] (array_a Unique $ shape [n]) ), ( "partition", IntrinsicPolyFun [tp_a, sp_n] [ Scalar (Prim $ Signed Int32), Scalar (t_a mempty) `arr` Scalar (Prim $ Signed Int64), array_a Observe $ shape [n] ] ( RetType [k] . Scalar $ tupleRecord [ array_a Unique $ shape [n], Array Unique (shape [k]) (Prim $ Signed Int64) ] ) ), ( "acc_write", IntrinsicPolyFun [sp_k, tp_a] [ Scalar $ accType Consume $ array_ka mempty, Scalar (Prim $ Signed Int64), Scalar $ t_a Observe ] $ RetType [] $ Scalar $ accType Unique (array_ka mempty) ), ( "scatter_stream", IntrinsicPolyFun [tp_a, tp_b, sp_k, sp_n] [ array_ka Consume, Scalar (accType mempty (array_ka mempty)) `carr` ( Scalar (t_b mempty) `arr` Scalar (accType Nonunique $ array_a mempty $ shape [k]) ), array_b Observe $ shape [n] ] $ RetType [] $ array_ka Unique ), ( "hist_stream", IntrinsicPolyFun [tp_a, tp_b, sp_k, sp_n] [ array_a Consume $ shape [k], Scalar (t_a mempty) `arr` (Scalar (t_a mempty) `arr` Scalar (t_a Nonunique)), Scalar $ t_a Observe, Scalar (accType mempty $ array_ka mempty) `carr` ( Scalar (t_b mempty) `arr` Scalar (accType Nonunique $ array_a mempty $ shape [k]) ), array_b Observe $ shape [n] ] $ RetType [] $ array_a Unique $ shape [k] ), ( "jvp2", IntrinsicPolyFun [tp_a, tp_b] [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), Scalar (t_a Observe), Scalar (t_a Observe) ] $ RetType [] $ Scalar $ tupleRecord [Scalar $ t_b Nonunique, Scalar $ t_b Nonunique] ), ( "vjp2", IntrinsicPolyFun [tp_a, tp_b] [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), Scalar (t_a Observe), Scalar (t_b Observe) ] $ RetType [] $ Scalar $ tupleRecord [Scalar $ t_b Nonunique, Scalar $ t_a Nonunique] ) ] ++ -- Experimental LMAD ones. [ ( "flat_index_2d", IntrinsicPolyFun [tp_a, sp_n] [ array_a Observe $ shape [n], Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64) ] $ RetType [m, k] $ array_a Nonunique $ shape [m, k] ), ( "flat_update_2d", IntrinsicPolyFun [tp_a, sp_n, sp_k, sp_l] [ array_a Consume $ shape [n], Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), array_a Observe $ shape [k, l] ] $ RetType [] $ array_a Unique $ shape [n] ), ( "flat_index_3d", IntrinsicPolyFun [tp_a, sp_n] [ array_a Observe $ shape [n], Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64) ] $ RetType [m, k, l] $ array_a Nonunique $ shape [m, k, l] ), ( "flat_update_3d", IntrinsicPolyFun [tp_a, sp_n, sp_k, sp_l, sp_p] [ array_a Consume $ shape [n], Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), array_a Observe $ shape [k, l, p] ] $ RetType [] $ array_a Unique $ shape [n] ), ( "flat_index_4d", IntrinsicPolyFun [tp_a, sp_n] [ array_a Observe $ shape [n], Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64) ] $ RetType [m, k, l, p] $ array_a Nonunique $ shape [m, k, l, p] ), ( "flat_update_4d", IntrinsicPolyFun [tp_a, sp_n, sp_k, sp_l, sp_p, sp_q] [ array_a Consume $ shape [n], Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), Scalar (Prim $ Signed Int64), array_a Observe $ shape [k, l, p, q] ] $ RetType [] $ array_a Unique $ shape [n] ) ] ) where primOp = zipWith namify [20 ..] $ map primFun (M.toList Primitive.primFuns) ++ map unOpFun Primitive.allUnOps ++ map binOpFun Primitive.allBinOps ++ map cmpOpFun Primitive.allCmpOps ++ map convOpFun Primitive.allConvOps ++ map signFun Primitive.allIntTypes ++ map unsignFun Primitive.allIntTypes ++ map intrinsicPrim ( map Signed [minBound .. maxBound] ++ map Unsigned [minBound .. maxBound] ++ map FloatType [minBound .. maxBound] ++ [Bool] ) ++ -- This overrides the ! from Primitive. [ ( "!", IntrinsicOverloadedFun ( map Signed [minBound .. maxBound] ++ map Unsigned [minBound .. maxBound] ++ [Bool] ) [Nothing] Nothing ), ( "neg", IntrinsicOverloadedFun ( map Signed [minBound .. maxBound] ++ map Unsigned [minBound .. maxBound] ++ map FloatType [minBound .. maxBound] ++ [Bool] ) [Nothing] Nothing ) ] ++ -- The reason for the loop formulation is to ensure that we -- get a missing case warning if we forget a case. mapMaybe mkIntrinsicBinOp [minBound .. maxBound] intrinsicStart = 1 + baseTag (fst $ last primOp) [a, b, n, m, k, l, p, q] = zipWith VName (map nameFromText ["a", "b", "n", "m", "k", "l", "p", "q"]) [0 ..] t_a u = TypeVar u (qualName a) [] array_a u s = Array u s $ t_a mempty tp_a = TypeParamType Unlifted a mempty t_b u = TypeVar u (qualName b) [] array_b u s = Array u s $ t_b mempty tp_b = TypeParamType Unlifted b mempty [sp_n, sp_m, sp_k, sp_l, sp_p, sp_q] = map (`TypeParamDim` mempty) [n, m, k, l, p, q] size = flip sizeFromName mempty . qualName shape = Shape . map size tuple_array u x y s = Array u s (Record (M.fromList $ zip tupleFieldNames [x, y])) arr x y = Scalar $ Arrow mempty Unnamed Observe x (RetType [] y) carr x y = Scalar $ Arrow mempty Unnamed Consume x (RetType [] y) array_ka u = Array u (Shape [sizeFromName (qualName k) mempty]) $ t_a mempty accType u t = TypeVar u (qualName (fst intrinsicAcc)) [TypeArgType t] namify i (x, y) = (VName (nameFromText x) i, y) primFun (name, (ts, t, _)) = (name, IntrinsicMonoFun (map unPrim ts) $ unPrim t) unOpFun bop = (prettyText bop, IntrinsicMonoFun [t] t) where t = unPrim $ Primitive.unOpType bop binOpFun bop = (prettyText bop, IntrinsicMonoFun [t, t] t) where t = unPrim $ Primitive.binOpType bop cmpOpFun bop = (prettyText bop, IntrinsicMonoFun [t, t] Bool) where t = unPrim $ Primitive.cmpOpType bop convOpFun cop = (prettyText cop, IntrinsicMonoFun [unPrim ft] $ unPrim tt) where (ft, tt) = Primitive.convOpType cop signFun t = ("sign_" <> prettyText t, IntrinsicMonoFun [Unsigned t] $ Signed t) unsignFun t = ("unsign_" <> prettyText t, IntrinsicMonoFun [Signed t] $ Unsigned t) unPrim (Primitive.IntType t) = Signed t unPrim (Primitive.FloatType t) = FloatType t unPrim Primitive.Bool = Bool unPrim Primitive.Unit = Bool intrinsicPrim t = (prettyText t, IntrinsicType Unlifted [] $ Scalar $ Prim t) anyIntType = map Signed [minBound .. maxBound] ++ map Unsigned [minBound .. maxBound] anyNumberType = anyIntType ++ map FloatType [minBound .. maxBound] anyPrimType = Bool : anyNumberType mkIntrinsicBinOp :: BinOp -> Maybe (T.Text, Intrinsic) mkIntrinsicBinOp op = do op' <- intrinsicBinOp op pure (prettyText op, op') binOp ts = Just $ IntrinsicOverloadedFun ts [Nothing, Nothing] Nothing ordering = Just $ IntrinsicOverloadedFun anyPrimType [Nothing, Nothing] (Just Bool) intrinsicBinOp Plus = binOp anyNumberType intrinsicBinOp Minus = binOp anyNumberType intrinsicBinOp Pow = binOp anyNumberType intrinsicBinOp Times = binOp anyNumberType intrinsicBinOp Divide = binOp anyNumberType intrinsicBinOp Mod = binOp anyNumberType intrinsicBinOp Quot = binOp anyIntType intrinsicBinOp Rem = binOp anyIntType intrinsicBinOp ShiftR = binOp anyIntType intrinsicBinOp ShiftL = binOp anyIntType intrinsicBinOp Band = binOp anyIntType intrinsicBinOp Xor = binOp anyIntType intrinsicBinOp Bor = binOp anyIntType intrinsicBinOp LogAnd = binOp [Bool] intrinsicBinOp LogOr = binOp [Bool] intrinsicBinOp Equal = Just IntrinsicEquality intrinsicBinOp NotEqual = Just IntrinsicEquality intrinsicBinOp Less = ordering intrinsicBinOp Leq = ordering intrinsicBinOp Greater = ordering intrinsicBinOp Geq = ordering intrinsicBinOp _ = Nothing tupInt64 1 = Prim $ Signed Int64 tupInt64 x = tupleRecord $ replicate x $ Scalar $ Prim $ Signed Int64 -- | Is this include part of the built-in prelude? isBuiltin :: FilePath -> Bool isBuiltin = (== "/prelude") . takeDirectory -- | Is the position of this thing builtin as per 'isBuiltin'? Things -- without location are considered not built-in. isBuiltinLoc :: (Located a) => a -> Bool isBuiltinLoc x = case locOf x of NoLoc -> False Loc pos _ -> isBuiltin $ posFile pos -- | The largest tag used by an intrinsic - this can be used to -- determine whether a 'VName' refers to an intrinsic or a user-defined name. maxIntrinsicTag :: Int maxIntrinsicTag = maxinum $ map baseTag $ M.keys intrinsics -- | Create a name with no qualifiers from a name. qualName :: v -> QualName v qualName = QualName [] -- | Add another qualifier (at the head) to a qualified name. qualify :: v -> QualName v -> QualName v qualify k (QualName ks v) = QualName (k : ks) v -- | The modules imported by a Futhark program. progImports :: ProgBase f vn -> [(String, Loc)] progImports = concatMap decImports . progDecs -- | The modules imported by a single declaration. decImports :: DecBase f vn -> [(String, Loc)] decImports (OpenDec x _) = modExpImports x decImports (ModDec md) = modExpImports $ modExp md decImports ModTypeDec {} = [] decImports TypeDec {} = [] decImports ValDec {} = [] decImports (LocalDec d _) = decImports d decImports (ImportDec x _ loc) = [(x, locOf loc)] modExpImports :: ModExpBase f vn -> [(String, Loc)] modExpImports ModVar {} = [] modExpImports (ModParens p _) = modExpImports p modExpImports (ModImport f _ loc) = [(f, locOf loc)] modExpImports (ModDecs ds _) = concatMap decImports ds modExpImports (ModApply _ me _ _ _) = modExpImports me modExpImports (ModAscript me _ _ _) = modExpImports me modExpImports ModLambda {} = [] -- | The set of module types used in any exported (non-local) -- declaration. progModuleTypes :: ProgBase Info VName -> S.Set VName progModuleTypes prog = foldMap reach mtypes_used where -- Fixed point iteration. reach v = S.singleton v <> maybe mempty (foldMap reach) (M.lookup v reachable_from_mtype) reachable_from_mtype = foldMap onDec $ progDecs prog where onDec OpenDec {} = mempty onDec ModDec {} = mempty onDec (ModTypeDec sb) = M.singleton (modTypeName sb) (onModTypeExp (modTypeExp sb)) onDec TypeDec {} = mempty onDec ValDec {} = mempty onDec (LocalDec d _) = onDec d onDec ImportDec {} = mempty onModTypeExp (ModTypeVar v _ _) = S.singleton $ qualLeaf v onModTypeExp (ModTypeParens e _) = onModTypeExp e onModTypeExp (ModTypeSpecs ss _) = foldMap onSpec ss onModTypeExp (ModTypeWith e _ _) = onModTypeExp e onModTypeExp (ModTypeArrow _ e1 e2 _) = onModTypeExp e1 <> onModTypeExp e2 onSpec ValSpec {} = mempty onSpec TypeSpec {} = mempty onSpec TypeAbbrSpec {} = mempty onSpec (ModSpec vn e _ _) = S.singleton vn <> onModTypeExp e onSpec (IncludeSpec e _) = onModTypeExp e mtypes_used = foldMap onDec $ progDecs prog where onDec (OpenDec x _) = onModExp x onDec (ModDec md) = maybe mempty (onModTypeExp . fst) (modType md) <> onModExp (modExp md) onDec ModTypeDec {} = mempty onDec TypeDec {} = mempty onDec ValDec {} = mempty onDec LocalDec {} = mempty onDec ImportDec {} = mempty onModExp ModVar {} = mempty onModExp (ModParens p _) = onModExp p onModExp ModImport {} = mempty onModExp (ModDecs ds _) = mconcat $ map onDec ds onModExp (ModApply me1 me2 _ _ _) = onModExp me1 <> onModExp me2 onModExp (ModAscript me se _ _) = onModExp me <> onModTypeExp se onModExp (ModLambda p r me _) = onModParam p <> maybe mempty (onModTypeExp . fst) r <> onModExp me onModParam = onModTypeExp . modParamType onModTypeExp (ModTypeVar v _ _) = S.singleton $ qualLeaf v onModTypeExp (ModTypeParens e _) = onModTypeExp e onModTypeExp ModTypeSpecs {} = mempty onModTypeExp (ModTypeWith e _ _) = onModTypeExp e onModTypeExp (ModTypeArrow _ e1 e2 _) = onModTypeExp e1 <> onModTypeExp e2 -- | Extract a leading @((name, namespace, file), remainder)@ from a -- documentation comment string. These are formatted as -- \`name\`\@namespace[\@file]. Let us hope that this pattern does not occur -- anywhere else. identifierReference :: String -> Maybe ((String, String, Maybe FilePath), String) identifierReference ('`' : s) | (identifier, '`' : '@' : s') <- break (== '`') s, (namespace, s'') <- span isAlpha s', not $ null namespace = case s'' of '@' : '"' : s''' | (file, '"' : s'''') <- span (/= '"') s''' -> Just ((identifier, namespace, Just file), s'''') _ -> Just ((identifier, namespace, Nothing), s'') identifierReference _ = Nothing -- | Given an operator name, return the operator that determines its -- syntactical properties. leadingOperator :: Name -> BinOp leadingOperator s = maybe Backtick snd $ find ((`isPrefixOf` s') . fst) $ sortOn (Down . length . fst) $ zip (map prettyString operators) operators where s' = nameToString s operators :: [BinOp] operators = [minBound .. maxBound :: BinOp] -- | Find instances of typed holes in the program. progHoles :: ProgBase Info VName -> [(Loc, StructType)] progHoles = foldMap holesInDec . progDecs where holesInDec (ValDec vb) = holesInExp $ valBindBody vb holesInDec (ModDec me) = holesInModExp $ modExp me holesInDec (OpenDec me _) = holesInModExp me holesInDec (LocalDec d _) = holesInDec d holesInDec TypeDec {} = mempty holesInDec ModTypeDec {} = mempty holesInDec ImportDec {} = mempty holesInModExp (ModDecs ds _) = foldMap holesInDec ds holesInModExp (ModParens me _) = holesInModExp me holesInModExp (ModApply x y _ _ _) = holesInModExp x <> holesInModExp y holesInModExp (ModAscript me _ _ _) = holesInModExp me holesInModExp (ModLambda _ _ me _) = holesInModExp me holesInModExp ModVar {} = mempty holesInModExp ModImport {} = mempty holesInExp = flip execState mempty . onExp onExp e@(Hole (Info t) loc) = do modify ((locOf loc, toStruct t) :) pure e onExp e = astMap (identityMapper {mapOnExp = onExp}) e -- | Strip semantically irrelevant stuff from the top level of -- expression. This is used to provide a slightly fuzzy notion of -- expression equality. -- -- Ideally we'd implement unification on a simpler representation that -- simply didn't allow us. stripExp :: Exp -> Maybe Exp stripExp (Parens e _) = stripExp e `mplus` Just e stripExp (Assert _ e _ _) = stripExp e `mplus` Just e stripExp (Attr _ e _) = stripExp e `mplus` Just e stripExp (Ascript e _ _) = stripExp e `mplus` Just e stripExp _ = Nothing -- | All non-trivial subexpressions (as by stripExp) of some -- expression, not including the expression itself. subExps :: Exp -> [Exp] subExps e | Just e' <- stripExp e = subExps e' | otherwise = astMap mapper e `execState` mempty where mapOnExp e' | Just e'' <- stripExp e' = mapOnExp e'' | otherwise = do modify (e' :) astMap mapper e' mapper = identityMapper {mapOnExp} similarSlices :: Slice -> Slice -> Maybe [(Exp, Exp)] similarSlices slice1 slice2 | length slice1 == length slice2 = do concat <$> zipWithM match slice1 slice2 | otherwise = Nothing where match (DimFix e1) (DimFix e2) = Just [(e1, e2)] match (DimSlice a1 b1 c1) (DimSlice a2 b2 c2) = concat <$> sequence [pair (a1, a2), pair (b1, b2), pair (c1, c2)] match _ _ = Nothing pair (Nothing, Nothing) = Just [] pair (Just x, Just y) = Just [(x, y)] pair _ = Nothing -- | If these two expressions are structurally similar at top level as -- sizes, produce their subexpressions (which are not necessarily -- similar, but you can check for that!). This is the machinery -- underlying expresssion unification. We assume that the expressions -- have the same type. similarExps :: Exp -> Exp -> Maybe [(Exp, Exp)] similarExps e1 e2 | bareExp e1 == bareExp e2 = Just [] similarExps e1 e2 | Just e1' <- stripExp e1 = similarExps e1' e2 similarExps e1 e2 | Just e2' <- stripExp e2 = similarExps e1 e2' similarExps (IntLit x _ _) (Literal v _) = case v of SignedValue (Int8Value y) | x == toInteger y -> Just [] SignedValue (Int16Value y) | x == toInteger y -> Just [] SignedValue (Int32Value y) | x == toInteger y -> Just [] SignedValue (Int64Value y) | x == toInteger y -> Just [] _ -> Nothing similarExps (AppExp (BinOp (op1, _) _ (x1, _) (y1, _) _) _) (AppExp (BinOp (op2, _) _ (x2, _) (y2, _) _) _) | op1 == op2 = Just [(x1, x2), (y1, y2)] similarExps (AppExp (Apply f1 args1 _) _) (AppExp (Apply f2 args2 _) _) | f1 == f2 = Just $ zip (map snd $ NE.toList args1) (map snd $ NE.toList args2) similarExps (AppExp (Index arr1 slice1 _) _) (AppExp (Index arr2 slice2 _) _) | arr1 == arr2, length slice1 == length slice2 = similarSlices slice1 slice2 similarExps (TupLit es1 _) (TupLit es2 _) | length es1 == length es2 = Just $ zip es1 es2 similarExps (RecordLit fs1 _) (RecordLit fs2 _) | length fs1 == length fs2 = zipWithM onFields fs1 fs2 where onFields (RecordFieldExplicit (L _ n1) fe1 _) (RecordFieldExplicit (L _ n2) fe2 _) | n1 == n2 = Just (fe1, fe2) onFields (RecordFieldImplicit (L _ vn1) ty1 _) (RecordFieldImplicit (L _ vn2) ty2 _) = Just (Var (qualName vn1) ty1 mempty, Var (qualName vn2) ty2 mempty) onFields _ _ = Nothing similarExps (ArrayLit es1 _ _) (ArrayLit es2 _ _) | length es1 == length es2 = Just $ zip es1 es2 similarExps (Project field1 e1 _ _) (Project field2 e2 _ _) | field1 == field2 = Just [(e1, e2)] similarExps (Negate e1 _) (Negate e2 _) = Just [(e1, e2)] similarExps (Not e1 _) (Not e2 _) = Just [(e1, e2)] similarExps (Constr n1 es1 _ _) (Constr n2 es2 _ _) | length es1 == length es2, n1 == n2 = Just $ zip es1 es2 similarExps (Update e1 slice1 e'1 _) (Update e2 slice2 e'2 _) = ([(e1, e2), (e'1, e'2)] ++) <$> similarSlices slice1 slice2 similarExps (RecordUpdate e1 names1 e'1 _ _) (RecordUpdate e2 names2 e'2 _ _) | names1 == names2 = Just [(e1, e2), (e'1, e'2)] similarExps (OpSection op1 _ _) (OpSection op2 _ _) | op1 == op2 = Just [] similarExps (OpSectionLeft op1 _ x1 _ _ _) (OpSectionLeft op2 _ x2 _ _ _) | op1 == op2 = Just [(x1, x2)] similarExps (OpSectionRight op1 _ x1 _ _ _) (OpSectionRight op2 _ x2 _ _ _) | op1 == op2 = Just [(x1, x2)] similarExps (ProjectSection names1 _ _) (ProjectSection names2 _ _) | names1 == names2 = Just [] similarExps (IndexSection slice1 _ _) (IndexSection slice2 _ _) = similarSlices slice1 slice2 similarExps _ _ = Nothing -- | Are these the same expression as per recursively invoking -- 'similarExps'? sameExp :: Exp -> Exp -> Bool sameExp e1 e2 | Just es <- similarExps e1 e2 = all (uncurry sameExp) es | otherwise = False -- | An identifier with type- and aliasing information. type Ident = IdentBase Info VName -- | An index with type information. type DimIndex = DimIndexBase Info VName -- | A slice with type information. type Slice = SliceBase Info VName -- | An expression with type information. type Exp = ExpBase Info VName -- | An application expression with type information. type AppExp = AppExpBase Info VName -- | A pattern with type information. type Pat = PatBase Info VName -- | An constant declaration with type information. type ValBind = ValBindBase Info VName -- | A type binding with type information. type TypeBind = TypeBindBase Info VName -- | A type-checked module binding. type ModBind = ModBindBase Info VName -- | A type-checked module type binding. type ModTypeBind = ModTypeBindBase Info VName -- | A type-checked module expression. type ModExp = ModExpBase Info VName -- | A type-checked module parameter. type ModParam = ModParamBase Info VName -- | A type-checked module type expression. type ModTypeExp = ModTypeExpBase Info VName -- | A type-checked declaration. type Dec = DecBase Info VName -- | A type-checked specification. type Spec = SpecBase Info VName -- | An Futhark program with type information. type Prog = ProgBase Info VName -- | A known type arg with shape annotations. type StructTypeArg = TypeArg Size -- | A type-checked type parameter. type TypeParam = TypeParamBase VName -- | A known scalar type with no shape annotations. type ScalarType = ScalarTypeBase () -- | A type-checked case (of a match expression). type Case = CaseBase Info VName -- | A type with no aliasing information but shape annotations. type UncheckedType = TypeBase (Shape Name) () -- | An unchecked type expression. type UncheckedTypeExp = TypeExp UncheckedExp Name -- | An identifier with no type annotations. type UncheckedIdent = IdentBase NoInfo Name -- | An index with no type annotations. type UncheckedDimIndex = DimIndexBase NoInfo Name -- | A slice with no type annotations. type UncheckedSlice = SliceBase NoInfo Name -- | An expression with no type annotations. type UncheckedExp = ExpBase NoInfo Name -- | A module expression with no type annotations. type UncheckedModExp = ModExpBase NoInfo Name -- | A module type expression with no type annotations. type UncheckedModTypeExp = ModTypeExpBase NoInfo Name -- | A type parameter with no type annotations. type UncheckedTypeParam = TypeParamBase Name -- | A pattern with no type annotations. type UncheckedPat = PatBase NoInfo Name -- | A function declaration with no type annotations. type UncheckedValBind = ValBindBase NoInfo Name -- | A type binding with no type annotations. type UncheckedTypeBind = TypeBindBase NoInfo Name -- | A module type binding with no type annotations. type UncheckedModTypeBind = ModTypeBindBase NoInfo Name -- | A module binding with no type annotations. type UncheckedModBind = ModBindBase NoInfo Name -- | A declaration with no type annotations. type UncheckedDec = DecBase NoInfo Name -- | A spec with no type annotations. type UncheckedSpec = SpecBase NoInfo Name -- | A Futhark program with no type annotations. type UncheckedProg = ProgBase NoInfo Name -- | A case (of a match expression) with no type annotations. type UncheckedCase = CaseBase NoInfo Name futhark-0.25.27/src/Language/Futhark/Query.hs000066400000000000000000000312251475065116200207240ustar00rootroot00000000000000-- | Facilities for answering queries about a program, such as "what -- appears at this source location", or "where is this name bound". -- The intent is that this is used as a building block for IDE-like -- functionality. module Language.Futhark.Query ( BoundTo (..), boundLoc, AtPos (..), atPos, Pos (..), ) where import Control.Monad import Control.Monad.State import Data.List (find) import Data.Map qualified as M import Futhark.Util.Loc (Loc (..), Pos (..)) import Language.Futhark import Language.Futhark.Semantic import Language.Futhark.Traversals import System.FilePath.Posix qualified as Posix -- | What a name is bound to. data BoundTo = BoundTerm StructType Loc | BoundModule Loc | BoundModuleType Loc | BoundType Loc deriving (Eq, Show) data Def = DefBound BoundTo | DefIndirect VName deriving (Eq, Show) type Defs = M.Map VName Def -- | Where was a bound variable actually bound? That is, what is the -- location of its definition? boundLoc :: BoundTo -> Loc boundLoc (BoundTerm _ loc) = loc boundLoc (BoundModule loc) = loc boundLoc (BoundModuleType loc) = loc boundLoc (BoundType loc) = loc sizeDefs :: SizeBinder VName -> Defs sizeDefs (SizeBinder v loc) = M.singleton v $ DefBound $ BoundTerm (Scalar (Prim (Signed Int64))) (locOf loc) patternDefs :: Pat (TypeBase Size u) -> Defs patternDefs (Id vn (Info t) loc) = M.singleton vn $ DefBound $ BoundTerm (toStruct t) (locOf loc) patternDefs (TuplePat pats _) = mconcat $ map patternDefs pats patternDefs (RecordPat fields _) = mconcat $ map (patternDefs . snd) fields patternDefs (PatParens pat _) = patternDefs pat patternDefs (PatAttr _ pat _) = patternDefs pat patternDefs Wildcard {} = mempty patternDefs PatLit {} = mempty patternDefs (PatAscription pat _ _) = patternDefs pat patternDefs (PatConstr _ _ pats _) = mconcat $ map patternDefs pats typeParamDefs :: TypeParamBase VName -> Defs typeParamDefs (TypeParamDim vn loc) = M.singleton vn $ DefBound $ BoundTerm (Scalar $ Prim $ Signed Int64) (locOf loc) typeParamDefs (TypeParamType _ vn loc) = M.singleton vn $ DefBound $ BoundType $ locOf loc expDefs :: Exp -> Defs expDefs e = execState (astMap mapper e) extra where mapper = identityMapper {mapOnExp = onExp} onExp e' = do modify (<> expDefs e') pure e' identDefs (Ident v (Info vt) vloc) = M.singleton v $ DefBound $ BoundTerm (toStruct vt) $ locOf vloc extra = case e of AppExp (LetPat sizes pat _ _ _) _ -> foldMap sizeDefs sizes <> patternDefs pat Lambda params _ _ _ _ -> mconcat (map patternDefs params) AppExp (LetFun name (tparams, params, _, Info ret, _) _ loc) _ -> let name_t = funType params ret in M.singleton name (DefBound $ BoundTerm name_t (locOf loc)) <> mconcat (map typeParamDefs tparams) <> mconcat (map patternDefs params) AppExp (LetWith v _ _ _ _ _) _ -> identDefs v AppExp (Loop _ merge _ form _ _) _ -> patternDefs merge <> case form of For i _ -> identDefs i ForIn pat _ -> patternDefs pat While {} -> mempty _ -> mempty valBindDefs :: ValBind -> Defs valBindDefs vbind = M.insert (valBindName vbind) (DefBound $ BoundTerm vbind_t (locOf vbind)) $ mconcat (map typeParamDefs (valBindTypeParams vbind)) <> mconcat (map patternDefs (valBindParams vbind)) <> expDefs (valBindBody vbind) where vbind_t = funType (valBindParams vbind) $ unInfo $ valBindRetType vbind typeBindDefs :: TypeBind -> Defs typeBindDefs tbind = M.singleton (typeAlias tbind) $ DefBound $ BoundType $ locOf tbind modParamDefs :: ModParam -> Defs modParamDefs (ModParam p se _ loc) = M.singleton p (DefBound $ BoundModule $ locOf loc) <> modTypeExpDefs se modExpDefs :: ModExp -> Defs modExpDefs ModVar {} = mempty modExpDefs (ModParens me _) = modExpDefs me modExpDefs ModImport {} = mempty modExpDefs (ModDecs decs _) = mconcat $ map decDefs decs modExpDefs (ModApply e1 e2 _ (Info substs) _) = modExpDefs e1 <> modExpDefs e2 <> M.map DefIndirect substs modExpDefs (ModAscript e _ (Info substs) _) = modExpDefs e <> M.map DefIndirect substs modExpDefs (ModLambda p _ e _) = modParamDefs p <> modExpDefs e modBindDefs :: ModBind -> Defs modBindDefs mbind = M.singleton (modName mbind) (DefBound $ BoundModule $ locOf mbind) <> mconcat (map modParamDefs (modParams mbind)) <> modExpDefs (modExp mbind) <> case modType mbind of Nothing -> mempty Just (_, Info substs) -> M.map DefIndirect substs specDefs :: Spec -> Defs specDefs spec = case spec of ValSpec v tparams _ (Info t) _ loc -> let vdef = DefBound $ BoundTerm t (locOf loc) in M.insert v vdef $ mconcat (map typeParamDefs tparams) TypeAbbrSpec tbind -> typeBindDefs tbind TypeSpec _ v _ _ loc -> M.singleton v $ DefBound $ BoundType $ locOf loc ModSpec v se _ loc -> M.singleton v (DefBound $ BoundModuleType $ locOf loc) <> modTypeExpDefs se IncludeSpec se _ -> modTypeExpDefs se modTypeExpDefs :: ModTypeExp -> Defs modTypeExpDefs se = case se of ModTypeVar _ (Info substs) _ -> M.map DefIndirect substs ModTypeParens e _ -> modTypeExpDefs e ModTypeSpecs specs _ -> mconcat $ map specDefs specs ModTypeWith e _ _ -> modTypeExpDefs e ModTypeArrow _ e1 e2 _ -> modTypeExpDefs e1 <> modTypeExpDefs e2 sigBindDefs :: ModTypeBind -> Defs sigBindDefs sbind = M.singleton (modTypeName sbind) (DefBound $ BoundModuleType $ locOf sbind) <> modTypeExpDefs (modTypeExp sbind) decDefs :: Dec -> Defs decDefs (ValDec vbind) = valBindDefs vbind decDefs (TypeDec vbind) = typeBindDefs vbind decDefs (ModDec mbind) = modBindDefs mbind decDefs (ModTypeDec mbind) = sigBindDefs mbind decDefs (OpenDec me _) = modExpDefs me decDefs (LocalDec dec _) = decDefs dec decDefs ImportDec {} = mempty -- | All bindings of everything in the program. progDefs :: Prog -> Defs progDefs = mconcat . map decDefs . progDecs allBindings :: Imports -> M.Map VName BoundTo allBindings imports = M.mapMaybe forward defs where defs = mconcat $ map (progDefs . fileProg . snd) imports forward (DefBound x) = Just x forward (DefIndirect v) = forward =<< M.lookup v defs data RawAtPos = RawAtName (QualName VName) Loc contains :: (Located a) => a -> Pos -> Bool contains a pos = case locOf a of Loc start end -> pos >= start && pos <= end NoLoc -> False atPosInTypeExp :: TypeExp Exp VName -> Pos -> Maybe RawAtPos atPosInTypeExp te pos = case te of TEVar qn loc -> do guard $ loc `contains` pos Just $ RawAtName qn $ locOf loc TEParens te' _ -> atPosInTypeExp te' pos TETuple es _ -> msum $ map (`atPosInTypeExp` pos) es TERecord fields _ -> msum $ map ((`atPosInTypeExp` pos) . snd) fields TEArray dim te' _ -> atPosInTypeExp te' pos `mplus` inDim dim TEUnique te' _ -> atPosInTypeExp te' pos TEApply e1 arg _ -> atPosInTypeExp e1 pos `mplus` inArg arg TEArrow _ e1 e2 _ -> atPosInTypeExp e1 pos `mplus` atPosInTypeExp e2 pos TESum cs _ -> msum $ map (`atPosInTypeExp` pos) $ concatMap snd cs TEDim _ t _ -> atPosInTypeExp t pos where inArg (TypeArgExpSize dim) = inDim dim inArg (TypeArgExpType e2) = atPosInTypeExp e2 pos inDim (SizeExp e _) = atPosInExp e pos inDim SizeExpAny {} = Nothing atPosInPat :: Pat (TypeBase Size u) -> Pos -> Maybe RawAtPos atPosInPat (Id vn _ loc) pos = do guard $ loc `contains` pos Just $ RawAtName (qualName vn) $ locOf loc atPosInPat (TuplePat pats _) pos = msum $ map (`atPosInPat` pos) pats atPosInPat (RecordPat fields _) pos = msum $ map ((`atPosInPat` pos) . snd) fields atPosInPat (PatParens pat _) pos = atPosInPat pat pos atPosInPat (PatAttr _ pat _) pos = atPosInPat pat pos atPosInPat (PatAscription pat te _) pos = atPosInPat pat pos `mplus` atPosInTypeExp te pos atPosInPat (PatConstr _ _ pats _) pos = msum $ map (`atPosInPat` pos) pats atPosInPat PatLit {} _ = Nothing atPosInPat Wildcard {} _ = Nothing atPosInExp :: Exp -> Pos -> Maybe RawAtPos atPosInExp (Var qn _ loc) pos = do guard $ loc `contains` pos Just $ RawAtName qn $ locOf loc atPosInExp (QualParens (qn, loc) _ _) pos | loc `contains` pos = Just $ RawAtName qn $ locOf loc -- All the value cases are TODO - we need another RawAtPos constructor. atPosInExp Literal {} _ = Nothing atPosInExp IntLit {} _ = Nothing atPosInExp FloatLit {} _ = Nothing atPosInExp (AppExp (LetPat _ pat _ _ _) _) pos | pat `contains` pos = atPosInPat pat pos atPosInExp (AppExp (LetWith a b _ _ _ _) _) pos | a `contains` pos = Just $ RawAtName (qualName $ identName a) (locOf a) | b `contains` pos = Just $ RawAtName (qualName $ identName b) (locOf b) atPosInExp (AppExp (Loop _ merge _ _ _ _) _) pos | merge `contains` pos = atPosInPat merge pos atPosInExp (Ascript _ te _) pos | te `contains` pos = atPosInTypeExp te pos atPosInExp (Coerce _ te _ _) pos | te `contains` pos = atPosInTypeExp te pos atPosInExp e pos = do guard $ e `contains` pos -- Use the Either monad for short-circuiting for efficiency reasons. -- The first hit is going to be the only one. case astMap mapper e of Left atpos -> Just atpos Right _ -> Nothing where mapper = identityMapper {mapOnExp = onExp} onExp e' = case atPosInExp e' pos of Just atpos -> Left atpos Nothing -> Right e' atPosInModExp :: ModExp -> Pos -> Maybe RawAtPos atPosInModExp (ModVar qn loc) pos = do guard $ loc `contains` pos Just $ RawAtName qn $ locOf loc atPosInModExp (ModParens me _) pos = atPosInModExp me pos atPosInModExp ModImport {} _ = Nothing atPosInModExp (ModDecs decs _) pos = msum $ map (`atPosInDec` pos) decs atPosInModExp (ModApply e1 e2 _ _ _) pos = atPosInModExp e1 pos `mplus` atPosInModExp e2 pos atPosInModExp (ModAscript e _ _ _) pos = atPosInModExp e pos atPosInModExp (ModLambda _ _ e _) pos = atPosInModExp e pos atPosInSpec :: Spec -> Pos -> Maybe RawAtPos atPosInSpec spec pos = case spec of ValSpec _ _ te _ _ _ -> atPosInTypeExp te pos TypeAbbrSpec tbind -> atPosInTypeBind tbind pos TypeSpec {} -> Nothing ModSpec _ se _ _ -> atPosInModTypeExp se pos IncludeSpec se _ -> atPosInModTypeExp se pos atPosInModTypeExp :: ModTypeExp -> Pos -> Maybe RawAtPos atPosInModTypeExp se pos = case se of ModTypeVar qn _ loc -> do guard $ loc `contains` pos Just $ RawAtName qn $ locOf loc ModTypeParens e _ -> atPosInModTypeExp e pos ModTypeSpecs specs _ -> msum $ map (`atPosInSpec` pos) specs ModTypeWith e _ _ -> atPosInModTypeExp e pos ModTypeArrow _ e1 e2 _ -> atPosInModTypeExp e1 pos `mplus` atPosInModTypeExp e2 pos atPosInValBind :: ValBind -> Pos -> Maybe RawAtPos atPosInValBind vbind pos = msum (map (`atPosInPat` pos) (valBindParams vbind)) `mplus` atPosInExp (valBindBody vbind) pos `mplus` join (atPosInTypeExp <$> valBindRetDecl vbind <*> pure pos) atPosInTypeBind :: TypeBind -> Pos -> Maybe RawAtPos atPosInTypeBind = atPosInTypeExp . typeExp atPosInModBind :: ModBind -> Pos -> Maybe RawAtPos atPosInModBind (ModBind _ params sig e _ _) pos = msum (map inParam params) `mplus` atPosInModExp e pos `mplus` case sig of Nothing -> Nothing Just (se, _) -> atPosInModTypeExp se pos where inParam (ModParam _ se _ _) = atPosInModTypeExp se pos atPosInModTypeBind :: ModTypeBind -> Pos -> Maybe RawAtPos atPosInModTypeBind = atPosInModTypeExp . modTypeExp atPosInDec :: Dec -> Pos -> Maybe RawAtPos atPosInDec dec pos = do guard $ dec `contains` pos case dec of ValDec vbind -> atPosInValBind vbind pos TypeDec tbind -> atPosInTypeBind tbind pos ModDec mbind -> atPosInModBind mbind pos ModTypeDec sbind -> atPosInModTypeBind sbind pos OpenDec e _ -> atPosInModExp e pos LocalDec dec' _ -> atPosInDec dec' pos ImportDec {} -> Nothing atPosInProg :: Prog -> Pos -> Maybe RawAtPos atPosInProg prog pos = msum $ map (`atPosInDec` pos) (progDecs prog) containingModule :: Imports -> Pos -> Maybe FileModule containingModule imports (Pos file _ _ _) = snd <$> find ((== file') . fst) imports where file' = mkInitialImport $ fst $ Posix.splitExtension file -- | Information about what is at the given source location. data AtPos = AtName (QualName VName) (Maybe BoundTo) Loc deriving (Eq, Show) -- | Information about what's at the given source position. Returns -- 'Nothing' if there is nothing there, including if the source -- position is invalid. atPos :: Imports -> Pos -> Maybe AtPos atPos imports pos = do prog <- fileProg <$> containingModule imports pos RawAtName qn loc <- atPosInProg prog pos Just $ AtName qn (qualLeaf qn `M.lookup` allBindings imports) loc futhark-0.25.27/src/Language/Futhark/Semantic.hs000066400000000000000000000140061475065116200213600ustar00rootroot00000000000000-- | Definitions of various semantic objects (*not* the Futhark -- semantics themselves). module Language.Futhark.Semantic ( ImportName, mkInitialImport, mkImportFrom, includeToFilePath, includeToString, includeToText, FileModule (..), Imports, Namespace (..), Env (..), TySet, FunModType (..), NameMap, BoundV (..), Mod (..), TypeBinding (..), MTy (..), ) where import Data.Map.Strict qualified as M import Data.Text qualified as T import Futhark.Util (fromPOSIX, toPOSIX) import Futhark.Util.Pretty import Language.Futhark import System.FilePath qualified as Native import System.FilePath.Posix qualified as Posix import Prelude hiding (mod) -- | Create an import name immediately from a file path specified by -- the user. mkInitialImport :: Native.FilePath -> ImportName mkInitialImport = ImportName . Posix.normalise . toPOSIX -- | We resolve '..' paths here and assume that no shenanigans are -- going on with symbolic links. If there is, too bad. Don't do -- that. mkImportFrom :: ImportName -> String -> ImportName mkImportFrom (ImportName includer) includee | Posix.isAbsolute includee = ImportName includee | otherwise = ImportName . Posix.normalise . Posix.joinPath . resolveDotDot [] $ init (Posix.splitPath includer) ++ Posix.splitPath includee where resolveDotDot parts [] = reverse parts resolveDotDot parts@("../" : _) ("../" : todo) = resolveDotDot ("../" : parts) todo resolveDotDot (_ : parts) ("../" : todo) = resolveDotDot parts todo resolveDotDot parts (p : todo) = resolveDotDot (p : parts) todo -- | Create a @.fut@ file corresponding to an 'ImportName'. includeToFilePath :: ImportName -> Native.FilePath includeToFilePath (ImportName s) = fromPOSIX $ Posix.normalise s Posix.<.> "fut" -- | Produce a human-readable canonicalized string from an -- 'ImportName'. includeToString :: ImportName -> String includeToString (ImportName s) = Posix.normalise s -- | Produce a human-readable canonicalized text from an -- 'ImportName'. includeToText :: ImportName -> T.Text includeToText (ImportName s) = T.pack $ Posix.normalise s -- | The result of type checking some file. Can be passed to further -- invocations of the type checker. data FileModule = FileModule { -- | Abstract types. fileAbs :: TySet, -- | The environment made available when importing this module. fileEnv :: Env, fileProg :: Prog, -- | The environment at the bottom of the file. Includes local -- parts. fileScope :: Env } -- | A mapping from import names to imports. The ordering is significant. type Imports = [(ImportName, FileModule)] -- | The space inhabited by a name. data Namespace = -- | Functions and values. Term | Type | Signature deriving (Eq, Ord, Show, Enum) -- | A mapping of abstract types to their liftedness. type TySet = M.Map (QualName VName) Liftedness -- | Representation of a module, which is either a plain environment, -- or a parametric module ("functor" in SML). data Mod = ModEnv Env | ModFun FunModType deriving (Show) -- | A parametric functor consists of a set of abstract types, the -- environment of its parameter, and the resulting module type. data FunModType = FunModType { funModTypeAbs :: TySet, funModTypeMod :: Mod, funModTypeMty :: MTy } deriving (Show) -- | Representation of a module type. data MTy = MTy { -- | Abstract types in the module type. mtyAbs :: TySet, mtyMod :: Mod } deriving (Show) -- | A binding from a name to its definition as a type. We allow a -- return type here to support type abbreviations that hide some inner -- sizes (these must necessarily be 'Lifted' or 'SizeLifted'). data TypeBinding = TypeAbbr Liftedness [TypeParam] StructRetType deriving (Eq, Show) -- | Type parameters, list of parameter types (optinally named), and -- return type. The type parameters are in scope in both parameter -- types and the return type. Non-functional values have only a -- return type. data BoundV = BoundV { boundValTParams :: [TypeParam], boundValType :: StructType } deriving (Show) -- | A mapping from names (which always exist in some namespace) to a -- unique (tagged) name. type NameMap = M.Map (Namespace, Name) (QualName VName) -- | Modules produces environment with this representation. data Env = Env { envVtable :: M.Map VName BoundV, envTypeTable :: M.Map VName TypeBinding, envModTypeTable :: M.Map VName MTy, envModTable :: M.Map VName Mod, envNameMap :: NameMap } deriving (Show) instance Semigroup Env where Env vt1 tt1 st1 mt1 nt1 <> Env vt2 tt2 st2 mt2 nt2 = Env (vt1 <> vt2) (tt1 <> tt2) (st1 <> st2) (mt1 <> mt2) (nt1 <> nt2) instance Pretty Namespace where pretty Term = "name" pretty Type = "type" pretty Signature = "module type" instance Monoid Env where mempty = Env mempty mempty mempty mempty mempty instance Pretty MTy where pretty = pretty . mtyMod instance Pretty Mod where pretty (ModEnv e) = pretty e pretty (ModFun (FunModType _ mod mty)) = pretty mod <+> "->" pretty mty instance Pretty Env where pretty (Env vtable ttable sigtable modtable _) = nestedBlock "{" "}" $ stack $ punctuate line $ concat [ map renderTypeBind (M.toList ttable), map renderValBind (M.toList vtable), map renderModType (M.toList sigtable), map renderMod (M.toList modtable) ] where renderTypeBind (name, TypeAbbr l tps tp) = p l <+> prettyName name <> mconcat (map ((" " <>) . pretty) tps) <> " =" <+> pretty tp where p Lifted = "type^" p SizeLifted = "type~" p Unlifted = "type" renderValBind (name, BoundV tps t) = "val" <+> prettyName name <> mconcat (map ((" " <>) . pretty) tps) <> " =" <+> pretty t renderModType (name, _sig) = "module type" <+> prettyName name renderMod (name, mod) = "module" <+> prettyName name <> " =" <+> pretty mod futhark-0.25.27/src/Language/Futhark/Syntax.hs000066400000000000000000001246371475065116200211170ustar00rootroot00000000000000{-# LANGUAGE Strict #-} -- | The Futhark source language AST definition. Many types, such as -- 'ExpBase', are parametrised by type and name representation. -- E.g. in a value of type @ExpBase f vn@, annotations are wrapped in -- the functor @f@, and all names are of type @vn@. See -- https://futhark.readthedocs.org for a language reference, or this -- module may be a little hard to understand. -- -- The system of primitive types is interesting in itself. See -- "Language.Futhark.Primitive". module Language.Futhark.Syntax ( module Language.Futhark.Core, prettyString, prettyText, -- * Types Uniqueness (..), IntType (..), FloatType (..), PrimType (..), Size, Shape (..), shapeRank, stripDims, TypeBase (..), TypeArg (..), SizeExp (..), TypeExp (..), TypeArgExp (..), PName (..), ScalarTypeBase (..), RetTypeBase (..), StructType, ParamType, ResType, StructRetType, ResRetType, ValueType, Diet (..), -- * Values IntValue (..), FloatValue (..), PrimValue (..), IsPrimValue (..), -- * Abstract syntax tree AttrInfo (..), AttrAtom (..), BinOp (..), IdentBase (..), Inclusiveness (..), DimIndexBase (..), SliceBase, SizeBinder (..), AppExpBase (..), AppRes (..), ExpBase (..), FieldBase (..), CaseBase (..), LoopInitBase (..), LoopFormBase (..), PatLit (..), PatBase (..), -- * Module language ImportName (..), SpecBase (..), ModTypeExpBase (..), TypeRefBase (..), ModTypeBindBase (..), ModExpBase (..), ModBindBase (..), ModParamBase (..), -- * Definitions DocComment (..), ValBindBase (..), EntryPoint (..), EntryType (..), EntryParam (..), Liftedness (..), TypeBindBase (..), TypeParamBase (..), typeParamName, ProgBase (..), DecBase (..), -- * Miscellaneous L (..), NoInfo (..), Info (..), QualName (..), mkApply, mkApplyUT, sizeFromName, sizeFromInteger, loopInitExp, ) where import Control.Applicative import Control.Monad import Data.Bifoldable import Data.Bifunctor import Data.Bitraversable import Data.Foldable import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Monoid hiding (Sum) import Data.Ord import Data.Text qualified as T import Data.Traversable import Futhark.Util.Loc import Futhark.Util.Pretty import Language.Futhark.Core import Language.Futhark.Primitive ( FloatType (..), FloatValue (..), IntType (..), IntValue (..), ) import System.FilePath.Posix qualified as Posix import Prelude -- | No information functor. Usually used for placeholder type- or -- aliasing information. data NoInfo a = NoInfo deriving (Eq, Ord, Show) instance Functor NoInfo where fmap _ NoInfo = NoInfo instance Foldable NoInfo where foldr _ b NoInfo = b instance Traversable NoInfo where traverse _ NoInfo = pure NoInfo -- | Some information. The dual to 'NoInfo' newtype Info a = Info {unInfo :: a} deriving (Eq, Ord, Show) instance Functor Info where fmap f (Info x) = Info $ f x instance Foldable Info where foldr f b (Info x) = f x b instance Traversable Info where traverse f (Info x) = Info <$> f x -- | Low-level primitive types. data PrimType = Signed IntType | Unsigned IntType | FloatType FloatType | Bool deriving (Eq, Ord, Show) -- | Non-array values. data PrimValue = SignedValue !IntValue | UnsignedValue !IntValue | FloatValue !FloatValue | BoolValue !Bool deriving (Eq, Ord, Show) -- | A class for converting ordinary Haskell values to primitive -- Futhark values. class IsPrimValue v where primValue :: v -> PrimValue instance IsPrimValue Int where primValue = SignedValue . Int32Value . fromIntegral instance IsPrimValue Int8 where primValue = SignedValue . Int8Value instance IsPrimValue Int16 where primValue = SignedValue . Int16Value instance IsPrimValue Int32 where primValue = SignedValue . Int32Value instance IsPrimValue Int64 where primValue = SignedValue . Int64Value instance IsPrimValue Word8 where primValue = UnsignedValue . Int8Value . fromIntegral instance IsPrimValue Word16 where primValue = UnsignedValue . Int16Value . fromIntegral instance IsPrimValue Word32 where primValue = UnsignedValue . Int32Value . fromIntegral instance IsPrimValue Word64 where primValue = UnsignedValue . Int64Value . fromIntegral instance IsPrimValue Float where primValue = FloatValue . Float32Value instance IsPrimValue Double where primValue = FloatValue . Float64Value instance IsPrimValue Bool where primValue = BoolValue -- | The value of an v'AttrAtom'. data AttrAtom vn = AtomName Name | AtomInt Integer deriving (Eq, Ord, Show) -- | The payload of an attribute. data AttrInfo vn = AttrAtom (AttrAtom vn) SrcLoc | AttrComp Name [AttrInfo vn] SrcLoc deriving (Eq, Ord, Show) -- | The elaborated size of a dimension is just an expression. type Size = ExpBase Info VName -- | Create a 'Size' from a name. sizeFromName :: QualName VName -> SrcLoc -> Size sizeFromName name = Var name (Info $ Scalar $ Prim $ Signed Int64) -- | Create a 'Size' from a constant integer. sizeFromInteger :: Integer -> SrcLoc -> Size sizeFromInteger x = IntLit x (Info <$> Scalar $ Prim $ Signed Int64) -- | The size of an array type is a list of its dimension sizes. If -- 'Nothing', that dimension is of a (statically) unknown size. newtype Shape dim = Shape {shapeDims :: [dim]} deriving (Eq, Ord, Show) instance Foldable Shape where foldr f x (Shape ds) = foldr f x ds instance Traversable Shape where traverse f (Shape ds) = Shape <$> traverse f ds instance Functor Shape where fmap f (Shape ds) = Shape $ map f ds instance Semigroup (Shape dim) where Shape l1 <> Shape l2 = Shape $ l1 ++ l2 instance Monoid (Shape dim) where mempty = Shape [] -- | The number of dimensions contained in a shape. shapeRank :: Shape dim -> Int shapeRank = length . shapeDims -- | @stripDims n shape@ strips the outer @n@ dimensions from -- @shape@, returning 'Nothing' if this would result in zero or -- fewer dimensions. stripDims :: Int -> Shape dim -> Maybe (Shape dim) stripDims i (Shape l) | i < length l = Just $ Shape $ drop i l | otherwise = Nothing -- | The name (if any) of a function parameter. The 'Eq' and 'Ord' -- instances always compare values of this type equal. data PName = Named VName | Unnamed deriving (Show) instance Eq PName where _ == _ = True instance Ord PName where _ <= _ = True -- | Types that can appear to the right of a function arrow. This -- just means they can be existentially quantified. data RetTypeBase dim as = RetType { retDims :: [VName], retType :: TypeBase dim as } deriving (Eq, Ord, Show) instance Bitraversable RetTypeBase where bitraverse f g (RetType dims t) = RetType dims <$> bitraverse f g t instance Functor (RetTypeBase dim) where fmap = fmapDefault instance Foldable (RetTypeBase dim) where foldMap = foldMapDefault instance Traversable (RetTypeBase dim) where traverse = bitraverse pure instance Bifunctor RetTypeBase where bimap = bimapDefault instance Bifoldable RetTypeBase where bifoldMap = bifoldMapDefault -- | Types that can be elements of arrays. This representation does -- allow arrays of records of functions, which is nonsensical, but it -- convolutes the code too much if we try to statically rule it out. data ScalarTypeBase dim u = Prim PrimType | TypeVar u (QualName VName) [TypeArg dim] | Record (M.Map Name (TypeBase dim u)) | Sum (M.Map Name [TypeBase dim u]) | -- | The aliasing corresponds to the lexical -- closure of the function. Arrow u PName Diet (TypeBase dim NoUniqueness) (RetTypeBase dim Uniqueness) deriving (Eq, Ord, Show) instance Bitraversable ScalarTypeBase where bitraverse _ _ (Prim t) = pure $ Prim t bitraverse f g (Record fs) = Record <$> traverse (bitraverse f g) fs bitraverse f g (TypeVar als t args) = TypeVar <$> g als <*> pure t <*> traverse (traverse f) args bitraverse f g (Arrow u v d t1 t2) = Arrow <$> g u <*> pure v <*> pure d <*> bitraverse f pure t1 <*> bitraverse f pure t2 bitraverse f g (Sum cs) = Sum <$> (traverse . traverse) (bitraverse f g) cs instance Functor (ScalarTypeBase dim) where fmap = fmapDefault instance Foldable (ScalarTypeBase dim) where foldMap = foldMapDefault instance Traversable (ScalarTypeBase dim) where traverse = bitraverse pure instance Bifunctor ScalarTypeBase where bimap = bimapDefault instance Bifoldable ScalarTypeBase where bifoldMap = bifoldMapDefault -- | An expanded Futhark type is either an array, or something that -- can be an element of an array. When comparing types for equality, -- function parameter names are ignored. This representation permits -- some malformed types (arrays of functions), but importantly rules -- out arrays-of-arrays. data TypeBase dim u = Scalar (ScalarTypeBase dim u) | Array u (Shape dim) (ScalarTypeBase dim NoUniqueness) deriving (Eq, Ord, Show) instance Bitraversable TypeBase where bitraverse f g (Scalar t) = Scalar <$> bitraverse f g t bitraverse f g (Array als shape t) = Array <$> g als <*> traverse f shape <*> bitraverse f pure t instance Functor (TypeBase dim) where fmap = fmapDefault instance Foldable (TypeBase dim) where foldMap = foldMapDefault instance Traversable (TypeBase dim) where traverse = bitraverse pure instance Bifunctor TypeBase where bimap = bimapDefault instance Bifoldable TypeBase where bifoldMap = bifoldMapDefault -- | An argument passed to a type constructor. data TypeArg dim = TypeArgDim dim | TypeArgType (TypeBase dim NoUniqueness) deriving (Eq, Ord, Show) instance Traversable TypeArg where traverse f (TypeArgDim v) = TypeArgDim <$> f v traverse f (TypeArgType t) = TypeArgType <$> bitraverse f pure t instance Functor TypeArg where fmap = fmapDefault instance Foldable TypeArg where foldMap = foldMapDefault -- | A "structural" type with shape annotations and no aliasing -- information, used for declarations. type StructType = TypeBase Size NoUniqueness -- | A type with consumption information, used for function parameters -- (but not in function types). type ParamType = TypeBase Size Diet -- | A type with uniqueness information, used for function return types type ResType = TypeBase Size Uniqueness -- | A value type contains full, manifest size information. type ValueType = TypeBase Int64 NoUniqueness -- | The return type version of a 'ResType'. type StructRetType = RetTypeBase Size NoUniqueness -- | The return type version of a 'StructType'. type ResRetType = RetTypeBase Size Uniqueness -- | A dimension declaration expression for use in a 'TypeExp'. -- Syntactically includes the brackets. data SizeExp d = -- | The size of the dimension is this expression (or whatever), -- all of which free variables must be in scope. SizeExp d SrcLoc | -- | No dimension declaration. SizeExpAny SrcLoc deriving (Eq, Ord, Show) instance Functor SizeExp where fmap = fmapDefault instance Foldable SizeExp where foldMap = foldMapDefault instance Traversable SizeExp where traverse _ (SizeExpAny loc) = pure (SizeExpAny loc) traverse f (SizeExp d loc) = SizeExp <$> f d <*> pure loc instance Located (SizeExp d) where locOf (SizeExp _ loc) = locOf loc locOf (SizeExpAny loc) = locOf loc -- | A type argument expression passed to a type constructor. data TypeArgExp d vn = TypeArgExpSize (SizeExp d) | TypeArgExpType (TypeExp d vn) deriving (Eq, Ord, Show) instance Functor (TypeArgExp d) where fmap = fmapDefault instance Foldable (TypeArgExp d) where foldMap = foldMapDefault instance Traversable (TypeArgExp d) where traverse = bitraverse pure instance Bifunctor TypeArgExp where bimap = bimapDefault instance Bifoldable TypeArgExp where bifoldMap = bifoldMapDefault instance Bitraversable TypeArgExp where bitraverse f _ (TypeArgExpSize d) = TypeArgExpSize <$> traverse f d bitraverse f g (TypeArgExpType te) = TypeArgExpType <$> bitraverse f g te instance Located (TypeArgExp f vn) where locOf (TypeArgExpSize e) = locOf e locOf (TypeArgExpType t) = locOf t -- | An unstructured syntactic type with type variables and possibly -- shape declarations - this is what the user types in the source -- program. These are used to construct 'TypeBase's in the type -- checker. data TypeExp d vn = TEVar (QualName vn) SrcLoc | TEParens (TypeExp d vn) SrcLoc | TETuple [TypeExp d vn] SrcLoc | TERecord [(L Name, TypeExp d vn)] SrcLoc | TEArray (SizeExp d) (TypeExp d vn) SrcLoc | TEUnique (TypeExp d vn) SrcLoc | TEApply (TypeExp d vn) (TypeArgExp d vn) SrcLoc | TEArrow (Maybe vn) (TypeExp d vn) (TypeExp d vn) SrcLoc | TESum [(Name, [TypeExp d vn])] SrcLoc | TEDim [vn] (TypeExp d vn) SrcLoc deriving (Eq, Ord, Show) instance Bitraversable TypeExp where bitraverse _ g (TEVar v loc) = TEVar <$> traverse g v <*> pure loc bitraverse f g (TEParens te loc) = TEParens <$> bitraverse f g te <*> pure loc bitraverse f g (TETuple tes loc) = TETuple <$> traverse (bitraverse f g) tes <*> pure loc bitraverse f g (TERecord fs loc) = TERecord <$> traverse (traverse (bitraverse f g)) fs <*> pure loc bitraverse f g (TESum cs loc) = TESum <$> traverse (traverse (traverse (bitraverse f g))) cs <*> pure loc bitraverse f g (TEArray d te loc) = TEArray <$> traverse f d <*> bitraverse f g te <*> pure loc bitraverse f g (TEUnique te loc) = TEUnique <$> bitraverse f g te <*> pure loc bitraverse f g (TEApply te arg loc) = TEApply <$> bitraverse f g te <*> bitraverse f g arg <*> pure loc bitraverse f g (TEArrow pn te1 te2 loc) = TEArrow <$> traverse g pn <*> bitraverse f g te1 <*> bitraverse f g te2 <*> pure loc bitraverse f g (TEDim dims te loc) = TEDim <$> traverse g dims <*> bitraverse f g te <*> pure loc instance Functor (TypeExp d) where fmap = fmapDefault instance Foldable (TypeExp dim) where foldMap = foldMapDefault instance Traversable (TypeExp dim) where traverse = bitraverse pure instance Bifunctor TypeExp where bimap = bimapDefault instance Bifoldable TypeExp where bifoldMap = bifoldMapDefault instance Located (TypeExp f vn) where locOf (TEArray _ _ loc) = locOf loc locOf (TETuple _ loc) = locOf loc locOf (TERecord _ loc) = locOf loc locOf (TEVar _ loc) = locOf loc locOf (TEParens _ loc) = locOf loc locOf (TEUnique _ loc) = locOf loc locOf (TEApply _ _ loc) = locOf loc locOf (TEArrow _ _ _ loc) = locOf loc locOf (TESum _ loc) = locOf loc locOf (TEDim _ _ loc) = locOf loc -- | Information about which parts of a parameter are consumed. This -- can be considered kind of an effect on the function. data Diet = -- | Does not consume the parameter. Observe | -- | Consumes the parameter. Consume deriving (Eq, Ord, Show) instance Semigroup Diet where (<>) = max instance Monoid Diet where mempty = Observe -- | An identifier consists of its name and the type of the value -- bound to the identifier. data IdentBase f vn t = Ident { identName :: vn, identType :: f t, identSrcLoc :: SrcLoc } deriving instance (Show (Info t)) => Show (IdentBase Info VName t) deriving instance (Show (Info t), Show vn) => Show (IdentBase NoInfo vn t) instance (Eq vn) => Eq (IdentBase ty vn t) where x == y = identName x == identName y instance (Ord vn) => Ord (IdentBase ty vn t) where compare = comparing identName instance Located (IdentBase ty vn t) where locOf = locOf . identSrcLoc -- | Default binary operators. data BinOp = -- | A pseudo-operator standing in for any normal -- identifier used as an operator (they all have the -- same fixity). Backtick | -- | Not a real operator, but operator with this as a prefix may -- be defined by the user. Bang | -- | Not a real operator, but operator with this as a prefix -- may be defined by the user. Equ | Plus | Minus | Pow | Times | Divide | Mod | Quot | Rem | ShiftR | ShiftL | Band | Xor | Bor | LogAnd | LogOr | -- Relational Ops for all primitive types at least Equal | NotEqual | Less | Leq | Greater | Geq | -- Some functional ops. -- | @|>@ PipeRight | -- | @<|@ -- Misc PipeLeft deriving (Eq, Ord, Show, Enum, Bounded) -- | Whether a bound for an end-point of a 'DimSlice' or a range -- literal is inclusive or exclusive. data Inclusiveness a = DownToExclusive a | -- | May be "down to" if step is negative. ToInclusive a | UpToExclusive a deriving (Eq, Ord, Show) instance (Located a) => Located (Inclusiveness a) where locOf (DownToExclusive x) = locOf x locOf (ToInclusive x) = locOf x locOf (UpToExclusive x) = locOf x instance Functor Inclusiveness where fmap = fmapDefault instance Foldable Inclusiveness where foldMap = foldMapDefault instance Traversable Inclusiveness where traverse f (DownToExclusive x) = DownToExclusive <$> f x traverse f (ToInclusive x) = ToInclusive <$> f x traverse f (UpToExclusive x) = UpToExclusive <$> f x -- | An indexing of a single dimension. data DimIndexBase f vn = DimFix (ExpBase f vn) | DimSlice (Maybe (ExpBase f vn)) (Maybe (ExpBase f vn)) (Maybe (ExpBase f vn)) deriving instance Show (DimIndexBase Info VName) deriving instance (Show vn) => Show (DimIndexBase NoInfo vn) deriving instance Eq (DimIndexBase NoInfo VName) deriving instance Eq (DimIndexBase Info VName) deriving instance Ord (DimIndexBase NoInfo VName) deriving instance Ord (DimIndexBase Info VName) -- | A slicing of an array (potentially multiple dimensions). type SliceBase f vn = [DimIndexBase f vn] -- | A name qualified with a breadcrumb of module accesses. data QualName vn = QualName { qualQuals :: ![vn], qualLeaf :: !vn } deriving (Show) instance (Eq v) => Eq (QualName v) where QualName _ v1 == QualName _ v2 = v1 == v2 instance (Ord v) => Ord (QualName v) where QualName _ v1 `compare` QualName _ v2 = compare v1 v2 instance Functor QualName where fmap = fmapDefault instance Foldable QualName where foldMap = foldMapDefault instance Traversable QualName where traverse f (QualName qs v) = QualName <$> traverse f qs <*> f v -- | A binding of a size in a pattern (essentially a size parameter in -- a @let@ expression). data SizeBinder vn = SizeBinder {sizeName :: !vn, sizeLoc :: !SrcLoc} deriving (Eq, Ord, Show) instance Located (SizeBinder vn) where locOf = locOf . sizeLoc -- | An "application expression" is a semantic (not syntactic) -- grouping of expressions that have "funcall-like" semantics, mostly -- meaning that they can return existential sizes. In our type -- theory, these are all thought to be bound to names (*Administrative -- Normal Form*), but as this is not practical in a real language, we -- instead use an annotation ('AppRes') that stores the information we -- need, so we can pretend that an application expression was really -- bound to a name. data AppExpBase f vn = -- | Function application. Parts of the compiler expects that the -- function expression is never itself an 'Apply'. Use the -- 'mkApply' function to maintain this invariant, rather than -- constructing 'Apply' directly. -- -- The @Maybe VNames@ are existential sizes generated by this -- argument. May have duplicates across the program, but they -- will all produce the same value (the expressions will be -- identical). Apply (ExpBase f vn) (NE.NonEmpty (f (Maybe VName), ExpBase f vn)) SrcLoc | Range (ExpBase f vn) (Maybe (ExpBase f vn)) (Inclusiveness (ExpBase f vn)) SrcLoc | LetPat [SizeBinder vn] (PatBase f vn StructType) (ExpBase f vn) (ExpBase f vn) SrcLoc | LetFun vn ( [TypeParamBase vn], [PatBase f vn ParamType], Maybe (TypeExp (ExpBase f vn) vn), f ResRetType, ExpBase f vn ) (ExpBase f vn) SrcLoc | If (ExpBase f vn) (ExpBase f vn) (ExpBase f vn) SrcLoc | Loop [VName] -- Size parameters. (PatBase f vn ParamType) -- Loop parameter pattern. (LoopInitBase f vn) -- Possibly initial value. (LoopFormBase f vn) -- Do or while loop. (ExpBase f vn) -- Loop body. SrcLoc | BinOp (QualName vn, SrcLoc) (f StructType) (ExpBase f vn, f (Maybe VName)) (ExpBase f vn, f (Maybe VName)) SrcLoc | LetWith (IdentBase f vn StructType) (IdentBase f vn StructType) (SliceBase f vn) (ExpBase f vn) (ExpBase f vn) SrcLoc | Index (ExpBase f vn) (SliceBase f vn) SrcLoc | -- | A match expression. Match (ExpBase f vn) (NE.NonEmpty (CaseBase f vn)) SrcLoc deriving instance Show (AppExpBase Info VName) deriving instance (Show vn) => Show (AppExpBase NoInfo vn) deriving instance Eq (AppExpBase NoInfo VName) deriving instance Eq (AppExpBase Info VName) deriving instance Ord (AppExpBase NoInfo VName) deriving instance Ord (AppExpBase Info VName) instance Located (AppExpBase f vn) where locOf (Range _ _ _ pos) = locOf pos locOf (BinOp _ _ _ _ loc) = locOf loc locOf (If _ _ _ loc) = locOf loc locOf (Apply _ _ loc) = locOf loc locOf (LetPat _ _ _ _ loc) = locOf loc locOf (LetFun _ _ _ loc) = locOf loc locOf (LetWith _ _ _ _ _ loc) = locOf loc locOf (Index _ _ loc) = locOf loc locOf (Loop _ _ _ _ _ loc) = locOf loc locOf (Match _ _ loc) = locOf loc -- | An annotation inserted by the type checker on constructs that are -- "function calls" (either literally or conceptually). This -- annotation encodes the result type, as well as any existential -- sizes that are generated here. data AppRes = AppRes { appResType :: StructType, appResExt :: [VName] } deriving (Eq, Ord, Show) -- | The Futhark expression language. -- -- This allows us to encode whether or not the expression has been -- type-checked in the Haskell type of the expression. Specifically, -- the parser will produce expressions of type @Exp 'NoInfo' 'Name'@, -- and the type checker will convert these to @Exp 'Info' 'VName'@, in -- which type information is always present and all names are unique. data ExpBase f vn = Literal PrimValue SrcLoc | -- | A polymorphic integral literal. IntLit Integer (f StructType) SrcLoc | -- | A polymorphic decimal literal. FloatLit Double (f StructType) SrcLoc | -- | A string literal is just a fancy syntax for an array -- of bytes. StringLit [Word8] SrcLoc | Hole (f StructType) SrcLoc | Var (QualName vn) (f StructType) SrcLoc | -- | A parenthesized expression. Parens (ExpBase f vn) SrcLoc | QualParens (QualName vn, SrcLoc) (ExpBase f vn) SrcLoc | -- | Tuple literals, e.g., @{1+3, {x, y+z}}@. TupLit [ExpBase f vn] SrcLoc | -- | Record literals, e.g. @{x=2,y=3,z}@. RecordLit [FieldBase f vn] SrcLoc | -- | Array literals, e.g., @[ [1+x, 3], [2, 1+4] ]@. -- Second arg is the row type of the rows of the array. ArrayLit [ExpBase f vn] (f StructType) SrcLoc | -- | Array value constants, where the elements are known to be -- constant primitives. This is a fast-path variant of 'ArrayLit' -- that will in some cases be constructed by the parser, and also -- result from normalisation later on. Has exactly the same -- semantics as an 'ArrayLit'. ArrayVal [PrimValue] PrimType SrcLoc | -- | An attribute applied to the following expression. Attr (AttrInfo vn) (ExpBase f vn) SrcLoc | Project Name (ExpBase f vn) (f StructType) SrcLoc | -- | Numeric negation (ugly special case; Haskell did it first). Negate (ExpBase f vn) SrcLoc | -- | Logical and bitwise negation. Not (ExpBase f vn) SrcLoc | -- | Fail if the first expression does not return true, -- and return the value of the second expression if it -- does. Assert (ExpBase f vn) (ExpBase f vn) (f T.Text) SrcLoc | -- | An n-ary value constructor. Constr Name [ExpBase f vn] (f StructType) SrcLoc | Update (ExpBase f vn) (SliceBase f vn) (ExpBase f vn) SrcLoc | RecordUpdate (ExpBase f vn) [Name] (ExpBase f vn) (f StructType) SrcLoc | Lambda [PatBase f vn ParamType] (ExpBase f vn) (Maybe (TypeExp (ExpBase f vn) vn)) (f ResRetType) SrcLoc | -- | @+@; first two types are operands, third is result. OpSection (QualName vn) (f StructType) SrcLoc | -- | @2+@; first type is operand, second is result. OpSectionLeft (QualName vn) (f StructType) (ExpBase f vn) (f (PName, ParamType, Maybe VName), f (PName, ParamType)) (f ResRetType, f [VName]) SrcLoc | -- | @+2@; first type is operand, second is result. OpSectionRight (QualName vn) (f StructType) (ExpBase f vn) (f (PName, ParamType), f (PName, ParamType, Maybe VName)) (f ResRetType) SrcLoc | -- | Field projection as a section: @(.x.y.z)@. ProjectSection [Name] (f StructType) SrcLoc | -- | Array indexing as a section: @(.[i,j])@. IndexSection (SliceBase f vn) (f StructType) SrcLoc | -- | Type ascription: @e : t@. Ascript (ExpBase f vn) (TypeExp (ExpBase f vn) vn) SrcLoc | -- | Size coercion: @e :> t@. Coerce (ExpBase f vn) (TypeExp (ExpBase f vn) vn) (f StructType) SrcLoc | AppExp (AppExpBase f vn) (f AppRes) deriving instance Show (ExpBase Info VName) deriving instance (Show vn) => Show (ExpBase NoInfo vn) deriving instance Eq (ExpBase NoInfo VName) deriving instance Ord (ExpBase NoInfo VName) deriving instance Eq (ExpBase Info VName) deriving instance Ord (ExpBase Info VName) instance Located (ExpBase f vn) where locOf (Literal _ loc) = locOf loc locOf (IntLit _ _ loc) = locOf loc locOf (FloatLit _ _ loc) = locOf loc locOf (Parens _ loc) = locOf loc locOf (QualParens _ _ loc) = locOf loc locOf (TupLit _ pos) = locOf pos locOf (RecordLit _ pos) = locOf pos locOf (Project _ _ _ pos) = locOf pos locOf (ArrayLit _ _ pos) = locOf pos locOf (ArrayVal _ _ loc) = locOf loc locOf (StringLit _ loc) = locOf loc locOf (Var _ _ loc) = locOf loc locOf (Ascript _ _ loc) = locOf loc locOf (Coerce _ _ _ loc) = locOf loc locOf (Negate _ pos) = locOf pos locOf (Not _ pos) = locOf pos locOf (Update _ _ _ pos) = locOf pos locOf (RecordUpdate _ _ _ _ pos) = locOf pos locOf (Lambda _ _ _ _ loc) = locOf loc locOf (Hole _ loc) = locOf loc locOf (OpSection _ _ loc) = locOf loc locOf (OpSectionLeft _ _ _ _ _ loc) = locOf loc locOf (OpSectionRight _ _ _ _ _ loc) = locOf loc locOf (ProjectSection _ _ loc) = locOf loc locOf (IndexSection _ _ loc) = locOf loc locOf (Assert _ _ _ loc) = locOf loc locOf (Constr _ _ _ loc) = locOf loc locOf (Attr _ _ loc) = locOf loc locOf (AppExp e _) = locOf e -- | An entry in a record literal. data FieldBase f vn = RecordFieldExplicit (L Name) (ExpBase f vn) SrcLoc | RecordFieldImplicit (L vn) (f StructType) SrcLoc deriving instance Show (FieldBase Info VName) deriving instance (Show vn) => Show (FieldBase NoInfo vn) deriving instance Eq (FieldBase NoInfo VName) deriving instance Eq (FieldBase Info VName) deriving instance Ord (FieldBase NoInfo VName) deriving instance Ord (FieldBase Info VName) instance Located (FieldBase f vn) where locOf (RecordFieldExplicit _ _ loc) = locOf loc locOf (RecordFieldImplicit _ _ loc) = locOf loc -- | A case in a match expression. data CaseBase f vn = CasePat (PatBase f vn StructType) (ExpBase f vn) SrcLoc deriving instance Show (CaseBase Info VName) deriving instance (Show vn) => Show (CaseBase NoInfo vn) deriving instance Eq (CaseBase NoInfo VName) deriving instance Eq (CaseBase Info VName) deriving instance Ord (CaseBase NoInfo VName) deriving instance Ord (CaseBase Info VName) instance Located (CaseBase f vn) where locOf (CasePat _ _ loc) = locOf loc -- | Initial value for the loop. If none is provided, then an -- expression will be synthesised based on the parameter. data LoopInitBase f vn = LoopInitExplicit (ExpBase f vn) | LoopInitImplicit (f (ExpBase f vn)) deriving instance Show (LoopInitBase Info VName) deriving instance (Show vn) => Show (LoopInitBase NoInfo vn) deriving instance Eq (LoopInitBase NoInfo VName) deriving instance Eq (LoopInitBase Info VName) deriving instance Ord (LoopInitBase NoInfo VName) deriving instance Ord (LoopInitBase Info VName) instance Located (LoopInitBase Info vn) where locOf (LoopInitExplicit e) = locOf e locOf (LoopInitImplicit (Info e)) = locOf e -- | Whether the loop is a @for@-loop or a @while@-loop. data LoopFormBase f vn = For (IdentBase f vn StructType) (ExpBase f vn) | ForIn (PatBase f vn StructType) (ExpBase f vn) | While (ExpBase f vn) deriving instance Show (LoopFormBase Info VName) deriving instance (Show vn) => Show (LoopFormBase NoInfo vn) deriving instance Eq (LoopFormBase NoInfo VName) deriving instance Eq (LoopFormBase Info VName) deriving instance Ord (LoopFormBase NoInfo VName) deriving instance Ord (LoopFormBase Info VName) -- | A literal in a pattern. data PatLit = PatLitInt Integer | PatLitFloat Double | PatLitPrim PrimValue deriving (Eq, Ord, Show) -- | A pattern as used most places where variables are bound (function -- parameters, @let@ expressions, etc). data PatBase f vn t = TuplePat [PatBase f vn t] SrcLoc | RecordPat [(L Name, PatBase f vn t)] SrcLoc | PatParens (PatBase f vn t) SrcLoc | Id vn (f t) SrcLoc | Wildcard (f t) SrcLoc -- Nothing, i.e. underscore. | PatAscription (PatBase f vn t) (TypeExp (ExpBase f vn) vn) SrcLoc | PatLit PatLit (f t) SrcLoc | PatConstr Name (f t) [PatBase f vn t] SrcLoc | PatAttr (AttrInfo vn) (PatBase f vn t) SrcLoc deriving instance (Show (Info t)) => Show (PatBase Info VName t) deriving instance (Show (NoInfo t), Show vn) => Show (PatBase NoInfo vn t) deriving instance (Eq (NoInfo t)) => Eq (PatBase NoInfo VName t) deriving instance (Eq (Info t)) => Eq (PatBase Info VName t) deriving instance (Ord (NoInfo t)) => Ord (PatBase NoInfo VName t) deriving instance (Ord (Info t)) => Ord (PatBase Info VName t) instance Located (PatBase f vn t) where locOf (TuplePat _ loc) = locOf loc locOf (RecordPat _ loc) = locOf loc locOf (PatParens _ loc) = locOf loc locOf (Id _ _ loc) = locOf loc locOf (Wildcard _ loc) = locOf loc locOf (PatAscription _ _ loc) = locOf loc locOf (PatLit _ _ loc) = locOf loc locOf (PatConstr _ _ _ loc) = locOf loc locOf (PatAttr _ _ loc) = locOf loc instance (Traversable f) => Functor (PatBase f vn) where fmap = fmapDefault instance (Traversable f) => Foldable (PatBase f vn) where foldMap = foldMapDefault instance (Traversable f) => Traversable (PatBase f vn) where traverse f (Id v t loc) = Id v <$> traverse f t <*> pure loc traverse f (TuplePat ps loc) = TuplePat <$> traverse (traverse f) ps <*> pure loc traverse f (RecordPat ps loc) = RecordPat <$> traverse (traverse $ traverse f) ps <*> pure loc traverse f (PatParens p loc) = PatParens <$> traverse f p <*> pure loc traverse f (Wildcard t loc) = Wildcard <$> traverse f t <*> pure loc traverse f (PatAscription p te loc) = PatAscription <$> traverse f p <*> pure te <*> pure loc traverse f (PatLit l t loc) = PatLit l <$> traverse f t <*> pure loc traverse f (PatConstr c t ps loc) = PatConstr c <$> traverse f t <*> traverse (traverse f) ps <*> pure loc traverse f (PatAttr attr p loc) = PatAttr attr <$> traverse f p <*> pure loc -- | Documentation strings, including source location. The string may -- contain newline characters, but it does not contain comment prefix -- markers. data DocComment = DocComment T.Text SrcLoc deriving (Show) instance Located DocComment where locOf (DocComment _ loc) = locOf loc -- | Part of the type of an entry point. Has an actual type, and -- maybe also an ascribed type expression. Note that although size -- expressions in the elaborated type can contain variables, they are -- no longer in scope, and are considered more like equivalence -- classes. data EntryType = EntryType { entryType :: StructType, entryAscribed :: Maybe (TypeExp (ExpBase Info VName) VName) } deriving (Show) -- | A parameter of an entry point. data EntryParam = EntryParam { entryParamName :: Name, entryParamType :: EntryType } deriving (Show) -- | Information about the external interface exposed by an entry -- point. The important thing is that that we remember the original -- source-language types, without desugaring them at all. The -- annoying thing is that we do not require type annotations on entry -- points, so the types can be either ascribed or inferred. data EntryPoint = EntryPoint { entryParams :: [EntryParam], entryReturn :: EntryType } deriving (Show) -- | Function Declarations data ValBindBase f vn = ValBind { -- | Just if this function is an entry point. If so, it also -- contains the externally visible interface. Note that this may not -- strictly be well-typed after some desugaring operations, as it -- may refer to abstract types that are no longer in scope. valBindEntryPoint :: Maybe (f EntryPoint), valBindName :: vn, valBindRetDecl :: Maybe (TypeExp (ExpBase f vn) vn), -- | If 'valBindParams' is null, then the 'retDims' are brought -- into scope at this point. valBindRetType :: f ResRetType, valBindTypeParams :: [TypeParamBase vn], valBindParams :: [PatBase f vn ParamType], valBindBody :: ExpBase f vn, valBindDoc :: Maybe DocComment, valBindAttrs :: [AttrInfo vn], valBindLocation :: SrcLoc } deriving instance Show (ValBindBase Info VName) deriving instance Show (ValBindBase NoInfo Name) instance Located (ValBindBase f vn) where locOf = locOf . valBindLocation -- | Type Declarations data TypeBindBase f vn = TypeBind { typeAlias :: vn, typeLiftedness :: Liftedness, typeParams :: [TypeParamBase vn], typeExp :: TypeExp (ExpBase f vn) vn, typeElab :: f StructRetType, typeDoc :: Maybe DocComment, typeBindLocation :: SrcLoc } deriving instance Show (TypeBindBase Info VName) deriving instance Show (TypeBindBase NoInfo Name) instance Located (TypeBindBase f vn) where locOf = locOf . typeBindLocation -- | The liftedness of a type parameter. By the @Ord@ instance, -- @Unlifted < SizeLifted < Lifted@. data Liftedness = -- | May only be instantiated with a zero-order type of (possibly -- symbolically) known size. Unlifted | -- | May only be instantiated with a zero-order type, but the size -- can be varying. SizeLifted | -- | May be instantiated with a functional type. Lifted deriving (Eq, Ord, Show) -- | A type parameter. data TypeParamBase vn = -- | A type parameter that must be a size. TypeParamDim vn SrcLoc | -- | A type parameter that must be a type. TypeParamType Liftedness vn SrcLoc deriving (Eq, Ord, Show) instance Functor TypeParamBase where fmap = fmapDefault instance Foldable TypeParamBase where foldMap = foldMapDefault instance Traversable TypeParamBase where traverse f (TypeParamDim v loc) = TypeParamDim <$> f v <*> pure loc traverse f (TypeParamType l v loc) = TypeParamType l <$> f v <*> pure loc instance Located (TypeParamBase vn) where locOf (TypeParamDim _ loc) = locOf loc locOf (TypeParamType _ _ loc) = locOf loc -- | The name of a type parameter. typeParamName :: TypeParamBase vn -> vn typeParamName (TypeParamDim v _) = v typeParamName (TypeParamType _ v _) = v -- | A spec is a component of a module type. data SpecBase f vn = ValSpec { specName :: vn, specTypeParams :: [TypeParamBase vn], specTypeExp :: TypeExp (ExpBase f vn) vn, specType :: f StructType, specDoc :: Maybe DocComment, specLocation :: SrcLoc } | TypeAbbrSpec (TypeBindBase f vn) | -- | Abstract type. TypeSpec Liftedness vn [TypeParamBase vn] (Maybe DocComment) SrcLoc | ModSpec vn (ModTypeExpBase f vn) (Maybe DocComment) SrcLoc | IncludeSpec (ModTypeExpBase f vn) SrcLoc deriving instance Show (SpecBase Info VName) deriving instance Show (SpecBase NoInfo Name) instance Located (SpecBase f vn) where locOf (ValSpec _ _ _ _ _ loc) = locOf loc locOf (TypeAbbrSpec tbind) = locOf tbind locOf (TypeSpec _ _ _ _ loc) = locOf loc locOf (ModSpec _ _ _ loc) = locOf loc locOf (IncludeSpec _ loc) = locOf loc -- | A module type expression. data ModTypeExpBase f vn = ModTypeVar (QualName vn) (f (M.Map VName VName)) SrcLoc | ModTypeParens (ModTypeExpBase f vn) SrcLoc | ModTypeSpecs [SpecBase f vn] SrcLoc | ModTypeWith (ModTypeExpBase f vn) (TypeRefBase f vn) SrcLoc | ModTypeArrow (Maybe vn) (ModTypeExpBase f vn) (ModTypeExpBase f vn) SrcLoc deriving instance Show (ModTypeExpBase Info VName) deriving instance Show (ModTypeExpBase NoInfo Name) -- | A type refinement. data TypeRefBase f vn = TypeRef (QualName vn) [TypeParamBase vn] (TypeExp (ExpBase f vn) vn) SrcLoc deriving instance Show (TypeRefBase Info VName) deriving instance Show (TypeRefBase NoInfo Name) instance Located (TypeRefBase f vn) where locOf (TypeRef _ _ _ loc) = locOf loc instance Located (ModTypeExpBase f vn) where locOf (ModTypeVar _ _ loc) = locOf loc locOf (ModTypeParens _ loc) = locOf loc locOf (ModTypeSpecs _ loc) = locOf loc locOf (ModTypeWith _ _ loc) = locOf loc locOf (ModTypeArrow _ _ _ loc) = locOf loc -- | Module type binding. data ModTypeBindBase f vn = ModTypeBind { modTypeName :: vn, modTypeExp :: ModTypeExpBase f vn, modTypeDoc :: Maybe DocComment, modTypeLoc :: SrcLoc } deriving instance Show (ModTypeBindBase Info VName) deriving instance Show (ModTypeBindBase NoInfo Name) instance Located (ModTypeBindBase f vn) where locOf = locOf . modTypeLoc -- | Canonical reference to a Futhark code file. Does not include the -- @.fut@ extension. This is most often a path relative to the -- working directory of the compiler. In a multi-file program, a file -- is known by exactly one import name, even if it is referenced -- relatively by different names by files in different subdirectories. newtype ImportName = ImportName Posix.FilePath deriving (Eq, Ord, Show) -- | Module expression. data ModExpBase f vn = ModVar (QualName vn) SrcLoc | ModParens (ModExpBase f vn) SrcLoc | -- | The contents of another file as a module. ModImport FilePath (f ImportName) SrcLoc | ModDecs [DecBase f vn] SrcLoc | -- | Functor application. The first mapping is from parameter -- names to argument names, while the second maps names in the -- constructed module to the names inside the functor. ModApply (ModExpBase f vn) (ModExpBase f vn) (f (M.Map VName VName)) (f (M.Map VName VName)) SrcLoc | ModAscript (ModExpBase f vn) (ModTypeExpBase f vn) (f (M.Map VName VName)) SrcLoc | ModLambda (ModParamBase f vn) (Maybe (ModTypeExpBase f vn, f (M.Map VName VName))) (ModExpBase f vn) SrcLoc deriving instance Show (ModExpBase Info VName) deriving instance Show (ModExpBase NoInfo Name) instance Located (ModExpBase f vn) where locOf (ModVar _ loc) = locOf loc locOf (ModParens _ loc) = locOf loc locOf (ModImport _ _ loc) = locOf loc locOf (ModDecs _ loc) = locOf loc locOf (ModApply _ _ _ _ loc) = locOf loc locOf (ModAscript _ _ _ loc) = locOf loc locOf (ModLambda _ _ _ loc) = locOf loc -- | A module binding. data ModBindBase f vn = ModBind { modName :: vn, modParams :: [ModParamBase f vn], modType :: Maybe (ModTypeExpBase f vn, f (M.Map VName VName)), modExp :: ModExpBase f vn, modDoc :: Maybe DocComment, modLocation :: SrcLoc } deriving instance Show (ModBindBase Info VName) deriving instance Show (ModBindBase NoInfo Name) instance Located (ModBindBase f vn) where locOf = locOf . modLocation -- | A module parameter. data ModParamBase f vn = ModParam { modParamName :: vn, modParamType :: ModTypeExpBase f vn, modParamAbs :: f [VName], modParamLocation :: SrcLoc } deriving instance Show (ModParamBase Info VName) deriving instance Show (ModParamBase NoInfo Name) instance Located (ModParamBase f vn) where locOf = locOf . modParamLocation -- | A top-level binding. data DecBase f vn = ValDec (ValBindBase f vn) | TypeDec (TypeBindBase f vn) | ModTypeDec (ModTypeBindBase f vn) | ModDec (ModBindBase f vn) | OpenDec (ModExpBase f vn) SrcLoc | LocalDec (DecBase f vn) SrcLoc | ImportDec FilePath (f ImportName) SrcLoc deriving instance Show (DecBase Info VName) deriving instance Show (DecBase NoInfo Name) instance Located (DecBase f vn) where locOf (ValDec d) = locOf d locOf (TypeDec d) = locOf d locOf (ModTypeDec d) = locOf d locOf (ModDec d) = locOf d locOf (OpenDec _ loc) = locOf loc locOf (LocalDec _ loc) = locOf loc locOf (ImportDec _ _ loc) = locOf loc -- | The program described by a single Futhark file. May depend on -- other files. data ProgBase f vn = Prog { progDoc :: Maybe DocComment, progDecs :: [DecBase f vn] } deriving instance Show (ProgBase Info VName) deriving instance Show (ProgBase NoInfo Name) -- | Construct an 'Apply' node, with type information. mkApply :: ExpBase Info vn -> [(Maybe VName, ExpBase Info vn)] -> AppRes -> ExpBase Info vn mkApply f args (AppRes t ext) | Just args' <- NE.nonEmpty $ map onArg args = case f of (AppExp (Apply f' f_args loc) (Info (AppRes _ f_ext))) -> AppExp (Apply f' (f_args <> args') (srcspan loc $ snd $ NE.last args')) (Info $ AppRes t $ f_ext <> ext) _ -> AppExp (Apply f args' (srcspan f $ snd $ NE.last args')) (Info (AppRes t ext)) | otherwise = f where onArg (v, x) = (Info v, x) -- | Construct an 'Apply' node, without type information. mkApplyUT :: ExpBase NoInfo vn -> ExpBase NoInfo vn -> ExpBase NoInfo vn mkApplyUT (AppExp (Apply f args loc) _) x = AppExp (Apply f (args <> NE.singleton (NoInfo, x)) (srcspan loc x)) NoInfo mkApplyUT f x = AppExp (Apply f (NE.singleton (NoInfo, x)) (srcspan f x)) NoInfo loopInitExp :: LoopInitBase Info VName -> ExpBase Info VName loopInitExp (LoopInitExplicit e) = e loopInitExp (LoopInitImplicit (Info e)) = e --- Some prettyprinting definitions are here because we need them in --- the Attributes module. instance Pretty PrimType where pretty (Unsigned Int8) = "u8" pretty (Unsigned Int16) = "u16" pretty (Unsigned Int32) = "u32" pretty (Unsigned Int64) = "u64" pretty (Signed t) = pretty t pretty (FloatType t) = pretty t pretty Bool = "bool" instance Pretty BinOp where pretty Backtick = "``" pretty Bang = "!" pretty Equ = "=" pretty Plus = "+" pretty Minus = "-" pretty Pow = "**" pretty Times = "*" pretty Divide = "/" pretty Mod = "%" pretty Quot = "//" pretty Rem = "%%" pretty ShiftR = ">>" pretty ShiftL = "<<" pretty Band = "&" pretty Xor = "^" pretty Bor = "|" pretty LogAnd = "&&" pretty LogOr = "||" pretty Equal = "==" pretty NotEqual = "!=" pretty Less = "<" pretty Leq = "<=" pretty Greater = ">" pretty Geq = ">=" pretty PipeLeft = "<|" pretty PipeRight = "|>" futhark-0.25.27/src/Language/Futhark/Traversals.hs000066400000000000000000000517561475065116200217600ustar00rootroot00000000000000-- | -- -- Functions for generic traversals across Futhark syntax trees. The -- motivation for this module came from dissatisfaction with rewriting -- the same trivial tree recursions for every module. A possible -- alternative would be to use normal \"Scrap your -- boilerplate\"-techniques, but these are rejected for two reasons: -- -- * They are too slow. -- -- * More importantly, they do not tell you whether you have missed -- some cases. -- -- Instead, this module defines various traversals of the Futhark syntax -- tree. The implementation is rather tedious, but the interface is -- easy to use. -- -- A traversal of the Futhark syntax tree is expressed as a record of -- functions expressing the operations to be performed on the various -- types of nodes. module Language.Futhark.Traversals ( ASTMapper (..), ASTMappable (..), identityMapper, bareExp, ) where import Data.Bifunctor import Data.Bitraversable import Data.List.NonEmpty qualified as NE import Language.Futhark.Syntax -- | Express a monad mapping operation on a syntax node. Each element -- of this structure expresses the operation to be performed on a -- given child. data ASTMapper m = ASTMapper { mapOnExp :: ExpBase Info VName -> m (ExpBase Info VName), mapOnName :: QualName VName -> m (QualName VName), mapOnStructType :: StructType -> m StructType, mapOnParamType :: ParamType -> m ParamType, mapOnResRetType :: ResRetType -> m ResRetType } -- | An 'ASTMapper' that just leaves its input unchanged. identityMapper :: (Monad m) => ASTMapper m identityMapper = ASTMapper { mapOnExp = pure, mapOnName = pure, mapOnStructType = pure, mapOnParamType = pure, mapOnResRetType = pure } -- | The class of things that we can map an 'ASTMapper' across. class ASTMappable x where -- | Map a monadic action across the immediate children of an -- object. Importantly, the 'astMap' action is not invoked for -- the object itself, and the mapping does not descend recursively -- into subexpressions. The mapping is done left-to-right. astMap :: (Monad m) => ASTMapper m -> x -> m x instance ASTMappable (AppExpBase Info VName) where astMap tv (Range start next end loc) = Range <$> mapOnExp tv start <*> traverse (mapOnExp tv) next <*> traverse (mapOnExp tv) end <*> pure loc astMap tv (If c texp fexp loc) = If <$> mapOnExp tv c <*> mapOnExp tv texp <*> mapOnExp tv fexp <*> pure loc astMap tv (Match e cases loc) = Match <$> mapOnExp tv e <*> astMap tv cases <*> pure loc astMap tv (Apply f args loc) = do f' <- mapOnExp tv f args' <- traverse (traverse $ mapOnExp tv) args -- Safe to disregard return type because existentials cannot be -- instantiated here, as the return is necessarily a function. pure $ case f' of AppExp (Apply f_inner args_inner _) _ -> Apply f_inner (args_inner <> args') loc _ -> Apply f' args' loc astMap tv (LetPat sizes pat e body loc) = LetPat sizes <$> astMap tv pat <*> mapOnExp tv e <*> mapOnExp tv body <*> pure loc astMap tv (LetFun name (tparams, params, ret, t, e) body loc) = LetFun name <$> ( (tparams,,,,) <$> mapM (astMap tv) params <*> traverse (astMap tv) ret <*> traverse (mapOnResRetType tv) t <*> mapOnExp tv e ) <*> mapOnExp tv body <*> pure loc astMap tv (LetWith dest src idxexps vexp body loc) = LetWith <$> astMap tv dest <*> astMap tv src <*> mapM (astMap tv) idxexps <*> mapOnExp tv vexp <*> mapOnExp tv body <*> pure loc astMap tv (BinOp (fname, fname_loc) t (x, xext) (y, yext) loc) = BinOp <$> ((,) <$> mapOnName tv fname <*> pure fname_loc) <*> traverse (mapOnStructType tv) t <*> ((,) <$> mapOnExp tv x <*> pure xext) <*> ((,) <$> mapOnExp tv y <*> pure yext) <*> pure loc astMap tv (Loop sparams mergepat loopinit form loopbody loc) = Loop sparams <$> astMap tv mergepat <*> astMap tv loopinit <*> astMap tv form <*> mapOnExp tv loopbody <*> pure loc astMap tv (Index arr idxexps loc) = Index <$> mapOnExp tv arr <*> mapM (astMap tv) idxexps <*> pure loc instance ASTMappable (ExpBase Info VName) where astMap tv (Var name t loc) = Var <$> mapOnName tv name <*> traverse (mapOnStructType tv) t <*> pure loc astMap tv (Hole t loc) = Hole <$> traverse (mapOnStructType tv) t <*> pure loc astMap _ (Literal val loc) = pure $ Literal val loc astMap _ (StringLit vs loc) = pure $ StringLit vs loc astMap tv (IntLit val t loc) = IntLit val <$> traverse (mapOnStructType tv) t <*> pure loc astMap tv (FloatLit val t loc) = FloatLit val <$> traverse (mapOnStructType tv) t <*> pure loc astMap tv (Parens e loc) = Parens <$> mapOnExp tv e <*> pure loc astMap tv (QualParens (name, nameloc) e loc) = QualParens <$> ((,) <$> mapOnName tv name <*> pure nameloc) <*> mapOnExp tv e <*> pure loc astMap tv (TupLit els loc) = TupLit <$> mapM (mapOnExp tv) els <*> pure loc astMap tv (RecordLit fields loc) = RecordLit <$> astMap tv fields <*> pure loc astMap _ (ArrayVal vs t loc) = pure $ ArrayVal vs t loc astMap tv (ArrayLit els t loc) = ArrayLit <$> mapM (mapOnExp tv) els <*> traverse (mapOnStructType tv) t <*> pure loc astMap tv (Ascript e tdecl loc) = Ascript <$> mapOnExp tv e <*> astMap tv tdecl <*> pure loc astMap tv (Coerce e tdecl t loc) = Coerce <$> mapOnExp tv e <*> astMap tv tdecl <*> traverse (mapOnStructType tv) t <*> pure loc astMap tv (Negate x loc) = Negate <$> mapOnExp tv x <*> pure loc astMap tv (Not x loc) = Not <$> mapOnExp tv x <*> pure loc astMap tv (Update src slice v loc) = Update <$> mapOnExp tv src <*> mapM (astMap tv) slice <*> mapOnExp tv v <*> pure loc astMap tv (RecordUpdate src fs v (Info t) loc) = RecordUpdate <$> mapOnExp tv src <*> pure fs <*> mapOnExp tv v <*> (Info <$> mapOnStructType tv t) <*> pure loc astMap tv (Project field e t loc) = Project field <$> mapOnExp tv e <*> traverse (mapOnStructType tv) t <*> pure loc astMap tv (Assert e1 e2 desc loc) = Assert <$> mapOnExp tv e1 <*> mapOnExp tv e2 <*> pure desc <*> pure loc astMap tv (Lambda params body ret t loc) = Lambda <$> mapM (astMap tv) params <*> mapOnExp tv body <*> traverse (astMap tv) ret <*> traverse (mapOnResRetType tv) t <*> pure loc astMap tv (OpSection name t loc) = OpSection <$> mapOnName tv name <*> traverse (mapOnStructType tv) t <*> pure loc astMap tv (OpSectionLeft name t arg (Info (pa, t1a, argext), Info (pb, t1b)) (ret, retext) loc) = OpSectionLeft <$> mapOnName tv name <*> traverse (mapOnStructType tv) t <*> mapOnExp tv arg <*> ( (,) <$> (Info <$> ((pa,,) <$> mapOnParamType tv t1a <*> pure argext)) <*> (Info <$> ((pb,) <$> mapOnParamType tv t1b)) ) <*> ((,) <$> traverse (mapOnResRetType tv) ret <*> pure retext) <*> pure loc astMap tv (OpSectionRight name t arg (Info (pa, t1a), Info (pb, t1b, argext)) t2 loc) = OpSectionRight <$> mapOnName tv name <*> traverse (mapOnStructType tv) t <*> mapOnExp tv arg <*> ( (,) <$> (Info <$> ((pa,) <$> mapOnParamType tv t1a)) <*> (Info <$> ((pb,,) <$> mapOnParamType tv t1b <*> pure argext)) ) <*> traverse (mapOnResRetType tv) t2 <*> pure loc astMap tv (ProjectSection fields t loc) = ProjectSection fields <$> traverse (mapOnStructType tv) t <*> pure loc astMap tv (IndexSection idxs t loc) = IndexSection <$> mapM (astMap tv) idxs <*> traverse (mapOnStructType tv) t <*> pure loc astMap tv (Constr name es t loc) = Constr name <$> traverse (mapOnExp tv) es <*> traverse (mapOnStructType tv) t <*> pure loc astMap tv (Attr attr e loc) = Attr attr <$> mapOnExp tv e <*> pure loc astMap tv (AppExp e res) = AppExp <$> astMap tv e <*> astMap tv res instance ASTMappable (LoopInitBase Info VName) where astMap tv (LoopInitExplicit e) = LoopInitExplicit <$> mapOnExp tv e astMap tv (LoopInitImplicit (Info e)) = LoopInitImplicit . Info <$> mapOnExp tv e instance ASTMappable (LoopFormBase Info VName) where astMap tv (For i bound) = For <$> astMap tv i <*> mapOnExp tv bound astMap tv (ForIn pat e) = ForIn <$> astMap tv pat <*> mapOnExp tv e astMap tv (While e) = While <$> mapOnExp tv e instance ASTMappable (TypeExp (ExpBase Info VName) VName) where astMap tv (TEVar qn loc) = TEVar <$> mapOnName tv qn <*> pure loc astMap tv (TEParens te loc) = TEParens <$> astMap tv te <*> pure loc astMap tv (TETuple ts loc) = TETuple <$> traverse (astMap tv) ts <*> pure loc astMap tv (TERecord ts loc) = TERecord <$> traverse (traverse $ astMap tv) ts <*> pure loc astMap tv (TEArray te dim loc) = TEArray <$> astMap tv te <*> astMap tv dim <*> pure loc astMap tv (TEUnique t loc) = TEUnique <$> astMap tv t <*> pure loc astMap tv (TEApply t1 t2 loc) = TEApply <$> astMap tv t1 <*> astMap tv t2 <*> pure loc astMap tv (TEArrow v t1 t2 loc) = TEArrow v <$> astMap tv t1 <*> astMap tv t2 <*> pure loc astMap tv (TESum cs loc) = TESum <$> traverse (traverse $ astMap tv) cs <*> pure loc astMap tv (TEDim dims t loc) = TEDim dims <$> astMap tv t <*> pure loc instance ASTMappable (TypeArgExp (ExpBase Info VName) VName) where astMap tv (TypeArgExpSize dim) = TypeArgExpSize <$> astMap tv dim astMap tv (TypeArgExpType te) = TypeArgExpType <$> astMap tv te instance ASTMappable (SizeExp (ExpBase Info VName)) where astMap tv (SizeExp e loc) = SizeExp <$> mapOnExp tv e <*> pure loc astMap _ (SizeExpAny loc) = pure $ SizeExpAny loc instance ASTMappable (DimIndexBase Info VName) where astMap tv (DimFix j) = DimFix <$> mapOnExp tv j astMap tv (DimSlice i j stride) = DimSlice <$> maybe (pure Nothing) (fmap Just . mapOnExp tv) i <*> maybe (pure Nothing) (fmap Just . mapOnExp tv) j <*> maybe (pure Nothing) (fmap Just . mapOnExp tv) stride instance ASTMappable AppRes where astMap tv (AppRes t ext) = AppRes <$> mapOnStructType tv t <*> pure ext type TypeTraverser f t dim1 als1 dim2 als2 = (QualName VName -> f (QualName VName)) -> (dim1 -> f dim2) -> (als1 -> f als2) -> t dim1 als1 -> f (t dim2 als2) traverseScalarType :: (Applicative f) => TypeTraverser f ScalarTypeBase dim1 als1 dims als2 traverseScalarType _ _ _ (Prim t) = pure $ Prim t traverseScalarType f g h (Record fs) = Record <$> traverse (traverseType f g h) fs traverseScalarType f g h (TypeVar als t args) = TypeVar <$> h als <*> f t <*> traverse (traverseTypeArg f g) args traverseScalarType f g h (Arrow als v u t1 (RetType dims t2)) = Arrow <$> h als <*> pure v <*> pure u <*> traverseType f g pure t1 <*> (RetType dims <$> traverseType f g pure t2) traverseScalarType f g h (Sum cs) = Sum <$> (traverse . traverse) (traverseType f g h) cs traverseType :: (Applicative f) => TypeTraverser f TypeBase dim1 als1 dims als2 traverseType f g h (Array als shape et) = Array <$> h als <*> traverse g shape <*> traverseScalarType f g pure et traverseType f g h (Scalar t) = Scalar <$> traverseScalarType f g h t traverseTypeArg :: (Applicative f) => (QualName VName -> f (QualName VName)) -> (dim1 -> f dim2) -> TypeArg dim1 -> f (TypeArg dim2) traverseTypeArg _ g (TypeArgDim d) = TypeArgDim <$> g d traverseTypeArg f g (TypeArgType t) = TypeArgType <$> traverseType f g pure t instance ASTMappable StructType where astMap tv = traverseType (mapOnName tv) (mapOnExp tv) pure instance ASTMappable ParamType where astMap tv = traverseType (mapOnName tv) (mapOnExp tv) pure instance ASTMappable (TypeBase Size Uniqueness) where astMap tv = traverseType (mapOnName tv) (mapOnExp tv) pure instance ASTMappable ResRetType where astMap tv (RetType ext t) = RetType ext <$> astMap tv t instance ASTMappable (IdentBase Info VName StructType) where astMap tv (Ident name (Info t) loc) = Ident name <$> (Info <$> mapOnStructType tv t) <*> pure loc traversePat :: (Monad m) => (t1 -> m t2) -> (ExpBase Info VName -> m (ExpBase Info VName)) -> PatBase Info VName t1 -> m (PatBase Info VName t2) traversePat f _ (Id name (Info t) loc) = Id name <$> (Info <$> f t) <*> pure loc traversePat f g (TuplePat pats loc) = TuplePat <$> mapM (traversePat f g) pats <*> pure loc traversePat f g (RecordPat fields loc) = RecordPat <$> mapM (traverse $ traversePat f g) fields <*> pure loc traversePat f g (PatParens pat loc) = PatParens <$> traversePat f g pat <*> pure loc traversePat f g (PatAscription pat t loc) = PatAscription <$> traversePat f g pat <*> bitraverse g pure t <*> pure loc traversePat f _ (Wildcard (Info t) loc) = Wildcard <$> (Info <$> f t) <*> pure loc traversePat f _ (PatLit v (Info t) loc) = PatLit v <$> (Info <$> f t) <*> pure loc traversePat f g (PatConstr n (Info t) ps loc) = PatConstr n <$> (Info <$> f t) <*> mapM (traversePat f g) ps <*> pure loc traversePat f g (PatAttr attr p loc) = PatAttr attr <$> traversePat f g p <*> pure loc instance ASTMappable (PatBase Info VName StructType) where astMap tv = traversePat (mapOnStructType tv) (mapOnExp tv) instance ASTMappable (PatBase Info VName ParamType) where astMap tv = traversePat (mapOnParamType tv) (mapOnExp tv) instance ASTMappable (FieldBase Info VName) where astMap tv (RecordFieldExplicit name e loc) = RecordFieldExplicit name <$> mapOnExp tv e <*> pure loc astMap tv (RecordFieldImplicit (L nameloc name) t loc) = RecordFieldImplicit <$> (L nameloc <$> (qualLeaf <$> mapOnName tv (QualName [] name))) <*> traverse (mapOnStructType tv) t <*> pure loc instance ASTMappable (CaseBase Info VName) where astMap tv (CasePat pat e loc) = CasePat <$> astMap tv pat <*> mapOnExp tv e <*> pure loc instance (ASTMappable a) => ASTMappable (Info a) where astMap tv = traverse $ astMap tv instance (ASTMappable a) => ASTMappable [a] where astMap tv = traverse $ astMap tv instance (ASTMappable a) => ASTMappable (NE.NonEmpty a) where astMap tv = traverse $ astMap tv instance (ASTMappable a, ASTMappable b) => ASTMappable (a, b) where astMap tv (x, y) = (,) <$> astMap tv x <*> astMap tv y instance (ASTMappable a, ASTMappable b, ASTMappable c) => ASTMappable (a, b, c) where astMap tv (x, y, z) = (,,) <$> astMap tv x <*> astMap tv y <*> astMap tv z -- It would be lovely if the following code would be written in terms -- of ASTMappable, but unfortunately it involves changing the Info -- functor. For simplicity, the general traversals do not support -- that. Sometimes a little duplication is better than an overly -- complex abstraction. The types ensure that this will be correct -- anyway, so it's just tedious, and not actually fragile. bareField :: FieldBase Info VName -> FieldBase NoInfo VName bareField (RecordFieldExplicit name e loc) = RecordFieldExplicit name (bareExp e) loc bareField (RecordFieldImplicit name _ loc) = RecordFieldImplicit name NoInfo loc barePat :: PatBase Info VName t -> PatBase NoInfo VName t barePat (TuplePat ps loc) = TuplePat (map barePat ps) loc barePat (RecordPat fs loc) = RecordPat (map (fmap barePat) fs) loc barePat (PatParens p loc) = PatParens (barePat p) loc barePat (Id v _ loc) = Id v NoInfo loc barePat (Wildcard _ loc) = Wildcard NoInfo loc barePat (PatAscription pat t loc) = PatAscription (barePat pat) (bareTypeExp t) loc barePat (PatLit v _ loc) = PatLit v NoInfo loc barePat (PatConstr c _ ps loc) = PatConstr c NoInfo (map barePat ps) loc barePat (PatAttr attr p loc) = PatAttr attr (barePat p) loc bareDimIndex :: DimIndexBase Info VName -> DimIndexBase NoInfo VName bareDimIndex (DimFix e) = DimFix $ bareExp e bareDimIndex (DimSlice x y z) = DimSlice (bareExp <$> x) (bareExp <$> y) (bareExp <$> z) bareLoopInit :: LoopInitBase Info VName -> LoopInitBase NoInfo VName bareLoopInit (LoopInitExplicit e) = LoopInitExplicit $ bareExp e bareLoopInit (LoopInitImplicit _) = LoopInitImplicit NoInfo bareLoopForm :: LoopFormBase Info VName -> LoopFormBase NoInfo VName bareLoopForm (For (Ident i _ loc) e) = For (Ident i NoInfo loc) (bareExp e) bareLoopForm (ForIn pat e) = ForIn (barePat pat) (bareExp e) bareLoopForm (While e) = While (bareExp e) bareCase :: CaseBase Info VName -> CaseBase NoInfo VName bareCase (CasePat pat e loc) = CasePat (barePat pat) (bareExp e) loc bareSizeExp :: SizeExp (ExpBase Info VName) -> SizeExp (ExpBase NoInfo VName) bareSizeExp (SizeExp e loc) = SizeExp (bareExp e) loc bareSizeExp (SizeExpAny loc) = SizeExpAny loc bareTypeExp :: TypeExp (ExpBase Info VName) VName -> TypeExp (ExpBase NoInfo VName) VName bareTypeExp (TEVar qn loc) = TEVar qn loc bareTypeExp (TEParens te loc) = TEParens (bareTypeExp te) loc bareTypeExp (TETuple tys loc) = TETuple (map bareTypeExp tys) loc bareTypeExp (TERecord fs loc) = TERecord (map (second bareTypeExp) fs) loc bareTypeExp (TEArray size ty loc) = TEArray (bareSizeExp size) (bareTypeExp ty) loc bareTypeExp (TEUnique ty loc) = TEUnique (bareTypeExp ty) loc bareTypeExp (TEApply ty ta loc) = TEApply (bareTypeExp ty) (bareTypeArgExp ta) loc where bareTypeArgExp (TypeArgExpSize size) = TypeArgExpSize $ bareSizeExp size bareTypeArgExp (TypeArgExpType tya) = TypeArgExpType $ bareTypeExp tya bareTypeExp (TEArrow arg tya tyr loc) = TEArrow arg (bareTypeExp tya) (bareTypeExp tyr) loc bareTypeExp (TESum cs loc) = TESum (map (second $ map bareTypeExp) cs) loc bareTypeExp (TEDim names ty loc) = TEDim names (bareTypeExp ty) loc -- | Remove all annotations from an expression, but retain the -- name/scope information. bareExp :: ExpBase Info VName -> ExpBase NoInfo VName bareExp (Var name _ loc) = Var name NoInfo loc bareExp (Hole _ loc) = Hole NoInfo loc bareExp (Literal v loc) = Literal v loc bareExp (IntLit val _ loc) = IntLit val NoInfo loc bareExp (FloatLit val _ loc) = FloatLit val NoInfo loc bareExp (Parens e loc) = Parens (bareExp e) loc bareExp (QualParens name e loc) = QualParens name (bareExp e) loc bareExp (TupLit els loc) = TupLit (map bareExp els) loc bareExp (StringLit vs loc) = StringLit vs loc bareExp (RecordLit fields loc) = RecordLit (map bareField fields) loc bareExp (ArrayVal vs t loc) = ArrayVal vs t loc bareExp (ArrayLit els _ loc) = ArrayLit (map bareExp els) NoInfo loc bareExp (Ascript e te loc) = Ascript (bareExp e) (bareTypeExp te) loc bareExp (Coerce e te _ loc) = Coerce (bareExp e) (bareTypeExp te) NoInfo loc bareExp (Negate x loc) = Negate (bareExp x) loc bareExp (Not x loc) = Not (bareExp x) loc bareExp (Update src slice v loc) = Update (bareExp src) (map bareDimIndex slice) (bareExp v) loc bareExp (RecordUpdate src fs v _ loc) = RecordUpdate (bareExp src) fs (bareExp v) NoInfo loc bareExp (Project field e _ loc) = Project field (bareExp e) NoInfo loc bareExp (Assert e1 e2 _ loc) = Assert (bareExp e1) (bareExp e2) NoInfo loc bareExp (Lambda params body ret _ loc) = Lambda (map barePat params) (bareExp body) (fmap bareTypeExp ret) NoInfo loc bareExp (OpSection name _ loc) = OpSection name NoInfo loc bareExp (OpSectionLeft name _ arg _ _ loc) = OpSectionLeft name NoInfo (bareExp arg) (NoInfo, NoInfo) (NoInfo, NoInfo) loc bareExp (OpSectionRight name _ arg _ _ loc) = OpSectionRight name NoInfo (bareExp arg) (NoInfo, NoInfo) NoInfo loc bareExp (ProjectSection fields _ loc) = ProjectSection fields NoInfo loc bareExp (IndexSection slice _ loc) = IndexSection (map bareDimIndex slice) NoInfo loc bareExp (Constr name es _ loc) = Constr name (map bareExp es) NoInfo loc bareExp (AppExp appexp _) = AppExp appexp' NoInfo where appexp' = case appexp of Match e cases loc -> Match (bareExp e) (fmap bareCase cases) loc Loop _ mergepat loopinit form loopbody loc -> Loop [] (barePat mergepat) (bareLoopInit loopinit) (bareLoopForm form) (bareExp loopbody) loc LetWith (Ident dest _ destloc) (Ident src _ srcloc) idxexps vexp body loc -> LetWith (Ident dest NoInfo destloc) (Ident src NoInfo srcloc) (map bareDimIndex idxexps) (bareExp vexp) (bareExp body) loc BinOp fname _ (x, _) (y, _) loc -> BinOp fname NoInfo (bareExp x, NoInfo) (bareExp y, NoInfo) loc If c texp fexp loc -> If (bareExp c) (bareExp texp) (bareExp fexp) loc Apply f args loc -> Apply (bareExp f) (fmap ((NoInfo,) . bareExp . snd) args) loc LetPat sizes pat e body loc -> LetPat sizes (barePat pat) (bareExp e) (bareExp body) loc LetFun name (fparams, params, ret, _, e) body loc -> LetFun name (fparams, map barePat params, fmap bareTypeExp ret, NoInfo, bareExp e) (bareExp body) loc Range start next end loc -> Range (bareExp start) (fmap bareExp next) (fmap bareExp end) loc Index arr slice loc -> Index (bareExp arr) (map bareDimIndex slice) loc bareExp (Attr attr e loc) = Attr attr (bareExp e) loc futhark-0.25.27/src/Language/Futhark/Tuple.hs000066400000000000000000000026461475065116200207150ustar00rootroot00000000000000-- \* Basic utilities for interpreting tuples as records. module Language.Futhark.Tuple ( areTupleFields, tupleFields, tupleFieldNames, sortFields, ) where import Data.Char (isDigit, ord) import Data.List (sortOn) import Data.Map qualified as M import Data.Text qualified as T import Language.Futhark.Core (Name, nameFromString, nameToText) -- | Does this record map correspond to a tuple? areTupleFields :: M.Map Name a -> Maybe [a] areTupleFields fs = let fs' = sortFields fs in if (null fs || length fs' > 1) && and (zipWith (==) (map fst fs') tupleFieldNames) then Just $ map snd fs' else Nothing -- | Construct a record map corresponding to a tuple. tupleFields :: [a] -> M.Map Name a tupleFields as = M.fromList $ zip tupleFieldNames as -- | Increasing field names for a tuple (starts at 0). tupleFieldNames :: [Name] tupleFieldNames = map (nameFromString . show) [(0 :: Int) ..] -- | Sort fields by their name; taking care to sort numeric fields by -- their numeric value. This ensures that tuples and tuple-like -- records match. sortFields :: M.Map Name a -> [(Name, a)] sortFields l = map snd $ sortOn fst $ zip (map (fieldish . fst) l') l' where l' = M.toList l onDigit Nothing _ = Nothing onDigit (Just d) c | isDigit c = Just $ d * 10 + ord c - ord '0' | otherwise = Nothing fieldish s = maybe (Right s) Left $ T.foldl' onDigit (Just 0) $ nameToText s futhark-0.25.27/src/Language/Futhark/TypeChecker.hs000066400000000000000000000662271475065116200220370ustar00rootroot00000000000000-- | The type checker checks whether the program is type-consistent -- and adds type annotations and various other elaborations. The -- program does not need to have any particular properties for the -- type checker to function; in particular it does not need unique -- names. module Language.Futhark.TypeChecker ( checkProg, checkExp, checkDec, checkModExp, Notes, TypeError (..), prettyTypeError, prettyTypeErrorNoLoc, Warnings, initialEnv, envWithImports, ) where import Control.Monad import Data.Bifunctor import Data.Either import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Data.Ord import Data.Set qualified as S import Futhark.FreshNames hiding (newName) import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Semantic import Language.Futhark.TypeChecker.Modules import Language.Futhark.TypeChecker.Monad import Language.Futhark.TypeChecker.Names import Language.Futhark.TypeChecker.Terms import Language.Futhark.TypeChecker.Types import Prelude hiding (abs, mod) --- The main checker -- | Type check a program containing no type information, yielding -- either a type error or a program with complete type information. -- Accepts a mapping from file names (excluding extension) to -- previously type checked results. The 'ImportName' is used to resolve -- relative @import@s. checkProg :: Imports -> VNameSource -> ImportName -> UncheckedProg -> (Warnings, Either TypeError (FileModule, VNameSource)) checkProg files src name prog = runTypeM initialEnv files' name src $ checkProgM prog where files' = M.map fileEnv $ M.fromList files -- | Type check a single expression containing no type information, -- yielding either a type error or the same expression annotated with -- type information. Also returns a list of type parameters, which -- will be nonempty if the expression is polymorphic. See also -- 'checkProg'. checkExp :: Imports -> VNameSource -> Env -> UncheckedExp -> (Warnings, Either TypeError ([TypeParam], Exp)) checkExp files src env e = second (fmap fst) $ runTypeM env files' (mkInitialImport "") src (checkOneExp =<< resolveExp e) where files' = M.map fileEnv $ M.fromList files -- | Type check a single declaration containing no type information, -- yielding either a type error or the same declaration annotated with -- type information along the Env produced by that declaration. See -- also 'checkProg'. checkDec :: Imports -> VNameSource -> Env -> ImportName -> UncheckedDec -> (Warnings, Either TypeError (Env, Dec, VNameSource)) checkDec files src env name d = second (fmap massage) $ runTypeM env files' name src $ do (_, env', d') <- checkOneDec d pure (env' <> env, d') where massage ((env', d'), src') = (env', d', src') files' = M.map fileEnv $ M.fromList files -- | Type check a single module expression containing no type information, -- yielding either a type error or the same expression annotated with -- type information along the Env produced by that declaration. See -- also 'checkProg'. checkModExp :: Imports -> VNameSource -> Env -> ModExpBase NoInfo Name -> (Warnings, Either TypeError (MTy, ModExpBase Info VName)) checkModExp files src env me = second (fmap fst) . runTypeM env files' (mkInitialImport "") src $ do (_abs, mty, me') <- checkOneModExp me pure (mty, me') where files' = M.map fileEnv $ M.fromList files -- | An initial environment for the type checker, containing -- intrinsics and such. initialEnv :: Env initialEnv = intrinsicsModule { envModTable = initialModTable, envNameMap = M.insert (Term, nameFromString "intrinsics") (qualName intrinsics_v) topLevelNameMap } where initialTypeTable = M.fromList $ mapMaybe addIntrinsicT $ M.toList intrinsics initialModTable = M.singleton intrinsics_v (ModEnv intrinsicsModule) intrinsics_v = VName (nameFromString "intrinsics") 0 intrinsicsModule = Env mempty initialTypeTable mempty mempty intrinsicsNameMap addIntrinsicT (name, IntrinsicType l ps t) = Just (name, TypeAbbr l ps $ RetType [] t) addIntrinsicT _ = Nothing -- | Produce an environment, based on the one passed in, where all of -- the provided imports have been @open@ened in order. This could in principle -- also be done with 'checkDec', but this is more precise. envWithImports :: Imports -> Env -> Env envWithImports imports env = mconcat (map (fileEnv . snd) (reverse imports)) <> env checkProgM :: UncheckedProg -> TypeM FileModule checkProgM (Prog doc decs) = do checkForDuplicateDecs decs (abs, env, decs', full_env) <- checkDecs decs pure (FileModule abs env (Prog doc decs') full_env) dupDefinitionError :: (MonadTypeChecker m) => Namespace -> Name -> SrcLoc -> SrcLoc -> m a dupDefinitionError space name loc1 loc2 = typeError loc1 mempty $ "Duplicate definition of" <+> pretty space <+> prettyName name <> "." "Previously defined at" <+> pretty (locStr loc2) <> "." checkForDuplicateDecs :: [DecBase NoInfo Name] -> TypeM () checkForDuplicateDecs = foldM_ (flip f) mempty where check namespace name loc known = case M.lookup (namespace, name) known of Just loc' -> dupDefinitionError namespace name loc loc' _ -> pure $ M.insert (namespace, name) loc known f (ValDec vb) = check Term (valBindName vb) (srclocOf vb) f (TypeDec (TypeBind name _ _ _ _ _ loc)) = check Type name loc f (ModTypeDec (ModTypeBind name _ _ loc)) = check Signature name loc f (ModDec (ModBind name _ _ _ _ loc)) = check Term name loc f OpenDec {} = pure f LocalDec {} = pure f ImportDec {} = pure bindingTypeParams :: [TypeParam] -> TypeM a -> TypeM a bindingTypeParams tparams = localEnv env where env = mconcat $ map typeParamEnv tparams typeParamEnv (TypeParamDim v _) = mempty { envVtable = M.singleton v $ BoundV [] (Scalar $ Prim $ Signed Int64) } typeParamEnv (TypeParamType l v _) = mempty { envTypeTable = M.singleton v $ TypeAbbr l [] . RetType [] . Scalar $ TypeVar mempty (qualName v) [] } checkTypeDecl :: UncheckedTypeExp -> TypeM ([VName], TypeExp Exp VName, StructType, Liftedness) checkTypeDecl te = do (te', svars, RetType dims st, l) <- checkTypeExp checkSizeExp =<< resolveTypeExp te pure (svars ++ dims, te', toStruct st, l) -- In this function, after the recursion, we add the Env of the -- current Spec *after* the one that is returned from the recursive -- call. This implements the behaviour that specs later in a module -- type can override those earlier (it rarely matters, but it affects -- the specific structure of substitutions in case some module type is -- redundantly imported multiple times). checkSpecs :: [SpecBase NoInfo Name] -> TypeM (TySet, Env, [SpecBase Info VName]) checkSpecs [] = pure (mempty, mempty, []) checkSpecs (ValSpec name tparams vtype NoInfo doc loc : specs) = do (tparams', vtype', vtype_t) <- resolveTypeParams tparams $ \tparams' -> bindingTypeParams tparams' $ do (ext, vtype', vtype_t, _) <- checkTypeDecl vtype unless (null ext) $ typeError loc mempty $ "All function parameters must have non-anonymous sizes." "Hint: add size parameters to" <+> dquotes (pretty name) <> "." pure (tparams', vtype', vtype_t) bindSpaced1 Term name loc $ \name' -> do let valenv = mempty { envVtable = M.singleton name' $ BoundV tparams' vtype_t, envNameMap = M.singleton (Term, name) $ qualName name' } usedName name' (abstypes, env, specs') <- localEnv valenv $ checkSpecs specs pure ( abstypes, env <> valenv, ValSpec name' tparams' vtype' (Info vtype_t) doc loc : specs' ) checkSpecs (TypeAbbrSpec tdec : specs) = do (tenv, tdec') <- checkTypeBind tdec bindSpaced1 Type (typeAlias tdec) (srclocOf tdec) $ \name' -> do usedName name' (abstypes, env, specs') <- localEnv tenv $ checkSpecs specs pure ( abstypes, env <> tenv, TypeAbbrSpec tdec' : specs' ) checkSpecs (TypeSpec l name ps doc loc : specs) = do ps' <- resolveTypeParams ps pure bindSpaced1 Type name loc $ \name' -> do usedName name' let tenv = mempty { envNameMap = M.singleton (Type, name) $ qualName name', envTypeTable = M.singleton name' $ TypeAbbr l ps' . RetType [] . Scalar $ TypeVar mempty (qualName name') $ map typeParamToArg ps' } (abstypes, env, specs') <- localEnv tenv $ checkSpecs specs pure ( M.insert (qualName name') l abstypes, env <> tenv, TypeSpec l name' ps' doc loc : specs' ) checkSpecs (ModSpec name sig doc loc : specs) = do (_sig_abs, mty, sig') <- checkModTypeExp sig bindSpaced1 Term name loc $ \name' -> do usedName name' let senv = mempty { envNameMap = M.singleton (Term, name) $ qualName name', envModTable = M.singleton name' $ mtyMod mty } (abstypes, env, specs') <- localEnv senv $ checkSpecs specs pure ( M.mapKeys (qualify name') (mtyAbs mty) <> abstypes, env <> senv, ModSpec name' sig' doc loc : specs' ) checkSpecs (IncludeSpec e loc : specs) = do (e_abs, env_abs, e_env, e') <- checkModTypeExpToEnv e mapM_ warnIfShadowing $ M.keys env_abs (abstypes, env, specs') <- localEnv e_env $ checkSpecs specs pure ( e_abs <> env_abs <> abstypes, env <> e_env, IncludeSpec e' loc : specs' ) where warnIfShadowing qn = do known <- isKnownType qn when known $ warnAbout qn warnAbout qn = warn loc $ "Inclusion shadows type" <+> dquotes (pretty qn) <> "." checkModTypeExp :: ModTypeExpBase NoInfo Name -> TypeM (TySet, MTy, ModTypeExpBase Info VName) checkModTypeExp (ModTypeParens e loc) = do (abs, mty, e') <- checkModTypeExp e pure (abs, mty, ModTypeParens e' loc) checkModTypeExp (ModTypeVar name NoInfo loc) = do (name', mty) <- lookupMTy loc name (mty', substs) <- newNamesForMTy mty pure (mtyAbs mty', mty', ModTypeVar name' (Info substs) loc) checkModTypeExp (ModTypeSpecs specs loc) = do checkForDuplicateSpecs specs (abstypes, env, specs') <- checkSpecs specs pure (abstypes, MTy abstypes $ ModEnv env, ModTypeSpecs specs' loc) checkModTypeExp (ModTypeWith s (TypeRef tname ps te trloc) loc) = do (abs, s_abs, s_env, s') <- checkModTypeExpToEnv s resolveTypeParams ps $ \ps' -> do (ext, te', te_t, _) <- bindingTypeParams ps' $ checkTypeDecl te unless (null ext) $ typeError te' mempty "Anonymous dimensions are not allowed here." (tname', s_abs', s_env') <- refineEnv loc s_abs s_env tname ps' te_t pure (abs, MTy s_abs' $ ModEnv s_env', ModTypeWith s' (TypeRef tname' ps' te' trloc) loc) checkModTypeExp (ModTypeArrow maybe_pname e1 e2 loc) = do (e1_abs, MTy s_abs e1_mod, e1') <- checkModTypeExp e1 (env_for_e2, maybe_pname') <- case maybe_pname of Just pname -> bindSpaced1 Term pname loc $ \pname' -> pure ( mempty { envNameMap = M.singleton (Term, pname) $ qualName pname', envModTable = M.singleton pname' e1_mod }, Just pname' ) Nothing -> pure (mempty, Nothing) (e2_abs, e2_mod, e2') <- localEnv env_for_e2 $ checkModTypeExp e2 pure ( e1_abs <> e2_abs, MTy mempty $ ModFun $ FunModType s_abs e1_mod e2_mod, ModTypeArrow maybe_pname' e1' e2' loc ) checkModTypeExpToEnv :: ModTypeExpBase NoInfo Name -> TypeM (TySet, TySet, Env, ModTypeExpBase Info VName) checkModTypeExpToEnv e = do (abs, MTy mod_abs mod, e') <- checkModTypeExp e case mod of ModEnv env -> pure (abs, mod_abs, env, e') ModFun {} -> unappliedFunctor $ srclocOf e checkModTypeBind :: ModTypeBindBase NoInfo Name -> TypeM (TySet, Env, ModTypeBindBase Info VName) checkModTypeBind (ModTypeBind name e doc loc) = do (abs, env, e') <- checkModTypeExp e bindSpaced1 Signature name loc $ \name' -> do usedName name' pure ( abs, mempty { envModTypeTable = M.singleton name' env, envNameMap = M.singleton (Signature, name) (qualName name') }, ModTypeBind name' e' doc loc ) checkOneModExp :: ModExpBase NoInfo Name -> TypeM (TySet, MTy, ModExpBase Info VName) checkOneModExp (ModParens e loc) = do (abs, mty, e') <- checkOneModExp e pure (abs, mty, ModParens e' loc) checkOneModExp (ModDecs decs loc) = do checkForDuplicateDecs decs (abstypes, env, decs', _) <- checkDecs decs pure ( abstypes, MTy abstypes $ ModEnv env, ModDecs decs' loc ) checkOneModExp (ModVar v loc) = do (v', env) <- lookupMod loc v when ( baseName (qualLeaf v') == nameFromString "intrinsics" && baseTag (qualLeaf v') <= maxIntrinsicTag ) $ typeError loc mempty "The 'intrinsics' module may not be used in module expressions." pure (mempty, MTy mempty env, ModVar v' loc) checkOneModExp (ModImport name NoInfo loc) = do (name', env) <- lookupImport loc name pure ( mempty, MTy mempty $ ModEnv env, ModImport name (Info name') loc ) checkOneModExp (ModApply f e NoInfo NoInfo loc) = do (f_abs, f_mty, f') <- checkOneModExp f case mtyMod f_mty of ModFun functor -> do (e_abs, e_mty, e') <- checkOneModExp e (mty, psubsts, rsubsts) <- applyFunctor (locOf loc) functor e_mty pure ( mtyAbs mty <> f_abs <> e_abs, mty, ModApply f' e' (Info psubsts) (Info rsubsts) loc ) _ -> typeError loc mempty "Cannot apply non-parametric module." checkOneModExp (ModAscript me se NoInfo loc) = do (me_abs, me_mod, me') <- checkOneModExp me (se_abs, se_mty, se') <- checkModTypeExp se match_subst <- badOnLeft $ matchMTys me_mod se_mty (locOf loc) pure (se_abs <> me_abs, se_mty, ModAscript me' se' (Info match_subst) loc) checkOneModExp (ModLambda param maybe_fsig_e body_e loc) = withModParam param $ \param' param_abs param_mod -> do (abs, maybe_fsig_e', body_e', mty) <- checkModBody (fst <$> maybe_fsig_e) body_e loc pure ( abs, MTy mempty $ ModFun $ FunModType param_abs param_mod mty, ModLambda param' maybe_fsig_e' body_e' loc ) checkOneModExpToEnv :: ModExpBase NoInfo Name -> TypeM (TySet, Env, ModExpBase Info VName) checkOneModExpToEnv e = do (e_abs, MTy abs mod, e') <- checkOneModExp e case mod of ModEnv env -> pure (e_abs <> abs, env, e') ModFun {} -> unappliedFunctor $ srclocOf e withModParam :: ModParamBase NoInfo Name -> (ModParamBase Info VName -> TySet -> Mod -> TypeM a) -> TypeM a withModParam (ModParam pname psig_e NoInfo loc) m = do (_abs, MTy p_abs p_mod, psig_e') <- checkModTypeExp psig_e bindSpaced1 Term pname loc $ \pname' -> do let in_body_env = mempty {envModTable = M.singleton pname' p_mod} localEnv in_body_env $ m (ModParam pname' psig_e' (Info $ map qualLeaf $ M.keys p_abs) loc) p_abs p_mod withModParams :: [ModParamBase NoInfo Name] -> ([(ModParamBase Info VName, TySet, Mod)] -> TypeM a) -> TypeM a withModParams [] m = m [] withModParams (p : ps) m = withModParam p $ \p' pabs pmod -> withModParams ps $ \ps' -> m $ (p', pabs, pmod) : ps' checkModBody :: Maybe (ModTypeExpBase NoInfo Name) -> ModExpBase NoInfo Name -> SrcLoc -> TypeM ( TySet, Maybe (ModTypeExp, Info (M.Map VName VName)), ModExp, MTy ) checkModBody maybe_fsig_e body_e loc = enteringModule $ do (body_e_abs, body_mty, body_e') <- checkOneModExp body_e case maybe_fsig_e of Nothing -> pure ( mtyAbs body_mty <> body_e_abs, Nothing, body_e', body_mty ) Just fsig_e -> do (fsig_abs, fsig_mty, fsig_e') <- checkModTypeExp fsig_e fsig_subst <- badOnLeft $ matchMTys body_mty fsig_mty (locOf loc) pure ( fsig_abs <> body_e_abs, Just (fsig_e', Info fsig_subst), body_e', fsig_mty ) checkModBind :: ModBindBase NoInfo Name -> TypeM (TySet, Env, ModBindBase Info VName) checkModBind (ModBind name [] maybe_fsig_e e doc loc) = do (e_abs, maybe_fsig_e', e', mty) <- checkModBody (fst <$> maybe_fsig_e) e loc bindSpaced1 Term name loc $ \name' -> do usedName name' pure ( e_abs, mempty { envModTable = M.singleton name' $ mtyMod mty, envNameMap = M.singleton (Term, name) $ qualName name' }, ModBind name' [] maybe_fsig_e' e' doc loc ) checkModBind (ModBind name (p : ps) maybe_fsig_e body_e doc loc) = do (abs, params', maybe_fsig_e', body_e', funsig) <- withModParam p $ \p' p_abs p_mod -> withModParams ps $ \params_stuff -> do let (ps', ps_abs, ps_mod) = unzip3 params_stuff (abs, maybe_fsig_e', body_e', mty) <- checkModBody (fst <$> maybe_fsig_e) body_e loc let addParam (x, y) mty' = MTy mempty $ ModFun $ FunModType x y mty' pure ( abs, p' : ps', maybe_fsig_e', body_e', FunModType p_abs p_mod $ foldr addParam mty $ zip ps_abs ps_mod ) bindSpaced1 Term name loc $ \name' -> do usedName name' pure ( abs, mempty { envModTable = M.singleton name' $ ModFun funsig, envNameMap = M.singleton (Term, name) $ qualName name' }, ModBind name' params' maybe_fsig_e' body_e' doc loc ) checkForDuplicateSpecs :: [SpecBase NoInfo Name] -> TypeM () checkForDuplicateSpecs = foldM_ (flip f) mempty where check namespace name loc known = case M.lookup (namespace, name) known of Just loc' -> dupDefinitionError namespace name loc loc' _ -> pure $ M.insert (namespace, name) loc known f (ValSpec name _ _ _ _ loc) = check Term name loc f (TypeAbbrSpec (TypeBind name _ _ _ _ _ loc)) = check Type name loc f (TypeSpec _ name _ _ loc) = check Type name loc f (ModSpec name _ _ loc) = check Term name loc f IncludeSpec {} = pure checkTypeBind :: TypeBindBase NoInfo Name -> TypeM (Env, TypeBindBase Info VName) checkTypeBind (TypeBind name l tps te NoInfo doc loc) = resolveTypeParams tps $ \tps' -> do (te', svars, RetType dims t, l') <- bindingTypeParams tps' $ checkTypeExp checkSizeExp =<< resolveTypeExp te let (witnessed, _) = determineSizeWitnesses $ toStruct t case L.find (`S.notMember` witnessed) svars of Just _ -> typeError (locOf te) mempty . withIndexLink "anonymous-nonconstructive" $ "Type abbreviation contains an anonymous size not used constructively as an array size." Nothing -> pure () let elab_t = RetType (svars ++ dims) $ toStruct t let used_dims = fvVars $ freeInType t case filter ((`S.notMember` used_dims) . typeParamName) $ filter isSizeParam tps' of [] -> pure () tp : _ -> typeError loc mempty $ "Size parameter" <+> dquotes (pretty tp) <+> "unused." case (l, l') of (_, Lifted) | l < Lifted -> typeError loc mempty $ "Non-lifted type abbreviations may not contain functions." "Hint: consider using 'type^'." (_, SizeLifted) | l < SizeLifted -> typeError loc mempty $ "Non-size-lifted type abbreviations may not contain size-lifted types." "Hint: consider using 'type~'." (Unlifted, _) | not $ null $ svars ++ dims -> typeError loc mempty $ "Non-lifted type abbreviations may not use existential sizes in their definition." "Hint: use 'type~' or add size parameters to" <+> dquotes (prettyName name) <> "." _ -> pure () bindSpaced1 Type name loc $ \name' -> do usedName name' pure ( mempty { envTypeTable = M.singleton name' $ TypeAbbr l tps' elab_t, envNameMap = M.singleton (Type, name) $ qualName name' }, TypeBind name' l tps' te' (Info elab_t) doc loc ) entryPoint :: [Pat ParamType] -> Maybe (TypeExp Exp VName) -> ResRetType -> EntryPoint entryPoint params orig_ret_te (RetType _ret orig_ret) = EntryPoint (map patternEntry params ++ more_params) rettype' where (more_params, rettype') = onRetType orig_ret_te $ toStruct orig_ret patternEntry (PatParens p _) = patternEntry p patternEntry (PatAscription p te _) = EntryParam (patternName p) $ EntryType (patternStructType p) (Just te) patternEntry p = EntryParam (patternName p) $ EntryType (patternStructType p) Nothing patternName (Id x _ _) = baseName x patternName (PatParens p _) = patternName p patternName _ = "_" pname (Named v) = baseName v pname Unnamed = "_" onRetType (Just (TEArrow p t1_te t2_te _)) (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = let (xs, y) = onRetType (Just t2_te) $ toStruct t2 in (EntryParam (maybe "_" baseName p) (EntryType t1 (Just t1_te)) : xs, y) onRetType _ (Scalar (Arrow _ p _ t1 (RetType _ t2))) = let (xs, y) = onRetType Nothing $ toStruct t2 in (EntryParam (pname p) (EntryType t1 Nothing) : xs, y) onRetType te t = ([], EntryType t te) checkEntryPoint :: SrcLoc -> [TypeParam] -> [Pat ParamType] -> Maybe (TypeExp Exp VName) -> ResRetType -> TypeM () checkEntryPoint loc tparams params maybe_tdecl rettype | any isTypeParam tparams = typeError loc mempty $ withIndexLink "polymorphic-entry" "Entry point functions may not be polymorphic." | not (all orderZero param_ts) || not (orderZero rettype') = typeError loc mempty $ withIndexLink "higher-order-entry" "Entry point functions may not be higher-order." | sizes_only_in_ret <- S.fromList (map typeParamName tparams) `S.intersection` fvVars (freeInType rettype') `S.difference` foldMap (fvVars . freeInType) param_ts, not $ S.null sizes_only_in_ret = typeError loc mempty $ withIndexLink "size-polymorphic-entry" "Entry point functions must not be size-polymorphic in their return type." | (constructive, _) <- foldMap (determineSizeWitnesses . toStruct) param_ts, Just p <- L.find (flip S.notMember constructive . typeParamName) tparams = typeError p mempty . withIndexLink "nonconstructive-entry" $ "Entry point size parameter " <> pretty p <> " only used non-constructively." | p : _ <- filter nastyParameter params = warn p $ "Entry point parameter\n" indent 2 (pretty p) "\nwill have an opaque type, so the entry point will likely not be callable." | nastyReturnType maybe_tdecl rettype_t = warn loc $ "Entry point return type\n" indent 2 (pretty rettype) "\nwill have an opaque type, so the result will likely not be usable." | otherwise = pure () where (RetType _ rettype_t) = rettype (rettype_params, rettype') = unfoldFunType rettype_t param_ts = map patternType params ++ rettype_params checkValBind :: ValBindBase NoInfo Name -> TypeM (Env, ValBind) checkValBind vb = do (ValBind entry fname maybe_tdecl NoInfo tparams params body doc attrs loc) <- resolveValBind vb top_level <- atTopLevel when (not top_level && isJust entry) $ typeError loc mempty $ withIndexLink "nested-entry" "Entry points may not be declared inside modules." attrs' <- mapM checkAttr attrs (tparams', params', maybe_tdecl', rettype, body') <- checkFunDef (fname, maybe_tdecl, tparams, params, body, loc) let entry' = Info (entryPoint params' maybe_tdecl' rettype) <$ entry case entry' of Just _ -> checkEntryPoint loc tparams' params' maybe_tdecl' rettype _ -> pure () let vb' = ValBind entry' fname maybe_tdecl' (Info rettype) tparams' params' body' doc attrs' loc pure ( mempty { envVtable = M.singleton fname $ uncurry BoundV $ valBindTypeScheme vb', envNameMap = M.singleton (Term, baseName fname) $ qualName fname }, vb' ) nastyType :: (Monoid als) => TypeBase dim als -> Bool nastyType (Scalar Prim {}) = False nastyType t@Array {} = nastyType $ stripArray 1 t nastyType _ = True nastyReturnType :: (Monoid als) => Maybe (TypeExp Exp VName) -> TypeBase dim als -> Bool nastyReturnType Nothing (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = nastyType t1 || nastyReturnType Nothing t2 nastyReturnType (Just (TEArrow _ te1 te2 _)) (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = (not (niceTypeExp te1) && nastyType t1) || nastyReturnType (Just te2) t2 nastyReturnType (Just te) _ | niceTypeExp te = False nastyReturnType te t | Just ts <- isTupleRecord t = case te of Just (TETuple tes _) -> or $ zipWith nastyType' (map Just tes) ts _ -> any nastyType ts | otherwise = nastyType' te t where nastyType' (Just te') _ | niceTypeExp te' = False nastyType' _ t' = nastyType t' nastyParameter :: Pat ParamType -> Bool nastyParameter p = nastyType (patternType p) && not (ascripted p) where ascripted (PatAscription _ te _) = niceTypeExp te ascripted (PatParens p' _) = ascripted p' ascripted _ = False niceTypeExp :: TypeExp Exp VName -> Bool niceTypeExp (TEVar (QualName [] _) _) = True niceTypeExp (TEApply te TypeArgExpSize {} _) = niceTypeExp te niceTypeExp (TEArray _ te _) = niceTypeExp te niceTypeExp (TEUnique te _) = niceTypeExp te niceTypeExp (TEDim _ te _) = niceTypeExp te niceTypeExp _ = False checkOneDec :: DecBase NoInfo Name -> TypeM (TySet, Env, DecBase Info VName) checkOneDec (ModDec struct) = do (abs, modenv, struct') <- checkModBind struct pure (abs, modenv, ModDec struct') checkOneDec (ModTypeDec sig) = do (abs, sigenv, sig') <- checkModTypeBind sig pure (abs, sigenv, ModTypeDec sig') checkOneDec (TypeDec tdec) = do (tenv, tdec') <- checkTypeBind tdec pure (mempty, tenv, TypeDec tdec') checkOneDec (OpenDec x loc) = do (x_abs, x_env, x') <- checkOneModExpToEnv x pure (x_abs, x_env, OpenDec x' loc) checkOneDec (LocalDec d loc) = do (abstypes, env, d') <- checkOneDec d pure (abstypes, env, LocalDec d' loc) checkOneDec (ImportDec name NoInfo loc) = do (name', env) <- lookupImport loc name when (isBuiltin name) $ typeError loc mempty $ pretty name <+> "may not be explicitly imported." pure (mempty, env, ImportDec name (Info name') loc) checkOneDec (ValDec vb) = do (env, vb') <- checkValBind vb pure (mempty, env, ValDec vb') checkDecs :: [DecBase NoInfo Name] -> TypeM (TySet, Env, [DecBase Info VName], Env) checkDecs (d : ds) = do (d_abstypes, d_env, d') <- checkOneDec d (ds_abstypes, ds_env, ds', full_env) <- localEnv d_env $ checkDecs ds pure ( d_abstypes <> ds_abstypes, case d' of LocalDec {} -> ds_env ImportDec {} -> ds_env _ -> ds_env <> d_env, d' : ds', full_env ) checkDecs [] = do full_env <- askEnv pure (mempty, mempty, [], full_env) futhark-0.25.27/src/Language/Futhark/TypeChecker/000077500000000000000000000000001475065116200214665ustar00rootroot00000000000000futhark-0.25.27/src/Language/Futhark/TypeChecker/Consumption.hs000066400000000000000000001125251475065116200243460ustar00rootroot00000000000000-- | Check that a value definition does not violate any consumption -- constraints. module Language.Futhark.TypeChecker.Consumption ( checkValDef, ) where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.Bifoldable import Data.Bifunctor import Data.DList qualified as DL import Data.Foldable import Data.List qualified as L import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Monad (Notes, TypeError (..), withIndexLink) import Prelude hiding (mod) type Names = S.Set VName -- | A variable that is aliased. Can be still in-scope, or have gone -- out of scope and be free. In the latter case, it behaves more like -- an equivalence class. See uniqueness-error18.fut for an example of -- why this is necessary. data Alias = AliasBound {aliasVar :: VName} | AliasFree {aliasVar :: VName} deriving (Eq, Ord, Show) instance Pretty Alias where pretty (AliasBound v) = prettyName v pretty (AliasFree v) = "~" <> prettyName v instance Pretty (S.Set Alias) where pretty = braces . commasep . map pretty . S.toList -- | The set of in-scope variables that are being aliased. boundAliases :: Aliases -> S.Set VName boundAliases = S.map aliasVar . S.filter bound where bound AliasBound {} = True bound AliasFree {} = False -- | Aliases for a type, which is a set of the variables that are -- aliased. type Aliases = S.Set Alias type TypeAliases = TypeBase Size Aliases -- | @t \`setAliases\` als@ returns @t@, but with @als@ substituted for -- any already present aliases. setAliases :: TypeBase dim asf -> ast -> TypeBase dim ast setAliases t = addAliases t . const -- | @t \`addAliases\` f@ returns @t@, but with any already present -- aliases replaced by @f@ applied to that aliases. addAliases :: TypeBase dim asf -> (asf -> ast) -> TypeBase dim ast addAliases = flip second aliases :: TypeAliases -> Aliases aliases = bifoldMap (const mempty) id setFieldAliases :: TypeAliases -> [Name] -> TypeAliases -> TypeAliases setFieldAliases ve_als (x : xs) (Scalar (Record fs)) = Scalar $ Record $ M.adjust (setFieldAliases ve_als xs) x fs setFieldAliases ve_als _ _ = ve_als data Entry a = Consumable {entryAliases :: a} | Nonconsumable {entryAliases :: a} deriving (Eq, Ord, Show) instance Functor Entry where fmap f (Consumable als) = Consumable $ f als fmap f (Nonconsumable als) = Nonconsumable $ f als data CheckEnv = CheckEnv { envVtable :: M.Map VName (Entry TypeAliases), -- | Location of the definition we are checking. envLoc :: Loc } -- | A description of where an artificial compiler-generated -- intermediate name came from. data NameReason = -- | Name is the result of a function application. NameAppRes (Maybe (QualName VName)) SrcLoc | NameLoopRes SrcLoc nameReason :: SrcLoc -> NameReason -> Doc a nameReason loc (NameAppRes Nothing apploc) = "result of application at" <+> pretty (locStrRel loc apploc) nameReason loc (NameAppRes fname apploc) = "result of applying" <+> dquotes (pretty fname) <+> parens ("at" <+> pretty (locStrRel loc apploc)) nameReason loc (NameLoopRes apploc) = "result of loop at" <+> pretty (locStrRel loc apploc) type Consumed = M.Map VName Loc data CheckState = CheckState { stateConsumed :: Consumed, stateErrors :: DL.DList TypeError, stateNames :: M.Map VName NameReason, stateCounter :: Int } newtype CheckM a = CheckM (ReaderT CheckEnv (State CheckState) a) deriving ( Functor, Applicative, Monad, MonadReader CheckEnv, MonadState CheckState ) runCheckM :: Loc -> CheckM a -> (a, [TypeError]) runCheckM loc (CheckM m) = let (a, s) = runState (runReaderT m env) initial_state in (a, DL.toList (stateErrors s)) where env = CheckEnv { envVtable = mempty, envLoc = loc } initial_state = CheckState { stateConsumed = mempty, stateErrors = mempty, stateNames = mempty, stateCounter = 0 } describeVar :: VName -> CheckM (Doc a) describeVar v = do loc <- asks envLoc gets $ maybe ("variable" <+> dquotes (prettyName v)) (nameReason (srclocOf loc)) . M.lookup v . stateNames noConsumable :: CheckM a -> CheckM a noConsumable = local $ \env -> env {envVtable = M.map f $ envVtable env} where f = Nonconsumable . entryAliases addError :: (Located loc) => loc -> Notes -> Doc () -> CheckM () addError loc notes e = modify $ \s -> s {stateErrors = DL.snoc (stateErrors s) (TypeError (locOf loc) notes e)} incCounter :: CheckM Int incCounter = state $ \s -> (stateCounter s, s {stateCounter = stateCounter s + 1}) returnAliased :: Name -> SrcLoc -> CheckM () returnAliased name loc = addError loc mempty . withIndexLink "return-aliased" $ "Unique-typed return value is aliased to" <+> dquotes (prettyName name) <> ", which is not consumable." uniqueReturnAliased :: SrcLoc -> CheckM () uniqueReturnAliased loc = addError loc mempty . withIndexLink "unique-return-aliased" $ "A unique-typed component of the return value is aliased to some other component." checkReturnAlias :: SrcLoc -> [Pat ParamType] -> ResType -> TypeAliases -> CheckM () checkReturnAlias loc params rettp = foldM_ (checkReturnAlias' params) S.empty . returnAliases rettp where checkReturnAlias' params' seen (Unique, names) = do when (any (`S.member` S.map snd seen) $ S.toList names) $ uniqueReturnAliased loc notAliasesParam params' names pure $ seen `S.union` tag Unique names checkReturnAlias' _ seen (Nonunique, names) = do when (any (`S.member` seen) $ S.toList $ tag Unique names) $ uniqueReturnAliased loc pure $ seen `S.union` tag Nonunique names notAliasesParam params' names = forM_ params' $ \p -> let consumedNonunique (v, t) = not (consumableParamType t) && (v `S.member` names) in case find consumedNonunique $ patternMap p of Just (v, _) -> returnAliased (baseName v) loc Nothing -> pure () tag u = S.map (u,) returnAliases (Scalar (Record ets1)) (Scalar (Record ets2)) = concat $ M.elems $ M.intersectionWith returnAliases ets1 ets2 returnAliases expected got = [(uniqueness expected, S.map aliasVar $ aliases got)] consumableParamType (Array u _ _) = u == Consume consumableParamType (Scalar Prim {}) = True consumableParamType (Scalar (TypeVar u _ _)) = u == Consume consumableParamType (Scalar (Record fs)) = all consumableParamType fs consumableParamType (Scalar (Sum fs)) = all (all consumableParamType) fs consumableParamType (Scalar Arrow {}) = False unscope :: [VName] -> Aliases -> Aliases unscope bound = S.map f where f (AliasFree v) = AliasFree v f (AliasBound v) = if v `elem` bound then AliasFree v else AliasBound v -- | Figure out the aliases of each bound name in a pattern. matchPat :: Pat t -> TypeAliases -> DL.DList (VName, (t, TypeAliases)) matchPat (PatParens p _) t = matchPat p t matchPat (TuplePat ps _) t | Just ts <- isTupleRecord t = mconcat $ zipWith matchPat ps ts matchPat (RecordPat fs1 _) (Scalar (Record fs2)) = mconcat $ zipWith matchPat (map snd (sortFields (M.fromList (map (first unLoc) fs1)))) (map snd (sortFields fs2)) matchPat (Id v (Info t) _) als = DL.singleton (v, (t, als)) matchPat (PatAscription p _ _) t = matchPat p t matchPat (PatConstr v _ ps _) (Scalar (Sum cs)) | Just ts <- M.lookup v cs = mconcat $ zipWith matchPat ps ts matchPat TuplePat {} _ = mempty matchPat RecordPat {} _ = mempty matchPat PatConstr {} _ = mempty matchPat Wildcard {} _ = mempty matchPat PatLit {} _ = mempty matchPat (PatAttr _ p _) t = matchPat p t bindingPat :: Pat StructType -> TypeAliases -> CheckM (a, TypeAliases) -> CheckM (a, TypeAliases) bindingPat p t = fmap (second (second (unscope (patNames p)))) . local bind where bind env = env { envVtable = foldr (uncurry M.insert . f) (envVtable env) (matchPat p t) } where f (v, (_, als)) = (v, Consumable $ second (S.insert (AliasBound v)) als) bindingParam :: Pat ParamType -> CheckM (a, TypeAliases) -> CheckM (a, TypeAliases) bindingParam p m = do mapM_ (noConsumable . bitraverse_ checkExp pure) p second (second (unscope (patNames p))) <$> local bind m where bind env = env { envVtable = foldr (uncurry M.insert . f) (envVtable env) (patternMap p) } f (v, t) | diet t == Consume = (v, Consumable $ t `setAliases` S.singleton (AliasBound v)) | otherwise = (v, Nonconsumable $ t `setAliases` S.singleton (AliasBound v)) bindingIdent :: Diet -> Ident StructType -> CheckM (a, TypeAliases) -> CheckM (a, TypeAliases) bindingIdent d (Ident v (Info t) _) = fmap (second (second (unscope [v]))) . local bind where bind env = env {envVtable = M.insert v t' (envVtable env)} d' = case d of Consume -> Consumable Observe -> Nonconsumable t' = d' $ t `setAliases` S.singleton (AliasBound v) bindingParams :: [Pat ParamType] -> CheckM (a, TypeAliases) -> CheckM (a, TypeAliases) bindingParams params m = noConsumable $ second (second (unscope (foldMap patNames params))) <$> foldr bindingParam m params bindingLoopForm :: LoopFormBase Info VName -> CheckM (a, TypeAliases) -> CheckM (a, TypeAliases) bindingLoopForm (For ident _) m = bindingIdent Observe ident m bindingLoopForm (ForIn pat _) m = bindingParam pat' m where pat' = fmap (second (const Observe)) pat bindingLoopForm While {} m = m bindingFun :: VName -> TypeAliases -> CheckM a -> CheckM a bindingFun v t = local $ \env -> env {envVtable = M.insert v (Nonconsumable t) (envVtable env)} checkIfConsumed :: Loc -> Aliases -> CheckM () checkIfConsumed rloc als = do cons <- gets stateConsumed let bad v = fmap (v,) $ v `M.lookup` cons forM_ (mapMaybe (bad . aliasVar) $ S.toList als) $ \(v, wloc) -> do v' <- describeVar v addError rloc mempty . withIndexLink "use-after-consume" $ "Using" <+> v' <> ", but this was consumed at" <+> pretty (locStrRel rloc wloc) <> ". (Possibly through aliases.)" consumed :: Consumed -> CheckM () consumed vs = modify $ \s -> s {stateConsumed = stateConsumed s <> vs} consumeAliases :: Loc -> Aliases -> CheckM () consumeAliases loc als = do vtable <- asks envVtable let isBad v = case v `M.lookup` vtable of Just (Nonconsumable {}) -> True Just _ -> False Nothing -> True checkIfConsumable (AliasBound v) | isBad v = do v' <- describeVar v addError loc mempty . withIndexLink "not-consumable" $ "Consuming" <+> v' <> ", which is not consumable." checkIfConsumable _ = pure () mapM_ checkIfConsumable $ S.toList als checkIfConsumed loc als consumed als' where als' = M.fromList $ map ((,loc) . aliasVar) $ S.toList als consume :: Loc -> VName -> StructType -> CheckM () consume loc v t = consumeAliases loc . aliases =<< observeVar loc v t -- | Observe the given name here and return its aliases. observeVar :: Loc -> VName -> StructType -> CheckM TypeAliases observeVar loc v t = do als <- asks $ \env -> maybe (isGlobal (envVtable env)) isLocal $ M.lookup v (envVtable env) checkIfConsumed loc (aliases als) pure als where isLocal = entryAliases -- Handling globals is tricky. For arrays and such, we do want to -- track their aliases. We do not want to track the aliases of -- functions. However, array bindings that are *polymorphic* -- should be treated like functions. However, we do not have -- access to the original binding information here. To avoid -- having to plumb that all the way here, we infer that an array -- binding is a polymorphic instantiation if its size contains any -- locally bound names. isGlobal vtable | isInstantiation vtable t = second (const mempty) t | otherwise = selfAlias $ second (const mempty) t isInstantiation vtable = any (`M.member` vtable) . fvVars . freeInType selfAlias (Array als shape et) = Array (S.insert (AliasBound v) als) shape et selfAlias (Scalar st) = Scalar $ selfAlias' st selfAlias' (TypeVar als tn args) = TypeVar als tn args -- #1675 FIXME selfAlias' (Record fs) = Record $ fmap selfAlias fs selfAlias' (Sum fs) = Sum $ fmap (map selfAlias) fs selfAlias' et@Arrow {} = et selfAlias' et@Prim {} = et -- Capture any newly consumed variables that occur during the provided action. contain :: CheckM a -> CheckM (a, Consumed) contain m = do prev_cons <- gets stateConsumed x <- m new_cons <- gets $ (`M.difference` prev_cons) . stateConsumed modify $ \s -> s {stateConsumed = prev_cons} pure (x, new_cons) -- | The two types are assumed to be approximately structurally equal, -- but not necessarily regarding sizes. Combines aliases and prefers -- other information from first argument. combineAliases :: TypeAliases -> TypeAliases -> TypeAliases combineAliases (Array als1 et1 shape1) t2 = Array (als1 <> aliases t2) et1 shape1 combineAliases (Scalar (TypeVar als1 tv1 targs1)) t2 = Scalar $ TypeVar (als1 <> aliases t2) tv1 targs1 combineAliases t1 (Scalar (TypeVar als2 tv2 targs2)) = Scalar $ TypeVar (als2 <> aliases t1) tv2 targs2 combineAliases (Scalar (Record ts1)) (Scalar (Record ts2)) | length ts1 == length ts2, L.sort (M.keys ts1) == L.sort (M.keys ts2) = Scalar $ Record $ M.intersectionWith combineAliases ts1 ts2 combineAliases (Scalar (Arrow als1 mn1 d1 pt1 (RetType dims1 rt1))) (Scalar (Arrow als2 _ _ _ (RetType _ _))) = Scalar (Arrow (als1 <> als2) mn1 d1 pt1 (RetType dims1 rt1)) combineAliases (Scalar (Sum cs1)) (Scalar (Sum cs2)) | length cs1 == length cs2, L.sort (M.keys cs1) == L.sort (M.keys cs2) = Scalar $ Sum $ M.intersectionWith (zipWith combineAliases) cs1 cs2 combineAliases (Scalar (Prim t)) _ = Scalar $ Prim t combineAliases t1 t2 = error $ "combineAliases invalid args: " ++ show (t1, t2) -- An alias inhibits uniqueness if it is used in disjoint values. aliasesMultipleTimes :: TypeAliases -> Names aliasesMultipleTimes = S.fromList . map fst . filter ((> 1) . snd) . M.toList . delve where delve (Scalar (Record fs)) = foldl' (M.unionWith (+)) mempty $ map delve $ M.elems fs delve t = M.fromList $ map ((,1 :: Int) . aliasVar) $ S.toList $ aliases t consumingParams :: [Pat ParamType] -> Names consumingParams = S.fromList . map fst . filter ((== Consume) . diet . snd) . foldMap patternMap arrayAliases :: TypeAliases -> Aliases arrayAliases (Array als _ _) = als arrayAliases (Scalar Prim {}) = mempty arrayAliases (Scalar (Record fs)) = foldMap arrayAliases fs arrayAliases (Scalar (TypeVar als _ _)) = als arrayAliases (Scalar Arrow {}) = mempty arrayAliases (Scalar (Sum fs)) = mconcat $ concatMap (map arrayAliases) $ M.elems fs overlapCheck :: (Pretty src, Pretty ve) => Loc -> (src, TypeAliases) -> (ve, TypeAliases) -> CheckM () overlapCheck loc (src, src_als) (ve, ve_als) = when (any (`S.member` aliases src_als) (aliases ve_als)) $ addError loc mempty $ "Source array for in-place update" indent 2 (pretty src) "might alias update value" indent 2 (pretty ve) "Hint: use" <+> dquotes "copy" <+> "to remove aliases from the value." inferReturnUniqueness :: [Pat ParamType] -> ResType -> TypeAliases -> ResType inferReturnUniqueness [] ret _ = ret `setUniqueness` Nonunique inferReturnUniqueness params ret ret_als = delve ret ret_als where forbidden = aliasesMultipleTimes ret_als consumings = consumingParams params delve (Scalar (Record fs1)) (Scalar (Record fs2)) = Scalar $ Record $ M.intersectionWith delve fs1 fs2 delve (Scalar (Sum cs1)) (Scalar (Sum cs2)) = Scalar $ Sum $ M.intersectionWith (zipWith delve) cs1 cs2 delve t t_als | all (`S.member` consumings) $ boundAliases (arrayAliases t_als), not $ any ((`S.member` forbidden) . aliasVar) (aliases t_als) = t `setUniqueness` Unique | otherwise = t `setUniqueness` Nonunique checkSubExps :: (ASTMappable e) => e -> CheckM e checkSubExps = astMap identityMapper {mapOnExp = fmap fst . checkExp} noAliases :: Exp -> CheckM (Exp, TypeAliases) noAliases e = do e' <- checkSubExps e pure (e', second (const mempty) (typeOf e)) aliasParts :: TypeAliases -> [Aliases] aliasParts (Scalar (Record ts)) = foldMap aliasParts $ M.elems ts aliasParts t = [aliases t] noSelfAliases :: Loc -> TypeAliases -> CheckM () noSelfAliases loc = foldM_ check mempty . aliasParts where check seen als = do when (any (`S.member` seen) als) $ addError loc mempty . withIndexLink "self-aliases-arg" $ "Argument passed for consuming parameter is self-aliased." pure $ als <> seen consumeAsNeeded :: Loc -> ParamType -> TypeAliases -> CheckM () consumeAsNeeded loc (Scalar (Record fs1)) (Scalar (Record fs2)) = sequence_ $ M.elems $ M.intersectionWith (consumeAsNeeded loc) fs1 fs2 consumeAsNeeded loc pt t = when (diet pt == Consume) $ consumeAliases loc $ aliases t checkArg :: [(Exp, TypeAliases)] -> ParamType -> Exp -> CheckM (Exp, TypeAliases) checkArg prev p_t e = do ((e', e_als), e_cons) <- contain $ checkExp e consumed e_cons let e_t = typeOf e' when (e_cons /= mempty && not (orderZero e_t)) $ addError (locOf e) mempty $ "Argument of functional type" indent 2 (pretty e_t) "contains consumption, which is not allowed." when (diet p_t == Consume) $ do noSelfAliases (locOf e) e_als consumeAsNeeded (locOf e) p_t e_als case mapMaybe prevAlias $ S.toList $ boundAliases $ aliases e_als of [] -> pure () (v, prev_arg) : _ -> addError (locOf e) mempty $ "Argument is consumed, but aliases" indent 2 (prettyName v) "which is also aliased by other argument" indent 2 (pretty prev_arg) "at" <+> pretty (locTextRel (locOf e) (locOf prev_arg)) <> "." pure (e', e_als) where prevAlias v = (v,) . fst <$> find (S.member v . boundAliases . aliases . snd) prev -- | @returnType appres ret_type arg_diet arg_type@ gives result of applying -- an argument the given types to a function with the given return -- type, consuming the argument with the given diet. returnType :: Aliases -> ResType -> Diet -> TypeAliases -> TypeAliases returnType _ (Array Unique et shape) _ _ = Array mempty et shape returnType appres (Array Nonunique et shape) Consume _ = Array appres et shape returnType appres (Array Nonunique et shape) Observe arg = Array (appres <> aliases arg) et shape returnType _ (Scalar (TypeVar Unique t targs)) _ _ = Scalar $ TypeVar mempty t targs returnType appres (Scalar (TypeVar Nonunique t targs)) Consume _ = Scalar $ TypeVar appres t targs returnType appres (Scalar (TypeVar Nonunique t targs)) Observe arg = Scalar $ TypeVar (appres <> aliases arg) t targs returnType appres (Scalar (Record fs)) d arg = Scalar $ Record $ fmap (\et -> returnType appres et d arg) fs returnType _ (Scalar (Prim t)) _ _ = Scalar $ Prim t returnType appres (Scalar (Arrow _ v pd t1 (RetType dims t2))) Consume _ = Scalar $ Arrow appres v pd t1 $ RetType dims t2 returnType appres (Scalar (Arrow _ v pd t1 (RetType dims t2))) Observe arg = Scalar $ Arrow (appres <> aliases arg) v pd t1 $ RetType dims t2 returnType appres (Scalar (Sum cs)) d arg = Scalar $ Sum $ (fmap . fmap) (\et -> returnType appres et d arg) cs applyArg :: TypeAliases -> TypeAliases -> TypeAliases applyArg (Scalar (Arrow closure_als _ d _ (RetType _ rettype))) arg_als = returnType closure_als rettype d arg_als applyArg t _ = error $ "applyArg: " <> show t boundFreeInExp :: Exp -> CheckM (M.Map VName TypeAliases) boundFreeInExp e = do vtable <- asks envVtable pure $ M.mapMaybe (fmap entryAliases) . M.fromSet (`M.lookup` vtable) $ fvVars (freeInExp e) -- Loops are tricky because we want to infer the uniqueness of their -- parameters. This is pretty unusual: we do not do this for ordinary -- functions. type Loop = (Pat ParamType, LoopInitBase Info VName, LoopFormBase Info VName, Exp) -- | Mark bindings of consumed names as Consume, except those under a -- 'PatAscription', which are left unchanged. updateParamDiet :: (VName -> Bool) -> Pat ParamType -> Pat ParamType updateParamDiet cons = recurse where recurse (Wildcard (Info t) wloc) = Wildcard (Info $ t `setUniqueness` Observe) wloc recurse (PatParens p ploc) = PatParens (recurse p) ploc recurse (PatAttr attr p ploc) = PatAttr attr (recurse p) ploc recurse (Id name (Info t) iloc) | cons name = let t' = t `setUniqueness` Consume in Id name (Info t') iloc | otherwise = let t' = t `setUniqueness` Observe in Id name (Info t') iloc recurse (TuplePat pats ploc) = TuplePat (map recurse pats) ploc recurse (RecordPat fs ploc) = RecordPat (map (fmap recurse) fs) ploc recurse (PatAscription p t ploc) = PatAscription p t ploc recurse p@PatLit {} = p recurse (PatConstr n t ps ploc) = PatConstr n t (map recurse ps) ploc convergeLoopParam :: Loc -> Pat ParamType -> Names -> TypeAliases -> CheckM (Pat ParamType) convergeLoopParam loop_loc param body_cons body_als = do let -- Make the pattern Consume where needed. param' = updateParamDiet (`S.member` S.filter (`elem` patNames param) body_cons) param -- Check that the new values of consumed merge parameters do not -- alias something bound outside the loop, AND that anything -- returned for a unique merge parameter does not alias anything -- else returned. let checkMergeReturn (Id pat_v (Info pat_v_t) patloc) t = do let free_als = S.filter (`notElem` patNames param) $ boundAliases (aliases t) when (diet pat_v_t == Consume) $ forM_ free_als $ \v -> lift . addError loop_loc mempty $ "Return value for consuming loop parameter" <+> dquotes (prettyName pat_v) <+> "aliases" <+> dquotes (prettyName v) <> "." (cons, obs) <- get unless (S.null $ aliases t `S.intersection` cons) $ lift . addError loop_loc mempty $ "Return value for loop parameter" <+> dquotes (prettyName pat_v) <+> "aliases other consumed loop parameter." when ( diet pat_v_t == Consume && not (S.null (aliases t `S.intersection` (cons <> obs))) ) $ lift . addError loop_loc mempty $ "Return value for consuming loop parameter" <+> dquotes (prettyName pat_v) <+> "aliases previously returned value." if diet pat_v_t == Consume then put (cons <> aliases t, obs) else put (cons, obs <> aliases t) pure $ Id pat_v (Info pat_v_t) patloc checkMergeReturn (Wildcard (Info pat_v_t) patloc) _ = pure $ Wildcard (Info pat_v_t) patloc checkMergeReturn (PatParens p _) t = checkMergeReturn p t checkMergeReturn (PatAscription p _ _) t = checkMergeReturn p t checkMergeReturn (RecordPat pfs patloc) (Scalar (Record tfs)) = RecordPat . map unshuffle . M.toList <$> sequence pfs' <*> pure patloc where pfs' = M.intersectionWith check (M.fromList (map shuffle pfs)) tfs check (loc, x) y = (loc,) <$> checkMergeReturn x y shuffle (L loc v, t) = (v, (loc, t)) unshuffle (v, (loc, t)) = (L loc v, t) checkMergeReturn (TuplePat pats patloc) t | Just ts <- isTupleRecord t = TuplePat <$> zipWithM checkMergeReturn pats ts <*> pure patloc checkMergeReturn p _ = pure p (param'', (param_cons, _)) <- runStateT (checkMergeReturn param' body_als) (mempty, mempty) let body_cons' = body_cons <> S.map aliasVar param_cons if body_cons' == body_cons && patternType param'' == patternType param then pure param' else convergeLoopParam loop_loc param'' body_cons' body_als checkLoop :: Loc -> Loop -> CheckM (Loop, TypeAliases) checkLoop loop_loc (param, arg, form, body) = do form' <- checkSubExps form -- We pretend that every part of the loop parameter has a consuming -- diet, as we need to allow consumption in the body, which we then -- use to infer the proper diet of the parameter. ((body', body_cons), body_als) <- noConsumable . bindingParam (updateParamDiet (const True) param) . bindingLoopForm form' $ do ((body', body_als), body_cons) <- contain $ checkExp body pure ((body', body_cons), body_als) param' <- convergeLoopParam loop_loc param (M.keysSet body_cons) body_als let param_t = patternType param' ((arg', arg_als), arg_cons) <- case arg of LoopInitImplicit (Info e) -> contain $ first (LoopInitImplicit . Info) <$> checkArg [] param_t e LoopInitExplicit e -> contain $ first LoopInitExplicit <$> checkArg [] param_t e consumed arg_cons free_bound <- boundFreeInExp body let bad = any (`M.member` arg_cons) . boundAliases . aliases . snd forM_ (filter bad $ M.toList free_bound) $ \(v, _) -> do v' <- describeVar v addError loop_loc mempty $ "Loop body uses" <+> v' <> " (or an alias)," "but this is consumed by the initial loop argument." v <- VName "internal_loop_result" <$> incCounter modify $ \s -> s {stateNames = M.insert v (NameLoopRes (srclocOf loop_loc)) $ stateNames s} let loopt = funType [param'] (RetType [] $ paramToRes param_t) `setAliases` S.singleton (AliasFree v) pure ( (param', arg', form', body'), applyArg loopt arg_als `combineAliases` body_als ) checkFuncall :: (Foldable f) => SrcLoc -> Maybe (QualName VName) -> TypeAliases -> f TypeAliases -> CheckM TypeAliases checkFuncall loc fname f_als arg_als = do v <- VName "internal_app_result" <$> incCounter modify $ \s -> s {stateNames = M.insert v (NameAppRes fname loc) $ stateNames s} pure $ foldl applyArg (second (S.insert (AliasFree v)) f_als) arg_als checkExp :: Exp -> CheckM (Exp, TypeAliases) -- First we have the complicated cases. -- checkExp (AppExp (Apply f args loc) appres) = do (f', f_als) <- checkExp f (args', args_als) <- NE.unzip <$> checkArgs (toRes Nonunique f_als) args res_als <- checkFuncall loc (fname f) f_als args_als pure (AppExp (Apply f' args' loc) appres, res_als) where fname (Var v _ _) = Just v fname (AppExp (Apply e _ _) _) = fname e fname _ = Nothing checkArg' prev d (Info p, e) = do (e', e_als) <- checkArg prev (second (const d) (typeOf e)) e pure ((Info p, e'), e_als) checkArgs (Scalar (Arrow _ _ d _ (RetType _ rt))) (x NE.:| args') = do -- Note Futhark uses right-to-left evaluation of applications. args'' <- maybe (pure []) (fmap NE.toList . checkArgs rt) $ NE.nonEmpty args' (x', x_als) <- checkArg' (map (first snd) args'') d x pure $ (x', x_als) NE.:| args'' checkArgs t _ = error $ "checkArgs: " <> prettyString t -- checkExp (AppExp (Loop sparams pat loopinit form body loc) appres) = do ((pat', loopinit', form', body'), als) <- checkLoop (locOf loc) (pat, loopinit, form, body) pure ( AppExp (Loop sparams pat' loopinit' form' body' loc) appres, als ) -- checkExp (AppExp (LetPat sizes p e body loc) appres) = do ((e', e_als), e_cons) <- contain $ checkExp e consumed e_cons let e_t = typeOf e' when (e_cons /= mempty && not (orderZero e_t)) $ addError (locOf e) mempty $ "Let-bound expression of higher-order type" indent 2 (pretty e_t) "contains consumption, which is not allowed." bindingPat p e_als $ do (body', body_als) <- checkExp body pure ( AppExp (LetPat sizes p e' body' loc) appres, body_als ) -- checkExp (AppExp (If cond te fe loc) appres) = do (cond', _) <- checkExp cond ((te', te_als), te_cons) <- contain $ checkExp te ((fe', fe_als), fe_cons) <- contain $ checkExp fe let all_cons = te_cons <> fe_cons notConsumed = not . (`M.member` all_cons) . aliasVar comb_als = second (S.filter notConsumed) $ te_als `combineAliases` fe_als consumed all_cons pure ( AppExp (If cond' te' fe' loc) appres, appResType (unInfo appres) `setAliases` mempty `combineAliases` comb_als ) -- checkExp (AppExp (Match cond cs loc) appres) = do (cond', cond_als) <- checkExp cond ((cs', cs_als), cs_cons) <- first NE.unzip . NE.unzip <$> mapM (checkCase cond_als) cs let all_cons = fold cs_cons notConsumed = not . (`M.member` all_cons) . aliasVar comb_als = second (S.filter notConsumed) $ foldl1 combineAliases cs_als consumed all_cons pure ( AppExp (Match cond' cs' loc) appres, appResType (unInfo appres) `setAliases` mempty `combineAliases` comb_als ) where checkCase cond_als (CasePat p body caseloc) = contain $ bindingPat p cond_als $ do (body', body_als) <- checkExp body pure (CasePat p body' caseloc, body_als) -- checkExp (AppExp (LetFun fname (typarams, params, te, Info (RetType ext ret), funbody) letbody loc) appres) = do ((ret', funbody'), ftype) <- bindingParams params $ do -- Throw away the consumption - it can refer only to the parameters -- anyway. ((funbody', funbody_als), _body_cons) <- contain $ checkExp funbody checkReturnAlias loc params ret funbody_als checkGlobalAliases loc params funbody_als free_bound <- boundFreeInExp funbody let ret' = inferReturnUniqueness params ret funbody_als als = foldMap aliases (M.elems free_bound) ftype = funType params (RetType ext ret') `setAliases` als pure ((ret', funbody'), ftype) (letbody', letbody_als) <- bindingFun fname ftype $ checkExp letbody pure ( AppExp (LetFun fname (typarams, params, te, Info (RetType ext ret'), funbody') letbody' loc) appres, letbody_als ) -- checkExp (AppExp (BinOp (op, oploc) opt (x, xp) (y, yp) loc) appres) = do op_als <- observeVar (locOf oploc) (qualLeaf op) (unInfo opt) let at1 : at2 : _ = fst $ unfoldFunType op_als (x', x_als) <- checkArg [] at1 x (y', y_als) <- checkArg [(x', x_als)] at2 y res_als <- checkFuncall loc (Just op) op_als [x_als, y_als] pure ( AppExp (BinOp (op, oploc) opt (x', xp) (y', yp) loc) appres, res_als ) -- checkExp e@(Lambda params body te (Info (RetType ext ret)) loc) = bindingParams params $ do -- Throw away the consumption - it can refer only to the parameters -- anyway. ((body', body_als), _body_cons) <- contain $ checkExp body checkReturnAlias loc params ret body_als checkGlobalAliases loc params body_als free_bound <- boundFreeInExp e let ret' = inferReturnUniqueness params ret body_als als = foldMap aliases (M.elems free_bound) ftype = funType params (RetType ext ret') `setAliases` als pure ( Lambda params body' te (Info (RetType ext ret')) loc, ftype ) -- checkExp (AppExp (LetWith dst src slice ve body loc) appres) = do src_als <- observeVar (locOf dst) (identName src) (unInfo $ identType src) slice' <- checkSubExps slice (ve', ve_als) <- checkExp ve consume (locOf src) (identName src) (unInfo (identType src)) overlapCheck (locOf ve) (src, src_als) (ve', ve_als) (body', body_als) <- bindingIdent Consume dst $ checkExp body pure (AppExp (LetWith dst src slice' ve' body' loc) appres, body_als) -- checkExp (Update src slice ve loc) = do slice' <- checkSubExps slice (ve', ve_als) <- checkExp ve (src', src_als) <- checkExp src overlapCheck (locOf ve) (src', src_als) (ve', ve_als) consumeAliases (locOf loc) $ aliases src_als pure (Update src' slice' ve' loc, second (const mempty) src_als) -- Cases that simply propagate aliases directly. checkExp (Var v (Info t) loc) = do als <- observeVar (locOf loc) (qualLeaf v) t checkIfConsumed (locOf loc) (aliases als) pure (Var v (Info t) loc, als) checkExp (OpSection v (Info t) loc) = do als <- observeVar (locOf loc) (qualLeaf v) t checkIfConsumed (locOf loc) (aliases als) pure (OpSection v (Info t) loc, als) checkExp (OpSectionLeft op ftype arg arginfo retinfo loc) = do let (_, Info (pn, pt2)) = arginfo (Info ret, _) = retinfo als <- observeVar (locOf loc) (qualLeaf op) (unInfo ftype) (arg', arg_als) <- checkExp arg pure ( OpSectionLeft op ftype arg' arginfo retinfo loc, Scalar $ Arrow (aliases arg_als <> aliases als) pn (diet pt2) (toStruct pt2) ret ) checkExp (OpSectionRight op ftype arg arginfo retinfo loc) = do let (Info (pn, pt2), _) = arginfo Info ret = retinfo als <- observeVar (locOf loc) (qualLeaf op) (unInfo ftype) (arg', arg_als) <- checkExp arg pure ( OpSectionRight op ftype arg' arginfo retinfo loc, Scalar $ Arrow (aliases arg_als <> aliases als) pn (diet pt2) (toStruct pt2) ret ) checkExp (IndexSection slice t loc) = do slice' <- checkSubExps slice pure (IndexSection slice' t loc, unInfo t `setAliases` mempty) checkExp (ProjectSection fs t loc) = do pure (ProjectSection fs t loc, unInfo t `setAliases` mempty) checkExp (Coerce e te t loc) = do (e', e_als) <- checkExp e pure (Coerce e' te t loc, e_als) checkExp (Ascript e te loc) = do (e', e_als) <- checkExp e pure (Ascript e' te loc, e_als) checkExp (AppExp (Index v slice loc) appres) = do (v', v_als) <- checkExp v slice' <- checkSubExps slice pure ( AppExp (Index v' slice' loc) appres, appResType (unInfo appres) `setAliases` aliases v_als ) checkExp (Assert e1 e2 t loc) = do (e1', _) <- checkExp e1 (e2', e2_als) <- checkExp e2 pure (Assert e1' e2' t loc, e2_als) checkExp (Parens e loc) = do (e', e_als) <- checkExp e pure (Parens e' loc, e_als) checkExp (QualParens v e loc) = do (e', e_als) <- checkExp e pure (QualParens v e' loc, e_als) checkExp (Attr attr e loc) = do (e', e_als) <- checkExp e pure (Attr attr e' loc, e_als) checkExp (Project name e t loc) = do (e', e_als) <- checkExp e pure ( Project name e' t loc, case e_als of Scalar (Record fs) | Just name_als <- M.lookup name fs -> name_als _ -> error $ "checkExp Project: bad type " <> prettyString e_als ) checkExp (TupLit es loc) = do (es', es_als) <- mapAndUnzipM checkExp es pure (TupLit es' loc, Scalar $ tupleRecord es_als) checkExp (Constr name es t loc) = do (es', es_als) <- mapAndUnzipM checkExp es pure ( Constr name es' t loc, case unInfo t of Scalar (Sum cs) -> Scalar . Sum . M.insert name es_als $ M.map (map (`setAliases` mempty)) cs t' -> error $ "checkExp Constr: bad type " <> prettyString t' ) checkExp (RecordUpdate src fields ve t loc) = do (src', src_als) <- checkExp src (ve', ve_als) <- checkExp ve pure ( RecordUpdate src' fields ve' t loc, setFieldAliases ve_als fields src_als ) checkExp (RecordLit fs loc) = do (fs', fs_als) <- mapAndUnzipM checkField fs pure (RecordLit fs' loc, Scalar $ Record $ M.fromList fs_als) where checkField (RecordFieldExplicit name e floc) = do (e', e_als) <- checkExp e pure (RecordFieldExplicit name e' floc, (unLoc name, e_als)) checkField (RecordFieldImplicit name t floc) = do name_als <- observeVar (locOf floc) (unLoc name) $ unInfo t pure (RecordFieldImplicit name t floc, (baseName (unLoc name), name_als)) -- Cases that create alias-free values. checkExp e@(AppExp Range {} _) = noAliases e checkExp e@IntLit {} = noAliases e checkExp e@FloatLit {} = noAliases e checkExp e@Literal {} = noAliases e checkExp e@StringLit {} = noAliases e checkExp e@ArrayVal {} = noAliases e checkExp e@ArrayLit {} = noAliases e checkExp e@Negate {} = noAliases e checkExp e@Not {} = noAliases e checkExp e@Hole {} = noAliases e checkGlobalAliases :: SrcLoc -> [Pat ParamType] -> TypeAliases -> CheckM () checkGlobalAliases loc params body_t = do vtable <- asks envVtable let global = flip M.notMember vtable unless (null params) $ forM_ (boundAliases $ arrayAliases body_t) $ \v -> when (global v) . addError loc mempty . withIndexLink "alias-free-variable" $ "Function result aliases the free variable " <> dquotes (prettyName v) <> "." "Use" <+> dquotes "copy" <+> "to break the aliasing." -- | Type-check a value definition. This also infers a new return -- type that may be more unique than previously. checkValDef :: (VName, [Pat ParamType], Exp, ResRetType, Maybe (TypeExp Exp VName), SrcLoc) -> ((Exp, ResRetType), [TypeError]) checkValDef (_fname, params, body, RetType ext ret, retdecl, loc) = runCheckM (locOf loc) $ do fmap fst . bindingParams params $ do (body', body_als) <- checkExp body checkReturnAlias loc params ret body_als checkGlobalAliases loc params body_als -- If the user did not provide an annotation (meaning the return -- type is fully inferred), we infer the uniqueness. Otherwise, -- we go with whatever they wanted. This lets the user define -- non-unique return types even if the body actually has no -- aliases. ret' <- case retdecl of Just retdecl' -> do when (null params && unique ret) $ addError retdecl' mempty "A top-level constant cannot have a unique type." pure $ RetType ext ret Nothing -> pure $ RetType ext $ inferReturnUniqueness params ret body_als pure ( (body', ret'), body_als -- Don't matter. ) {-# NOINLINE checkValDef #-} futhark-0.25.27/src/Language/Futhark/TypeChecker/Match.hs000066400000000000000000000134151475065116200230620ustar00rootroot00000000000000-- | Checking for missing cases in a match expression. Based on -- "Warnings for pattern matching" by Luc Maranget. We only detect -- inexhaustiveness here - ideally, we would also like to check for -- redundant cases. module Language.Futhark.TypeChecker.Match ( unmatched, Match, ) where import Data.Bifunctor (first) import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Futhark.Util (maybeHead, nubOrd) import Futhark.Util.Pretty hiding (group, space) import Language.Futhark hiding (ExpBase (Constr)) data Constr = Constr Name | ConstrTuple | ConstrRecord [Name] | -- | Treated as 0-ary. ConstrLit PatLit deriving (Eq, Ord, Show) -- | A representation of the essentials of a pattern. data Match t = MatchWild t | MatchConstr Constr [Match t] t deriving (Eq, Ord, Show) matchType :: Match StructType -> StructType matchType (MatchWild t) = t matchType (MatchConstr _ _ t) = t pprMatch :: Int -> Match t -> Doc a pprMatch _ MatchWild {} = "_" pprMatch _ (MatchConstr (ConstrLit l) _ _) = pretty l pprMatch p (MatchConstr (Constr c) ps _) = parensIf (not (null ps) && p >= 10) $ "#" <> pretty c <> mconcat (map ((" " <>) . pprMatch 10) ps) pprMatch _ (MatchConstr ConstrTuple ps _) = parens $ commasep $ map (pprMatch (-1)) ps pprMatch _ (MatchConstr (ConstrRecord fs) ps _) = braces $ commasep $ zipWith ppField fs ps where ppField name t = pretty (nameToString name) <> equals <> pprMatch (-1) t instance Pretty (Match t) where pretty = pprMatch (-1) patternToMatch :: Pat StructType -> Match StructType patternToMatch (Id _ (Info t) _) = MatchWild t patternToMatch (Wildcard (Info t) _) = MatchWild t patternToMatch (PatParens p _) = patternToMatch p patternToMatch (PatAttr _ p _) = patternToMatch p patternToMatch (PatAscription p _ _) = patternToMatch p patternToMatch (PatLit l (Info t) _) = MatchConstr (ConstrLit l) [] t patternToMatch p@(TuplePat ps _) = MatchConstr ConstrTuple (map patternToMatch ps) $ patternStructType p patternToMatch p@(RecordPat fs _) = MatchConstr (ConstrRecord fnames) (map patternToMatch ps) $ patternStructType p where (fnames, ps) = unzip $ sortFields $ M.fromList $ map (first unLoc) fs patternToMatch (PatConstr c (Info t) args _) = MatchConstr (Constr c) (map patternToMatch args) t isConstr :: Match t -> Maybe Name isConstr (MatchConstr (Constr c) _ _) = Just c isConstr _ = Nothing isBool :: Match t -> Maybe Bool isBool (MatchConstr (ConstrLit (PatLitPrim (BoolValue b))) _ _) = Just b isBool _ = Nothing complete :: [Match StructType] -> Bool complete xs | Just x <- maybeHead xs, Scalar (Sum all_cs) <- matchType x, Just xs_cs <- mapM isConstr xs = all (`elem` xs_cs) (M.keys all_cs) | otherwise = all (`elem` fromMaybe [] (mapM isBool xs)) [True, False] || all isRecord xs || all isTuple xs where isRecord (MatchConstr ConstrRecord {} _ _) = True isRecord _ = False isTuple (MatchConstr ConstrTuple _ _) = True isTuple _ = False specialise :: [StructType] -> Match StructType -> [[Match StructType]] -> [[Match StructType]] specialise ats c1 = go where go ((c2 : row) : ps) | Just args <- match c1 c2 = (args ++ row) : go ps | otherwise = go ps go _ = [] match (MatchConstr c1' _ _) (MatchConstr c2' args _) | c1' == c2' = Just args | otherwise = Nothing match _ MatchWild {} = Just $ map MatchWild ats match _ _ = Nothing defaultMat :: [[Match t]] -> [[Match t]] defaultMat = mapMaybe onRow where onRow (MatchConstr {} : _) = Nothing onRow (MatchWild {} : ps) = Just ps onRow [] = Nothing -- Should not happen. findUnmatched :: [[Match StructType]] -> Int -> [[Match ()]] findUnmatched pmat n | ((p : _) : _) <- pmat, Just heads <- mapM maybeHead pmat = if complete heads then completeCase heads else incompleteCase (matchType p) heads where completeCase cs = do c <- cs let ats = case c of MatchConstr _ args _ -> map matchType args MatchWild _ -> [] a_k = length ats pmat' = specialise ats c pmat u <- findUnmatched pmat' (a_k + n - 1) pure $ case c of MatchConstr c' _ _ -> let (r, p) = splitAt a_k u in MatchConstr c' r () : p MatchWild _ -> MatchWild () : u incompleteCase pt cs = do u <- findUnmatched (defaultMat pmat) (n - 1) if null cs then pure $ MatchWild () : u else case pt of Scalar (Sum all_cs) -> do -- Figure out which constructors are missing. let sigma = mapMaybe isConstr cs notCovered (k, _) = k `notElem` sigma (cname, ts) <- filter notCovered $ M.toList all_cs pure $ MatchConstr (Constr cname) (map (const (MatchWild ())) ts) () : u Scalar (Prim Bool) -> do -- Figure out which constants are missing. let sigma = mapMaybe isBool cs b <- filter (`notElem` sigma) [True, False] pure $ MatchConstr (ConstrLit (PatLitPrim (BoolValue b))) [] () : u _ -> do -- FIXME: this is wrong in the unlikely case where someone -- is pattern-matching every single possible number for -- some numeric type. It should be handled more like Bool -- above. pure $ MatchWild () : u findUnmatched [] n = [replicate n $ MatchWild ()] findUnmatched _ _ = [] {-# NOINLINE unmatched #-} -- | Find the unmatched cases. unmatched :: [Pat StructType] -> [Match ()] unmatched orig_ps = -- The algorithm may find duplicate example, which we filter away -- here. nubOrd $ mapMaybe maybeHead $ findUnmatched (map (L.singleton . patternToMatch) orig_ps) 1 futhark-0.25.27/src/Language/Futhark/TypeChecker/Modules.hs000066400000000000000000000553371475065116200234470ustar00rootroot00000000000000-- | Implementation of the Futhark module system (at least most of it; -- some is scattered elsewhere in the type checker). module Language.Futhark.TypeChecker.Modules ( matchMTys, newNamesForMTy, refineEnv, applyFunctor, ) where import Control.Monad import Data.Either import Data.Map.Strict qualified as M import Data.Maybe import Data.Ord import Data.Set qualified as S import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.Semantic import Language.Futhark.TypeChecker.Monad import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify (doUnification) import Prelude hiding (abs, mod) substituteTypesInMod :: TypeSubs -> Mod -> Mod substituteTypesInMod substs (ModEnv e) = ModEnv $ substituteTypesInEnv substs e substituteTypesInMod substs (ModFun (FunModType abs mod mty)) = ModFun $ FunModType abs (substituteTypesInMod substs mod) (substituteTypesInMTy substs mty) substituteTypesInMTy :: TypeSubs -> MTy -> MTy substituteTypesInMTy substs (MTy abs mod) = MTy abs $ substituteTypesInMod substs mod substituteTypesInEnv :: TypeSubs -> Env -> Env substituteTypesInEnv substs env = env { envVtable = M.map (snd . substituteTypesInBoundV substs) $ envVtable env, envTypeTable = M.mapWithKey subT $ envTypeTable env, envModTable = M.map (substituteTypesInMod substs) $ envModTable env } where subT name (TypeAbbr l _ _) | Just (Subst ps rt) <- substs name = TypeAbbr l ps rt subT _ (TypeAbbr l ps (RetType dims t)) = TypeAbbr l ps $ applySubst substs $ RetType dims t -- Also returns names of new sizes arising from substituting a -- size-lifted type at the outermost part of the type. This is a -- somewhat rare case (see #2120). The right solution is to generally -- fresh (or at least unique) names. substituteTypesInBoundV :: TypeSubs -> BoundV -> ([VName], BoundV) substituteTypesInBoundV substs (BoundV tps t) = let RetType dims t' = applySubst substs $ RetType [] t in (dims, BoundV (tps <> map (`TypeParamDim` mempty) dims) t') -- | All names defined anywhere in the 'Env'. allNamesInEnv :: Env -> S.Set VName allNamesInEnv (Env vtable ttable stable modtable _names) = S.fromList ( M.keys vtable ++ M.keys ttable ++ M.keys stable ++ M.keys modtable ) <> mconcat ( map allNamesInMTy (M.elems stable) ++ map allNamesInMod (M.elems modtable) ++ map allNamesInType (M.elems ttable) ) where allNamesInType (TypeAbbr _ ps _) = S.fromList $ map typeParamName ps allNamesInMod :: Mod -> S.Set VName allNamesInMod (ModEnv env) = allNamesInEnv env allNamesInMod ModFun {} = mempty allNamesInMTy :: MTy -> S.Set VName allNamesInMTy (MTy abs mod) = S.fromList (map qualLeaf $ M.keys abs) <> allNamesInMod mod -- | Create unique renames for the module type. This is used for -- e.g. generative functor application. newNamesForMTy :: MTy -> TypeM (MTy, M.Map VName VName) newNamesForMTy orig_mty = do pairs <- forM (S.toList $ allNamesInMTy orig_mty) $ \v -> do v' <- newName v pure (v, v') let substs = M.fromList pairs rev_substs = M.fromList $ map (uncurry $ flip (,)) pairs pure (substituteInMTy substs orig_mty, rev_substs) where substituteInMTy :: M.Map VName VName -> MTy -> MTy substituteInMTy substs (MTy mty_abs mty_mod) = MTy (M.mapKeys (fmap substitute) mty_abs) (substituteInMod mty_mod) where substituteInEnv (Env vtable ttable _stable modtable names) = let vtable' = substituteInMap substituteInBinding vtable ttable' = substituteInMap substituteInTypeBinding ttable mtable' = substituteInMap substituteInMod modtable in Env { envVtable = vtable', envTypeTable = ttable', envModTypeTable = mempty, envModTable = mtable', envNameMap = M.map (fmap substitute) names } substitute v = fromMaybe v $ M.lookup v substs -- For applySubst and friends. subst v = ExpSubst . flip sizeFromName mempty . qualName <$> M.lookup v substs substituteInMap f m = let (ks, vs) = unzip $ M.toList m in M.fromList $ zip (map (\k -> fromMaybe k $ M.lookup k substs) ks) (map f vs) substituteInBinding (BoundV ps t) = BoundV (map substituteInTypeParam ps) (substituteInType t) substituteInMod (ModEnv env) = ModEnv $ substituteInEnv env substituteInMod (ModFun funsig) = ModFun $ substituteInFunModType funsig substituteInFunModType (FunModType abs mod mty) = FunModType (M.mapKeys (fmap substitute) abs) (substituteInMod mod) (substituteInMTy substs mty) substituteInTypeBinding (TypeAbbr l ps (RetType dims t)) = TypeAbbr l (map substituteInTypeParam ps) $ RetType dims $ substituteInType t substituteInTypeParam (TypeParamDim p loc) = TypeParamDim (substitute p) loc substituteInTypeParam (TypeParamType l p loc) = TypeParamType l (substitute p) loc substituteInScalarType :: ScalarTypeBase Size u -> ScalarTypeBase Size u substituteInScalarType (TypeVar u (QualName qs v) targs) = TypeVar u (QualName (map substitute qs) $ substitute v) $ map substituteInTypeArg targs substituteInScalarType (Prim t) = Prim t substituteInScalarType (Record ts) = Record $ fmap substituteInType ts substituteInScalarType (Sum ts) = Sum $ (fmap . fmap) substituteInType ts substituteInScalarType (Arrow als v d1 t1 (RetType dims t2)) = Arrow als v d1 (substituteInType t1) $ RetType dims $ substituteInType t2 substituteInType :: TypeBase Size u -> TypeBase Size u substituteInType (Scalar t) = Scalar $ substituteInScalarType t substituteInType (Array u shape t) = Array u (substituteInShape shape) $ substituteInScalarType t substituteInShape (Shape ds) = Shape $ map (applySubst subst) ds substituteInTypeArg (TypeArgDim e) = TypeArgDim $ applySubst subst e substituteInTypeArg (TypeArgType t) = TypeArgType $ substituteInType t mtyTypeAbbrs :: MTy -> M.Map VName TypeBinding mtyTypeAbbrs (MTy _ mod) = modTypeAbbrs mod modTypeAbbrs :: Mod -> M.Map VName TypeBinding modTypeAbbrs (ModEnv env) = envTypeAbbrs env modTypeAbbrs (ModFun (FunModType _ mod mty)) = modTypeAbbrs mod <> mtyTypeAbbrs mty envTypeAbbrs :: Env -> M.Map VName TypeBinding envTypeAbbrs env = envTypeTable env <> (mconcat . map modTypeAbbrs . M.elems . envModTable) env -- | Refine the given type name in the given env. refineEnv :: SrcLoc -> TySet -> Env -> QualName Name -> [TypeParam] -> StructType -> TypeM (QualName VName, TySet, Env) refineEnv loc tset env tname ps t | Just (tname', TypeAbbr _ cur_ps (RetType _ (Scalar (TypeVar _ (QualName qs v) _)))) <- findTypeDef tname (ModEnv env), QualName (qualQuals tname') v `M.member` tset = if paramsMatch cur_ps ps then pure ( tname', QualName qs v `M.delete` tset, substituteTypesInEnv ( flip M.lookup $ M.fromList [ (qualLeaf tname', Subst cur_ps $ RetType [] t), (v, Subst ps $ RetType [] t) ] ) env ) else typeError loc mempty $ "Cannot refine a type having" <+> tpMsg ps <> " with a type having " <> tpMsg cur_ps <> "." | otherwise = typeError loc mempty $ dquotes (pretty tname) <+> "is not an abstract type in the module type." where tpMsg [] = "no type parameters" tpMsg xs = "type parameters" <+> hsep (map pretty xs) paramsMatch :: [TypeParam] -> [TypeParam] -> Bool paramsMatch ps1 ps2 = length ps1 == length ps2 && all match (zip ps1 ps2) where match (TypeParamType l1 _ _, TypeParamType l2 _ _) = l1 <= l2 match (TypeParamDim _ _, TypeParamDim _ _) = True match _ = False findBinding :: (Env -> M.Map VName v) -> Namespace -> Name -> Env -> Maybe (VName, v) findBinding table namespace name the_env = do QualName _ name' <- M.lookup (namespace, name) $ envNameMap the_env (name',) <$> M.lookup name' (table the_env) findTypeDef :: QualName Name -> Mod -> Maybe (QualName VName, TypeBinding) findTypeDef _ ModFun {} = Nothing findTypeDef (QualName [] name) (ModEnv the_env) = do (name', tb) <- findBinding envTypeTable Type name the_env pure (qualName name', tb) findTypeDef (QualName (q : qs) name) (ModEnv the_env) = do (q', q_mod) <- findBinding envModTable Term q the_env (QualName qs' name', tb) <- findTypeDef (QualName qs name) q_mod pure (QualName (q' : qs') name', tb) resolveAbsTypes :: TySet -> Mod -> TySet -> Loc -> Either TypeError (M.Map VName (QualName VName, TypeBinding)) resolveAbsTypes mod_abs mod sig_abs loc = do let abs_mapping = M.fromList $ zip (map (fmap baseName . fst) $ M.toList mod_abs) (M.toList mod_abs) fmap M.fromList . forM (M.toList sig_abs) $ \(name, name_l) -> case findTypeDef (fmap baseName name) mod of Just (name', TypeAbbr mod_l ps t) | mod_l > name_l -> mismatchedLiftedness name_l (map qualLeaf $ M.keys mod_abs) name (mod_l, ps, t) | name_l < SizeLifted, not $ null $ retDims t -> anonymousSizes (map qualLeaf $ M.keys mod_abs) name (mod_l, ps, t) | Just (abs_name, _) <- M.lookup (fmap baseName name) abs_mapping -> pure (qualLeaf name, (abs_name, TypeAbbr name_l ps t)) | otherwise -> pure (qualLeaf name, (name', TypeAbbr name_l ps t)) _ -> missingType loc $ fmap baseName name where mismatchedLiftedness name_l abs name mod_t = Left . TypeError (locOf loc) mempty $ "Module defines" indent 2 (ppTypeAbbr abs name mod_t) "but module type requires" <+> what <> "." where what = case name_l of Unlifted -> "a non-lifted type" SizeLifted -> "a size-lifted type" Lifted -> "a lifted type" anonymousSizes abs name mod_t = Left . TypeError (locOf loc) mempty $ "Module defines" indent 2 (ppTypeAbbr abs name mod_t) "which contains anonymous sizes, but module type requires non-lifted type." resolveMTyNames :: MTy -> MTy -> M.Map VName (QualName VName) resolveMTyNames = resolveMTyNames' where resolveMTyNames' (MTy _mod_abs mod) (MTy _sig_abs sig) = resolveModNames mod sig resolveModNames (ModEnv mod_env) (ModEnv sig_env) = resolveEnvNames mod_env sig_env resolveModNames (ModFun mod_fun) (ModFun sig_fun) = resolveModNames (funModTypeMod mod_fun) (funModTypeMod sig_fun) <> resolveMTyNames' (funModTypeMty mod_fun) (funModTypeMty sig_fun) resolveModNames _ _ = mempty resolveEnvNames mod_env sig_env = let mod_substs = resolve Term mod_env $ envModTable sig_env onMod (modname, mod_env_mod) = case M.lookup modname mod_substs of Just (QualName _ modname') | Just sig_env_mod <- M.lookup modname' $ envModTable mod_env -> resolveModNames sig_env_mod mod_env_mod _ -> mempty in mconcat [ resolve Term mod_env $ envVtable sig_env, resolve Type mod_env $ envVtable sig_env, resolve Signature mod_env $ envVtable sig_env, mod_substs, mconcat $ map onMod $ M.toList $ envModTable sig_env ] resolve namespace mod_env = M.mapMaybeWithKey resolve' where resolve' name _ = M.lookup (namespace, baseName name) $ envNameMap mod_env missingType :: (Pretty a) => Loc -> a -> Either TypeError b missingType loc name = Left . TypeError loc mempty $ "Module does not define a type named" <+> pretty name <> "." missingVal :: (Pretty a) => Loc -> a -> Either TypeError b missingVal loc name = Left . TypeError loc mempty $ "Module does not define a value named" <+> pretty name <> "." topLevelSize :: Loc -> VName -> Either TypeError b topLevelSize loc name = Left . TypeError loc mempty $ "Type substitution in" <+> dquotes (prettyName name) <+> "results in a top-level size." missingMod :: (Pretty a) => Loc -> a -> Either TypeError b missingMod loc name = Left . TypeError loc mempty $ "Module does not define a module named" <+> pretty name <> "." mismatchedType :: Loc -> [VName] -> [VName] -> VName -> (Liftedness, [TypeParam], StructRetType) -> (Liftedness, [TypeParam], StructRetType) -> Either TypeError b mismatchedType loc abs quals name spec_t env_t = Left . TypeError loc mempty $ "Module defines" indent 2 (ppTypeAbbr abs (QualName quals name) env_t) "but module type requires" indent 2 (ppTypeAbbr abs (QualName quals name) spec_t) ppTypeAbbr :: [VName] -> QualName VName -> (Liftedness, [TypeParam], StructRetType) -> Doc a ppTypeAbbr abs name (l, ps, RetType [] (Scalar (TypeVar _ tn args))) | qualLeaf tn `elem` abs, map typeParamToArg ps == args = "type" <> pretty l <+> pretty name <+> hsep (map pretty ps) ppTypeAbbr _ name (l, ps, t) = "type" <> pretty l <+> hsep (pretty name : map pretty ps) <+> equals <+> nest 2 (align (pretty t)) -- | Return new renamed/abstracted env, as well as a mapping from -- names in the signature to names in the new env. This is used for -- functor application. The first env is the module env, and the -- second the env it must match. matchMTys :: MTy -> MTy -> Loc -> Either TypeError (M.Map VName VName) matchMTys orig_mty orig_mty_sig = matchMTys' (M.map (ExpSubst . flip sizeFromName mempty) $ resolveMTyNames orig_mty orig_mty_sig) [] orig_mty orig_mty_sig where matchMTys' :: M.Map VName (Subst StructRetType) -> [VName] -> MTy -> MTy -> Loc -> Either TypeError (M.Map VName VName) matchMTys' _ _ (MTy _ ModFun {}) (MTy _ ModEnv {}) loc = Left $ TypeError loc mempty "Cannot match parametric module with non-parametric module type." matchMTys' _ _ (MTy _ ModEnv {}) (MTy _ ModFun {}) loc = Left $ TypeError loc mempty "Cannot match non-parametric module with paramatric module type." matchMTys' old_abs_subst_to_type quals (MTy mod_abs mod) (MTy sig_abs sig) loc = do -- Check that abstract types in 'sig' have an implementation in -- 'mod'. This also gives us a substitution that we use to check -- the types of values. abs_substs <- resolveAbsTypes mod_abs mod sig_abs loc let abs_subst_to_type = old_abs_subst_to_type <> M.map (substFromAbbr . snd) abs_substs abs_name_substs = M.map (qualLeaf . fst) abs_substs substs <- matchMods abs_subst_to_type quals mod sig loc pure (substs <> abs_name_substs) matchMods :: M.Map VName (Subst StructRetType) -> [VName] -> Mod -> Mod -> Loc -> Either TypeError (M.Map VName VName) matchMods _ _ ModEnv {} ModFun {} loc = Left $ TypeError loc mempty "Cannot match non-parametric module with parametric module type." matchMods _ _ ModFun {} ModEnv {} loc = Left $ TypeError loc mempty "Cannot match parametric module with non-parametric module type." matchMods abs_subst_to_type quals (ModEnv mod) (ModEnv sig) loc = matchEnvs abs_subst_to_type quals mod sig loc matchMods old_abs_subst_to_type quals (ModFun (FunModType mod_abs mod_pmod mod_mod)) (ModFun (FunModType sig_abs sig_pmod sig_mod)) loc = do -- We need to use different substitutions when matching -- parameter and body signatures - this is because the -- concrete parameter must be *at least as* general as the -- ascripted parameter, while the concrete body must be *at -- most as* general as the ascripted body. abs_substs <- resolveAbsTypes mod_abs mod_pmod sig_abs loc p_abs_substs <- resolveAbsTypes sig_abs sig_pmod mod_abs loc let abs_subst_to_type = old_abs_subst_to_type <> M.map (substFromAbbr . snd) abs_substs p_abs_subst_to_type = old_abs_subst_to_type <> M.map (substFromAbbr . snd) p_abs_substs abs_name_substs = M.map (qualLeaf . fst) abs_substs pmod_substs <- matchMods p_abs_subst_to_type quals sig_pmod mod_pmod loc mod_substs <- matchMTys' abs_subst_to_type quals mod_mod sig_mod loc pure (pmod_substs <> mod_substs <> abs_name_substs) matchEnvs :: M.Map VName (Subst StructRetType) -> [VName] -> Env -> Env -> Loc -> Either TypeError (M.Map VName VName) matchEnvs abs_subst_to_type quals env sig loc = do -- XXX: we only want to create substitutions for visible names. -- This must be wrong in some cases. Probably we need to -- rethink how we do shadowing for module types. let visible = S.fromList $ map qualLeaf $ M.elems $ envNameMap sig isVisible name = name `S.member` visible -- Check that all type abbreviations are correctly defined. abbr_name_substs <- fmap M.fromList $ forM (filter (isVisible . fst) $ M.toList $ envTypeTable sig) $ \(name, TypeAbbr spec_l spec_ps spec_t) -> case findBinding envTypeTable Type (baseName name) env of Just (name', TypeAbbr l ps t) -> matchTypeAbbr loc abs_subst_to_type quals name spec_l spec_ps spec_t name' l ps t Nothing -> missingType loc $ baseName name -- Check that all values are defined correctly, substituting the -- abstract types first. val_substs <- fmap M.fromList $ forM (M.toList $ envVtable sig) $ \(name, spec_bv) -> do let (spec_dims, spec_bv') = substituteTypesInBoundV (`M.lookup` abs_subst_to_type) spec_bv (spec_witnesses, _) = determineSizeWitnesses $ boundValType spec_bv' -- The hacky check for #2120. when (any (`S.member` spec_witnesses) spec_dims) $ topLevelSize loc name case findBinding envVtable Term (baseName name) env of Just (name', bv) -> matchVal loc quals name spec_bv' name' bv _ -> missingVal loc (baseName name) -- Check for correct modules. mod_substs <- fmap M.unions $ forM (M.toList $ envModTable sig) $ \(name, modspec) -> case findBinding envModTable Term (baseName name) env of Just (name', mod) -> M.insert name name' <$> matchMods abs_subst_to_type (quals ++ [name]) mod modspec loc Nothing -> missingMod loc $ baseName name pure $ val_substs <> mod_substs <> abbr_name_substs matchTypeAbbr :: Loc -> M.Map VName (Subst StructRetType) -> [VName] -> VName -> Liftedness -> [TypeParam] -> StructRetType -> VName -> Liftedness -> [TypeParam] -> StructRetType -> Either TypeError (VName, VName) matchTypeAbbr loc abs_subst_to_type quals spec_name spec_l spec_ps spec_t name l ps t = do -- Number of type parameters must match. unless (length spec_ps == length ps) $ nomatch spec_t -- Abstract types have a particular restriction to ensure that -- if we have a value of an abstract type 't [n]', then there is -- an array of size 'n' somewhere inside. when (M.member spec_name abs_subst_to_type) $ case filter (`S.notMember` fst (determineSizeWitnesses (retType t))) (map typeParamName $ filter isSizeParam ps) of [] -> pure () d : _ -> Left . TypeError loc mempty $ "Type" indent 2 (ppTypeAbbr [] (QualName quals name) (l, ps, t)) textwrap "cannot be made abstract because size parameter" indent 2 (prettyName d) textwrap "is not used constructively as an array size in the definition." let spec_t' = applySubst (`M.lookup` abs_subst_to_type) spec_t nonrigid = ps <> map (`TypeParamDim` mempty) (retDims t) case doUnification loc spec_ps nonrigid (retType spec_t') (retType t) of Right _ -> pure (spec_name, name) _ -> nomatch spec_t' where nomatch spec_t' = mismatchedType loc (M.keys abs_subst_to_type) quals spec_name (spec_l, spec_ps, spec_t') (l, ps, t) matchVal :: Loc -> [VName] -> VName -> BoundV -> VName -> BoundV -> Either TypeError (VName, VName) matchVal loc quals spec_name spec_v name v = case matchValBinding loc spec_v v of Nothing -> pure (spec_name, name) Just problem -> Left $ TypeError loc mempty $ "Module type specifies" indent 2 (ppValBind (QualName quals spec_name) spec_v) "but module provides" indent 2 (ppValBind (QualName quals spec_name) v) problem matchValBinding :: Loc -> BoundV -> BoundV -> Maybe (Doc ()) matchValBinding loc (BoundV spec_tps orig_spec_t) (BoundV tps orig_t) = do case doUnification loc spec_tps tps (toStruct orig_spec_t) (toStruct orig_t) of Left (TypeError _ notes msg) -> Just $ msg <> pretty notes Right _ -> Nothing ppValBind v (BoundV tps t) = "val" <+> pretty v <+> hsep (map pretty tps) <+> colon indent 2 (align (pretty t)) -- | Apply a parametric module to an argument. applyFunctor :: Loc -> FunModType -> MTy -> TypeM ( MTy, M.Map VName VName, M.Map VName VName ) applyFunctor applyloc (FunModType p_abs p_mod body_mty) a_mty = do p_subst <- badOnLeft $ matchMTys a_mty (MTy p_abs p_mod) applyloc -- Apply type abbreviations from a_mty to body_mty. let a_abbrs = mtyTypeAbbrs a_mty isSub v = case M.lookup v a_abbrs of Just abbr -> Just $ substFromAbbr abbr _ -> Just $ ExpSubst $ sizeFromName (qualName v) mempty type_subst = M.mapMaybe isSub p_subst body_mty' = substituteTypesInMTy (`M.lookup` type_subst) body_mty (body_mty'', body_subst) <- newNamesForMTy body_mty' pure (body_mty'', p_subst, body_subst) futhark-0.25.27/src/Language/Futhark/TypeChecker/Monad.hs000066400000000000000000000452741475065116200230740ustar00rootroot00000000000000-- | Main monad in which the type checker runs, as well as ancillary -- data definitions. module Language.Futhark.TypeChecker.Monad ( TypeM, runTypeM, askEnv, askImportName, atTopLevel, enteringModule, bindSpaced, bindSpaced1, bindIdents, qualifyTypeVars, lookupMTy, lookupImport, lookupMod, localEnv, TypeError (..), prettyTypeError, prettyTypeErrorNoLoc, withIndexLink, unappliedFunctor, unknownVariable, underscoreUse, Notes, aNote, MonadTypeChecker (..), TypeState (stateNameSource), usedName, checkName, checkAttr, checkQualName, checkValName, badOnLeft, isKnownType, module Language.Futhark.Warnings, Env (..), TySet, FunModType (..), ImportTable, NameMap, BoundV (..), Mod (..), TypeBinding (..), MTy (..), anySignedType, anyUnsignedType, anyIntType, anyFloatType, anyNumberType, anyPrimType, Namespace (..), intrinsicsNameMap, topLevelNameMap, mkTypeVarName, ) where import Control.Monad import Control.Monad.Except import Control.Monad.Identity import Control.Monad.Reader import Control.Monad.State.Strict import Data.Either import Data.List (find) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T import Data.Version qualified as Version import Futhark.FreshNames hiding (newName) import Futhark.FreshNames qualified import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Semantic import Language.Futhark.Traversals import Language.Futhark.Warnings import Paths_futhark qualified import Prelude hiding (mapM, mod) newtype Note = Note (Doc ()) -- | A collection of extra information regarding a type error. newtype Notes = Notes [Note] deriving (Semigroup, Monoid) instance Pretty Note where pretty (Note msg) = unAnnotate $ "Note:" <+> align msg instance Pretty Notes where pretty (Notes notes) = unAnnotate $ foldMap (((line <> line) <>) . pretty) notes -- | A single note. aNote :: Doc () -> Notes aNote = Notes . pure . Note -- | Information about an error during type checking. data TypeError = TypeError Loc Notes (Doc ()) -- | Prettyprint type error. prettyTypeError :: TypeError -> Doc AnsiStyle prettyTypeError (TypeError loc notes msg) = annotate (bold <> color Red) ("Error at " <> pretty (locText (srclocOf loc)) <> ":") prettyTypeErrorNoLoc (TypeError loc notes msg) -- | Prettyprint type error, without location information. This can -- be used for cases where the location is printed in some other way. prettyTypeErrorNoLoc :: TypeError -> Doc AnsiStyle prettyTypeErrorNoLoc (TypeError _ notes msg) = unAnnotate msg <> pretty notes <> hardline errorIndexUrl :: Doc a errorIndexUrl = version_url <> "error-index.html" where version = Paths_futhark.version base_url = "https://futhark.readthedocs.io/en/" version_url | last (Version.versionBranch version) == 0 = base_url <> "latest/" | otherwise = base_url <> "v" <> pretty (Version.showVersion version) <> "/" -- | Attach a reference to documentation explaining the error in more detail. withIndexLink :: Doc a -> Doc a -> Doc a withIndexLink href msg = stack [ msg, "\nFor more information, see:", indent 2 (errorIndexUrl <> "#" <> href) ] -- | An unexpected functor appeared! unappliedFunctor :: (MonadTypeChecker m) => SrcLoc -> m a unappliedFunctor loc = typeError loc mempty "Cannot have parametric module here." -- | An unknown variable was referenced. unknownVariable :: (MonadTypeChecker m) => Namespace -> QualName Name -> SrcLoc -> m a unknownVariable space name loc = typeError loc mempty $ "Unknown" <+> pretty space <+> dquotes (pretty name) -- | A name prefixed with an underscore was used. underscoreUse :: (MonadTypeChecker m) => SrcLoc -> QualName Name -> m a underscoreUse loc name = typeError loc mempty $ "Use of" <+> dquotes (pretty name) <> ": variables prefixed with underscore may not be accessed." -- | A mapping from import import names to 'Env's. This is used to -- resolve @import@ declarations. type ImportTable = M.Map ImportName Env data Context = Context { contextEnv :: Env, contextImportTable :: ImportTable, contextImportName :: ImportName, -- | Currently type-checking at the top level? If false, we are -- inside a module. contextAtTopLevel :: Bool } data TypeState = TypeState { stateNameSource :: VNameSource, stateWarnings :: Warnings, -- | Which names have been used. stateUsed :: S.Set VName, stateCounter :: Int } -- | The type checker runs in this monad. newtype TypeM a = TypeM ( ReaderT Context (StateT TypeState (Except (Warnings, TypeError))) a ) deriving ( Monad, Functor, Applicative, MonadReader Context, MonadState TypeState ) instance MonadError TypeError TypeM where throwError e = TypeM $ do ws <- gets stateWarnings throwError (ws, e) catchError (TypeM m) f = TypeM $ m `catchError` f' where f' (_, e) = let TypeM m' = f e in m' -- | Run a 'TypeM' computation. runTypeM :: Env -> ImportTable -> ImportName -> VNameSource -> TypeM a -> (Warnings, Either TypeError (a, VNameSource)) runTypeM env imports fpath src (TypeM m) = do let ctx = Context env imports fpath True s = TypeState src mempty mempty 0 case runExcept $ runStateT (runReaderT m ctx) s of Left (ws, e) -> (ws, Left e) Right (x, s') -> (stateWarnings s', Right (x, stateNameSource s')) -- | Retrieve the current 'Env'. askEnv :: TypeM Env askEnv = asks contextEnv -- | The name of the current file/import. askImportName :: TypeM ImportName askImportName = asks contextImportName -- | Are we type-checking at the top level, or are we inside a nested -- module? atTopLevel :: TypeM Bool atTopLevel = asks contextAtTopLevel -- | We are now going to type-check the body of a module. enteringModule :: TypeM a -> TypeM a enteringModule = local $ \ctx -> ctx {contextAtTopLevel = False} -- | Look up a module type. lookupMTy :: SrcLoc -> QualName Name -> TypeM (QualName VName, MTy) lookupMTy loc qn = do (scope, qn'@(QualName _ name)) <- checkQualNameWithEnv Signature qn loc (qn',) <$> maybe explode pure (M.lookup name $ envModTypeTable scope) where explode = unknownVariable Signature qn loc -- | Look up an import. lookupImport :: SrcLoc -> FilePath -> TypeM (ImportName, Env) lookupImport loc file = do imports <- asks contextImportTable my_path <- asks contextImportName let canonical_import = mkImportFrom my_path file case M.lookup canonical_import imports of Nothing -> typeError loc mempty $ "Unknown import" <+> dquotes (pretty (includeToText canonical_import)) "Known:" <+> commasep (map (pretty . includeToText) (M.keys imports)) Just scope -> pure (canonical_import, scope) -- | Evaluate a 'TypeM' computation within an extended (/not/ -- replaced) environment. localEnv :: Env -> TypeM a -> TypeM a localEnv env = local $ \ctx -> let env' = env <> contextEnv ctx in ctx {contextEnv = env'} incCounter :: TypeM Int incCounter = do s <- get put s {stateCounter = stateCounter s + 1} pure $ stateCounter s bindNameMap :: NameMap -> TypeM a -> TypeM a bindNameMap m = local $ \ctx -> let env = contextEnv ctx in ctx {contextEnv = env {envNameMap = m <> envNameMap env}} -- | Monads that support type checking. The reason we have this -- internal interface is because we use distinct monads for checking -- expressions and declarations. class (Monad m) => MonadTypeChecker m where warn :: (Located loc) => loc -> Doc () -> m () warnings :: Warnings -> m () newName :: VName -> m VName newID :: Name -> m VName newID s = newName $ VName s 0 newTypeName :: Name -> m VName bindVal :: VName -> BoundV -> m a -> m a lookupType :: QualName VName -> m ([TypeParam], StructRetType, Liftedness) typeError :: (Located loc) => loc -> Notes -> Doc () -> m a warnIfUnused :: (Namespace, VName, SrcLoc) -> TypeM () warnIfUnused (ns, name, loc) = do used <- gets stateUsed unless (name `S.member` used || "_" `T.isPrefixOf` nameToText (baseName name)) $ warn loc $ "Unused" <+> pretty ns <+> dquotes (prettyName name) <> "." -- | Map source-level names to fresh unique internal names, and -- evaluate a type checker context with the mapping active. bindSpaced :: [(Namespace, Name, SrcLoc)] -> ([VName] -> TypeM a) -> TypeM a bindSpaced names body = do names' <- mapM (\(_, v, _) -> newID v) names let mapping = M.fromList $ zip (map (\(ns, v, _) -> (ns, v)) names) $ map qualName names' bindNameMap mapping (body names') <* mapM_ warnIfUnused [(ns, v, loc) | ((ns, _, loc), v) <- zip names names'] -- | Map single source-level name to fresh unique internal names, and -- evaluate a type checker context with the mapping active. bindSpaced1 :: Namespace -> Name -> SrcLoc -> (VName -> TypeM a) -> TypeM a bindSpaced1 ns name loc body = do name' <- newID name let mapping = M.singleton (ns, name) $ qualName name' bindNameMap mapping (body name') <* warnIfUnused (ns, name', loc) -- | Bind these identifiers in the name map and also check whether -- they have been used. bindIdents :: [IdentBase NoInfo VName t] -> TypeM a -> TypeM a bindIdents idents body = do let mapping = M.fromList $ zip (map ((Term,) . (baseName . identName)) idents) (map (qualName . identName) idents) bindNameMap mapping body <* mapM_ warnIfUnused [(Term, v, loc) | Ident v _ loc <- idents] -- | Indicate that this name has been used. This is usually done -- implicitly by other operations, but sometimes we want to make a -- "fake" use to avoid things like top level functions being -- considered unused. usedName :: VName -> TypeM () usedName name = modify $ \s -> s {stateUsed = S.insert name $ stateUsed s} instance MonadTypeChecker TypeM where warnings ws = modify $ \s -> s {stateWarnings = stateWarnings s <> ws} warn loc problem = warnings $ singleWarning (locOf loc) problem newName v = do s <- get let (v', src') = Futhark.FreshNames.newName (stateNameSource s) v put $ s {stateNameSource = src'} pure v' newTypeName name = do i <- incCounter newID $ mkTypeVarName name i bindVal v t = local $ \ctx -> ctx { contextEnv = (contextEnv ctx) { envVtable = M.insert v t $ envVtable $ contextEnv ctx } } lookupType qn = do outer_env <- askEnv scope <- lookupQualNameEnv qn case M.lookup (qualLeaf qn) $ envTypeTable scope of Nothing -> error $ "lookupType: " <> show qn Just (TypeAbbr l ps (RetType dims def)) -> pure (ps, RetType dims $ qualifyTypeVars outer_env mempty (qualQuals qn) def, l) typeError loc notes s = throwError $ TypeError (locOf loc) notes s lookupQualNameEnv :: QualName VName -> TypeM Env lookupQualNameEnv qn@(QualName quals _) = do env <- askEnv descend env quals where descend scope [] = pure scope descend scope (q : qs) | Just (ModEnv q_scope) <- M.lookup q $ envModTable scope = descend q_scope qs | otherwise = error $ "lookupQualNameEnv: " ++ show qn checkQualNameWithEnv :: Namespace -> QualName Name -> SrcLoc -> TypeM (Env, QualName VName) checkQualNameWithEnv space qn@(QualName quals name) loc = do env <- askEnv descend env quals where descend scope [] | Just name' <- M.lookup (space, name) $ envNameMap scope = do usedName $ qualLeaf name' pure (scope, name') | otherwise = unknownVariable space qn loc descend scope (q : qs) | Just (QualName _ q') <- M.lookup (Term, q) $ envNameMap scope, Just res <- M.lookup q' $ envModTable scope = do usedName q' case res of ModEnv q_scope -> do (scope', QualName qs' name') <- descend q_scope qs pure (scope', QualName (q' : qs') name') ModFun {} -> unappliedFunctor loc | otherwise = unknownVariable space qn loc -- | Elaborate the given qualified name in the given namespace at the -- given location, producing the corresponding unique 'QualName'. -- Fails if the name is a module. checkValName :: QualName Name -> SrcLoc -> TypeM (QualName VName) checkValName name loc = do (env, name') <- checkQualNameWithEnv Term name loc case M.lookup (qualLeaf name') $ envModTable env of Just _ -> unknownVariable Term name loc Nothing -> pure name' -- | Elaborate the given qualified name in the given namespace at the -- given location, producing the corresponding unique 'QualName'. checkQualName :: Namespace -> QualName Name -> SrcLoc -> TypeM (QualName VName) checkQualName space name loc = snd <$> checkQualNameWithEnv space name loc -- | Elaborate the given name in the given namespace at the given -- location, producing the corresponding unique 'VName'. checkName :: Namespace -> Name -> SrcLoc -> TypeM VName checkName space name loc = qualLeaf <$> checkQualName space (qualName name) loc -- | Does a type with this name already exist? This is used for -- warnings, so it is OK it's a little unprincipled. isKnownType :: QualName VName -> TypeM Bool isKnownType qn = do env <- askEnv descend env (qualQuals qn) (qualLeaf qn) where descend env [] v | Just v' <- M.lookup (Type, baseName v) $ envNameMap env = pure $ M.member (qualLeaf v') $ envTypeTable env descend env (q : qs) v | Just q' <- M.lookup (Term, baseName q) $ envNameMap env, Just (ModEnv env') <- M.lookup (qualLeaf q') $ envModTable env = descend env' qs v descend _ _ _ = pure False lookupMod :: SrcLoc -> QualName Name -> TypeM (QualName VName, Mod) lookupMod loc qn = do (scope, qn'@(QualName _ name)) <- checkQualNameWithEnv Term qn loc case M.lookup name $ envModTable scope of Nothing -> unknownVariable Term qn loc Just m -> pure (qn', m) -- | Try to prepend qualifiers to the type names such that they -- represent how to access the type in some scope. qualifyTypeVars :: Env -> [VName] -> [VName] -> TypeBase Size as -> TypeBase Size as qualifyTypeVars outer_env orig_except ref_qs = onType (S.fromList orig_except) where onType :: S.Set VName -> TypeBase Size as -> TypeBase Size as onType except (Array u shape et) = Array u (fmap (onDim except) shape) (onScalar except et) onType except (Scalar t) = Scalar $ onScalar except t onScalar _ (Prim t) = Prim t onScalar except (TypeVar u qn targs) = TypeVar u (qual except qn) (map (onTypeArg except) targs) onScalar except (Record m) = Record $ M.map (onType except) m onScalar except (Sum m) = Sum $ M.map (map $ onType except) m onScalar except (Arrow as p d t1 (RetType dims t2)) = Arrow as p d (onType except' t1) $ RetType dims (onType except' t2) where except' = case p of Named p' -> S.insert p' except Unnamed -> except onTypeArg except (TypeArgDim d) = TypeArgDim $ onDim except d onTypeArg except (TypeArgType t) = TypeArgType $ onType except t onDim except e = runIdentity $ onDimM except e onDimM except (Var qn typ loc) = pure $ Var (qual except qn) typ loc onDimM except e = astMap (identityMapper {mapOnExp = onDimM except}) e qual except (QualName orig_qs name) | name `elem` except || reachable orig_qs name outer_env = QualName orig_qs name | otherwise = prependAsNecessary [] ref_qs $ QualName orig_qs name prependAsNecessary qs rem_qs (QualName orig_qs name) | reachable (qs ++ orig_qs) name outer_env = QualName (qs ++ orig_qs) name | otherwise = case rem_qs of q : rem_qs' -> prependAsNecessary (qs ++ [q]) rem_qs' (QualName orig_qs name) [] -> QualName orig_qs name reachable [] name env = name `M.member` envVtable env || isJust (find matches $ M.elems (envTypeTable env)) where matches (TypeAbbr _ _ (RetType _ (Scalar (TypeVar _ (QualName x_qs name') _)))) = null x_qs && name == name' matches _ = False reachable (q : qs') name env | Just (ModEnv env') <- M.lookup q $ envModTable env = reachable qs' name env' | otherwise = False -- | Turn a 'Left' 'TypeError' into an actual error. badOnLeft :: Either TypeError a -> TypeM a badOnLeft = either throwError pure -- | All signed integer types. anySignedType :: [PrimType] anySignedType = map Signed [minBound .. maxBound] -- | All unsigned integer types. anyUnsignedType :: [PrimType] anyUnsignedType = map Unsigned [minBound .. maxBound] -- | All integer types. anyIntType :: [PrimType] anyIntType = anySignedType ++ anyUnsignedType -- | All floating-point types. anyFloatType :: [PrimType] anyFloatType = map FloatType [minBound .. maxBound] -- | All number types. anyNumberType :: [PrimType] anyNumberType = anyIntType ++ anyFloatType -- | All primitive types. anyPrimType :: [PrimType] anyPrimType = Bool : anyIntType ++ anyFloatType --- Name handling -- | The 'NameMap' corresponding to the intrinsics module. intrinsicsNameMap :: NameMap intrinsicsNameMap = M.fromList $ map mapping $ M.toList intrinsics where mapping (v, IntrinsicType {}) = ((Type, baseName v), QualName [] v) mapping (v, _) = ((Term, baseName v), QualName [] v) -- | The names that are available in the initial environment. topLevelNameMap :: NameMap topLevelNameMap = M.filterWithKey (\k _ -> available k) intrinsicsNameMap where available :: (Namespace, Name) -> Bool available (Type, _) = True available (Term, v) = v `S.member` (type_names <> binop_names <> fun_names) where type_names = S.fromList $ map (nameFromText . prettyText) anyPrimType binop_names = S.fromList $ map (nameFromText . prettyText) [minBound .. (maxBound :: BinOp)] fun_names = S.fromList [nameFromString "shape"] available _ = False -- | Construct the name of a new type variable given a base -- description and a tag number (note that this is distinct from -- actually constructing a VName; the tag here is intended for human -- consumption but the machine does not care). mkTypeVarName :: Name -> Int -> Name mkTypeVarName desc i = desc <> nameFromString (mapMaybe subscript (show i)) where subscript = flip lookup $ zip "0123456789" "₀₁₂₃₄₅₆₇₈₉" -- | Type-check an attribute. checkAttr :: (MonadTypeChecker m) => AttrInfo VName -> m (AttrInfo VName) checkAttr (AttrComp f attrs loc) = AttrComp f <$> mapM checkAttr attrs <*> pure loc checkAttr (AttrAtom (AtomName v) loc) = pure $ AttrAtom (AtomName v) loc checkAttr (AttrAtom (AtomInt x) loc) = pure $ AttrAtom (AtomInt x) loc futhark-0.25.27/src/Language/Futhark/TypeChecker/Names.hs000066400000000000000000000457351475065116200231030ustar00rootroot00000000000000-- | Resolve names. -- -- This also performs a small amount of rewriting; specifically -- turning 'Var's with qualified names into 'Project's, based on -- whether they are referencing a module or not. -- -- Also checks for other name-related problems, such as duplicate -- names. module Language.Futhark.TypeChecker.Names ( resolveValBind, resolveTypeParams, resolveTypeExp, resolveExp, ) where import Control.Monad import Control.Monad.Except import Control.Monad.State import Data.List qualified as L import Data.Map qualified as M import Data.Text qualified as T import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.Semantic (includeToFilePath) import Language.Futhark.TypeChecker.Monad import Prelude hiding (mod) -- | Names that may not be shadowed. doNotShadow :: [Name] doNotShadow = ["&&", "||"] checkDoNotShadow :: (Located a) => a -> Name -> TypeM () checkDoNotShadow loc v = when (v `elem` doNotShadow) $ typeError loc mempty . withIndexLink "may-not-be-redefined" $ "The" <+> prettyName v <+> "operator may not be redefined." -- | Check whether the type contains arrow types that define the same -- parameter. These might also exist further down, but that's not -- really a problem - we mostly do this checking to help the user, -- since it is likely an error, but it's easy to assign a semantics to -- it (normal name shadowing). checkForDuplicateNamesInType :: TypeExp (ExpBase NoInfo Name) Name -> TypeM () checkForDuplicateNamesInType = check mempty where bad v loc prev_loc = typeError loc mempty $ "Name" <+> dquotes (pretty v) <+> "also bound at" <+> pretty (locStr prev_loc) <> "." check seen (TEArrow (Just v) t1 t2 loc) | Just prev_loc <- M.lookup v seen = bad v loc prev_loc | otherwise = check seen' t1 >> check seen' t2 where seen' = M.insert v loc seen check seen (TEArrow Nothing t1 t2 _) = check seen t1 >> check seen t2 check seen (TETuple ts _) = mapM_ (check seen) ts check seen (TERecord fs _) = mapM_ (check seen . snd) fs check seen (TEUnique t _) = check seen t check seen (TESum cs _) = mapM_ (mapM (check seen) . snd) cs check seen (TEApply t1 (TypeArgExpType t2) _) = check seen t1 >> check seen t2 check seen (TEApply t1 TypeArgExpSize {} _) = check seen t1 check seen (TEDim (v : vs) t loc) | Just prev_loc <- M.lookup v seen = bad v loc prev_loc | otherwise = check (M.insert v loc seen) (TEDim vs t loc) check seen (TEDim [] t _) = check seen t check _ TEArray {} = pure () check _ TEVar {} = pure () check seen (TEParens te _) = check seen te -- | Check for duplication of names inside a binding group. checkForDuplicateNames :: (MonadTypeChecker m) => [UncheckedTypeParam] -> [UncheckedPat t] -> m () checkForDuplicateNames tps pats = (`evalStateT` mempty) $ do mapM_ checkTypeParam tps mapM_ checkPat pats where checkTypeParam (TypeParamType _ v loc) = seen Type v loc checkTypeParam (TypeParamDim v loc) = seen Term v loc checkPat (Id v _ loc) = seen Term v loc checkPat (PatParens p _) = checkPat p checkPat (PatAttr _ p _) = checkPat p checkPat Wildcard {} = pure () checkPat (TuplePat ps _) = mapM_ checkPat ps checkPat (RecordPat fs _) = mapM_ (checkPat . snd) fs checkPat (PatAscription p _ _) = checkPat p checkPat PatLit {} = pure () checkPat (PatConstr _ _ ps _) = mapM_ checkPat ps seen ns v loc = do already <- gets $ M.lookup (ns, v) case already of Just prev_loc -> lift $ typeError loc mempty $ "Name" <+> dquotes (pretty v) <+> "also bound at" <+> pretty (locStr prev_loc) <> "." Nothing -> modify $ M.insert (ns, v) loc resolveQualName :: QualName Name -> SrcLoc -> TypeM (QualName VName) resolveQualName v loc = do v' <- checkValName v loc case v' of QualName (q : _) _ | baseTag q <= maxIntrinsicTag -> do me <- askImportName unless (isBuiltin (includeToFilePath me)) $ warn loc "Using intrinsic functions directly can easily crash the compiler or result in wrong code generation." _ -> pure () pure v' resolveName :: Name -> SrcLoc -> TypeM VName resolveName v loc = qualLeaf <$> resolveQualName (qualName v) loc resolveAttrAtom :: AttrAtom Name -> TypeM (AttrAtom VName) resolveAttrAtom (AtomName v) = pure $ AtomName v resolveAttrAtom (AtomInt x) = pure $ AtomInt x resolveAttrInfo :: AttrInfo Name -> TypeM (AttrInfo VName) resolveAttrInfo (AttrAtom atom loc) = AttrAtom <$> resolveAttrAtom atom <*> pure loc resolveAttrInfo (AttrComp name infos loc) = AttrComp name <$> mapM resolveAttrInfo infos <*> pure loc resolveSizeExp :: SizeExp (ExpBase NoInfo Name) -> TypeM (SizeExp (ExpBase NoInfo VName)) resolveSizeExp (SizeExpAny loc) = pure $ SizeExpAny loc resolveSizeExp (SizeExp e loc) = SizeExp <$> resolveExp e <*> pure loc -- | Resolve names in a single type expression. resolveTypeExp :: TypeExp (ExpBase NoInfo Name) Name -> TypeM (TypeExp (ExpBase NoInfo VName) VName) resolveTypeExp orig = checkForDuplicateNamesInType orig >> f orig where f (TEVar v loc) = TEVar <$> checkQualName Type v loc <*> pure loc f (TEParens te loc) = TEParens <$> f te <*> pure loc f (TETuple tes loc) = TETuple <$> mapM f tes <*> pure loc f (TERecord fs loc) = TERecord <$> mapM (traverse f) fs <*> pure loc f (TEUnique te loc) = TEUnique <$> f te <*> pure loc f (TEApply te1 args loc) = TEApply <$> f te1 <*> onArg args <*> pure loc where onArg (TypeArgExpSize size) = TypeArgExpSize <$> resolveSizeExp size onArg (TypeArgExpType te) = TypeArgExpType <$> f te f (TEArrow Nothing te1 te2 loc) = TEArrow Nothing <$> f te1 <*> f te2 <*> pure loc f (TEArrow (Just v) te1 te2 loc) = bindSpaced1 Term v loc $ \v' -> do usedName v' TEArrow (Just v') <$> f te1 <*> f te2 <*> pure loc f (TESum cs loc) = TESum <$> mapM (traverse $ mapM f) cs <*> pure loc f (TEDim vs te loc) = bindSpaced (map (Term,,loc) vs) $ \vs' -> TEDim vs' <$> f te <*> pure loc f (TEArray size te loc) = TEArray <$> resolveSizeExp size <*> f te <*> pure loc -- | Resolve names in a single expression. resolveExp :: ExpBase NoInfo Name -> TypeM (ExpBase NoInfo VName) -- -- First all the trivial cases. resolveExp (Literal x loc) = pure $ Literal x loc resolveExp (IntLit x NoInfo loc) = pure $ IntLit x NoInfo loc resolveExp (FloatLit x NoInfo loc) = pure $ FloatLit x NoInfo loc resolveExp (StringLit x loc) = pure $ StringLit x loc resolveExp (Hole NoInfo loc) = pure $ Hole NoInfo loc -- -- The main interesting cases (except for the ones in AppExp). resolveExp (Var qn NoInfo loc) = do -- The qualifiers of a variable is divided into two parts: first a -- possibly-empty sequence of module qualifiers, followed by a -- possible-empty sequence of record field accesses. We use scope -- information to perform the split, by taking qualifiers off the -- end until we find something that is not a module. (qn', fields) <- findRootVar (qualQuals qn) (qualLeaf qn) when ("_" `T.isPrefixOf` nameToText (qualLeaf qn)) $ underscoreUse loc qn pure $ L.foldl' project (Var qn' NoInfo loc) fields where findRootVar qs name = (whenFound <$> resolveQualName (QualName qs name) loc) `catchError` notFound qs name whenFound qn' = (qn', []) notFound qs name err | null qs = throwError err | otherwise = do (qn', fields) <- findRootVar (init qs) (last qs) `catchError` const (throwError err) pure (qn', fields ++ [name]) project e k = Project k e NoInfo loc -- resolveExp (Lambda params body ret NoInfo loc) = do checkForDuplicateNames [] params resolveParams params $ \params' -> do body' <- resolveExp body ret' <- traverse resolveTypeExp ret pure $ Lambda params' body' ret' NoInfo loc -- resolveExp (QualParens (modname, modnameloc) e loc) = do (modname', mod) <- lookupMod loc modname case mod of ModEnv env -> localEnv (qualifyEnv modname' env) $ do e' <- resolveExp e pure $ QualParens (modname', modnameloc) e' loc ModFun {} -> typeError loc mempty . withIndexLink "module-is-parametric" $ "Module" <+> pretty modname <+> " is a parametric module." where qualifyEnv modname' env = env {envNameMap = qualify' modname' <$> envNameMap env} qualify' modname' (QualName qs name) = QualName (qualQuals modname' ++ [qualLeaf modname'] ++ qs) name -- -- The tedious recursive cases. resolveExp (Parens e loc) = Parens <$> resolveExp e <*> pure loc resolveExp (Attr attr e loc) = Attr <$> resolveAttrInfo attr <*> resolveExp e <*> pure loc resolveExp (TupLit es loc) = TupLit <$> mapM resolveExp es <*> pure loc resolveExp (ArrayVal vs t loc) = pure $ ArrayVal vs t loc resolveExp (ArrayLit es NoInfo loc) = ArrayLit <$> mapM resolveExp es <*> pure NoInfo <*> pure loc resolveExp (Negate e loc) = Negate <$> resolveExp e <*> pure loc resolveExp (Not e loc) = Not <$> resolveExp e <*> pure loc resolveExp (Assert e1 e2 NoInfo loc) = Assert <$> resolveExp e1 <*> resolveExp e2 <*> pure NoInfo <*> pure loc resolveExp (RecordLit fs loc) = RecordLit <$> mapM resolveField fs <*> pure loc where resolveField (RecordFieldExplicit k e floc) = RecordFieldExplicit k <$> resolveExp e <*> pure floc resolveField (RecordFieldImplicit (L vnloc vn) NoInfo floc) = RecordFieldImplicit <$> (L vnloc <$> resolveName vn floc) <*> pure NoInfo <*> pure floc resolveExp (Project k e NoInfo loc) = Project k <$> resolveExp e <*> pure NoInfo <*> pure loc resolveExp (Constr k es NoInfo loc) = Constr k <$> mapM resolveExp es <*> pure NoInfo <*> pure loc resolveExp (Update e1 slice e2 loc) = Update <$> resolveExp e1 <*> resolveSlice slice <*> resolveExp e2 <*> pure loc resolveExp (RecordUpdate e1 fs e2 NoInfo loc) = RecordUpdate <$> resolveExp e1 <*> pure fs <*> resolveExp e2 <*> pure NoInfo <*> pure loc resolveExp (OpSection v NoInfo loc) = OpSection <$> resolveQualName v loc <*> pure NoInfo <*> pure loc resolveExp (OpSectionLeft v info1 e info2 info3 loc) = OpSectionLeft <$> resolveQualName v loc <*> pure info1 <*> resolveExp e <*> pure info2 <*> pure info3 <*> pure loc resolveExp (OpSectionRight v info1 e info2 info3 loc) = OpSectionRight <$> resolveQualName v loc <*> pure info1 <*> resolveExp e <*> pure info2 <*> pure info3 <*> pure loc resolveExp (ProjectSection ks info loc) = pure $ ProjectSection ks info loc resolveExp (IndexSection slice info loc) = IndexSection <$> resolveSlice slice <*> pure info <*> pure loc resolveExp (Ascript e te loc) = Ascript <$> resolveExp e <*> resolveTypeExp te <*> pure loc resolveExp (Coerce e te info loc) = Coerce <$> resolveExp e <*> resolveTypeExp te <*> pure info <*> pure loc resolveExp (AppExp e NoInfo) = AppExp <$> resolveAppExp e <*> pure NoInfo sizeBinderToParam :: SizeBinder Name -> UncheckedTypeParam sizeBinderToParam (SizeBinder v loc) = TypeParamDim v loc patternExp :: UncheckedPat t -> TypeM (ExpBase NoInfo VName) patternExp (Id v _ loc) = Var <$> resolveQualName (qualName v) loc <*> pure NoInfo <*> pure loc patternExp (TuplePat pats loc) = TupLit <$> mapM patternExp pats <*> pure loc patternExp (Wildcard _ loc) = typeError loc mempty "Cannot have wildcard here." patternExp (PatLit _ _ loc) = typeError loc mempty "Cannot have literal here." patternExp (PatConstr _ _ _ loc) = typeError loc mempty "Cannot have constructor here." patternExp (PatAttr _ p _) = patternExp p patternExp (PatAscription pat _ _) = patternExp pat patternExp (PatParens pat _) = patternExp pat patternExp (RecordPat fs loc) = RecordLit <$> mapM field fs <*> pure loc where field (L nameloc name, pat) = RecordFieldExplicit (L nameloc name) <$> patternExp pat <*> pure (srclocOf loc) resolveAppExp :: AppExpBase NoInfo Name -> TypeM (AppExpBase NoInfo VName) resolveAppExp (Apply f args loc) = Apply <$> resolveExp f <*> traverse (traverse resolveExp) args <*> pure loc resolveAppExp (Range e1 e2 e3 loc) = Range <$> resolveExp e1 <*> traverse resolveExp e2 <*> traverse resolveExp e3 <*> pure loc resolveAppExp (If e1 e2 e3 loc) = If <$> resolveExp e1 <*> resolveExp e2 <*> resolveExp e3 <*> pure loc resolveAppExp (Match e cases loc) = Match <$> resolveExp e <*> mapM resolveCase cases <*> pure loc where resolveCase (CasePat p body cloc) = resolvePat p $ \p' -> CasePat p' <$> resolveExp body <*> pure cloc resolveAppExp (LetPat sizes p e1 e2 loc) = do checkForDuplicateNames (map sizeBinderToParam sizes) [p] e1' <- resolveExp e1 resolveSizes sizes $ \sizes' -> do resolvePat p $ \p' -> do e2' <- resolveExp e2 pure $ LetPat sizes' p' e1' e2' loc resolveAppExp (LetFun fname (tparams, params, ret, NoInfo, fbody) body loc) = do checkForDuplicateNames tparams params checkDoNotShadow loc fname (tparams', params', ret', fbody') <- resolveTypeParams tparams $ \tparams' -> resolveParams params $ \params' -> do ret' <- traverse resolveTypeExp ret (tparams',params',ret',) <$> resolveExp fbody bindSpaced1 Term fname loc $ \fname' -> do body' <- resolveExp body pure $ LetFun fname' (tparams', params', ret', NoInfo, fbody') body' loc resolveAppExp (LetWith (Ident dst _ dstloc) (Ident src _ srcloc) slice e1 e2 loc) = do src' <- Ident <$> resolveName src srcloc <*> pure NoInfo <*> pure srcloc e1' <- resolveExp e1 slice' <- resolveSlice slice bindSpaced1 Term dst loc $ \dstv -> do let dst' = Ident dstv NoInfo dstloc e2' <- resolveExp e2 pure $ LetWith dst' src' slice' e1' e2' loc resolveAppExp (BinOp (f, floc) finfo (e1, info1) (e2, info2) loc) = do f' <- resolveQualName f floc e1' <- resolveExp e1 e2' <- resolveExp e2 pure $ BinOp (f', floc) finfo (e1', info1) (e2', info2) loc resolveAppExp (Index e1 slice loc) = Index <$> resolveExp e1 <*> resolveSlice slice <*> pure loc resolveAppExp (Loop sizes pat loopinit form body loc) = do e' <- case loopinit of LoopInitExplicit e -> LoopInitExplicit <$> resolveExp e LoopInitImplicit NoInfo -> LoopInitExplicit <$> patternExp pat case form of For (Ident i _ iloc) bound -> do bound' <- resolveExp bound bindSpaced1 Term i iloc $ \iv -> do let i' = Ident iv NoInfo iloc resolvePat pat $ \pat' -> do body' <- resolveExp body pure $ Loop sizes pat' e' (For i' bound') body' loc ForIn elemp arr -> do arr' <- resolveExp arr resolvePat elemp $ \elemp' -> resolvePat pat $ \pat' -> do body' <- resolveExp body pure $ Loop sizes pat' e' (ForIn elemp' arr') body' loc While cond -> resolvePat pat $ \pat' -> do cond' <- resolveExp cond body' <- resolveExp body pure $ Loop sizes pat' e' (While cond') body' loc resolveSlice :: SliceBase NoInfo Name -> TypeM (SliceBase NoInfo VName) resolveSlice = mapM onDimIndex where onDimIndex (DimFix e) = DimFix <$> resolveExp e onDimIndex (DimSlice e1 e2 e3) = DimSlice <$> traverse resolveExp e1 <*> traverse resolveExp e2 <*> traverse resolveExp e3 resolvePat :: PatBase NoInfo Name t -> (PatBase NoInfo VName t -> TypeM a) -> TypeM a resolvePat outer m = do outer' <- resolve outer bindIdents (patIdents outer') $ m outer' where resolve (Id v NoInfo loc) = do checkDoNotShadow loc v Id <$> newID v <*> pure NoInfo <*> pure loc resolve (Wildcard NoInfo loc) = pure $ Wildcard NoInfo loc resolve (PatParens p loc) = PatParens <$> resolve p <*> pure loc resolve (TuplePat ps loc) = TuplePat <$> mapM resolve ps <*> pure loc resolve (RecordPat ps loc) = RecordPat <$> mapM (traverse resolve) ps <*> pure loc resolve (PatAscription p t loc) = PatAscription <$> resolve p <*> resolveTypeExp t <*> pure loc resolve (PatLit l NoInfo loc) = pure $ PatLit l NoInfo loc resolve (PatConstr k NoInfo ps loc) = PatConstr k NoInfo <$> mapM resolve ps <*> pure loc resolve (PatAttr attr p loc) = PatAttr <$> resolveAttrInfo attr <*> resolve p <*> pure loc resolveParams :: [PatBase NoInfo Name ParamType] -> ([PatBase NoInfo VName ParamType] -> TypeM a) -> TypeM a resolveParams [] m = m [] resolveParams (p : ps) m = resolvePat p $ \p' -> resolveParams ps (m . (p' :)) -- | @resolveTypeParams ps m@ resolves the type parameters @ps@, then -- invokes the continuation @m@ with the resolveed parameters, while -- extending the monadic name map with @ps@. resolveTypeParams :: [TypeParamBase Name] -> ([TypeParamBase VName] -> TypeM a) -> TypeM a resolveTypeParams ps m = bindSpaced (map typeParamSpace ps) $ \_ -> m =<< evalStateT (mapM checkTypeParam ps) mempty where typeParamSpace (TypeParamDim pv loc) = (Term, pv, loc) typeParamSpace (TypeParamType _ pv loc) = (Type, pv, loc) checkParamName ns v loc = do seen <- gets $ M.lookup (ns, v) case seen of Just prev -> lift $ typeError loc mempty $ "Type parameter" <+> dquotes (pretty v) <+> "previously defined at" <+> pretty (locStr prev) <> "." Nothing -> do modify $ M.insert (ns, v) loc lift $ checkName ns v loc checkTypeParam (TypeParamDim pv loc) = TypeParamDim <$> checkParamName Term pv loc <*> pure loc checkTypeParam (TypeParamType l pv loc) = TypeParamType l <$> checkParamName Type pv loc <*> pure loc resolveSizes :: [SizeBinder Name] -> ([SizeBinder VName] -> TypeM a) -> TypeM a resolveSizes [] m = m [] -- Minor optimisation. resolveSizes sizes m = do foldM_ lookForDuplicates mempty sizes bindSpaced (map sizeWithSpace sizes) $ \sizes' -> m $ zipWith SizeBinder sizes' $ map srclocOf sizes where lookForDuplicates prev size | Just (_, prevloc) <- L.find ((== sizeName size) . fst) prev = typeError size mempty $ "Size name also bound at " <> pretty (locStrRel (srclocOf size) prevloc) <> "." | otherwise = pure $ (sizeName size, srclocOf size) : prev sizeWithSpace size = (Term, sizeName size, srclocOf size) -- | Resolve names in a value binding. If this succeeds, then it is -- guaranteed that all names references things that are in scope. resolveValBind :: ValBindBase NoInfo Name -> TypeM (ValBindBase NoInfo VName) resolveValBind (ValBind entry fname ret NoInfo tparams params body doc attrs loc) = do attrs' <- mapM resolveAttrInfo attrs checkForDuplicateNames tparams params checkDoNotShadow loc fname resolveTypeParams tparams $ \tparams' -> resolveParams params $ \params' -> do ret' <- traverse resolveTypeExp ret body' <- resolveExp body bindSpaced1 Term fname loc $ \fname' -> do usedName fname' pure $ ValBind entry fname' ret' NoInfo tparams' params' body' doc attrs' loc futhark-0.25.27/src/Language/Futhark/TypeChecker/Terms.hs000066400000000000000000001667261475065116200231360ustar00rootroot00000000000000-- | Facilities for type-checking Futhark terms. Checking a term -- requires a little more context to track uniqueness and such. -- -- Type inference is implemented through a variation of -- Hindley-Milner. The main complication is supporting the rich -- number of built-in language constructs, as well as uniqueness -- types. This is mostly done in an ad hoc way, and many programs -- will require the programmer to fall back on type annotations. module Language.Futhark.TypeChecker.Terms ( checkOneExp, checkSizeExp, checkFunDef, ) where import Control.Monad import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State.Strict import Data.Bifunctor import Data.Bitraversable import Data.Char (isAscii) import Data.Either import Data.List (delete, find, genericLength, partition) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.Util (mapAccumLM, nubOrd) import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Primitive (intByteSize) import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Consumption qualified as Consumption import Language.Futhark.TypeChecker.Match import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod) import Language.Futhark.TypeChecker.Terms.Loop import Language.Futhark.TypeChecker.Terms.Monad import Language.Futhark.TypeChecker.Terms.Pat import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify import Prelude hiding (mod) hasBinding :: Exp -> Bool hasBinding Lambda {} = True hasBinding (AppExp LetPat {} _) = True hasBinding (AppExp LetFun {} _) = True hasBinding (AppExp Loop {} _) = True hasBinding (AppExp LetWith {} _) = True hasBinding (AppExp Match {} _) = True hasBinding e = isNothing $ astMap m e where m = identityMapper {mapOnExp = \e' -> if hasBinding e' then Nothing else Just e'} overloadedTypeVars :: Constraints -> Names overloadedTypeVars = mconcat . map f . M.elems where f (_, HasFields _ fs _) = mconcat $ map typeVars $ M.elems fs f _ = mempty --- Basic checking -- | Determine if the two types are identical, ignoring uniqueness. -- Mismatched dimensions are turned into fresh rigid type variables. -- Causes a 'TypeError' if they fail to match, and otherwise returns -- one of them. unifyBranchTypes :: SrcLoc -> StructType -> StructType -> TermTypeM (StructType, [VName]) unifyBranchTypes loc t1 t2 = onFailure (CheckingBranches t1 t2) $ unifyMostCommon (mkUsage loc "unification of branch results") t1 t2 unifyBranches :: SrcLoc -> Exp -> Exp -> TermTypeM (StructType, [VName]) unifyBranches loc e1 e2 = do e1_t <- expTypeFully e1 e2_t <- expTypeFully e2 unifyBranchTypes loc e1_t e2_t sliceShape :: Maybe (SrcLoc, Rigidity) -> [DimIndex] -> TypeBase Size as -> TermTypeM (TypeBase Size as, [VName]) sliceShape r slice t@(Array u (Shape orig_dims) et) = runStateT (setDims <$> adjustDims slice orig_dims) [] where setDims [] = stripArray (length orig_dims) t setDims dims' = Array u (Shape dims') et -- If the result is supposed to be a nonrigid size variable, then -- don't bother trying to create non-existential sizes. This is -- necessary to make programs type-check without too much -- ceremony; see e.g. tests/inplace5.fut. isRigid Rigid {} = True isRigid _ = False refine_sizes = maybe False (isRigid . snd) r sliceSize orig_d i j stride = case r of Just (loc, Rigid _) -> do (d, ext) <- lift . extSize loc $ SourceSlice orig_d' (bareExp <$> i) (bareExp <$> j) (bareExp <$> stride) modify (maybeToList ext ++) pure d Just (loc, Nonrigid) -> lift $ flip sizeFromName loc . qualName <$> newFlexibleDim (mkUsage loc "size of slice") "slice_dim" Nothing -> do v <- lift $ newID "slice_anydim" modify (v :) pure $ sizeFromName (qualName v) mempty where -- The original size does not matter if the slice is fully specified. orig_d' | isJust i, isJust j = Nothing | otherwise = Just orig_d warnIfBinding binds d i j stride size = if binds then do lift . warn (srclocOf size) $ withIndexLink "size-expression-bind" "Size expression with binding is replaced by unknown size." (:) <$> sliceSize d i j stride else pure (size :) adjustDims (DimFix {} : idxes') (_ : dims) = adjustDims idxes' dims -- Pat match some known slices to be non-existential. adjustDims (DimSlice i j stride : idxes') (d : dims) | refine_sizes, maybe True ((== Just 0) . isInt64) i, maybe True ((== Just 1) . isInt64) stride = do let binds = maybe False hasBinding j warnIfBinding binds d i j stride (fromMaybe d j) <*> adjustDims idxes' dims adjustDims ((DimSlice i j stride) : idxes') (d : dims) | refine_sizes, Just i' <- i, -- if i ~ 0, previous case maybe True ((== Just 1) . isInt64) stride = do let j' = fromMaybe d j binds = hasBinding j' || hasBinding i' warnIfBinding binds d i j stride (sizeMinus j' i') <*> adjustDims idxes' dims -- stride == -1 adjustDims ((DimSlice Nothing Nothing stride) : idxes') (d : dims) | refine_sizes, maybe True ((== Just (-1)) . isInt64) stride = (d :) <$> adjustDims idxes' dims adjustDims ((DimSlice (Just i) (Just j) stride) : idxes') (d : dims) | refine_sizes, maybe True ((== Just (-1)) . isInt64) stride = do let binds = hasBinding i || hasBinding j warnIfBinding binds d (Just i) (Just j) stride (sizeMinus i j) <*> adjustDims idxes' dims -- existential adjustDims ((DimSlice i j stride) : idxes') (d : dims) = (:) <$> sliceSize d i j stride <*> adjustDims idxes' dims adjustDims _ dims = pure dims sizeMinus j i = AppExp ( BinOp (qualName (intrinsicVar "-"), mempty) sizeBinOpInfo (j, Info Nothing) (i, Info Nothing) mempty ) $ Info $ AppRes i64 [] i64 = Scalar $ Prim $ Signed Int64 sizeBinOpInfo = Info $ foldFunType [i64, i64] $ RetType [] i64 sliceShape _ _ t = pure (t, []) --- Main checkers checkAscript :: SrcLoc -> TypeExp (ExpBase NoInfo VName) VName -> ExpBase NoInfo VName -> TermTypeM (TypeExp Exp VName, Exp) checkAscript loc te e = do (te', decl_t, _) <- checkTypeExpNonrigid te e' <- checkExp e e_t <- expTypeFully e' onFailure (CheckingAscription (toStruct decl_t) e_t) $ unify (mkUsage loc "type ascription") (toStruct decl_t) e_t pure (te', e') checkCoerce :: SrcLoc -> TypeExp (ExpBase NoInfo VName) VName -> ExpBase NoInfo VName -> TermTypeM (TypeExp Exp VName, StructType, Exp) checkCoerce loc te e = do (te', te_t, ext) <- checkTypeExpNonrigid te e' <- checkExp e e_t <- expTypeFully e' te_t_nonrigid <- makeNonExtFresh ext $ toStruct te_t onFailure (CheckingAscription (toStruct te_t) e_t) $ unify (mkUsage loc "size coercion") e_t te_t_nonrigid -- If the type expression had any anonymous dimensions, these will -- now be in 'ext'. Those we keep nonrigid and unify with e_t. -- This ensures that 'x :> [1][]i32' does not make the second -- dimension unknown. Use of matchDims is sensible because the -- structure of e_t' will be fully known due to the unification, and -- te_t because type expressions are complete. pure (te', toStruct te_t, e') where makeNonExtFresh ext = bitraverse onDim pure where onDim d@(Var v _ _) | qualLeaf v `elem` ext = pure d onDim d = do v <- newTypeName "coerce" constrain v . Size Nothing $ mkUsage loc "a size coercion where the underlying expression size cannot be determined" pure $ sizeFromName (qualName v) (srclocOf d) -- Used to remove unknown sizes from function body types before we -- perform let-generalisation. This is because if a function is -- inferred to return something of type '[x+y]t' where 'x' or 'y' are -- unknown, we want to turn that into '[z]t', where ''z' is a fresh -- unknown, which is then by let-generalisation turned into -- '?[z].[z]t'. unscopeUnknown :: TypeBase Size u -> TermTypeM (TypeBase Size u) unscopeUnknown t = do constraints <- getConstraints -- These sizes will be immediately turned into existentials, so we -- do not need to care about their location. fst <$> sizeFree mempty (expKiller constraints) t where expKiller _ Var {} = Nothing expKiller constraints e = S.lookupMin $ S.filter (isUnknown constraints) $ (`S.difference` witnesses) $ fvVars $ freeInExp e isUnknown constraints vn | Just UnknownSize {} <- snd <$> M.lookup vn constraints = True isUnknown _ _ = False (witnesses, _) = determineSizeWitnesses $ toStruct t unscopeType :: SrcLoc -> [VName] -> TypeBase Size as -> TermTypeM (TypeBase Size as, [VName]) unscopeType tloc unscoped = sizeFree tloc $ find (`elem` unscoped) . fvVars . freeInExp checkExp :: ExpBase NoInfo VName -> TermTypeM Exp checkExp (Literal val loc) = pure $ Literal val loc checkExp (Hole _ loc) = do t <- newTypeVar loc "t" pure $ Hole (Info t) loc checkExp (StringLit vs loc) = pure $ StringLit vs loc checkExp (IntLit val NoInfo loc) = do t <- newTypeVar loc "t" mustBeOneOf anyNumberType (mkUsage loc "integer literal") t pure $ IntLit val (Info t) loc checkExp (FloatLit val NoInfo loc) = do t <- newTypeVar loc "t" mustBeOneOf anyFloatType (mkUsage loc "float literal") t pure $ FloatLit val (Info t) loc checkExp (TupLit es loc) = TupLit <$> mapM checkExp es <*> pure loc checkExp (RecordLit fs loc) = RecordLit <$> evalStateT (mapM checkField fs) mempty <*> pure loc where checkField (RecordFieldExplicit f e rloc) = do errIfAlreadySet (unLoc f) rloc modify $ M.insert (unLoc f) rloc RecordFieldExplicit f <$> lift (checkExp e) <*> pure rloc checkField (RecordFieldImplicit name NoInfo rloc) = do errIfAlreadySet (baseName (unLoc name)) rloc t <- lift $ lookupVar rloc $ qualName $ unLoc name modify $ M.insert (baseName (unLoc name)) rloc pure $ RecordFieldImplicit name (Info t) rloc errIfAlreadySet f rloc = do maybe_sloc <- gets $ M.lookup f case maybe_sloc of Just sloc -> lift . typeError rloc mempty $ "Field" <+> dquotes (pretty f) <+> "previously defined at" <+> pretty (locStrRel rloc sloc) <> "." Nothing -> pure () -- No need to type check this, as these are only produced by the -- parser if the elements are monomorphic and all match. checkExp (ArrayVal vs t loc) = pure $ ArrayVal vs t loc checkExp (ArrayLit all_es _ loc) = -- Construct the result type and unify all elements with it. We -- only create a type variable for empty arrays; otherwise we use -- the type of the first element. This significantly cuts down on -- the number of type variables generated for pathologically large -- multidimensional array literals. case all_es of [] -> do et <- newTypeVar loc "t" t <- arrayOfM loc et (Shape [sizeFromInteger 0 mempty]) pure $ ArrayLit [] (Info t) loc e : es -> do e' <- checkExp e et <- expType e' es' <- mapM (unifies "type of first array element" et <=< checkExp) es t <- arrayOfM loc et (Shape [sizeFromInteger (genericLength all_es) mempty]) pure $ ArrayLit (e' : es') (Info t) loc checkExp (AppExp (Range start maybe_step end loc) _) = do start' <- require "use in range expression" anySignedType =<< checkExp start start_t <- expType start' maybe_step' <- case maybe_step of Nothing -> pure Nothing Just step -> do let warning = warn loc "First and second element of range are identical, this will produce an empty array." case (start, step) of (Literal x _, Literal y _) -> when (x == y) warning (Var x_name _ _, Var y_name _ _) -> when (x_name == y_name) warning _ -> pure () Just <$> (unifies "use in range expression" start_t =<< checkExp step) let unifyRange e = unifies "use in range expression" start_t =<< checkExp e end' <- traverse unifyRange end end_t <- case end' of DownToExclusive e -> expType e ToInclusive e -> expType e UpToExclusive e -> expType e -- Special case some ranges to give them a known size. let warnIfBinding binds size = if binds then do warn (srclocOf size) $ withIndexLink "size-expression-bind" "Size expression with binding is replaced by unknown size." d <- newRigidDim loc RigidRange "range_dim" pure (sizeFromName (qualName d) mempty, Just d) else pure (size, Nothing) (dim, retext) <- case (isInt64 start', isInt64 <$> maybe_step', end') of (Just 0, Just (Just 1), UpToExclusive end'') | Scalar (Prim (Signed Int64)) <- end_t -> warnIfBinding (hasBinding end'') end'' (Just 0, Nothing, UpToExclusive end'') | Scalar (Prim (Signed Int64)) <- end_t -> warnIfBinding (hasBinding end'') end'' (_, Nothing, UpToExclusive end'') | Scalar (Prim (Signed Int64)) <- end_t -> warnIfBinding (hasBinding end'' || hasBinding start') $ sizeMinus end'' start' (_, Nothing, ToInclusive end'') -- No stride means we assume a stride of one. | Scalar (Prim (Signed Int64)) <- end_t -> warnIfBinding (hasBinding end'' || hasBinding start') $ sizeMinusInc end'' start' (Just 1, Just (Just 2), ToInclusive end'') | Scalar (Prim (Signed Int64)) <- end_t -> warnIfBinding (hasBinding end'') end'' _ -> do d <- newRigidDim loc RigidRange "range_dim" pure (sizeFromName (qualName d) mempty, Just d) t <- arrayOfM loc start_t (Shape [dim]) let res = AppRes t (maybeToList retext) pure $ AppExp (Range start' maybe_step' end' loc) (Info res) where i64 = Scalar $ Prim $ Signed Int64 mkBinOp op t x y = AppExp ( BinOp (qualName (intrinsicVar op), mempty) sizeBinOpInfo (x, Info Nothing) (y, Info Nothing) mempty ) (Info $ AppRes t []) mkSub = mkBinOp "-" i64 mkAdd = mkBinOp "+" i64 sizeMinus j i = j `mkSub` i sizeMinusInc j i = (j `mkSub` i) `mkAdd` sizeFromInteger 1 mempty sizeBinOpInfo = Info $ foldFunType [i64, i64] $ RetType [] i64 checkExp (Ascript e te loc) = do (te', e') <- checkAscript loc te e pure $ Ascript e' te' loc checkExp (Coerce e te NoInfo loc) = do (te', te_t, e') <- checkCoerce loc te e t <- expTypeFully e' t' <- matchDims (const . const pure) t te_t pure $ Coerce e' te' (Info t') loc checkExp (AppExp (BinOp (op, oploc) NoInfo (e1, _) (e2, _) loc) NoInfo) = do ftype <- lookupVar oploc op e1' <- checkExp e1 e2' <- checkExp e2 -- Note that the application to the first operand cannot fix any -- existential sizes, because it must by necessity be a function. (_, rt, p1_ext, _) <- checkApply loc (Just op, 0) ftype e1' (_, rt', p2_ext, retext) <- checkApply loc (Just op, 1) rt e2' pure $ AppExp ( BinOp (op, oploc) (Info ftype) (e1', Info p1_ext) (e2', Info p2_ext) loc ) (Info (AppRes rt' retext)) checkExp (Project k e NoInfo loc) = do e' <- checkExp e t <- expType e' kt <- mustHaveField (mkUsage loc $ docText $ "projection of field " <> dquotes (pretty k)) k t pure $ Project k e' (Info kt) loc checkExp (AppExp (If e1 e2 e3 loc) _) = do e1' <- checkExp e1 e2' <- checkExp e2 e3' <- checkExp e3 let bool = Scalar $ Prim Bool e1_t <- expType e1' onFailure (CheckingRequired [bool] e1_t) $ unify (mkUsage e1' "use as 'if' condition") bool e1_t (brancht, retext) <- unifyBranches loc e2' e3' zeroOrderType (mkUsage loc "returning value of this type from 'if' expression") "type returned from branch" brancht pure $ AppExp (If e1' e2' e3' loc) (Info $ AppRes brancht retext) checkExp (Parens e loc) = Parens <$> checkExp e <*> pure loc checkExp (QualParens (modname, modnameloc) e loc) = do mod <- lookupMod modname case mod of ModEnv env -> local (`withEnv` env) $ do e' <- checkExp e pure $ QualParens (modname, modnameloc) e' loc ModFun {} -> typeError loc mempty . withIndexLink "module-is-parametric" $ "Module" <+> pretty modname <+> " is a parametric module." checkExp (Var qn NoInfo loc) = do t <- lookupVar loc qn pure $ Var qn (Info t) loc checkExp (Negate arg loc) = do arg' <- require "numeric negation" anyNumberType =<< checkExp arg pure $ Negate arg' loc checkExp (Not arg loc) = do arg' <- require "logical negation" (Bool : anyIntType) =<< checkExp arg pure $ Not arg' loc checkExp (AppExp (Apply fe args loc) NoInfo) = do fe' <- checkExp fe args' <- mapM (checkExp . snd) args t <- expType fe' let fname = case fe' of Var v _ _ -> Just v _ -> Nothing ((_, exts, rt), args'') <- mapAccumLM (onArg fname) (0, [], t) args' pure $ AppExp (Apply fe' args'' loc) $ Info $ AppRes rt exts where onArg fname (i, all_exts, t) arg' = do (_, rt, argext, exts) <- checkApply loc (fname, i) t arg' pure ( (i + 1, all_exts <> exts, rt), (Info argext, arg') ) checkExp (AppExp (LetPat sizes pat e body loc) _) = do e' <- checkExp e -- Not technically an ascription, but we want the pattern to have -- exactly the type of 'e'. t <- expType e' bindingSizes sizes . incLevel . bindingPat sizes pat t $ \pat' -> do body' <- incLevel $ checkExp body body_t <- expTypeFully body' -- If the bound expression is of type i64, then we replace the -- pattern name with the expression in the type of the body. -- Otherwise, we need to come up with unknown sizes for the -- sizes going out of scope. t' <- normType t -- Might be overloaded integer until now. (body_t', retext) <- case (t', patNames pat') of (Scalar (Prim (Signed Int64)), [v]) | not $ hasBinding e' -> do let f x = if x == v then Just (ExpSubst e') else Nothing pure (applySubst f body_t, []) _ -> unscopeType loc (map sizeName sizes <> patNames pat') body_t pure $ AppExp (LetPat sizes (fmap toStruct pat') e' body' loc) (Info $ AppRes body_t' retext) checkExp (AppExp (LetFun name (tparams, params, maybe_retdecl, NoInfo, e) body loc) _) = do (tparams', params', maybe_retdecl', rettype, e') <- checkBinding (name, maybe_retdecl, tparams, params, e, loc) let entry = BoundV tparams' $ funType params' rettype bindF scope = scope { scopeVtable = M.insert name entry $ scopeVtable scope } body' <- localScope bindF $ checkExp body (body_t, ext) <- unscopeType loc [name] =<< expTypeFully body' pure $ AppExp ( LetFun name (tparams', params', maybe_retdecl', Info rettype, e') body' loc ) (Info $ AppRes body_t ext) checkExp (AppExp (LetWith dest src slice ve body loc) _) = do src' <- checkIdent src slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage src "type of source array") "src" $ sliceDims slice' unify (mkUsage loc "type of target array") t $ unInfo $ identType src' (elemt, _) <- sliceShape (Just (loc, Nonrigid)) slice' =<< normTypeFully t ve' <- unifies "type of target array" elemt =<< checkExp ve bindingIdent dest (unInfo (identType src')) $ \dest' -> do body' <- checkExp body (body_t, ext) <- unscopeType loc [identName dest'] =<< expTypeFully body' pure $ AppExp (LetWith dest' src' slice' ve' body' loc) (Info $ AppRes body_t ext) checkExp (Update src slice ve loc) = do slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage' src) "src" $ sliceDims slice' (elemt, _) <- sliceShape (Just (loc, Nonrigid)) slice' =<< normTypeFully t ve' <- unifies "type of target array" elemt =<< checkExp ve src' <- unifies "type of target array" t =<< checkExp src pure $ Update src' slice' ve' loc -- Record updates are a bit hacky, because we do not have row typing -- (yet?). For now, we only permit record updates where we know the -- full type up to the field we are updating. checkExp (RecordUpdate src fields ve NoInfo loc) = do src' <- checkExp src ve' <- checkExp ve a <- expTypeFully src' foldM_ (flip $ mustHaveField usage) a fields ve_t <- expType ve' updated_t <- updateField fields ve_t =<< expTypeFully src' pure $ RecordUpdate src' fields ve' (Info updated_t) loc where usage = mkUsage loc "record update" updateField [] ve_t src_t = do (src_t', _) <- allDimsFreshInType usage Nonrigid "any" src_t onFailure (CheckingRecordUpdate fields src_t' ve_t) $ unify usage src_t' ve_t pure ve_t updateField (f : fs) ve_t (Scalar (Record m)) | Just f_t <- M.lookup f m = do f_t' <- updateField fs ve_t f_t pure $ Scalar $ Record $ M.insert f f_t' m updateField _ _ _ = typeError loc mempty . withIndexLink "record-type-not-known" $ "Full type of" indent 2 (pretty src) textwrap " is not known at this point. Add a type annotation to the original record to disambiguate." -- checkExp (AppExp (Index e slice loc) _) = do slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage' loc) "e" $ sliceDims slice' e' <- unifies "being indexed at" t =<< checkExp e -- XXX, the RigidSlice here will be overridden in sliceShape with a proper value. (t', retext) <- sliceShape (Just (loc, Rigid (RigidSlice Nothing ""))) slice' =<< expTypeFully e' pure $ AppExp (Index e' slice' loc) (Info $ AppRes t' retext) checkExp (Assert e1 e2 NoInfo loc) = do e1' <- require "being asserted" [Bool] =<< checkExp e1 e2' <- checkExp e2 pure $ Assert e1' e2' (Info (prettyText e1)) loc checkExp (Lambda params body rettype_te NoInfo loc) = do (params', body', rettype', RetType dims ty) <- incLevel . bindingParams [] params $ \params' -> do rettype_checked <- traverse checkTypeExpNonrigid rettype_te let declared_rettype = case rettype_checked of Just (_, st, _) -> Just st Nothing -> Nothing body' <- checkFunBody params' body declared_rettype loc body_t <- expTypeFully body' params'' <- mapM updateTypes params' (rettype', rettype_st) <- case rettype_checked of Just (te, st, ext) -> pure (Just te, RetType ext st) Nothing -> do ret <- inferReturnSizes params'' $ toRes Nonunique body_t pure (Nothing, ret) pure (params'', body', rettype', rettype_st) verifyFunctionParams Nothing params' (ty', dims') <- unscopeType loc dims ty pure $ Lambda params' body' rettype' (Info (RetType dims' ty')) loc where -- Inferring the sizes of the return type of a lambda is a lot -- like let-generalisation. We wish to remove any rigid sizes -- that were created when checking the body, except for those that -- are visible in types that existed before we entered the body, -- are parameters, or are used in parameters. inferReturnSizes params' ret = do cur_lvl <- curLevel let named (Named x, _, _) = Just x named (Unnamed, _, _) = Nothing param_names = mapMaybe (named . patternParam) params' pos_sizes = sizeNamesPos $ funType params' $ RetType [] ret hide k (lvl, _) = lvl >= cur_lvl && k `notElem` param_names && k `S.notMember` pos_sizes hidden_sizes <- S.fromList . M.keys . M.filterWithKey hide <$> getConstraints let onDim name | name `S.member` hidden_sizes = S.singleton name onDim _ = mempty pure $ RetType (S.toList $ foldMap onDim $ fvVars $ freeInType ret) ret checkExp (OpSection op _ loc) = do ftype <- lookupVar loc op pure $ OpSection op (Info ftype) loc checkExp (OpSectionLeft op _ e _ _ loc) = do ftype <- lookupVar loc op e' <- checkExp e (t1, rt, argext, retext) <- checkApply loc (Just op, 0) ftype e' case (ftype, rt) of (Scalar (Arrow _ m1 d1 _ _), Scalar (Arrow _ m2 d2 t2 rettype)) -> pure $ OpSectionLeft op (Info ftype) e' (Info (m1, toParam d1 t1, argext), Info (m2, toParam d2 t2)) (Info rettype, Info retext) loc _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype checkExp (OpSectionRight op _ e _ NoInfo loc) = do ftype <- lookupVar loc op e' <- checkExp e case ftype of Scalar (Arrow _ m1 d1 t1 (RetType [] (Scalar (Arrow _ m2 d2 t2 (RetType dims2 ret))))) -> do (t2', arrow', argext, _) <- checkApply loc (Just op, 1) (Scalar $ Arrow mempty m2 d2 t2 $ RetType [] $ Scalar $ Arrow Nonunique m1 d1 t1 $ RetType dims2 ret) e' case arrow' of Scalar (Arrow _ _ _ t1' (RetType dims2' ret')) -> pure $ OpSectionRight op (Info ftype) e' (Info (m1, toParam d1 t1'), Info (m2, toParam d2 t2', argext)) (Info $ RetType dims2' ret') loc _ -> error $ "OpSectionRight: impossible type\n" <> prettyString arrow' _ -> typeError loc mempty $ "Operator section with invalid operator of type" <+> pretty ftype checkExp (ProjectSection fields NoInfo loc) = do a <- newTypeVar loc "a" let usage = mkUsage loc "projection at" b <- foldM (flip $ mustHaveField usage) a fields let ft = Scalar $ Arrow mempty Unnamed Observe a $ RetType [] $ toRes Nonunique b pure $ ProjectSection fields (Info ft) loc checkExp (IndexSection slice NoInfo loc) = do slice' <- checkSlice slice (t, _) <- newArrayType (mkUsage' loc) "e" $ sliceDims slice' (t', retext) <- sliceShape Nothing slice' t let ft = Scalar $ Arrow mempty Unnamed Observe t $ RetType retext $ toRes Nonunique t' pure $ IndexSection slice' (Info ft) loc checkExp (AppExp (Loop _ mergepat loopinit form loopbody loc) _) = do ((sparams, mergepat', loopinit', form', loopbody'), appres) <- checkLoop checkExp (mergepat, loopinit, form, loopbody) loc pure $ AppExp (Loop sparams mergepat' loopinit' form' loopbody' loc) (Info appres) checkExp (Constr name es NoInfo loc) = do t <- newTypeVar loc "t" es' <- mapM checkExp es ets <- mapM expType es' mustHaveConstr (mkUsage loc "use of constructor") name t ets pure $ Constr name es' (Info t) loc checkExp (AppExp (Match e cs loc) _) = do e' <- checkExp e mt <- expType e' (cs', t, retext) <- checkCases mt cs zeroOrderType (mkUsage loc "being returned 'match'") "type returned from pattern match" t pure $ AppExp (Match e' cs' loc) (Info $ AppRes t retext) checkExp (Attr info e loc) = Attr <$> checkAttr info <*> checkExp e <*> pure loc checkCases :: StructType -> NE.NonEmpty (CaseBase NoInfo VName) -> TermTypeM (NE.NonEmpty (CaseBase Info VName), StructType, [VName]) checkCases mt rest_cs = case NE.uncons rest_cs of (c, Nothing) -> do (c', t, retext) <- checkCase mt c pure (NE.singleton c', t, retext) (c, Just cs) -> do ((c', c_t, _), (cs', cs_t, _)) <- (,) <$> checkCase mt c <*> checkCases mt cs (brancht, retext) <- unifyBranchTypes (srclocOf c) c_t cs_t pure (NE.cons c' cs', brancht, retext) checkCase :: StructType -> CaseBase NoInfo VName -> TermTypeM (CaseBase Info VName, StructType, [VName]) checkCase mt (CasePat p e loc) = bindingPat [] p mt $ \p' -> do e' <- checkExp e e_t <- expTypeFully e' (e_t', retext) <- unscopeType loc (patNames p') e_t pure (CasePat (fmap toStruct p') e' loc, e_t', retext) -- | An unmatched pattern. Used in in the generation of -- unmatched pattern warnings by the type checker. data Unmatched p = UnmatchedNum p [PatLit] | UnmatchedBool p | UnmatchedConstr p | Unmatched p deriving (Functor, Show) instance Pretty (Unmatched (Pat StructType)) where pretty um = case um of (UnmatchedNum p nums) -> pretty' p <+> "where p is not one of" <+> pretty nums (UnmatchedBool p) -> pretty' p (UnmatchedConstr p) -> pretty' p (Unmatched p) -> pretty' p where pretty' (PatAscription p t _) = pretty p <> ":" <+> pretty t pretty' (PatParens p _) = parens $ pretty' p pretty' (PatAttr _ p _) = parens $ pretty' p pretty' (Id v _ _) = prettyName v pretty' (TuplePat pats _) = parens $ commasep $ map pretty' pats pretty' (RecordPat fs _) = braces $ commasep $ map ppField fs where ppField (L _ name, t) = pretty (nameToString name) <> equals <> pretty' t pretty' Wildcard {} = "_" pretty' (PatLit e _ _) = pretty e pretty' (PatConstr n _ ps _) = "#" <> pretty n <+> sep (map pretty' ps) checkIdent :: IdentBase NoInfo VName StructType -> TermTypeM (Ident StructType) checkIdent (Ident name _ loc) = do vt <- lookupVar loc $ qualName name pure $ Ident name (Info vt) loc checkSlice :: SliceBase NoInfo VName -> TermTypeM [DimIndex] checkSlice = mapM checkDimIndex where checkDimIndex (DimFix i) = do DimFix <$> (require "use as index" anySignedType =<< checkExp i) checkDimIndex (DimSlice i j s) = DimSlice <$> check i <*> check j <*> check s check = maybe (pure Nothing) $ fmap Just . unifies "use as index" (Scalar $ Prim $ Signed Int64) <=< checkExp -- The number of dimensions affected by this slice (so the minimum -- rank of the array we are slicing). sliceDims :: [DimIndex] -> Int sliceDims = length instantiateDimsInReturnType :: SrcLoc -> Maybe (QualName VName) -> ResRetType -> TermTypeM (ResType, [VName]) instantiateDimsInReturnType loc fname (RetType dims t) | null dims = pure (t, mempty) | otherwise = do dims' <- mapM new dims pure (first (onDim $ zip dims $ map (ExpSubst . (`sizeFromName` loc) . qualName) dims') t, dims') where new = newRigidDim loc (RigidRet fname) . nameFromString . takeWhile isAscii . baseString onDim dims' = applySubst (`lookup` dims') -- Some information about the function/operator we are trying to -- apply, and how many arguments it has previously accepted. Used for -- generating nicer type errors. type ApplyOp = (Maybe (QualName VName), Int) -- | Extract all those names that are bound inside the type. boundInsideType :: TypeBase Size as -> S.Set VName boundInsideType (Array _ _ t) = boundInsideType (Scalar t) boundInsideType (Scalar Prim {}) = mempty boundInsideType (Scalar (TypeVar _ _ targs)) = foldMap f targs where f (TypeArgType t) = boundInsideType t f TypeArgDim {} = mempty boundInsideType (Scalar (Record fs)) = foldMap boundInsideType fs boundInsideType (Scalar (Sum cs)) = foldMap (foldMap boundInsideType) cs boundInsideType (Scalar (Arrow _ pn _ t1 (RetType dims t2))) = pn' <> boundInsideType t1 <> S.fromList dims <> boundInsideType t2 where pn' = case pn of Unnamed -> mempty Named v -> S.singleton v -- Returns the sizes of the immediate type produced, -- the sizes of parameter types, and the sizes of return types. dimUses :: TypeBase Size u -> (Names, Names) dimUses = flip execState mempty . traverseDims f where f bound pos e = case pos of PosImmediate -> modify ((fvVars fv, mempty) <>) PosParam -> modify ((mempty, fvVars fv) <>) PosReturn -> pure () where fv = freeInExp e `freeWithout` bound checkApply :: SrcLoc -> ApplyOp -> StructType -> Exp -> TermTypeM (StructType, StructType, Maybe VName, [VName]) checkApply loc (fname, _) (Scalar (Arrow _ pname _ tp1 tp2)) argexp = do let argtype = typeOf argexp onFailure (CheckingApply fname argexp tp1 argtype) $ do unify (mkUsage argexp "use as function argument") tp1 argtype -- Perform substitutions of instantiated variables in the types. (tp2', ext) <- instantiateDimsInReturnType loc fname =<< normTypeFully tp2 argtype' <- normTypeFully argtype -- Check whether this would produce an impossible return type. let (tp2_produced_dims, tp2_paramdims) = dimUses tp2' problematic = S.fromList ext <> boundInsideType argtype' problem = any (`S.member` problematic) (tp2_paramdims `S.difference` tp2_produced_dims) when (not (S.null problematic) && problem) $ do typeError loc mempty . withIndexLink "existential-param-ret" $ "Existential size would appear in function parameter of return type:" indent 2 (pretty (RetType ext tp2')) textwrap "This is usually because a higher-order function is used with functional arguments that return existential sizes or locally named sizes, which are then used as parameters of other function arguments." (argext, tp2'') <- case pname of Named pname' | S.member pname' (fvVars $ freeInType tp2') -> if hasBinding argexp then do warn (srclocOf argexp) $ withIndexLink "size-expression-bind" "Size expression with binding is replaced by unknown size." d <- newRigidDim argexp (RigidArg fname $ prettyTextOneLine $ bareExp argexp) "n" let parsubst v = if v == pname' then Just $ ExpSubst $ sizeFromName (qualName d) $ srclocOf argexp else Nothing pure (Just d, applySubst parsubst $ toStruct tp2') else let parsubst v = if v == pname' then Just $ ExpSubst $ fromMaybe argexp $ stripExp argexp else Nothing in pure (Nothing, applySubst parsubst $ toStruct tp2') _ -> pure (Nothing, toStruct tp2') pure (tp1, tp2'', argext, ext) checkApply loc fname tfun@(Scalar TypeVar {}) arg = do tv <- newTypeVar loc "b" unify (mkUsage loc "use as function") tfun $ Scalar (Arrow mempty Unnamed Observe (typeOf arg) $ RetType [] $ paramToRes tv) tfun' <- normType tfun checkApply loc fname tfun' arg checkApply loc (fname, prev_applied) ftype argexp = do let fname' = maybe "expression" (dquotes . pretty) fname typeError loc mempty $ if prev_applied == 0 then "Cannot apply" <+> fname' <+> "as function, as it has type:" indent 2 (pretty ftype) else "Cannot apply" <+> fname' <+> "to argument #" <> pretty (prev_applied + 1) <+> dquotes (shorten $ group $ pretty argexp) <> "," "as" <+> fname' <+> "only takes" <+> pretty prev_applied <+> arguments <> "." where arguments | prev_applied == 1 = "argument" | otherwise = "arguments" -- | Type-check a single expression in isolation. This expression may -- turn out to be polymorphic, in which case the list of type -- parameters will be non-empty. checkOneExp :: ExpBase NoInfo VName -> TypeM ([TypeParam], Exp) checkOneExp e = runTermTypeM checkExp $ do e' <- checkExp e let t = typeOf e' (tparams, _, _) <- letGeneralise (nameFromString "") (srclocOf e) [] [] $ toRes Nonunique t fixOverloadedTypes $ typeVars t e'' <- normTypeFully e' localChecks e'' causalityCheck e'' pure (tparams, e'') -- | Type-check a single size expression in isolation. This expression may -- turn out to be polymorphic, in which case it is unified with i64. checkSizeExp :: ExpBase NoInfo VName -> TypeM Exp checkSizeExp e = runTermTypeM checkExp $ do e' <- checkExp e let t = typeOf e' when (hasBinding e') $ typeError (srclocOf e') mempty . withIndexLink "size-expression-bind" $ "Size expression with binding is forbidden." unify (mkUsage e' "Size expression") t (Scalar (Prim (Signed Int64))) normTypeFully e' -- Verify that all sum type constructors and empty array literals have -- a size that is known (rigid or a type parameter). This is to -- ensure that we can actually determine their shape at run-time. causalityCheck :: Exp -> TermTypeM () causalityCheck binding_body = do constraints <- getConstraints let checkCausality what known t loc | (d, dloc) : _ <- mapMaybe (unknown constraints known) $ S.toList (fvVars $ freeInType t) = Just $ lift $ causality what (locOf loc) d dloc t | otherwise = Nothing checkParamCausality known p = checkCausality (pretty p) known (patternType p) (locOf p) collectingNewKnown = lift . flip execStateT mempty onExp :: S.Set VName -> Exp -> StateT (S.Set VName) (Either TypeError) Exp onExp known (Var v (Info t) loc) | Just bad <- checkCausality (dquotes (pretty v)) known t loc = bad onExp known (ProjectSection _ (Info t) loc) | Just bad <- checkCausality "projection section" known t loc = bad onExp known (IndexSection _ (Info t) loc) | Just bad <- checkCausality "projection section" known t loc = bad onExp known (OpSectionRight _ (Info t) _ _ _ loc) | Just bad <- checkCausality "operator section" known t loc = bad onExp known (OpSectionLeft _ (Info t) _ _ _ loc) | Just bad <- checkCausality "operator section" known t loc = bad onExp known (ArrayLit [] (Info t) loc) | Just bad <- checkCausality "empty array" known t loc = bad onExp known (Hole (Info t) loc) | Just bad <- checkCausality "hole" known t loc = bad onExp known e@(Lambda params body _ _ _) | bad : _ <- mapMaybe (checkParamCausality known) params = bad | otherwise = do -- Existentials coming into existence in the lambda body -- are not known outside of it. void $ collectingNewKnown $ onExp known body pure e onExp known e@(AppExp (LetPat _ _ bindee_e body_e _) (Info res)) = do sequencePoint known bindee_e body_e $ appResExt res pure e onExp known e@(AppExp (Match scrutinee cs _) (Info res)) = do new_known <- collectingNewKnown $ onExp known scrutinee void $ recurse (new_known <> known) cs modify ((new_known <> S.fromList (appResExt res)) <>) pure e onExp known e@(AppExp (Apply f args _) (Info res)) = do seqArgs known $ reverse $ NE.toList args pure e where seqArgs known' [] = do void $ onExp known' f modify (S.fromList (appResExt res) <>) seqArgs known' ((Info p, x) : xs) = do new_known <- collectingNewKnown $ onExp known' x void $ seqArgs (new_known <> known') xs modify ((new_known <> S.fromList (maybeToList p)) <>) onExp known e@(Constr v args (Info t) loc) = do seqArgs known args pure e where seqArgs known' [] | Just bad <- checkCausality (dquotes ("#" <> pretty v)) known' t loc = bad | otherwise = pure () seqArgs known' (x : xs) = do new_known <- collectingNewKnown $ onExp known' x void $ seqArgs (new_known <> known') xs modify (new_known <>) onExp known e@(AppExp (BinOp (f, floc) ft (x, Info xp) (y, Info yp) _) (Info res)) = do args_known <- collectingNewKnown $ sequencePoint known x y $ catMaybes [xp, yp] void $ onExp (args_known <> known) (Var f ft floc) modify ((args_known <> S.fromList (appResExt res)) <>) pure e onExp known e@(AppExp e' (Info res)) = do recurse known e' modify (<> S.fromList (appResExt res)) pure e onExp known e = do recurse known e pure e recurse known = void . astMap mapper where mapper = identityMapper {mapOnExp = onExp known} sequencePoint known x y ext = do new_known <- collectingNewKnown $ onExp known x void $ onExp (new_known <> known) y modify ((new_known <> S.fromList ext) <>) either throwError (const $ pure ()) $ evalStateT (onExp mempty binding_body) mempty where unknown constraints known v = do guard $ v `S.notMember` known loc <- case snd <$> M.lookup v constraints of Just (UnknownSize loc _) -> Just loc _ -> Nothing pure (v, loc) causality what loc d dloc t = Left . TypeError loc mempty . withIndexLink "causality-check" $ "Causality check: size" <+> dquotes (prettyName d) <+> "needed for type of" <+> what <> colon indent 2 (pretty t) "But" <+> dquotes (prettyName d) <+> "is computed at" <+> pretty (locStrRel loc dloc) <> "." "" "Hint:" <+> align ( textwrap "Bind the expression producing" <+> dquotes (prettyName d) <+> "with 'let' beforehand." ) mustBeIrrefutable :: (MonadTypeChecker f) => Pat StructType -> f () mustBeIrrefutable p = do case unmatched [p] of [] -> pure () ps' -> typeError p mempty . withIndexLink "refutable-pattern" $ "Refutable pattern not allowed here.\nUnmatched cases:" indent 2 (stack (map pretty ps')) -- | Traverse the expression, emitting warnings and errors for various -- problems: -- -- * Unmatched cases. -- -- * If any of the literals overflow their inferred types. Note: -- currently unable to detect float underflow (such as 1e-400 -> 0) localChecks :: Exp -> TermTypeM () localChecks = void . check where check e@(AppExp (Match _ cs loc) _) = do let ps = fmap (\(CasePat p _ _) -> p) cs case unmatched $ NE.toList ps of [] -> recurse e ps' -> typeError loc mempty . withIndexLink "unmatched-cases" $ "Unmatched cases in match expression:" indent 2 (stack (map pretty ps')) check e@(AppExp (LetPat _ p _ _ _) _) = mustBeIrrefutable p *> recurse e check e@(Lambda ps _ _ _ _) = mapM_ (mustBeIrrefutable . fmap toStruct) ps *> recurse e check e@(AppExp (LetFun _ (_, ps, _, _, _) _ _) _) = mapM_ (mustBeIrrefutable . fmap toStruct) ps *> recurse e check e@(AppExp (Loop _ p _ form _ _) _) = do mustBeIrrefutable (fmap toStruct p) case form of ForIn form_p _ -> mustBeIrrefutable form_p _ -> pure () recurse e check e@(IntLit x ty loc) = e <$ case ty of Info (Scalar (Prim t)) -> errorBounds (inBoundsI x t) x t loc _ -> error "Inferred type of int literal is not a number" check e@(FloatLit x ty loc) = e <$ case ty of Info (Scalar (Prim (FloatType t))) -> errorBounds (inBoundsF x t) x t loc _ -> error "Inferred type of float literal is not a float" check e@(Negate (IntLit x ty loc1) loc2) = e <$ case ty of Info (Scalar (Prim t)) -> errorBounds (inBoundsI (-x) t) (-x) t (loc1 <> loc2) _ -> error "Inferred type of int literal is not a number" check e@(AppExp (BinOp (QualName [] v, _) _ (x, _) _ loc) _) | baseName v == "==", Array {} <- typeOf x, baseTag v <= maxIntrinsicTag = do warn loc $ textwrap "Comparing arrays with \"==\" is deprecated and will stop working in a future revision of the language." recurse e check e = recurse e recurse = astMap identityMapper {mapOnExp = check} bitWidth ty = 8 * intByteSize ty :: Int inBoundsI x (Signed t) = x >= -2 ^ (bitWidth t - 1) && x < 2 ^ (bitWidth t - 1) inBoundsI x (Unsigned t) = x >= 0 && x < 2 ^ bitWidth t inBoundsI x (FloatType Float16) = not $ isInfinite (fromIntegral x :: Half) inBoundsI x (FloatType Float32) = not $ isInfinite (fromIntegral x :: Float) inBoundsI x (FloatType Float64) = not $ isInfinite (fromIntegral x :: Double) inBoundsI _ Bool = error "Inferred type of int literal is not a number" inBoundsF x Float16 = not $ isInfinite (realToFrac x :: Float) inBoundsF x Float32 = not $ isInfinite (realToFrac x :: Float) inBoundsF x Float64 = not $ isInfinite x errorBounds inBounds x ty loc = unless inBounds $ typeError loc mempty . withIndexLink "literal-out-of-bounds" $ "Literal " <> pretty x <> " out of bounds for inferred type " <> pretty ty <> "." -- | Type-check a top-level (or module-level) function definition. -- Despite the name, this is also used for checking constant -- definitions, by treating them as 0-ary functions. checkFunDef :: ( VName, Maybe (TypeExp (ExpBase NoInfo VName) VName), [TypeParam], [PatBase NoInfo VName ParamType], ExpBase NoInfo VName, SrcLoc ) -> TypeM ( [TypeParam], [Pat ParamType], Maybe (TypeExp Exp VName), ResRetType, Exp ) checkFunDef (fname, maybe_retdecl, tparams, params, body, loc) = runTermTypeM checkExp $ do (tparams', params', maybe_retdecl', RetType dims rettype', body') <- checkBinding (fname, maybe_retdecl, tparams, params, body, loc) -- Since this is a top-level function, we also resolve overloaded -- types, using either defaults or complaining about ambiguities. fixOverloadedTypes $ typeVars rettype' <> foldMap (typeVars . patternType) params' -- Then replace all inferred types in the body and parameters. body'' <- normTypeFully body' params'' <- mapM normTypeFully params' maybe_retdecl'' <- traverse updateTypes maybe_retdecl' rettype'' <- normTypeFully rettype' -- Check if the function body can actually be evaluated. causalityCheck body'' -- Check for various problems. mapM_ (mustBeIrrefutable . fmap toStruct) params' localChecks body'' let ((body''', updated_ret), errors) = Consumption.checkValDef ( fname, params'', body'', RetType dims rettype'', maybe_retdecl'', loc ) mapM_ throwError errors pure (tparams', params'', maybe_retdecl'', updated_ret, body''') -- | This is "fixing" as in "setting them", not "correcting them". We -- only make very conservative fixing. fixOverloadedTypes :: Names -> TermTypeM () fixOverloadedTypes tyvars_at_toplevel = getConstraints >>= mapM_ fixOverloaded . M.toList . M.map snd where fixOverloaded (v, Overloaded ots usage) | Signed Int32 `elem` ots = do unify usage (Scalar (TypeVar mempty (qualName v) [])) $ Scalar (Prim $ Signed Int32) when (v `S.member` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to i32." | FloatType Float64 `elem` ots = do unify usage (Scalar (TypeVar mempty (qualName v) [])) $ Scalar (Prim $ FloatType Float64) when (v `S.member` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to f64." | otherwise = typeError usage mempty . withIndexLink "ambiguous-type" $ "Type is ambiguous (could be one of" <+> commasep (map pretty ots) <> ")." "Add a type annotation to disambiguate the type." fixOverloaded (v, NoConstraint _ usage) = do -- See #1552. unify usage (Scalar (TypeVar mempty (qualName v) [])) $ Scalar (tupleRecord []) when (v `S.member` tyvars_at_toplevel) $ warn usage "Defaulting ambiguous type to ()." fixOverloaded (_, Equality usage) = typeError usage mempty . withIndexLink "ambiguous-type" $ "Type is ambiguous (must be equality type)." "Add a type annotation to disambiguate the type." fixOverloaded (_, HasFields _ fs usage) = typeError usage mempty . withIndexLink "ambiguous-type" $ "Type is ambiguous. Must be record with fields:" indent 2 (stack $ map field $ M.toList fs) "Add a type annotation to disambiguate the type." where field (l, t) = pretty l <> colon <+> align (pretty t) fixOverloaded (_, HasConstrs _ cs usage) = typeError usage mempty . withIndexLink "ambiguous-type" $ "Type is ambiguous (must be a sum type with constructors:" <+> pretty (Sum cs) <> ")." "Add a type annotation to disambiguate the type." fixOverloaded (v, Size Nothing (Usage Nothing loc)) = typeError loc mempty . withIndexLink "ambiguous-size" $ "Ambiguous size" <+> dquotes (prettyName v) <> "." fixOverloaded (v, Size Nothing (Usage (Just u) loc)) = typeError loc mempty . withIndexLink "ambiguous-size" $ "Ambiguous size" <+> dquotes (prettyName v) <+> "arising from" <+> pretty u <> "." fixOverloaded _ = pure () hiddenParamNames :: [Pat ParamType] -> [VName] hiddenParamNames params = hidden where param_all_names = mconcat $ map patNames params named (Named x, _, _) = Just x named (Unnamed, _, _) = Nothing param_names = S.fromList $ mapMaybe (named . patternParam) params hidden = filter (`notElem` param_names) param_all_names inferredReturnType :: SrcLoc -> [Pat ParamType] -> StructType -> TermTypeM StructType inferredReturnType loc params t = do -- The inferred type may refer to names that are bound by the -- parameter patterns, but which will not be visible in the type. -- These we must turn into fresh type variables, which will be -- existential in the return type. fst <$> unscopeType loc hidden_params t where hidden_params = filter (`elem` hidden) $ foldMap patNames params hidden = hiddenParamNames params checkBinding :: ( VName, Maybe (TypeExp (ExpBase NoInfo VName) VName), [TypeParam], [PatBase NoInfo VName ParamType], ExpBase NoInfo VName, SrcLoc ) -> TermTypeM ( [TypeParam], [Pat ParamType], Maybe (TypeExp Exp VName), ResRetType, Exp ) checkBinding (fname, maybe_retdecl, tparams, params, body, loc) = incLevel . bindingParams tparams params $ \params' -> do maybe_retdecl' <- traverse checkTypeExpNonrigid maybe_retdecl body' <- checkFunBody params' body ((\(_, x, _) -> x) <$> maybe_retdecl') (maybe loc srclocOf maybe_retdecl) params'' <- mapM updateTypes params' body_t <- expTypeFully body' (maybe_retdecl'', rettype) <- case maybe_retdecl' of Just (retdecl', ret, _) -> do ret' <- normTypeFully ret pure (Just retdecl', ret') Nothing | null params -> pure (Nothing, toRes Nonunique body_t) | otherwise -> do body_t' <- inferredReturnType loc params'' body_t pure (Nothing, toRes Nonunique body_t') verifyFunctionParams (Just fname) params'' (tparams', params''', rettype') <- letGeneralise (baseName fname) loc tparams params'' =<< unscopeUnknown rettype when ( null params && any isSizeParam tparams' && not (null (retDims rettype')) ) $ typeError loc mempty $ textwrap "A size-polymorphic value binding may not have a type with an existential size." "Type of this binding is:" indent 2 (pretty rettype') "with the following type parameters:" indent 2 (sep $ map pretty $ filter isSizeParam tparams') pure (tparams', params''', maybe_retdecl'', rettype', body') -- | Extract all the shape names that occur in positive position -- (roughly, left side of an arrow) in a given type. sizeNamesPos :: TypeBase Size als -> S.Set VName sizeNamesPos (Scalar (Arrow _ _ _ t1 (RetType _ t2))) = onParam t1 <> sizeNamesPos t2 where onParam :: TypeBase Size als -> S.Set VName onParam (Scalar Arrow {}) = mempty onParam (Scalar (Record fs)) = mconcat $ map onParam $ M.elems fs onParam (Scalar (TypeVar _ _ targs)) = mconcat $ map onTypeArg targs onParam t = fvVars $ freeInType t onTypeArg (TypeArgDim (Var d _ _)) = S.singleton $ qualLeaf d onTypeArg (TypeArgDim _) = mempty onTypeArg (TypeArgType t) = onParam t sizeNamesPos _ = mempty -- | Verify certain restrictions on function parameters, and bail out -- on dubious constructions. -- -- These restrictions apply to all functions (anonymous or otherwise). -- Top-level functions have further restrictions that are checked -- during let-generalisation. verifyFunctionParams :: Maybe VName -> [Pat ParamType] -> TermTypeM () verifyFunctionParams fname params = onFailure (CheckingParams (baseName <$> fname)) $ verifyParams (foldMap patNames params) =<< mapM updateTypes params where verifyParams forbidden (p : ps) | d : _ <- filter (`elem` forbidden) $ S.toList $ fvVars $ freeInPat p = typeError p mempty . withIndexLink "inaccessible-size" $ "Parameter" <+> dquotes (pretty p) "refers to size" <+> dquotes (prettyName d) <> comma textwrap "which will not be accessible to the caller" <> comma textwrap "possibly because it is nested in a tuple or record." textwrap "Consider ascribing an explicit type that does not reference " <> dquotes (prettyName d) <> "." | otherwise = verifyParams forbidden' ps where forbidden' = case patternParam p of (Named v, _, _) -> delete v forbidden _ -> forbidden verifyParams _ [] = pure () -- | Move existentials down to the level where they are actually used -- (i.e. have their "witnesses"). E.g. changes -- -- @ -- ?[n].bool -> [n]bool -- @ -- -- to -- -- @ -- bool -> ?[n].[n]bool -- @ injectExt :: [VName] -> TypeBase Size u -> RetTypeBase Size u injectExt [] ret = RetType [] ret injectExt ext ret = RetType ext_here $ deeper ret where (immediate, _) = dimUses ret (ext_here, ext_there) = partition (`S.member` immediate) ext deeper :: TypeBase Size u -> TypeBase Size u deeper (Scalar (Prim t)) = Scalar $ Prim t deeper (Scalar (Record fs)) = Scalar $ Record $ M.map deeper fs deeper (Scalar (Sum cs)) = Scalar $ Sum $ M.map (map deeper) cs deeper (Scalar (Arrow als p d1 t1 (RetType t2_ext t2))) = Scalar $ Arrow als p d1 t1 $ injectExt (nubOrd (ext_there <> t2_ext)) t2 deeper (Scalar (TypeVar u tn targs)) = Scalar $ TypeVar u tn $ map deeperArg targs deeper t@Array {} = t deeperArg (TypeArgType t) = TypeArgType $ deeper t deeperArg (TypeArgDim d) = TypeArgDim d -- | Find all type variables in the given type that are covered by the -- constraints, and produce type parameters that close over them. -- -- The passed-in list of type parameters is always prepended to the -- produced list of type parameters. closeOverTypes :: Name -> SrcLoc -> [TypeParam] -> [StructType] -> ResType -> Constraints -> TermTypeM ([TypeParam], ResRetType) closeOverTypes defname defloc tparams paramts ret substs = do (more_tparams, retext) <- partitionEithers . catMaybes <$> mapM closeOver (M.toList $ M.map snd to_close_over) let mkExt v = case M.lookup v substs of Just (_, UnknownSize {}) -> Just v _ -> Nothing pure ( tparams ++ more_tparams, injectExt (nubOrd $ retext ++ mapMaybe mkExt (S.toList $ fvVars $ freeInType ret)) ret ) where -- Diet does not matter here. t = foldFunType (map (toParam Observe) paramts) $ RetType [] ret to_close_over = M.filterWithKey (\k _ -> k `S.member` visible) substs visible = typeVars t <> fvVars (freeInType t) (produced_sizes, param_sizes) = dimUses t -- Avoid duplicate type parameters. closeOver (k, _) | k `elem` map typeParamName tparams = pure Nothing closeOver (k, NoConstraint l usage) = pure $ Just $ Left $ TypeParamType l k $ srclocOf usage closeOver (k, ParamType l loc) = pure $ Just $ Left $ TypeParamType l k $ srclocOf loc closeOver (k, Size Nothing usage) = pure $ Just $ Left $ TypeParamDim k $ srclocOf usage closeOver (k, UnknownSize _ _) | k `S.member` param_sizes, k `S.notMember` produced_sizes = do notes <- dimNotes defloc $ sizeFromName (qualName k) mempty typeError defloc notes . withIndexLink "unknown-param-def" $ "Unknown size" <+> dquotes (prettyName k) <+> "in parameter of" <+> dquotes (prettyName defname) <> ", which is inferred as:" indent 2 (pretty t) | k `S.member` produced_sizes = pure $ Just $ Right k closeOver (_, _) = pure Nothing letGeneralise :: Name -> SrcLoc -> [TypeParam] -> [Pat ParamType] -> ResType -> TermTypeM ([TypeParam], [Pat ParamType], ResRetType) letGeneralise defname defloc tparams params restype = onFailure (CheckingLetGeneralise defname) $ do now_substs <- getConstraints -- Candidates for let-generalisation are those type variables that -- -- (1) were not known before we checked this function, and -- -- (2) are not used in the (new) definition of any type variables -- known before we checked this function. -- -- (3) are not referenced from an overloaded type (for example, -- are the element types of an incompletely resolved record type). -- This is a bit more restrictive than I'd like, and SML for -- example does not have this restriction. -- -- Criteria (1) and (2) is implemented by looking at the binding -- level of the type variables. let keep_type_vars = overloadedTypeVars now_substs cur_lvl <- curLevel let candidate k (lvl, _) = (k `S.notMember` keep_type_vars) && lvl >= (cur_lvl - length params) new_substs = M.filterWithKey candidate now_substs (tparams', RetType ret_dims restype') <- closeOverTypes defname defloc tparams (map patternStructType params) restype new_substs restype'' <- updateTypes restype' let used_sizes = freeInType restype'' <> foldMap (freeInType . patternType) params case filter ((`S.notMember` fvVars used_sizes) . typeParamName) $ filter isSizeParam tparams' of [] -> pure () tp : _ -> unusedSize $ SizeBinder (typeParamName tp) (srclocOf tp) -- We keep those type variables that were not closed over by -- let-generalisation. modifyConstraints $ M.filterWithKey $ \k _ -> k `notElem` map typeParamName tparams' pure (tparams', params, RetType ret_dims restype'') checkFunBody :: [Pat ParamType] -> ExpBase NoInfo VName -> Maybe ResType -> SrcLoc -> TermTypeM Exp checkFunBody params body maybe_rettype loc = do body' <- checkExp body -- Unify body return type with return annotation, if one exists. case maybe_rettype of Just rettype -> do body_t <- expTypeFully body' -- We need to turn any sizes provided by "hidden" parameter -- names into existential sizes instead. let hidden = hiddenParamNames params (body_t', _) <- unscopeType loc (filter (`elem` hidden) $ foldMap patNames params) body_t let usage = mkUsage body "return type annotation" onFailure (CheckingReturn rettype body_t') $ unify usage (toStruct rettype) body_t' Nothing -> pure () pure body' arrayOfM :: SrcLoc -> StructType -> Shape Size -> TermTypeM StructType arrayOfM loc t shape = do arrayElemType (mkUsage loc "use as array element") "type used in array" t pure $ arrayOf shape t futhark-0.25.27/src/Language/Futhark/TypeChecker/Terms/000077500000000000000000000000001475065116200225605ustar00rootroot00000000000000futhark-0.25.27/src/Language/Futhark/TypeChecker/Terms/Loop.hs000066400000000000000000000264471475065116200240420ustar00rootroot00000000000000-- | Type inference of @loop@. This is complicated because of the -- uniqueness and size inference, so the implementation is separate -- from the main type checker. module Language.Futhark.TypeChecker.Terms.Loop ( UncheckedLoop, CheckedLoop, checkLoop, ) where import Control.Monad import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor import Data.Bitraversable import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Futhark.Util (nubOrd) import Futhark.Util.Pretty hiding (group, space) import Language.Futhark import Language.Futhark.TypeChecker.Monad hiding (BoundV) import Language.Futhark.TypeChecker.Terms.Monad import Language.Futhark.TypeChecker.Terms.Pat import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify import Prelude hiding (mod) -- | Retrieve an oracle that can be used to decide whether two are in -- the same equivalence class (i.e. have been unified). This is an -- exotic operation. getAreSame :: (MonadUnify m) => m (VName -> VName -> Bool) getAreSame = check <$> getConstraints where check constraints x y = case (M.lookup x constraints, M.lookup y constraints) of (Just (_, Size (Just (Var x' _ _)) _), _) -> check constraints (qualLeaf x') y (_, Just (_, Size (Just (Var y' _ _)) _)) -> check constraints x (qualLeaf y') _ -> x == y -- | Replace specified sizes with distinct fresh size variables. someDimsFreshInType :: SrcLoc -> Name -> [VName] -> TypeBase Size als -> TermTypeM (TypeBase Size als) someDimsFreshInType loc desc fresh t = do areSameSize <- getAreSame let freshen v = any (areSameSize v) fresh bitraverse (onDim freshen) pure t where onDim freshen (Var d _ _) | freshen $ qualLeaf d = do v <- newFlexibleDim (mkUsage' loc) desc pure $ sizeFromName (qualName v) loc onDim _ d = pure d -- | Replace the specified sizes with fresh size variables of the -- specified ridigity. Returns the new fresh size variables. freshDimsInType :: Usage -> Rigidity -> Name -> [VName] -> TypeBase Size u -> TermTypeM (TypeBase Size u, [VName]) freshDimsInType usage r desc fresh t = do areSameSize <- getAreSame second (map snd) <$> runStateT (bitraverse (onDim areSameSize) pure t) mempty where onDim areSameSize (Var (QualName _ d) _ _) | any (areSameSize d) fresh = do prev_subst <- gets $ L.find (areSameSize d . fst) case prev_subst of Just (_, d') -> pure $ sizeFromName (qualName d') $ srclocOf usage Nothing -> do v <- lift $ newDimVar usage r desc modify ((d, v) :) pure $ sizeFromName (qualName v) $ srclocOf usage onDim _ d = pure d data ArgSource = Initial | BodyResult wellTypedLoopArg :: ArgSource -> [VName] -> Pat ParamType -> Exp -> TermTypeM () wellTypedLoopArg src sparams pat arg = do (merge_t, _) <- freshDimsInType (mkUsage arg desc) Nonrigid "loop" sparams $ toStruct (patternType pat) arg_t <- toStruct <$> expTypeFully arg onFailure (checking merge_t arg_t) $ unify (mkUsage arg desc) merge_t arg_t where (checking, desc) = case src of Initial -> (CheckingLoopInitial, "matching initial loop values to pattern") BodyResult -> (CheckingLoopBody, "matching loop body to pattern") -- | An un-checked loop. type UncheckedLoop = (PatBase NoInfo VName ParamType, LoopInitBase NoInfo VName, LoopFormBase NoInfo VName, ExpBase NoInfo VName) -- | A loop that has been type-checked. type CheckedLoop = ([VName], Pat ParamType, LoopInitBase Info VName, LoopFormBase Info VName, Exp) checkForImpossible :: Loc -> S.Set VName -> ParamType -> TermTypeM () checkForImpossible loc known_before pat_t = do cs <- getConstraints let bad v = do guard $ v `S.notMember` known_before (_, UnknownSize v_loc _) <- M.lookup v cs Just . typeError (srclocOf loc) mempty $ "Inferred type for loop parameter is" indent 2 (pretty pat_t) "but" <+> dquotes (prettyName v) <+> "is an existential size created inside the loop body at" <+> pretty (locStrRel loc v_loc) <> "." case mapMaybe bad $ S.toList $ fvVars $ freeInType pat_t of problem : _ -> problem [] -> pure () -- | Type-check a @loop@ expression, passing in a function for -- type-checking subexpressions. checkLoop :: (ExpBase NoInfo VName -> TermTypeM Exp) -> UncheckedLoop -> SrcLoc -> TermTypeM (CheckedLoop, AppRes) checkLoop checkExp (mergepat, loopinit, form, loopbody) loc = do loopinit' <- checkExp $ case loopinit of LoopInitExplicit e -> e LoopInitImplicit _ -> -- Should have been filled out in Names error "Unspected LoopInitImplicit" known_before <- M.keysSet <$> getConstraints zeroOrderType (mkUsage loopinit' "use as loop variable") "type used as loop variable" . toStruct =<< expTypeFully loopinit' -- The handling of dimension sizes is a bit intricate, but very -- similar to checking a function, followed by checking a call to -- it. The overall procedure is as follows: -- -- (1) All empty dimensions in the merge pattern are instantiated -- with nonrigid size variables. All explicitly specified -- dimensions are preserved. -- -- (2) The body of the loop is type-checked. The result type is -- combined with the merge pattern type to determine which sizes are -- variant, and these are turned into size parameters for the merge -- pattern. -- -- (3) We now conceptually have a function parameter type and -- return type. We check that it can be called with the body type -- as argument. -- -- (4) Similarly to (3), we check that the "function" can be -- called with the initial merge values as argument. The result -- of this is the type of the loop as a whole. (merge_t, new_dims_map) <- -- dim handling (1) allDimsFreshInType (mkUsage loc "loop parameter type inference") Nonrigid "loop_d" =<< expTypeFully loopinit' let new_dims_to_initial_dim = M.toList new_dims_map new_dims = map fst new_dims_to_initial_dim -- dim handling (2) let checkLoopReturnSize mergepat' loopbody' = do loopbody_t <- expTypeFully loopbody' mergepat_t <- normTypeFully (patternType mergepat') let ok_names = known_before <> S.fromList new_dims checkForImpossible (locOf mergepat) ok_names mergepat_t pat_t <- someDimsFreshInType loc "loop" new_dims mergepat_t -- We are ignoring the dimensions here, because any mismatches -- should be turned into fresh size variables. onFailure (CheckingLoopBody (toStruct pat_t) (toStruct loopbody_t)) $ unify (mkUsage loopbody "matching loop body to loop pattern") (toStruct pat_t) (toStruct loopbody_t) -- Figure out which of the 'new_dims' dimensions are variant. -- This works because we know that each dimension from -- new_dims in the pattern is unique and distinct. areSameSize <- getAreSame let onDims _ x y | x == y = pure x onDims _ e d = do forM_ (fvVars $ freeInExp e) $ \v -> do case L.find (areSameSize v . fst) new_dims_to_initial_dim of Just (_, e') -> if e' == d then modify $ first $ M.insert v $ ExpSubst e' else unless (v `S.member` known_before) $ modify (second (v :)) _ -> pure () pure e loopbody_t' <- normTypeFully loopbody_t merge_t' <- normTypeFully merge_t let (init_substs, sparams) = execState (matchDims onDims merge_t' loopbody_t') mempty -- Make sure that any of new_dims that are invariant will be -- replaced with the invariant size in the loop body. Failure -- to do this can cause type annotations to still refer to -- new_dims. let dimToInit (v, ExpSubst e) = constrain v $ Size (Just e) (mkUsage loc "size of loop parameter") dimToInit _ = pure () mapM_ dimToInit $ M.toList init_substs mergepat'' <- applySubst (`M.lookup` init_substs) <$> updateTypes mergepat' -- Eliminate those new_dims that turned into sparams so it won't -- look like we have ambiguous sizes lying around. modifyConstraints $ M.filterWithKey $ \k _ -> k `notElem` sparams -- dim handling (3) -- -- The only trick here is that we have to turn any instances -- of loop parameters in the type of loopbody' rigid, -- because we are no longer in a position to change them, -- really. wellTypedLoopArg BodyResult sparams mergepat'' loopbody' pure (nubOrd sparams, mergepat'') (sparams, mergepat', form', loopbody') <- case form of For i uboundexp -> do uboundexp' <- require "being the bound in a 'for' loop" anySignedType =<< checkExp uboundexp bound_t <- expTypeFully uboundexp' bindingIdent i bound_t $ \i' -> bindingPat [] mergepat merge_t $ \mergepat' -> incLevel $ do loopbody' <- checkExp loopbody (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' pure ( sparams, mergepat'', For i' uboundexp', loopbody' ) ForIn xpat e -> do (arr_t, _) <- newArrayType (mkUsage' (srclocOf e)) "e" 1 e' <- unifies "being iterated in a 'for-in' loop" arr_t =<< checkExp e t <- expTypeFully e' case t of _ | Just t' <- peelArray 1 t -> bindingPat [] xpat t' $ \xpat' -> bindingPat [] mergepat merge_t $ \mergepat' -> incLevel $ do loopbody' <- checkExp loopbody (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' pure ( sparams, mergepat'', ForIn (fmap toStruct xpat') e', loopbody' ) | otherwise -> typeError (srclocOf e) mempty $ "Iteratee of a for-in loop must be an array, but expression has type" <+> pretty t While cond -> bindingPat [] mergepat merge_t $ \mergepat' -> incLevel $ do cond' <- checkExp cond >>= unifies "being the condition of a 'while' loop" (Scalar $ Prim Bool) loopbody' <- checkExp loopbody (sparams, mergepat'') <- checkLoopReturnSize mergepat' loopbody' pure ( sparams, mergepat'', While cond', loopbody' ) -- dim handling (4) wellTypedLoopArg Initial sparams mergepat' loopinit' (loopt, retext) <- freshDimsInType (mkUsage loc "inference of loop result type") (Rigid RigidLoop) "loop" sparams (patternType mergepat') pure ( (sparams, mergepat', LoopInitExplicit loopinit', form', loopbody'), AppRes (toStruct loopt) retext ) futhark-0.25.27/src/Language/Futhark/TypeChecker/Terms/Monad.hs000066400000000000000000000503211475065116200241530ustar00rootroot00000000000000{-# LANGUAGE Strict #-} -- | Facilities for type-checking terms. Factored out of -- "Language.Futhark.TypeChecker.Terms" to prevent the module from -- being gigantic. -- -- Incidentally also a nice place to put Haddock comments to make the -- internal API of the type checker easier to browse. module Language.Futhark.TypeChecker.Terms.Monad ( TermTypeM, runTermTypeM, ValBinding (..), SizeSource (SourceSlice), Inferred (..), Checking (..), withEnv, localScope, TermEnv (..), TermScope (..), TermTypeState (..), onFailure, extSize, expType, expTypeFully, constrain, newArrayType, allDimsFreshInType, updateTypes, Names, -- * Primitive checking unifies, require, checkTypeExpNonrigid, lookupVar, lookupMod, -- * Sizes isInt64, -- * Control flow incLevel, -- * Errors unusedSize, ) where import Control.Monad import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State.Strict import Data.Bitraversable import Data.Char (isAscii) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T import Futhark.FreshNames hiding (newName) import Futhark.FreshNames qualified import Futhark.Util.Pretty hiding (space) import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Monad hiding (BoundV, lookupMod, stateNameSource) import Language.Futhark.TypeChecker.Monad qualified as TypeM import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify import Prelude hiding (mod) type Names = S.Set VName data ValBinding = BoundV [TypeParam] StructType | OverloadedF [PrimType] [Maybe PrimType] (Maybe PrimType) | EqualityF deriving (Show) unusedSize :: (MonadTypeChecker m) => SizeBinder VName -> m a unusedSize p = typeError p mempty . withIndexLink "unused-size" $ "Size" <+> pretty p <+> "unused in pattern." data Inferred t = NoneInferred | Ascribed t instance Functor Inferred where fmap _ NoneInferred = NoneInferred fmap f (Ascribed t) = Ascribed (f t) data Checking = CheckingApply (Maybe (QualName VName)) Exp StructType StructType | CheckingReturn ResType StructType | CheckingAscription StructType StructType | CheckingLetGeneralise Name | CheckingParams (Maybe Name) | CheckingPat (PatBase NoInfo VName StructType) (Inferred StructType) | CheckingLoopBody StructType StructType | CheckingLoopInitial StructType StructType | CheckingRecordUpdate [Name] StructType StructType | CheckingRequired [StructType] StructType | CheckingBranches StructType StructType instance Pretty Checking where pretty (CheckingApply f e expected actual) = header "Expected:" <+> align (pretty expected) "Actual: " <+> align (pretty actual) where header = case f of Nothing -> "Cannot apply function to" <+> dquotes (shorten $ group $ pretty e) <> " (invalid type)." Just fname -> "Cannot apply" <+> dquotes (pretty fname) <+> "to" <+> dquotes (align $ shorten $ group $ pretty e) <> " (invalid type)." pretty (CheckingReturn expected actual) = "Function body does not have expected type." "Expected:" <+> align (pretty expected) "Actual: " <+> align (pretty actual) pretty (CheckingAscription expected actual) = "Expression does not have expected type from explicit ascription." "Expected:" <+> align (pretty expected) "Actual: " <+> align (pretty actual) pretty (CheckingLetGeneralise fname) = "Cannot generalise type of" <+> dquotes (pretty fname) <> "." pretty (CheckingParams fname) = "Invalid use of parameters in" <+> dquotes fname' <> "." where fname' = maybe "anonymous function" pretty fname pretty (CheckingPat pat NoneInferred) = "Invalid pattern" <+> dquotes (pretty pat) <> "." pretty (CheckingPat pat (Ascribed t)) = "Pattern" indent 2 (pretty pat) "cannot match value of type" indent 2 (pretty t) pretty (CheckingLoopBody expected actual) = "Loop body does not have expected type." "Expected:" <+> align (pretty expected) "Actual: " <+> align (pretty actual) pretty (CheckingLoopInitial expected actual) = "Initial loop values do not have expected type." "Expected:" <+> align (pretty expected) "Actual: " <+> align (pretty actual) pretty (CheckingRecordUpdate fs expected actual) = "Type mismatch when updating record field" <+> dquotes fs' <> "." "Existing:" <+> align (pretty expected) "New: " <+> align (pretty actual) where fs' = mconcat $ punctuate "." $ map pretty fs pretty (CheckingRequired [expected] actual) = "Expression must have type" <+> pretty expected <> "." "Actual type:" <+> align (pretty actual) pretty (CheckingRequired expected actual) = "Type of expression must be one of " <+> expected' <> "." "Actual type:" <+> align (pretty actual) where expected' = commasep (map pretty expected) pretty (CheckingBranches t1 t2) = "Branches differ in type." "Former:" <+> pretty t1 "Latter:" <+> pretty t2 -- | Type checking happens with access to this environment. The -- 'TermScope' will be extended during type-checking as bindings come into -- scope. data TermEnv = TermEnv { termScope :: TermScope, termChecking :: Maybe Checking, termLevel :: Level, termChecker :: ExpBase NoInfo VName -> TermTypeM Exp, termOuterEnv :: Env, termImportName :: ImportName } data TermScope = TermScope { scopeVtable :: M.Map VName ValBinding, scopeTypeTable :: M.Map VName TypeBinding, scopeModTable :: M.Map VName Mod } deriving (Show) instance Semigroup TermScope where TermScope vt1 tt1 mt1 <> TermScope vt2 tt2 mt2 = TermScope (vt2 `M.union` vt1) (tt2 `M.union` tt1) (mt1 `M.union` mt2) envToTermScope :: Env -> TermScope envToTermScope env = TermScope { scopeVtable = vtable, scopeTypeTable = envTypeTable env, scopeModTable = envModTable env } where vtable = M.map valBinding $ envVtable env valBinding (TypeM.BoundV tps v) = BoundV tps v withEnv :: TermEnv -> Env -> TermEnv withEnv tenv env = tenv {termScope = termScope tenv <> envToTermScope env} -- | Wrap a function name to give it a vacuous Eq instance for SizeSource. newtype FName = FName (Maybe (QualName VName)) deriving (Show) instance Eq FName where _ == _ = True instance Ord FName where compare _ _ = EQ -- | What was the source of some existential size? This is used for -- using the same existential variable if the same source is -- encountered in multiple locations. data SizeSource = SourceArg FName (ExpBase NoInfo VName) | SourceSlice (Maybe Size) (Maybe (ExpBase NoInfo VName)) (Maybe (ExpBase NoInfo VName)) (Maybe (ExpBase NoInfo VName)) deriving (Eq, Ord, Show) -- | The state is a set of constraints and a counter for generating -- type names. This is distinct from the usual counter we use for -- generating unique names, as these will be user-visible. data TermTypeState = TermTypeState { stateConstraints :: Constraints, stateCounter :: !Int, stateWarnings :: Warnings, stateNameSource :: VNameSource } newtype TermTypeM a = TermTypeM ( ReaderT TermEnv (StateT TermTypeState (Except (Warnings, TypeError))) a ) deriving ( Monad, Functor, Applicative, MonadReader TermEnv, MonadState TermTypeState ) instance MonadError TypeError TermTypeM where throwError e = TermTypeM $ do ws <- gets stateWarnings throwError (ws, e) catchError (TermTypeM m) f = TermTypeM $ m `catchError` f' where f' (_, e) = let TermTypeM m' = f e in m' incCounter :: TermTypeM Int incCounter = do s <- get put s {stateCounter = stateCounter s + 1} pure $ stateCounter s constrain :: VName -> Constraint -> TermTypeM () constrain v c = do lvl <- curLevel modifyConstraints $ M.insert v (lvl, c) instance MonadUnify TermTypeM where getConstraints = gets stateConstraints putConstraints x = modify $ \s -> s {stateConstraints = x} newTypeVar loc desc = do i <- incCounter v <- newID $ mkTypeVarName desc i constrain v $ NoConstraint Lifted $ mkUsage' loc pure $ Scalar $ TypeVar mempty (qualName v) [] curLevel = asks termLevel newDimVar usage rigidity name = do dim <- newTypeName name case rigidity of Rigid rsrc -> constrain dim $ UnknownSize (locOf usage) rsrc Nonrigid -> constrain dim $ Size Nothing usage pure dim unifyError loc notes bcs doc = do checking <- asks termChecking case checking of Just checking' -> throwError $ TypeError (locOf loc) notes $ pretty checking' <> line doc <> pretty bcs Nothing -> throwError $ TypeError (locOf loc) notes $ doc <> pretty bcs matchError loc notes bcs t1 t2 = do checking <- asks termChecking case checking of Just checking' | hasNoBreadCrumbs bcs -> throwError $ TypeError (locOf loc) notes $ pretty checking' | otherwise -> throwError $ TypeError (locOf loc) notes $ pretty checking' <> line doc <> pretty bcs Nothing -> throwError $ TypeError (locOf loc) notes $ doc <> pretty bcs where doc = "Types" indent 2 (pretty t1) "and" indent 2 (pretty t2) "do not match." -- | Instantiate a type scheme with fresh type variables for its type -- parameters. Returns the names of the fresh type variables, the -- instance list, and the instantiated type. instantiateTypeScheme :: QualName VName -> SrcLoc -> [TypeParam] -> StructType -> TermTypeM ([VName], StructType) instantiateTypeScheme qn loc tparams t = do let tnames = map typeParamName tparams (tparam_names, tparam_substs) <- mapAndUnzipM (instantiateTypeParam qn loc) tparams let substs = M.fromList $ zip tnames tparam_substs t' = applySubst (`M.lookup` substs) t pure (tparam_names, t') -- | Create a new type name and insert it (unconstrained) in the -- substitution map. instantiateTypeParam :: (Monoid as) => QualName VName -> SrcLoc -> TypeParam -> TermTypeM (VName, Subst (RetTypeBase dim as)) instantiateTypeParam qn loc tparam = do i <- incCounter let name = nameFromString (takeWhile isAscii (baseString (typeParamName tparam))) v <- newID $ mkTypeVarName name i case tparam of TypeParamType x _ _ -> do constrain v . NoConstraint x . mkUsage loc . docText $ "instantiated type parameter of " <> dquotes (pretty qn) pure (v, Subst [] $ RetType [] $ Scalar $ TypeVar mempty (qualName v) []) TypeParamDim {} -> do constrain v . Size Nothing . mkUsage loc . docText $ "instantiated size parameter of " <> dquotes (pretty qn) pure (v, ExpSubst $ sizeFromName (qualName v) loc) lookupQualNameEnv :: QualName VName -> TermTypeM TermScope lookupQualNameEnv (QualName [q] _) | baseTag q <= maxIntrinsicTag = asks termScope -- Magical intrinsic module. lookupQualNameEnv qn@(QualName quals _) = do scope <- asks termScope descend scope quals where descend scope [] = pure scope descend scope (q : qs) | Just (ModEnv q_scope) <- M.lookup q $ scopeModTable scope = descend (envToTermScope q_scope) qs | otherwise = error $ "lookupQualNameEnv " <> show qn lookupMod :: QualName VName -> TermTypeM Mod lookupMod qn@(QualName _ name) = do scope <- lookupQualNameEnv qn case M.lookup name $ scopeModTable scope of Nothing -> error $ "lookupMod: " <> show qn Just m -> pure m localScope :: (TermScope -> TermScope) -> TermTypeM a -> TermTypeM a localScope f = local $ \tenv -> tenv {termScope = f $ termScope tenv} instance MonadTypeChecker TermTypeM where warnings ws = modify $ \s -> s {stateWarnings = stateWarnings s <> ws} warn loc problem = warnings $ singleWarning (locOf loc) problem newName v = do s <- get let (v', src') = Futhark.FreshNames.newName (stateNameSource s) v put $ s {stateNameSource = src'} pure v' newTypeName name = do i <- incCounter newID $ mkTypeVarName name i bindVal v (TypeM.BoundV tps t) = localScope $ \scope -> scope {scopeVtable = M.insert v (BoundV tps t) $ scopeVtable scope} lookupType qn = do outer_env <- asks termOuterEnv scope <- lookupQualNameEnv qn case M.lookup (qualLeaf qn) $ scopeTypeTable scope of Nothing -> error $ "lookupType: " <> show qn Just (TypeAbbr l ps (RetType dims def)) -> pure ( ps, RetType dims $ qualifyTypeVars outer_env (map typeParamName ps) (qualQuals qn) def, l ) typeError loc notes s = do checking <- asks termChecking case checking of Just checking' -> throwError $ TypeError (locOf loc) notes (pretty checking' <> line s) Nothing -> throwError $ TypeError (locOf loc) notes s lookupVar :: SrcLoc -> QualName VName -> TermTypeM StructType lookupVar loc qn@(QualName qs name) = do scope <- lookupQualNameEnv qn let usage = mkUsage loc $ docText $ "use of " <> dquotes (pretty qn) case M.lookup name $ scopeVtable scope of Nothing -> error $ "lookupVar: " <> show qn Just (BoundV tparams t) -> do if null tparams && null qs then pure t else do (tnames, t') <- instantiateTypeScheme qn loc tparams t outer_env <- asks termOuterEnv pure $ qualifyTypeVars outer_env tnames qs t' Just EqualityF -> do argtype <- newTypeVar loc "t" equalityType usage argtype pure $ Scalar . Arrow mempty Unnamed Observe argtype . RetType [] $ Scalar $ Arrow mempty Unnamed Observe argtype $ RetType [] $ Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do argtype <- newTypeVar loc "t" mustBeOneOf ts usage argtype let (pts', rt') = instOverloaded argtype pts rt pure $ foldFunType (map (toParam Observe) pts') $ RetType [] $ toRes Nonunique rt' where instOverloaded argtype pts rt = ( map (maybe (toStruct argtype) (Scalar . Prim)) pts, maybe (toStruct argtype) (Scalar . Prim) rt ) onFailure :: Checking -> TermTypeM a -> TermTypeM a onFailure c = local $ \env -> env {termChecking = Just c} extSize :: SrcLoc -> SizeSource -> TermTypeM (Size, Maybe VName) extSize loc e = do let rsrc = case e of SourceArg (FName fname) e' -> RigidArg fname $ prettyTextOneLine e' SourceSlice d i j s -> RigidSlice d $ prettyTextOneLine $ DimSlice i j s d <- newRigidDim loc rsrc "n" pure ( sizeFromName (qualName d) loc, Just d ) incLevel :: TermTypeM a -> TermTypeM a incLevel = local $ \env -> env {termLevel = termLevel env + 1} -- | Get the type of an expression, with top level type variables -- substituted. Never call 'typeOf' directly (except in a few -- carefully inspected locations)! expType :: Exp -> TermTypeM StructType expType = normType . typeOf -- | Get the type of an expression, with all type variables -- substituted. Slower than 'expType', but sometimes necessary. -- Never call 'typeOf' directly (except in a few carefully inspected -- locations)! expTypeFully :: Exp -> TermTypeM StructType expTypeFully = normTypeFully . typeOf newArrayType :: Usage -> Name -> Int -> TermTypeM (StructType, StructType) newArrayType usage desc r = do v <- newTypeName desc constrain v $ NoConstraint Unlifted usage dims <- replicateM r $ newDimVar usage Nonrigid "dim" let rowt = TypeVar mempty (qualName v) [] mkSize = flip sizeFromName (srclocOf usage) . qualName pure ( Array mempty (Shape $ map mkSize dims) rowt, Scalar rowt ) -- | Replace *all* dimensions with distinct fresh size variables. allDimsFreshInType :: Usage -> Rigidity -> Name -> TypeBase Size als -> TermTypeM (TypeBase Size als, M.Map VName Size) allDimsFreshInType usage r desc t = runStateT (bitraverse onDim pure t) mempty where onDim d = do v <- lift $ newDimVar usage r desc modify $ M.insert v d pure $ sizeFromName (qualName v) $ srclocOf usage -- | Replace all type variables with their concrete types. updateTypes :: (ASTMappable e) => e -> TermTypeM e updateTypes = astMap tv where tv = ASTMapper { mapOnExp = astMap tv, mapOnName = pure, mapOnStructType = normTypeFully, mapOnParamType = normTypeFully, mapOnResRetType = normTypeFully } --- Basic checking unifies :: T.Text -> StructType -> Exp -> TermTypeM Exp unifies why t e = do unify (mkUsage (srclocOf e) why) t . toStruct =<< expType e pure e -- | @require ts e@ causes a 'TypeError' if @expType e@ is not one of -- the types in @ts@. Otherwise, simply returns @e@. require :: T.Text -> [PrimType] -> Exp -> TermTypeM Exp require why ts e = do mustBeOneOf ts (mkUsage (srclocOf e) why) . toStruct =<< expType e pure e checkExpForSize :: ExpBase NoInfo VName -> TermTypeM Exp checkExpForSize e = do checker <- asks termChecker e' <- checker e let t = toStruct $ typeOf e' unify (mkUsage (locOf e') "Size expression") t (Scalar (Prim (Signed Int64))) updateTypes e' checkTypeExpNonrigid :: TypeExp (ExpBase NoInfo VName) VName -> TermTypeM (TypeExp Exp VName, ResType, [VName]) checkTypeExpNonrigid te = do (te', svars, rettype, _l) <- checkTypeExp checkExpForSize te -- No guarantee that the locally bound sizes in rettype are globally -- unique, but we want to turn them into size variables, so let's -- give them some unique names. RetType dims st <- renameRetType rettype forM_ (svars ++ dims) $ \v -> constrain v $ Size Nothing $ mkUsage (srclocOf te) "anonymous size in type expression" pure (te', st, svars ++ dims) --- Sizes isInt64 :: Exp -> Maybe Int64 isInt64 (Literal (SignedValue (Int64Value k')) _) = Just $ fromIntegral k' isInt64 (IntLit k' _ _) = Just $ fromInteger k' isInt64 (Negate x _) = negate <$> isInt64 x isInt64 (Parens x _) = isInt64 x isInt64 _ = Nothing -- Running initialTermScope :: TermScope initialTermScope = TermScope { scopeVtable = initialVtable, scopeTypeTable = mempty, scopeModTable = mempty } where initialVtable = M.fromList $ mapMaybe addIntrinsicF $ M.toList intrinsics prim = Scalar . Prim arrow x y = Scalar $ Arrow mempty Unnamed Observe x y addIntrinsicF (name, IntrinsicMonoFun pts t) = Just (name, BoundV [] $ arrow pts' $ RetType [] $ prim t) where pts' = case pts of [pt] -> prim pt _ -> Scalar $ tupleRecord $ map prim pts addIntrinsicF (name, IntrinsicOverloadedFun ts pts rts) = Just (name, OverloadedF ts pts rts) addIntrinsicF (name, IntrinsicPolyFun tvs pts rt) = Just ( name, BoundV tvs $ foldFunType pts rt ) addIntrinsicF (name, IntrinsicEquality) = Just (name, EqualityF) addIntrinsicF _ = Nothing runTermTypeM :: (ExpBase NoInfo VName -> TermTypeM Exp) -> TermTypeM a -> TypeM a runTermTypeM checker (TermTypeM m) = do initial_scope <- (initialTermScope <>) . envToTermScope <$> askEnv name <- askImportName outer_env <- askEnv src <- gets TypeM.stateNameSource let initial_tenv = TermEnv { termScope = initial_scope, termChecking = Nothing, termLevel = 0, termChecker = checker, termImportName = name, termOuterEnv = outer_env } initial_state = TermTypeState { stateConstraints = mempty, stateCounter = 0, stateWarnings = mempty, stateNameSource = src } case runExcept (runStateT (runReaderT m initial_tenv) initial_state) of Left (ws, e) -> do warnings ws throwError e Right (a, TermTypeState {stateNameSource, stateWarnings}) -> do warnings stateWarnings modify $ \s -> s {TypeM.stateNameSource = stateNameSource} pure a futhark-0.25.27/src/Language/Futhark/TypeChecker/Terms/Pat.hs000066400000000000000000000243751475065116200236530ustar00rootroot00000000000000-- | Type checking of patterns. module Language.Futhark.TypeChecker.Terms.Pat ( binding, bindingParams, bindingPat, bindingIdent, bindingSizes, ) where import Control.Monad import Data.Bifunctor import Data.Either import Data.List (find, isPrefixOf, sort, sortBy) import Data.Map.Strict qualified as M import Data.Maybe import Data.Ord (comparing) import Data.Set qualified as S import Futhark.Util.Pretty hiding (group, space) import Language.Futhark import Language.Futhark.TypeChecker.Monad hiding (BoundV) import Language.Futhark.TypeChecker.Terms.Monad import Language.Futhark.TypeChecker.Types import Language.Futhark.TypeChecker.Unify hiding (Usage) import Prelude hiding (mod) nonrigidFor :: [(SizeBinder VName, QualName VName)] -> StructType -> StructType nonrigidFor [] = id -- Minor optimisation. nonrigidFor sizes = first onDim where onDim (Var (QualName _ v) info loc) | Just (_, v') <- find ((== v) . sizeName . fst) sizes = Var v' info loc onDim d = d -- | Bind these identifiers locally while running the provided action. binding :: [Ident StructType] -> TermTypeM a -> TermTypeM a binding idents m = localScope (`bindVars` idents) $ do -- Those identifiers that can potentially also be sizes are -- added as type constraints. This is necessary so that we -- can properly detect scope violations during unification. -- We do this for *all* identifiers, not just those that are -- integers, because they may become integers later due to -- inference... forM_ idents $ \ident -> constrain (identName ident) $ ParamSize $ locOf ident m where bindVars = foldl bindVar bindVar scope (Ident name (Info tp) _) = scope { scopeVtable = M.insert name (BoundV [] tp) $ scopeVtable scope } bindingTypes :: [Either (VName, TypeBinding) (VName, Constraint)] -> TermTypeM a -> TermTypeM a bindingTypes types m = do lvl <- curLevel modifyConstraints (<> M.map (lvl,) (M.fromList constraints)) localScope extend m where (tbinds, constraints) = partitionEithers types extend scope = scope { scopeTypeTable = M.fromList tbinds <> scopeTypeTable scope } bindingTypeParams :: [TypeParam] -> TermTypeM a -> TermTypeM a bindingTypeParams tparams = binding (mapMaybe typeParamIdent tparams) . bindingTypes (concatMap typeParamType tparams) where typeParamType (TypeParamType l v loc) = [ Left (v, TypeAbbr l [] $ RetType [] $ Scalar (TypeVar mempty (qualName v) [])), Right (v, ParamType l $ locOf loc) ] typeParamType (TypeParamDim v loc) = [Right (v, ParamSize $ locOf loc)] typeParamIdent :: TypeParam -> Maybe (Ident StructType) typeParamIdent (TypeParamDim v loc) = Just $ Ident v (Info $ Scalar $ Prim $ Signed Int64) loc typeParamIdent _ = Nothing -- | Bind @let@-bound sizes. This is usually followed by 'bindingPat' -- immediately afterwards. bindingSizes :: [SizeBinder VName] -> TermTypeM a -> TermTypeM a bindingSizes [] m = m -- Minor optimisation. bindingSizes sizes m = binding (map sizeWithType sizes) m where sizeWithType size = Ident (sizeName size) (Info (Scalar (Prim (Signed Int64)))) (srclocOf size) -- | Bind a single term-level identifier. bindingIdent :: IdentBase NoInfo VName StructType -> StructType -> (Ident StructType -> TermTypeM a) -> TermTypeM a bindingIdent (Ident v NoInfo vloc) t m = do let ident = Ident v (Info t) vloc binding [ident] $ m ident -- All this complexity is just so we can handle un-suffixed numeric -- literals in patterns. patLitMkType :: PatLit -> SrcLoc -> TermTypeM ParamType patLitMkType (PatLitInt _) loc = do t <- newTypeVar loc "t" mustBeOneOf anyNumberType (mkUsage loc "integer literal") (toStruct t) pure t patLitMkType (PatLitFloat _) loc = do t <- newTypeVar loc "t" mustBeOneOf anyFloatType (mkUsage loc "float literal") (toStruct t) pure t patLitMkType (PatLitPrim v) _ = pure $ Scalar $ Prim $ primValueType v checkPat' :: [(SizeBinder VName, QualName VName)] -> PatBase NoInfo VName ParamType -> Inferred ParamType -> TermTypeM (Pat ParamType) checkPat' sizes (PatParens p loc) t = PatParens <$> checkPat' sizes p t <*> pure loc checkPat' sizes (PatAttr attr p loc) t = PatAttr <$> checkAttr attr <*> checkPat' sizes p t <*> pure loc checkPat' _ (Id name NoInfo loc) (Ascribed t) = pure $ Id name (Info t) loc checkPat' _ (Id name NoInfo loc) NoneInferred = do t <- newTypeVar loc "t" pure $ Id name (Info t) loc checkPat' _ (Wildcard _ loc) (Ascribed t) = pure $ Wildcard (Info t) loc checkPat' _ (Wildcard NoInfo loc) NoneInferred = do t <- newTypeVar loc "t" pure $ Wildcard (Info t) loc checkPat' sizes p@(TuplePat ps loc) (Ascribed t) | Just ts <- isTupleRecord t, length ts == length ps = TuplePat <$> zipWithM (checkPat' sizes) ps (map Ascribed ts) <*> pure loc | otherwise = do ps_t <- replicateM (length ps) (newTypeVar loc "t") unify (mkUsage loc "matching a tuple pattern") (Scalar (tupleRecord ps_t)) (toStruct t) checkPat' sizes p $ Ascribed $ toParam Observe $ Scalar $ tupleRecord ps_t checkPat' sizes (TuplePat ps loc) NoneInferred = TuplePat <$> mapM (\p -> checkPat' sizes p NoneInferred) ps <*> pure loc checkPat' _ (RecordPat p_fs _) _ | Just (L loc f, _) <- find (("_" `isPrefixOf`) . nameToString . unLoc . fst) p_fs = typeError loc mempty $ "Underscore-prefixed fields are not allowed." "Did you mean" <> dquotes (pretty (drop 1 (nameToString f)) <> "=_") <> "?" checkPat' sizes p@(RecordPat p_fs loc) (Ascribed t) | Scalar (Record t_fs) <- t, p_fs' <- sortBy (comparing fst) p_fs, t_fs' <- sortBy (comparing fst) (M.toList t_fs), map fst t_fs' == map (unLoc . fst) p_fs' = RecordPat <$> zipWithM check p_fs' t_fs' <*> pure loc | otherwise = do p_fs' <- traverse (const $ newTypeVar loc "t") $ M.fromList $ map (first unLoc) p_fs when (sort (M.keys p_fs') /= sort (map (unLoc . fst) p_fs)) $ typeError loc mempty $ "Duplicate fields in record pattern" <+> pretty p <> "." unify (mkUsage loc "matching a record pattern") (Scalar (Record p_fs')) (toStruct t) checkPat' sizes p $ Ascribed $ toParam Observe $ Scalar (Record p_fs') where check (L f_loc f, p_f) (_, t_f) = (L f_loc f,) <$> checkPat' sizes p_f (Ascribed t_f) checkPat' sizes (RecordPat fs loc) NoneInferred = RecordPat . M.toList <$> traverse (\p -> checkPat' sizes p NoneInferred) (M.fromList fs) <*> pure loc checkPat' sizes (PatAscription p t loc) maybe_outer_t = do (t', st, _) <- checkTypeExpNonrigid t case maybe_outer_t of Ascribed outer_t -> do let st_forunify = nonrigidFor sizes $ toStruct st unify (mkUsage loc "explicit type ascription") st_forunify (toStruct outer_t) PatAscription <$> checkPat' sizes p (Ascribed (resToParam st)) <*> pure t' <*> pure loc NoneInferred -> PatAscription <$> checkPat' sizes p (Ascribed (resToParam st)) <*> pure t' <*> pure loc checkPat' _ (PatLit l NoInfo loc) (Ascribed t) = do t' <- patLitMkType l loc unify (mkUsage loc "matching against literal") (toStruct t') (toStruct t) pure $ PatLit l (Info t') loc checkPat' _ (PatLit l NoInfo loc) NoneInferred = do t' <- patLitMkType l loc pure $ PatLit l (Info t') loc checkPat' sizes (PatConstr n NoInfo ps loc) (Ascribed (Scalar (Sum cs))) | Just ts <- M.lookup n cs = do when (length ps /= length ts) $ typeError loc mempty $ "Pattern #" <> pretty n <> " expects" <+> pretty (length ps) <+> "constructor arguments, but type provides" <+> pretty (length ts) <+> "arguments." ps' <- zipWithM (checkPat' sizes) ps $ map Ascribed ts pure $ PatConstr n (Info (Scalar (Sum cs))) ps' loc checkPat' sizes (PatConstr n NoInfo ps loc) (Ascribed t) = do t' <- newTypeVar loc "t" ps' <- forM ps $ \p -> do p_t <- newTypeVar (srclocOf p) "t" checkPat' sizes p $ Ascribed p_t mustHaveConstr usage n (toStruct t') (patternStructType <$> ps') unify usage t' (toStruct t) pure $ PatConstr n (Info t) ps' loc where usage = mkUsage loc "matching against constructor" checkPat' sizes (PatConstr n NoInfo ps loc) NoneInferred = do ps' <- mapM (\p -> checkPat' sizes p NoneInferred) ps t <- newTypeVar loc "t" mustHaveConstr usage n (toStruct t) (patternStructType <$> ps') pure $ PatConstr n (Info t) ps' loc where usage = mkUsage loc "matching against constructor" checkPat :: [(SizeBinder VName, QualName VName)] -> PatBase NoInfo VName (TypeBase Size u) -> Inferred StructType -> (Pat ParamType -> TermTypeM a) -> TermTypeM a checkPat sizes p t m = do p' <- onFailure (CheckingPat (fmap toStruct p) t) $ checkPat' sizes (fmap (toParam Observe) p) (fmap (toParam Observe) t) let explicit = mustBeExplicitInType $ patternStructType p' case filter ((`S.member` explicit) . sizeName . fst) sizes of (size, _) : _ -> typeError size mempty $ "Cannot bind" <+> pretty size <+> "as it is never used as the size of a concrete (non-function) value." [] -> m p' -- | Check and bind a @let@-pattern. bindingPat :: [SizeBinder VName] -> PatBase NoInfo VName (TypeBase Size u) -> StructType -> (Pat ParamType -> TermTypeM a) -> TermTypeM a bindingPat sizes p t m = do substs <- mapM mkSizeSubst sizes checkPat substs p (Ascribed t) $ \p' -> binding (patIdents (fmap toStruct p')) $ case filter ((`S.notMember` fvVars (freeInPat p')) . sizeName) sizes of [] -> m p' size : _ -> unusedSize size where mkSizeSubst v = do v' <- newID $ baseName $ sizeName v constrain v' . Size Nothing $ mkUsage v "ambiguous size of bound expression" pure (v, qualName v') -- | Check and bind type and value parameters. bindingParams :: [TypeParam] -> [PatBase NoInfo VName ParamType] -> ([Pat ParamType] -> TermTypeM a) -> TermTypeM a bindingParams tps orig_ps m = bindingTypeParams tps $ do let descend ps' (p : ps) = checkPat [] p NoneInferred $ \p' -> binding (patIdents $ fmap toStruct p') $ incLevel $ descend (p' : ps') ps descend ps' [] = m $ reverse ps' incLevel $ descend [] orig_ps futhark-0.25.27/src/Language/Futhark/TypeChecker/Types.hs000066400000000000000000000470241475065116200231350ustar00rootroot00000000000000-- | Type checker building blocks that do not involve unification. module Language.Futhark.TypeChecker.Types ( checkTypeExp, renameRetType, typeParamToArg, Subst (..), substFromAbbr, TypeSubs, Substitutable (..), substTypesAny, -- * Witnesses mustBeExplicitInType, mustBeExplicitInBinding, determineSizeWitnesses, ) where import Control.Monad import Control.Monad.Identity import Control.Monad.State import Data.Bifunctor import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Set qualified as S import Futhark.Util (nubOrd) import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Monad mustBeExplicitAux :: StructType -> M.Map VName Bool mustBeExplicitAux t = execState (traverseDims onDim t) mempty where onDim bound _ (Var d _ _) | qualLeaf d `S.member` bound = modify $ \s -> M.insertWith (&&) (qualLeaf d) False s onDim _ PosImmediate (Var d _ _) = modify $ \s -> M.insertWith (&&) (qualLeaf d) False s onDim _ _ e = modify $ flip (S.foldr (\v -> M.insertWith (&&) v True)) $ fvVars $ freeInExp e -- | Determine which of the sizes in a type are used as sizes outside -- of functions in the type, and which are not. The former are said -- to be "witnessed" by this type, while the latter are not. In -- practice, the latter means that the actual sizes must come from -- somewhere else. determineSizeWitnesses :: StructType -> (S.Set VName, S.Set VName) determineSizeWitnesses t = bimap (S.fromList . M.keys) (S.fromList . M.keys) $ M.partition not $ mustBeExplicitAux t -- | Figure out which of the sizes in a binding type must be passed -- explicitly, because their first use is as something else than just -- an array dimension. mustBeExplicitInBinding :: StructType -> S.Set VName mustBeExplicitInBinding bind_t = let (ts, ret) = unfoldFunType bind_t alsoRet = M.unionWith (&&) $ M.fromList $ map (,True) (S.toList (fvVars (freeInType ret))) in S.fromList $ M.keys $ M.filter id $ alsoRet $ L.foldl' onType mempty $ map toStruct ts where onType uses t = uses <> mustBeExplicitAux t -- Left-biased union. -- | Figure out which of the sizes in a parameter type must be passed -- explicitly, because their first use is as something else than just -- an array dimension. mustBeExplicitInType :: StructType -> S.Set VName mustBeExplicitInType = snd . determineSizeWitnesses -- | Ensure that the dimensions of the RetType are unique by -- generating new names for them. This is to avoid name capture. renameRetType :: (MonadTypeChecker m) => ResRetType -> m ResRetType renameRetType (RetType dims st) | dims /= mempty = do dims' <- mapM newName dims let mkSubst = ExpSubst . flip sizeFromName mempty . qualName m = M.fromList . zip dims $ map mkSubst dims' st' = applySubst (`M.lookup` m) st pure $ RetType dims' st' | otherwise = pure $ RetType dims st evalTypeExp :: (MonadTypeChecker m, Pretty df) => (df -> m Exp) -> TypeExp df VName -> m (TypeExp Exp VName, [VName], ResRetType, Liftedness) evalTypeExp _ (TEVar name loc) = do (ps, t, l) <- lookupType name t' <- renameRetType $ toResRet Nonunique t case ps of [] -> pure (TEVar name loc, [], t', l) _ -> typeError loc mempty $ "Type constructor" <+> dquotes (hsep (pretty name : map pretty ps)) <+> "used without any arguments." -- evalTypeExp df (TEParens te loc) = do (te', svars, ts, ls) <- evalTypeExp df te pure (TEParens te' loc, svars, ts, ls) -- evalTypeExp df (TETuple ts loc) = do (ts', svars, ts_s, ls) <- L.unzip4 <$> mapM (evalTypeExp df) ts pure ( TETuple ts' loc, mconcat svars, RetType (foldMap retDims ts_s) $ Scalar $ tupleRecord $ map retType ts_s, L.foldl' max Unlifted ls ) -- evalTypeExp df t@(TERecord fs loc) = do -- Check for duplicate field names. let field_names = map fst fs unless (L.sort field_names == L.sort (nubOrd field_names)) $ typeError loc mempty $ "Duplicate record fields in" <+> pretty t <> "." checked <- traverse (evalTypeExp df) $ M.fromList fs let fs' = fmap (\(x, _, _, _) -> x) checked fs_svars = foldMap (\(_, y, _, _) -> y) checked ts_s = fmap (\(_, _, z, _) -> z) checked ls = fmap (\(_, _, _, v) -> v) checked pure ( TERecord (M.toList fs') loc, fs_svars, RetType (foldMap retDims ts_s) . Scalar . Record $ M.mapKeys unLoc $ M.map retType ts_s, L.foldl' max Unlifted ls ) -- evalTypeExp df (TEArray d t loc) = do (d_svars, d', d'') <- checkSizeExp d (t', svars, RetType dims st, l) <- evalTypeExp df t case (l, arrayOfWithAliases Nonunique (Shape [d'']) st) of (Unlifted, st') -> pure ( TEArray d' t' loc, svars, RetType (d_svars ++ dims) st', Unlifted ) (SizeLifted, _) -> typeError loc mempty $ "Cannot create array with elements of size-lifted type" <+> dquotes (pretty t) <+> "(might cause irregular array)." (Lifted, _) -> typeError loc mempty $ "Cannot create array with elements of lifted type" <+> dquotes (pretty t) <+> "(might contain function)." where checkSizeExp (SizeExpAny dloc) = do dv <- newTypeName "d" pure ([dv], SizeExpAny dloc, sizeFromName (qualName dv) dloc) checkSizeExp (SizeExp e dloc) = do e' <- df e pure ([], SizeExp e' dloc, e') -- evalTypeExp df (TEUnique t loc) = do (t', svars, RetType dims st, l) <- evalTypeExp df t unless (mayContainArray st) $ warn loc $ "Declaring" <+> dquotes (pretty st) <+> "as unique has no effect." pure (TEUnique t' loc, svars, RetType dims $ st `setUniqueness` Unique, l) where mayContainArray (Scalar Prim {}) = False mayContainArray Array {} = True mayContainArray (Scalar (Record fs)) = any mayContainArray fs mayContainArray (Scalar TypeVar {}) = True mayContainArray (Scalar Arrow {}) = False mayContainArray (Scalar (Sum cs)) = (any . any) mayContainArray cs -- evalTypeExp df (TEArrow (Just v) t1 t2 loc) = do (t1', svars1, RetType dims1 st1, _) <- evalTypeExp df t1 bindVal v (BoundV [] $ toStruct st1) $ do (t2', svars2, RetType dims2 st2, _) <- evalTypeExp df t2 pure ( TEArrow (Just v) t1' t2' loc, svars1 ++ dims1 ++ svars2, RetType [] $ Scalar $ Arrow Nonunique (Named v) (diet $ resToParam st1) (toStruct st1) (RetType dims2 st2), Lifted ) -- evalTypeExp df (TEArrow Nothing t1 t2 loc) = do (t1', svars1, RetType dims1 st1, _) <- evalTypeExp df t1 (t2', svars2, RetType dims2 st2, _) <- evalTypeExp df t2 pure ( TEArrow Nothing t1' t2' loc, svars1 ++ dims1 ++ svars2, RetType [] . Scalar $ Arrow Nonunique Unnamed (diet $ resToParam st1) (toStruct st1) $ RetType dims2 st2, Lifted ) -- evalTypeExp df (TEDim dims t loc) = do bindDims dims $ do (t', svars, RetType t_dims st, l) <- evalTypeExp df t let (witnessed, _) = determineSizeWitnesses $ toStruct st case L.find (`S.notMember` witnessed) dims of Just d -> typeError loc mempty . withIndexLink "unused-existential" $ "Existential size " <> dquotes (prettyName d) <> " not used as array size." Nothing -> pure ( TEDim dims t' loc, svars, RetType (dims ++ t_dims) st, max l SizeLifted ) where bindDims [] m = m bindDims (d : ds) m = bindVal d (BoundV [] $ Scalar $ Prim $ Signed Int64) $ bindDims ds m -- evalTypeExp df t@(TESum cs loc) = do let constructors = map fst cs unless (L.sort constructors == L.sort (nubOrd constructors)) $ typeError loc mempty $ "Duplicate constructors in" <+> pretty t unless (length constructors < 256) $ typeError loc mempty "Sum types must have less than 256 constructors." checked <- (traverse . traverse) (evalTypeExp df) $ M.fromList cs let cs' = (fmap . fmap) (\(x, _, _, _) -> x) checked cs_svars = (foldMap . foldMap) (\(_, y, _, _) -> y) checked ts_s = (fmap . fmap) (\(_, _, z, _) -> z) checked ls = (concatMap . fmap) (\(_, _, _, v) -> v) checked pure ( TESum (M.toList cs') loc, cs_svars, RetType (foldMap (foldMap retDims) ts_s) $ Scalar $ Sum $ M.map (map retType) ts_s, L.foldl' max Unlifted ls ) evalTypeExp df ote@TEApply {} = do (tname, tname_loc, targs) <- rootAndArgs ote (ps, tname_t, l) <- lookupType tname RetType t_dims t <- renameRetType $ toResRet Nonunique tname_t if length ps /= length targs then typeError tloc mempty $ "Type constructor" <+> dquotes (pretty tname) <+> "requires" <+> pretty (length ps) <+> "arguments, but provided" <+> pretty (length targs) <> "." else do (targs', dims, substs) <- unzip3 <$> zipWithM checkArgApply ps targs pure ( foldl (\x y -> TEApply x y tloc) (TEVar tname tname_loc) targs', [], RetType (t_dims ++ mconcat dims) $ applySubst (`M.lookup` mconcat substs) t, l ) where tloc = srclocOf ote rootAndArgs (TEVar qn loc) = pure (qn, loc, []) rootAndArgs (TEApply op arg _) = do (op', loc, args) <- rootAndArgs op pure (op', loc, args ++ [arg]) rootAndArgs te' = typeError (srclocOf te') mempty $ "Type" <+> dquotes (pretty te') <+> "is not a type constructor." checkSizeExp (SizeExp e dloc) = do e' <- df e pure ( TypeArgExpSize (SizeExp e' dloc), [], ExpSubst e' ) checkSizeExp (SizeExpAny loc) = do d <- newTypeName "d" pure ( TypeArgExpSize (SizeExpAny loc), [d], ExpSubst $ sizeFromName (qualName d) loc ) checkArgApply (TypeParamDim pv _) (TypeArgExpSize d) = do (d', svars, subst) <- checkSizeExp d pure (d', svars, M.singleton pv subst) checkArgApply (TypeParamType _ pv _) (TypeArgExpType te) = do (te', svars, RetType dims st, _) <- evalTypeExp df te pure ( TypeArgExpType te', svars ++ dims, M.singleton pv $ Subst [] $ RetType [] $ toStruct st ) checkArgApply p a = typeError tloc mempty $ "Type argument" <+> pretty a <+> "not valid for a type parameter" <+> pretty p <> "." -- | Check a type expression, producing: -- -- * The checked expression. -- * Size variables for any anonymous sizes in the expression. -- * The elaborated type. -- * The liftedness of the type. checkTypeExp :: (MonadTypeChecker m, Pretty df) => (df -> m Exp) -> TypeExp df VName -> m (TypeExp Exp VName, [VName], ResRetType, Liftedness) checkTypeExp = evalTypeExp -- | Construct a type argument corresponding to a type parameter. typeParamToArg :: TypeParam -> StructTypeArg typeParamToArg (TypeParamDim v ploc) = TypeArgDim $ sizeFromName (qualName v) ploc typeParamToArg (TypeParamType _ v _) = TypeArgType $ Scalar $ TypeVar mempty (qualName v) [] -- | A type substitution may be a substitution or a yet-unknown -- substitution (but which is certainly an overloaded primitive -- type!). data Subst t = Subst [TypeParam] t | ExpSubst Exp deriving (Show) instance (Pretty t) => Pretty (Subst t) where pretty (Subst [] t) = pretty t pretty (Subst tps t) = mconcat (map pretty tps) <> colon <+> pretty t pretty (ExpSubst e) = pretty e instance Functor Subst where fmap f (Subst ps t) = Subst ps $ f t fmap _ (ExpSubst e) = ExpSubst e -- | Create a type substitution corresponding to a type binding. substFromAbbr :: TypeBinding -> Subst StructRetType substFromAbbr (TypeAbbr _ ps rt) = Subst ps rt -- | Substitutions to apply in a type. type TypeSubs = VName -> Maybe (Subst StructRetType) -- | Class of types which allow for substitution of types with no -- annotations for type variable names. class Substitutable a where applySubst :: TypeSubs -> a -> a instance Substitutable (RetTypeBase Size Uniqueness) where applySubst f (RetType dims t) = let RetType more_dims t' = substTypesRet f' t in RetType (dims ++ more_dims) t' where f' = fmap (fmap (second (const mempty))) . f instance Substitutable (RetTypeBase Size NoUniqueness) where applySubst f (RetType dims t) = let RetType more_dims t' = substTypesRet f t in RetType (dims ++ more_dims) t' instance Substitutable StructType where applySubst = substTypesAny instance Substitutable ParamType where applySubst f = substTypesAny $ fmap (fmap $ second $ const Observe) . f instance Substitutable (TypeBase Size Uniqueness) where applySubst f = substTypesAny $ fmap (fmap $ second $ const Nonunique) . f instance Substitutable Exp where applySubst f = runIdentity . mapOnExp where mapOnExp (Var (QualName _ v) _ _) | Just (ExpSubst e') <- f v = pure e' mapOnExp e' = astMap mapper e' mapper = ASTMapper { mapOnExp, mapOnName = pure, mapOnStructType = pure . applySubst f, mapOnParamType = pure . applySubst f, mapOnResRetType = pure . applySubst f } instance (Substitutable d) => Substitutable (Shape d) where applySubst f = fmap $ applySubst f instance Substitutable (Pat StructType) where applySubst f = runIdentity . astMap mapper where mapper = ASTMapper { mapOnExp = pure . applySubst f, mapOnName = pure, mapOnStructType = pure . applySubst f, mapOnParamType = pure . applySubst f, mapOnResRetType = pure . applySubst f } instance Substitutable (Pat ParamType) where applySubst f = runIdentity . astMap mapper where mapper = ASTMapper { mapOnExp = pure . applySubst f, mapOnName = pure, mapOnStructType = pure . applySubst f, mapOnParamType = pure . applySubst f, mapOnResRetType = pure . applySubst f } applyType :: (Monoid u) => [TypeParam] -> TypeBase Size u -> [StructTypeArg] -> TypeBase Size u applyType ps t args = substTypesAny (`M.lookup` substs) t where substs = M.fromList $ zipWith mkSubst ps args -- We are assuming everything has already been type-checked for correctness. mkSubst (TypeParamDim pv _) (TypeArgDim e) = (pv, ExpSubst e) mkSubst (TypeParamType _ pv _) (TypeArgType at) = (pv, Subst [] $ RetType [] $ second mempty at) mkSubst p a = error $ "applyType mkSubst: cannot substitute " ++ prettyString a ++ " for " ++ prettyString p substTypesRet :: (Monoid u) => (VName -> Maybe (Subst (RetTypeBase Size u))) -> TypeBase Size u -> RetTypeBase Size u substTypesRet lookupSubst ot = uncurry (flip RetType) $ runState (onType ot) [] where -- In case we are substituting the same RetType in multiple -- places, we must ensure each instance is given distinct -- dimensions. E.g. substituting 'a ↦ ?[n].[n]bool' into '(a,a)' -- should give '?[n][m].([n]bool,[m]bool)'. -- -- XXX: the size names we invent here not globally unique. This -- is _probably_ not a problem, since substituting types with -- outermost non-null existential sizes is done only when type -- checking modules and monomorphising. freshDims (RetType [] t) = pure $ RetType [] t freshDims (RetType ext t) = do seen_ext <- get if not $ any (`elem` seen_ext) ext then pure $ RetType ext t else do let start = maximum $ map baseTag seen_ext ext' = zipWith VName (map baseName ext) [start + 1 ..] mkSubst = ExpSubst . flip sizeFromName mempty . qualName extsubsts = M.fromList $ zip ext $ map mkSubst ext' RetType [] t' = substTypesRet (`M.lookup` extsubsts) t pure $ RetType ext' t' onType :: forall as. (Monoid as) => TypeBase Size as -> State [VName] (TypeBase Size as) onType (Array u shape et) = arrayOfWithAliases u (applySubst lookupSubst' shape) <$> onType (Scalar et) onType (Scalar (Prim t)) = pure $ Scalar $ Prim t onType (Scalar (TypeVar u v targs)) = do targs' <- mapM subsTypeArg targs case lookupSubst $ qualLeaf v of Just (Subst ps rt) -> do RetType ext t <- freshDims rt modify (ext ++) pure $ second (<> u) $ applyType ps (second (const u) t) targs' _ -> pure $ Scalar $ TypeVar u v targs' onType (Scalar (Record ts)) = Scalar . Record <$> traverse onType ts onType (Scalar (Arrow u v d t1 t2)) = Scalar <$> (Arrow u v d <$> onType t1 <*> onRetType t2) onType (Scalar (Sum ts)) = Scalar . Sum <$> traverse (traverse onType) ts onRetType (RetType dims t) = do ext <- get let (t', ext') = runState (onType t) ext new_ext = ext' L.\\ ext case t of Scalar Arrow {} -> do put ext' pure $ RetType dims t' _ -> pure $ RetType (new_ext <> dims) t' subsTypeArg (TypeArgType t) = do let RetType dims t' = substTypesRet lookupSubst' t modify (dims ++) pure $ TypeArgType t' subsTypeArg (TypeArgDim v) = pure $ TypeArgDim $ applySubst lookupSubst' v lookupSubst' = fmap (fmap $ second (const NoUniqueness)) . lookupSubst -- | Perform substitutions, from type names to types, on a type. Works -- regardless of what shape and uniqueness information is attached to the type. substTypesAny :: (Monoid u) => (VName -> Maybe (Subst (RetTypeBase Size u))) -> TypeBase Size u -> TypeBase Size u substTypesAny lookupSubst ot = case substTypesRet lookupSubst ot of RetType [] ot' -> ot' RetType dims ot' -> -- XXX HACK FIXME: turn any sizes that propagate to the top into -- AnySize. This should _never_ happen during type-checking, but -- may happen as we substitute types during monomorphisation and -- defunctorisation later on. See Note [AnySize] let toAny (Var v _ _) | qualLeaf v `elem` dims = anySize toAny d = d in first toAny ot' -- Note [AnySize] -- -- Consider a program: -- -- module m : { type~ t } = { type~ t = ?[n].[n]bool } -- let f (x: m.t) (y: m.t) = 0 -- -- After defunctorisation (and inlining the definitions of types), we -- want this: -- -- let f [n][m] (x: [n]bool) (y: [m]bool) = 0 -- -- But this means that defunctorisation would need to redo some amount -- of size inference. Not so complicated in the example above, but -- what if loops and branches are involved? -- -- So instead, what defunctorisation actually does is produce this: -- -- let f (x: []bool) (y: []bool) = 0 -- -- I.e. we put in empty dimensions (AnySize), which are much later -- turned into distinct sizes in Futhark.Internalise.Exps. This will -- result in unnecessary dynamic size checks, which will hopefully be -- optimised away. -- -- Important: The type checker will _never_ produce programs with -- AnySize, but unfortunately some of the compilation steps -- (defunctorisation, monomorphisation, defunctionalisation) will do -- so. Similarly, the core language is also perfectly well behaved. -- -- Example with monomorphisation: -- -- let f '~a (b: bool) (x: () -> a) (y: () -> a) : a = if b then x () else y () -- let g = f true (\() -> [1]) (\() -> [1,2]) -- -- This should produce: -- -- let f (b: bool) (x: () -> ?[n].[n]i32) (y: () -> ?[m].[m]i32) : ?[k].[k]i32 = -- if b then x () else y () -- -- Not so easy! Again, what we actually produce is -- -- let f (b: bool) (x: () -> []i32) (y: () -> []i32) : []i32 = -- if b then x () else y () futhark-0.25.27/src/Language/Futhark/TypeChecker/Unify.hs000066400000000000000000001327461475065116200231310ustar00rootroot00000000000000-- | Implementation of unification and other core type system building -- blocks. module Language.Futhark.TypeChecker.Unify ( Constraint (..), Usage (..), mkUsage, mkUsage', Level, Constraints, MonadUnify (..), Rigidity (..), RigidSource (..), BreadCrumbs, sizeFree, noBreadCrumbs, hasNoBreadCrumbs, dimNotes, zeroOrderType, arrayElemType, mustHaveConstr, mustHaveField, mustBeOneOf, equalityType, normType, normTypeFully, unify, unifyMostCommon, doUnification, ) where import Control.Monad import Control.Monad.Except import Control.Monad.Identity import Control.Monad.Reader import Control.Monad.State import Data.List qualified as L import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S import Data.Text qualified as T import Futhark.Util (topologicalSort) import Futhark.Util.Pretty import Language.Futhark import Language.Futhark.Traversals import Language.Futhark.TypeChecker.Monad hiding (BoundV) import Language.Futhark.TypeChecker.Types -- | A piece of information that describes what process the type -- checker currently performing. This is used to give better error -- messages for unification errors. data BreadCrumb = MatchingTypes StructType StructType | MatchingFields [Name] | MatchingConstructor Name | Matching (Doc ()) instance Pretty BreadCrumb where pretty (MatchingTypes t1 t2) = "When matching type" indent 2 (pretty t1) "with" indent 2 (pretty t2) pretty (MatchingFields fields) = "When matching types of record field" <+> dquotes (mconcat $ punctuate "." $ map pretty fields) <> dot pretty (MatchingConstructor c) = "When matching types of constructor" <+> dquotes (pretty c) <> dot pretty (Matching s) = unAnnotate s -- | Unification failures can occur deep down inside complicated types -- (consider nested records). We leave breadcrumbs behind us so we -- can report the path we took to find the mismatch. newtype BreadCrumbs = BreadCrumbs [BreadCrumb] -- | An empty path. noBreadCrumbs :: BreadCrumbs noBreadCrumbs = BreadCrumbs [] -- | Is the path empty? hasNoBreadCrumbs :: BreadCrumbs -> Bool hasNoBreadCrumbs (BreadCrumbs xs) = null xs -- | Drop a breadcrumb on the path behind you. breadCrumb :: BreadCrumb -> BreadCrumbs -> BreadCrumbs breadCrumb (MatchingFields xs) (BreadCrumbs (MatchingFields ys : bcs)) = BreadCrumbs $ MatchingFields (ys ++ xs) : bcs breadCrumb bc (BreadCrumbs bcs) = BreadCrumbs $ bc : bcs instance Pretty BreadCrumbs where pretty (BreadCrumbs []) = mempty pretty (BreadCrumbs bcs) = line <> stack (map pretty bcs) -- | A usage that caused a type constraint. data Usage = Usage (Maybe T.Text) Loc deriving (Show) -- | Construct a 'Usage' from a location and a description. mkUsage :: (Located a) => a -> T.Text -> Usage mkUsage = flip (Usage . Just) . locOf -- | Construct a 'Usage' that has just a location, but no particular -- description. mkUsage' :: (Located a) => a -> Usage mkUsage' = Usage Nothing . locOf instance Pretty Usage where pretty (Usage Nothing loc) = "use at " <> textwrap (locText loc) pretty (Usage (Just s) loc) = textwrap s <+> "at" <+> textwrap (locText loc) instance Located Usage where locOf (Usage _ loc) = locOf loc -- | The level at which a type variable is bound. Higher means -- deeper. We can only unify a type variable at level @i@ with a type -- @t@ if all type names that occur in @t@ are at most at level @i@. type Level = Int -- | A constraint on a yet-ambiguous type variable. data Constraint = NoConstraint Liftedness Usage | ParamType Liftedness Loc | Constraint StructRetType Usage | Overloaded [PrimType] Usage | HasFields Liftedness (M.Map Name StructType) Usage | Equality Usage | HasConstrs Liftedness (M.Map Name [StructType]) Usage | ParamSize Loc | -- | Is not actually a type, but a term-level size, -- possibly already set to something specific. Size (Maybe Exp) Usage | -- | A size that does not unify with anything - -- created from the result of applying a function -- whose return size is existential, or otherwise -- hiding a size. UnknownSize Loc RigidSource deriving (Show) instance Located Constraint where locOf (NoConstraint _ usage) = locOf usage locOf (ParamType _ usage) = locOf usage locOf (Constraint _ usage) = locOf usage locOf (Overloaded _ usage) = locOf usage locOf (HasFields _ _ usage) = locOf usage locOf (Equality usage) = locOf usage locOf (HasConstrs _ _ usage) = locOf usage locOf (ParamSize loc) = locOf loc locOf (Size _ usage) = locOf usage locOf (UnknownSize loc _) = locOf loc -- | Mapping from fresh type variables, instantiated from the type -- schemes of polymorphic functions, to (possibly) specific types as -- determined on application and the location of that application, or -- a partial constraint on their type. type Constraints = M.Map VName (Level, Constraint) lookupSubst :: VName -> Constraints -> Maybe (Subst StructRetType) lookupSubst v constraints = case snd <$> M.lookup v constraints of Just (Constraint t _) -> Just $ Subst [] $ applySubst (`lookupSubst` constraints) t Just (Size (Just d) _) -> Just $ ExpSubst $ applySubst (`lookupSubst` constraints) d _ -> Nothing -- | The source of a rigid size. data RigidSource = -- | A function argument that is not a constant or variable name. RigidArg (Maybe (QualName VName)) T.Text | -- | An existential return size. RigidRet (Maybe (QualName VName)) | -- | Similarly to 'RigidRet', but produce by a loop. RigidLoop | -- | Produced by a complicated slice expression. RigidSlice (Maybe Size) T.Text | -- | Produced by a complicated range expression. RigidRange | -- | Mismatch in branches. RigidCond StructType StructType | -- | Invented during unification. RigidUnify | -- | A name used in a size went out of scope. RigidOutOfScope Loc VName deriving (Eq, Ord, Show) -- | The ridigity of a size variable. All rigid sizes are tagged with -- information about how they were generated. data Rigidity = Rigid RigidSource | Nonrigid deriving (Eq, Ord, Show) prettySource :: Loc -> Loc -> RigidSource -> Doc () prettySource ctx loc (RigidRet Nothing) = "is unknown size returned by function at" <+> pretty (locStrRel ctx loc) <> "." prettySource ctx loc (RigidRet (Just fname)) = "is unknown size returned by" <+> dquotes (pretty fname) <+> "at" <+> pretty (locStrRel ctx loc) <> "." prettySource ctx loc (RigidArg fname arg) = "is value of argument" indent 2 (shorten (pretty arg)) "passed to" <+> fname' <+> "at" <+> pretty (locStrRel ctx loc) <> "." where fname' = maybe "function" (dquotes . pretty) fname prettySource ctx loc (RigidSlice d slice) = "is size produced by slice" indent 2 (shorten (pretty slice)) d_desc <> "at" <+> pretty (locStrRel ctx loc) <> "." where d_desc = case d of Just d' -> "of dimension of size " <> dquotes (pretty d') <> " " Nothing -> mempty prettySource ctx loc RigidLoop = "is unknown size of value returned at" <+> pretty (locStrRel ctx loc) <> "." prettySource ctx loc RigidRange = "is unknown length of range at" <+> pretty (locStrRel ctx loc) <> "." prettySource ctx loc (RigidOutOfScope boundloc v) = "is an unknown size arising from " <> dquotes (prettyName v) <> " going out of scope at " <> pretty (locStrRel ctx loc) <> "." "Originally bound at " <> pretty (locStrRel ctx boundloc) <> "." prettySource _ _ RigidUnify = "is an artificial size invented during unification of functions with anonymous sizes." prettySource ctx loc (RigidCond t1 t2) = "is unknown due to conditional expression at " <> pretty (locStrRel ctx loc) <> "." "One branch returns array of type: " <> align (pretty t1) "The other an array of type: " <> align (pretty t2) -- | Retrieve notes describing the purpose or origin of the given -- t'Size'. The location is used as the *current* location, for the -- purpose of reporting relative locations. dimNotes :: (Located a, MonadUnify m) => a -> Exp -> m Notes dimNotes ctx (Var d _ _) = do c <- M.lookup (qualLeaf d) <$> getConstraints case c of Just (_, UnknownSize loc rsrc) -> pure . aNote $ dquotes (pretty d) <+> prettySource (locOf ctx) loc rsrc _ -> pure mempty dimNotes _ _ = pure mempty typeNotes :: (Located a, MonadUnify m) => a -> StructType -> m Notes typeNotes ctx = fmap mconcat . mapM (dimNotes ctx . flip sizeFromName mempty . qualName) . S.toList . fvVars . freeInType typeVarNotes :: (MonadUnify m) => VName -> m Notes typeVarNotes v = maybe mempty (note . snd) . M.lookup v <$> getConstraints where note (HasConstrs _ cs _) = aNote $ prettyName v <+> "=" <+> hsep (map ppConstr (M.toList cs)) <+> "..." note (Overloaded ts _) = aNote $ prettyName v <+> "must be one of" <+> mconcat (punctuate ", " (map pretty ts)) note (HasFields _ fs _) = aNote $ prettyName v <+> "=" <+> braces (mconcat (punctuate ", " (map ppField (M.toList fs)))) note _ = mempty ppConstr (c, _) = "#" <> pretty c <+> "..." <+> "|" ppField (f, _) = prettyName f <> ":" <+> "..." -- | Monads that which to perform unification must implement this type -- class. class (Monad m) => MonadUnify m where getConstraints :: m Constraints putConstraints :: Constraints -> m () modifyConstraints :: (Constraints -> Constraints) -> m () modifyConstraints f = do x <- getConstraints putConstraints $ f x newTypeVar :: (Monoid als, Located a) => a -> Name -> m (TypeBase dim als) newDimVar :: Usage -> Rigidity -> Name -> m VName newRigidDim :: (Located a) => a -> RigidSource -> Name -> m VName newRigidDim loc = newDimVar (mkUsage' loc) . Rigid newFlexibleDim :: Usage -> Name -> m VName newFlexibleDim usage = newDimVar usage Nonrigid curLevel :: m Level matchError :: (Located loc) => loc -> Notes -> BreadCrumbs -> StructType -> StructType -> m a unifyError :: (Located loc) => loc -> Notes -> BreadCrumbs -> Doc () -> m a -- | Replace all type variables with their substitution. normTypeFully :: (Substitutable a, MonadUnify m) => a -> m a normTypeFully t = do constraints <- getConstraints pure $ applySubst (`lookupSubst` constraints) t -- | Replace any top-level type variable with its substitution. normType :: (MonadUnify m) => StructType -> m StructType normType t@(Scalar (TypeVar _ (QualName [] v) [])) = do constraints <- getConstraints case snd <$> M.lookup v constraints of Just (Constraint (RetType [] t') _) -> normType t' _ -> pure t normType t = pure t rigidConstraint :: Constraint -> Bool rigidConstraint ParamType {} = True rigidConstraint ParamSize {} = True rigidConstraint UnknownSize {} = True rigidConstraint _ = False unsharedConstructorsMsg :: M.Map Name t -> M.Map Name t -> Doc a unsharedConstructorsMsg cs1 cs2 = "Unshared constructors:" <+> commasep (map (("#" <>) . pretty) missing) <> "." where missing = filter (`notElem` M.keys cs1) (M.keys cs2) ++ filter (`notElem` M.keys cs2) (M.keys cs1) -- | Is the given type variable the name of an abstract type or type -- parameter, which we cannot substitute? isRigid :: VName -> Constraints -> Bool isRigid v constraints = maybe True (rigidConstraint . snd) $ M.lookup v constraints -- | If the given type variable is nonrigid, what is its level? isNonRigid :: VName -> Constraints -> Maybe Level isNonRigid v constraints = do (lvl, c) <- M.lookup v constraints guard $ not $ rigidConstraint c pure lvl type UnifySizes m = BreadCrumbs -> [VName] -> (VName -> Maybe Int) -> Exp -> Exp -> m () flipUnifySizes :: UnifySizes m -> UnifySizes m flipUnifySizes onDims bcs bound nonrigid t1 t2 = onDims bcs bound nonrigid t2 t1 unifyWith :: (MonadUnify m) => UnifySizes m -> Usage -> [VName] -> BreadCrumbs -> StructType -> StructType -> m () unifyWith onDims usage = subunify False where swap True x y = (y, x) swap False x y = (x, y) subunify ord bound bcs t1 t2 = do constraints <- getConstraints t1' <- normType t1 t2' <- normType t2 let nonrigid v = isNonRigid v constraints failure = matchError (srclocOf usage) mempty bcs t1' t2' link ord' = linkVarToType linkDims usage bound bcs where -- We may have to flip the order of future calls to -- onDims inside linkVarToType. linkDims | ord' = flipUnifySizes onDims | otherwise = onDims unifyTypeArg bcs' (TypeArgDim d1) (TypeArgDim d2) = onDims' bcs' (swap ord d1 d2) unifyTypeArg bcs' (TypeArgType t) (TypeArgType arg_t) = subunify ord bound bcs' t arg_t unifyTypeArg bcs' _ _ = unifyError usage mempty bcs' "Cannot unify a type argument with a dimension argument (or vice versa)." onDims' bcs' (d1, d2) = onDims bcs' bound nonrigid (applySubst (`lookupSubst` constraints) d1) (applySubst (`lookupSubst` constraints) d2) case (t1', t2') of (Scalar (Prim pt1), Scalar (Prim pt2)) | pt1 == pt2 -> pure () ( Scalar (Record fs), Scalar (Record arg_fs) ) | M.keys fs == M.keys arg_fs -> unifySharedFields onDims usage bound bcs fs arg_fs | otherwise -> do let missing = filter (`notElem` M.keys arg_fs) (M.keys fs) ++ filter (`notElem` M.keys fs) (M.keys arg_fs) unifyError usage mempty bcs $ "Unshared fields:" <+> commasep (map pretty missing) <> "." ( Scalar (TypeVar _ (QualName _ tn) targs), Scalar (TypeVar _ (QualName _ arg_tn) arg_targs) ) | tn == arg_tn, length targs == length arg_targs -> do let bcs' = breadCrumb (Matching "When matching type arguments.") bcs zipWithM_ (unifyTypeArg bcs') targs arg_targs ( Scalar (TypeVar _ (QualName [] v1) []), Scalar (TypeVar _ (QualName [] v2) []) ) -> case (nonrigid v1, nonrigid v2) of (Nothing, Nothing) -> failure (Just lvl1, Nothing) -> link ord v1 lvl1 t2' (Nothing, Just lvl2) -> link (not ord) v2 lvl2 t1' (Just lvl1, Just lvl2) | lvl1 <= lvl2 -> link ord v1 lvl1 t2' | otherwise -> link (not ord) v2 lvl2 t1' (Scalar (TypeVar _ (QualName [] v1) []), _) | Just lvl <- nonrigid v1 -> link ord v1 lvl t2' (_, Scalar (TypeVar _ (QualName [] v2) [])) | Just lvl <- nonrigid v2 -> link (not ord) v2 lvl t1' ( Scalar (Arrow _ p1 d1 a1 (RetType b1_dims b1)), Scalar (Arrow _ p2 d2 a2 (RetType b2_dims b2)) ) | uncurry (<) $ swap ord d1 d2 -> do unifyError usage mempty bcs . withIndexLink "unify-consuming-param" $ "Parameters" indent 2 (pretty d1 <> pretty a1) "and" indent 2 (pretty d2 <> pretty a2) "are incompatible regarding consuming their arguments." | uncurry (<) $ swap ord (uniqueness b2) (uniqueness b1) -> do unifyError usage mempty bcs . withIndexLink "unify-return-uniqueness" $ "Return types" indent 2 (pretty b1) "and" indent 2 (pretty b2) "have incompatible uniqueness." | otherwise -> do -- Introduce the existentials as size variables so they -- are subject to unification. We will remove them again -- afterwards. let (r1, r2) = swap ord (Size Nothing $ Usage Nothing mempty) (UnknownSize mempty RigidUnify) lvl <- curLevel modifyConstraints (M.fromList (map (,(lvl, r1)) b1_dims) <>) modifyConstraints (M.fromList (map (,(lvl, r2)) b2_dims) <>) let bound' = bound <> mapMaybe pname [p1, p2] <> b1_dims <> b2_dims subunify (not ord) bound (breadCrumb (Matching "When matching parameter types.") bcs) a1 a2 subunify ord bound' (breadCrumb (Matching "When matching return types.") bcs) (toStruct b1') (toStruct b2') -- Delete the size variables we introduced to represent -- the existential sizes. modifyConstraints $ \m -> L.foldl' (flip M.delete) m (b1_dims <> b2_dims) where (b1', b2') = -- Replace one parameter name with the other in the -- return type, in case of dependent types. I.e., -- we want type '(n: i32) -> [n]i32' to unify with -- type '(x: i32) -> [x]i32'. case (p1, p2) of (Named p1', Named p2') -> let f v | v == p2' = Just $ ExpSubst $ sizeFromName (qualName p1') mempty | otherwise = Nothing in (b1, applySubst f b2) (_, _) -> (b1, b2) pname (Named x) = Just x pname Unnamed = Nothing (Array {}, Array {}) | Shape (t1_d : _) <- arrayShape t1', Shape (t2_d : _) <- arrayShape t2', Just t1'' <- peelArray 1 t1', Just t2'' <- peelArray 1 t2' -> do onDims' bcs (swap ord t1_d t2_d) subunify ord bound bcs t1'' t2'' ( Scalar (Sum cs), Scalar (Sum arg_cs) ) | M.keys cs == M.keys arg_cs -> unifySharedConstructors onDims usage bound bcs cs arg_cs | otherwise -> unifyError usage mempty bcs $ unsharedConstructorsMsg arg_cs cs _ -> failure anyBound :: [VName] -> ExpBase Info VName -> Bool anyBound bound e = any (`S.member` fvVars (freeInExp e)) bound unifySizes :: (MonadUnify m) => Usage -> UnifySizes m unifySizes usage bcs bound nonrigid e1 e2 | Just es <- similarExps e1 e2 = mapM_ (uncurry $ unifySizes usage bcs bound nonrigid) es unifySizes usage bcs bound nonrigid (Var v1 _ _) e2 | Just lvl1 <- nonrigid (qualLeaf v1), not (anyBound bound e2) || (qualLeaf v1 `elem` bound) = linkVarToDim usage bcs (qualLeaf v1) lvl1 e2 unifySizes usage bcs bound nonrigid e1 (Var v2 _ _) | Just lvl2 <- nonrigid (qualLeaf v2), not (anyBound bound e1) || (qualLeaf v2 `elem` bound) = linkVarToDim usage bcs (qualLeaf v2) lvl2 e1 unifySizes usage bcs _ _ e1 e2 = do notes <- (<>) <$> dimNotes usage e2 <*> dimNotes usage e2 unifyError usage notes bcs $ "Sizes" <+> dquotes (pretty e1) <+> "and" <+> dquotes (pretty e2) <+> "do not match." -- | Unifies two types. unify :: (MonadUnify m) => Usage -> StructType -> StructType -> m () unify usage = unifyWith (unifySizes usage) usage mempty noBreadCrumbs occursCheck :: (MonadUnify m) => Usage -> BreadCrumbs -> VName -> StructType -> m () occursCheck usage bcs vn tp = when (vn `S.member` typeVars tp) $ unifyError usage mempty bcs $ "Occurs check: cannot instantiate" <+> prettyName vn <+> "with" <+> pretty tp <> "." scopeCheck :: (MonadUnify m) => Usage -> BreadCrumbs -> VName -> Level -> StructType -> m () scopeCheck usage bcs vn max_lvl tp = do constraints <- getConstraints checkType constraints tp where checkType constraints t = mapM_ (check constraints) $ typeVars t <> fvVars (freeInType t) check constraints v | Just (lvl, c) <- M.lookup v constraints, lvl > max_lvl = if rigidConstraint c then scopeViolation v else modifyConstraints $ M.insert v (max_lvl, c) | otherwise = pure () scopeViolation v = do notes <- typeNotes usage tp unifyError usage notes bcs $ "Cannot unify type" indent 2 (pretty tp) "with" <+> dquotes (prettyName vn) <+> "(scope violation)." "This is because" <+> dquotes (prettyName v) <+> "is rigidly bound in a deeper scope." -- Expressions witnessed by type, topologically sorted. topWit :: TypeBase Exp u -> [Exp] topWit = topologicalSort depends . witnessedExps where witnessedExps t = execState (traverseDims onDim t) mempty where onDim _ PosImmediate e = modify (e :) onDim _ _ _ = pure () depends a b = any (sameExp b) $ subExps a sizeFree :: (MonadUnify m) => SrcLoc -> (Exp -> Maybe VName) -> TypeBase Size u -> m (TypeBase Size u, [VName]) sizeFree tloc expKiller orig_t = do runReaderT (toBeReplaced orig_t $ onType orig_t) mempty `runStateT` mempty where lookReplacement e repl = snd <$> L.find (sameExp e . fst) repl expReplace mapping e | Just e' <- lookReplacement e mapping = e' | otherwise = runIdentity $ astMap mapper e where mapper = identityMapper {mapOnExp = pure . expReplace mapping} replacing e = do e' <- asks (`expReplace` e) case expKiller e' of Nothing -> pure e' Just cause -> do vn <- lift $ lift $ newRigidDim tloc (RigidOutOfScope (locOf e) cause) "d" modify (vn :) pure $ sizeFromName (qualName vn) (srclocOf e) toBeReplaced t m' = foldl f m' $ topWit t where f m e = do e' <- replacing e local ((e, e') :) m onScalar (Record fs) = Record <$> traverse onType fs onScalar (Sum cs) = Sum <$> (traverse . traverse) onType cs onScalar (Arrow as pn d argT (RetType dims retT)) = do argT' <- onType argT old_bound <- get retT' <- toBeReplaced retT $ onType retT rl <- state $ L.partition (`notElem` old_bound) let dims' = dims <> rl pure $ Arrow as pn d argT' (RetType dims' retT') onScalar (TypeVar u v args) = TypeVar u v <$> mapM onTypeArg args where onTypeArg (TypeArgDim d) = TypeArgDim <$> replacing d onTypeArg (TypeArgType ty) = TypeArgType <$> onType ty onScalar (Prim pt) = pure $ Prim pt onType :: (MonadUnify m) => TypeBase Size u -> ReaderT [(Exp, Exp)] (StateT [VName] m) (TypeBase Size u) onType (Array u shape scalar) = Array u <$> traverse replacing shape <*> onScalar scalar onType (Scalar ty) = Scalar <$> onScalar ty linkVarToType :: (MonadUnify m) => UnifySizes m -> Usage -> [VName] -> BreadCrumbs -> VName -> Level -> StructType -> m () linkVarToType onDims usage bound bcs vn lvl tp_unnorm = do -- We have to expand anyway for the occurs check, so we might as -- well link the fully expanded type. tp <- normTypeFully tp_unnorm occursCheck usage bcs vn tp scopeCheck usage bcs vn lvl tp let link = do let (witnessed, not_witnessed) = determineSizeWitnesses tp used v = v `S.member` witnessed || v `S.member` not_witnessed (ext_witnessed, ext_not_witnessed) = L.partition (`elem` witnessed) $ filter used bound -- Any size that uses an ext_not_witnessed variable must -- be replaced with a fresh existential. problematic e = L.find (`elem` ext_not_witnessed) $ S.toList $ fvVars $ freeInExp e (tp', ext_new) <- sizeFree (srclocOf usage) problematic tp modifyConstraints $ M.insert vn (lvl, Constraint (RetType (ext_new <> ext_witnessed) tp') usage) let unliftedBcs unlifted_usage = breadCrumb ( Matching $ "When verifying that" <+> dquotes (prettyName vn) <+> textwrap "is not instantiated with a function type, due to" <+> pretty unlifted_usage ) bcs constraints <- getConstraints case snd <$> M.lookup vn constraints of Just (NoConstraint Unlifted unlift_usage) -> do link arrayElemTypeWith usage (unliftedBcs unlift_usage) tp when (any (`elem` bound) (fvVars (freeInType tp))) $ unifyError usage mempty bcs $ "Type variable" <+> prettyName vn <+> "cannot be instantiated with type containing anonymous sizes:" indent 2 (pretty tp) textwrap "This is usually because the size of an array returned by a higher-order function argument cannot be determined statically. This can also be due to the return size being a value parameter. Add type annotation to clarify." Just (Equality _) -> do link equalityType usage tp Just (Overloaded ts old_usage) | tp `notElem` map (Scalar . Prim) ts -> do link case tp of Scalar (TypeVar _ (QualName [] v) []) | not $ isRigid v constraints -> linkVarToTypes usage v ts _ -> unifyError usage mempty bcs $ "Cannot instantiate" <+> dquotes (prettyName vn) <+> "with type" indent 2 (pretty tp) "as" <+> dquotes (prettyName vn) <+> "must be one of" <+> commasep (map pretty ts) "due to" <+> pretty old_usage <> "." Just (HasFields l required_fields old_usage) -> do when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp case tp of Scalar (Record tp_fields) | all (`M.member` tp_fields) $ M.keys required_fields -> do required_fields' <- mapM normTypeFully required_fields let tp' = Scalar $ Record $ required_fields <> tp_fields -- Crucially left-biased. ext = filter (`S.member` fvVars (freeInType tp')) bound modifyConstraints $ M.insert vn (lvl, Constraint (RetType ext tp') usage) unifySharedFields onDims usage bound bcs required_fields' tp_fields Scalar (TypeVar _ (QualName [] v) []) -> do case M.lookup v constraints of Just (_, HasFields _ tp_fields _) -> unifySharedFields onDims usage bound bcs required_fields tp_fields Just (_, NoConstraint {}) -> pure () Just (_, Equality {}) -> pure () _ -> do notes <- (<>) <$> typeVarNotes vn <*> typeVarNotes v noRecordType notes link modifyConstraints $ M.insertWith combineFields v (lvl, HasFields l required_fields old_usage) where combineFields (_, HasFields l1 fs1 usage1) (_, HasFields l2 fs2 _) = (lvl, HasFields (l1 `min` l2) (M.union fs1 fs2) usage1) combineFields hasfs _ = hasfs _ -> unifyError usage mempty bcs $ "Cannot instantiate" <+> dquotes (prettyName vn) <+> "with type" indent 2 (pretty tp) "as" <+> dquotes (prettyName vn) <+> "must be a record with fields" indent 2 (pretty (Record required_fields)) "due to" <+> pretty old_usage <> "." -- See Note [Linking variables to sum types] Just (HasConstrs l required_cs old_usage) -> do when (l == Unlifted) $ arrayElemTypeWith usage (unliftedBcs old_usage) tp case tp of Scalar (Sum ts) | all (`M.member` ts) $ M.keys required_cs -> do let tp' = Scalar $ Sum $ required_cs <> ts -- Crucially left-biased. ext = filter (`S.member` fvVars (freeInType tp')) bound modifyConstraints $ M.insert vn (lvl, Constraint (RetType ext tp') usage) unifySharedConstructors onDims usage bound bcs required_cs ts | otherwise -> unsharedConstructors required_cs ts =<< typeVarNotes vn Scalar (TypeVar _ (QualName [] v) []) -> do case M.lookup v constraints of Just (_, HasConstrs _ v_cs _) -> unifySharedConstructors onDims usage bound bcs required_cs v_cs Just (_, NoConstraint {}) -> pure () Just (_, Equality {}) -> pure () _ -> do notes <- (<>) <$> typeVarNotes vn <*> typeVarNotes v noSumType notes link modifyConstraints $ M.insertWith combineConstrs v (lvl, HasConstrs l required_cs old_usage) where combineConstrs (_, HasConstrs l1 cs1 usage1) (_, HasConstrs l2 cs2 _) = (lvl, HasConstrs (l1 `min` l2) (M.union cs1 cs2) usage1) combineConstrs hasCs _ = hasCs _ -> noSumType =<< typeVarNotes vn _ -> link where unsharedConstructors cs1 cs2 notes = unifyError usage notes bcs (unsharedConstructorsMsg cs1 cs2) noSumType notes = unifyError usage notes bcs "Cannot unify a sum type with a non-sum type." noRecordType notes = unifyError usage notes bcs "Cannot unify a record type with a non-record type." linkVarToDim :: (MonadUnify m) => Usage -> BreadCrumbs -> VName -> Level -> Exp -> m () linkVarToDim usage bcs vn lvl e = do constraints <- getConstraints mapM_ (checkVar constraints) $ fvVars $ freeInExp e modifyConstraints $ M.insert vn (lvl, Size (Just e) usage) where checkVar _ dim' | vn == dim' = do notes <- dimNotes usage e unifyError usage notes bcs $ "Occurs check: cannot instantiate" <+> dquotes (prettyName vn) <+> "with" <+> dquotes (pretty e) <+> "." checkVar constraints dim' | Just (dim_lvl, c) <- dim' `M.lookup` constraints, dim_lvl >= lvl = case c of ParamSize {} -> do notes <- dimNotes usage e unifyError usage notes bcs $ "Cannot link size" <+> dquotes (prettyName vn) <+> "to" <+> dquotes (pretty e) <+> "(scope violation)." "This is because" <+> dquotes (pretty $ qualName dim') <+> "is not in scope when" <+> dquotes (prettyName vn) <+> "is introduced." _ -> modifyConstraints $ M.insert dim' (lvl, c) checkVar _ _ = pure () -- | Assert that this type must be one of the given primitive types. mustBeOneOf :: (MonadUnify m) => [PrimType] -> Usage -> StructType -> m () mustBeOneOf [req_t] usage t = unify usage (Scalar (Prim req_t)) t mustBeOneOf ts usage t = do t' <- normType t constraints <- getConstraints let isRigid' v = isRigid v constraints case t' of Scalar (TypeVar _ (QualName [] v) []) | not $ isRigid' v -> linkVarToTypes usage v ts Scalar (Prim pt) | pt `elem` ts -> pure () _ -> failure where failure = unifyError usage mempty noBreadCrumbs $ "Cannot unify type" <+> dquotes (pretty t) <+> "with any of " <> commasep (map pretty ts) <> "." linkVarToTypes :: (MonadUnify m) => Usage -> VName -> [PrimType] -> m () linkVarToTypes usage vn ts = do vn_constraint <- M.lookup vn <$> getConstraints case vn_constraint of Just (lvl, Overloaded vn_ts vn_usage) -> case ts `L.intersect` vn_ts of [] -> unifyError usage mempty noBreadCrumbs $ "Type constrained to one of" <+> commasep (map pretty ts) <+> "but also one of" <+> commasep (map pretty vn_ts) <+> "due to" <+> pretty vn_usage <> "." ts' -> modifyConstraints $ M.insert vn (lvl, Overloaded ts' usage) Just (_, HasConstrs _ _ vn_usage) -> unifyError usage mempty noBreadCrumbs $ "Type constrained to one of" <+> commasep (map pretty ts) <> ", but also inferred to be sum type due to" <+> pretty vn_usage <> "." Just (_, HasFields _ _ vn_usage) -> unifyError usage mempty noBreadCrumbs $ "Type constrained to one of" <+> commasep (map pretty ts) <> ", but also inferred to be record due to" <+> pretty vn_usage <> "." Just (lvl, _) -> modifyConstraints $ M.insert vn (lvl, Overloaded ts usage) Nothing -> unifyError usage mempty noBreadCrumbs $ "Cannot constrain type to one of" <+> commasep (map pretty ts) -- | Assert that this type must support equality. equalityType :: (MonadUnify m, Pretty (Shape dim), Pretty u) => Usage -> TypeBase dim u -> m () equalityType usage t = do unless (orderZero t) $ unifyError usage mempty noBreadCrumbs $ "Type " <+> dquotes (pretty t) <+> "does not support equality (may contain function)." mapM_ mustBeEquality $ typeVars t where mustBeEquality vn = do constraints <- getConstraints case M.lookup vn constraints of Just (_, Constraint (RetType [] (Scalar (TypeVar _ (QualName [] vn') []))) _) -> mustBeEquality vn' Just (_, Constraint (RetType _ vn_t) cusage) | not $ orderZero vn_t -> unifyError usage mempty noBreadCrumbs $ "Type" <+> dquotes (pretty t) <+> "does not support equality." "Constrained to be higher-order due to" <+> pretty cusage <+> "." | otherwise -> pure () Just (lvl, NoConstraint _ _) -> modifyConstraints $ M.insert vn (lvl, Equality usage) Just (_, Overloaded _ _) -> pure () -- All primtypes support equality. Just (_, Equality {}) -> pure () _ -> unifyError usage mempty noBreadCrumbs $ "Type" <+> prettyName vn <+> "does not support equality." zeroOrderTypeWith :: (MonadUnify m) => Usage -> BreadCrumbs -> StructType -> m () zeroOrderTypeWith usage bcs t = do unless (orderZero t) $ unifyError usage mempty bcs $ "Type" indent 2 (pretty t) "found to be functional." mapM_ mustBeZeroOrder . S.toList . typeVars =<< normType t where mustBeZeroOrder vn = do constraints <- getConstraints case M.lookup vn constraints of Just (lvl, NoConstraint _ _) -> modifyConstraints $ M.insert vn (lvl, NoConstraint Unlifted usage) Just (lvl, HasFields _ fs _) -> modifyConstraints $ M.insert vn (lvl, HasFields Unlifted fs usage) Just (lvl, HasConstrs _ cs _) -> modifyConstraints $ M.insert vn (lvl, HasConstrs Unlifted cs usage) Just (_, ParamType Lifted ploc) -> unifyError usage mempty bcs $ "Type parameter" <+> dquotes (prettyName vn) <+> "at" <+> pretty (locStr ploc) <+> "may be a function." _ -> pure () -- | Assert that this type must be zero-order. zeroOrderType :: (MonadUnify m) => Usage -> T.Text -> StructType -> m () zeroOrderType usage desc = zeroOrderTypeWith usage $ breadCrumb bc noBreadCrumbs where bc = Matching $ "When checking" <+> textwrap desc arrayElemTypeWith :: (MonadUnify m, Pretty (Shape dim), Pretty u) => Usage -> BreadCrumbs -> TypeBase dim u -> m () arrayElemTypeWith usage bcs t = do unless (orderZero t) $ unifyError usage mempty bcs $ "Type" indent 2 (pretty t) "found to be functional." mapM_ mustBeZeroOrder . S.toList . typeVars $ t where mustBeZeroOrder vn = do constraints <- getConstraints case M.lookup vn constraints of Just (lvl, NoConstraint _ _) -> modifyConstraints $ M.insert vn (lvl, NoConstraint Unlifted usage) Just (_, ParamType l ploc) | l `elem` [Lifted, SizeLifted] -> unifyError usage mempty bcs $ "Type parameter" <+> dquotes (prettyName vn) <+> "bound at" <+> pretty (locStr ploc) <+> "is lifted and cannot be an array element." _ -> pure () -- | Assert that this type must be valid as an array element. arrayElemType :: (MonadUnify m, Pretty (Shape dim), Pretty u) => Usage -> T.Text -> TypeBase dim u -> m () arrayElemType usage desc = arrayElemTypeWith usage $ breadCrumb bc noBreadCrumbs where bc = Matching $ "When checking" <+> textwrap desc unifySharedFields :: (MonadUnify m) => UnifySizes m -> Usage -> [VName] -> BreadCrumbs -> M.Map Name StructType -> M.Map Name StructType -> m () unifySharedFields onDims usage bound bcs fs1 fs2 = forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (t1, t2)) -> unifyWith onDims usage bound (breadCrumb (MatchingFields [f]) bcs) t1 t2 unifySharedConstructors :: (MonadUnify m) => UnifySizes m -> Usage -> [VName] -> BreadCrumbs -> M.Map Name [StructType] -> M.Map Name [StructType] -> m () unifySharedConstructors onDims usage bound bcs cs1 cs2 = forM_ (M.toList $ M.intersectionWith (,) cs1 cs2) $ \(c, (f1, f2)) -> unifyConstructor c f1 f2 where unifyConstructor c f1 f2 | length f1 == length f2 = do let bcs' = breadCrumb (MatchingConstructor c) bcs zipWithM_ (unifyWith onDims usage bound bcs') f1 f2 | otherwise = unifyError usage mempty bcs $ "Cannot unify constructor" <+> dquotes (prettyName c) <> "." -- | In @mustHaveConstr usage c t fs@, the type @t@ must have a -- constructor named @c@ that takes arguments of types @ts@. mustHaveConstr :: (MonadUnify m) => Usage -> Name -> StructType -> [StructType] -> m () mustHaveConstr usage c t fs = do constraints <- getConstraints case t of Scalar (TypeVar _ (QualName _ tn) []) | Just (lvl, NoConstraint l _) <- M.lookup tn constraints -> do mapM_ (scopeCheck usage noBreadCrumbs tn lvl) fs modifyConstraints $ M.insert tn (lvl, HasConstrs l (M.singleton c fs) usage) | Just (lvl, HasConstrs l cs _) <- M.lookup tn constraints -> case M.lookup c cs of Nothing -> modifyConstraints $ M.insert tn (lvl, HasConstrs l (M.insert c fs cs) usage) Just fs' | length fs == length fs' -> zipWithM_ (unify usage) fs fs' | otherwise -> unifyError usage mempty noBreadCrumbs $ "Different arity for constructor" <+> dquotes (pretty c) <> "." Scalar (Sum cs) -> case M.lookup c cs of Nothing -> unifyError usage mempty noBreadCrumbs $ "Constuctor" <+> dquotes (pretty c) <+> "not present in type." Just fs' | length fs == length fs' -> zipWithM_ (unify usage) fs fs' | otherwise -> unifyError usage mempty noBreadCrumbs $ "Different arity for constructor" <+> dquotes (pretty c) <+> "." _ -> unify usage t $ Scalar $ Sum $ M.singleton c fs mustHaveFieldWith :: (MonadUnify m) => UnifySizes m -> Usage -> [VName] -> BreadCrumbs -> Name -> StructType -> m StructType mustHaveFieldWith onDims usage bound bcs l t = do constraints <- getConstraints l_type <- newTypeVar (locOf usage) "t" case t of Scalar (TypeVar _ (QualName _ tn) []) | Just (lvl, NoConstraint {}) <- M.lookup tn constraints -> do scopeCheck usage bcs tn lvl l_type modifyConstraints $ M.insert tn (lvl, HasFields Lifted (M.singleton l l_type) usage) pure l_type | Just (lvl, HasFields lifted fields _) <- M.lookup tn constraints -> do case M.lookup l fields of Just t' -> unifyWith onDims usage bound bcs l_type t' Nothing -> modifyConstraints $ M.insert tn (lvl, HasFields lifted (M.insert l l_type fields) usage) pure l_type Scalar (Record fields) | Just t' <- M.lookup l fields -> do unify usage l_type t' pure t' | otherwise -> unifyError usage mempty bcs $ "Attempt to access field" <+> dquotes (pretty l) <+> " of value of type" <+> pretty (toStructural t) <> "." _ -> do unify usage t $ Scalar $ Record $ M.singleton l l_type pure l_type -- | Assert that some type must have a field with this name and type. mustHaveField :: (MonadUnify m) => Usage -> Name -> StructType -> m StructType mustHaveField usage = mustHaveFieldWith (unifySizes usage) usage mempty noBreadCrumbs newDimOnMismatch :: (MonadUnify m) => Loc -> StructType -> StructType -> m (StructType, [VName]) newDimOnMismatch loc t1 t2 = do (t, seen) <- runStateT (matchDims onDims t1 t2) mempty pure (t, M.elems seen) where r = RigidCond t1 t2 same (e1, e2) = maybe False (all same) $ similarExps e1 e2 onDims _ d1 d2 | same (d1, d2) = pure d1 | otherwise = do -- Remember mismatches we have seen before and reuse the -- same new size. maybe_d <- gets $ M.lookup (d1, d2) case maybe_d of Just d -> pure $ sizeFromName (qualName d) $ srclocOf loc Nothing -> do d <- lift $ newRigidDim loc r "differ" modify $ M.insert (d1, d2) d pure $ sizeFromName (qualName d) $ srclocOf loc -- | Like unification, but creates new size variables where mismatches -- occur. Returns the new dimensions thus created. unifyMostCommon :: (MonadUnify m) => Usage -> StructType -> StructType -> m (StructType, [VName]) unifyMostCommon usage t1 t2 = do -- We are ignoring the dimensions here, because any mismatches -- should be turned into fresh size variables. let allOK _ _ _ _ _ = pure () unifyWith allOK usage mempty noBreadCrumbs t1 t2 t1' <- normTypeFully t1 t2' <- normTypeFully t2 newDimOnMismatch (locOf usage) t1' t2' -- Simple MonadUnify implementation. type UnifyMState = (Constraints, Int) newtype UnifyM a = UnifyM (StateT UnifyMState (Except TypeError) a) deriving ( Monad, Functor, Applicative, MonadState UnifyMState, MonadError TypeError ) newVar :: Name -> UnifyM VName newVar name = do (x, i) <- get put (x, i + 1) pure $ VName (mkTypeVarName name i) i instance MonadUnify UnifyM where getConstraints = gets fst putConstraints x = modify $ \(_, i) -> (x, i) newTypeVar loc name = do v <- newVar name modifyConstraints $ M.insert v (0, NoConstraint Lifted $ Usage Nothing $ locOf loc) pure $ Scalar $ TypeVar mempty (qualName v) [] newDimVar usage rigidity name = do dim <- newVar name case rigidity of Rigid src -> modifyConstraints $ M.insert dim (0, UnknownSize (locOf usage) src) Nonrigid -> modifyConstraints $ M.insert dim (0, Size Nothing usage) pure dim curLevel = pure 1 unifyError loc notes bcs doc = throwError $ TypeError (locOf loc) notes $ doc <> pretty bcs matchError loc notes bcs t1 t2 = throwError $ TypeError (locOf loc) notes $ doc <> pretty bcs where doc = "Types" indent 2 (pretty t1) "and" indent 2 (pretty t2) "do not match." runUnifyM :: [TypeParam] -> [TypeParam] -> UnifyM a -> Either TypeError a runUnifyM rigid_tparams nonrigid_tparams (UnifyM m) = runExcept $ evalStateT m (constraints, 0) where constraints = M.fromList $ map nonrigid nonrigid_tparams <> map rigid rigid_tparams nonrigid (TypeParamDim p loc) = (p, (1, Size Nothing $ Usage Nothing $ locOf loc)) nonrigid (TypeParamType l p loc) = (p, (1, NoConstraint l $ Usage Nothing $ locOf loc)) rigid (TypeParamDim p loc) = (p, (0, ParamSize $ locOf loc)) rigid (TypeParamType l p loc) = (p, (0, ParamType l $ locOf loc)) -- | Perform a unification of two types outside a monadic context. -- The first list of type parameters are rigid but may have liftedness -- constraints; the second list of type parameters are allowed to be -- instantiated. All other types are considered rigid with no -- constraints. doUnification :: Loc -> [TypeParam] -> [TypeParam] -> StructType -> StructType -> Either TypeError StructType doUnification loc rigid_tparams nonrigid_tparams t1 t2 = runUnifyM rigid_tparams nonrigid_tparams $ do unify (Usage Nothing (locOf loc)) t1 t2 normTypeFully t2 -- Note [Linking variables to sum types] -- -- Consider the case when unifying a result type -- -- i32 -> ?[n].(#foo [n]bool) -- -- with -- -- i32 -> ?[k].a -- -- where 'a' has a HasConstrs constraint saying that it must have at -- least a constructor of type '#foo [0]bool'. -- -- This unification should succeed, but we must not merely link 'a' to -- '#foo [n]bool', as 'n' is not free. Instead we should instantiate -- 'a' to be a concrete sum type (because now we know exactly which -- constructor labels it must have), and unify each of its constructor -- payloads with the corresponding expected payload. futhark-0.25.27/src/Language/Futhark/Warnings.hs000066400000000000000000000041411475065116200214040ustar00rootroot00000000000000-- | A very simple representation of collections of warnings. -- Warnings have a position (so they can be ordered), and their -- 'Pretty'-instance produces a human-readable string. module Language.Futhark.Warnings ( Warnings, anyWarnings, singleWarning, singleWarning', listWarnings, prettyWarnings, ) where import Data.List (sortOn) import Data.Monoid import Futhark.Util.Loc import Futhark.Util.Pretty import Language.Futhark.Core (locText, prettyStacktrace) import Prelude -- | The warnings produced by the compiler. The 'Show' instance -- produces a human-readable description. newtype Warnings = Warnings [(Loc, [Loc], Doc ())] instance Semigroup Warnings where Warnings ws1 <> Warnings ws2 = Warnings $ ws1 <> ws2 instance Monoid Warnings where mempty = Warnings mempty -- | Prettyprint warnings, making use of colours and such. prettyWarnings :: Warnings -> Doc AnsiStyle prettyWarnings (Warnings []) = mempty prettyWarnings (Warnings ws) = stack $ map ((<> hardline) . onWarning) $ sortOn (rep . wloc) ws where wloc (x, _, _) = locOf x rep NoLoc = ("", 0) rep (Loc p _) = (posFile p, posCoff p) onWarning (loc, [], w) = annotate (color Yellow) ("Warning at" <+> pretty (locText loc) <> ":") indent 2 (unAnnotate w) onWarning (loc, locs, w) = annotate (color Yellow) ("Warning at" pretty (prettyStacktrace 0 (map locText (loc : locs)))) indent 2 (unAnnotate w) -- | True if there are any warnings in the set. anyWarnings :: Warnings -> Bool anyWarnings (Warnings ws) = not $ null ws -- | A single warning at the given location. singleWarning :: Loc -> Doc () -> Warnings singleWarning loc = singleWarning' loc [] -- | A single warning at the given location, but also with a stack -- trace (sort of) to the location. singleWarning' :: Loc -> [Loc] -> Doc () -> Warnings singleWarning' loc locs problem = Warnings [(loc, locs, problem)] -- | Exports Warnings into a list of (location, problem). listWarnings :: Warnings -> [(Loc, Doc ())] listWarnings (Warnings ws) = map (\(loc, _, doc) -> (loc, doc)) ws futhark-0.25.27/src/main.hs000066400000000000000000000003141475065116200154070ustar00rootroot00000000000000-- | The *actual* @futhark@ command line program, as seen by cabal. module Main (main) where import Futhark.CLI.Main qualified -- | This is the main function. main :: IO () main = Futhark.CLI.Main.main futhark-0.25.27/tests/000077500000000000000000000000001475065116200145045ustar00rootroot00000000000000futhark-0.25.27/tests/.gitignore000066400000000000000000000002231475065116200164710ustar00rootroot00000000000000* !*/ !/**/*.* /**/data /**/*.c /**/*.expected /**/*.actual /**/*.wasm /**/*.json /**/*.ispc /**/*.cache /**/*.h /**/*.o /**/*.gch /**/*.out /**/*~futhark-0.25.27/tests/BabyBearFun.fut000066400000000000000000000034461475065116200173530ustar00rootroot00000000000000-- == -- input { -- } -- output { -- [[2,4,5],[1,5,3],[3,7,1]] -- } -------------------------------------------------- -- SAC VERSIOn -------------------------------------------------- --inline i32[.,.] floydSbs1(i32[.,.] d ) [ -- dT = transpose(d); -- res = with -- (. <= [#i,j] <= .) : -- min( d[i,j], minval( d[i] + dT[j])); -- : modarray(d); -- return( res); --] -------------------------------------------------- -- C VERSIOn -------------------------------------------------- --inline i32* floydSbs1( i32 n, i32* d ) [ -- do k = 1, n -- do i = 1, n -- do j = 1, n -- d[i,j] = min(d[i,j], d[i,k] + d[k,j]) -- enddo -- enddo -- enddo -------------------------------------------------- -- C VERSIOn -------------------------------------------------- --inline i32* floydSbs1( i32 n, i32* d ) [ -- do i = 1, n -- do j = 1, n -- minrow = 0; -- do k = 1, n -- minrow = min(minrow, d[i,k] + d[k,j]) -- enddo -- d[i,j] = min(d[i,j], minrow) -- enddo -- enddo def min1 [n] (a: [n]i32, b: [n]i32): [n]i32 = map (uncurry i32.min) (zip a b) def redmin1(a: []i32): i32 = reduce i32.min 1200 a def redmin2 [n][m] (a: [n][m]i32): [n]i32 = map redmin1 a def plus1 [n] (a: [n]i32, b: [n]i32): [n]i32 = map2 (+) a b def plus2 [n][m] (a: [n][m]i32, b: [n][m]i32): [n][m]i32 = map plus1 (zip a b) def replin [k] (len: i64) (a: [k]i32): [len][k]i32 = replicate len a def floydSbsFun (n: i64) (d: [n][n]i32 ): [][]i32 = let d3 = replicate n <| transpose d let d2 = map (replin(n)) d let abr = map plus2 (zip d3 d2) let partial = map redmin2 abr in map min1 (zip partial d ) def main: [][]i32 = let arr = [[2,4,5], [1,1000,3], [3,7,1]] in floydSbsFun 3 arr futhark-0.25.27/tests/BabyBearImp.fut000066400000000000000000000027541475065116200173510ustar00rootroot00000000000000-- == -- input { -- } -- output { -- [[2,4,5],[1,5,3],[3,7,1]] -- } -------------------------------------------------- -- SAC VERSIOn -------------------------------------------------- --inline i32[.,.] floydSbs1(i32[.,.] d ) [ -- dT = transpose(d); -- res = with -- (. <= [#i,j] <= .) : -- min( d[i,j], minval( d[i] + dT[j])); -- : modarray(d); -- return( res); --] -------------------------------------------------- -- C VERSIOn -------------------------------------------------- --inline i32* floydSbs1( i32 n, i32* d ) [ -- do k = 1, n -- do i = 1, n -- do j = 1, n -- d[i,j] = min(d[i,j], d[i,k] + d[k,j]) -- enddo -- enddo -- enddo -------------------------------------------------- -- C VERSIOn -------------------------------------------------- --inline i32* floydSbs1( i32 n, i32* d ) [ -- do i = 1, n -- do j = 1, n -- minrow = 0; -- do k = 1, n -- minrow = min(minrow, d[i,k] + d[k,j]) -- enddo -- d[i,j] = min(d[i,j], minrow) -- enddo -- enddo def floydSbsImp(n: i32, d: *[][]i32): [][]i32 = let dT = copy (transpose d) in loop d = d for i < n do loop d for j < n do let sumrow = map2 (+) d[i] dT[j] let minrow = reduce i32.min 1200 sumrow let minrow = i32.min d[i,j] minrow let d[i,j] = minrow in d def main: [][]i32 = let arr = [[2,4,5], [1,1000,3], [3,7,1]] in floydSbsImp(3, copy(arr)) futhark-0.25.27/tests/README.md000066400000000000000000000012401475065116200157600ustar00rootroot00000000000000Futhark Integration Test Suite ============================== This directory contains a large number of small programs that test the Futhark compiler itself. You should not assume that these programs are examples of good Futhark style, nor that they are written to be understandable. Many are from very early days, when the language was much different, and have been semi-mechanically updated. Others are intentionally ugly to test out some corner case of the compiler. Look at the [example programs](https://github.com/diku-dk/futhark/tree/master/examples) or the [benchmark suite](https://github.com/diku-dk/futhark-benchmarks/) if you want examples to learn from. futhark-0.25.27/tests/accs/000077500000000000000000000000001475065116200154155ustar00rootroot00000000000000futhark-0.25.27/tests/accs/dup.fut000066400000000000000000000007401475065116200167260ustar00rootroot00000000000000-- This does not make much sense, but should not crash the compiler. -- == -- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] } -- output { [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9] } import "intrinsics" def fst (x,_) = x def f (acc: *acc ([]i32)) i = let acc = fst (acc, acc) let acc = write acc (i*2) (i32.i64 i) let acc = write acc (i*2+1) (i32.i64 i) in acc def main (xs: *[]i32) = scatter_stream xs f (iota 10) futhark-0.25.27/tests/accs/fusion0.fut000066400000000000000000000007071475065116200175240ustar00rootroot00000000000000-- == -- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] } -- output { [0, 2, 4, 7, 8, 12, 13, 17, 16, 24, 11, 16, 15, 19, 16, 27, 16, 25, 26, 28] } -- structure { /WithAcc/SplitAcc/Screma/Screma 0 } import "intrinsics" def f (acc: *acc ([]i32)) i = let js = scan (+) 0 (map (+i) (iota 10)) in loop acc for j in js do write acc j (i32.i64 i) def main (xs: *[]i32) = reduce_by_index_stream xs (+) 0 f (iota 10) futhark-0.25.27/tests/accs/hist0.fut000066400000000000000000000006051475065116200171650ustar00rootroot00000000000000-- == -- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] } -- output { [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 19, 11, 12, 13, 14, 15, 16, 17, 18, 19] } import "intrinsics" def f (acc: *acc ([]i32)) i = let acc = write acc i (i32.i64 i) let acc = write acc (i+1) (i32.i64 i) in acc def main (xs: *[]i32) = reduce_by_index_stream xs (+) 0 f (iota 10) futhark-0.25.27/tests/accs/hist1.fut000066400000000000000000000010141475065116200171610ustar00rootroot00000000000000-- == -- input { [0f32, 1f32, 2f32, 3f32, 4f32, 5f32, 6f32, 7f32, 8f32, -- 9f32, 10f32, 11f32, 12f32, 13f32, 14f32, 15f32, 16f32, 17f32, -- 18f32, 19f32] } -- output { [0f32, 1f32, 3f32, 4f32, 6f32, 7f32, 9f32, 10f32, 12f32, -- 13f32, 15f32, 16f32, 18f32, 19f32, 21f32, 22f32, 24f32, 25f32, -- 27f32, 28f32] } import "intrinsics" def f (acc: *acc ([]f32)) i = let acc = write acc (i*2) (f32.i64 i) let acc = write acc (i*2+1) (f32.i64 i) in acc def main (xs: *[]f32) = reduce_by_index_stream xs (+) 0 f (iota 10) futhark-0.25.27/tests/accs/hist2.fut000066400000000000000000000005671475065116200171760ustar00rootroot00000000000000-- Writing an array with a vector operator. -- == -- input { [[2],[3],[4]] } -- output { [[3i32], [3i32], [5i32]] } import "intrinsics" def f (acc: *acc ([][]i32)) (i, x) = let acc = write acc (i*2) x in acc def main [n] (xs: *[][n]i32) = reduce_by_index_stream xs (map2 (+)) (replicate n 0) f (zip (iota 10) (replicate 10 (replicate n 1))) futhark-0.25.27/tests/accs/hist3.fut000066400000000000000000000011261475065116200171670ustar00rootroot00000000000000-- Writing an array with a non-vector operator. -- == -- input { [[2],[3],[4]] } -- output { [[3i32], [3i32], [5i32]] } import "intrinsics" def vecadd [n] (xs: [n]i32) (ys: [n]i32) : [n]i32 = -- This is just map2 (+), but written in a way the compiler -- hopefully will not recognise. loop acc = replicate n 0 for i < n do acc with [i] = xs[i] + ys[i] def f (acc: *acc ([][]i32)) (i, x) = let acc = write acc (i*2) x in acc def main [n] (xs: *[][n]i32) = reduce_by_index_stream xs (map2 (+)) (replicate n 0) f (zip (iota 10) (replicate 10 (replicate n 1))) futhark-0.25.27/tests/accs/hist4.fut000066400000000000000000000012521475065116200171700ustar00rootroot00000000000000-- Tuple data. -- == -- input { -- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] -- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] -- } -- output { -- [0, 1, 3, 4, 6, 7, 9, 10, 12, 13, 15, 16, 18, 19, 21, 22, 24, 25, 27, 28] -- [1, 2, 4, 5, 7, 8, 10, 11, 13, 14, 16, 17, 19, 20, 22, 23, 25, 26, 28, 29] -- } import "intrinsics" def f (acc: *acc ([](i32,i32))) i = let acc = write acc (i*2) (i32.i64 i, i32.i64 (i+1)) let acc = write acc (i*2+1) (i32.i64 i, i32.i64 (i+1)) in acc def main (xs: *[]i32) (ys: *[]i32) = let op x y = (x.0 + y.0, x.1 + y.1) in reduce_by_index_stream (zip xs ys) op (0,0) f (iota 10) |> unzip futhark-0.25.27/tests/accs/hist5.fut000066400000000000000000000030631475065116200171730ustar00rootroot00000000000000-- Complex operator that will require explicit locking. -- == -- input { -- 10i64 -- [0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32, 9.0f32, 10.0f32, 11.0f32, 12.0f32, 13.0f32, 14.0f32, 15.0f32, 16.0f32, 17.0f32, 18.0f32, 19.0f32, 20.0f32, 21.0f32, 22.0f32, 23.0f32, 24.0f32, 25.0f32, 26.0f32, 27.0f32, 28.0f32, 29.0f32, 30.0f32, 31.0f32, 32.0f32, 33.0f32, 34.0f32, 35.0f32, 36.0f32, 37.0f32, 38.0f32, 39.0f32, 40.0f32, 41.0f32, 42.0f32, 43.0f32, 44.0f32, 45.0f32, 46.0f32, 47.0f32, 48.0f32, 49.0f32, 50.0f32, 51.0f32, 52.0f32, 53.0f32, 54.0f32, 55.0f32, 56.0f32, 57.0f32, 58.0f32, 59.0f32, 60.0f32, 61.0f32, 62.0f32, 63.0f32, 64.0f32, 65.0f32, 66.0f32, 67.0f32, 68.0f32, 69.0f32, 70.0f32, 71.0f32, 72.0f32, 73.0f32, 74.0f32, 75.0f32, 76.0f32, 77.0f32, 78.0f32, 79.0f32, 80.0f32, 81.0f32, 82.0f32, 83.0f32, 84.0f32, 85.0f32, 86.0f32, 87.0f32, 88.0f32, 89.0f32, 90.0f32, 91.0f32, 92.0f32, 93.0f32, 94.0f32, 95.0f32, 96.0f32, 97.0f32, 98.0f32, 99.0f32] -- } -- output { -- [90.0f32, 91.0f32, 92.0f32, 93.0f32, 94.0f32, 95.0f32, 96.0f32, 97.0f32, 98.0f32, 99.0f32] -- [90i64, 91i64, 92i64, 93i64, 94i64, 95i64, 96i64, 97i64, 98i64, 99i64] -- } import "intrinsics" def f n (acc: *acc ([](f32,i64))) (i, x) = write acc (i%n) (x, i) def main n (xs: []f32) = let op = (\(a,i) (b,j) -> if a < b then (b,j) else if b < a then (a,i) else if j < i then (b,j) else (a, i)) in reduce_by_index_stream (replicate n (0,0)) op (f32.lowest,-1) (f n) (zip (indices xs) xs) |> unzip futhark-0.25.27/tests/accs/id.fut000066400000000000000000000006001475065116200165250ustar00rootroot00000000000000-- If the accumulator isn't updated, the entire thing should go away. -- This is not because of user code (nobody would write this), but -- because the compiler may internally generate code like this -- (possibly after other simplifications). -- == -- structure { WithAcc 0 } import "intrinsics" def main (xs: *[]i32) = scatter_stream xs (\(acc: *acc ([]i32)) _ -> acc) (iota 10) futhark-0.25.27/tests/accs/intrinsics.fut000066400000000000000000000013561475065116200203270ustar00rootroot00000000000000-- We don't want to expose these constructs to users just yet, as they -- are not terribly stable. type~ acc 't = intrinsics.acc t def scatter_stream [k] 'a 'b (dest: *[k]a) (f: *acc ([k]a) -> b -> acc ([k]a)) (bs: []b) : *[k]a = intrinsics.scatter_stream dest f bs :> *[k]a def reduce_by_index_stream [k] 'a 'b (dest: *[k]a) (op: a -> a -> a) (ne: a) (f: *acc ([k]a) -> b -> acc ([k]a)) (bs: []b) : *[k]a = intrinsics.hist_stream dest op ne f bs :> *[k]a def write [n] 't (acc : *acc ([n]t)) (i: i64) (v: t) : *acc ([n]t) = intrinsics.acc_write acc i v futhark-0.25.27/tests/accs/neutral.fut000066400000000000000000000004431475065116200176100ustar00rootroot00000000000000-- Test that we can remove updates that are just the neutral element. -- == -- structure { WithAcc 0 } import "intrinsics" def f (acc: *acc ([]i32)) i = let acc = write acc i 1 let acc = write acc (i+1) 1 in acc def main (xs: *[]i32) = reduce_by_index_stream xs (*) 1 f (iota 10) futhark-0.25.27/tests/accs/outermap0.fut000066400000000000000000000006141475065116200200520ustar00rootroot00000000000000-- == -- input { 4i64 5i64 2i64 } -- output { [[0, 0, 1, 1, 0], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 1, 3]] } import "intrinsics" def f (acc: *acc ([]i32)) i = let acc = write acc (i*2) (i32.i64 i) let acc = write acc (i*2+1) (i32.i64 i) in acc def main n m k = tabulate n (\i -> let xs = replicate m (i32.i64 i) in scatter_stream xs f (iota k)) futhark-0.25.27/tests/accs/scatter0.fut000066400000000000000000000005551475065116200176670ustar00rootroot00000000000000-- == -- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] } -- output { [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9] } import "intrinsics" def f (acc: *acc ([]i32)) i = let acc = write acc (i*2) (i32.i64 i) let acc = write acc (i*2+1) (i32.i64 i) in acc def main (xs: *[]i32) = scatter_stream xs f (iota 10) futhark-0.25.27/tests/accs/scatter1.fut000066400000000000000000000005001475065116200176560ustar00rootroot00000000000000-- == -- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] } -- output { [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 18, 19] } import "intrinsics" def f (acc: *acc ([]i32)) i = loop acc for j < i do write acc (j+i) 1 def main (xs: *[]i32) = scatter_stream xs f (iota 10) futhark-0.25.27/tests/accs/scatter2.fut000066400000000000000000000005231475065116200176640ustar00rootroot00000000000000-- Scattering arrays. -- == -- input { [[0,0,0],[0,0,0],[0,0,0],[0,0,0]] [[1,-1], [0,3]] } -- output { [[1, 2, 3], [1, 2, 3], [0, 0, 0], [1, 2, 3]] } import "intrinsics" def f 't (x: t) (acc: *acc ([]t)) (is: []i32) = loop acc for i in is do write acc (i64.i32 i) x def main (xs: *[][]i32) is = scatter_stream xs (f [1,2,3]) is futhark-0.25.27/tests/accs/scatterhist.fut000066400000000000000000000012701475065116200204720ustar00rootroot00000000000000-- Dynamically pick whether we are doing a scatter- or histogram -- accumulation! -- == -- input { true [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] } -- output { [0, 1, 3, 4, 6, 7, 9, 10, 12, 13, 15, 16, 18, 19, 21, 22, 24, 25, 27, 28] } -- input { false [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] } -- output { [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9] } import "intrinsics" def f (acc: *acc ([]i32)) i = let acc = write acc (i*2) (i32.i64 i) let acc = write acc (i*2+1) (i32.i64 i) in acc def main b (xs: *[]i32) = if b then reduce_by_index_stream xs (+) 0 f (iota 10) else scatter_stream xs f (iota 10) futhark-0.25.27/tests/ad/000077500000000000000000000000001475065116200150705ustar00rootroot00000000000000futhark-0.25.27/tests/ad/arr0.fut000066400000000000000000000004051475065116200164530ustar00rootroot00000000000000def f (xs: [2]f64) = xs[0] * xs[1] -- == -- entry: f_jvp -- input { [5.0, 7.0] } -- output { 7.0 5.0 } entry f_jvp xs = (jvp f xs [1,0], jvp f xs [0,1]) -- == -- entry: f_vjp -- input { [5.0, 7.0] } -- output { [7.0, 5.0] } entry f_vjp xs = vjp f xs 1 futhark-0.25.27/tests/ad/arr1.fut000066400000000000000000000004571475065116200164630ustar00rootroot00000000000000def f (x, y) : [2]f64 = [x+y, x*y] -- == -- entry: f_vjp f_jvp -- input { 5.0 7.0 } -- output { [1.0,7.0] [1.0, 5.0] } entry f_jvp x y = (jvp f (x,y) (1,0), jvp f (x,y) (0,1)) entry f_vjp x y = let (dx1,dx2) = vjp f (x,y) [1,0] let (dy1,dy2) = vjp f (x,y) [0,1] in ([dx1, dy1], [dx2, dy2]) futhark-0.25.27/tests/ad/arr2.fut000066400000000000000000000005121475065116200164540ustar00rootroot00000000000000def f (x, y) : [2][1]f64 = [x, y] -- == -- entry: f_vjp f_jvp -- input { [5.0] [7.0] } -- output { [[1.0],[0.0]] [[0.0], [1.0]] } entry f_jvp x y = (jvp f (x,y) ([1],[0]), jvp f (x,y) ([0],[1])) entry f_vjp x y = let (dx1,dx2) = vjp f (x,y) [[1],[0]] let (dy1,dy2) = vjp f (x,y) [[0],[1]] in ([dx1, dy1], [dx2, dy2]) futhark-0.25.27/tests/ad/cert.fut000066400000000000000000000001751475065116200165500ustar00rootroot00000000000000-- Quick test that we don't crash in the presence of certificates used -- free in vjp. def main y = vjp (map (\x -> x / y)) futhark-0.25.27/tests/ad/clz.fut000066400000000000000000000002501475065116200163750ustar00rootroot00000000000000-- Tricky because clz has a different return type than input type. -- Not really differentiable, though. -- == entry vjp_u64 = vjp u64.clz entry jvp_u64 = jvp u64.clz futhark-0.25.27/tests/ad/cmp0.fut000066400000000000000000000006501475065116200164500ustar00rootroot00000000000000-- == -- entry: fwd -- compiled input { 1.0 2.0 } -- output { false false } -- compiled input { 1.0 1.0 } -- output { true true } entry fwd x y = (jvp (\(a, b) -> f64.(a == b)) (x,y) (1,0), jvp (\(a, b) -> f64.(a == b)) (x,y) (0,1)) -- == -- entry: rev -- compiled input { 1.0 2.0 } -- output { 1.0 1.0 } -- compiled input { 1.0 1.0 } -- output { 1.0 1.0 } entry rev x y = vjp (\(a, b) -> f64.(a == b)) (x,y) true futhark-0.25.27/tests/ad/concat0.fut000066400000000000000000000004701475065116200171400ustar00rootroot00000000000000-- == -- entry: f_jvp -- input { [1,2,3] [4,5,6] } -- output { [1,2,3,4,5,6] } entry f_jvp xs ys : []i32 = jvp (uncurry concat) (xs,ys) (xs, ys) -- == -- entry: f_vjp -- input { [1,2,3] [4,5,6] } -- output { [1,2,3] [4,5,6] } entry f_vjp xs ys : ([]i32, []i32) = vjp (uncurry concat) (xs,ys) (concat xs ys) futhark-0.25.27/tests/ad/confusion0.fut000066400000000000000000000003111475065116200176660ustar00rootroot00000000000000-- == -- entry: fwd rev -- input { 1 2 } output { 1 } def d f x = jvp f x 1 def drev f x = vjp f x 1 entry fwd x y = d (\x' -> (d (x'*) y)) x entry rev x y = drev (\x' -> (drev (x'*) y)) x futhark-0.25.27/tests/ad/consume0.fut000066400000000000000000000005551475065116200173460ustar00rootroot00000000000000-- == -- entry: rev fwd -- input { [1.0,2.0,3.0] } -- output { [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] } def f (xs: []f64) = copy xs with [0] = 0 entry fwd [n] (xs: *[n]f64) = #[unsafe] tabulate n (\i -> jvp f xs (replicate n 0 with [i] = 1)) entry rev [n] (xs: *[n]f64) = #[unsafe] tabulate n (\i -> vjp f xs (replicate n 0 with [i] = 1)) futhark-0.25.27/tests/ad/consume1.fut000066400000000000000000000006431475065116200173450ustar00rootroot00000000000000-- == -- entry: rev fwd -- input { true [1.0,2.0,3.0] } -- output { [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] } def f b (xs: []f64) = let ys = copy xs in if b then ys with [0] = 0 else ys entry fwd [n] b (xs: *[n]f64) = #[unsafe] tabulate n (\i -> jvp (f b) xs (replicate n 0 with [i] = 1)) entry rev [n] b (xs: *[n]f64) = #[unsafe] tabulate n (\i -> vjp (f b) xs (replicate n 0 with [i] = 1)) futhark-0.25.27/tests/ad/consume2.fut000066400000000000000000000006271475065116200173500ustar00rootroot00000000000000-- == -- entry: rev fwd -- input { [true] [[1.0,2.0,3.0]] [[0.0,1.0,0.0]] } -- output { [[0.000000f64, 1.000000f64, 0.000000f64]] } def f b (xs: []f64) = let ys = copy xs in if b then ys with [0] = 0 else ys def g bs (xss: [][]f64) = #[unsafe] map2 f bs (copy xss) entry fwd [n] bs (xss: *[n][]f64) = #[unsafe] jvp (g bs) xss entry rev [n] bs (xss: *[n][]f64) = #[unsafe] vjp (g bs) xss futhark-0.25.27/tests/ad/consume3.fut000066400000000000000000000005461475065116200173510ustar00rootroot00000000000000def test [n] (xs: [n]f64) = let xs' = copy xs let xs'' = copy xs in xs' with [1] = xs''[1] -- == -- entry: prim -- input { [5.0, 7.0, 9.0] } -- output { [5.0, 7.0, 9.0] } entry prim [n] (xs: [n]f64) = test xs -- == -- entry: f_vjp -- input { [5.0, 7.0, 9.0] } -- output { [1.0, 1.0, 1.0] } entry f_vjp [n] (xs: [n]f64) = vjp test xs (replicate n 1) futhark-0.25.27/tests/ad/consume4.fut000066400000000000000000000007251475065116200173510ustar00rootroot00000000000000def test [n] (xs: [n]i32) = let xs' = copy xs let xs'' = map (*2) xs' in xs' with [1] = xs''[1] -- == -- entry: prim -- input { [5, 7, 9] } -- output { [5, 14, 9] } entry prim [n] (xs: [n]i32) = test xs -- == -- entry: f_vjp -- input { [5, 7, 9] } -- output { [1, 2, 1] } entry f_vjp [n] (xs: [n]i32) = vjp test xs (replicate n 1) -- == -- entry: f_jvp -- input { [5, 7, 9] } -- output { [1, 2, 1] } entry f_jvp [n] (xs: [n]i32) = jvp test xs (replicate n 1) futhark-0.25.27/tests/ad/consume5.fut000066400000000000000000000003531475065116200173470ustar00rootroot00000000000000def test [n] (xs: [n]i32) = let xs' = copy xs let foo = xs' with [1] = i32.sum xs' in map (*2) foo -- == -- entry: f_vjp -- input { [1, 2, 3] } -- output { [4, 2, 4] } entry f_vjp [n] (xs: [n]i32) = vjp test xs (replicate n 1) futhark-0.25.27/tests/ad/consume6.fut000066400000000000000000000006431475065116200173520ustar00rootroot00000000000000def test [n] (xs: [n]i32) = let xs' = copy xs in loop xs'' = xs' for i < n do let foo = xs'' with [i] = 1 let m = map (\x -> x) foo in foo with [i] = m[i] -- == -- entry: prim -- input { [1,2,3,4,5] } output { [1,1,1,1,1] } entry prim [n] (xs: [n]i32) = test xs -- == -- entry: f_vjp -- input { [1,2,3,4,5] } output { [0,0,0,0,0] } entry f_vjp [n] (xs: [n]i32) = vjp test xs (replicate n 1) futhark-0.25.27/tests/ad/conv0.fut000066400000000000000000000002631475065116200166360ustar00rootroot00000000000000-- == -- entry: fwd -- input { 1.0 } -- output { 1f32 } entry fwd x = jvp f32.f64 x 1 -- == -- entry: rev -- input { 1.0 } -- output { 1f64 } entry rev x = vjp f32.f64 x 1 futhark-0.25.27/tests/ad/conv1.fut000066400000000000000000000002651475065116200166410ustar00rootroot00000000000000-- == -- entry: fwd -- input { 1f64 } -- output { 1i32 } entry fwd x = jvp i32.f64 x 1 -- == -- entry: rev -- input { 1f64 } -- output { 2f64 } entry rev x = vjp i32.f64 x 2 futhark-0.25.27/tests/ad/fadd.fut000066400000000000000000000003041475065116200165030ustar00rootroot00000000000000-- == -- entry: f_jvp f_vjp -- input { 5.0 7.0 } -- output { 1.0 1.0 } def f (x,y) = x + y : f64 entry f_jvp x y = (jvp f (x,y) (1,0), jvp f (x,y) (0,1)) entry f_vjp x y = vjp f (x,y) 1 futhark-0.25.27/tests/ad/fdiv.fut000066400000000000000000000003161475065116200165400ustar00rootroot00000000000000-- == -- entry: f_jvp f_vjp -- input { 5.0 7.0 } -- output { 0.14285 -0.102041 } def f (x,y) = x / y : f64 entry f_jvp x y = (jvp f (x,y) (1,0), jvp f (x,y) (0,1)) entry f_vjp x y = vjp f (x,y) 1 futhark-0.25.27/tests/ad/fmul.fut000066400000000000000000000003041475065116200165500ustar00rootroot00000000000000-- == -- entry: f_jvp f_vjp -- input { 5.0 7.0 } -- output { 7.0 5.0 } def f (x,y) = x * y : f64 entry f_jvp x y = (jvp f (x,y) (1,0), jvp f (x,y) (0,1)) entry f_vjp x y = vjp f (x,y) 1 futhark-0.25.27/tests/ad/for0.fut000066400000000000000000000005221475065116200164550ustar00rootroot00000000000000def pow y x = loop acc = 1 for _i < y do acc * x -- == -- entry: prim -- input { 3 4 } output { 64 } -- input { 9 3 } output { 19683 } entry prim y x = pow y x -- == -- entry: f_jvp f_vjp -- input { 3 4 } output { 48 } -- input { 9 3 } output { 59049 } entry f_jvp y x = jvp (pow y) x 1 entry f_vjp y x = vjp (pow y) x 1 futhark-0.25.27/tests/ad/for1.fut000066400000000000000000000011001475065116200164470ustar00rootroot00000000000000def pow_list [n] y (xs :[n]i32) = loop accs = (replicate n 1) for _i < y do map2 (*) accs xs -- == -- entry: prim -- input { 3 [1,2,3] } output { [1,8,27] } entry prim y xs = pow_list y xs -- == -- entry: f_vjp f_jvp -- input { 3 [1,2,3] } -- output { [[3,0,0], -- [0,12,0], -- [0,0,27]] -- } entry f_jvp [n] y (xs :[n]i32) = tabulate n (\i -> jvp (pow_list y) xs (replicate n 0 with [i] = 1)) |> transpose entry f_vjp [n] y (xs :[n]i32) = tabulate n (\i -> vjp (pow_list y) xs (replicate n 0 with [i] = 1)) futhark-0.25.27/tests/ad/for2.fut000066400000000000000000000005541475065116200164640ustar00rootroot00000000000000def mult_list xs = loop start = 1 for x in xs do x * x -- == -- entry: prim -- input { [11,5,13] } output { 169 } entry prim = mult_list -- == -- entry: f_jvp f_vjp -- input { [11,5,13] } output { [0,0,26] } entry f_jvp [n] (xs :[n]i32) = tabulate n (\i -> jvp mult_list xs (replicate n 0 with [i] = 1)) entry f_vjp [n] (xs: [n]i32) = vjp mult_list xs 1 futhark-0.25.27/tests/ad/for3.fut000066400000000000000000000012131475065116200164560ustar00rootroot00000000000000def square [n] (xs: [n]i32) = let xs' = copy xs in loop xs'' = xs' for i < n do let a = xs''[i] in xs'' with [i] = a * a -- == -- entry: prim -- input { [1,2,3,4,5] } output { [1,4,9,16,25] } entry prim [n] (xs: [n]i32) = square xs -- == -- entry: f_jvp f_vjp -- input { [1,2,3,4,5] } -- output { [[2,0,0,0,0], -- [0,4,0,0,0], -- [0,0,6,0,0], -- [0,0,0,8,0], -- [0,0,0,0,10]] -- } entry f_jvp [n] (xs :[n]i32) = tabulate n (\i -> jvp square xs (replicate n 0 with [i] = 1)) |> transpose entry f_vjp [n] (xs :[n]i32) = tabulate n (\i -> vjp square xs (replicate n 0 with [i] = 1)) futhark-0.25.27/tests/ad/fwd/000077500000000000000000000000001475065116200156505ustar00rootroot00000000000000futhark-0.25.27/tests/ad/fwd/acc0.fut000066400000000000000000000010201475065116200171670ustar00rootroot00000000000000import "../../accs/intrinsics" def f (acc : *acc([]i32)) i = write acc i (i32.i64 i) -- square entries -- == -- entry: prim -- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } -- output { [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] } entry prim [n] (xs: [n]i32) = let (xs' : *[n]i32) = copy xs in reduce_by_index_stream xs' (*) 1 f (map i64.i32 (xs :> [n]i32)) -- == -- entry: f_jvp -- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } -- output { [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] } entry f_jvp (xs: *[]i32) = jvp prim xs (replicate 10 1) futhark-0.25.27/tests/ad/fwd/for0.fut000066400000000000000000000004531475065116200172400ustar00rootroot00000000000000def pow y x = loop acc = 1 for i < y do acc * x -- == -- entry: prim -- input { 3 4 } output { 64 } -- input { 9 3 } output { 19683 } entry prim y x = pow y x -- == -- entry: f_jvp -- input { 3 4 } output { 48 } -- input { 9 3 } output { 59049 } entry f_jvp y x = jvp (pow y) x 1 futhark-0.25.27/tests/ad/fwd/for1.fut000066400000000000000000000005131475065116200172360ustar00rootroot00000000000000-- computes x^2*y^3 def pow y x = loop acc = 1 for i in [y, y*y] do acc * x * i -- == -- entry: prim -- input { 3 4 } output { 432 } -- input { 9 3 } output { 6561 } entry prim y x = pow y x -- == -- entry: f_jvp -- input { 3 4 } output { 216 } -- input { 9 3 } output { 4374 } entry f_jvp y x = jvp (pow y) x 1 futhark-0.25.27/tests/ad/fwd/map0.fut000066400000000000000000000002311475065116200172210ustar00rootroot00000000000000def f x = map (*(x*x)) [0,1,2] -- == -- entry: f_jvp -- input { 2 } output { [0, 4, 8] } -- input { 4 } output { [0, 8, 16] } entry f_jvp x = jvp f x 1 futhark-0.25.27/tests/ad/fwd/red0.fut000066400000000000000000000002131475065116200172160ustar00rootroot00000000000000def f x = reduce (*) 1 [1,2,x,4] -- == -- entry: f_jvp -- input { 3 } output { 8 } -- input { 10 } output { 8 } entry f_jvp x = jvp f x 1 futhark-0.25.27/tests/ad/fwd/scatter0.fut000066400000000000000000000002351475065116200201150ustar00rootroot00000000000000def f x = let vs = [x, x*x, x*x*x] in spread 5 1 [0,1,2] vs -- == -- entry: f_jvp -- input { 5 } output { [1, 10, 75, 0, 0] } entry f_jvp x = jvp f x 1 futhark-0.25.27/tests/ad/fwd/while0.fut000066400000000000000000000005561475065116200175660ustar00rootroot00000000000000def pow y x = let (_, res) = loop (i, acc) = (0, 1) while i < y do (i + 1, acc * x) in res -- == -- entry: prim -- input { 3 4 } output { 64 } -- input { 9 3 } output { 19683 } entry prim y x = pow y x -- == -- entry: f_jvp -- input { 3 4 } output { 48 } -- input { 9 3 } output { 59049 } entry f_jvp y x = jvp (pow y) x 1 futhark-0.25.27/tests/ad/gather0.fut000066400000000000000000000013421475065116200171420ustar00rootroot00000000000000-- == -- entry: fwd_J rev_J -- input { [4.0,3.0,2.0,1.0] [0i64,1i64,2i64,3i64] } -- output { [[1.0, 0.0, 0.0, 0.0], -- [0.0, 1.0, 0.0, 0.0], -- [0.0, 0.0, 1.0, 0.0], -- [0.0, 0.0, 0.0, 1.0]] -- } -- input { [4.0,3.0,2.0,1.0] [0i64,0i64,3i64,3i64] } -- output { [[1.0, 0.0, 0.0, 0.0], -- [1.0, 0.0, 0.0, 0.0], -- [0.0, 0.0, 0.0, 1.0], -- [0.0, 0.0, 0.0, 1.0]] -- } def gather xs is = map (\(i: i64) -> xs[i]) is entry fwd_J [n] [m] (xs: [n]f64) (is: [m]i64) = transpose (tabulate n (\j -> jvp (`gather` is) xs (replicate n 0 with [j] = 1))) entry rev_J [n] [m] (xs: [n]f64) (is: [m]i64) = tabulate m (\j -> vjp (`gather` is) xs (replicate m 0 with [j] = 1)) futhark-0.25.27/tests/ad/gather1.fut000066400000000000000000000023201475065116200171400ustar00rootroot00000000000000-- == -- entry: fwd_J rev_J -- input -- { -- [[1.0,2.0],[3.0,4.0]] [1i64, 0i64, 1i64, 1i64] -- } -- output -- { -- [[[[0.000000f64, 1.000000f64], -- [0.000000f64, 0.000000f64]], -- [[1.000000f64, 0.000000f64], -- [0.000000f64, 0.000000f64]], -- [[0.000000f64, 1.000000f64], -- [0.000000f64, 0.000000f64]], -- [[0.000000f64, 1.000000f64], -- [0.000000f64, 0.000000f64]]], -- [[[0.000000f64, 0.000000f64], -- [0.000000f64, 1.000000f64]], -- [[0.000000f64, 0.000000f64], -- [1.000000f64, 0.000000f64]], -- [[0.000000f64, 0.000000f64], -- [0.000000f64, 1.000000f64]], -- [[0.000000f64, 0.000000f64], -- [0.000000f64, 1.000000f64]]]] -- } def gather xs is = map (\(i: i64) -> xs[i]) is def mapgather xss is = map (`gather` is) xss def onehot n i : [n]f64 = tabulate n (\j -> f64.bool (i==j)) def onehot_2d n m p : [n][m]f64 = tabulate_2d n m (\i j -> f64.bool ((i,j)==p)) entry fwd_J [n][m][k] (xs: [n][m]f64) (is: [k]i64) = tabulate_2d n m (\i j -> jvp (`mapgather` is) xs (onehot_2d n m (i,j))) |> map transpose |> map (map transpose) |> map transpose entry rev_J [n][m][k] (xs: [n][m]f64) (is: [k]i64) = tabulate_2d n k (\i j -> vjp (`mapgather` is) xs (onehot_2d n k (i,j))) futhark-0.25.27/tests/ad/gather2.fut000066400000000000000000000016361475065116200171520ustar00rootroot00000000000000-- == -- entry: fwd_J rev_J -- input -- { -- [1.0,2.0,3.0,4.0] -- [[1i64, 3i64], [2i64, 2i64]] -- } -- output -- { -- [[[0.000000f64, 1.000000f64, 0.000000f64, 0.000000f64], -- [0.000000f64, 0.000000f64, 0.000000f64, 1.000000f64]], -- [[0.000000f64, 0.000000f64, 1.000000f64, 0.000000f64], -- [0.000000f64, 0.000000f64, 1.000000f64, 0.000000f64]]] -- } def gather xs is = map (\(i: i64) -> xs[i]) is def mapgather [k][n][m] (xs: [k]f64) (iss: [n][m]i64) : [n][m]f64 = map (gather xs) iss def onehot n i : [n]f64 = tabulate n (\j -> f64.bool (i==j)) entry fwd_J [k][n][m] (xs: [k]f64) (iss: [n][m]i64) = tabulate k (\i -> jvp (`mapgather` iss) xs (onehot k i)) |> transpose |> map transpose def onehot_2d n m p : [n][m]f64 = tabulate_2d n m (\i j -> f64.bool ((i,j)==p)) entry rev_J [k][n][m] (xs: [k]f64) (iss: [n][m]i64) = tabulate_2d n m (\i j -> vjp (`mapgather` iss) xs (onehot_2d n m (i,j))) futhark-0.25.27/tests/ad/genred-opt/000077500000000000000000000000001475065116200171345ustar00rootroot00000000000000futhark-0.25.27/tests/ad/genred-opt/bfast-mmm.fut000066400000000000000000000015161475065116200215420ustar00rootroot00000000000000 -- | dot-product but in which we filter-out the entries for which `vct[i]==NAN` let dotprod_filt [n] (vct: [n]f32) (xs: [n]f32) (ys: [n]f32) : f32 = f32.sum (map3 (\v x y -> x * y * if (v == 333.333) then 0.0 else 1.0) vct xs ys) -- f32.isnan v -- | matrix-matrix multiplication but with NAN-filtering on `vct` let matmul_filt [n][p][m] (xss: [n][p]f32) (yss: [p][m]f32) (vct: [p]f32) : [n][m]f32 = map (\xs -> map (dotprod_filt vct xs) (transpose yss)) xss -- | implementation is in this entry point -- the outer map is distributed directly entry main [m][N][k] (n: i64) (X: [k][N]f32) (images : [m][N]f32) (res_adj: [m][k][k]f32) = let Xt= copy (transpose X) let Xh = (X[:,:n]) let Xth = (Xt[:n,:]) let Yh = (images[:,:n]) let batchMMM (Z, Zt, Q) = map (matmul_filt Z Zt) Q in vjp batchMMM (X, Xt, images) res_adj futhark-0.25.27/tests/ad/genred-opt/gemm-simple.fut000066400000000000000000000011251475065116200220670ustar00rootroot00000000000000-- == -- entry: rev_J -- compiled random input { 0.5f32 0.7f32 [1024][1024]f32 [1024][1024]f32 [1024][1024]f32 [1024][1024]f32} auto output type real = f32 let real_sum = f32.sum let dotprod xs ys = real_sum (map2 (*) xs ys) let gemm [m][n][q] (alpha:real) (beta: real) (xss: [m][q]real, yss: [q][n]real, css: [m][n]real) = map2 (\xs cs -> map2 (\ys c -> c*alpha + beta*(dotprod xs ys)) (transpose yss) cs) xss css entry rev_J [n][m][q] (alpha: real) (beta: real) (xss: [m][q]real) (yss: [q][n]real) (css: [m][n]real) (res_adj: [m][n]real) = vjp (gemm alpha beta) (xss, yss, css) res_adj futhark-0.25.27/tests/ad/genred-opt/lud-mmm.fut000066400000000000000000000062211475065116200212250ustar00rootroot00000000000000-- == -- -- compiled random input {[128][32][32]f32 [128][32][32]f32 [128][128][32][32]f32 [128][128][32][32]f32} auto output -- Alternatively try this with default_tile_size=8 and default_reg_tile_size=2 -- compiled random input {[256][16][16]f32 [256][16][16]f32 [256][256][16][16]f32} auto output let ludMult [m][b] (top_per: [m][b][b]f32, lft_per: [m][b][b]f32, mat_slice: [m][m][b][b]f32) : *[m][m][b][b]f32 = -- let top_slice = map transpose top_per in map (\(mat_arr: [m][b][b]f32, lft: [b][b]f32): [m][b][b]f32 -> map (\ (mat_blk: [b][b]f32, top: [b][b]f32): [b][b]f32 -> map (\ (mat_row: [b]f32, lft_row: [b]f32): [b]f32 -> map2 (\mat_el top_row -> let prods = map2 (*) lft_row top_row let sum = f32.sum prods in mat_el - sum ) mat_row (transpose top) ) (zip (mat_blk) lft ) ) (zip (mat_arr) (top_per) ) ) (zip (mat_slice) (lft_per) ) let main [m][b] (top_per: [m][b][b]f32) (lft_per: [m][b][b]f32) (mat_slice: [m][m][b][b]f32 ) (res_adj: [m][m][b][b]f32) = vjp ludMult (top_per, lft_per, mat_slice) res_adj --- BIG COMMENT ------- --- The straigh compilation yields something like: --- --- segmap(thread; #groups=num_tblocks_10633; tblocksize=segmap_tblock_size_10632; virtualise) --- (gtid_9445 < m_9078, gtid_9446 < m_9078, gtid_9447 < b_9079, gtid_9448 < b_9079, gtid_9449 < b_9079) (~phys_tid_9450) : --- { acc(acc_cert_p_10557, [m_9078][m_9078][b_9079][b_9079], {f32}), --- acc(acc_cert_p_10590, [m_9078][m_9078][b_9079][b_9079], {f32}) --- } { --- --- let {r_adj_el : f32} = r_adj[gtid_9445, gtid_9446, gtid_9447, gtid_9448] --- let {lft_el : f32} = lft_per_9081[gtid_9445, gtid_9447, gtid_9449] --- let {top_el : f32} = top_per_coalesced_10752[gtid_9446, gtid_9449, gtid_9448] --- --- let {acc_10651 : acc(acc_cert_p_10590, [m_9078][m_9078][b_9079][b_9079], {f32})} = --- update_acc(acc_p_10591, {gtid_9445, gtid_9446, gtid_9447, gtid_9449}, {r_adj_el * top_el}) --- --- let {acc_10652 : acc(acc_cert_p_10557, [m_9078][m_9078][b_9079][b_9079], {f32})} = --- update_acc(acc_p_10558, {gtid_9445, gtid_9446, gtid_9448, gtid_9449}, {r_adj_el * lft_el}) --- --- return {returns acc_10652, returns acc_10651} --- } --- --- in {withacc_inter_10636, withacc_inter_10635}) --- --- --- segmap(thread; #groups=num_tblocks_10675; tblocksize=segmap_tblock_size_10674; virtualise) --- (gtid_9334 < m_9078, gtid_9335 < m_9078, gtid_9336 < b_9079, gtid_9337 < b_9079) (~phys_tid_9338) : --- { acc(acc_cert_p_10532, [m_9078][b_9079][b_9079], {f32}) } --- { --- let {r_adj_el : f32} = r_adj[gtid_9334, gtid_9335, gtid_9336, gtid_9337] --- let {acc_10683 : acc(acc_cert_p_10532, [m_9078][b_9079][b_9079], {f32})} = --- update_acc(acc_p_10533, {gtid_9334, gtid_9336, gtid_9337}, {r_adj_el}) --- return {returns acc_10683} --- } -------------------------------------------------- futhark-0.25.27/tests/ad/genred-opt/matmul-simple.fut000066400000000000000000000007411475065116200224440ustar00rootroot00000000000000-- == -- entry: rev_J -- compiled input { -- [[1.0,2.0],[3.0,4.0]] -- [[5.0,6.0],[7.0,8.0]] -- [[1.0,2.0],[3.0,4.0]] -- } -- output { -- [[17.0, 23.0], [39.0, 53.0]] -- [[10.0, 14.0], [14.0, 20.0]] -- } let dotprod xs ys = f64.sum (map2 (*) xs ys) let matmul [m][n][q] (xss: [m][q]f64, yss: [q][n]f64) = map (\xs -> map (dotprod xs) (transpose yss)) xss entry rev_J [n][m][q] (xss: [m][q]f64) (yss: [q][n]f64) (res_adj: [m][n]f64) = vjp matmul (xss, yss) res_adj futhark-0.25.27/tests/ad/genred-opt/matmul.fut000066400000000000000000000014661475065116200211620ustar00rootroot00000000000000-- == -- entry: fwd_J rev_J -- input -- { -- [[1.0,2.0],[3.0,4.0]] [[5.0,6.0],[7.0,8.0]] -- } -- output -- { -- [[[[1.0f64, 0.0f64], -- [2.0f64, 0.0f64]], -- [[0.0f64, 1.0f64], -- [0.0f64, 2.0f64]]], -- [[[3.0f64, 0.0f64], -- [4.0f64, 0.0f64]], -- [[0.0f64, 3.0f64], -- [0.0f64, 4.0f64]]]] -- } let dotprod xs ys = f64.sum (map2 (*) xs ys) let matmul xss yss = map (\xs -> map (dotprod xs) (transpose yss)) xss let onehot_2d n m p : [n][m]f64 = tabulate_2d n m (\i j -> f64.bool ((i,j)==p)) entry fwd_J [n][m][p] (xss: [n][m]f64) (yss: [m][p]f64) = tabulate_2d m p (\i j -> jvp (matmul xss) yss (onehot_2d m p (i,j))) |> transpose |> map transpose |> transpose entry rev_J [n][m][p] (xss: [n][m]f64) (yss: [m][p]f64) = tabulate_2d n p (\i j -> vjp (matmul xss) yss (onehot_2d n p (i,j))) futhark-0.25.27/tests/ad/genred-opt/matvec-simple.fut000066400000000000000000000013601475065116200224220ustar00rootroot00000000000000-- == -- entry: rev_J -- compiled random input { [1024][80]f32 [1024][80]f32 [1024][1024]f32} auto output -- compiled random input { [1024][1024]f32 [1024][1024]f32 [1024][1024]f32} auto output -- compiled random input { [2048][1024]f32 [2048][1024]f32 [2048][2048]f32} auto output -- compiled input -- { -- [[1.0,2.0],[3.0,4.0]] [[5.0,6.0],[7.0,8.0]] -- } type real = f32 let real_sum = f32.sum let dotprod xs ys = real_sum (map2 (*) xs ys) let matvec [n][q] (mat: [n][q]real) (vct: [q]real) : [n]real = map (dotprod vct) mat let matmat [m][n][q] (mat1: [m][q]real, mat2: [n][q]real) : [m][n]real = map (matvec mat2) mat1 entry rev_J [m][n][q] (mat1: [m][q]real) (mat2: [n][q]real) (res_adj: [m][n]real) = vjp matmat (mat1, mat2) res_adj futhark-0.25.27/tests/ad/if0.fut000066400000000000000000000006611475065116200162710ustar00rootroot00000000000000-- == -- entry: f_jvp -- input { true 5.0 7.0 } -- output { 7.0 5.0 } -- input { false 5.0 7.0 } -- output { 0.14285 -0.102041 } def f (b, x, y) : f64 = if b then x*y else x/y entry f_jvp b x y = (jvp f (b,x,y) (b,1,0), jvp f (b,x,y) (b,0,1)) -- == -- entry: f_vjp -- input { true 5.0 7.0 } -- output { false 7.0 5.0 } -- input { false 5.0 7.0 } -- output { false 0.14285 -0.102041 } entry f_vjp b x y = vjp f (b,x,y) 1 futhark-0.25.27/tests/ad/if1.fut000066400000000000000000000005651475065116200162750ustar00rootroot00000000000000-- == -- entry: f_jvp -- input { false 5.0 } -- output { 2.0 } -- input { true 5.0 } -- output { 11.0 } def f (b, x) : f64 = let y = if b then x*x else x let z = y + x in z entry f_jvp b x = (jvp f (b,x) (b,1)) -- == -- entry: f_vjp -- input { false 5.0 } -- output { false 2.0 } -- input { true 5.0 } -- output { false 11.0 } entry f_vjp b x = vjp f (b,x) 1 futhark-0.25.27/tests/ad/if2.fut000066400000000000000000000010131475065116200162630ustar00rootroot00000000000000-- == -- entry: f_jvp -- input { [1.0,2.0,3.0] } -- output { [0.0, 3.0, 2.0] } -- input { [-1.0,2.0,3.0] } -- output { [3.0, 0.0, -1.0] } -- structure { If/Replicate 0 } def f x : f64 = #[unsafe] let z = if x[0] < 0 then x[0] else x[1] let y = x[2] in y * z entry f_jvp x = map (\i -> jvp f x (map (const 0) x with [i] = 1)) (indices x) -- == -- entry: f_vjp -- input { [1.0,2.0,3.0] } -- output { [0.0, 3.0, 2.0] } -- input { [-1.0,2.0,3.0] } -- output { [3.0, 0.0, -1.0] } entry f_vjp x = vjp f x 1 futhark-0.25.27/tests/ad/imul.fut000066400000000000000000000002361475065116200165570ustar00rootroot00000000000000-- Check the absence of integer overflow. -- == -- input { 2000000000i32 2000000000i32 } output { -294967296i32 } def main x y : i32 = vjp (\x -> x * y) x 2 futhark-0.25.27/tests/ad/iota0.fut000066400000000000000000000001761475065116200166300ustar00rootroot00000000000000-- == -- entry: rev -- compiled input { 5i64 } -- output { 5i64 } def f (n: i64) = i64.sum (iota n) entry rev n = vjp f n 1 futhark-0.25.27/tests/ad/isnan.fut000066400000000000000000000003121475065116200167140ustar00rootroot00000000000000-- We should not crash on functions like isnan (or isinf for that -- matter) that have a differentiable domain, but a nondifferentiable -- codomain. entry fwd = jvp f32.isnan entry rev = vjp f32.isnan futhark-0.25.27/tests/ad/issue1473.fut000066400000000000000000000054421475065116200172640ustar00rootroot00000000000000-- test mpr sim with ad for params def pi = 3.141592653589793f32 -- some type abbreviations type mpr_pars = {G: f32, I: f32, Delta: f32, eta: f32, tau: f32, J: f32} type mpr_node = (f32, f32) type mpr_net [n] = [n]mpr_node -- this is tranposed from mpr-pdq to avoid tranposes in history update type mpr_hist [t] [n] = [t]mpr_net [n] type connectome [n] = {weights: [n][n]f32, idelays: [n][n]i64} -- do one time step w/ Euler def mpr_step [t] [n] (now: i64) (dt: f32) (buf: *mpr_hist [t] [n]) (conn: connectome [n]) (p: mpr_pars) : *mpr_hist [t] [n] = -- define individual derivatives as in mpr pdq let dr r V = 1 / p.tau * (p.Delta / (pi * p.tau) + 2 * V * r) let dV r V r_c = 1 / p.tau * (V ** 2 - pi ** 2 * p.tau ** 2 * r ** 2 + p.eta + p.J * p.tau * r + p.I + r_c) let dfun (r, V, c) = (dr r V, dV r V c) -- unpack current state for clarity let (r, V) = last buf |> unzip -- connectivity eval let r_c_i i w d = map2 (\wj dj -> wj * buf[now - dj, i].0) w d |> reduce (+) 0f32 |> (* p.G) let r_c = map3 r_c_i (iota n) conn.weights conn.idelays -- Euler step let erV = map3 (\r V c -> (dr r V, dV r V c)) r V r_c |> map2 (\(r, V) (dr, dV) -> (r + dt * dr, V + dt * dV)) (last buf) |> map1 (\(r, V) -> (if r >= 0f32 then r else 0f32, V)) -- now for the Heun step let (er, eV) = unzip erV let hrV = map3 (\r V c -> (dr r V, dV r V c)) er eV r_c |> map2 (\(r, V) (dr, dV) -> (r + dt * dr, V + dt * dV)) (last buf) |> map1 (\(r, V) -> (if r >= 0f32 then r else 0f32, V)) -- return updated buffer in buf with [now + 1] = copy hrV def run_mpr [t] [n] (horizon: i64) (dt: f32) (buf: mpr_hist [t] [n]) (conn: connectome [n]) (p: mpr_pars) : mpr_hist [t] [n] = loop buf = copy buf for now < (t - horizon - 1) do mpr_step (now + horizon) dt buf conn p def mpr_pars_with_G (p: mpr_pars) (new_G: f32) : mpr_pars = let new_p = copy p in new_p with G = new_G def loss [t] [n] (x: mpr_hist [t] [n]) : f32 = let r = map unzip x[t - 10:] |> unzip |> (.0) let sum = map (reduce (+) 0f32) r |> reduce (+) 0f32 in sum def sweep [t] [n] (ng: i64) (horizon: i64) (dt: f32) (buf: mpr_hist [t] [n]) (conn: connectome [n]) (p: mpr_pars) : [ng]f32 = let Gs = tabulate ng (\i -> 0.0 + (f32.i64 i) * 0.1) let do_one G = run_mpr horizon dt buf conn (mpr_pars_with_G p G) |> loss in map (\g -> vjp do_one g 1f32) Gs -- == -- no_ispc compiled input { 1i64 5i64 10i64 7i64 } -- output { [0.000086f32] } def main (ng: i64) (nh: i64) (nt: i64) (nn: i64) = let dt = 0.01f32 let buf = tabulate_2d (nt + nh) nn (\i j -> (0.1f32, -2.0f32)) let conn = { weights = tabulate_2d nn nn (\i j -> 0.1f32) , idelays = tabulate_2d nn nn (\i j -> ((i * j) % nh)) } let p = {G = 0.1f32, I = 0.0f32, Delta = 0.7f32, eta = (-4.6f32), tau = 1.0f32, J = 14.5f32} in sweep ng nh dt buf conn p futhark-0.25.27/tests/ad/issue1564.fut000066400000000000000000000014121475065116200172560ustar00rootroot00000000000000-- = -- entry: kmeansSpAD let costSparse [nnz][np1][cols][k] (colidxs: [nnz]i64) (begrows: [np1]i64) (cluster_centers: [k][cols]f32) : f32 = let n = np1 - 1 -- partial_distss : [n][k]f32 let cluster = cluster_centers[0] let foo = let j = 0 let correction = 0 let (correction, _) = loop (correction, j) while j < 2 do let column = colidxs[j] let cluster_value = cluster[column] in (correction+cluster_value, j) in correction in foo entry kmeansSpAD [nnz][np1] (k: i64) (indices_data: [nnz]i64) (pointers: [np1]i64) = jvp2 (\x -> vjp (costSparse indices_data pointers) x 1) (replicate k (replicate 1 1)) (replicate k (replicate 1 1)) futhark-0.25.27/tests/ad/issue1577.fut000066400000000000000000000004041475065116200172620ustar00rootroot00000000000000-- == -- entry: main -- input { [1i64,1i64,3i64,3i64] [1,2,3,4] } -- output { [0,3,0,7,0] } let red [n] (is: [n]i64) (vs: [n]i32) = reduce_by_index (replicate 5 1) (*) 1 is vs let main [n] (is: [n]i64) (vs: [n]i32) = jvp (red is) vs (replicate n 1) futhark-0.25.27/tests/ad/issue1604.fut000066400000000000000000000003441475065116200172540ustar00rootroot00000000000000-- == -- entry: f_vjp -- input { [1, 2, 3] } -- output { [9, 9, 9] } def f [n] (xs: [n]i32) = loop (B,A) = (xs,xs) for i < n do (map (+1) B, map (*2) A) entry f_vjp [n] (xs: [n]i32) = vjp f xs (replicate n 1, replicate n 1) futhark-0.25.27/tests/ad/issue1613.fut000066400000000000000000000012041475065116200172500ustar00rootroot00000000000000def dotprod [n] (xs: [n]f32) (ys: [n]f32) = reduce (+) 0 (map2 (*) xs ys) def matvecmul_row [n][m] (xss: [n][m]f32) (ys: [m]f32) = map (dotprod ys) xss def helmholtz [n] (R: f32) (T: f32) (b: [n]f32) (A: [n][n]f32) (xs: [n]f32) : f32 = let bxs = dotprod b xs let term1 = map (\x -> f32.log (x / (1 - bxs))) xs |> f32.sum let term2 = dotprod xs (matvecmul_row A xs) / (f32.sqrt(8) * bxs) let term3 = (1 + (1 + f32.sqrt(2)) * bxs) / (1 + (1 - f32.sqrt(2)) * bxs) |> f32.log in R * T * term1 - term2 * term3 entry calculate_jacobian [n] (R: f32) (T: f32) (b: [n]f32) (A: [n][n]f32) (xs: [n]f32) = vjp (helmholtz R T b A) xs 1.0 futhark-0.25.27/tests/ad/issue1879.fut000066400000000000000000000006131475065116200172710ustar00rootroot00000000000000-- == -- entry: main_ad -- input { [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] } -- output { [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] } -- input { [[1.0, 2.0, 3.0], [7.0, 8.0, 9.0]] } -- output { [[1.0, 1.0, 1.0], [0.0, 0.0, 0.0]] } def f [n] (xs: [n][3]f64) = let p x = if x[0] < 5 then true else false let (res, sizes) = partition p xs in res |> flatten |> f64.sum entry main_ad xs = vjp f xs 1.0 futhark-0.25.27/tests/ad/issue2199.fut000066400000000000000000000006321475065116200172660ustar00rootroot00000000000000-- == -- entry: test_primal test_rev -- input { [1.0,2.0] [3.0,4.0] [5.0, 6.0] } -- output { 3.0 7.0 11.0 } def op (x0, y0, z0) (x1, y1, z1) : (f64, f64, f64) = (x0 + x1, y0 + y1, z0 + z1) def ne = (0f64, 0f64, 0f64) def primal xs = reduce_comm op ne xs entry test_primal as bs cs = primal (zip3 as bs cs) entry test_rev as bs cs = (vjp2 (\(as, bs, cs) -> test_primal as bs cs) (as, bs, cs) (1, 1, 1)).0 futhark-0.25.27/tests/ad/kmeans-cost-rev.fut000066400000000000000000000005701475065116200206300ustar00rootroot00000000000000def euclid_dist_2 [d] (pt1: [d]f32) (pt2: [d]f32): f32 = f32.sum (map (\x->x*x) (map2 (-) pt1 pt2)) def cost [n][k][d] (points: [n][d]f32) (centres: [k][d]f32) = points |> map (\p -> map (euclid_dist_2 p) centres) |> map f32.minimum |> f32.sum def grad f x = vjp f x 1f32 def main [n][d] cluster_centres (points: [n][d]f32) = grad (cost points) cluster_centres futhark-0.25.27/tests/ad/lighthouse.fut000066400000000000000000000023741475065116200177710ustar00rootroot00000000000000-- From [Griewank 2008]. -- == -- entry: lighthouse_jvp lighthouse_vjp -- input { 2.0 1.5 0.4 2.1 } -- output { 2.902513633461043f64 -15.102798701184362f64 95.71780341846966f64 18.23196255589898f64 -- 4.353770450191565f64 -16.849170784854458f64 143.57670512770449f64 27.347943833848472f64 -- } def lighthouse (nu, gamma, omega, t) = let y1 = (nu * f64.tan(omega * t)) / (gamma - f64.tan(omega * t)) let y2 = (gamma * nu * f64.tan(omega * t)) / (gamma - f64.tan(omega * t)) in (y1, y2) entry lighthouse_jvp nu gamma omega t = let (y1_dnu, y2_dnu) = jvp lighthouse (nu, gamma, omega, t) (1, 0, 0, 0) let (y1_dgamma, y2_dgamma) = jvp lighthouse (nu, gamma, omega, t) (0, 1, 0, 0) let (y1_domega, y2_domega) = jvp lighthouse (nu, gamma, omega, t) (0, 0, 1, 0) let (y1_dt, y2_dt) = jvp lighthouse (nu, gamma, omega, t) (0, 0, 0, 1) in (y1_dnu, y1_dgamma, y1_domega, y1_dt, y2_dnu, y2_dgamma, y2_domega, y2_dt) entry lighthouse_vjp nu gamma omega t = let (y1_dnu, y1_dgamma, y1_domega, y1_dt) = vjp lighthouse (nu, gamma, omega, t) (1, 0) let (y2_dnu, y2_dgamma, y2_domega, y2_dt) = vjp lighthouse (nu, gamma, omega, t) (0, 1) in (y1_dnu, y1_dgamma, y1_domega, y1_dt, y2_dnu, y2_dgamma, y2_domega, y2_dt) futhark-0.25.27/tests/ad/lotka_volterra.fut000066400000000000000000000071421475065116200206440ustar00rootroot00000000000000let lv_step (growth_prey: f32) (predation: f32) (growth_pred: f32) (decline_pred: f32) (prey: f32) (pred: f32) : (f32, f32) = let dprey = (growth_prey - predation*pred) * prey let dpred = (growth_pred * prey - decline_pred) * pred in (dprey, dpred) let euler_method (step_size: f32) (num_steps: i64) (init_prey: f32) (init_pred: f32) (growth_prey: f32) (predation: f32) (growth_pred: f32) (decline_pred: f32) : [](f32,f32) = let states = replicate (num_steps+1) (init_prey, init_pred) let (_, states) = loop (curr_state, states) = ((init_prey,init_pred), states) for i < num_steps do let (curr_prey, curr_pred) = curr_state let (dprey, dpred) = lv_step growth_prey predation growth_pred decline_pred curr_prey curr_pred let next_state = (curr_prey + step_size * dprey, curr_pred + step_size * dpred) let states[i+1] = next_state in (next_state, states) in states let runge_kutta (step_size : f32) (num_steps : i64) (init_prey: f32, init_pred: f32, growth_prey: f32, predation: f32, growth_pred: f32, decline_pred: f32) : [](f32,f32) = let fn = lv_step growth_prey predation growth_pred decline_pred let states = replicate (num_steps) (init_prey, init_pred) let (_, states) = loop (curr_state, states) = ((init_prey,init_pred), states) for i < num_steps do let (curr_prey, curr_pred) = curr_state let (k1_prey, k1_pred) = fn curr_prey curr_pred let (k2_prey, k2_pred) = fn (curr_prey + step_size/2 * k1_prey) (curr_pred + step_size/2 * k1_pred) let (k3_prey, k3_pred) = fn (curr_prey + step_size/2 * k2_prey) (curr_pred + step_size/2 * k2_pred) let (k4_prey, k4_pred) = fn (curr_prey + step_size * k3_prey) (curr_pred + step_size * k3_pred) let next_state = (curr_prey + step_size/6 * (k1_prey + 2*k2_prey + 2*k3_prey + k4_prey), curr_pred + step_size/6 * (k1_pred + 2*k2_pred + 2*k3_pred + k4_pred)) let states[i] = next_state in (next_state, states) in states let to_array ((v1: f32), (v2: f32)): [2]f32 = [v1, v2] entry main (step_size: f32) (num_steps: i64) (init_prey: f32) (init_pred: f32) (growth_prey: f32) (predation: f32) (growth_pred: f32) (decline_pred: f32) : [][2]f32 = map to_array (runge_kutta step_size num_steps (init_prey, init_pred, growth_prey, predation, growth_pred, decline_pred)) entry runge_kutta_fwd (step_size : f32) (num_steps : i64) (init_prey: f32) (init_pred: f32) (growth_prey: f32) (predation: f32) (growth_pred: f32) (decline_pred: f32) (init_prey_tan: f32) (init_pred_tan: f32) (growth_prey_tan: f32) (predation_tan: f32) (growth_pred_tan: f32) (decline_pred_tan: f32) : [][2]f32 = map to_array (jvp (runge_kutta step_size num_steps ) (init_prey, init_pred, growth_prey, predation, growth_pred, decline_pred) (init_prey_tan, init_pred_tan, growth_prey_tan, predation_tan, growth_pred_tan, decline_pred_tan)) futhark-0.25.27/tests/ad/map0.fut000066400000000000000000000001461475065116200164460ustar00rootroot00000000000000-- == -- entry: rev -- input { [1,2,3] [3,2,1] } -- output { [6,4,2] } entry rev = vjp (map (*2i32)) futhark-0.25.27/tests/ad/map1.fut000066400000000000000000000003031475065116200164420ustar00rootroot00000000000000-- -- == -- entry: rev -- input { [[1.0,2.0,3.0,4.0],[1.0,2.0,3.0,4.0]] [1.0,2.0] } -- output {[[24.0, 12.0, 8.0, 6.0], -- [48.0, 24.0, 16.0, 12.0]] } entry rev = vjp (map f64.product) futhark-0.25.27/tests/ad/map2.fut000066400000000000000000000005421475065116200164500ustar00rootroot00000000000000-- Map with free variable. -- == -- entry: fwd_J rev_J -- input { 2.0 [1.0,2.0,3.0] } -- output { [1.0,2.0,3.0] } def onehot n i : [n]f64 = tabulate n (\j -> f64.bool (i==j)) entry fwd_J [n] (c: f64) (xs: [n]f64) = jvp (\c' -> map (*c') xs) c 1 entry rev_J [n] (c: f64) (xs: [n]f64) = tabulate n (\i -> vjp (\c' -> map (*c') xs) c (onehot n i)) futhark-0.25.27/tests/ad/map3.fut000066400000000000000000000004421475065116200164500ustar00rootroot00000000000000-- == -- entry: fwd rev -- input { 1i32 [1i32,2i32,3i32] } -- output { [1i32,2i32,3i32] } entry fwd [n] (x: i32) (xs: [n]i32) = jvp (\x -> map (*x) xs) x 1 entry rev [n] (x: i32) (xs: [n]i32) = tabulate n (\i -> vjp (\x -> map (*x) xs) x (replicate n 0 with [i] = 1)) futhark-0.25.27/tests/ad/map4.fut000066400000000000000000000012201475065116200164440ustar00rootroot00000000000000-- An array is both a 'map' input and a free variable in the lambda. -- == -- entry: fwd_J rev_J -- input { [1,2,3] } -- output { -- [[[2, 0, 0], [1, 1, 0], [1, 0, 1]], [[1, 1, 0], [0, 2, 0], [0, 1, 1]], [[1, 0, 1], [0, 1, 1], [0, 0, 2]]] -- } def f (xs: []i32) = map (\x -> map (+x) xs) xs def onehot n i : [n]i32 = tabulate n (\j -> i32.bool (i==j)) def onehot_2d n m p : [n][m]i32 = tabulate_2d n m (\i j -> i32.bool ((i,j)==p)) entry fwd_J [n] (xs: [n]i32) = tabulate n (\i -> jvp f xs (onehot n i)) |> map transpose |> transpose |> map transpose entry rev_J [n] (xs: [n]i32) = tabulate_2d n n (\i j -> vjp f xs (onehot_2d n n (i,j))) futhark-0.25.27/tests/ad/map5.fut000066400000000000000000000007621475065116200164570ustar00rootroot00000000000000-- Map with free array variable. -- == -- entry: fwd_J rev_J -- input { [[1,2,3],[4,5,6]] [0,0] } -- output { [[1, 0], [0, 1]] } def onehot n i : [n]i32 = tabulate n (\j -> i32.bool (i==j)) def f [n][m] (free: [n][m]i32) (is: [n]i32) = map (\i -> foldl (+) 0 free[i]+i) is entry fwd_J [n][m] (free: [n][m]i32) (is: [n]i32) = tabulate n (\i -> jvp (f free) is (onehot n i)) |> transpose entry rev_J [n][m] (free: [n][m]i32) (is: [n]i32) = tabulate n (\i -> vjp (f free) is (onehot n i)) futhark-0.25.27/tests/ad/map6.fut000066400000000000000000000015151475065116200164550ustar00rootroot00000000000000-- #1878 -- == -- entry: fwd_J rev_J -- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } -- output { [[0.0, 2.0, 3.0, 4.0], -- [0.0, 0.0, 1.0, 1.0], -- [0.0, 0.0, 0.0, 1.0], -- [0.0, 0.0, 0.0, 0.0], -- [-4.0, -6.0, -7.0, -8.0], -- [0.0, 0.0, -1.0, -1.0], -- [0.0, 0.0, 0.0, -1.0], -- [0.0, 0.0, 0.0, 0.0]] -- } def obj (x : [8]f64) = #[unsafe] -- For simplicity of generated code. let col_w_pre_red = tabulate_3d 4 2 4 (\ k i j -> x[k+j]*x[i+j]) let col_w_red = map (map f64.sum) col_w_pre_red let col_eq : [4]f64 = map (\w -> w[0] - w[1]) col_w_red in col_eq entry fwd_J (x : [8]f64) = tabulate 8 (\i -> jvp obj x (replicate 8 0 with [i] = 1)) entry rev_J (x : [8]f64) = transpose (tabulate 4 (\i -> vjp obj x (replicate 4 0 with [i] = 1))) futhark-0.25.27/tests/ad/map7.fut000066400000000000000000000012071475065116200164540ustar00rootroot00000000000000-- #1878. The interesting thing here is that the sparse adjoint also -- has active free variables. -- == -- entry: fwd_J rev_J -- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } -- output { [0.0, 0.0, 0.0, 0.0, -4.0, 0.0, 0.0, 0.0] } def obj (x : [8]f64) = #[unsafe] -- For simplicity of generated code. let col_w_pre_red = tabulate_3d 4 2 4 (\ k i j -> x[k+j]*x[i+j]) let col_w_red = map (map f64.sum) col_w_pre_red let col_eq : [4]f64 = map (\w -> w[0] - w[1]) col_w_red in f64.maximum col_eq entry fwd_J (x : [8]f64) = tabulate 8 (\i -> jvp obj x (replicate 8 0 with [i] = 1)) entry rev_J (x : [8]f64) = vjp obj x 1 futhark-0.25.27/tests/ad/matmul.fut000066400000000000000000000016061475065116200171120ustar00rootroot00000000000000-- == -- entry: fwd_J rev_J -- input -- { -- [[1.0,2.0],[3.0,4.0]] [[5.0,6.0],[7.0,8.0]] -- } -- output -- { -- [[[[1.000000f64, 0.000000f64], -- [1.000000f64, 0.000000f64]], -- [[0.000000f64, 1.000000f64], -- [0.000000f64, 1.000000f64]]], -- [[[1.000000f64, 0.000000f64], -- [1.000000f64, 0.000000f64]], -- [[0.000000f64, 1.000000f64], -- [0.000000f64, 1.000000f64]]]] -- } def dotprod xs ys = f64.sum (map2 (+) xs ys) def matmul xss yss = map (\xs -> map (dotprod xs) (transpose yss)) xss def onehot_2d n m p : [n][m]f64 = tabulate_2d n m (\i j -> f64.bool ((i,j)==p)) entry fwd_J [n][m][p] (xss: [n][m]f64) (yss: [m][p]f64) = tabulate_2d m p (\i j -> jvp (matmul xss) yss (onehot_2d m p (i,j))) |> transpose |> map transpose |> transpose entry rev_J [n][m][p] (xss: [n][m]f64) (yss: [m][p]f64) = tabulate_2d n p (\i j -> vjp (matmul xss) yss (onehot_2d n p (i,j))) futhark-0.25.27/tests/ad/maximum.fut000066400000000000000000000005671475065116200172750ustar00rootroot00000000000000-- == -- entry: rev fwd -- input { [1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 5.0] } -- output { [0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0] } -- input { [1.0, 1.0] } -- output { [1.0, 0.0] } -- structure { /Screma 2 } def f = map f64.abs >-> f64.maximum entry rev [n] (xs: [n]f64) = vjp f xs 1 entry fwd [n] (xs: [n]f64) = tabulate n (\i -> jvp f xs (tabulate n ((==i) >-> f64.bool))) futhark-0.25.27/tests/ad/minimum.fut000066400000000000000000000005111475065116200172600ustar00rootroot00000000000000-- == -- entry: rev fwd -- input { [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 5.0] } -- output { [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] } -- input { [1.0, 1.0] } -- output { [1.0, 0.0] } entry rev [n] (xs: [n]f64) = vjp f64.minimum xs 1 entry fwd [n] (xs: [n]f64) = tabulate n (\i -> jvp f64.minimum xs (tabulate n ((==i) >-> f64.bool))) futhark-0.25.27/tests/ad/minmax.fut000066400000000000000000000007171475065116200171060ustar00rootroot00000000000000-- == -- entry: rev fwd -- input { [1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 5.0] } -- output { [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] -- [0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0] -- } -- structure { /Screma 2 } def f xs = let ys = map f64.abs xs in (f64.minimum ys, f64.maximum ys) entry rev [n] (xs: [n]f64) = (vjp f xs (1,0), vjp f xs (0,1)) entry fwd [n] (xs: [n]f64) = unzip (tabulate n (\i -> jvp f xs (tabulate n ((==i) >-> f64.bool)))) futhark-0.25.27/tests/ad/negate.fut000066400000000000000000000002431475065116200170520ustar00rootroot00000000000000-- == -- entry: fwd rev -- input { 1f32 } output { -1f32 } -- input { 3f32 } output { -1f32 } def f x : f32 = -x entry fwd x = jvp f x 1 entry rev x = jvp f x 1 futhark-0.25.27/tests/ad/nested0.fut000066400000000000000000000003701475065116200171520ustar00rootroot00000000000000-- == -- entry: f_vjp -- input { [1,2,3] } -- output { [24,48,72] } def f [n] (xs: [n]i32) = map (\x -> x * x * x * x) xs entry f_vjp [n] (xs: [n]i32) = vjp (\xs -> vjp (\xs -> vjp f xs (replicate n 1)) xs (replicate n 1)) xs (replicate n 1) futhark-0.25.27/tests/ad/nested1.fut000066400000000000000000000005011475065116200171470ustar00rootroot00000000000000-- == -- entry: f_vjp -- input { [1,2,3] [0,1,2] } -- output { [6,12,18] [0,0,0] } def f [n] (xsis: ([n]i32, [n]i32)) = let (xs, is) = xsis in map (\i -> xs[i] * xs[i] * xs[i]) is entry f_vjp [n] (xs: [n]i32) (is: [n]i32) = vjp (\(xs, is) -> vjp f (xs,is) (replicate n 1)) (xs,is) (replicate n 1, replicate n 1) futhark-0.25.27/tests/ad/nested2.fut000066400000000000000000000006041475065116200171540ustar00rootroot00000000000000-- == -- entry: f_vjp -- input { [1,2,3] [0,1,2] } -- output { [24,48,72] [0,0,0] } def f [n] (xsis: ([n]i32, [n]i32)) = let (xs, is) = xsis in map (\i -> xs[i] * xs[i] * xs[i] * xs[i]) is entry f_vjp [n] (xs: [n]i32) (is: [n]i32) = vjp (\(xs, is) -> vjp (\(xs, is) -> vjp f (xs,is) (replicate n 1)) (xs,is) (replicate n 1, replicate n 0)) (xs,is) (replicate n 1, replicate n 0) futhark-0.25.27/tests/ad/nested3.fut000066400000000000000000000005771475065116200171660ustar00rootroot00000000000000-- == -- entry: f_vjp -- input { [[1,2,3],[1,2,3],[1,2,3]] [1,2,3]} -- output { [[1,1,1],[1,1,1],[1,1,1]] [3,3,3] } def f [n] (xssys : ([n][n]i32, [n]i32)) = let (xss,ys) = xssys in i32.sum (map (\xs -> i32.sum (map2 (*) xs ys)) xss) entry f_vjp [n] (xss: [n][n]i32) (ys: [n]i32) = vjp (\(xss, ys) -> vjp f (xss, ys) 1 ) (xss, ys) (replicate n (replicate n 1), replicate n 1) futhark-0.25.27/tests/ad/nested4.fut000066400000000000000000000007621475065116200171630ustar00rootroot00000000000000-- == -- entry: f_vjp -- input { [[1,2,3],[1,2,3],[1,2,3]] [0,1,2] [0,1,2]} -- output { [[6,12,18],[6,12,18],[6,12,18]] [0,0,0] [0,0,0] } def f [n] (xssisjs: ([n][n]i32, [n]i32, [n]i32)) = let (xss, is, js) = xssisjs in map (\i -> map (\j -> xss[i,j] * xss[i,j] * xss[i,j]) js) is entry f_vjp [n] (xss: [n][n]i32) (is: [n]i32) (js: [n]i32) = vjp (\(xss,is,js) -> vjp f (xss, is, js) (replicate n (replicate n 1))) (xss, is, js) ((replicate n (replicate n 1)), replicate n 0, replicate n 0) futhark-0.25.27/tests/ad/not.fut000066400000000000000000000002061475065116200164060ustar00rootroot00000000000000-- == -- entry: fwd rev -- input { true } output { true } def f x : bool = !x entry fwd x = jvp f x true entry rev x = jvp f x true futhark-0.25.27/tests/ad/pow0.fut000066400000000000000000000004351475065116200164770ustar00rootroot00000000000000-- The power function has a dangerous kink for x==0. -- == -- entry: fwd -- input { 0.0 1.0 } output { 1.0 } -- == -- entry: rev -- input { 0.0 1.0 } output { 1.0 0.0 } entry fwd x y : f64 = jvp (\(x, y) -> x ** y) (x, y) (1, 1) entry rev x y = vjp (\(x, y) -> x ** y) (x, y) 1f64 futhark-0.25.27/tests/ad/rearrange0.fut000066400000000000000000000004051475065116200176350ustar00rootroot00000000000000-- == -- entry: f_jvp -- input { [[1,2],[3,4]] } -- output { [[1,3],[2,4]] } entry f_jvp (xss: [][]i32) = jvp transpose xss xss -- == -- entry: f_vjp -- input { [[1,2],[3,4]] } -- output { [[1,3],[2,4]] } entry f_vjp (xss: [][]i32) = vjp transpose xss xss futhark-0.25.27/tests/ad/reduce-vec-minmax0.fut000066400000000000000000000013771475065116200212110ustar00rootroot00000000000000-- == -- entry: main -- compiled random input { [50][66]f32 } output { true } -- compiled random input { [23][45]f32} output { true } let redmap [n][m] (arr: [m][n]f32) : [n]f32 = reduce (map2 f32.max) (replicate n f32.lowest) arr let forward [n][m] (arr: [m][n]f32) : [n][m][n]f32 = tabulate_2d m n (\i j -> jvp redmap arr (replicate m (replicate n 0) with [i] = (replicate n 0 with [j] = 1))) |> transpose let reverse [n][m] (arr: [m][n]f32) : [n][m][n]f32 = tabulate n (\i -> vjp redmap arr (replicate n 0 with [i] = 1)) let main [n][m] (arr : [m][n]f32) : bool = let l = n * m * n let fs = forward arr |> flatten_3d :> [l]f32 let rs = reverse arr |> flatten_3d :> [l]f32 in reduce (&&) true (map2 (\i j -> f32.abs(i - j) < 0.0001f32) fs rs) futhark-0.25.27/tests/ad/reduce0.fut000066400000000000000000000004501475065116200171360ustar00rootroot00000000000000-- Simple reduce with multiplication -- == -- entry: rev -- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32] 1.0f32 } output { [24.0f32, 12.0f32, 8.0f32, 6.0f32] 24.0f32 } def red_mult [n] (xs: [n]f32, c: f32) : f32 = reduce (*) 1 xs * c entry rev [n] (xs: [n]f32) (c: f32) = vjp red_mult (xs,c) 1 futhark-0.25.27/tests/ad/reduce1.fut000066400000000000000000000011761475065116200171450ustar00rootroot00000000000000-- Reduce with a fancier operator. -- == -- entry: rev -- input { [1.0,2.0,3.0] [2.0,3.0,4.0] [3.0,4.0,5.0] [4.0,5.0,6.0] } -- output { [47.0, 28.0, 32.0] -- [83.0, 44.0, 32.0] -- [47.0, 42.0, 42.0] -- [83.0, 66.0, 42.0] } def mm2by2 (a1: f64, b1: f64, c1: f64, d1: f64) (a2: f64, b2: f64, c2: f64, d2: f64) = ( a1*a2 + b1*c2 , a1*b2 + b1*d2 , c1*a2 + d1*c2 , c1*b2 + d1*d2 ) def red_mm [n] (xs: [n](f64,f64,f64,f64)) = reduce mm2by2 (1, 0, 0, 1) xs entry rev [n] (xs1: [n]f64) (xs2: [n]f64) (xs3: [n]f64) (xs4: [n]f64) = vjp red_mm (zip4 xs1 xs2 xs3 xs4) (1, 1, 1, 1) |> unzip4 futhark-0.25.27/tests/ad/reduce2.fut000066400000000000000000000007231475065116200171430ustar00rootroot00000000000000-- Result of one reduction is used free in a map. -- == -- tags { no_ispc } -- entry: fwd rev -- input { [3f64, 1f64, 5f64] } output { [-1.000000f64, -1.000000f64, -1.000000f64] } def sumBy 'a (f : a -> f64) (xs : []a) : f64 = map f xs |> f64.sum def f (arr : []f64) = let mx = f64.sum arr let sumShiftedExp = sumBy (\x -> x - mx) arr in sumShiftedExp + mx entry fwd x = map (jvp f x) [[1.0,0.0,0.0],[0.0,1.0,0.0],[0.0,0.0,1.0]] entry rev x = vjp f x 1f64 futhark-0.25.27/tests/ad/reduce_by_index0.fut000066400000000000000000000005711475065116200210230ustar00rootroot00000000000000-- == -- entry: f_jvp -- input { [0i64,1i64,2i64,3i64] [1f64,2f64,3f64,4f64] } -- output { [[1f64,0f64,0f64,0f64],[0f64,1f64,0f64,0f64],[0f64,0f64,1f64,0f64],[0f64,0f64,0f64,1f64]] } def f [n] (is: [n]i64) (vs: [n]f64) = hist (+) 0 4 is (map (+2) vs) entry f_jvp [n] (is: [n]i64) (vs: [n]f64) = tabulate n (\i -> jvp (f is) vs (replicate n 0 with [i] = 1)) |> transpose futhark-0.25.27/tests/ad/reducebyindex0.fut000066400000000000000000000014541475065116200205260ustar00rootroot00000000000000-- == -- entry: main -- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] -- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32] } -- output { -- [8f32,960f32,504f32,0f32,0f32,0f32,7f32,8f32] -- [0f32,0f32,0f32,0f32,0f32,0f32,4f32,8f32,0f32,0f32,384f32,0f32,640f32,576f32,0f32,480f32,0f32,0f32] -- [8f32,960f32,0f32,0f32,0f32,0f32,9f32,0f32] } def f [n][m] (is: [n]i64) (dst: [m]f32,vs: [n]f32,c: [m]f32) = let r = reduce_by_index (copy dst) (\x y -> x*y*2) 0.5 is vs in map2 (*) r c def main [n][m] (is: [n]i64) (dst: [m]f32) (vs: [n]f32) (c: [m]f32) = vjp (f is) (dst,vs,c) (replicate m 1)futhark-0.25.27/tests/ad/reducebyindex1.fut000066400000000000000000000044621475065116200205310ustar00rootroot00000000000000-- == -- entry: rev -- compiled input { -- [0i64,1i64,2i64,3i64,2i64,1i64,0i64,1i64,2i64] -- [0f32,1f32,2f32,3f32,4f32] -- [-1i64,-1i64,-1i64,-1i64,-1i64] -- [-5f32,0f32,2f32,5f32,4f32,1f32,-2f32,-8f32,4f32] -- [0i64,1i64,2i64,3i64,4i64,5i64,6i64,7i64,8i64] } -- output { -- [1f32,1f32,0f32,0f32,1f32] -- [1i64,1i64,0i64,0i64,1i64] -- [0f32,0f32,0f32,1f32,1f32,0f32,0f32,0f32,0f32] -- [0i64,0i64,0i64,1i64,1i64,0i64,0i64,0i64,0i64] } def argmax (x: f32,i: i64) (y: f32,j: i64) = if x == y then (x,i64.min i j) else if x > y then (x,i) else (y,j) def f [n][m] (is: [n]i64) (dst: [m](f32,i64), vs: [n](f32,i64)) = reduce_by_index (copy dst) argmax (f32.lowest,i64.highest) is vs entry rev [n][m] (is: [n]i64) (dst0: [m]f32) (dst1: [m]i64) (vs0: [n]f32) (vs1: [n]i64) = let (r1,r2) = vjp (f is) (zip dst0 dst1, zip vs0 vs1) (zip (replicate m 1) (replicate m 1)) let (v1,i1) = unzip r1 let (v2,i2) = unzip r2 in (v1,i1,v2,i2) def fvec [n][m][k] (is: [n]i64) (dst: [k][m](f32,i64), vs: [n][m](f32,i64)) = reduce_by_index (copy dst) (map2 argmax) (replicate m (f32.lowest,i64.highest)) is vs -- == -- entry: revvec -- compiled input { -- [0i64,1i64,2i64,1i64,0i64,1i64] -- [[1f32,2f32,3f32],[4f32,5f32,6f32],[7f32,8f32,9f32],[10f32,11f32,12f32]] -- [[-1i64,-1i64,-1i64],[-1i64,-1i64,-1i64],[-1i64,-1i64,-1i64],[-1i64,-1i64,-1i64]] -- [[4f32,2f32,2f32],[5f32,6f32,8f32],[9f32,8f32,6f32],[4f32,6f32,7f32],[4f32,0f32,3f32],[3f32,6f32,7f32]] -- [[1i64,2i64,3i64],[4i64,5i64,6i64],[7i64,8i64,9i64],[10i64,11i64,12i64],[13i64,14i64,15i64],[16i64,17i64,18i64]] } -- output { -- [[0f32,1f32,1f32],[0f32,0f32,0f32],[0f32,1f32,1f32],[1f32,1f32,1f32]] -- [[0i64,1i64,1i64],[0i64,0i64,0i64],[0i64,1i64,1i64],[1i64,1i64,1i64]] -- [[1f32,0f32,0f32],[1f32,1f32,1f32],[1f32,0f32,0f32],[0f32,0f32,0f32],[0f32,0f32,0f32],[0f32,0f32,0f32]] -- [[1i64,0i64,0i64],[1i64,1i64,1i64],[1i64,0i64,0i64],[0i64,0i64,0i64],[0i64,0i64,0i64],[0i64,0i64,0i64]] } entry revvec [n][m][k] (is: [n]i64) (dst0: [k][m]f32) (dst1: [k][m]i64) (vs0: [n][m]f32) (vs1: [n][m]i64) = let (r1,r2) = vjp (fvec is) (map2 zip dst0 dst1, map2 zip vs0 vs1) (map2 zip (replicate k (replicate m 1)) (replicate k (replicate m 1))) let (v1,i1) = map unzip r1 |> unzip let (v2,i2) = map unzip r2 |> unzip in (v1,i1,v2,i2)futhark-0.25.27/tests/ad/reducebyindex2.fut000066400000000000000000000011651475065116200205270ustar00rootroot00000000000000-- == -- input { -- [0i64,1i64,2i64,3i64,2i64,1i64,0i64,1i64,2i64] -- [0f64,1f64,2f64,3f64] -- [2f64,3f64,4f64,5f64,6f64,0f64,8f64,9f64,1f64] -- [1f64,2f64,3f64,4f64,5f64,6f64,7f64,8f64,9f64] } -- output { -- [112f64,0f64,3240f64,20f64] -- [0f64,0f64,1620f64,12f64,1080f64,2592f64,0f64,0f64,6480f64] -- [0f64,0f64,2160f64,15f64,1296f64,0f64,0f64,0f64,720f64] } def f [n][m] (is: [n]i64) (dst: [m]f64,vs: [n]f64,c: [n]f64) = let tmp = map2 (*) c vs in reduce_by_index (copy dst) (*) 1 is tmp def main [n][m] (is: [n]i64) (dst: [m]f64) (vs: [n]f64) (c: [n]f64) = vjp (f is) (dst,vs,c) (replicate m 1)futhark-0.25.27/tests/ad/reducebyindex3.fut000066400000000000000000000047251475065116200205350ustar00rootroot00000000000000-- == -- entry: rev -- input { -- [0i64,1i64,2i64,1i64,0i64,1i64,2i64] -- [1f64,2f64,3f64,4f64,5f64,6f64,7f64] } -- output { -- [[1f64, 0f64, 0f64, 0f64, 1f64, 0f64, 0f64], -- [5f64, 1f64, 0f64, 1f64, 1f64, 1f64, 0f64], -- [0f64, 24f64, 1f64, 12f64, 0f64, 8f64, 1f64], -- [0f64, 0f64, 7f64, 0f64, 0f64, 0f64, 3f64]] } entry f [n] (is: [n]i64) (vs: [n]f64) = let r1 = reduce_by_index (replicate 4 1) (*) 1 (map (+1) is) vs let r2 = reduce_by_index (replicate 4 0) (+) 0 is (map (+2) vs) in map2 (+) r1 r2 entry rev [n] (is: [n]i64) (vs: [n]f64) = tabulate 4 (\i -> vjp (f is) vs (replicate 4 0 with [i] = 1)) -- entry fwd [n] (is: [n]i64) (vs: [n]f64) = -- tabulate n (\i -> jvp (f is) vs (replicate n 0 with [i] = 1)) -- |> map (.1) |> transpose -- entry ftest [n] (is: [n]i64) (vs: [n]f64) = -- reduce_by_index (replicate 4 0) (+) 0 is (map (+2) vs) -- entry revtest [n] (is: [n]i64) (vs: [n]f64) = -- tabulate 4 (\i -> vjp (ftest is) vs (replicate 4 0 with [i] = 1)) -- entry fwdtest [n] (is: [n]i64) (vs: [n]f64) = -- tabulate n (\i -> jvp (ftest is) vs (replicate n 0 with [i] = 1)) -- |> transpose -- entry revmap [n] (vs: [n]f64) = -- tabulate n (\i -> vjp (map (+2)) vs (replicate n 0 with [i] = 1)) -- entry fwdmap [n] (vs: [n]f64) = -- tabulate n (\i -> jvp (map (+2)) vs (replicate n 0 with [i] = 1)) -- |> transpose -- entry revp [n] (is: [n]i64) (vs: [n]f64) = -- (vjp2 (f is) vs (replicate 4 1)) -- entry fwdp [n] (is: [n]i64) (vs: [n]f64) = -- (jvp2 (f is) vs (replicate n 1)) -- [[0f64, 0f64, 0f64, 0f64, 0f64, 0f64, 0f64], -- [5f64, 0f64, 0f64, 0f64, 1f64, 0f64, 0f64], -- [0f64, 24f64, 0f64, 12f64, 0f64, 8f64, 0f64], -- [0f64, 0f64, 7f64, 0f64, 0f64, 0f64, 3f64]] -- [[1f64, 0f64, 0f64, 0f64, 1f64, 0f64, 0f64], -- [5f64, 1f64, 0f64, 1f64, 1f64, 1f64, 0f64], -- [0f64, 24f64, 1f64, 12f64, 0f64, 8f64, 1f64], -- [0f64, 0f64, 7f64, 0f64, 0f64, 0f64, 3f64]] -- [[0f64, 0f64, 0f64, 0f64, 0f64, 0f64, 0f64], -- [5f64, 0f64, 0f64, 0f64, 1f64, 0f64, 0f64], -- [0f64, 24f64, 0f64, 12f64, 0f64, 8f64, 0f64], -- [0f64, 0f64, 7f64, 0f64, 0f64, 0f64, 3f64]] -- [[0f64, 0f64, 0f64, 0f64, 0f64, 0f64, 0f64], -- [5f64, 0f64, 0f64, 0f64, 1f64, 0f64, 0f64], -- [0f64, 24f64, 0f64, 12f64, 0f64, 8f64, 0f64], -- [0f64, 0f64, 7f64, 0f64, 0f64, 0f64, 3f64]] -- [[1f64, 0f64, 0f64, 0f64, 1f64, 0f64, 0f64], -- [0f64, 1f64, 0f64, 1f64, 0f64, 1f64, 0f64], -- [0f64, 0f64, 1f64, 0f64, 0f64, 0f64, 1f64], -- [0f64, 0f64, 0f64, 0f64, 0f64, 0f64, 0f64]]futhark-0.25.27/tests/ad/reducebyindex4.fut000066400000000000000000000015221475065116200205260ustar00rootroot00000000000000-- == -- entry: rev -- input { -- [ 0i64, 1i64, 2i64, 1i64, 0i64, 1i64, 2i64, 1i64, 0i64] -- [ 1f32, 2f32, 3f32, 4f32, 5f32, 6f32, 7f32, 8f32, 9f32] -- [10f32,11f32,12f32,13f32,14f32,15f32,16f32,17f32,18f32] } -- output { -- [252f32,3315f32,15f32,2805f32,180f32,2431f32,12f32,2145f32,140f32] -- [468f32,7221f32,22f32,5757f32,288f32,4765f32,15f32,4053f32,204f32] } def op (a1:f32,b1:f32) (a2:f32,b2:f32) : (f32,f32) = (b1*a2+b2*a1, b1*b2) def f [n] (is: [n]i64) (vs: [n](f32,f32)) = reduce_by_index (replicate 4 (0,1)) op (0,1) is vs entry rev [n] (is: [n]i64) (vs0: [n]f32) (vs1: [n]f32) = vjp (f is) (zip vs0 vs1) (replicate 4 (1,1)) |> unzip entry fwd [n] (is: [n]i64) (vs0: [n]f32) (vs1: [n]f32) = tabulate n (\i -> jvp (f is) (zip vs0 vs1) (replicate n (0,0) with [i] = (1,1))) |> transpose |> map unzip |> unzipfuthark-0.25.27/tests/ad/reducebyindex5.fut000066400000000000000000000023741475065116200205350ustar00rootroot00000000000000-- == -- entry: rev -- compiled input { -- [0i64,1i64,2i64,1i64,0i64,1i64,2i64,1i64] -- [15f32,1f32,2f32,15f32,16f32] -- [1f32,2f32,12f32,3f32,2f32,4f32,7f32,5f32] } -- output { -- [0f32,1f32,0f32,1f32,0f32] -- [0f32,1f32,0f32,1f32,0f32,1f32,0f32,1f32] } let sat_add_u4 (x: f32) (y: f32): f32 = let sat_val = f32.i32 ((1 << 4) - 1) in if sat_val - x < y then sat_val else x + y def f [n][m] (is: [n]i64) (dst: [m]f32, as: [n]f32) = reduce_by_index (copy dst) sat_add_u4 0 is as entry rev [n][m] (is: [n]i64) (dst: [m]f32) (as: [n]f32) = vjp (f is) (dst,as) (replicate m 1) -- == -- entry: revvec -- compiled input { -- [0i64,1i64,2i64,1i64,0i64] -- [[15f32,1f32,2f32],[0f32,16f32,3f32],[15f32,4f32,5f32]] -- [[1f32,3f32,10f32],[6f32,8f32,3f32],[0f32,11f32,14f32],[7f32,9f32,3f32],[2f32,4f32,5f32]] } -- output { -- [[0f32,1f32,0f32],[1f32,0f32,1f32],[1f32,1f32,0f32]] -- [[0f32,1f32,0f32],[1f32,0f32,1f32],[1f32,1f32,0f32],[1f32,0f32,1f32],[0f32,1f32,0f32]] } def fvec [k][n][m] (is: [n]i64) (dst: [k][m]f32, as: [n][m]f32) = reduce_by_index (copy dst) (map2 sat_add_u4) (replicate m 0) is as entry revvec [k][n][m] (is: [n]i64) (dst: [k][m]f32) (as: [n][m]f32) = vjp (fvec is) (dst,as) (replicate k (replicate m 1)) futhark-0.25.27/tests/ad/reducebyindex6.fut000066400000000000000000000006121475065116200205270ustar00rootroot00000000000000-- == -- entry: rev -- input { -- [1i64,-3i64,1i64,5i64,1i64,-3i64,1i64,5i64] -- [1f32, 2f32,3f32,4f32,5f32, 6f32,7f32,8f32] } -- output { -- [840f32,0f32,280f32,0f32,168f32,0f32,120f32,0f32] } def f [n] (is: [n]i64) (as: [n]f32) = reduce_by_index (replicate 2 0.5) (\x y -> x*y*2) 0.5 is as entry rev [n] (is: [n]i64) (as: [n]f32) = vjp (f is) as (replicate 2 0 with [1] = 1)futhark-0.25.27/tests/ad/reducebyindexadd0.fut000066400000000000000000000007241475065116200211760ustar00rootroot00000000000000-- == -- entry: rev -- input { [2f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,10f32] } output { [5f32,0f32,0f32,0f32,0f32] } -- checks original dst is used def red_add [n][m] (is: [n]i64) (vs: [n]f32) (dst: [m]f32) = let dst2 = copy dst let a = map (**2) dst2 let b = reduce_by_index dst2 (+) 0 is vs in map2 (+) a b entry rev [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) = vjp (red_add is vs) dst (replicate m 0 with [0] = 1)futhark-0.25.27/tests/ad/reducebyindexadd1.fut000066400000000000000000000011761475065116200212010ustar00rootroot00000000000000-- == -- entry: main -- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] } -- output { -- [1f32,1f32,1f32,1f32,1f32,1f32,1f32,1f32] -- [1f32,1f32,1f32,1f32,1f32,1f32,1f32,1f32,1f32,1f32,1f32,1f32,1f32,1f32,1f32,1f32,0f32,0f32] } def f [n][m] (is: [n]i64) (dst: [m]f32,vs: [n]f32) = reduce_by_index (copy dst) (+) 0 is vs def main [n][m] (is: [n]i64) (dst: [m]f32) (vs: [n]f32) = vjp (f is) (dst,vs) (replicate m 1)futhark-0.25.27/tests/ad/reducebyindexadd2.fut000066400000000000000000000024501475065116200211760ustar00rootroot00000000000000-- == -- entry: main -- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] -- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32] } -- output { -- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32] -- [5f32,6f32,3f32,5f32,6f32,3f32,1f32,1f32,5f32,6f32,2f32,5f32,2f32,4f32,4f32,2f32,0f32,0f32] -- [4f32,14f32,13f32,13f32,29f32,30f32,9f32,0f32] } -- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] -- [0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32] } -- output { -- [0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32] -- [0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32] -- [4f32,14f32,13f32,13f32,29f32,30f32,9f32,0f32] } def f [n][m] (is: [n]i64) (dst: [m]f32,vs: [n]f32,c: [m]f32) = let tmp = reduce_by_index (copy dst) (+) 0 is vs in map2 (*) tmp c def main [n][m] (is: [n]i64) (dst: [m]f32) (vs: [n]f32) (c: [m]f32) = vjp (f is) (dst,vs,c) (replicate m 1)futhark-0.25.27/tests/ad/reducebyindexadd3.fut000066400000000000000000000015631475065116200212030ustar00rootroot00000000000000-- == -- entry: main -- input { -- [0i64,0i64,0i64,1i64,1i64,2i64,2i64,2i64,2i64] -- [[1f32,2f32],[0f32,4f32],[5f32,0f32],[9f32,0f32]] -- [[1f32,3f32],[2f32,4f32],[18f32,5f32],[6f32,0f32],[7f32,9f32],[0f32,14f32],[11f32,0f32],[0f32,16f32],[13f32,17f32]] -- [[1f32,2f32],[3f32,4f32],[5f32,6f32],[7f32,8f32]] } -- output { -- [[1f32,2f32],[3f32,4f32],[5f32,6f32],[7f32,8f32]] -- [[1f32,2f32],[1f32,2f32],[1f32,2f32],[3f32,4f32],[3f32,4f32],[5f32,6f32],[5f32,6f32],[5f32,6f32],[5f32,6f32]] -- [[22f32,14f32],[13f32,13f32],[29f32,47f32],[9f32,0f32]] } def f [n][m][k] (is: [n]i64) (dst: [k][m]f32,vs: [n][m]f32,c: [k][m]f32) = let tmp = reduce_by_index (copy dst) (map2 (+)) (replicate m 0) is vs in map2 (map2 (*)) tmp c def main [n][m][k] (is: [n]i64) (dst: [k][m]f32) (vs: [n][m]f32) (c: [k][m]f32) = vjp (f is) (dst,vs,c) (replicate k (replicate m 1))futhark-0.25.27/tests/ad/reducebyindexadd4.fut000066400000000000000000000021421475065116200211760ustar00rootroot00000000000000-- == -- entry: main -- input { -- [0i64,0i64,0i64,1i64,1i64,1i64,1i64] -- [[[1f32,2f32],[0f32,4f32]],[[5f32,0f32],[9f32,0f32]]] -- [[[1f32,3f32],[6f32,0f32]],[[2f32,4f32],[7f32,9f32]],[[18f32,5f32],[19f32,20f32]], -- [[0f32,14f32],[1f32,1f32]],[[11f32,0f32],[1f32,1f32]],[[0f32,16f32],[1f32,1f32]],[[13f32,21f32],[1f32,1f32]]] -- [[[1f32,2f32],[3f32,4f32]],[[5f32,6f32],[7f32,8f32]]] } -- output { -- [[[1f32,2f32],[3f32,4f32]],[[5f32,6f32],[7f32,8f32]]] -- [[[1f32,2f32],[3f32,4f32]],[[1f32,2f32],[3f32,4f32]],[[1f32,2f32],[3f32,4f32]], -- [[5f32,6f32],[7f32,8f32]],[[5f32,6f32],[7f32,8f32]],[[5f32,6f32],[7f32,8f32]],[[5f32,6f32],[7f32,8f32]]] -- [[[22f32,14f32],[32f32,33f32]],[[29f32,51f32],[13f32,4f32]]] } def f [n][m][k][l] (is: [n]i64) (dst: [k][m][l]f32,vs: [n][m][l]f32,c: [k][m][l]f32) = let tmp = reduce_by_index (copy dst) (map2 (map2 (+))) (replicate m (replicate l 0)) is vs in map2 (map2 (map2 (*))) tmp c def main [n][m][k][l] (is: [n]i64) (dst: [k][m][l]f32) (vs: [n][m][l]f32) (c: [k][m][l]f32) = vjp (f is) (dst,vs,c) (replicate k (replicate m (replicate l 1)))futhark-0.25.27/tests/ad/reducebyindexgenbenchtests.fut000066400000000000000000000101521475065116200232160ustar00rootroot00000000000000-- == -- entry: satadd -- compiled random input { [1000]u64 [100]i32 [1000]i32 } output { true } -- == -- entry: argmax -- compiled random input { [500]u64 [50]i32 [50]i32 [500]i32 [500]i32 } output { true } let sat_add_u24 (x: i32) (y: i32): i32 = let sat_val = (1 << 24) - 1 in if sat_val - x < y then sat_val else x + y def satadd_p [n][m] (is: [n]i64) (dst: [m]i32,vs: [n]i32): [m]i32 = reduce_by_index (copy dst) sat_add_u24 0 is vs def satadd_fwd [n][m] (is: [n]i64) (dst: [m]i32,vs: [n]i32): ([m][m]i32, [m][n]i32) = let t1 = tabulate m (\i -> jvp (satadd_p is) (dst,vs) (replicate m 0 with [i] = 1,replicate n 0)) let t2 = tabulate n (\i -> jvp (satadd_p is) (dst,vs) (replicate m 0,replicate n 0 with [i] = 1)) in (transpose t1,transpose t2) def satadd_rev [n][m] (is: [n]i64) (dst: [m]i32,vs: [n]i32) : ([m][m]i32, [m][n]i32) = tabulate m (\i -> vjp (satadd_p is) (dst,vs) (replicate m 0 with [i] = 1)) |> unzip entry satadd [n][m] (is': [n]u64) (dst': [m]i32) (vs': [n]i32) = let is = map (\i -> i64.u64 (i%%u64.i64 m)) is' let dst = map (%%1000) dst' let vs = map (%%1000) vs' let (fwd1,fwd2) = satadd_fwd is (dst,vs) let (rev1,rev2) = satadd_rev is (dst,vs) let t1 = map2 (map2 (==)) fwd1 rev1 |> map (reduce (&&) true) |> reduce (&&) true let t2 = map2 (map2 (==)) fwd2 rev2 |> map (reduce (&&) true) |> reduce (&&) true in t1 && t2 let argmax_f (x:i32,i:i32) (y:i32,j:i32) : (i32,i32) = if x < y then (x, i) else if y < x then (y, j) else if i < j then (x, i) else (y, j) def argmax_p [n][m] (is: [n]i64) (dst_a: [m]i32,dst_b: [m]i32,vs_a: [n]i32,vs_b: [n]i32) = reduce_by_index (zip (copy dst_a) (copy dst_b)) argmax_f (i32.highest, i32.highest) is (zip vs_a vs_b) def argmax_fwd [n][m] (is: [n]i64) (dst_a: [m]i32) (dst_b: [m]i32) (vs_a: [n]i32) (vs_b: [n]i32): ([m][m]i32,[m][m]i32,[m][n]i32,[m][n]i32,[m][m]i32,[m][m]i32,[m][n]i32,[m][n]i32) = let (t1,t2) = tabulate m (\i -> jvp (argmax_p is) (dst_a,dst_b,vs_a,vs_b) (replicate m 0 with [i] = 1, replicate m 0, replicate n 0, replicate n 0)) |> map unzip |> unzip let (t3,t4) = tabulate m (\i -> jvp (argmax_p is) (dst_a,dst_b,vs_a,vs_b) (replicate m 0, replicate m 0 with [i] = 1, replicate n 0, replicate n 0)) |> map unzip |> unzip let (t5,t6) = tabulate n (\i -> jvp (argmax_p is) (dst_a,dst_b,vs_a,vs_b) (replicate m 0, replicate m 0, replicate n 0 with [i] = 1, replicate n 0)) |> map unzip |> unzip let (t7,t8) = tabulate n (\i -> jvp (argmax_p is) (dst_a,dst_b,vs_a,vs_b) (replicate m 0, replicate m 0, replicate n 0, replicate n 0 with [i] = 1)) |> map unzip |> unzip in (transpose t1,transpose t2,transpose t5,transpose t6,transpose t3,transpose t4,transpose t7,transpose t8) def argmax_rev [n][m] (is: [n]i64) (dst_a: [m]i32) (dst_b: [m]i32) (vs_a: [n]i32) (vs_b: [n]i32): ([m][m]i32,[m][m]i32,[m][n]i32,[m][n]i32,[m][m]i32,[m][m]i32,[m][n]i32,[m][n]i32) = let (t1,t2,t3,t4) = tabulate m (\i -> vjp (argmax_p is) (dst_a,dst_b,vs_a,vs_b) (zip (replicate m 0 with [i] = 1) (replicate m 0))) |> unzip4 let (t5,t6,t7,t8) = tabulate m (\i -> vjp (argmax_p is) (dst_a,dst_b,vs_a,vs_b) (zip (replicate m 0) (replicate m 0 with [i] = 1))) |> unzip4 in (t1,t2,t3,t4,t5,t6,t7,t8) entry argmax [n][m] (is': [n]u64) (dst_a: [m]i32) (dst_b: [m]i32) (vs_a: [n]i32) (vs_b: [n]i32) = let is = map (\i -> i64.u64 (i%%u64.i64 m)) is' let (f1,f2,f3,f4,f5,f6,f7,f8) = argmax_fwd is dst_a dst_b vs_a vs_b let (r1,r2,r3,r4,r5,r6,r7,r8) = argmax_rev is dst_a dst_b vs_a vs_b let t1 = map2 (map2 (==)) f1 r1 |> map (reduce (&&) true) |> reduce (&&) true let t2 = map2 (map2 (==)) f2 r2 |> map (reduce (&&) true) |> reduce (&&) true let t3 = map2 (map2 (==)) f3 r3 |> map (reduce (&&) true) |> reduce (&&) true let t4 = map2 (map2 (==)) f4 r4 |> map (reduce (&&) true) |> reduce (&&) true let t5 = map2 (map2 (==)) f5 r5 |> map (reduce (&&) true) |> reduce (&&) true let t6 = map2 (map2 (==)) f6 r6 |> map (reduce (&&) true) |> reduce (&&) true let t7 = map2 (map2 (==)) f7 r7 |> map (reduce (&&) true) |> reduce (&&) true let t8 = map2 (map2 (==)) f8 r8 |> map (reduce (&&) true) |> reduce (&&) true in t1 && t2 && t3 && t4 && t5 && t6 && t7 && t8 futhark-0.25.27/tests/ad/reducebyindexminmax0.fut000066400000000000000000000017431475065116200217410ustar00rootroot00000000000000-- == -- entry: rev -- input { [0i64, 1i64, 2i64, 3i64, 4i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32] } output { [0f32,0f32,0f32,0f32,0f32] } -- input { [0i64, 0i64, 0i64, 0i64, 0i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32] } output { [0f32,0f32,0f32,0f32,1f32] } -- input { [0i64, 0i64, 0i64, 0i64, 0i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 3.0f32] } output { [0f32,0f32,0f32,1f32,0f32] } -- == -- entry: revp -- input { [0i64, 1i64, 2i64, 3i64, 4i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32] } output { [0i64,0i64,0i64,0i64,0i64] } -- input { [0i64, 0i64, 0i64, 0i64, 0i64] [0.0f32, 1.0f32, 2.0f32, 3.0f32, 4.0f32] } output { [0i64,0i64,0i64,0i64,0i64] } def red_max [n] (is: [n]i64, vs: [n]f32) = reduce_by_index (replicate 5 0) f32.max f32.lowest is vs entry rev [n] (is: [n]i64) (vs: [n]f32) = let (_, res) = vjp red_max (is,vs) (replicate 5 0 with [0] = 1) in res entry revp [n] (is: [n]i64) (vs: [n]f32) = let (res,_) = vjp red_max (is,vs) (replicate 5 0 with [0] = 1) in resfuthark-0.25.27/tests/ad/reducebyindexminmax1.fut000066400000000000000000000020211475065116200217300ustar00rootroot00000000000000-- == -- entry: rev -- input { [0f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [0f32,0f32,0f32,0f32,0f32] } -- input { [1f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [0f32,0f32,0f32,0f32,0f32] } -- input { [2f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [0f32,0f32,0f32,0f32,0f32] } -- input { [5f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,0f32] } output { [1f32,0f32,0f32,0f32,0f32] } -- input { [5f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [1f32,0f32,0f32,0f32,0f32] } -- input { [0f32,1f32,2f32,3f32,4f32] [1i64,2i64,3i64,2i64,1i64] [1f32,2f32,3f32,4f32,5f32] } output { [1f32,0f32,0f32,0f32,0f32] } def red_max [n][m] (is: [n]i64) (vs: [n]f32) (dst: [m]f32) = reduce_by_index (copy dst) f32.max f32.lowest is vs entry rev [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) = vjp (red_max is vs) dst (replicate m 0 with [0] = 1)futhark-0.25.27/tests/ad/reducebyindexminmax10.fut000066400000000000000000000010451475065116200220150ustar00rootroot00000000000000-- == -- input { -- [0i64,1i64,0i64,1i64] -- [[1f32,2f32],[3f32,4f32]] -- [[1f32,0f32],[5f32,2f32],[-2f32,3f32],[4f32,6f32]] -- [[1f32,2f32],[3f32,4f32]] -- [[3f32,4f32],[5f32,6f32],[4f32,5f32],[6f32,7f32]] -- } -- output { [[3f32,4f32],[8f32,6f32],[4f32,7f32],[6f32,11f32]] } def primal3 [n][m][k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = (reduce_by_index (copy dst) (map2 f32.max) (replicate k f32.lowest) is vs,vs) def main [n][m][k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) bar1 bar2 = vjp (primal3 is dst) vs (bar1,bar2) futhark-0.25.27/tests/ad/reducebyindexminmax2.fut000066400000000000000000000012051475065116200217340ustar00rootroot00000000000000-- == -- entry: rev -- input { [5f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [4f32,3f32,2f32,1f32,0f32] } output { [0f32,0f32,0f32,0f32,0f32] } -- input { [4f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [5f32,4f32,3f32,2f32,1f32] } output { [1f32,0f32,0f32,0f32,0f32] } -- input { [5f32,1f32,2f32,3f32,4f32] [0i64,0i64,0i64,0i64,0i64] [5f32,4f32,3f32,2f32,1f32] } output { [0f32,0f32,0f32,0f32,0f32] } def red_max [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) = reduce_by_index (copy dst) f32.max f32.lowest is vs entry rev [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) = vjp (red_max dst is) vs (replicate m 0 with [0] = 1)futhark-0.25.27/tests/ad/reducebyindexminmax3.fut000066400000000000000000000013311475065116200217350ustar00rootroot00000000000000-- == -- entry: rev -- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [4f32,3f32,2f32] 3f32 } -- output { [0f32,0f32,0f32] 5f32 } -- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [10f32,3f32,2f32] 3f32 } -- output { [3f32,0f32,0f32] 10f32 } -- input { [5f32,1f32,2f32] [0i64,1i64,0i64] [10f32,30f32,2f32] 3f32 } -- output { [3f32,0f32,0f32] 10f32 } def red_max [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32, c: f32) = let red = reduce_by_index (copy dst) f32.max f32.lowest is vs in map (*c) red entry rev [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = vjp (red_max dst is) (vs,c) (replicate m 0 with [0] = 1) --tabulate n (\i -> vjp (red_max dst is) (vs, c) (replicate n 0 with [i] = 1))futhark-0.25.27/tests/ad/reducebyindexminmax4.fut000066400000000000000000000013311475065116200217360ustar00rootroot00000000000000-- == -- entry: rev -- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [4f32,3f32,2f32] 3f32 } -- output { [3f32,0f32,0f32] 5f32 } -- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [10f32,3f32,2f32] 3f32 } -- output { [0f32,0f32,0f32] 10f32 } -- input { [5f32,1f32,2f32] [0i64,1i64,0i64] [10f32,30f32,2f32] 3f32 } -- output { [0f32,0f32,0f32] 10f32 } def red_max [n][m] (vs: [n]f32) (is: [n]i64) (dst: [m]f32, c: f32) = let red = reduce_by_index (copy dst) f32.max f32.lowest is vs in map (*c) red entry rev [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = vjp (red_max vs is) (dst,c) (replicate m 0 with [0] = 1) --tabulate n (\i -> vjp (red_max dst is) (vs, c) (replicate n 0 with [i] = 1))futhark-0.25.27/tests/ad/reducebyindexminmax5.fut000066400000000000000000000006151475065116200217430ustar00rootroot00000000000000-- == -- entry: rev -- input { [4f32,1f32,2f32] [0i64,0i64,0i64] [5f32,1f32,2f32]} -- output { [13f32,5f32,5f32] } def red_max [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) = let red = reduce_by_index (copy dst) f32.max f32.lowest is vs in map (* reduce (+) 0 vs) red entry rev [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) = vjp (red_max dst is) vs (replicate m 0 with [0] = 1)futhark-0.25.27/tests/ad/reducebyindexminmax6.fut000066400000000000000000000011471475065116200217450ustar00rootroot00000000000000-- == -- entry: rev -- input { [2f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,10f32] } output { [4f32,0f32,0f32,0f32,0f32] } -- input { [10f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [21f32,0f32,0f32,0f32,0f32] } -- checks original dst is used def red_max [n][m] (is: [n]i64) (vs: [n]f32) (dst: [m]f32) = let dst2 = copy dst let a = map (**2) dst2 let b = reduce_by_index dst2 f32.max f32.lowest is vs in map2 (+) a b entry rev [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) = vjp (red_max is vs) dst (replicate m 0 with [0] = 1)futhark-0.25.27/tests/ad/reducebyindexminmax7.fut000066400000000000000000000014531475065116200217460ustar00rootroot00000000000000-- == -- compiled random input { [500]i64 [100][30]f32 [500][30]f32 } output { true } def primal [n][m][k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = reduce_by_index (copy dst) (map2 f32.max) (replicate k f32.lowest) is vs def rev [n][m][k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = tabulate m (\i -> vjp (primal is dst) vs (replicate m (replicate k 0) with [i] = replicate k 1)) def fwd [n][m][k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = tabulate n (\i -> jvp (primal is dst) vs (replicate n (replicate k 0) with [i] = replicate k 1)) |> transpose def main [n][m][k] (is': [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = let is = map (\i -> (i64.abs i) %% m) is' let r = rev is dst vs let f = fwd is dst vs in map2 (map2 (==)) r f |> map (reduce (&&) true) |> reduce (&&) true futhark-0.25.27/tests/ad/reducebyindexminmax8.fut000066400000000000000000000017211475065116200217450ustar00rootroot00000000000000-- == -- compiled random input { [100]i64 [50][30][20]f32 [100][30][20]f32 } output { true } def primal2 [n][m][k][l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = reduce_by_index (copy dst) (map2 (map2 f32.max)) (replicate k (replicate l f32.lowest)) is vs def rev2 [n][m][k][l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = tabulate m (\i -> vjp (primal2 is dst) vs (replicate m (replicate k (replicate l 0)) with [i] = replicate k (replicate l 1))) def fwd2 [n][m][k][l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = tabulate n (\i -> jvp (primal2 is dst) vs (replicate n (replicate k (replicate l 0)) with [i] = replicate k (replicate l 1))) |> transpose def main [n][m][k][l] (is': [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = let is = map (\i -> (i64.abs i) %% m) is' let r = rev2 is dst vs let f = fwd2 is dst vs in map2 (map2 (map2 (==))) r f |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true futhark-0.25.27/tests/ad/reducebyindexminmax9.fut000066400000000000000000000015071475065116200217500ustar00rootroot00000000000000-- == -- input { -- [0i64,1i64,0i64,1i64] -- [[[1f32,2f32],[3f32,4f32]],[[5f32,6f32],[7f32,8f32]]] -- [ [[1f32,0f32],[5f32,2f32]], [[7f32,4f32],[9f32,7f32]], [[-2f32,3f32],[4f32,6f32]], [[1f32,2f32],[5f32,9f32]] ] -- [[[1f32,2f32],[3f32,4f32]],[[5f32,6f32],[7f32,8f32]]] -- [ [[3f32,4f32],[5f32,6f32]], [[7f32,8f32],[9f32,10f32]], [[4f32,5f32],[6f32,7f32]], [[8f32,9f32],[10f32,11f32]] ] -- } -- output { [ [[3f32,4f32], [8f32,6f32]], [[12f32,8f32], [16f32,10f32]], [[4f32,7f32], [6f32,11f32]], [[8f32,9f32],[10f32,19f32]] ] } def primal4 [n][m][k][l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = (reduce_by_index (copy dst) (map2 (map2 f32.max)) (replicate k (replicate l f32.lowest)) is vs, vs) def main [n][m][k][l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) bar1 bar2 = vjp (primal4 is dst) vs (bar1,bar2) futhark-0.25.27/tests/ad/reducebyindexmul0.fut000066400000000000000000000021101475065116200212320ustar00rootroot00000000000000-- 0 zero - dst neutral / dst no neutral -- 1 zero - dst / vs -- 2 zero - both in vs / one in dst and one in vs -- bucket with no values - zero / not zero -- index out of bounds -- == -- entry: main -- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] } -- output { -- [2f32,60f32,42f32,0f32,0f32,0f32,1f32,1f32] -- [0f32,0f32,0f32,0f32,0f32,0f32,1f32,2f32,0f32,0f32,24f32,0f32,40f32,36f32,0f32,30f32,0f32,0f32] } def f [n][m] (is: [n]i64) (dst: [m]f32,vs: [n]f32) = reduce_by_index (copy dst) (*) 1 is vs def main [n][m] (is: [n]i64) (dst: [m]f32) (vs: [n]f32) = vjp (f is) (dst,vs) (replicate m 1) -- [0i64,0i64,1i64,1i64,1i64,2i64,2i64,3i64,3i64,4i64,4i64,4i64,4i64,5i64,5i64,5i64] -- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,0f32,9f32,0f32,11f32,0f32,13f32,14f32,0f32,16f32] -- [2f32,1f32,40f32,30f32,24f32,0f32,0f32,36f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32]futhark-0.25.27/tests/ad/reducebyindexmul1.fut000066400000000000000000000024531475065116200212450ustar00rootroot00000000000000-- == -- entry: main -- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] -- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32] } -- output { -- [2f32,120f32,126f32,0f32,0f32,0f32,7f32,8f32] -- [0f32,0f32,0f32,0f32,0f32,0f32,1f32,2f32,0f32,0f32,48f32,0f32,80f32,144f32,0f32,60f32,0f32,0f32] -- [2f32,120f32,0f32,0f32,0f32,0f32,9f32,0f32] } -- input { -- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] -- [1f32,2f32,0f32,4f32,5f32,0f32,9f32,0f32] -- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] -- [0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32] } -- output { -- [0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32] -- [0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32,0f32] -- [2f32,120f32,0f32,0f32,0f32,0f32,9f32,0f32] } def f [n][m] (is: [n]i64) (dst: [m]f32,vs: [n]f32,c: [m]f32) = let tmp = reduce_by_index (copy dst) (*) 1 is vs in map2 (*) tmp c def main [n][m] (is: [n]i64) (dst: [m]f32) (vs: [n]f32) (c: [m]f32) = vjp (f is) (dst,vs,c) (replicate m 1)futhark-0.25.27/tests/ad/reducebyindexmul2.fut000066400000000000000000000015461475065116200212500ustar00rootroot00000000000000-- == -- entry: rev -- input { [2f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,10f32] } output { [244f32,0f32,0f32,0f32,0f32] } -- input { [10f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [0f32,2f32,3f32,4f32,5f32] } output { [20f32,0f32,0f32,0f32,0f32] } -- input { [0f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,3f32,4f32,5f32] } output { [120f32,0f32,0f32,0f32,0f32] } -- input { [0f32,1f32,1f32,1f32,1f32] [0i64,0i64,0i64,0i64,0i64] [1f32,2f32,0f32,4f32,5f32] } output { [0f32,0f32,0f32,0f32,0f32] } -- checks original dst is used def red_mul [n][m] (is: [n]i64) (vs: [n]f32) (dst: [m]f32) = let dst2 = copy dst let a = map (**2) dst2 let b = reduce_by_index dst2 (*) 1 is vs in map2 (+) a b entry rev [n][m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) = vjp (red_mul is vs) dst (replicate m 0 with [0] = 1)futhark-0.25.27/tests/ad/reducebyindexmul3.fut000066400000000000000000000015741475065116200212520ustar00rootroot00000000000000-- == -- entry: main -- input { -- [0i64,0i64,0i64,1i64,1i64,2i64,2i64,2i64,2i64] -- [[1f32,2f32],[0f32,4f32],[5f32,0f32],[9f32,0f32]] -- [[1f32,3f32],[2f32,4f32],[18f32,5f32],[6f32,0f32],[7f32,9f32],[0f32,14f32],[11f32,0f32],[0f32,16f32],[13f32,17f32]] -- [[1f32,2f32],[3f32,4f32],[5f32,6f32],[7f32,8f32]] } -- output { -- [[36f32,120f32],[126f32,0f32],[0f32,0f32],[7f32,8f32]] -- [[36f32,80f32],[18f32,60f32],[2f32,48f32],[0f32,144f32],[0f32,0f32],[0f32,0f32],[0f32,0f32],[0f32,0f32],[0f32,0f32]] -- [[36f32,120f32],[0f32,0f32],[0f32,0f32],[9f32,0f32]] } def f [n][m][k] (is: [n]i64) (dst: [k][m]f32,vs: [n][m]f32,c: [k][m]f32) = let tmp = reduce_by_index (copy dst) (map2 (*)) (replicate m 1) is vs in map2 (map2 (*)) tmp c def main [n][m][k] (is: [n]i64) (dst: [k][m]f32) (vs: [n][m]f32) (c: [k][m]f32) = vjp (f is) (dst,vs,c) (replicate k (replicate m 1))futhark-0.25.27/tests/ad/reducebyindexmul4.fut000066400000000000000000000021601475065116200212430ustar00rootroot00000000000000-- == -- entry: main -- input { -- [0i64,0i64,0i64,1i64,1i64,1i64,1i64] -- [[[1f32,2f32],[0f32,4f32]],[[5f32,0f32],[9f32,0f32]]] -- [[[1f32,3f32],[6f32,0f32]],[[2f32,4f32],[7f32,9f32]],[[18f32,5f32],[19f32,20f32]], -- [[0f32,14f32],[1f32,1f32]],[[11f32,0f32],[1f32,1f32]],[[0f32,16f32],[1f32,1f32]],[[13f32,21f32],[1f32,1f32]]] -- [[[1f32,2f32],[3f32,4f32]],[[5f32,6f32],[7f32,8f32]]] } -- output { -- [[[36f32,120f32],[2394f32,0f32]],[[0f32,0f32],[7f32,8f32]]] -- [[[36f32,80f32],[0f32,2880f32]],[[18f32,60f32],[0f32,0f32]],[[2f32,48f32],[0f32,0f32]], -- [[0f32,0f32],[63f32,0f32]],[[0f32,0f32],[63f32,0f32]],[[0f32,0f32],[63f32,0f32]],[[0f32,0f32],[63f32,0f32]]] -- [[[36f32,120f32],[0f32,0f32]],[[0f32,0f32],[9f32,0f32]]] } def f [n][m][k][l] (is: [n]i64) (dst: [k][m][l]f32,vs: [n][m][l]f32,c: [k][m][l]f32) = let tmp = reduce_by_index (copy dst) (map2 (map2 (*))) (replicate m (replicate l 1)) is vs in map2 (map2 (map2 (*))) tmp c def main [n][m][k][l] (is: [n]i64) (dst: [k][m][l]f32) (vs: [n][m][l]f32) (c: [k][m][l]f32) = vjp (f is) (dst,vs,c) (replicate k (replicate m (replicate l 1)))futhark-0.25.27/tests/ad/reducebyindexspbenchtests.fut000066400000000000000000000147221475065116200230760ustar00rootroot00000000000000-- == -- entry: add mul max -- compiled random input { [100]i32 [1000]u64 [1000]i32 } -- == -- entry: vecadd vecmul vecmax -- compiled random input { [10][100]i32 [100]u64 [100][100]i32 } output { true } def add_p [n][m] (is: [n]i64) (dst: [m]i32,vs: [n]i32): [m]i32 = reduce_by_index (copy dst) (+) 0 is vs def add_rev [n][m] (dst: [m]i32) (is: [n]i64) (vs: [n]i32) = tabulate m (\i -> vjp (add_p is) (dst,vs) (replicate m 0 with [i] = 1)) |> unzip def add_fwd [n][m] (dst: [m]i32) (is: [n]i64) (vs: [n]i32) = let t1 = tabulate m (\i -> jvp (add_p is) (dst,vs) (replicate m 0 with [i] = 1,replicate n 0)) let t2 = tabulate n (\i -> jvp (add_p is) (dst,vs) (replicate m 0,replicate n 0 with [i] = 1)) in (transpose t1,transpose t2) entry add [n][m] (dst: [m]i32) (is': [n]u64) (vs: [n]i32) = let is = map (\i -> i64.u64 (i%%u64.i64 m)) is' let (fwd1,fwd2) = add_fwd dst is vs let (rev1,rev2) = add_rev dst is vs let t1 = map2 (map2 (==)) fwd1 rev1 |> map (reduce (&&) true) |> reduce (&&) true let t2 = map2 (map2 (==)) fwd2 rev2 |> map (reduce (&&) true) |> reduce (&&) true in t1 && t2 def mul_p [n][m] (is: [n]i64) (dst: [m]i32,vs: [n]i32): [m]i32 = reduce_by_index (copy dst) (*) 1 is vs def mul_rev [n][m] (dst: [m]i32) (is: [n]i64) (vs: [n]i32) = tabulate m (\i -> vjp (mul_p is) (dst,vs) (replicate m 0 with [i] = 1)) |> unzip def mul_fwd [n][m] (dst: [m]i32) (is: [n]i64) (vs: [n]i32) = let t1 = tabulate m (\i -> jvp (mul_p is) (dst,vs) (replicate m 0 with [i] = 1,replicate n 0)) let t2 = tabulate n (\i -> jvp (mul_p is) (dst,vs) (replicate m 0,replicate n 0 with [i] = 1)) in (transpose t1,transpose t2) entry mul [n][m] (dst: [m]i32) (is': [n]u64) (vs: [n]i32) = let is = map (\i -> i64.u64 (i%%u64.i64 m)) is' let (fwd1,fwd2) = mul_fwd dst is vs let (rev1,rev2) = mul_rev dst is vs let t1 = map2 (map2 (==)) fwd1 rev1 |> map (reduce (&&) true) |> reduce (&&) true let t2 = map2 (map2 (==)) fwd2 rev2 |> map (reduce (&&) true) |> reduce (&&) true in t1 && t2 def max_p [n][m] (is: [n]i64) (dst: [m]i32,vs: [n]i32): [m]i32 = reduce_by_index (copy dst) i32.max i32.lowest is vs def max_rev [n][m] (dst: [m]i32) (is: [n]i64) (vs: [n]i32) = tabulate m (\i -> vjp (max_p is) (dst,vs) (replicate m 0 with [i] = 1)) |> unzip def max_fwd [n][m] (dst: [m]i32) (is: [n]i64) (vs: [n]i32) = let t1 = tabulate m (\i -> jvp (max_p is) (dst,vs) (replicate m 0 with [i] = 1,replicate n 0)) let t2 = tabulate n (\i -> jvp (max_p is) (dst,vs) (replicate m 0,replicate n 0 with [i] = 1)) in (transpose t1,transpose t2) entry max [n][m] (dst: [m]i32) (is': [n]u64) (vs: [n]i32) = let is = map (\i -> i64.u64 (i%%u64.i64 m)) is' let (fwd1,fwd2) = max_fwd dst is vs let (rev1,rev2) = max_rev dst is vs let t1 = map2 (map2 (==)) fwd1 rev1 |> map (reduce (&&) true) |> reduce (&&) true let t2 = map2 (map2 (==)) fwd2 rev2 |> map (reduce (&&) true) |> reduce (&&) true in t1 && t2 def vecadd_p [n][m][k] (is: [n]i64) (dst: [m][k]i32, vs: [n][k]i32): [m][k]i32 = reduce_by_index (copy dst) (map2 (+)) (replicate k 0) is vs def vecadd_rev [n][m][k] (dst: [m][k]i32) (is: [n]i64) (vs: [n][k]i32) = tabulate m (\i -> vjp (vecadd_p is) (dst,vs) (replicate m (replicate k 0) with [i] = replicate k 1)) |> unzip def vecadd_fwd [n][m][k] (dst: [m][k]i32) (is: [n]i64) (vs: [n][k]i32) = let t1 = tabulate m (\i -> jvp (vecadd_p is) (dst,vs) (replicate m (replicate k 0) with [i] = replicate k 1,replicate n (replicate k 0))) let t2 = tabulate n (\i -> jvp (vecadd_p is) (dst,vs) (replicate m (replicate k 0),replicate n (replicate k 0) with [i] = replicate k 1)) in (transpose t1,transpose t2) entry vecadd [n][m][k] (dst: [m][k]i32) (is': [n]u64) (vs: [n][k]i32) = let is = map (\i -> i64.u64 (i%%u64.i64 m)) is' let (fwd1,fwd2) = vecadd_fwd dst is vs let (rev1,rev2) = vecadd_rev dst is vs let t1 = map2 (map2 (map2 (==))) fwd1 rev1 |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true let t2 = map2 (map2 (map2 (==))) fwd2 rev2 |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true in t1 && t2 def vecmul_p [n][m][k] (is: [n]i64) (dst: [m][k]i32, vs: [n][k]i32): [m][k]i32 = reduce_by_index (copy dst) (map2 (*)) (replicate k 1) is vs def vecmul_rev [n][m][k] (dst: [m][k]i32) (is: [n]i64) (vs: [n][k]i32) = tabulate m (\i -> vjp (vecmul_p is) (dst,vs) (replicate m (replicate k 0) with [i] = replicate k 1)) |> unzip def vecmul_fwd [n][m][k] (dst: [m][k]i32) (is: [n]i64) (vs: [n][k]i32) = let t1 = tabulate m (\i -> jvp (vecmul_p is) (dst,vs) (replicate m (replicate k 0) with [i] = replicate k 1,replicate n (replicate k 0))) let t2 = tabulate n (\i -> jvp (vecmul_p is) (dst,vs) (replicate m (replicate k 0),replicate n (replicate k 0) with [i] = replicate k 1)) in (transpose t1,transpose t2) entry vecmul [n][m][k] (dst': [m][k]i32) (is': [n]u64) (vs': [n][k]i32) = let is = map (\i -> i64.u64 (i%%u64.i64 m)) is' let dst = map (map (\x -> (x%%2) + 1)) dst' let vs = map (map (\x -> (x%%2) + 1)) vs' let (fwd1,fwd2) = vecmul_fwd dst is vs let (rev1,rev2) = vecmul_rev dst is vs let t1 = map2 (map2 (map2 (==))) fwd1 rev1 |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true let t2 = map2 (map2 (map2 (==))) fwd2 rev2 |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true in t1 && t2 def vecmax_p [n][m][k] (is: [n]i64) (dst: [m][k]i32, vs: [n][k]i32): [m][k]i32 = reduce_by_index (copy dst) (map2 i32.max) (replicate k i32.lowest) is vs def vecmax_rev [n][m][k] (dst: [m][k]i32) (is: [n]i64) (vs: [n][k]i32) = tabulate m (\i -> vjp (vecmax_p is) (dst,vs) (replicate m (replicate k 0) with [i] = replicate k 1)) |> unzip def vecmax_fwd [n][m][k] (dst: [m][k]i32) (is: [n]i64) (vs: [n][k]i32) = let t1 = tabulate m (\i -> jvp (vecmax_p is) (dst,vs) (replicate m (replicate k 0) with [i] = replicate k 1,replicate n (replicate k 0))) let t2 = tabulate n (\i -> jvp (vecmax_p is) (dst,vs) (replicate m (replicate k 0),replicate n (replicate k 0) with [i] = replicate k 1)) in (transpose t1,transpose t2) entry vecmax [n][m][k] (dst: [m][k]i32) (is': [n]u64) (vs: [n][k]i32) = let is = map (\i -> i64.u64 (i%%u64.i64 m)) is' let (fwd1,fwd2) = vecmax_fwd dst is vs let (rev1,rev2) = vecmax_rev dst is vs let t1 = map2 (map2 (map2 (==))) fwd1 rev1 |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true let t2 = map2 (map2 (map2 (==))) fwd2 rev2 |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true in t1 && t2 futhark-0.25.27/tests/ad/reducebyindexvecmin0.fut000066400000000000000000000030341475065116200217240ustar00rootroot00000000000000-- == -- entry: vecmin -- input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] -- [[8i32, 5i32, -2i32, 4i32, 6i32], -- [12i32, 8i32, 7i32, 2i32, 6i32], -- [3i32, 9i32, -2i32, 11i32, 1i32], -- [7i32, 3i32, 12i32, 7i32, 10i32], -- [9i32, 12i32, 4i32, 1i32, 8i32], -- [7i32, -1i32, 11i32, 6i32, 10i32], -- [-2i32, 6i32, 7i32, 1i32, 12i32], -- [8i32, 0i32, 9i32, 6i32, 7i32], -- [7i32, 3i32, 6i32, 7i32, 8i32], -- [1i32, 4i32, 2i32, 9i32, 9i32]] -- [[2i32, 2i32, 2i32, 2i32, 4i32], -- [2i32, 1i32, 3i32, 3i32, 1i32], -- [2i32, 5i32, 2i32, 4i32, 5i32], -- [1i32, 2i32, 1i32, 1i32, 4i32], -- [5i32, 1i32, 4i32, 4i32, 5i32], -- [4i32, 1i32, 4i32, 5i32, 3i32]] } -- output { [[4i32, 1i32, 4i32, 5i32, 3i32], -- [0i32, 0i32, 0i32, 0i32, 4i32], -- [2i32, 0i32, 2i32, 0i32, 5i32], -- [0i32, 1i32, 0i32, 0i32, 5i32], -- [0i32, 0i32, 0i32, 1i32, 0i32], -- [0i32, 2i32, 0i32, 0i32, 0i32], -- [5i32, 0i32, 4i32, 4i32, 0i32], -- [0i32, 5i32, 0i32, 4i32, 0i32], -- [0i32, 0i32, 0i32, 0i32, 0i32], -- [1i32, 0i32, 1i32, 0i32, 0i32]] } entry vecmin [n][d][bins] (is: [n]i64) (vs: [n][d]i32) (adj_out: [bins][d]i32) = vjp (hist (map2 i32.min) (replicate d i32.highest) bins is) vs adj_out futhark-0.25.27/tests/ad/reducebyindexvecmul0.fut000066400000000000000000000031251475065116200217370ustar00rootroot00000000000000-- == -- entry: vecmul -- input { [5i64, 3i64, 2i64, 4i64, 3i64, 3i64, 4i64, 2i64, 2i64, 3i64] -- [[8i32, 5i32, -2i32, 4i32, 6i32], -- [12i32, 8i32, 7i32, 2i32, 6i32], -- [3i32, 9i32, -2i32, 11i32, 1i32], -- [7i32, 3i32, 12i32, 7i32, 10i32], -- [9i32, 12i32, 4i32, 1i32, 8i32], -- [7i32, -1i32, 11i32, 6i32, 10i32], -- [-2i32, 6i32, 7i32, 1i32, 12i32], -- [8i32, 0i32, 9i32, 6i32, 7i32], -- [7i32, 3i32, 6i32, 7i32, 8i32], -- [1i32, 4i32, 2i32, 9i32, 9i32]] -- [[2i32, 2i32, 2i32, 2i32, 4i32], -- [2i32, 1i32, 3i32, 3i32, 1i32], -- [2i32, 5i32, 2i32, 4i32, 5i32], -- [1i32, 2i32, 1i32, 1i32, 4i32], -- [5i32, 1i32, 4i32, 4i32, 5i32], -- [4i32, 1i32, 4i32, 5i32, 3i32]] } -- output { [[4i32, 1i32, 4i32, 5i32, 3i32], -- [63i32, -96i32, 88i32, 54i32, 2880i32], -- [112i32, 0i32, 108i32, 168i32, 280i32], -- [-10i32, 6i32, 28i32, 4i32, 60i32], -- [84i32, -64i32, 154i32, 108i32, 2160i32], -- [108i32, 768i32, 56i32, 18i32, 1728i32], -- [35i32, 3i32, 48i32, 28i32, 50i32], -- [42i32, 135i32, -24i32, 308i32, 40i32], -- [48i32, 0i32, -36i32, 264i32, 35i32], -- [756i32, -192i32, 308i32, 12i32, 1920i32]] } entry vecmul [n][d][bins] (is: [n]i64) (vs: [n][d]i32) (adj_out: [bins][d]i32) = vjp (hist (map2 (*)) (replicate d 1i32) bins is) vs adj_out futhark-0.25.27/tests/ad/reducemul0.fut000066400000000000000000000004771475065116200176650ustar00rootroot00000000000000-- == -- entry: rev fwd -- input { [0.0f32, 2.0f32, 0.0f32, 4.0f32] } output { [0.0f32, 0.0f32, 0.0f32, 0.0f32] } def red_mult [n] (xs: [n]f32) : f32 = reduce (*) 1 xs entry rev [n] (xs: [n]f32) = vjp red_mult (xs) 1 entry fwd [n] (xs: [n]f32) = tabulate n (\i -> jvp red_mult xs (replicate n 0 with [i] = 1)) futhark-0.25.27/tests/ad/reducemul1.fut000066400000000000000000000003561475065116200176620ustar00rootroot00000000000000-- == -- entry: rev -- input { [1f32, 0f32, 3f32, 4f32] 3.0f32 } output { [0f32, 36f32, 0f32, 0f32] 0f32 } def red_mult [n] (xs: [n]f32, c: f32) : f32 = reduce (*) 1 xs * c entry rev [n] (xs: [n]f32) (c: f32) = vjp red_mult (xs,c) 1futhark-0.25.27/tests/ad/reducemul2.fut000066400000000000000000000003551475065116200176620ustar00rootroot00000000000000-- == -- entry: rev -- input { [1f32, 0f32, 3f32, 0f32] 3.0f32 } output { [0f32, 0f32, 0f32, 0f32] 0f32 } def red_mult [n] (xs: [n]f32, c: f32) : f32 = reduce (*) 1 xs * c entry rev [n] (xs: [n]f32) (c: f32) = vjp red_mult (xs,c) 1futhark-0.25.27/tests/ad/reducemul3.fut000066400000000000000000000003621475065116200176610ustar00rootroot00000000000000-- == -- entry: rev -- input { [1f32, 2f32, 3f32, 4f32] 3.0f32 } output { [72f32, 36f32, 24f32, 18f32] 24f32 } def red_mult [n] (xs: [n]f32, c: f32) : f32 = reduce (*) 1 xs * c entry rev [n] (xs: [n]f32) (c: f32) = vjp red_mult (xs,c) 1futhark-0.25.27/tests/ad/reducemul4.fut000066400000000000000000000007131475065116200176620ustar00rootroot00000000000000-- == -- entry: fwd rev -- input { [1f32, 2f32, 3f32, 4f32] } output { [[48f32, 12f32, 8f32, 6f32], [48f32, 48f32, 16f32, 12f32], [72f32, 36f32, 48f32, 18f32], [96f32, 48f32, 32f32, 48f32]] } def fun [n] (as: [n]f32) = let x = reduce (*) 1 as in map (*x) as entry fwd [n] (as: [n]f32) = tabulate n (\i -> jvp fun as (replicate n 0 with [i] = 1)) |> transpose entry rev [n] (as: [n]f32) = tabulate n (\i -> vjp fun as (replicate n 0 with [i] = 1))futhark-0.25.27/tests/ad/reducevec0.fut000066400000000000000000000020511475065116200176330ustar00rootroot00000000000000-- == -- entry: rev fwd -- input { -- [[[0f32,1f32],[2f32,3f32]], -- [[5f32,1f32],[3f32,0f32]], -- [[0f32,1f32],[4f32,4f32]]] } -- output { [[[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]], [[[0f32, 1f32], [0f32, 0f32]], [[0f32, 1f32], [0f32, 0f32]], [[0f32, 1f32], [0f32, 0f32]]]], [[[[0f32, 0f32], [12f32, 0f32]], [[0f32, 0f32], [8f32, 0f32]], [[0f32, 0f32], [6f32, 0f32]]], [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 12f32]], [[0f32, 0f32], [0f32, 0f32]]]]] } let f [n][m][k] (xs: [n][m][k]f32) : [m][k]f32 = reduce (map2 (map2 (*))) (replicate m (replicate k 1)) xs entry rev [n][m][k] (xs: [n][m][k]f32) : [m][k][n][m][k]f32 = tabulate_2d m k (\i j -> vjp f xs (replicate m (replicate k 0) with [i] = (replicate k 0 with [j] = 1))) entry fwd [n][m][k] (xs: [n][m][k]f32) : [m][k][n][m][k]f32 = tabulate_3d n m k (\i j l -> jvp f xs (replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [l] = 1)))) |> transpose |> map transposefuthark-0.25.27/tests/ad/reducevecmul0.fut000066400000000000000000000006301475065116200203520ustar00rootroot00000000000000-- == -- entry: rev -- input { [[0.0f32, 2.0f32, 0.0f32, 4.0f32], [4.0f32, 2.0f32, 0.0f32, 0.0f32]] } output { [[4.000000f32, 0.000000f32, 0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32, 0.000000f32, 0.000000f32]] } let red_mult [m][n] (xs: [m][n]f32) : [n]f32 = reduce (map2 (*)) (replicate n 1) xs entry rev [m][n] (xs: [m][n]f32) = vjp red_mult xs (replicate n 0 with [0] = 1) let main = revfuthark-0.25.27/tests/ad/reducevecmul1.fut000066400000000000000000000005601475065116200203550ustar00rootroot00000000000000-- == -- entry: rev -- input { [[1f32, 0f32, 3f32, 4f32], [2f32,2f32,2f32,2f32]] 3.0f32 } output { [[0f32, 0f32, 0f32, 6f32], [0f32, 0f32, 0f32,12f32]] 8f32 } def red_mult [m][n] (xs: [m][n]f32, c: f32) = reduce (map2 (*)) (replicate n 1) xs |> map (*c) entry rev [m][n] (xs: [m][n]f32) (c: f32) = vjp red_mult (xs,c) (replicate n 0 with [3] = 1) let main = revfuthark-0.25.27/tests/ad/reducevecmul2.fut000066400000000000000000000005601475065116200203560ustar00rootroot00000000000000-- == -- entry: rev -- input { [[1f32, 0f32, 3f32, 0f32], [1f32,2f32,3f32,4f32]] 3.0f32 } output { [[0f32, 0f32, 9f32, 0f32], [0f32, 0f32, 9f32, 0f32]] 9f32 } def red_mult [m][n] (xs: [m][n]f32, c: f32) = reduce (map2 (*)) (replicate n 1) xs |> map (*c) entry rev [m][n] (xs: [m][n]f32) (c: f32) = vjp red_mult (xs,c) (replicate n 0 with [2] = 1) let main = revfuthark-0.25.27/tests/ad/reducevecmul3.fut000066400000000000000000000005621475065116200203610ustar00rootroot00000000000000-- == -- entry: rev -- input { [[1f32, 2f32, 3f32, 4f32], [1f32, 1f32, 1f32, 1f32]] 3.0f32 } output { [[0f32, 3f32, 0f32, 0f32],[0f32, 6f32, 0f32, 0f32]] 2f32 } def red_mult [m][n] (xs: [m][n]f32, c: f32) = reduce (map2 (*)) (replicate n 1) xs |> map (*c) entry rev [m][n] (xs: [m][n]f32) (c: f32) = vjp red_mult (xs,c) (replicate n 0 with [1] = 1) let main = revfuthark-0.25.27/tests/ad/replicate0.fut000066400000000000000000000003441475065116200176410ustar00rootroot00000000000000-- == -- entry: f_jvp -- input { 3i64 2 } -- output { [1,1,1] } entry f_jvp n x : []i32 = jvp (replicate n) x 1 -- == -- entry: f_vjp -- input { 3i64 2i64 } -- output { 3i64 } entry f_vjp n x = vjp (replicate n) x (iota n) futhark-0.25.27/tests/ad/replicate1.fut000066400000000000000000000005701475065116200176430ustar00rootroot00000000000000-- Differentiating with respect to 'n' does not make much sense, but -- it should at least not crash. -- == -- entry: f_jvp -- input { 3i64 2i64 } -- output { [0i64,0i64,0i64] } entry f_jvp n x = jvp (\n' -> replicate n' x :> [n]i64) n 1 -- == -- entry: f_vjp -- input { 3i64 2i64 } -- output { 0i64 } entry f_vjp n x = vjp (\n' -> replicate n' x :> [n]i64) n (iota n) futhark-0.25.27/tests/ad/replicate2.fut000066400000000000000000000004541475065116200176450ustar00rootroot00000000000000-- == -- entry: f_jvp -- input { 2i64 3i64 2 } -- output { [[1,1,1],[1,1,1]] } entry f_jvp n m x : [][]i32 = jvp (replicate m >-> replicate n) x 1 -- == -- entry: f_vjp -- input { 2i64 3i64 2i64 } -- output { 6i64 } entry f_vjp n m x = vjp (replicate m >-> replicate n) x (replicate n (iota m)) futhark-0.25.27/tests/ad/reshape0.fut000066400000000000000000000004351475065116200173210ustar00rootroot00000000000000-- == -- entry: f_jvp -- input { 2i64 2i64 [1,2,3,4] } -- output { [[1,2],[3,4]] } entry f_jvp n m (xs: [n*m]i32) = jvp unflatten xs xs -- == -- entry: f_vjp -- input { 2i64 2i64 [1,2,3,4] } -- output { [1,2,3,4] } entry f_vjp n m (xs: [n*m]i32) = vjp unflatten xs (unflatten xs) futhark-0.25.27/tests/ad/rev_const.fut000066400000000000000000000002471475065116200176150ustar00rootroot00000000000000-- What happens if a result is constant? -- == -- input { 1f32 2f32 } output { 1f32 1f32 } def main (x: f32) (y: f32) = vjp (\(x',y') -> (x' + y', 0)) (x,y) (1, 0) futhark-0.25.27/tests/ad/rev_unused.fut000066400000000000000000000002461475065116200177710ustar00rootroot00000000000000-- What happens if not all the parameters are used? -- == -- input { 1f32 2f32 } output { 1f32 0f32 } def main (x: f32) (y: f32) = vjp (\(x',_) -> x' + 2) (x,y) 1 futhark-0.25.27/tests/ad/rotate0.fut000066400000000000000000000004061475065116200171660ustar00rootroot00000000000000-- == -- entry: f_jvp -- input { 1i64 [1,2,3,4] } -- output { [2,3,4,1] } entry f_jvp k (xs: []i32) = jvp (rotate k) xs xs -- == -- entry: f_vjp -- input { 1i64 [1,2,3,4] } -- output { [1,2,3,4] } entry f_vjp k (xs: []i32) = vjp (rotate k) xs (rotate k xs) futhark-0.25.27/tests/ad/scan0.fut000066400000000000000000000011751475065116200166200ustar00rootroot00000000000000-- Scan with multiplication. -- generic case -- == -- entry: fwd_J rev_J -- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } -- output { [[1.0f32, 0.0f32, 0.0f32, 0.0f32, 0.0f32], -- [2.0f32, 1.0f32, 0.0f32, 0.0f32, 0.0f32], -- [6.0f32, 3.0f32, 2.0f32, 0.0f32, 0.0f32], -- [24.0f32, 12.0f32, 8.0f32, 6.0f32, 0.0f32], -- [120.0f32, 60.0f32, 40.0f32, 30.0f32, 24.0f32]] -- } entry fwd_J [n] (a: [n]f32) = tabulate n (\i -> jvp (scan (*) 1) a (replicate n 0 with [i] = 1)) |> transpose entry rev_J [n] (a: [n]f32) = tabulate n (\i -> vjp (scan (*) 1) a (replicate n 0 with [i] = 1)) futhark-0.25.27/tests/ad/scan1.fut000066400000000000000000000011701475065116200166140ustar00rootroot00000000000000-- Scan with addition. -- addition special case -- == -- entry: fwd_J rev_J -- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } -- output { [[1.0f32, 0.0f32, 0.0f32, 0.0f32, 0.0f32], -- [1.0f32, 1.0f32, 0.0f32, 0.0f32, 0.0f32], -- [1.0f32, 1.0f32, 1.0f32, 0.0f32, 0.0f32], -- [1.0f32, 1.0f32, 1.0f32, 1.0f32, 0.0f32], -- [1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32]] -- } entry fwd_J [n] (a: [n]f32) = tabulate n (\i -> jvp (scan (+) 0) a (replicate n 0 with [i] = 1)) |> transpose entry rev_J [n] (a: [n]f32) = tabulate n (\i -> vjp (scan (+) 0) a (replicate n 0 with [i] = 1)) futhark-0.25.27/tests/ad/scan2.fut000066400000000000000000000023331475065116200166170ustar00rootroot00000000000000-- Scan with vectorised addition. -- special cases: vectorised and addition -- == -- entry: fwd_J rev_J -- input { [[1f32,1f32],[2f32,2f32],[3f32,3f32],[4f32,4f32],[5f32,5f32]] } -- output { [[[1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32]]] } def primal [n][k] (a: [n][k]f32) = scan (map2 (+)) (replicate k 0) a entry fwd_J [n][k] (a: [n][k]f32) = tabulate n (\i -> jvp primal a (replicate n (replicate k 0) with [i] = replicate k 1)) |> transpose entry rev_J [n][k] (a: [n][k]f32) = tabulate n (\i -> vjp primal a (replicate n (replicate k 0) with [i] = replicate k 1)) futhark-0.25.27/tests/ad/scan3.fut000066400000000000000000000055551475065116200166310ustar00rootroot00000000000000-- Scan with 2x2 matrix multiplication. -- MatrixMul case -- == -- entry: fwd_J rev_J -- input { [[1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32], [1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32]] } -- output { -- [[[[1f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32]], -- [[0f32, 1f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 1f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32]]], -- [[[4f32, 2f32, 0f32, 0f32], [1f32, 0f32, 2f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32]], -- [[3f32, 1f32, 0f32, 0f32], [0f32, 1f32, 0f32, 2f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 4f32, 2f32], [3f32, 0f32, 4f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 3f32, 1f32], [0f32, 3f32, 0f32, 4f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32]]], -- [[[13f32, 5f32, 0f32, 0f32], [1f32, 3f32, 2f32, 6f32], [8f32, 0f32, 5f32, 0f32], [0f32, 0f32, 0f32, 0f32]], -- [[20f32, 8f32, 0f32, 0f32], [2f32, 4f32, 4f32, 8f32], [0f32, 8f32, 0f32, 5f32], [0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 13f32, 5f32], [3f32, 9f32, 4f32, 12f32], [20f32, 0f32, 13f32, 0f32], [0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 20f32, 8f32], [6f32, 12f32, 8f32, 16f32], [0f32, 20f32, 0f32, 13f32], [0f32, 0f32, 0f32, 0f32]]], -- [[[92f32, 36f32, 0f32, 0f32], [8f32, 20f32, 16f32, 40f32], [32f32, 16f32, 20f32, 10f32], [23f32, 0f32, 36f32, 0f32]], -- [[59f32, 23f32, 0f32, 0f32], [5f32, 13f32, 10f32, 26f32], [24f32, 8f32, 15f32, 5f32], [0f32, 23f32, 0f32, 36f32]], -- [[0f32, 0f32, 92f32, 36f32], [24f32, 60f32, 32f32, 80f32], [80f32, 40f32, 52f32, 26f32], [59f32, 0f32, 92f32, 0f32]], -- [[0f32, 0f32, 59f32, 23f32], [15f32, 39f32, 20f32, 52f32], [60f32, 20f32, 39f32, 13f32], [0f32, 59f32, 0f32, 92f32]]]] -- } def mm2by2 (a1: f32, b1: f32, c1: f32, d1: f32) (a2: f32, b2: f32, c2: f32, d2: f32) = ( a1*a2 + b1*c2 , a1*b2 + b1*d2 , c1*a2 + d1*c2 , c1*b2 + d1*d2 ) def primal [n] (xs: [n](f32,f32,f32,f32)) = scan mm2by2 (1, 0, 0, 1) xs def fromarrs = map (\(x: [4]f32) -> (x[0],x[1],x[2],x[3])) def toarrs = map (\(a,b,c,d) -> [a,b,c,d]) def onehot_2d n m x y = tabulate_2d n m (\i j -> f32.bool((i,j) == (x,y))) entry fwd_J [n] (input: [n][4]f32) : [n][4][n][4]f32 = let input = fromarrs input in tabulate (n*4) (\i -> jvp primal input (fromarrs (onehot_2d n 4 (i/4) (i%4)))) |> map toarrs |> transpose |> map transpose |> map (map unflatten) entry rev_J [n] (input: [n][4]f32) : [n][4][n][4]f32 = let input = fromarrs input in tabulate (n*4) (\i -> vjp primal input (fromarrs (onehot_2d n 4 (i/4) (i%4)))) |> unflatten |> map (map toarrs) futhark-0.25.27/tests/ad/scan4.fut000066400000000000000000000021651475065116200166240ustar00rootroot00000000000000-- Scan with tuple operator. -- ZeroQuadrant case -- == -- entry: fwd_J rev_J -- input { [[1.0f32, 2.0f32, 3.0f32], [4.0f32, 3.0f32, 5.0f32], [3.0f32, 4.0f32, 2.0f32], [4.0f32, 2.0f32, 1.0f32]] } -- output { -- [[[1f32, 1f32, 1f32], [0f32, 0f32, 0f32], -- [0f32, 0f32, 0f32], [0f32, 0f32, 0f32]], -- [[1f32, 3f32, 0f32], [1f32, 2f32, 1f32], -- [0f32, 0f32, 0f32], [0f32, 0f32, 0f32]], -- [[1f32, 12f32, 0f32], [1f32, 8f32, 1f32], -- [1f32, 6f32, 0f32], [0f32, 0f32, 0f32]], -- [[1f32, 24f32, 0f32], [1f32, 16f32, 1f32], -- [1f32, 12f32, 0f32], [1f32, 24f32, 0f32]]] -- } def primal [n] (xs: [n](f32,f32,f32)) = scan (\(a1,b1,c1) (a2,b2,c2) -> (a1+a2, b1*b2, f32.max c1 c2)) (0,1,f32.lowest) xs def fromarrs = map (\x -> (x[0],x[1],x[2])) def toarrs = map (\(a,b,c) -> [a,b,c]) entry fwd_J [n] (input: [n][3]f32) = let input = fromarrs input in tabulate n (\i -> jvp primal input (replicate n (0,0,0) with [i] = (1,1,1))) |> map toarrs |> transpose entry rev_J [n] (input: [n][3]f32) = let input = fromarrs input in tabulate n (\i -> vjp primal input (replicate n (0,0,0) with [i] = (1,1,1))) |> map toarrs futhark-0.25.27/tests/ad/scan5.fut000066400000000000000000000016461475065116200166300ustar00rootroot00000000000000-- Scan with vectorised product. -- Vectorised special case + generic case -- == -- entry: fwd_J rev_J -- input { [[1f32,1f32],[2f32,2f32],[3f32,3f32],[4f32,4f32],[5f32,5f32]] } -- output { -- [[[1f32, 1f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]], -- [[2f32, 2f32], [1f32, 1f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]], -- [[6f32, 6f32], [3f32, 3f32], [2f32, 2f32], [0f32, 0f32], [0f32, 0f32]], -- [[24f32, 24f32], [12f32, 12f32], [8f32, 8f32], [6f32, 6f32], [0f32, 0f32]], -- [[120f32, 120f32], [60f32, 60f32], [40f32, 40f32], [30f32, 30f32], [24f32, 24f32]]] -- } def primal [n][k] (a: [n][k]f32) = scan (map2 (*)) (replicate k 1) a entry fwd_J [n][k] (a: [n][k]f32) = tabulate n (\i -> jvp primal a (replicate n (replicate k 0) with [i] = replicate k 1)) |> transpose entry rev_J [n][k] (a: [n][k]f32) = tabulate n (\i -> vjp primal a (replicate n (replicate k 0) with [i] = replicate k 1)) futhark-0.25.27/tests/ad/scan6.fut000066400000000000000000000144571475065116200166350ustar00rootroot00000000000000-- Scan with linear function composition. -- MatrixMul case -- == -- entry: fwd_J rev_J -- input { [[1f32, 2f32], [4f32, 3f32], [3f32, 4f32], [4f32, 2f32]] } -- output { -- [[[[1f32, 0f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]], -- [[0f32, 1f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]]], -- [[[3f32, 0f32], [1f32, 1f32], [0f32, 0f32], [0f32, 0f32]], -- [[0f32, 3f32], [0f32, 2f32], [0f32, 0f32], [0f32, 0f32]]], -- [[[12f32, 0f32], [4f32, 4f32], [1f32, 7f32], [0f32, 0f32]], -- [[0f32, 12f32], [0f32, 8f32], [0f32, 6f32], [0f32, 0f32]]], -- [[[24f32, 0f32], [8f32, 8f32], [2f32, 14f32], [1f32, 31f32]], -- [[0f32, 24f32], [0f32, 16f32], [0f32, 12f32], [0f32, 24f32]]]] -- } def primal [n] (xs: [n](f32,f32)) = scan (\(a1,b1) (a2,b2) -> (a2 + b2*a1, b1*b2)) (0,1) xs def fromarrs = map (\x -> (x[0],x[1])) def toarrs = map (\(a,b) -> [a,b]) def onehot_2d n m x y = tabulate_2d n m (\i j -> f32.bool((i,j) == (x,y))) entry fwd_J [n] (input: [n][2]f32) = let input = fromarrs input in tabulate (n*2) (\i -> jvp primal input (fromarrs (onehot_2d n 2 (i/2) (i%2)))) |> map toarrs |> transpose |> map transpose |> map (map unflatten) entry rev_J [n] (input: [n][2]f32) = let input = fromarrs input in tabulate (n*2) (\i -> vjp primal input (fromarrs (onehot_2d n 2 (i/2) (i%2)))) |> unflatten |> map (map toarrs) -- == -- entry: fwd_J2 rev_J2 -- no_oclgrind input { [[1f32,2f32,3f32,4f32,5f32,6f32],[6f32,5f32,4f32,3f32,2f32,1f32],[4f32,5f32,6f32,1f32,2f32,3f32],[3f32,2f32,1f32,6f32,5f32,4f32]] } -- output { [[[[1f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 1f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 1f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 1f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[4f32, 3f32, 0f32, 0f32, 0f32, 0f32], [1f32, 0f32, 1f32, 2f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[2f32, 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 1f32, 0f32, 0f32, 1f32, 2f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 4f32, 0f32, 3f32, 0f32], [0f32, 0f32, 3f32, 5f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 4f32, 0f32, 3f32], [0f32, 0f32, 4f32, 6f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 2f32, 0f32, 1f32, 0f32], [0f32, 0f32, 0f32, 0f32, 3f32, 5f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 2f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32, 4f32, 6f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[26f32, 19f32, 0f32, 0f32, 0f32, 0f32], [6f32, 1f32, 6f32, 12f32, 1f32, 2f32], [1f32, 0f32, 16f32, 9f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[14f32, 9f32, 0f32, 0f32, 0f32, 0f32], [2f32, 3f32, 2f32, 4f32, 3f32, 6f32], [0f32, 1f32, 0f32, 0f32, 16f32, 9f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 26f32, 0f32, 19f32, 0f32], [0f32, 0f32, 18f32, 30f32, 3f32, 5f32], [0f32, 0f32, 27f32, 11f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 26f32, 0f32, 19f32], [0f32, 0f32, 24f32, 36f32, 4f32, 6f32], [0f32, 0f32, 34f32, 14f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 14f32, 0f32, 9f32, 0f32], [0f32, 0f32, 6f32, 10f32, 9f32, 15f32], [0f32, 0f32, 0f32, 0f32, 27f32, 11f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 14f32, 0f32, 9f32], [0f32, 0f32, 8f32, 12f32, 12f32, 18f32], [0f32, 0f32, 0f32, 0f32, 34f32, 14f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[110f32, 73f32, 0f32, 0f32, 0f32, 0f32], [18f32, 19f32, 18f32, 36f32, 19f32, 38f32], [1f32, 6f32, 16f32, 9f32, 96f32, 54f32], [1f32, 0f32, 109f32, 64f32, 0f32, 0f32]], [[186f32, 131f32, 0f32, 0f32, 0f32, 0f32], [38f32, 17f32, 38f32, 76f32, 17f32, 34f32], [5f32, 4f32, 80f32, 45f32, 64f32, 36f32], [0f32, 1f32, 0f32, 0f32, 109f32, 64f32]], [[0f32, 0f32, 110f32, 0f32, 73f32, 0f32], [0f32, 0f32, 54f32, 90f32, 57f32, 95f32], [0f32, 0f32, 27f32, 11f32, 162f32, 66f32], [0f32, 0f32, 173f32, 87f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 110f32, 0f32, 73f32], [0f32, 0f32, 72f32, 108f32, 76f32, 114f32], [0f32, 0f32, 34f32, 14f32, 204f32, 84f32], [0f32, 0f32, 218f32, 110f32, 0f32, 0f32]], [[0f32, 0f32, 186f32, 0f32, 131f32, 0f32], [0f32, 0f32, 114f32, 190f32, 51f32, 85f32], [0f32, 0f32, 135f32, 55f32, 108f32, 44f32], [0f32, 0f32, 0f32, 0f32, 173f32, 87f32]], [[0f32, 0f32, 0f32, 186f32, 0f32, 131f32], [0f32, 0f32, 152f32, 228f32, 68f32, 102f32], [0f32, 0f32, 170f32, 70f32, 136f32, 56f32], [0f32, 0f32, 0f32, 0f32, 218f32, 110f32]]]] } def mm2by2 (a1, b1, c1, d1) (a2, b2, c2, d2) : (f32,f32,f32,f32) = ( a1*a2 + b1*c2 , a1*b2 + b1*d2 , c1*a2 + d1*c2 , c1*b2 + d1*d2 ) def mv2 (a, b, c, d) (e, f): (f32,f32) = ( a*e + b*f , c*e + d*f) def vv2 (a, b) (c, d): (f32,f32) = ( a+c , b+d) def lino2by2 (d1,c1) (d2,c2) : ((f32,f32), (f32,f32,f32,f32)) = (vv2 d2 (mv2 c2 d1),mm2by2 c2 c1) def primal2 [n] (as: [n]((f32,f32), (f32,f32,f32,f32))) = scan lino2by2 ((0,0),(1,0,0,1)) as def fromarrs2 = map (\x -> ((x[0],x[1]),(x[2],x[3],x[4],x[5]))) def toarrs2 = map (\((a,b),(c,d,e,f)) -> [a,b,c,d,e,f]) entry fwd_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 = let input = fromarrs2 input in tabulate (n*6) (\i -> jvp primal2 input (fromarrs2 (onehot_2d n 6 (i/6) (i%6)))) |> map toarrs2 |> transpose |> map transpose |> map (map unflatten) entry rev_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 = let input = fromarrs2 input in tabulate (n*6) (\i -> vjp primal2 input (fromarrs2 (onehot_2d n 6 (i/6) (i%6)))) |> unflatten |> map (map toarrs2) futhark-0.25.27/tests/ad/scan7.fut000066400000000000000000000173701475065116200166330ustar00rootroot00000000000000-- Scan with nested map -- vectorised special case, generic case -- == -- entry: fwd_J rev_J -- input { [[[1f32,2f32], [2f32,3f32]], [[4f32,5f32], [3f32,4f32]], -- [[3f32,4f32], [4f32,5f32]], [[4f32,5f32], [2f32,3f32]]] } -- output { --[[[[[[1f32, 0f32], [0f32, 0f32]], [[4f32, 0f32], [0f32, 0f32]], -- [[12f32, 0f32], [0f32, 0f32]], [[48f32, 0f32], [0f32, 0f32]]], -- [[[0f32, 1f32], [0f32, 0f32]], [[0f32, 5f32], [0f32, 0f32]], -- [[0f32, 20f32], [0f32, 0f32]], [[0f32, 100f32], [0f32, 0f32]]]], -- [[[[0f32, 0f32], [1f32, 0f32]], [[0f32, 0f32], [3f32, 0f32]], -- [[0f32, 0f32], [12f32, 0f32]], [[0f32, 0f32], [24f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 1f32]], [[0f32, 0f32], [0f32, 4f32]], -- [[0f32, 0f32], [0f32, 20f32]], [[0f32, 0f32], [0f32, 60f32]]]]], -- [[[[[0f32, 0f32], [0f32, 0f32]], [[1f32, 0f32], [0f32, 0f32]], -- [[3f32, 0f32], [0f32, 0f32]], [[12f32, 0f32], [0f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 2f32], [0f32, 0f32]], -- [[0f32, 8f32], [0f32, 0f32]], [[0f32, 40f32], [0f32, 0f32]]]], -- [[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [2f32, 0f32]], -- [[0f32, 0f32], [8f32, 0f32]], [[0f32, 0f32], [16f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 3f32]], -- [[0f32, 0f32], [0f32, 15f32]], [[0f32, 0f32], [0f32, 45f32]]]]], -- [[[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[4f32, 0f32], [0f32, 0f32]], [[16f32, 0f32], [0f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 10f32], [0f32, 0f32]], [[0f32, 50f32], [0f32, 0f32]]]], -- [[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 0f32], [6f32, 0f32]], [[0f32, 0f32], [12f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 0f32], [0f32, 12f32]], [[0f32, 0f32], [0f32, 36f32]]]]], -- [[[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 0f32], [0f32, 0f32]], [[12f32, 0f32], [0f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 0f32], [0f32, 0f32]], [[0f32, 40f32], [0f32, 0f32]]]], -- [[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [24f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 60f32]]]]]] -- } def primal [n][m][k] (xs: [n][m][k]f32) = scan (map2 (map2 (*))) (replicate m (replicate k 1)) xs entry fwd_J [n][m][k] (input: [n][m][k]f32) = tabulate_3d n m k (\i j q -> jvp primal input (replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [q] = 1)))) entry rev_J [n][m][k] (input: [n][m][k]f32) = let res = tabulate_3d n m k (\i j q -> vjp primal input (replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [q] = 1)))) let a = res |> map (map transpose) |> map (map (map transpose)) |> map (map (map (map transpose))) let a2 = a |> map transpose |> map (map transpose) |> map (map (map transpose)) in a2 |> transpose |> map transpose |> (map (map transpose)) -- == -- entry: test -- input { [[[1f32,2f32], [2f32,3f32]], [[4f32,5f32], [3f32,4f32]], -- [[3f32,4f32], [4f32,5f32]], [[4f32,5f32], [2f32,3f32]]] -- [[[[[[1f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]], [[[0f32, 1f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]]], [[[[0f32, 0f32], [1f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]], [[[0f32, 0f32], [0f32, 1f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]]]], [[[[[0f32, 0f32], [0f32, 0f32]], [[1f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]], [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 1f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]]], [[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [1f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]], [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 1f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]]]], [[[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[1f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]], [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 1f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]]], [[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [1f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]]], [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 1f32]], [[0f32, 0f32], [0f32, 0f32]]]]], [[[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[1f32, 0f32], [0f32, 0f32]]], [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 1f32], [0f32, 0f32]]]], [[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [1f32, 0f32]]], [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 1f32]]]]]] -- } -- output { --[[[[[[1f32, 0f32], [0f32, 0f32]], [[4f32, 0f32], [0f32, 0f32]], -- [[12f32, 0f32], [0f32, 0f32]], [[48f32, 0f32], [0f32, 0f32]]], -- [[[0f32, 1f32], [0f32, 0f32]], [[0f32, 5f32], [0f32, 0f32]], -- [[0f32, 20f32], [0f32, 0f32]], [[0f32, 100f32], [0f32, 0f32]]]], -- [[[[0f32, 0f32], [1f32, 0f32]], [[0f32, 0f32], [3f32, 0f32]], -- [[0f32, 0f32], [12f32, 0f32]], [[0f32, 0f32], [24f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 1f32]], [[0f32, 0f32], [0f32, 4f32]], -- [[0f32, 0f32], [0f32, 20f32]], [[0f32, 0f32], [0f32, 60f32]]]]], -- [[[[[0f32, 0f32], [0f32, 0f32]], [[1f32, 0f32], [0f32, 0f32]], -- [[3f32, 0f32], [0f32, 0f32]], [[12f32, 0f32], [0f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 2f32], [0f32, 0f32]], -- [[0f32, 8f32], [0f32, 0f32]], [[0f32, 40f32], [0f32, 0f32]]]], -- [[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [2f32, 0f32]], -- [[0f32, 0f32], [8f32, 0f32]], [[0f32, 0f32], [16f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 3f32]], -- [[0f32, 0f32], [0f32, 15f32]], [[0f32, 0f32], [0f32, 45f32]]]]], -- [[[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[4f32, 0f32], [0f32, 0f32]], [[16f32, 0f32], [0f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 10f32], [0f32, 0f32]], [[0f32, 50f32], [0f32, 0f32]]]], -- [[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 0f32], [6f32, 0f32]], [[0f32, 0f32], [12f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 0f32], [0f32, 12f32]], [[0f32, 0f32], [0f32, 36f32]]]]], -- [[[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 0f32], [0f32, 0f32]], [[12f32, 0f32], [0f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 0f32], [0f32, 0f32]], [[0f32, 40f32], [0f32, 0f32]]]], -- [[[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [24f32, 0f32]]], -- [[[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 0f32]], -- [[0f32, 0f32], [0f32, 0f32]], [[0f32, 0f32], [0f32, 60f32]]]]]] -- } entry test [n][m][k] (input: [n][m][k]f32) bar = let res = map (map (map (vjp primal input))) bar let a = res |> map (map transpose) |> map (map (map transpose)) |> map (map (map (map transpose))) let a2 = a |> map transpose |> map (map transpose) |> map (map (map transpose)) in a2 |> transpose |> map transpose |> (map (map transpose))futhark-0.25.27/tests/ad/scan8.fut000066400000000000000000000245571475065116200166410ustar00rootroot00000000000000-- Scan with 3x3 matrix multiplication. -- == -- entry: fwd rev -- input { [[1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32], -- [9f32,8f32,7f32,6f32,5f32,4f32,3f32,2f32,1f32], -- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32], -- [9f32,8f32,7f32,6f32,5f32,4f32,3f32,2f32,1f32]] } -- output { [[[[1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 1f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 1f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, -- 0f32, 0f32, 0f32, 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 0f32, 0f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 1f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 0f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], -- [[[9f32, 6f32, 3f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [1f32, -- 0f32, 0f32, 2f32, 0f32, 0f32, 3f32, 0f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32]], [[8f32, 5f32, 2f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [0f32, 1f32, 0f32, 0f32, 2f32, 0f32, 0f32, 3f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[7f32, -- 4f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 1f32, -- 0f32, 0f32, 2f32, 0f32, 0f32, 3f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32]], [[0f32, 0f32, 0f32, 9f32, 6f32, 3f32, 0f32, 0f32, -- 0f32], [4f32, 0f32, 0f32, 5f32, 0f32, 0f32, 6f32, 0f32, 0f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, -- 0f32, 8f32, 5f32, 2f32, 0f32, 0f32, 0f32], [0f32, 4f32, 0f32, 0f32, -- 5f32, 0f32, 0f32, 6f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32]], [[0f32, 0f32, 0f32, 7f32, 4f32, 1f32, 0f32, 0f32, 0f32], -- [0f32, 0f32, 4f32, 0f32, 0f32, 5f32, 0f32, 0f32, 6f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 9f32, 6f32, 3f32], [7f32, 0f32, 0f32, 8f32, 0f32, 0f32, -- 9f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 8f32, 5f32, 2f32], [0f32, -- 7f32, 0f32, 0f32, 8f32, 0f32, 0f32, 9f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 7f32, 4f32, 1f32], [0f32, 0f32, 7f32, 0f32, 0f32, 8f32, 0f32, 0f32, -- 9f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[90f32, -- 54f32, 18f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [1f32, 4f32, -- 7f32, 2f32, 8f32, 14f32, 3f32, 12f32, 21f32], [30f32, 0f32, 0f32, -- 24f32, 0f32, 0f32, 18f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32]], [[114f32, 69f32, 24f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32], [2f32, 5f32, 8f32, 4f32, 10f32, 16f32, -- 6f32, 15f32, 24f32], [0f32, 30f32, 0f32, 0f32, 24f32, 0f32, 0f32, -- 18f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32]], [[138f32, 84f32, 30f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32], [3f32, 6f32, 9f32, 6f32, 12f32, 18f32, 9f32, 18f32, 27f32], -- [0f32, 0f32, 30f32, 0f32, 0f32, 24f32, 0f32, 0f32, 18f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, -- 0f32, 90f32, 54f32, 18f32, 0f32, 0f32, 0f32], [4f32, 16f32, 28f32, -- 5f32, 20f32, 35f32, 6f32, 24f32, 42f32], [84f32, 0f32, 0f32, 69f32, -- 0f32, 0f32, 54f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 114f32, 69f32, 24f32, -- 0f32, 0f32, 0f32], [8f32, 20f32, 32f32, 10f32, 25f32, 40f32, 12f32, -- 30f32, 48f32], [0f32, 84f32, 0f32, 0f32, 69f32, 0f32, 0f32, 54f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 0f32, 138f32, 84f32, 30f32, 0f32, 0f32, 0f32], -- [12f32, 24f32, 36f32, 15f32, 30f32, 45f32, 18f32, 36f32, 54f32], -- [0f32, 0f32, 84f32, 0f32, 0f32, 69f32, 0f32, 0f32, 54f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 90f32, 54f32, 18f32], [7f32, 28f32, 49f32, -- 8f32, 32f32, 56f32, 9f32, 36f32, 63f32], [138f32, 0f32, 0f32, -- 114f32, 0f32, 0f32, 90f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 114f32, 69f32, 24f32], [14f32, 35f32, 56f32, 16f32, 40f32, -- 64f32, 18f32, 45f32, 72f32], [0f32, 138f32, 0f32, 0f32, 114f32, -- 0f32, 0f32, 90f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 138f32, -- 84f32, 30f32], [21f32, 42f32, 63f32, 24f32, 48f32, 72f32, 27f32, -- 54f32, 81f32], [0f32, 0f32, 138f32, 0f32, 0f32, 114f32, 0f32, 0f32, -- 90f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], -- [[[1908f32, 1152f32, 396f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], -- [30f32, 84f32, 138f32, 60f32, 168f32, 276f32, 90f32, 252f32, -- 414f32], [270f32, 180f32, 90f32, 216f32, 144f32, 72f32, 162f32, -- 108f32, 54f32], [252f32, 0f32, 0f32, 324f32, 0f32, 0f32, 396f32, -- 0f32, 0f32]], [[1566f32, 945f32, 324f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32], [24f32, 69f32, 114f32, 48f32, 138f32, 228f32, 72f32, -- 207f32, 342f32], [240f32, 150f32, 60f32, 192f32, 120f32, 48f32, -- 144f32, 90f32, 36f32], [0f32, 252f32, 0f32, 0f32, 324f32, 0f32, -- 0f32, 396f32, 0f32]], [[1224f32, 738f32, 252f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [18f32, 54f32, 90f32, 36f32, 108f32, 180f32, -- 54f32, 162f32, 270f32], [210f32, 120f32, 30f32, 168f32, 96f32, -- 24f32, 126f32, 72f32, 18f32], [0f32, 0f32, 252f32, 0f32, 0f32, -- 324f32, 0f32, 0f32, 396f32]], [[0f32, 0f32, 0f32, 1908f32, 1152f32, -- 396f32, 0f32, 0f32, 0f32], [120f32, 336f32, 552f32, 150f32, 420f32, -- 690f32, 180f32, 504f32, 828f32], [756f32, 504f32, 252f32, 621f32, -- 414f32, 207f32, 486f32, 324f32, 162f32], [738f32, 0f32, 0f32, -- 945f32, 0f32, 0f32, 1152f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, -- 1566f32, 945f32, 324f32, 0f32, 0f32, 0f32], [96f32, 276f32, 456f32, -- 120f32, 345f32, 570f32, 144f32, 414f32, 684f32], [672f32, 420f32, -- 168f32, 552f32, 345f32, 138f32, 432f32, 270f32, 108f32], [0f32, -- 738f32, 0f32, 0f32, 945f32, 0f32, 0f32, 1152f32, 0f32]], [[0f32, -- 0f32, 0f32, 1224f32, 738f32, 252f32, 0f32, 0f32, 0f32], [72f32, -- 216f32, 360f32, 90f32, 270f32, 450f32, 108f32, 324f32, 540f32], -- [588f32, 336f32, 84f32, 483f32, 276f32, 69f32, 378f32, 216f32, -- 54f32], [0f32, 0f32, 738f32, 0f32, 0f32, 945f32, 0f32, 0f32, -- 1152f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1908f32, 1152f32, -- 396f32], [210f32, 588f32, 966f32, 240f32, 672f32, 1104f32, 270f32, -- 756f32, 1242f32], [1242f32, 828f32, 414f32, 1026f32, 684f32, -- 342f32, 810f32, 540f32, 270f32], [1224f32, 0f32, 0f32, 1566f32, -- 0f32, 0f32, 1908f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 1566f32, 945f32, 324f32], [168f32, 483f32, 798f32, 192f32, -- 552f32, 912f32, 216f32, 621f32, 1026f32], [1104f32, 690f32, 276f32, -- 912f32, 570f32, 228f32, 720f32, 450f32, 180f32], [0f32, 1224f32, -- 0f32, 0f32, 1566f32, 0f32, 0f32, 1908f32, 0f32]], [[0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 1224f32, 738f32, 252f32], [126f32, 378f32, -- 630f32, 144f32, 432f32, 720f32, 162f32, 486f32, 810f32], [966f32, -- 552f32, 138f32, 798f32, 456f32, 114f32, 630f32, 360f32, 90f32], -- [0f32, 0f32, 1224f32, 0f32, 0f32, 1566f32, 0f32, 0f32, 1908f32]]]] -- } def mm3by3 (a1: f32, b1: f32, c1: f32, d1: f32, e1: f32, f1: f32, g1: f32, h1: f32, i1: f32) (a2: f32, b2: f32, c2: f32, d2: f32, e2: f32, f2: f32, g2: f32, h2: f32, i2: f32) = ( a1*a2 + b1*d2 + c1*g2 , a1*b2 + b1*e2 + c1*h2 , a1*c2 + b1*f2 + c1*i2 , d1*a2 + e1*d2 + f1*g2 , d1*b2 + e1*e2 + f1*h2 , d1*c2 + e1*f2 + f1*i2 , g1*a2 + h1*d2 + i1*g2 , g1*b2 + h1*e2 + i1*h2 , g1*c2 + h1*f2 + i1*i2 ) def primal3 [n] (xs: [n](f32,f32,f32,f32,f32,f32,f32,f32,f32)) = scan mm3by3 (1,0,0, 0,1,0, 0,0,1) xs def fromarrs3 = map (\(x: [9]f32) -> (x[0],x[1],x[2],x[3],x[4],x[5],x[6],x[7],x[8])) def toarrs3 = map (\(a,b,c,d,e,f,g,h,i) -> [a,b,c,d,e,f,g,h,i]) def onehot_2d n m x y = tabulate_2d n m (\i j -> f32.bool((i,j) == (x,y))) entry fwd [n] (input: [n][9]f32) : [n][9][n][9]f32 = let input = fromarrs3 input in tabulate (n*9) (\i -> jvp primal3 input (fromarrs3 (onehot_2d n 9 (i/9) (i%9)))) |> map toarrs3 |> transpose |> map transpose |> map (map unflatten) entry rev [n] (input: [n][9]f32) : [n][9][n][9]f32 = let input = fromarrs3 input in tabulate (n*9) (\i -> vjp primal3 input (fromarrs3 (onehot_2d n 9 (i/9) (i%9)))) |> unflatten |> map (map toarrs3) futhark-0.25.27/tests/ad/scan9.fut000066400000000000000000000743261475065116200166410ustar00rootroot00000000000000-- Scan with 4x4 matrix multiplication. -- == -- entry: fwd rev -- compiled no_oclgrind input { -- [[1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32,10f32,11f32,12f32,13f32,14f32,15f32,16f32], -- [16f32,15f32,14f32,13f32,12f32,11f32,10f32,9f32,8f32,7f32,6f32,5f32,4f32,3f32,2f32,1f32], -- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32,10f32,11f32,12f32,13f32,14f32,15f32,16f32], -- [16f32,15f32,14f32,13f32,12f32,11f32,10f32,9f32,8f32,7f32,6f32,5f32,4f32,3f32,2f32,1f32]] -- } -- output { [[[[1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], -- [[0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, -- 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, -- 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, -- 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, -- 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 1f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 1f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 1f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 1f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[16f32, 12f32, 8f32, 4f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], -- [1f32, 0f32, 0f32, 0f32, 2f32, 0f32, 0f32, 0f32, 3f32, 0f32, 0f32, -- 0f32, 4f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32]], [[15f32, 11f32, 7f32, 3f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, -- 1f32, 0f32, 0f32, 0f32, 2f32, 0f32, 0f32, 0f32, 3f32, 0f32, 0f32, -- 0f32, 4f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32]], [[14f32, 10f32, 6f32, 2f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, -- 1f32, 0f32, 0f32, 0f32, 2f32, 0f32, 0f32, 0f32, 3f32, 0f32, 0f32, -- 0f32, 4f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32]], [[13f32, 9f32, 5f32, 1f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, -- 1f32, 0f32, 0f32, 0f32, 2f32, 0f32, 0f32, 0f32, 3f32, 0f32, 0f32, -- 0f32, 4f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32]], [[0f32, 0f32, 0f32, 0f32, 16f32, 12f32, 8f32, 4f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [5f32, 0f32, 0f32, 0f32, -- 6f32, 0f32, 0f32, 0f32, 7f32, 0f32, 0f32, 0f32, 8f32, 0f32, 0f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 0f32, 0f32, 15f32, 11f32, 7f32, 3f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 5f32, 0f32, 0f32, 0f32, -- 6f32, 0f32, 0f32, 0f32, 7f32, 0f32, 0f32, 0f32, 8f32, 0f32, 0f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 0f32, 0f32, 14f32, 10f32, 6f32, 2f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 5f32, 0f32, 0f32, -- 0f32, 6f32, 0f32, 0f32, 0f32, 7f32, 0f32, 0f32, 0f32, 8f32, 0f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 0f32, 0f32, 13f32, 9f32, 5f32, 1f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 5f32, 0f32, -- 0f32, 0f32, 6f32, 0f32, 0f32, 0f32, 7f32, 0f32, 0f32, 0f32, 8f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 16f32, 12f32, -- 8f32, 4f32, 0f32, 0f32, 0f32, 0f32], [9f32, 0f32, 0f32, 0f32, -- 10f32, 0f32, 0f32, 0f32, 11f32, 0f32, 0f32, 0f32, 12f32, 0f32, -- 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 15f32, -- 11f32, 7f32, 3f32, 0f32, 0f32, 0f32, 0f32], [0f32, 9f32, 0f32, -- 0f32, 0f32, 10f32, 0f32, 0f32, 0f32, 11f32, 0f32, 0f32, 0f32, -- 12f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 14f32, 10f32, 6f32, 2f32, 0f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 9f32, 0f32, 0f32, 0f32, 10f32, 0f32, 0f32, 0f32, 11f32, 0f32, -- 0f32, 0f32, 12f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 13f32, 9f32, 5f32, 1f32, 0f32, 0f32, 0f32, 0f32], -- [0f32, 0f32, 0f32, 9f32, 0f32, 0f32, 0f32, 10f32, 0f32, 0f32, 0f32, -- 11f32, 0f32, 0f32, 0f32, 12f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 16f32, 12f32, 8f32, -- 4f32], [13f32, 0f32, 0f32, 0f32, 14f32, 0f32, 0f32, 0f32, 15f32, -- 0f32, 0f32, 0f32, 16f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 15f32, 11f32, -- 7f32, 3f32], [0f32, 13f32, 0f32, 0f32, 0f32, 14f32, 0f32, 0f32, -- 0f32, 15f32, 0f32, 0f32, 0f32, 16f32, 0f32, 0f32], [0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 14f32, -- 10f32, 6f32, 2f32], [0f32, 0f32, 13f32, 0f32, 0f32, 0f32, 14f32, -- 0f32, 0f32, 0f32, 15f32, 0f32, 0f32, 0f32, 16f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 13f32, 9f32, 5f32, 1f32], [0f32, 0f32, 0f32, 13f32, 0f32, 0f32, -- 0f32, 14f32, 0f32, 0f32, 0f32, 15f32, 0f32, 0f32, 0f32, 16f32], -- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], -- [[[386f32, 274f32, 162f32, 50f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [1f32, 5f32, 9f32, -- 13f32, 2f32, 10f32, 18f32, 26f32, 3f32, 15f32, 27f32, 39f32, 4f32, -- 20f32, 36f32, 52f32], [80f32, 0f32, 0f32, 0f32, 70f32, 0f32, 0f32, -- 0f32, 60f32, 0f32, 0f32, 0f32, 50f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32]], [[444f32, 316f32, 188f32, 60f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], -- [2f32, 6f32, 10f32, 14f32, 4f32, 12f32, 20f32, 28f32, 6f32, 18f32, -- 30f32, 42f32, 8f32, 24f32, 40f32, 56f32], [0f32, 80f32, 0f32, 0f32, -- 0f32, 70f32, 0f32, 0f32, 0f32, 60f32, 0f32, 0f32, 0f32, 50f32, -- 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[502f32, 358f32, -- 214f32, 70f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32], [3f32, 7f32, 11f32, 15f32, 6f32, 14f32, -- 22f32, 30f32, 9f32, 21f32, 33f32, 45f32, 12f32, 28f32, 44f32, -- 60f32], [0f32, 0f32, 80f32, 0f32, 0f32, 0f32, 70f32, 0f32, 0f32, -- 0f32, 60f32, 0f32, 0f32, 0f32, 50f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32]], [[560f32, 400f32, 240f32, 80f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [4f32, 8f32, -- 12f32, 16f32, 8f32, 16f32, 24f32, 32f32, 12f32, 24f32, 36f32, -- 48f32, 16f32, 32f32, 48f32, 64f32], [0f32, 0f32, 0f32, 80f32, 0f32, -- 0f32, 0f32, 70f32, 0f32, 0f32, 0f32, 60f32, 0f32, 0f32, 0f32, -- 50f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, -- 0f32, 386f32, 274f32, 162f32, 50f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [5f32, 25f32, 45f32, 65f32, 6f32, 30f32, 54f32, -- 78f32, 7f32, 35f32, 63f32, 91f32, 8f32, 40f32, 72f32, 104f32], -- [240f32, 0f32, 0f32, 0f32, 214f32, 0f32, 0f32, 0f32, 188f32, 0f32, -- 0f32, 0f32, 162f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32]], [[0f32, 0f32, 0f32, 0f32, 444f32, 316f32, 188f32, 60f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [10f32, 30f32, -- 50f32, 70f32, 12f32, 36f32, 60f32, 84f32, 14f32, 42f32, 70f32, -- 98f32, 16f32, 48f32, 80f32, 112f32], [0f32, 240f32, 0f32, 0f32, -- 0f32, 214f32, 0f32, 0f32, 0f32, 188f32, 0f32, 0f32, 0f32, 162f32, -- 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, -- 0f32, 502f32, 358f32, 214f32, 70f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [15f32, 35f32, 55f32, 75f32, 18f32, 42f32, -- 66f32, 90f32, 21f32, 49f32, 77f32, 105f32, 24f32, 56f32, 88f32, -- 120f32], [0f32, 0f32, 240f32, 0f32, 0f32, 0f32, 214f32, 0f32, 0f32, -- 0f32, 188f32, 0f32, 0f32, 0f32, 162f32, 0f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 560f32, 400f32, 240f32, -- 80f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [20f32, -- 40f32, 60f32, 80f32, 24f32, 48f32, 72f32, 96f32, 28f32, 56f32, -- 84f32, 112f32, 32f32, 64f32, 96f32, 128f32], [0f32, 0f32, 0f32, -- 240f32, 0f32, 0f32, 0f32, 214f32, 0f32, 0f32, 0f32, 188f32, 0f32, -- 0f32, 0f32, 162f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 386f32, 274f32, 162f32, -- 50f32, 0f32, 0f32, 0f32, 0f32], [9f32, 45f32, 81f32, 117f32, 10f32, -- 50f32, 90f32, 130f32, 11f32, 55f32, 99f32, 143f32, 12f32, 60f32, -- 108f32, 156f32], [400f32, 0f32, 0f32, 0f32, 358f32, 0f32, 0f32, -- 0f32, 316f32, 0f32, 0f32, 0f32, 274f32, 0f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 444f32, 316f32, 188f32, 60f32, 0f32, 0f32, 0f32, 0f32], -- [18f32, 54f32, 90f32, 126f32, 20f32, 60f32, 100f32, 140f32, 22f32, -- 66f32, 110f32, 154f32, 24f32, 72f32, 120f32, 168f32], [0f32, -- 400f32, 0f32, 0f32, 0f32, 358f32, 0f32, 0f32, 0f32, 316f32, 0f32, -- 0f32, 0f32, 274f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 502f32, 358f32, -- 214f32, 70f32, 0f32, 0f32, 0f32, 0f32], [27f32, 63f32, 99f32, -- 135f32, 30f32, 70f32, 110f32, 150f32, 33f32, 77f32, 121f32, 165f32, -- 36f32, 84f32, 132f32, 180f32], [0f32, 0f32, 400f32, 0f32, 0f32, -- 0f32, 358f32, 0f32, 0f32, 0f32, 316f32, 0f32, 0f32, 0f32, 274f32, -- 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 560f32, 400f32, 240f32, 80f32, 0f32, 0f32, -- 0f32, 0f32], [36f32, 72f32, 108f32, 144f32, 40f32, 80f32, 120f32, -- 160f32, 44f32, 88f32, 132f32, 176f32, 48f32, 96f32, 144f32, -- 192f32], [0f32, 0f32, 0f32, 400f32, 0f32, 0f32, 0f32, 358f32, 0f32, -- 0f32, 0f32, 316f32, 0f32, 0f32, 0f32, 274f32], [0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 386f32, 274f32, 162f32, 50f32], [13f32, -- 65f32, 117f32, 169f32, 14f32, 70f32, 126f32, 182f32, 15f32, 75f32, -- 135f32, 195f32, 16f32, 80f32, 144f32, 208f32], [560f32, 0f32, 0f32, -- 0f32, 502f32, 0f32, 0f32, 0f32, 444f32, 0f32, 0f32, 0f32, 386f32, -- 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 444f32, -- 316f32, 188f32, 60f32], [26f32, 78f32, 130f32, 182f32, 28f32, -- 84f32, 140f32, 196f32, 30f32, 90f32, 150f32, 210f32, 32f32, 96f32, -- 160f32, 224f32], [0f32, 560f32, 0f32, 0f32, 0f32, 502f32, 0f32, -- 0f32, 0f32, 444f32, 0f32, 0f32, 0f32, 386f32, 0f32, 0f32], [0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 502f32, 358f32, 214f32, 70f32], -- [39f32, 91f32, 143f32, 195f32, 42f32, 98f32, 154f32, 210f32, 45f32, -- 105f32, 165f32, 225f32, 48f32, 112f32, 176f32, 240f32], [0f32, -- 0f32, 560f32, 0f32, 0f32, 0f32, 502f32, 0f32, 0f32, 0f32, 444f32, -- 0f32, 0f32, 0f32, 386f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 560f32, 400f32, 240f32, 80f32], [52f32, 104f32, 156f32, -- 208f32, 56f32, 112f32, 168f32, 224f32, 60f32, 120f32, 180f32, -- 240f32, 64f32, 128f32, 192f32, 256f32], [0f32, 0f32, 0f32, 560f32, -- 0f32, 0f32, 0f32, 502f32, 0f32, 0f32, 0f32, 444f32, 0f32, 0f32, -- 0f32, 386f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[17760f32, -- 12640f32, 7520f32, 2400f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [80f32, 240f32, 400f32, -- 560f32, 160f32, 480f32, 800f32, 1120f32, 240f32, 720f32, 1200f32, -- 1680f32, 320f32, 960f32, 1600f32, 2240f32], [1280f32, 960f32, -- 640f32, 320f32, 1120f32, 840f32, 560f32, 280f32, 960f32, 720f32, -- 480f32, 240f32, 800f32, 600f32, 400f32, 200f32], [1620f32, 0f32, -- 0f32, 0f32, 1880f32, 0f32, 0f32, 0f32, 2140f32, 0f32, 0f32, 0f32, -- 2400f32, 0f32, 0f32, 0f32]], [[15868f32, 11292f32, 6716f32, -- 2140f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32], [70f32, 214f32, 358f32, 502f32, 140f32, 428f32, -- 716f32, 1004f32, 210f32, 642f32, 1074f32, 1506f32, 280f32, 856f32, -- 1432f32, 2008f32], [1200f32, 880f32, 560f32, 240f32, 1050f32, -- 770f32, 490f32, 210f32, 900f32, 660f32, 420f32, 180f32, 750f32, -- 550f32, 350f32, 150f32], [0f32, 1620f32, 0f32, 0f32, 0f32, 1880f32, -- 0f32, 0f32, 0f32, 2140f32, 0f32, 0f32, 0f32, 2400f32, 0f32, 0f32]], -- [[13976f32, 9944f32, 5912f32, 1880f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [60f32, 188f32, -- 316f32, 444f32, 120f32, 376f32, 632f32, 888f32, 180f32, 564f32, -- 948f32, 1332f32, 240f32, 752f32, 1264f32, 1776f32], [1120f32, -- 800f32, 480f32, 160f32, 980f32, 700f32, 420f32, 140f32, 840f32, -- 600f32, 360f32, 120f32, 700f32, 500f32, 300f32, 100f32], [0f32, -- 0f32, 1620f32, 0f32, 0f32, 0f32, 1880f32, 0f32, 0f32, 0f32, -- 2140f32, 0f32, 0f32, 0f32, 2400f32, 0f32]], [[12084f32, 8596f32, -- 5108f32, 1620f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32], [50f32, 162f32, 274f32, 386f32, 100f32, -- 324f32, 548f32, 772f32, 150f32, 486f32, 822f32, 1158f32, 200f32, -- 648f32, 1096f32, 1544f32], [1040f32, 720f32, 400f32, 80f32, 910f32, -- 630f32, 350f32, 70f32, 780f32, 540f32, 300f32, 60f32, 650f32, -- 450f32, 250f32, 50f32], [0f32, 0f32, 0f32, 1620f32, 0f32, 0f32, -- 0f32, 1880f32, 0f32, 0f32, 0f32, 2140f32, 0f32, 0f32, 0f32, -- 2400f32]], [[0f32, 0f32, 0f32, 0f32, 17760f32, 12640f32, 7520f32, -- 2400f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [400f32, -- 1200f32, 2000f32, 2800f32, 480f32, 1440f32, 2400f32, 3360f32, -- 560f32, 1680f32, 2800f32, 3920f32, 640f32, 1920f32, 3200f32, -- 4480f32], [3840f32, 2880f32, 1920f32, 960f32, 3424f32, 2568f32, -- 1712f32, 856f32, 3008f32, 2256f32, 1504f32, 752f32, 2592f32, -- 1944f32, 1296f32, 648f32], [5108f32, 0f32, 0f32, 0f32, 5912f32, -- 0f32, 0f32, 0f32, 6716f32, 0f32, 0f32, 0f32, 7520f32, 0f32, 0f32, -- 0f32]], [[0f32, 0f32, 0f32, 0f32, 15868f32, 11292f32, 6716f32, -- 2140f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [350f32, -- 1070f32, 1790f32, 2510f32, 420f32, 1284f32, 2148f32, 3012f32, -- 490f32, 1498f32, 2506f32, 3514f32, 560f32, 1712f32, 2864f32, -- 4016f32], [3600f32, 2640f32, 1680f32, 720f32, 3210f32, 2354f32, -- 1498f32, 642f32, 2820f32, 2068f32, 1316f32, 564f32, 2430f32, -- 1782f32, 1134f32, 486f32], [0f32, 5108f32, 0f32, 0f32, 0f32, -- 5912f32, 0f32, 0f32, 0f32, 6716f32, 0f32, 0f32, 0f32, 7520f32, -- 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 13976f32, 9944f32, 5912f32, -- 1880f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [300f32, -- 940f32, 1580f32, 2220f32, 360f32, 1128f32, 1896f32, 2664f32, -- 420f32, 1316f32, 2212f32, 3108f32, 480f32, 1504f32, 2528f32, -- 3552f32], [3360f32, 2400f32, 1440f32, 480f32, 2996f32, 2140f32, -- 1284f32, 428f32, 2632f32, 1880f32, 1128f32, 376f32, 2268f32, -- 1620f32, 972f32, 324f32], [0f32, 0f32, 5108f32, 0f32, 0f32, 0f32, -- 5912f32, 0f32, 0f32, 0f32, 6716f32, 0f32, 0f32, 0f32, 7520f32, -- 0f32]], [[0f32, 0f32, 0f32, 0f32, 12084f32, 8596f32, 5108f32, -- 1620f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [250f32, -- 810f32, 1370f32, 1930f32, 300f32, 972f32, 1644f32, 2316f32, 350f32, -- 1134f32, 1918f32, 2702f32, 400f32, 1296f32, 2192f32, 3088f32], -- [3120f32, 2160f32, 1200f32, 240f32, 2782f32, 1926f32, 1070f32, -- 214f32, 2444f32, 1692f32, 940f32, 188f32, 2106f32, 1458f32, 810f32, -- 162f32], [0f32, 0f32, 0f32, 5108f32, 0f32, 0f32, 0f32, 5912f32, -- 0f32, 0f32, 0f32, 6716f32, 0f32, 0f32, 0f32, 7520f32]], [[0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 17760f32, 12640f32, -- 7520f32, 2400f32, 0f32, 0f32, 0f32, 0f32], [720f32, 2160f32, -- 3600f32, 5040f32, 800f32, 2400f32, 4000f32, 5600f32, 880f32, -- 2640f32, 4400f32, 6160f32, 960f32, 2880f32, 4800f32, 6720f32], -- [6400f32, 4800f32, 3200f32, 1600f32, 5728f32, 4296f32, 2864f32, -- 1432f32, 5056f32, 3792f32, 2528f32, 1264f32, 4384f32, 3288f32, -- 2192f32, 1096f32], [8596f32, 0f32, 0f32, 0f32, 9944f32, 0f32, 0f32, -- 0f32, 11292f32, 0f32, 0f32, 0f32, 12640f32, 0f32, 0f32, 0f32]], -- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 15868f32, -- 11292f32, 6716f32, 2140f32, 0f32, 0f32, 0f32, 0f32], [630f32, -- 1926f32, 3222f32, 4518f32, 700f32, 2140f32, 3580f32, 5020f32, -- 770f32, 2354f32, 3938f32, 5522f32, 840f32, 2568f32, 4296f32, -- 6024f32], [6000f32, 4400f32, 2800f32, 1200f32, 5370f32, 3938f32, -- 2506f32, 1074f32, 4740f32, 3476f32, 2212f32, 948f32, 4110f32, -- 3014f32, 1918f32, 822f32], [0f32, 8596f32, 0f32, 0f32, 0f32, -- 9944f32, 0f32, 0f32, 0f32, 11292f32, 0f32, 0f32, 0f32, 12640f32, -- 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 13976f32, 9944f32, 5912f32, 1880f32, 0f32, 0f32, 0f32, 0f32], -- [540f32, 1692f32, 2844f32, 3996f32, 600f32, 1880f32, 3160f32, -- 4440f32, 660f32, 2068f32, 3476f32, 4884f32, 720f32, 2256f32, -- 3792f32, 5328f32], [5600f32, 4000f32, 2400f32, 800f32, 5012f32, -- 3580f32, 2148f32, 716f32, 4424f32, 3160f32, 1896f32, 632f32, -- 3836f32, 2740f32, 1644f32, 548f32], [0f32, 0f32, 8596f32, 0f32, -- 0f32, 0f32, 9944f32, 0f32, 0f32, 0f32, 11292f32, 0f32, 0f32, 0f32, -- 12640f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 12084f32, 8596f32, 5108f32, 1620f32, 0f32, 0f32, 0f32, 0f32], -- [450f32, 1458f32, 2466f32, 3474f32, 500f32, 1620f32, 2740f32, -- 3860f32, 550f32, 1782f32, 3014f32, 4246f32, 600f32, 1944f32, -- 3288f32, 4632f32], [5200f32, 3600f32, 2000f32, 400f32, 4654f32, -- 3222f32, 1790f32, 358f32, 4108f32, 2844f32, 1580f32, 316f32, -- 3562f32, 2466f32, 1370f32, 274f32], [0f32, 0f32, 0f32, 8596f32, -- 0f32, 0f32, 0f32, 9944f32, 0f32, 0f32, 0f32, 11292f32, 0f32, 0f32, -- 0f32, 12640f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 17760f32, 12640f32, 7520f32, 2400f32], -- [1040f32, 3120f32, 5200f32, 7280f32, 1120f32, 3360f32, 5600f32, -- 7840f32, 1200f32, 3600f32, 6000f32, 8400f32, 1280f32, 3840f32, -- 6400f32, 8960f32], [8960f32, 6720f32, 4480f32, 2240f32, 8032f32, -- 6024f32, 4016f32, 2008f32, 7104f32, 5328f32, 3552f32, 1776f32, -- 6176f32, 4632f32, 3088f32, 1544f32], [12084f32, 0f32, 0f32, 0f32, -- 13976f32, 0f32, 0f32, 0f32, 15868f32, 0f32, 0f32, 0f32, 17760f32, -- 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 15868f32, 11292f32, 6716f32, -- 2140f32], [910f32, 2782f32, 4654f32, 6526f32, 980f32, 2996f32, -- 5012f32, 7028f32, 1050f32, 3210f32, 5370f32, 7530f32, 1120f32, -- 3424f32, 5728f32, 8032f32], [8400f32, 6160f32, 3920f32, 1680f32, -- 7530f32, 5522f32, 3514f32, 1506f32, 6660f32, 4884f32, 3108f32, -- 1332f32, 5790f32, 4246f32, 2702f32, 1158f32], [0f32, 12084f32, -- 0f32, 0f32, 0f32, 13976f32, 0f32, 0f32, 0f32, 15868f32, 0f32, 0f32, -- 0f32, 17760f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 13976f32, 9944f32, 5912f32, -- 1880f32], [780f32, 2444f32, 4108f32, 5772f32, 840f32, 2632f32, -- 4424f32, 6216f32, 900f32, 2820f32, 4740f32, 6660f32, 960f32, -- 3008f32, 5056f32, 7104f32], [7840f32, 5600f32, 3360f32, 1120f32, -- 7028f32, 5020f32, 3012f32, 1004f32, 6216f32, 4440f32, 2664f32, -- 888f32, 5404f32, 3860f32, 2316f32, 772f32], [0f32, 0f32, 12084f32, -- 0f32, 0f32, 0f32, 13976f32, 0f32, 0f32, 0f32, 15868f32, 0f32, 0f32, -- 0f32, 17760f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, -- 0f32, 0f32, 0f32, 0f32, 0f32, 12084f32, 8596f32, 5108f32, 1620f32], -- [650f32, 2106f32, 3562f32, 5018f32, 700f32, 2268f32, 3836f32, -- 5404f32, 750f32, 2430f32, 4110f32, 5790f32, 800f32, 2592f32, -- 4384f32, 6176f32], [7280f32, 5040f32, 2800f32, 560f32, 6526f32, -- 4518f32, 2510f32, 502f32, 5772f32, 3996f32, 2220f32, 444f32, -- 5018f32, 3474f32, 1930f32, 386f32], [0f32, 0f32, 0f32, 12084f32, -- 0f32, 0f32, 0f32, 13976f32, 0f32, 0f32, 0f32, 15868f32, 0f32, 0f32, -- 0f32, 17760f32]]]] } def mm4by4 (a0,b0,c0,d0,e0,f0,g0,h0,i0,j0,k0,l0,m0,n0,o0,p0) (a1,b1,c1,d1,e1,f1,g1,h1,i1,j1,k1,l1,m1,n1,o1,p1:f32) = ( a0*a1 + b0*e1 + c0*i1 + d0*m1 , a0*b1 + b0*f1 + c0*j1 + d0*n1 , a0*c1 + b0*g1 + c0*k1 + d0*o1 , a0*d1 + b0*h1 + c0*l1 + d0*p1 , e0*a1 + f0*e1 + g0*i1 + h0*m1 , e0*b1 + f0*f1 + g0*j1 + h0*n1 , e0*c1 + f0*g1 + g0*k1 + h0*o1 , e0*d1 + f0*h1 + g0*l1 + h0*p1 , i0*a1 + j0*e1 + k0*i1 + l0*m1 , i0*b1 + j0*f1 + k0*j1 + l0*n1 , i0*c1 + j0*g1 + k0*k1 + l0*o1 , i0*d1 + j0*h1 + k0*l1 + l0*p1 , m0*a1 + n0*e1 + o0*i1 + p0*m1 , m0*b1 + n0*f1 + o0*j1 + p0*n1 , m0*c1 + n0*g1 + o0*k1 + p0*o1 , m0*d1 + n0*h1 + o0*l1 + p0*p1 ) def primal2 [n] (xs: [n](f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32)) = scan mm4by4 (1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1) xs def fromarrs2 = map (\(x: [16]f32) -> (x[0],x[1],x[2],x[3],x[4],x[5],x[6],x[7],x[8],x[9],x[10],x[11],x[12],x[13],x[14],x[15])) def toarrs2 = map (\(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p) -> [a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p]) def onehot_2d n m x y = tabulate_2d n m (\i j -> f32.bool((i,j) == (x,y))) entry fwd [n] (input: [n][16]f32) : [n][16][n][16]f32 = let input = fromarrs2 input in tabulate (n*16) (\i -> jvp primal2 input (fromarrs2 (onehot_2d n 16 (i/16) (i%16)))) |> map toarrs2 |> transpose |> map transpose |> map (map unflatten) entry rev [n] (input: [n][16]f32) : [n][16][n][16]f32 = let input = fromarrs2 input in tabulate (n*16) (\i -> vjp primal2 input (fromarrs2 (onehot_2d n 16 (i/16) (i%16)))) |> unflatten |> map (map toarrs2) futhark-0.25.27/tests/ad/scangenbenchtests.fut000066400000000000000000000146521475065116200213210ustar00rootroot00000000000000-- == -- entry: testmm2by2 -- compiled random input { [250][4]i32 } output { true } -- == -- entry: testmm3by3 -- compiled random input { [111][9]i32 } output { true } -- == -- entry: testmm4by4 -- compiled random input { [62][16]i32 } output { true } -- == -- entry: testlin -- compiled random input { [500][2]i32 } output { true } -- == -- entry: testlin2by2 -- compiled random input { [166][6]i32 } output { true } def mm2by2 (a1: i32, b1: i32, c1: i32, d1: i32) (a2: i32, b2: i32, c2: i32, d2: i32) = ( a1*a2 + b1*c2 , a1*b2 + b1*d2 , c1*a2 + d1*c2 , c1*b2 + d1*d2 ) def primal2 [n] (xs: [n](i32,i32,i32,i32)) = scan mm2by2 (1, 0, 0, 1) xs def fromarrs2 = map (\(x: [4]i32) -> (x[0],x[1],x[2],x[3])) def toarrs2 = map (\(a,b,c,d) -> [a,b,c,d]) def onehot_2d n m x y = tabulate_2d n m (\i j -> i32.bool((i,j) == (x,y))) def fwd_J2 [n] (input: [n][4]i32) : [n][4][n][4]i32 = let input = fromarrs2 input in tabulate (n*4) (\i -> jvp primal2 input (fromarrs2 (onehot_2d n 4 (i/4) (i%4)))) |> map toarrs2 |> transpose |> map transpose |> map (map unflatten) def rev_J2 [n] (input: [n][4]i32) : [n][4][n][4]i32 = let input = fromarrs2 input in tabulate (n*4) (\i -> vjp primal2 input (fromarrs2 (onehot_2d n 4 (i/4) (i%4)))) |> unflatten |> map (map toarrs2) entry testmm2by2 [n] (input: [n][4]i32) = let fwd = fwd_J2 input let rev = rev_J2 input in map2 (map2 (map2 (==))) rev fwd |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true def mm3by3 (a1: i32, b1: i32, c1: i32, d1: i32, e1: i32, f1: i32, g1: i32, h1: i32, i1: i32) (a2: i32, b2: i32, c2: i32, d2: i32, e2: i32, f2: i32, g2: i32, h2: i32, i2: i32) = ( a1*a2 + b1*d2 + c1*g2 , a1*b2 + b1*e2 + c1*h2 , a1*c2 + b1*f2 + c1*i2 , d1*a2 + e1*d2 + f1*g2 , d1*b2 + e1*e2 + f1*h2 , d1*c2 + e1*f2 + f1*i2 , g1*a2 + h1*d2 + i1*g2 , g1*b2 + h1*e2 + i1*h2 , g1*c2 + h1*f2 + i1*i2 ) def primal3 [n] (xs: [n](i32,i32,i32,i32,i32,i32,i32,i32,i32)) = scan mm3by3 (1,0,0, 0,1,0, 0,0,1) xs def fromarrs3 = map (\(x: [9]i32) -> (x[0],x[1],x[2],x[3],x[4],x[5],x[6],x[7],x[8])) def toarrs3 = map (\(a,b,c,d,e,f,g,h,i) -> [a,b,c,d,e,f,g,h,i]) def fwd_J3 [n] (input: [n][9]i32) : [n][9][n][9]i32 = let input = fromarrs3 input in tabulate (n*9) (\i -> jvp primal3 input (fromarrs3 (onehot_2d n 9 (i/9) (i%9)))) |> map toarrs3 |> transpose |> map transpose |> map (map unflatten) def rev_J3 [n] (input: [n][9]i32) : [n][9][n][9]i32 = let input = fromarrs3 input in tabulate (n*9) (\i -> vjp primal3 input (fromarrs3 (onehot_2d n 9 (i/9) (i%9)))) |> unflatten |> map (map toarrs3) entry testmm3by3 [n] (input: [n][9]i32) = let fwd = fwd_J3 input let rev = rev_J3 input in map2 (map2 (map2 (==))) rev fwd |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true def mm4by4 (a0,b0,c0,d0,e0,f0,g0,h0,i0,j0,k0,l0,m0,n0,o0,p0) (a1,b1,c1,d1,e1,f1,g1,h1,i1,j1,k1,l1,m1,n1,o1,p1:i32) = ( a0*a1 + b0*e1 + c0*i1 + d0*m1 , a0*b1 + b0*f1 + c0*j1 + d0*n1 , a0*c1 + b0*g1 + c0*k1 + d0*o1 , a0*d1 + b0*h1 + c0*l1 + d0*p1 , e0*a1 + f0*e1 + g0*i1 + h0*m1 , e0*b1 + f0*f1 + g0*j1 + h0*n1 , e0*c1 + f0*g1 + g0*k1 + h0*o1 , e0*d1 + f0*h1 + g0*l1 + h0*p1 , i0*a1 + j0*e1 + k0*i1 + l0*m1 , i0*b1 + j0*f1 + k0*j1 + l0*n1 , i0*c1 + j0*g1 + k0*k1 + l0*o1 , i0*d1 + j0*h1 + k0*l1 + l0*p1 , m0*a1 + n0*e1 + o0*i1 + p0*m1 , m0*b1 + n0*f1 + o0*j1 + p0*n1 , m0*c1 + n0*g1 + o0*k1 + p0*o1 , m0*d1 + n0*h1 + o0*l1 + p0*p1 ) def primal4 [n] (xs: [n](i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32)) = scan mm4by4 (1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1) xs def fromarrs4 = map (\(x: [16]i32) -> (x[0],x[1],x[2],x[3],x[4],x[5],x[6],x[7],x[8],x[9],x[10],x[11],x[12],x[13],x[14],x[15])) def toarrs4 = map (\(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p) -> [a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p]) def fwd_J4 [n] (input: [n][16]i32) : [n][16][n][16]i32 = let input = fromarrs4 input in tabulate (n*16) (\i -> jvp primal4 input (fromarrs4 (onehot_2d n 16 (i/16) (i%16)))) |> map toarrs4 |> transpose |> map transpose |> map (map unflatten) def rev_J4 [n] (input: [n][16]i32) : [n][16][n][16]i32 = let input = fromarrs4 input in tabulate (n*16) (\i -> vjp primal4 input (fromarrs4 (onehot_2d n 16 (i/16) (i%16)))) |> unflatten |> map (map toarrs4) entry testmm4by4 [n] (input: [n][16]i32) = let fwd = fwd_J4 input let rev = rev_J4 input in map2 (map2 (map2 (==))) rev fwd |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true def primallin [n] (xs: [n](i32,i32)) = scan (\(a1,b1) (a2,b2) -> (a2 + b2*a1, b1*b2)) (0,1) xs def fromarrslin = map (\x -> (x[0],x[1])) def toarrslin = map (\(a,b) -> [a,b]) def fwd_Jlin [n] (input: [n][2]i32) = let input = fromarrslin input in tabulate (n*2) (\i -> jvp primallin input (fromarrslin (onehot_2d n 2 (i/2) (i%2)))) |> map toarrslin |> transpose |> map transpose |> map (map unflatten) def rev_Jlin [n] (input: [n][2]i32) = let input = fromarrslin input in tabulate (n*2) (\i -> vjp primallin input (fromarrslin (onehot_2d n 2 (i/2) (i%2)))) |> unflatten |> map (map toarrslin) entry testlin [n] (input: [n][2]i32) = let fwd = fwd_Jlin input let rev = rev_Jlin input in map2 (map2 (map2 (==))) rev fwd |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true def mv2 (a, b, c, d) (e, f): (i32,i32) = ( a*e + b*f , c*e + d*f) def vv2 (a, b) (c, d): (i32,i32) = ( a+c , b+d) def lino2by2 (d1,c1) (d2,c2) : ((i32,i32), (i32,i32,i32,i32)) = (vv2 d2 (mv2 c2 d1),mm2by2 c2 c1) def primallin2 [n] (as: [n]((i32,i32), (i32,i32,i32,i32))) = scan lino2by2 ((0,0),(1,0,0,1)) as def fromarrslin2 = map (\x -> ((x[0],x[1]),(x[2],x[3],x[4],x[5]))) def toarrslin2 = map (\((a,b),(c,d,e,f)) -> [a,b,c,d,e,f]) def fwd_Jlin2 [n] (input: [n][6]i32) : [n][6][n][6]i32 = let input = fromarrslin2 input in tabulate (n*6) (\i -> jvp primallin2 input (fromarrslin2 (onehot_2d n 6 (i/6) (i%6)))) |> map toarrslin2 |> transpose |> map transpose |> map (map unflatten) def rev_Jlin2 [n] (input: [n][6]i32) : [n][6][n][6]i32 = let input = fromarrslin2 input in tabulate (n*6) (\i -> vjp primallin2 input (fromarrslin2 (onehot_2d n 6 (i/6) (i%6)))) |> unflatten |> map (map toarrslin2) entry testlin2by2 [n] (input: [n][6]i32) = let fwd = fwd_Jlin2 input let rev = rev_Jlin2 input in map2 (map2 (map2 (==))) rev fwd |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true futhark-0.25.27/tests/ad/scanvecbenchtests.fut000066400000000000000000000042311475065116200213150ustar00rootroot00000000000000-- == -- entry: add mul -- compiled random input { [1000]i32 } output { true } -- == -- entry: vecadd vecmul -- compiled random input { [100][100]i32 } output { true } def add_primal [n] (as: [n]i32) = scan (+) 0 as def add_fwd [n] (as: [n]i32) = tabulate n (\i -> jvp add_primal as (replicate n 0 with [i] = 1)) |> transpose def add_rev [n] (as: [n]i32) = tabulate n (\i -> vjp add_primal as (replicate n 0 with [i] = 1)) entry add [n] (as: [n]i32) = let rev = add_rev as let fwd = add_fwd as in map2 (map2 (==)) rev fwd |> map (reduce (&&) true) |> reduce (&&) true def vecadd_primal [n][m] (as: [n][m]i32) = scan (map2 (+)) (replicate m 0) as def vecadd_fwd [n][m] (as: [n][m]i32) = tabulate n (\i -> jvp vecadd_primal as (replicate n (replicate m 0) with [i] = replicate m 1)) |> transpose def vecadd_rev [n][m] (as: [n][m]i32) = tabulate n (\i -> vjp vecadd_primal as (replicate n (replicate m 0) with [i] = replicate m 1)) entry vecadd [n][m] (as: [n][m]i32) = let rev = vecadd_rev as let fwd = vecadd_fwd as in map2 (map2 (map2 (==))) rev fwd |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true def mul_primal [n] (as: [n]i32) = scan (*) 1 as def mul_fwd [n] (as: [n]i32)= tabulate n (\i -> jvp mul_primal as (replicate n 0 with [i] = 1)) |> transpose def mul_rev [n] (as: [n]i32)= tabulate n (\i -> vjp mul_primal as (replicate n 0 with [i] = 1)) entry mul [n] (as': [n]i32) = let as = map (\a -> i32.abs a % 2) as' let rev = mul_rev as let fwd = mul_fwd as in map2 (map2 (==)) rev fwd |> map (reduce (&&) true) |> reduce (&&) true def vecmul_primal [n][m] (as: [n][m]i32) = scan (map2 (*)) (replicate m 1) as def vecmul_fwd [n][m] (as: [n][m]i32) = tabulate n (\i -> jvp vecmul_primal as (replicate n (replicate m 0) with [i] = replicate m 1)) |> transpose def vecmul_rev [n][m] (as: [n][m]i32) = tabulate n (\i -> vjp vecmul_primal as (replicate n (replicate m 0) with [i] = replicate m 1)) entry vecmul [n][m] (as: [n][m]i32) = let rev = vecmul_rev as let fwd = vecmul_fwd as in map2 (map2 (map2 (==))) rev fwd |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true futhark-0.25.27/tests/ad/scatter0.fut000066400000000000000000000014201475065116200173320ustar00rootroot00000000000000-- Simple scatter, differentiating wrt. values. -- == -- entry: fwd rev -- input { [0f64, 0f64, 0f64, 0f64] [0i64, 1i64, 2i64, 3i64] [1f64, 2f64, 3f64, 0f64] } -- output { -- [[1.000000f64, 0.000000f64, 0.000000f64, 0.000000f64], -- [0.000000f64, 1.000000f64, 0.000000f64, 0.000000f64], -- [0.000000f64, 0.000000f64, 1.000000f64, 0.000000f64], -- [0.000000f64, 0.000000f64, 0.000000f64, 1.000000f64]] -- } def f [n][k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = scatter (copy xs) is vs entry fwd [n][k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = let g i = jvp (\vs -> f xs is vs) vs (replicate n 0 with [i] = 1) in tabulate n g entry rev [n][k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = let g i = vjp (\vs -> f xs is vs) vs (replicate k 0 with [i] = 1) in tabulate n g futhark-0.25.27/tests/ad/scatter1.fut000066400000000000000000000013701475065116200173370ustar00rootroot00000000000000-- Simple scatter, differentiating wrt. target. -- == -- entry: fwd rev -- input { [0f64, 0f64, 0f64, 0f64] [0i64, 1i64] [1f64, 2f64] } -- output { -- [[0.000000f64, 0.000000f64, 0.000000f64, 0.000000f64], -- [0.000000f64, 0.000000f64, 0.000000f64, 0.000000f64], -- [0.000000f64, 0.000000f64, 1.000000f64, 0.000000f64], -- [0.000000f64, 0.000000f64, 0.000000f64, 1.000000f64]] -- } def f [n][k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = scatter (copy xs) is vs entry fwd [n][k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = let g i = jvp (\xs -> f xs is vs) xs (replicate k 0 with [i] = 1) in tabulate k g entry rev [n][k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = let g i = vjp (\xs -> f xs is vs) xs (replicate k 0 with [i] = 1) in tabulate k g futhark-0.25.27/tests/ad/sdf.fut000066400000000000000000000040401475065116200163620ustar00rootroot00000000000000-- Signed Distance Functions, as you would find in a ray marcher. -- == -- entry: jvp_normal vjp_normal -- input { 0i32 0f64 1f64 0f64 } output { 1f64 0f64 0f64 } -- input { 1i32 0f64 1f64 0f64 } output { 0.412393f64 0.907265f64 -0.082479f64 } -- input { 2i32 0f64 1f64 0f64 } output { -0.375775f64 0.903687f64 -0.205287f64 } type Vec = {x:f64, y: f64, z: f64} type Angle = f64 type Position = Vec type Radiance = Vec type Direction = Vec type Distance = f64 type Radius = f64 type BlockHalfWidths = Vec def vmap (f: f64 -> f64) (v : Vec) = {x = f v.x, y = f v.y, z = f v.z} def (a: Vec) <-> (b: Vec) = {x = a.x-b.x, y = a.y-b.y, z = a.z-b.z} def rotateY (theta: f64) ({x,y,z} : Vec) = let cos_theta = f64.cos theta let sin_theta = f64.sin theta in { x = f64.(cos_theta * x - sin_theta * z) , y , z = f64.(sin_theta * x + cos_theta * z)} def dot (a: Vec) (b: Vec) : f64 = a.x * b.x + a.y * b.y + a.z * b.z def norm v = f64.sqrt (dot v v) type Object = #Wall Direction Distance | #Block Position BlockHalfWidths Angle | #Sphere Position Radius def sdObject (obj:Object) (pos:Position) : Distance = match obj case #Wall nor d -> f64.(d + dot nor pos) case #Block blockPos halfWidths angle -> let pos' = rotateY angle (pos <-> blockPos) in norm (vmap (f64.max 0) (vmap f64.abs pos' <-> halfWidths)) case #Sphere spherePos r -> let pos' = pos <-> spherePos in f64.(max (norm pos' - r) 0) def vec x y z : Vec = {x,y,z} def unvec ({x,y,z}: Vec) = (x,y,z) def wall : Object = #Wall (vec 1 0 0) 2 def sphere : Object = #Sphere (vec (-1.0) (-1.2) 0.2) 0.8 def block : Object = #Block (vec 1.0 (-1.6) 1.2) (vec 0.6 0.8 0.6) (-0.5) def get_obj (i: i32) = match i case 0 -> wall case 1 -> sphere case _ -> block entry jvp_normal (obji: i32) x y z = let f i = jvp (sdObject (get_obj obji)) (vec x y z) (vec (f64.bool(i == 0)) (f64.bool(i == 1)) (f64.bool(i == 2))) in (f 0, f 1, f 2) entry vjp_normal (obji: i32) x y z = unvec (vjp (sdObject (get_obj obji)) (vec x y z) 1) futhark-0.25.27/tests/ad/stripmine0.fut000066400000000000000000000004511475065116200177020ustar00rootroot00000000000000def pow y x = #[stripmine(3)] loop acc = 1 for _i < y do acc * x -- == -- entry: f_jvp f_vjp -- input { 3 4 } output { 48 } -- input { 9 3 } output { 59049 } -- compiled input { 1000000 1 } output { 1000000 } entry f_jvp y x = jvp (pow y) x 1 entry f_vjp y x = vjp (pow y) x 1 futhark-0.25.27/tests/ad/stripmine1.fut000066400000000000000000000012401475065116200177000ustar00rootroot00000000000000def square [n] (xs: [n]i32) = let xs' = copy xs in #[stripmine(2)] loop xs'' = xs' for i < n do let a = xs''[i] in xs'' with [i] = a * a -- == -- entry: prim -- input { [1,2,3,4,5] } output { [1,4,9,16,25] } entry prim [n] (xs: [n]i32) = square xs -- == -- entry: f_jvp f_vjp -- input { [1,2,3,4,5] } -- output { [[2,0,0,0,0], -- [0,4,0,0,0], -- [0,0,6,0,0], -- [0,0,0,8,0], -- [0,0,0,0,10]] -- } entry f_jvp [n] (xs :[n]i32) = tabulate n (\i -> jvp square xs (replicate n 0 with [i] = 1)) |> transpose entry f_vjp [n] (xs :[n]i32) = tabulate n (\i -> vjp square xs (replicate n 0 with [i] = 1)) futhark-0.25.27/tests/ad/stripmine2.fut000066400000000000000000000010651475065116200177060ustar00rootroot00000000000000def pow_list [n] y (xs :[n]i32) = #[stripmine(2)] loop accs = (replicate n 1) for _i < y do map2 (*) accs xs -- == -- entry: prim -- input { 3 [1,2,3] } output { [1,8,27] } entry prim y xs = pow_list y xs -- == -- entry: f_vjp f_jvp -- input { 3 [1,2,3] } -- output { [[3,0,0], -- [0,12,0], -- [0,0,27]] -- } entry f_jvp [n] y (xs :[n]i32) = tabulate n (\i -> jvp (pow_list y) xs (replicate n 0 with [i] = 1)) |> transpose entry f_vjp [n] y (xs :[n]i32) = tabulate n (\i -> vjp (pow_list y) xs (replicate n 0 with [i] = 1)) futhark-0.25.27/tests/ad/stripmine3.fut000066400000000000000000000006701475065116200177100ustar00rootroot00000000000000def test [n] (xs: [n]i32) = let xs' = copy xs in #[stripmine(2)] loop xs'' = xs' for i < n do let foo = xs'' with [i] = 1 let m = map (\x -> x) foo in foo with [i] = m[i] -- == -- entry: prim -- input { [1,2,3,4,5] } output { [1,1,1,1,1] } entry prim [n] (xs: [n]i32) = test xs -- == -- entry: f_vjp -- input { [1,2,3,4,5] } output { [0,0,0,0,0] } entry f_vjp [n] (xs: [n]i32) = vjp test xs (replicate n 1) futhark-0.25.27/tests/ad/sum.fut000066400000000000000000000005151475065116200164150ustar00rootroot00000000000000-- Simple reduce with summation. -- == -- entry: rev fwd -- input { [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] } -- output { [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] } def sum [n] (xs: [n]f64) = reduce (+) 0 xs entry rev [n] (xs: [n]f64) = vjp sum xs 1 entry fwd [n] (xs: [n]f64) = tabulate n (\i -> jvp sum xs (tabulate n ((==i) >-> f64.bool))) futhark-0.25.27/tests/ad/truedep0.fut000066400000000000000000000011641475065116200173420ustar00rootroot00000000000000def test [n] (xs: [n]i32) = loop #[true_dep] xs' = copy xs for i < (n - 1) do xs' with [i+1] = xs'[i] * xs'[i] -- == -- entry: prim -- input { [2,2,3,4,5] } output { [2,4,16,256,65536] } entry prim [n] (xs: [n]i32) = test xs -- == -- entry: f_jvp f_vjp -- input { [1,2,3,4,5] } -- output { [[1,0,0,0,0], -- [2,0,0,0,0], -- [4,0,0,0,0], -- [8,0,0,0,0], -- [16,0,0,0,0]] -- } entry f_jvp [n] (xs :[n]i32) = tabulate n (\i -> jvp test xs (replicate n 0 with [i] = 1)) |> transpose entry f_vjp [n] (xs :[n]i32) = tabulate n (\i -> vjp test xs (replicate n 0 with [i] = 1)) futhark-0.25.27/tests/ad/truedep1.fut000066400000000000000000000006141475065116200173420ustar00rootroot00000000000000entry test [n][m] (xss: [n][m]f32) = loop #[true_dep] xss' = copy xss for i < n do loop #[true_dep] xss'' = copy xss' for j < m do xss'' with [i,j] = xss''[i-1,j-1] * xss''[i-1,j] * xss''[i,j-1] -- == -- entry: prim entry prim [n][m] (xss: [n][m]f32) = test xss -- == -- entry: f_vjp entry f_vjp [n][m] (xss: [n][m]f32) = vjp test xss ((replicate n (replicate m 0)) with [0,0] = 1) futhark-0.25.27/tests/ad/while0.fut000066400000000000000000000006161475065116200170030ustar00rootroot00000000000000def pow y x = let (_, res) = #[bound(1000)] loop (i, acc) = (0, 1) while i < y do (i + 1, acc * x) in res -- == -- entry: prim -- input { 3 4 } output { 64 } -- input { 9 3 } output { 19683 } entry prim y x = pow y x -- == -- entry: f_jvp f_vjp -- input { 3 4 } output { 48 } -- input { 9 3 } output { 59049 } entry f_jvp y x = jvp (pow y) x 1 entry f_vjp y x = vjp (pow y) x 1 futhark-0.25.27/tests/ad/while1.fut000066400000000000000000000005731475065116200170060ustar00rootroot00000000000000def pow y x = let (_, res) = loop (i, acc) = (0, 1) while i < y do (i + 1, acc * x) in res -- == -- entry: prim -- input { 3 4 } output { 64 } -- input { 9 3 } output { 19683 } entry prim y x = pow y x -- == -- entry: f_jvp f_vjp -- input { 3 4 } output { 48 } -- input { 9 3 } output { 59049 } entry f_jvp y x = jvp (pow y) x 1 entry f_vjp y x = vjp (pow y) x 1 futhark-0.25.27/tests/allocs.fut000066400000000000000000000006731475065116200165070ustar00rootroot00000000000000-- Nasty program that tries to leak memory. If we can run this -- without leaking, then we're doing well. -- == -- compiled input { [0, 1000, 42, 1001, 50000] } -- output { 1300103225i64 } def main [n] (a: [n]i32): i64 = let b = loop b = iota(10) for i < n do (let m = i64.i32 a[i] in if m < length b then b else map (\j -> j + b[j % length b]) ( iota(m))) in reduce (+) 0 b futhark-0.25.27/tests/american_option.fut000066400000000000000000000031661475065116200204010ustar00rootroot00000000000000-- Port of Ken Friis Larsens pricer for American Put Options: -- -- https://github.com/kfl/american-options. -- -- This implementation is a straightforward sequential port - it is -- fairly slow on the GPU. -- -- == -- tags { no_python } -- compiled input { 1 } output { 6.745048f32 } -- compiled input { 8 } output { 13.943413f32 } -- compiled input { 16 } output { 16.218975f32 } -- compiled input { 30 } output { 17.648781f32 } -- constants def strike(): i32 = 100 def bankDays(): i32 = 252 def s0(): i32 = 100 def r(): f32 = 0.03 def alpha(): f32 = 0.07 def sigma(): f32 = 0.20 def binom(expiry: i32): f32 = let n = i64.i32 (expiry * bankDays()) let dt = f32.i32(expiry) / f32.i64(n) let u = f32.exp(alpha()*dt+sigma()*f32.sqrt(dt)) let d = f32.exp(alpha()*dt-sigma()*f32.sqrt(dt)) let stepR = f32.exp(r()*dt) let q = (stepR-d)/(u-d) let qUR = q/stepR let qDR = (1.0-q)/stepR let np1 = n+1 let uPow = map (u**) (map f32.i64 (iota np1)) let dPow = map (d**) (map f32.i64 (map (n-) (iota np1))) let st = map (f32.i32(s0())*) (map2 (*) uPow dPow) let finalPut = map (f32.max(0.0)) (map (f32.i32(strike())-) st) in let put = loop put = finalPut for i in reverse (map (1+) (iota n)) do let uPow_start = take i uPow let dPow_end = drop (n+1-i) dPow :> [i]f32 let st = map (f32.i32(s0())*) (map2 (*) uPow_start dPow_end) let put_tail = tail put :> [i]f32 let put_init = init put :> [i]f32 in map (\(x,y) -> f32.max x y) (zip (map (f32.i32(strike())-) st) (map2 (+) (map (qUR*) (put_tail)) (map (qDR*) (put_init)))) in put[0] def main(expiry: i32): f32 = binom(expiry) futhark-0.25.27/tests/apply-or-index.fut000066400000000000000000000003521475065116200200740ustar00rootroot00000000000000-- Test that we can distinguish function application with literal -- array argument from array indexing. -- == -- input { 1 } output { 3 } def f(xs: []i32): i32 = xs[0] def a: []i32 = [1,2,3] def main(x: i32): i32 = f [x] + a[x] futhark-0.25.27/tests/array14-running-example.fut000066400000000000000000000007111475065116200216150ustar00rootroot00000000000000-- Example program from the ARRAY'14 paper. -- == def main [k][m][n] (xs: [k]i64, as: [m][n]f64): [][]f64 = map (\(e: (i64, []f64)) -> #[unsafe] let (i, a) = e in let a = loop a = copy a for j < n do let a[j] = a[ xs[j] ] * 2.0 in a in map (\(j: i64): f64 -> if (j < 2*i) && (xs[j] == j) then a[j*i] else 0.0 ) (iota(n)) ) (zip (iota(m)) as ) futhark-0.25.27/tests/arrayInTupleArray.fut000066400000000000000000000002321475065116200206370ustar00rootroot00000000000000-- == -- input { -- } -- output { -- [1, 2, 3] -- true -- } def main: ([]i32,bool) = let arr = [ ([1,2,3],true) , ([4,5,6],false) ] in arr[0] futhark-0.25.27/tests/arraylit.fut000066400000000000000000000004231475065116200170520ustar00rootroot00000000000000-- Array literals should work even if their safety cannot be -- determined until runtime. -- -- == -- input { 2i64 2i64 } output { [[0i64,1i64], [3i64, 3i64]] } -- input { 2i64 3i64 } error: Error def main (n: i64) (m: i64): [][]i64 = [iota n, replicate m 3i64 :> [n]i64] futhark-0.25.27/tests/arraylit1.fut000066400000000000000000000003001475065116200171250ustar00rootroot00000000000000-- Array literals inside parallel section and with an in-place update. -- == -- input { 3 } output { [[1,0,0],[1,1,0],[1,2,0]] } def main(x: i32) = map (\y -> [1,0,0] with [1] = y) (0.. [k2p2][N]f32 futhark-0.25.27/tests/ascription0.fut000066400000000000000000000002721475065116200174600ustar00rootroot00000000000000-- Make sure type errors due to invalid type ascriptions are caught. -- -- == -- error: match def main(x: i32, y:i32): i32 = let (((a): i32), b: i32) : (bool,bool) = (x,y) in (a,b) futhark-0.25.27/tests/ascription1.fut000066400000000000000000000001511475065116200174550ustar00rootroot00000000000000-- Basic expression-level type ascription error. -- -- == -- error: f64.*i32 def main(x: i32) = x : f64 futhark-0.25.27/tests/ascription2.fut000066400000000000000000000003701475065116200174610ustar00rootroot00000000000000-- Array type ascription. -- -- == -- input { [[1,2],[3,4]] 2i64 2i64 } output { [[1,2],[3,4]] } -- input { [[1,2],[3,4]] 1i64 4i64 } error: cannot match shape of type.*`\[1\]\[4\] def main [n][m] (x: [n][m]i32) (a: i64) (b: i64) = x :> [a][b]i32 futhark-0.25.27/tests/ascription3.fut000066400000000000000000000002431475065116200174610ustar00rootroot00000000000000-- Array type ascription cannot change the rank of an array. -- -- == -- error: Expression does not have expected type def main [n][m] (x: [n][m]i32) = x : []i32 futhark-0.25.27/tests/ascription4.fut000066400000000000000000000001301475065116200174550ustar00rootroot00000000000000-- Ascription inside a lambda applies to the lambda body. def main = \x -> x + 2 : i32 futhark-0.25.27/tests/assert0.fut000066400000000000000000000002031475065116200166000ustar00rootroot00000000000000-- Basic assertion -- == -- input { 2 } output { 1 } -- input { 3 } error: x % 2 == 0 def main (x: i32) = assert (x%2 == 0) (x/2) futhark-0.25.27/tests/assert1.fut000066400000000000000000000002441475065116200166060ustar00rootroot00000000000000-- Assertion of functional value. -- == -- input { 2 } output { 1 } -- input { 3 } error: x % 2 == 0 def main (x: i32) = let f = assert (x%2 == 0) (x/) in f 2 futhark-0.25.27/tests/assert2.fut000066400000000000000000000001431475065116200166050ustar00rootroot00000000000000-- Assertion condition must be a boolean. -- == -- error: bool def main (x: i32) = assert 0 (x/2) futhark-0.25.27/tests/assert3.fut000066400000000000000000000002501475065116200166050ustar00rootroot00000000000000-- unsafe can remove assertions. -- == -- compiled input { 2 } output { 1 } -- compiled input { 3 } output { 1 } def main (x: i32) = #[unsafe] assert (x%2 == 0) (x/2) futhark-0.25.27/tests/assert4.fut000066400000000000000000000003101475065116200166030ustar00rootroot00000000000000-- unsafe can only remove assertions on its sub-exp -- == -- compiled input { 2 } output { 1 } -- compiled input { 3 } error: Assertion is false def main (x: i32) = assert (x%2 == 0) (#[unsafe] x/2) futhark-0.25.27/tests/attributes/000077500000000000000000000000001475065116200166725ustar00rootroot00000000000000futhark-0.25.27/tests/attributes/noinline0.fut000066400000000000000000000001271475065116200213050ustar00rootroot00000000000000-- == -- structure { Apply 1 } def f (x: i32) = x + 2 def main x = #[noinline] f x futhark-0.25.27/tests/attributes/noinline1.fut000066400000000000000000000001541475065116200213060ustar00rootroot00000000000000-- == -- structure { Apply 1 } def f (x: i64) = x + 2 def main x = map (\i -> #[noinline] f i) (iota x) futhark-0.25.27/tests/attributes/noinline2.fut000066400000000000000000000001271475065116200213070ustar00rootroot00000000000000-- == -- structure { Apply 1 } #[noinline] def f (x: i32) = x + 2 def main x = f x futhark-0.25.27/tests/attributes/num0.fut000066400000000000000000000001341475065116200202670ustar00rootroot00000000000000-- Numeric attributes. Not used for anything here. #[1] #[how_cool(1337)] def main = true futhark-0.25.27/tests/attributes/params0.fut000066400000000000000000000000361475065116200207540ustar00rootroot00000000000000def main (#[foo] x: bool) = x futhark-0.25.27/tests/attributes/params1.fut000066400000000000000000000001641475065116200207570ustar00rootroot00000000000000-- == -- input { 10 } output { 20 } def main (x: i32) = loop (#[maxval(1337)] acc : i32) = x for i < 10 do x + x futhark-0.25.27/tests/attributes/sequential0.fut000066400000000000000000000002051475065116200216410ustar00rootroot00000000000000-- == -- random input { [10][10]i32 } auto output -- structure gpu { SegMap 0 Loop 2 } def main xss = #[sequential] map i32.sum xss futhark-0.25.27/tests/attributes/sequential1.fut000066400000000000000000000002321475065116200216420ustar00rootroot00000000000000-- == -- random input { [10][10]i32 } auto output -- structure gpu { /SegMap 1 /SegMap/Loop 1 } def main xss = map (\xs -> #[sequential] i32.sum xs) xss futhark-0.25.27/tests/attributes/sequential2.fut000066400000000000000000000003001475065116200216370ustar00rootroot00000000000000-- Test that attributes don't disappear after fusion. -- == -- structure { Screma 1 } -- structure gpu { /SegMap 0 /Loop 1 } def main (xs: []i32) = (#[sequential] map (+1) xs, map (*2) xs) futhark-0.25.27/tests/attributes/sequential_inner0.fut000066400000000000000000000002241475065116200230350ustar00rootroot00000000000000-- == -- random input { [10][10]i32 } auto output -- structure gpu { /SegMap 1 /SegMap/Loop 1 } def main xss = #[sequential_inner] map i32.sum xss futhark-0.25.27/tests/attributes/sequential_inner1.fut000066400000000000000000000002421475065116200230360ustar00rootroot00000000000000-- == -- random input { [10][10][10]i32 } auto output -- structure gpu { /SegMap 1 /SegMap/Loop 1 } def main = map (\xss -> #[sequential_inner] map i32.sum xss) futhark-0.25.27/tests/attributes/sequential_outer0.fut000066400000000000000000000002221475065116200230560ustar00rootroot00000000000000-- == -- random input { [10][10]i32 } auto output -- structure gpu { /Loop 1 /Loop/SegRed 1 } def main xss = #[sequential_outer] map i32.sum xss futhark-0.25.27/tests/attributes/sequential_outer1.fut000066400000000000000000000004401475065116200230610ustar00rootroot00000000000000-- Slightly odd result due to interchange. -- == -- random input { [10][10][10]i32 } auto output -- structure gpu { -- /Loop 1 -- /Loop/SegRed 1 -- /Loop/SegMap 1 -- } def main xsss = #[incremental_flattening(only_inner)] map (\xss -> #[sequential_outer] map i32.sum xss) xsss futhark-0.25.27/tests/attributes/unroll0.fut000066400000000000000000000003221475065116200210020ustar00rootroot00000000000000-- == -- input { [1,2,4,5,1,2,3,4,1,2,4,1,2] } -- output { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 4, 1, 2] } -- structure { Loop 0 } def main (xs: *[]i32) = #[unroll] loop xs for i < 10 do let xs[i] = i in xs futhark-0.25.27/tests/attributes/unroll1.fut000066400000000000000000000003651475065116200210120ustar00rootroot00000000000000-- == -- input { 10 [1,2,4,5,1,2,3,4,1,2,4,1,2] } -- output { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 4, 1, 2] } -- structure { Loop 1 } -- warning: #\[unroll\] def main (n: i32) (xs: *[]i32) = #[unroll] loop xs for i < n do let xs[i] = i in xs futhark-0.25.27/tests/attributes/unroll2.fut000066400000000000000000000002471475065116200210120ustar00rootroot00000000000000-- == -- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } -- output { 47 } -- structure { Loop 0 } def main (xs: [10]i32) = #[unroll] loop sum = 2 for x in xs do sum + x futhark-0.25.27/tests/attributes/unroll3.fut000066400000000000000000000002341475065116200210070ustar00rootroot00000000000000-- == -- structure { Screma 2 } def main = map (\(x: i64) -> #[unroll] loop arr = replicate 10 x for i < 20 do map2 (+) arr (indices arr)) futhark-0.25.27/tests/attributes/unsafe0.fut000066400000000000000000000002041475065116200207470ustar00rootroot00000000000000-- Using unsafe we can avoid a bounds check. -- -- == -- structure { Assert 0 } def main(a: []i32, i: i32): i32 = #[unsafe] a[i] futhark-0.25.27/tests/attributes/unsafe1.fut000066400000000000000000000002331475065116200207520ustar00rootroot00000000000000-- Only one of the accesses should be unsafe. -- -- == -- structure { Assert 1 } def main(a: []i32, i: i32, j: i32): (i32,i32) = (#[unsafe] a[i], a[j]) futhark-0.25.27/tests/attributes/unsafe2.fut000066400000000000000000000003071475065116200207550ustar00rootroot00000000000000-- Using unsafe we can also avoid assertions due to shape checks. -- -- == -- structure { Assert 0 } def main [n][m] (a: [n]i32, b: [m]i32): ([n]i32,[n]i32) = unzip(#[unsafe] zip a (b :> [n]i32)) futhark-0.25.27/tests/attributes/warn_on_safety_checks0.fut000066400000000000000000000001431475065116200240260ustar00rootroot00000000000000-- == -- warning: Safety check required def main (xs: []f32) i = #[warn(safety_checks)] xs[i] futhark-0.25.27/tests/attributes/warn_on_safety_checks1.fut000066400000000000000000000001741475065116200240330ustar00rootroot00000000000000-- == -- warning: Safety check required def main (xs: []f32) (is: []i32) = #[warn(safety_checks)] map (\i -> xs[i]) is futhark-0.25.27/tests/attributes/warn_on_safety_checks2.fut000066400000000000000000000002241475065116200240300ustar00rootroot00000000000000-- == -- warning: Safety check required def f (xs: []f32) i = xs[i] def main (xs: []f32) (is: []i32) = #[warn(safety_checks)] map (f xs) is futhark-0.25.27/tests/babysitter/000077500000000000000000000000001475065116200166545ustar00rootroot00000000000000futhark-0.25.27/tests/babysitter/no-manifest-1.fut000066400000000000000000000034531475065116200217570ustar00rootroot00000000000000-- this is a simplified version of batch matrix inversion: it should not introduce a transposition for A[i,j] -- == -- structure gpu {Manifest 0} def gauss_jordan [nm] (n:i64) (m:i64) (A: *[nm]f32): [nm]f32 = loop A for i < n do -- the loop is outside the kernel, and hence `i` is a free -- variable in the kernel; hence fixing coalescing will likely -- do more work then the simple access: you will transpose an -- entire row to then read one element from it. This should not -- fire coalescing! let v1 = A[i] let A' = map (\ind -> let (k, j) = (ind / m, ind % m) in if v1 == 0.0 then A[k*m+j] else let x = (A[j] / v1) in if k < n-1 -- Ap case then ( A[(k+1)*m+j] - A[(k+1)*m+i] * x ) else x -- irow case ) (iota nm) in scatter A (iota nm) A' def mat_inv [n] (A: [n][n]f32): [n][n]f32 = let m = 2*n -- Pad the matrix with the identity matrix. let Ap = map (\ind -> let (i, j) = (ind / m, ind % m) -- the innermost index `j` is variant to -- the innermost kernel dimension `ind`; -- hence "likely" already in coalesced form! in if j < n then ( A[i,j] ) else if j == n+i then 1.0 else 0.0 ) (iota (n*m)) let Ap' = unflatten (gauss_jordan n m Ap) -- Drop the identity matrix at the front. in Ap'[0:n,n:n * 2] :> [n][n]f32 def main [m][n] (X : [m][n][n]f32) : [m][n][n]f32 = #[incremental_flattening(only_inner)] map mat_inv X futhark-0.25.27/tests/babysitter/no-manifest-2.fut000066400000000000000000000005121475065116200217510ustar00rootroot00000000000000-- this is a code snippet from bfast; it should transpose `y_error` -- == -- structure gpu {Manifest 0} def main [m][n] (nss: [m]i64) (hs: [m]i64) (y_errors: [m][n]f32) : [m]f32 = zip3 y_errors nss hs |> map (\(y_error, ns, h) -> map (\i -> y_error[i + ns-h+1]) (iota h) |> reduce (+) 0.0 ) futhark-0.25.27/tests/backtick.fut000066400000000000000000000001411475065116200167730ustar00rootroot00000000000000-- == -- input { 2 3 } output { 5 } def plus = (i32.+) def main (x: i32) (y: i32) = x `plus` y futhark-0.25.27/tests/bad-reduce.fut000066400000000000000000000033261475065116200172230ustar00rootroot00000000000000-- A particular wrong reduction, but it should not crash the compiler. type int = i32 def max (x:int, y:int) = i32.max x y def pred1 (k : int, x: int) : bool = if k == 1 then x == 0 -- zeroes else if k == 2 then true -- sorted else if k == 3 then true -- same else true -- default def pred2(k : int, x: int, y: int): bool = if k == 1 then (x == 0) && (y == 0) -- zeroes else if k == 2 then x <= y -- sorted else if k == 3 then x == y -- same else true -- default -- the task is to implement this operator by filling in the blanks def redOp (pind : int) (x: (int,int,int,int,int,int)) (y: (int,int,int,int,int,int)) : (int,int,int,int,int,int) = let (lssx, lisx, lcsx, tlx, firstx, lastx) = x let (lssy, lisy, lcsy, tly, firsty, lasty) = y let connect = false -- ... fill in the blanks (rewrite this line) ... should call pred2 let newlss = 0 -- ... fill in the blanks (rewrite this line) let newlis = 0 -- ... fill in the blanks (rewrite this line) let newlcs = 0 -- ... fill in the blanks (rewrite this line) let first = if tlx == 0 then firsty else firstx let last = if tly == 0 then lastx else lasty in (newlss, newlis, newlcs, tlx+tly, first, last) def mapOp (pind : int) (x: int) : (int,int,int,int,int,int) = let xmatch = if pred1(pind, x) then 1 else 0 in (xmatch, xmatch, xmatch, 1, x, x) def lssp (pind : int) (xs : []int) : int = let (x,_,_,_,_,_) = reduce (redOp pind) (0,0,0,0,0,0) (map (mapOp pind) xs) in x def main(xs: []int): int = lssp 2 xs -- computes sorted -- you may also try with zeroes, i.e., lssp 1 xs, and same, i.e., lssp 3 xs futhark-0.25.27/tests/bad_names.fut000066400000000000000000000003771475065116200171440ustar00rootroot00000000000000-- Tests that the code generator does not choke on terrible names that -- are not valid in C. -- -- == -- input { false 2 } -- output { 17 } def main(r': bool) (x_: i32): i32 = if r' then 123 else (x_ + 1) * 2 + 6 + (loop x=1 for i' < x_ do (x<<1)^i') futhark-0.25.27/tests/badentry0.fut000066400000000000000000000001201475065116200171050ustar00rootroot00000000000000-- == -- warning: Entry point parameter def main xs: ([]i32, []i32) = unzip xs futhark-0.25.27/tests/badentry1.fut000066400000000000000000000001301475065116200171070ustar00rootroot00000000000000-- == -- warning: Entry point return type def main (xs: []i32) (ys: []i32) = zip xs ys futhark-0.25.27/tests/badentry10.fut000066400000000000000000000002221475065116200171710ustar00rootroot00000000000000-- == -- input { 1i64 [1,2] } -- output { 1i64 } -- compiled input { 1i64 [1,2,3] } -- error: invalid size entry main (x: i64) (_: [x+1]i32) = x futhark-0.25.27/tests/badentry2.fut000066400000000000000000000001571475065116200171210ustar00rootroot00000000000000-- This should not warn, even though it is partially applied. -- == -- warning: ^$ def main = reduce (i32.+) 0 futhark-0.25.27/tests/badentry3.fut000066400000000000000000000002401475065116200171130ustar00rootroot00000000000000-- It is OK for part of a returned tuple to be opaque. -- == -- warning: ^$ type opaque [n] = [n](i32, i32) def main (x: i32): (opaque [], i32) = ([(x,x)],x) futhark-0.25.27/tests/badentry4.fut000066400000000000000000000001731475065116200171210ustar00rootroot00000000000000-- It is OK to have an array of opaques. -- == -- warning: ^$ type opaque = {x:i32} def main (x: i32): [1]opaque = [{x}] futhark-0.25.27/tests/badentry5.fut000066400000000000000000000002401475065116200171150ustar00rootroot00000000000000-- Fully abstract parameter type. -- == -- warning: Entry point parameter module m : { type t val x: t} = { type t = i32 def x = 0 } def main (x: m.t) = 0i32 futhark-0.25.27/tests/badentry6.fut000066400000000000000000000002201475065116200171140ustar00rootroot00000000000000-- Fully abstract return type. -- == -- warning: Entry point return module m : { type t val x: t} = { type t = i32 def x = 0 } def main = m.x futhark-0.25.27/tests/badentry7.fut000066400000000000000000000004271475065116200171260ustar00rootroot00000000000000-- Don't throw away module qualifiers when looking at opaque types for -- entry points. -- == -- warning: Entry point parameter module m0 = { type state = {f: i32} } module m1 = { type state = {f: [1]f32} } entry g (p0: m0.state) (p1: m1.state) = f32.i32 p0.f + p1.f[0] futhark-0.25.27/tests/badentry8.fut000066400000000000000000000002261475065116200171240ustar00rootroot00000000000000-- It is OK to put the type annotations in a higher-order return type. -- == -- warning: ^$ type t1 = {x:i32} type t2 = t1 def main : t1 -> t2 = id futhark-0.25.27/tests/badentry9.fut000066400000000000000000000001711475065116200171240ustar00rootroot00000000000000-- Entry points must use all sizes constructively. -- == -- error: \[x\].*constructive entry main [x] (_: [x+1]i32) = x futhark-0.25.27/tests/big0.fut000066400000000000000000000007571475065116200160560ustar00rootroot00000000000000-- Testing big arrays. -- == -- tags { no_python no_pyopencl } -- no_python no_opencl no_cuda no_hip no_wasm no_ispc compiled input { 2i64 1100000000i64 1 1073741823 } output { -2i8 } -- no_python no_opencl no_cuda no_hip no_wasm no_ispc compiled input { 3i64 1073741824i64 2 1073741823 } output { -3i8 } -- structure gpu-mem { SegMap 1 } def main (n: i64) (m: i64) (i: i32) (j: i32) = -- The opaque is just to force manifestation. (opaque (tabulate_2d n m (\i j -> i8.i64 (i ^ j))))[i,j] futhark-0.25.27/tests/big1.fut000066400000000000000000000002341475065116200160450ustar00rootroot00000000000000-- main -- == -- tags { no_gtx780 } -- no_python no_ispc compiled random input {11264000000i64} auto output def main (n: i64): i64 = iota n |> reduce (+) 0 futhark-0.25.27/tests/big2.fut000066400000000000000000000004061475065116200160470ustar00rootroot00000000000000-- == -- tags { no_gtx780 } -- no_ispc no_python no_wasm compiled input {2147483748i64} output { 0i8 -5i8 86i8 } -- To avoid enormous output, we just sample the result. def main n = let res = scan (+) 0 (map i8.i64 (iota n)) in (res[0], res[n/2], res[n-1]) futhark-0.25.27/tests/binding-warn0.fut000066400000000000000000000002571475065116200176670ustar00rootroot00000000000000-- It is bad to give an argument with a binding that is used in size -- but it is accepted -- == -- warning: with binding def f [n] (ns: *[n]i64) = iota (let m = n+2 in m*m) futhark-0.25.27/tests/binding-warn1.fut000066400000000000000000000002261475065116200176640ustar00rootroot00000000000000-- Bad to bind in slices, but accepted -- == -- warning: with binding def main (n:i64) (xs:*[n]i64) = let t = iota n in t[:let m = n-4 in m*m/n] futhark-0.25.27/tests/bitwise.fut000066400000000000000000000015221475065116200166720ustar00rootroot00000000000000-- Bitwise operation stress test. -- -- Mostly to ensure that interpreter and code generator agree. -- Originally distilled from MD5 sum calculation. -- == -- input { -- 1732584193u32 -- -271733879u32 -- -1732584194u32 -- 271733878u32 -- 1 -- } -- output { -- 271733878u32 -- 757607282u32 -- -271733879u32 -- -1732584194u32 -- } def funF(x: u32, y: u32, z: u32): u32 = x & y | !x & z def rotateL (x: u32, i: u32): u32 = let post = x << i let pre = (x >> i) & (!(0xFFFFFFFF << i)) in post | pre def frob(a: u32, b: u32, c: u32, d: u32): (u32, u32, u32, u32) = let w = 0x97989910 let f' = funF(b,c,d) let a' = b + rotateL((a + f' + w + 0xd76aa478), 7) in (d, a', b, c) def main (a: u32) (b: u32) (c: u32) (d: u32) (n: i32): (u32, u32, u32, u32) = loop (a',b',c',d') = (a,b,c,d) for _i < n do frob(a',b',c',d') futhark-0.25.27/tests/blackscholes.fut000066400000000000000000000534731475065116200176750ustar00rootroot00000000000000-- -- -- input { -- 5 -- } -- output { -- [0.000000, 0.000000, 0.000006, 0.000077, 0.000386, 0.001183, 0.002707, 0.005143, -- 0.008608, 0.013162, 0.018818, 0.025560, 0.033350, 0.042137, 0.051866, 0.062475, -- 0.073906, 0.086099, 0.098998, 0.112550, 0.126705, 0.141416, 0.156641, 0.172339, -- 0.188473, 0.205010, 0.221917, 0.239166, 0.256731, 0.274587, 0.292711, 0.311083, -- 0.329684, 0.348495, 0.367501, 0.386687, 0.406039, 0.425543, 0.445188, 0.464964, -- 0.484859, 0.504864, 0.524971, 0.545171, 0.565457, 0.585822, 0.606259, 0.626761, -- 0.647324, 0.667942, 0.688610, 0.709324, 0.730078, 0.750870, 0.771694, 0.792548, -- 0.813429, 0.834333, 0.855257, 0.876199, 0.897156, 0.918125, 0.939105, 0.960093, -- 0.981087, 1.002086, 1.023088, 1.044090, 1.065092, 1.086092, 1.107089, 1.128081, -- 1.149067, 1.170047, 1.191018, 1.211981, 1.232933, 1.253875, 1.274805, 1.295722, -- 1.316626, 1.337517, 1.358392, 1.379253, 1.400098, 1.420927, 1.441739, 1.462533, -- 1.483310, 1.504069, 1.524810, 1.545531, 1.566233, 1.586916, 1.607579, 1.628222, -- 1.648844, 1.669446, 1.690027, 1.710587, 1.731126, 1.751643, 1.772139, 1.792613, -- 1.813065, 1.833496, 1.853904, 1.874289, 1.894653, 1.914994, 1.935313, 1.955609, -- 1.975882, 1.996133, 2.016361, 2.036566, 2.056749, 2.076909, 2.097046, 2.117160, -- 2.137251, 2.157319, 2.177364, 2.197387, 2.217387, 2.237364, 2.257318, 2.277249, -- 2.297157, 2.317043, 2.336906, 2.356746, 2.376564, 2.396359, 2.416131, 2.435881, -- 2.455608, 2.475313, 2.494996, 2.514656, 2.534294, 2.553910, 2.573503, 2.593075, -- 2.612624, 2.632152, 2.651657, 2.671141, 2.690602, 2.710042, 2.729461, 2.748857, -- 2.768233, 2.787586, 2.806919, 2.826230, 2.845520, 2.864788, 2.884036, 2.903262, -- 2.922468, 2.941652, 2.960816, 2.979959, 2.999081, 3.018183, 3.037265, 3.056325, -- 3.075366, 3.094386, 3.113386, 3.132366, 3.151325, 3.170265, 3.189185, 3.208085, -- 3.226965, 3.245826, 3.264667, 3.283488, 3.302290, 3.321073, 3.339836, 3.358580, -- 3.377305, 3.396011, 3.414697, 3.433365, 3.452014, 3.470644, 3.489256, 3.507849, -- 3.526423, 3.544979, 3.563516, 3.582035, 3.600536, 3.619019, 3.637483, 3.655930, -- 3.674358, 3.692769, 3.711162, 3.729537, 3.747894, 3.766233, 3.784556, 3.802860, -- 3.821147, 3.839417, 3.857670, 3.875905, 3.894123, 3.912325, 3.930509, 3.948676, -- 3.966826, 3.984960, 4.003077, 4.021177, 4.039260, 4.057327, 4.075378, 4.093412, -- 4.111430, 4.129431, 4.147416, 4.165386, 4.183339, 4.201276, 4.219197, 4.237102, -- 4.254991, 4.272864, 4.290722, 4.308564, 4.326391, 4.344202, 4.361997, 4.379777, -- 4.397542, 4.415291, 4.433025, 4.450744, 4.468448, 4.486137, 4.503811, 4.521469, -- 4.539113, 4.556742, 4.574357, 4.591956, 4.609541, 4.627111, 4.644667, 4.662208, -- 4.679735, 4.697247, 4.714745, 4.732228, 4.749698, 4.767153, 4.784594, 4.802021, -- 4.819433, 4.836832, 4.854217, 4.871588, 4.888946, 4.906289, 4.923619, 4.940935, -- 4.958237, 4.975526, 4.992801, 5.010063, 5.027311, 5.044546, 5.061767, 5.078975, -- 5.096170, 5.113352, 5.130521, 5.147676, 5.164818, 5.181948, 5.199064, 5.216168, -- 5.233258, 5.250336, 5.267401, 5.284453, 5.301492, 5.318519, 5.335533, 5.352535, -- 5.369524, 5.386500, 5.403464, 5.420416, 5.437355, 5.454282, 5.471197, 5.488099, -- 5.504989, 5.521867, 5.538733, 5.555587, 5.572428, 5.589258, 5.606076, 5.622881, -- 5.639675, 5.656457, 5.673227, 5.689986, 5.706732, 5.723467, 5.740190, 5.756902, -- 5.773602, 5.790290, 5.806967, 5.823633, 5.840287, 5.856929, 5.873560, 5.890180, -- 5.906789, 5.923386, 5.939972, 5.956547, 5.973111, 5.989663, 6.006205, 6.022735, -- 6.039254, 6.055763, 6.072260, 6.088747, 6.105223, 6.121687, 6.138141, 6.154585, -- 6.171017, 6.187439, 6.203850, 6.220250, 6.236640, 6.253019, 6.269387, 6.285746, -- 6.302093, 6.318430, 6.334757, 6.351073, 6.367379, 6.383674, 6.399960, 6.416235, -- 6.432499, 6.448754, 6.464998, 6.481232, 6.497456, 6.513670, 6.529874, 6.546067, -- 6.562251, 6.578425, 6.594589, 6.610742, 6.626886, 6.643020, 6.659145, 6.675259, -- 6.691363, 6.707458, 6.723543, 6.739619, 6.755684, 6.771740, 6.787787, 6.803824, -- 6.819851, 6.835869, 6.851877, 6.867875, 6.883865, 6.899844, 6.915815, 6.931776, -- 6.947727, 6.963669, 6.979602, 6.995526, 7.011440, 7.027345, 7.043241, 7.059128, -- 7.075005, 7.090874, 7.106733, 7.122583, 7.138424, 7.154256, 7.170079, 7.185893, -- 7.201698, 7.217494, 7.233281, 7.249060, 7.264829, 7.280589, 7.296341, 7.312084, -- 7.327818, 7.343543, 7.359260, 7.374968, 7.390667, 7.406357, 7.422039, 7.437712, -- 7.453376, 7.469032, 7.484680, 7.500318, 7.515949, 7.531570, 7.547184, 7.562789, -- 7.578385, 7.593973, 7.609552, 7.625124, 7.640686, 7.656241, 7.671787, 7.687325, -- 7.702854, 7.718376, 7.733889, 7.749393, 7.764890, 7.780378, 7.795859, 7.811331, -- 7.826795, 7.842251, 7.857699, 7.873138, 7.888570, 7.903994, 7.919410, 7.934817, -- 7.950217, 7.965609, 7.980993, 7.996369, 8.011737, 8.027097, 8.042449, 8.057794, -- 8.073130, 8.088459, 8.103780, 8.119094, 8.134399, 8.149697, 8.164987, 8.180270, -- 8.195544, 8.210811, 8.226071, 8.241323, 8.256567, 8.271803, 8.287032, 8.302254, -- 8.317468, 8.332674, 8.347873, 8.363064, 8.378248, 8.393425, 8.408594, 8.423755, -- 8.438910, 8.454056, 8.469196, 8.484328, 8.499453, 8.514570, 8.529680, 8.544783, -- 8.559878, 8.574966, 8.590047, 8.605121, 8.620188, 8.635247, 8.650299, 8.665344, -- 8.680382, 8.695412, 8.710436, 8.725452, 8.740462, 8.755464, 8.770459, 8.785447, -- 8.800428, 8.815402, 8.830369, 8.845329, 8.860283, 8.875229, 8.890168, 8.905100, -- 8.920025, 8.934944, 8.949855, 8.964760, 8.979658, 8.994548, 9.009432, 9.024310, -- 9.039180, 9.054044, 9.068901, 9.083751, 9.098594, 9.113431, 9.128260, 9.143084, -- 9.157900, 9.172710, 9.187513, 9.202309, 9.217099, 9.231882, 9.246659, 9.261428, -- 9.276192, 9.290948, 9.305699, 9.320442, 9.335179, 9.349910, 9.364634, 9.379351, -- 9.394062, 9.408767, 9.423465, 9.438156, 9.452841, 9.467520, 9.482192, 9.496858, -- 9.511518, 9.526171, 9.540817, 9.555458, 9.570092, 9.584719, 9.599341, 9.613956, -- 9.628565, 9.643167, 9.657763, 9.672353, 9.686937, 9.701514, 9.716085, 9.730650, -- 9.745209, 9.759761, 9.774308, 9.788848, 9.803382, 9.817910, 9.832432, 9.846947, -- 9.861457, 9.875960, 9.890457, 9.904949, 9.919434, 9.933913, 9.948386, 9.962853, -- 9.977314, 9.991769, 10.006218, 10.020661, 10.035097, 10.049528, 10.063953, -- 10.078372, 10.092786, 10.107193, 10.121594, 10.135989, 10.150379, 10.164762, -- 10.179140, 10.193512, 10.207877, 10.222238, 10.236592, 10.250940, 10.265283, -- 10.279619, 10.293950, 10.308275, 10.322595, 10.336908, 10.351216, 10.365518, -- 10.379815, 10.394105, 10.408390, 10.422669, 10.436943, 10.451210, 10.465473, -- 10.479729, 10.493980, 10.508225, 10.522464, 10.536698, 10.550926, 10.565149, -- 10.579366, 10.593577, 10.607783, 10.621983, 10.636177, 10.650366, 10.664550, -- 10.678728, 10.692900, 10.707067, 10.721228, 10.735384, 10.749534, 10.763679, -- 10.777819, 10.791952, 10.806081, 10.820204, 10.834321, 10.848433, 10.862540, -- 10.876641, 10.890737, 10.904827, 10.918912, 10.932992, 10.947066, 10.961135, -- 10.975198, 10.989257, 11.003309, 11.017357, 11.031399, 11.045436, 11.059467, -- 11.073493, 11.087514, 11.101530, 11.115540, 11.129545, 11.143545, 11.157539, -- 11.171529, 11.185513, 11.199492, 11.213465, 11.227434, 11.241397, 11.255355, -- 11.269307, 11.283255, 11.297197, 11.311135, 11.325067, 11.338994, 11.352915, -- 11.366832, 11.380744, 11.394650, 11.408551, 11.422447, 11.436338, 11.450224, -- 11.464105, 11.477981, 11.491852, 11.505718, 11.519578, 11.533434, 11.547284, -- 11.561130, 11.574970, 11.588806, 11.602636, 11.616462, 11.630282, 11.644098, -- 11.657908, 11.671714, 11.685514, 11.699310, 11.713101, 11.726886, 11.740667, -- 11.754443, 11.768214, 11.781980, 11.795741, 11.809497, 11.823248, 11.836995, -- 11.850736, 11.864473, 11.878205, 11.891932, 11.905654, 11.919371, 11.933083, -- 11.946791, 11.960493, 11.974191, 11.987884, 12.001573, 12.015256, 12.028935, -- 12.042609, 12.056278, 12.069942, 12.083602, 12.097256, 12.110907, 12.124552, -- 12.138192, 12.151828, 12.165459, 12.179086, 12.192707, 12.206324, 12.219936, -- 12.233544, 12.247147, 12.260745, 12.274338, 12.287927, 12.301511, 12.315091, -- 12.328666, 12.342236, 12.355801, 12.369362, 12.382919, 12.396470, 12.410017, -- 12.423560, 12.437098, 12.450631, 12.464159, 12.477684, 12.491203, 12.504718, -- 12.518228, 12.531734, 12.545235, 12.558732, 12.572224, 12.585712, 12.599195, -- 12.612673, 12.626147, 12.639617, 12.653082, 12.666542, 12.679998, 12.693450, -- 12.706897, 12.720339, 12.733777, 12.747211, 12.760640, 12.774064, 12.787485, -- 12.800900, 12.814312, 12.827719, 12.841121, 12.854519, 12.867913, 12.881302, -- 12.894687, 12.908067, 12.921443, 12.934815, 12.948182, 12.961545, 12.974903, -- 12.988257, 13.001607, 13.014952, 13.028293, 13.041630, 13.054962, 13.068290, -- 13.081614, 13.094933, 13.108248, 13.121559, 13.134865, 13.148167, 13.161465, -- 13.174759, 13.188048, 13.201333, 13.214613, 13.227890, 13.241162, 13.254430, -- 13.267693, 13.280952, 13.294207, 13.307458, 13.320705, 13.333947, 13.347185, -- 13.360419, 13.373649, 13.386874, 13.400095, 13.413313, 13.426525, 13.439734, -- 13.452938, 13.466139, 13.479335, 13.492527, 13.505714, 13.518898, 13.532077, -- 13.545253, 13.558424, 13.571591, 13.584754, 13.597912, 13.611067, 13.624217, -- 13.637364, 13.650506, 13.663644, 13.676778, 13.689908, 13.703034, 13.716155, -- 13.729273, 13.742386, 13.755496, 13.768601, 13.781703, 13.794800, 13.807893, -- 13.820982, 13.834067, 13.847148, 13.860225, 13.873298, 13.886367, 13.899432, -- 13.912493, 13.925550, 13.938602, 13.951651, 13.964696, 13.977737, 13.990774, -- 14.003806, 14.016835, 14.029860, 14.042881, 14.055898, 14.068911, 14.081920, -- 14.094925, 14.107926, 14.120923, 14.133916, 14.146905, 14.159890, 14.172872, -- 14.185849, 14.198822, 14.211792, 14.224758, 14.237719, 14.250677, 14.263631, -- 14.276581, 14.289527, 14.302469, 14.315407, 14.328342, 14.341272, 14.354199, -- 14.367122, 14.380041, 14.392956, 14.405867, 14.418774, 14.431677, 14.444577, -- 14.457473, 14.470365, 14.483253, 14.496137, 14.509017, 14.521894, 14.534767, -- 14.547636, 14.560501, 14.573362, 14.586220, 14.599074, 14.611923, 14.624770, -- 14.637612, 14.650450, 14.663285, 14.676116, 14.688943, 14.701767, 14.714587, -- 14.727403, 14.740215, 14.753023, 14.765828, 14.778629, 14.791426, 14.804219, -- 14.817009, 14.829795, 14.842577, 14.855356, 14.868130, 14.880902, 14.893669, -- 14.906433, 14.919192, 14.931949, 14.944701, 14.957450, 14.970195, 14.982937, -- 14.995674, 15.008409, 15.021139, 15.033866, 15.046589, 15.059308, 15.072024, -- 15.084736, 15.097445, 15.110149, 15.122851, 15.135548, 15.148242, 15.160932, -- 15.173619, 15.186302, 15.198981, 15.211657, 15.224329, 15.236997, 15.249662, -- 15.262323, 15.274981, 15.287635, 15.300286, 15.312932, 15.325576, 15.338215, -- 15.350851, 15.363484, 15.376113, 15.388738, 15.401360, 15.413978, 15.426593, -- 15.439204, 15.451812, 15.464416, 15.477016, 15.489613, 15.502206, 15.514796, -- 15.527382, 15.539965, 15.552544, 15.565120, 15.577692, 15.590261, 15.602826, -- 15.615387, 15.627945, 15.640500, 15.653051, 15.665599, 15.678143, 15.690683, -- 15.703220, 15.715754, 15.728284, 15.740811, 15.753334, 15.765854, 15.778370, -- 15.790883, 15.803392, 15.815898, 15.828400, 15.840899, 15.853394, 15.865886, -- 15.878375, 15.890860, 15.903342, 15.915820, 15.928295, 15.940766, 15.953234, -- 15.965698, 15.978159, 15.990617, 16.003071, 16.015522, 16.027970, 16.040414, -- 16.052854, 16.065292, 16.077725, 16.090156, 16.102583, 16.115006, 16.127427, -- 16.139844, 16.152257, 16.164667, 16.177074, 16.189477, 16.201877, 16.214274, -- 16.226667, 16.239057, 16.251444, 16.263827, 16.276207, 16.288583, 16.300956, -- 16.313326, 16.325693, 16.338056, 16.350416, 16.362772, 16.375125, 16.387475, -- 16.399822, 16.412165, 16.424505, 16.436841, 16.449175, 16.461505, 16.473831, -- 16.486155, 16.498475, 16.510791, 16.523105, 16.535415, 16.547722, 16.560026, -- 16.572326, 16.584623, 16.596917, 16.609207, 16.621494, 16.633778, 16.646059, -- 16.658337, 16.670611, 16.682882, 16.695149, 16.707414, 16.719675, 16.731933, -- 16.744187, 16.756439, 16.768687, 16.780932, 16.793174, 16.805412, 16.817647, -- 16.829880, 16.842108, 16.854334, 16.866556, 16.878776, 16.890991, 16.903204, -- 16.915414, 16.927620, 16.939823, 16.952023, 16.964220, 16.976414, 16.988604, -- 17.000791, 17.012975, 17.025156, 17.037334, 17.049508, 17.061679, 17.073847, -- 17.086012, 17.098174, 17.110333, 17.122488, 17.134640, 17.146790, 17.158936, -- 17.171078, 17.183218, 17.195355, 17.207488, 17.219618, 17.231745, 17.243869, -- 17.255990, 17.268108, 17.280222, 17.292334, 17.304442, 17.316547, 17.328649, -- 17.340748, 17.352844, 17.364937, 17.377027, 17.389113, 17.401197, 17.413277, -- 17.425354, 17.437428, 17.449499, 17.461567, 17.473632, 17.485694, 17.497752, -- 17.509808, 17.521860, 17.533910, 17.545956, 17.558000, 17.570040, 17.582077, -- 17.594111, 17.606142, 17.618170, 17.630195, 17.642217, 17.654235, 17.666251, -- 17.678264, 17.690273, 17.702280, 17.714284, 17.726284, 17.738282, 17.750276, -- 17.762267, 17.774256, 17.786241, 17.798223, 17.810203, 17.822179, 17.834152, -- 17.846122, 17.858090, 17.870054, 17.882015, 17.893973, 17.905928, 17.917880, -- 17.929830, 17.941776, 17.953719, 17.965659, 17.977596, 17.989530, 18.001462, -- 18.013390, 18.025315, 18.037237, 18.049157, 18.061073, 18.072986, 18.084897, -- 18.096804, 18.108708, 18.120610, 18.132508, 18.144404, 18.156296, 18.168186, -- 18.180073, 18.191956, 18.203837, 18.215715, 18.227590, 18.239462, 18.251330, -- 18.263196, 18.275060, 18.286920, 18.298777, 18.310631, 18.322483, 18.334331, -- 18.346176, 18.358019, 18.369859, 18.381695, 18.393529, 18.405360, 18.417188, -- 18.429013, 18.440835, 18.452655, 18.464471, 18.476284, 18.488095, 18.499902, -- 18.511707, 18.523509, 18.535308, 18.547104, 18.558897, 18.570688, 18.582475, -- 18.594260, 18.606041, 18.617820, 18.629596, 18.641369, 18.653139, 18.664906, -- 18.676671, 18.688432, 18.700191, 18.711947, 18.723700, 18.735450, 18.747197, -- 18.758942, 18.770683, 18.782422, 18.794158, 18.805891, 18.817621, 18.829348, -- 18.841073, 18.852794, 18.864513, 18.876229, 18.887942, 18.899652, 18.911360, -- 18.923064, 18.934766, 18.946465, 18.958161, 18.969855, 18.981545, 18.993233, -- 19.004918, 19.016600, 19.028279, 19.039956, 19.051630, 19.063300, 19.074969, -- 19.086634, 19.098296, 19.109956, 19.121613, 19.133267, 19.144918, 19.156567, -- 19.168213, 19.179855, 19.191496, 19.203133, 19.214768, 19.226399, 19.238028, -- 19.249655, 19.261278, 19.272899, 19.284517, 19.296132, 19.307745, 19.319354, -- 19.330961, 19.342565, 19.354167, 19.365765, 19.377361, 19.388954, 19.400545, -- 19.412132, 19.423717, 19.435299, 19.446879, 19.458456, 19.470029, 19.481601, -- 19.493169, 19.504735, 19.516298, 19.527858, 19.539416, 19.550970, 19.562522, -- 19.574072, 19.585618, 19.597162, 19.608704, 19.620242, 19.631778, 19.643311, -- 19.654841, 19.666369, 19.677894, 19.689416, 19.700936, 19.712452, 19.723966, -- 19.735478, 19.746987, 19.758493, 19.769996, 19.781497, 19.792995, 19.804490, -- 19.815982, 19.827472, 19.838959, 19.850444, 19.861926, 19.873405, 19.884882, -- 19.896355, 19.907826, 19.919295, 19.930761, 19.942224, 19.953684, 19.965142, -- 19.976597, 19.988050, 19.999500, 20.010947, 20.022391, 20.033833, 20.045272, -- 20.056709, 20.068143, 20.079574, 20.091003, 20.102429, 20.113852, 20.125273, -- 20.136691, 20.148106, 20.159519, 20.170929, 20.182336, 20.193741, 20.205144, -- 20.216543, 20.227940, 20.239334, 20.250726, 20.262115, 20.273502, 20.284886, -- 20.296267, 20.307646, 20.319022, 20.330395, 20.341766, 20.353134, 20.364500, -- 20.375863, 20.387223, 20.398581, 20.409936, 20.421288, 20.432638, 20.443986, -- 20.455331, 20.466673, 20.478012, 20.489349, 20.500684, 20.512016, 20.523345, -- 20.534672, 20.545996, 20.557317, 20.568636, 20.579952, 20.591266, 20.602577, -- 20.613886, 20.625192, 20.636496, 20.647796, 20.659095, 20.670391, 20.681684, -- 20.692974, 20.704262, 20.715548, 20.726831, 20.738111, 20.749389, 20.760665, -- 20.771937, 20.783208, 20.794475, 20.805740, 20.817003, 20.828263, 20.839520, -- 20.850775, 20.862028, 20.873278, 20.884525, 20.895770, 20.907012, 20.918252, -- 20.929489, 20.940723, 20.951956, 20.963185, 20.974412, 20.985637, 20.996859, -- 21.008078, 21.019295, 21.030510, 21.041722, 21.052931, 21.064138, 21.075343, -- 21.086544, 21.097744, 21.108941, 21.120135, 21.131327, 21.142516, 21.153703, -- 21.164888, 21.176070, 21.187249, 21.198426, 21.209600, 21.220772, 21.231941, -- 21.243108, 21.254273, 21.265435, 21.276594, 21.287751, 21.298906, 21.310058, -- 21.321207, 21.332354, 21.343499, 21.354641, 21.365780, 21.376918, 21.388052, -- 21.399184, 21.410314, 21.421441, 21.432566, 21.443689, 21.454808, 21.465926, -- 21.477041, 21.488153, 21.499263, 21.510371, 21.521476, 21.532579, 21.543679, -- 21.554776, 21.565872, 21.576965, 21.588055, 21.599143, 21.610228, 21.621312, -- 21.632392, 21.643470, 21.654546, 21.665619, 21.676690, 21.687759, 21.698825, -- 21.709888, 21.720949, 21.732008, 21.743064, 21.754118, 21.765170, 21.776219, -- 21.787265, 21.798309, 21.809351, 21.820390, 21.831427, 21.842462, 21.853494, -- 21.864524, 21.875551, 21.886576, 21.897598, 21.908618, 21.919636, 21.930651, -- 21.941664, 21.952674, 21.963682, 21.974687, 21.985691, 21.996691, 22.007690, -- 22.018686, 22.029679, 22.040671, 22.051659, 22.062646, 22.073630, 22.084612, -- 22.095591, 22.106568, 22.117542, 22.128514, 22.139484, 22.150451, 22.161416, -- 22.172379, 22.183339, 22.194297, 22.205252, 22.216205, 22.227156, 22.238104, -- 22.249050, 22.259994, 22.270935, 22.281874, 22.292811, 22.303745, 22.314676, -- 22.325606, 22.336533, 22.347458, 22.358380, 22.369300, 22.380218, 22.391133, -- 22.402046, 22.412956, 22.423865, 22.434770, 22.445674, 22.456575, 22.467474, -- 22.478371, 22.489265, 22.500157, 22.511046, 22.521933, 22.532818, 22.543700, -- 22.554581, 22.565458, 22.576334, 22.587207, 22.598078, 22.608946, 22.619812, -- 22.630676, 22.641538, 22.652397, 22.663254, 22.674108, 22.684960, 22.695810, -- 22.706658, 22.717503, 22.728346, 22.739187, 22.750025, 22.760861, 22.771695, -- 22.782526, 22.793355, 22.804182, 22.815006, 22.825828, 22.836648, 22.847466, -- 22.858281, 22.869094, 22.879905, 22.890713, 22.901519, 22.912323, 22.923124, -- 22.933923, 22.944720, 22.955515, 22.966307, 22.977097, 22.987885, 22.998670, -- 23.009453, 23.020234, 23.031013, 23.041789, 23.052563, 23.063335, 23.074104, -- 23.084871, 23.095636, 23.106399, 23.117159, 23.127917, 23.138673, 23.149427, -- 23.160178, 23.170927, 23.181674, 23.192418, 23.203160, 23.213900, 23.224638, -- 23.235373, 23.246107, 23.256837, 23.267566, 23.278292, 23.289017, 23.299738, -- 23.310458, 23.321175, 23.331890, 23.342603, 23.353314, 23.364022, 23.374728, -- 23.385432, 23.396134, 23.406833, 23.417530, 23.428225, 23.438918, 23.449608, -- 23.460296, 23.470982, 23.481666, 23.492348, 23.503027, 23.513704, 23.524379, -- 23.535051, 23.545721, 23.556389, 23.567055, 23.577719, 23.588380, 23.599039, -- 23.609696, 23.620351, 23.631004, 23.641654, 23.652302, 23.662948, 23.673591, -- 23.684233, 23.694872, 23.705509, 23.716144, 23.726776, 23.737407, 23.748035, -- 23.758661, 23.769284, 23.779906, 23.790525, 23.801142, 23.811757, 23.822370, -- 23.832980, 23.843589, 23.854195, 23.864799, 23.875400, 23.886000, 23.896597, -- 23.907192, 23.917785, 23.928376, 23.938965, 23.949551, 23.960135, 23.970717, -- 23.981297, 23.991875, 24.002450, 24.013023, 24.023594, 24.034163, 24.044730, -- 24.055294, 24.065857, 24.076417, 24.086975, 24.097531, 24.108084, 24.118636, -- 24.129185, 24.139732, 24.150277, 24.160820, 24.171360, 24.181899, 24.192435, -- 24.202969, 24.213501, 24.224031, 24.234559, 24.245084, 24.255607, 24.266128, -- 24.276647, 24.287164, 24.297679, 24.308191, 24.318702, 24.329210, 24.339716, -- 24.350220, 24.360722, 24.371221, 24.381719, 24.392214, 24.402707, 24.413198, -- 24.423687, 24.434174, 24.444658, 24.455141, 24.465621, 24.476099, 24.486575, -- 24.497049, 24.507521, 24.517991, 24.528458, 24.538923, 24.549387, 24.559848, -- 24.570307, 24.580763, 24.591218, 24.601671, 24.612121, 24.622569, 24.633016, -- 24.643460, 24.653902, 24.664341, 24.674779, 24.685215, 24.695648, 24.706079, -- 24.716509, 24.726936, 24.737361, 24.747784, 24.758204, 24.768623, 24.779040, -- 24.789454, 24.799866, 24.810277, 24.820685, 24.831091, 24.841495, 24.851896, -- 24.862296] -- } def horner (x: f64): f64 = let (c1,c2,c3,c4,c5) = (0.31938153,-0.356563782,1.781477937,-1.821255978,1.330274429) in x * (c1 + x * (c2 + x * (c3 + x * (c4 + x * c5)))) def fabs (x: f64): f64 = if x < 0.0 then -x else x def cnd0 (d: f64): f64 = let k = 1.0 / (1.0 + 0.2316419 * fabs(d)) let p = horner(k) let rsqrt2pi = 0.39894228040143267793994605993438 in rsqrt2pi * f64.exp(-0.5*d*d) * p def cnd (d: f64): f64 = let c = cnd0(d) in if 0.0 < d then 1.0 - c else c def go (x: (bool,f64,f64,f64)): f64 = let (call, price, strike, years) = x let r = 0.08 -- riskfree let v = 0.30 -- volatility let v_sqrtT = v * f64.sqrt(years) let d1 = (f64.log(price / strike) + (r + 0.5 * v * v) * years) / v_sqrtT let d2 = d1 - v_sqrtT let cndD1 = cnd(d1) let cndD2 = cnd(d2) let x_expRT = strike * f64.exp(-r * years) in if call then price * cndD1 - x_expRT * cndD2 else x_expRT * (1.0 - cndD2) - price * (1.0 - cndD1) def blackscholes (xs: [](bool,f64,f64,f64)): []f64 = map go xs def main (years: i64): []f64 = let days = years*365 let a = map (+1) (iota(days)) let a = map f64.i64 a let a = map (\x -> (true, 58.0 + 4.0 * x / f64.i64(days), 65.0, x / 365.0)) a in blackscholes(a) futhark-0.25.27/tests/bounds-elim0.fut000066400000000000000000000002451475065116200175230ustar00rootroot00000000000000-- Optimise away a particularly simple case of bounds checking. -- == -- structure { Assert 0 } def main [n] (xs: [n]i32) = loop acc = 0 for i < n do acc + xs[i] futhark-0.25.27/tests/bounds-elim1.fut000066400000000000000000000002551475065116200175250ustar00rootroot00000000000000-- Optimise away another particularly simple case of bounds checking. -- == -- structure gpu { SegMap/Assert 0 } def main [n] (xs: [n]i32) = tabulate n (\i -> xs[i] + 2) futhark-0.25.27/tests/bounds-elim2.fut000066400000000000000000000001031475065116200175160ustar00rootroot00000000000000-- == -- structure { Assert 0 } def main (xs: []i32) = indices xs futhark-0.25.27/tests/bounds-error0.fut000066400000000000000000000002541475065116200177260ustar00rootroot00000000000000-- Test that a trivial runtime out-of-bounds access is caught. -- == -- input { [1,2,3] 4 } -- error: Index \[4\] out of bounds def main (a: []i32) (i: i32): i32 = a[i] futhark-0.25.27/tests/bounds-error1.fut000066400000000000000000000004301475065116200177230ustar00rootroot00000000000000-- The bounds error message should not refer to more dimensions than -- are present in the source language program. -- == -- input { [[1,2]] 4 } -- error: Index \[4\] out of bounds for array of shape \[1\] def index xs i = xs[i] def main (xss: [][]i32) (i: i32) = index xss i futhark-0.25.27/tests/branch_array.fut000066400000000000000000000006061475065116200176610ustar00rootroot00000000000000-- Tests some nasty code generation/simplification details about -- removing existential contexts. -- -- == -- -- input { true 3i64 } -- output { [0i64,1i64,2i64] } -- input { false 3i64 } -- output { [1337i64,1337i64,1337i64] } def f [n] (a: [n]i64): []i64 = a def g(n: i64): []i64 = replicate n 1337 def main (b: bool) (n: i64): []i64 = let a = iota(n) in if b then f(a) else g(n) futhark-0.25.27/tests/closedform_loop.fut000066400000000000000000000001621475065116200204110ustar00rootroot00000000000000-- == -- structure { Screma 1 } -- structure gpu { SegRed 1 } entry main n = f64.sum (replicate n (1/f64.i64 n)) futhark-0.25.27/tests/coalescing/000077500000000000000000000000001475065116200166135ustar00rootroot00000000000000futhark-0.25.27/tests/coalescing/coalescing1.fut000066400000000000000000000003111475065116200215160ustar00rootroot00000000000000-- == -- input { [[1,2],[3,4],[5,6]] } -- output { [4i32, 10i32, 16i32] } def main [n][m] (a: [n][m]i32): []i32 = map (\(r: []i32): i32 -> loop x = 0 for i < m do x * 2 + r[i]) a futhark-0.25.27/tests/coalescing/coalescing2.fut000066400000000000000000000021021475065116200215170ustar00rootroot00000000000000-- == -- input { [[[7i32, 10i32, 2i32], -- [4i32, 3i32, 1i32], -- [8i32, 4i32, 4i32], -- [0i32, 9i32, 9i32], -- [0i32, 1i32, 3i32], -- [2i32, 5i32, 10i32], -- [0i32, 5i32, 0i32]], -- [[5i32, 7i32, 6i32], -- [2i32, 2i32, 3i32], -- [4i32, 4i32, 7i32], -- [6i32, 10i32, 10i32], -- [5i32, 6i32, 10i32], -- [5i32, 1i32, 6i32], -- [1i32, 3i32, 9i32]], -- [[3i32, 2i32, 9i32], -- [4i32, 0i32, 7i32], -- [4i32, 6i32, 5i32], -- [5i32, 8i32, 5i32], -- [10i32, 8i32, 7i32], -- [5i32, 8i32, 7i32], -- [3i32, 6i32, 8i32]]] } -- output { [[50i32, 23i32, 44i32, 27i32, 5i32, 28i32, 10i32], -- [40i32, 15i32, 31i32, 54i32, 42i32, 28i32, 19i32], -- [25i32, 23i32, 33i32, 41i32, 63i32, 43i32, 32i32]] } def main [k][n][m] (rss: [k][n][m]i32): [][]i32 = map (\(rs: [][]i32) -> map (\(r: []i32): i32 -> loop x = 0 for i < m do x * 2 + r[i]) rs) rss futhark-0.25.27/tests/coalescing/coalescing3.fut000066400000000000000000000010051475065116200215210ustar00rootroot00000000000000-- == -- input { [[7i32, 10i32, 2i32, 4i32, 3i32, 1i32, 8i32], -- [4i32, 4i32, 0i32, 9i32, 9i32, 0i32, 1i32], -- [3i32, 2i32, 5i32, 10i32, 0i32, 5i32, 0i32]] } -- output { [[8i32, 11i32, 3i32, 5i32, 4i32, 2i32, 9i32], -- [5i32, 5i32, 1i32, 10i32, 10i32, 1i32, 2i32], -- [4i32, 3i32, 6i32, 11i32, 1i32, 6i32, 1i32]] } def main [n][m] (rss: *[n][m]i32): [][]i32 = map (\(rs: []i32) -> loop rs = copy rs for i < m do let rs[i] = rs[i] + 1 in rs) rss futhark-0.25.27/tests/coalescing/coalescing4.fut000066400000000000000000000005651475065116200215340ustar00rootroot00000000000000-- == -- structure gpu { Manifest 1 } def smoothen [n] (xs: [n]f32) = let pick i = xs[i64.min (n-1) (i64.max 0 i)] in tabulate n (\i -> pick (i-2) + pick (i-1) *4 + pick i * 6 + pick (i+1) * 4 + pick (i+2)) def main xss = xss |> transpose |> map transpose |> map (map smoothen) |> map transpose |> transpose futhark-0.25.27/tests/collision.fut000066400000000000000000000004411475065116200172160ustar00rootroot00000000000000-- Even though isnan32 is also the name of an intrinsic function, it's -- not actually the same thing! def isnan32 (x: i32) = let exponent = (x >> 23) & 0b11111111 let significand = x & 0b11111111111111111111111 in exponent == 0b11111111 && significand != 0 def main x = isnan32 x futhark-0.25.27/tests/complement.fut000066400000000000000000000002451475065116200173700ustar00rootroot00000000000000-- Test that complement works properly. -- == -- input { -- [1, 255, 0] -- } -- output { -- [-2, -256, -1] -- } def main(a: []i32): []i32 = map (\x -> !x) a futhark-0.25.27/tests/concat.fut000066400000000000000000000002131475065116200164670ustar00rootroot00000000000000-- == -- input { -- } -- output { -- 6 -- } def main: i32 = let a = [(1,(2,3))] let c = concat a a let (x,(y,z)) = c[1] in x+y+z futhark-0.25.27/tests/concat1.fut000066400000000000000000000003501475065116200165520ustar00rootroot00000000000000-- Basic 1D concatting. -- == -- input { -- [1,2] -- [3,4] -- } -- output { -- [1,2,3,4] -- } -- input { -- empty([0]i32) -- [1,2,3] -- } -- output { -- [1,2,3] -- } def main (a: []i32) (b: []i32): []i32 = concat a b futhark-0.25.27/tests/concat10.fut000066400000000000000000000001641475065116200166350ustar00rootroot00000000000000-- Simplifying away a redundant concat. -- == -- structure { Concat 0 } def main (xs: []i32) : *[]i32 = xs ++ [] futhark-0.25.27/tests/concat11.fut000066400000000000000000000002641475065116200166370ustar00rootroot00000000000000-- #1796 -- input { [1,2] } -- output { [[1, 2], [1, 2], [1, 2], [1, 2], [1, 2]] } -- structure { Replicate 1 Concat 0 } def main (xs: []i32) = replicate 2 xs ++ replicate 3 xs futhark-0.25.27/tests/concat12.fut000066400000000000000000000003431475065116200166360ustar00rootroot00000000000000-- Not safe to fuse these two replicates, at least not without being -- more clever about it than we currently are. -- == -- structure {Replicate 2} def main n m = replicate n (replicate n 0) ++ replicate m (replicate n 0) futhark-0.25.27/tests/concat2.fut000066400000000000000000000005341475065116200165570ustar00rootroot00000000000000-- == -- input { -- [[1,2],[3,4]] -- [[5,6],[7,8]] -- } -- output { -- [[1,2],[3,4],[5,6],[7,8]] -- } -- input { empty([0][0]i32) empty([0][0]i32) } output { empty([0][0]i32) } -- input { empty([0][1]i32) [[1]] } output { [[1]] } -- input { [[1]] empty([0][1]i32) } output { [[1]] } def main(a: [][]i32) (b: [][]i32): [][]i32 = concat a b futhark-0.25.27/tests/concat3.fut000066400000000000000000000002551475065116200165600ustar00rootroot00000000000000-- Fusion of concats. -- == -- input { [1] [2] [3] } output { [1,2,3] } -- structure { Concat 1 } def main (xs: []i32) (ys: []i32) (zs: []i32) = concat xs (concat ys zs) futhark-0.25.27/tests/concat4.fut000066400000000000000000000007201475065116200165560ustar00rootroot00000000000000-- Indexing into a concat. The simplifier should remove the concat. -- == -- input { [1,2,3] [4,5,6] [7,8,9] 1 } output { 2 } -- input { [1,2,3] [4,5,6] [7,8,9] 4 } output { 5 } -- input { [1,2,3] [4,5,6] [7,8,9] 7 } output { 8 } -- input { [1,2,3] [4,5,6] [7,8,9] 9 } error: .* -- input { [1,2,3] [4,5,6] [7,8,9] -1 } error: .* -- structure { Concat 0 } def main(as: []i32) (bs: []i32) (cs: []i32) (i: i32): i32 = let ds = concat (concat as bs) cs in ds[i] futhark-0.25.27/tests/concat7.fut000066400000000000000000000004641475065116200165660ustar00rootroot00000000000000-- Indexing into a concat across inner dimension. The simplifier -- should remove the concat. -- -- == -- input { [[1,1],[2,2],[3,3]] [[4],[5],[6]] 1 2 } output { 5 } -- structure { Concat 0 } def main [n][m] (as: [][n]i32) (bs: [][m]i32) (i: i32) (j: i32): i32 = let cs = map2 concat as bs in cs[i,j] futhark-0.25.27/tests/concat8.fut000066400000000000000000000002701475065116200165620ustar00rootroot00000000000000-- Simplification of concatenations of replicates and array literals. -- == -- structure { Replicate 0 ArrayLit 1 Concat 0 } def main (a: i32) (b: i32) (c: i32) = [a] ++ [b] ++ [c] futhark-0.25.27/tests/concat9.fut000066400000000000000000000005331475065116200165650ustar00rootroot00000000000000-- Simplification of concatenations of replicates of the same value, -- interspersed with array literals. -- == -- input { 2i64 3i64 } -- output { [42i32, 42i32, 42i32, 42i32, 42i32, 1i32, 2i32, 3i32, 4i32, 5i32, 42i32, 42i32, 42i32] } def main (n: i64) (m: i64) = replicate n 42 ++ replicate m 42 ++ [1,2,3] ++ [4,5] ++ replicate n 42 ++ [42] futhark-0.25.27/tests/constant_folding0.fut000066400000000000000000000001441475065116200206360ustar00rootroot00000000000000-- == -- input { 2 } output { 2 } -- structure { BinOp 0 } def main (x: i32) = ((x - 2) + 1) + 1 futhark-0.25.27/tests/constants/000077500000000000000000000000001475065116200165205ustar00rootroot00000000000000futhark-0.25.27/tests/constants/const0-error.fut000066400000000000000000000001431475065116200215730ustar00rootroot00000000000000-- Constants are properly type checked. -- == -- error: i32 def x: i32 = 2.0 def main(): i32 = 2 futhark-0.25.27/tests/constants/const0.fut000066400000000000000000000001371475065116200204470ustar00rootroot00000000000000-- Constants work at all. -- -- == -- input {} output { 2 } def v: i32 = 2 def main: i32 = v futhark-0.25.27/tests/constants/const1.fut000066400000000000000000000002441475065116200204470ustar00rootroot00000000000000-- Can a constant be an array of tuples? -- -- == -- input {} output { 3 } def v: [](i32,i32) = [(1,2)] def main: i32 = let (x,y) = v[0] in x + y futhark-0.25.27/tests/constants/const2.fut000066400000000000000000000002101475065116200204410ustar00rootroot00000000000000-- Can value declarations refer to each other? -- -- == -- input { } output { 3 } def x: i32 = 2 def y: i32 = x + 1 def main: i32 = y futhark-0.25.27/tests/constants/const3.fut000066400000000000000000000001621475065116200204500ustar00rootroot00000000000000-- -- == -- input { } output { [0,0,0] } def n: i64 = 3 def f(): [n]i32 = replicate n 0 def main: []i32 = f () futhark-0.25.27/tests/constants/const4.fut000066400000000000000000000002701475065116200204510ustar00rootroot00000000000000-- You can use a constant as a shape declaration in another constant. -- -- == -- input { } output { [0,0,0] } def n: i64 = 3 def x: [n]i32 = replicate n 0 def main: []i32 = copy x futhark-0.25.27/tests/constants/const5.fut000066400000000000000000000001731475065116200204540ustar00rootroot00000000000000-- == -- structure { Screma 1 } def big_sum = i64.sum (iota 1000000) def main b = if b then big_sum - 1 else big_sum + 1 futhark-0.25.27/tests/constants/const6.fut000066400000000000000000000001371475065116200204550ustar00rootroot00000000000000def number = 123 + 456 : i64 def array = iota number def sum = i64.sum array def main = sum futhark-0.25.27/tests/constants/const7.fut000066400000000000000000000003401475065116200204520ustar00rootroot00000000000000-- Same function used for two constants. Inlining must take care not -- to duplicate names. -- == -- input {} -- output { 8 } def f (x: i32) (y: i32) = let z = x + y in z def a = f 1 2 def b = f 2 3 def main = a + b futhark-0.25.27/tests/constants/const8.fut000066400000000000000000000002171475065116200204560ustar00rootroot00000000000000-- Fusion must also happen to constants -- == -- structure { Screma 1 } def n = 1000 : i64 def x = map (+2) (map (+3) (iota n)) def main = x futhark-0.25.27/tests/constants/const9.fut000066400000000000000000000006331475065116200204610ustar00rootroot00000000000000-- Some intermediate constants may not be live in the functions. -- Handle this gracefully. We can't test whether the memory is -- actually deallocated during initialisation (except through manual -- inspection), but we can at least check that this isn't fused -- unexpectedly. -- == -- structure { Screma 2 } def xs = map (+3) (iota 1000) def ys = copy xs with [4] = 0 def v = i64.sum ys def main a = a + v futhark-0.25.27/tests/convert_id.fut000066400000000000000000000003321475065116200173560ustar00rootroot00000000000000-- Test that certain numeric conversions are simplified away. -- == -- structure { ConvOp 4 } def main (x: i32) (y: u32) = (f32.i64 (i64.i32 x), i8.i64 (i64.i32 x), f32.u64 (u64.i32 x), u8.u64 (u64.i32 x)) futhark-0.25.27/tests/copyPropTest1.fut000066400000000000000000000017671475065116200177730ustar00rootroot00000000000000-- == -- input { -- } -- output { -- 52i64 -- } -- structure { Replicate 0 } def getInt (): i64 = if((1-1)*3 + (3/3 - 1) == 0) then (15 / 3)*2 else 10000000 def plus1 [n] (x: [n]i64) = map (\(y: i64): i64->y+1) x def main: i64 = let n = getInt() -- Int let x = iota(n) -- [#n]Int let m = (n*1)+(n*0) -- n :: Int let y = replicate m x -- [#n][#n]Int let u = map plus1 y -- [#n][#n]Int let z = replicate (m+n) y -- [[#n][#n]Int,m+n] let v = u[m/2-1] -- [#n]Int let o = (m +(2-4/2))*1 -- n :: Int let q = z[m/2] -- [#n][#n]Int let t = q[m/3, m/4] -- n/4 :: Int let w = x[m-n/2] -- n-n/2 :: Int let s = v[3] in -- u[m+n/2,3] :: Int x[m*n/n - 1] + y[o-1,if(o*(3-3) == 0) then o-1 else m*n*o] + u[0, m-n] + z[(1+3)/m, if(false || o*(3+(-9)/3)==0) then 3/5 else (4+2)/3, (m*1)/2 ] + s + t + q[0,n*1-1] + o + v[2] + (m - n + 0) futhark-0.25.27/tests/copyPropTest2.fut000066400000000000000000000013731475065116200177650ustar00rootroot00000000000000-- == -- input { -- } -- output { -- 91i64 -- 126i64 -- } -- structure { Replicate 0 } def getInt (): i64 = 10 def plus1(x: []i32): []i32 = map (\(y: i32): i32->y+1) x def main: (i64,i64) = let n = getInt() -- Int let x = iota(n) -- [#n]Int let m = (n * (5-4)) let y = copy(replicate n x) -- [#n][#n]Int copy necessary as y otherwise aliases x. let z = copy(replicate (n+n) y) -- [[#n][#n]Int,m+n]; copy necessary as z otherwise aliases x. let q = z[n-2] in -- [#n][#n]Int let (m,x) = loop ((m,x)) for i < n-1 do let x[i] = (m*1) let m = m + x[i+1] let m = m + z[n-1,n-2,i] in (m, x) let qq = m*(2-1) in (qq, m + x[n/2]) futhark-0.25.27/tests/copyPropTest3.fut000066400000000000000000000007671475065116200177740ustar00rootroot00000000000000-- == -- input { -- } -- output { -- 70i64 -- } def getInt(): i64 = 10 def myfun(x: (i64,i64,(i64,i64)) ): i64 = let (a,b,(c,d)) = x in a + b + c + d def main: i64 = let n = getInt() let a = (n, n, (n*0+5,n)) let (x1, x2) = (replicate n a, [ a, a, a, a, a, a, a, a, a, a ]) let y1 = replicate n x1 let y2 = replicate n x2 let z = y1[n-1] let (b,c,(d,e)) = z[n/2] let m = y2[0, (5-1)/2] let p = m in b + c + d + e + myfun(p) futhark-0.25.27/tests/cse0.fut000066400000000000000000000002401475065116200160520ustar00rootroot00000000000000-- CSE of things that don't look equal at first glance. -- == -- input { 1 3 } output { 4 4 } -- structure { BinOp 1 } def main (x: i32) (y: i32) = (x+y, y+x) futhark-0.25.27/tests/curry0.fut000066400000000000000000000003411475065116200164460ustar00rootroot00000000000000-- Some simple currying of operators. -- == -- input { -- [-3,-2,-1,0,1,2,3] -- } -- output { -- [-2, -1, -1, 0, 0, 1, 1] -- [5, 4, 3, 2, 1, 0, -1] -- } def main(a: []i32): ([]i32,[]i32) = (map (/ 2) a, map (2 -) a) futhark-0.25.27/tests/curry1.fut000066400000000000000000000007471475065116200164610ustar00rootroot00000000000000-- Test that we can curry even "complex" arguments. -- == -- input { -- [1.0,6.0,3.0,4.0,1.0,0.0] -- } -- output { -- 252.000000 -- } def f(x: (i64, f64)) (y: f64): f64 = let (a,b) = x in y*f64.i64(a)+b def g(x: [](f64,f64)) (y: f64): f64 = let (a,b) = unzip(x) in y + reduce (+) (0.0) a + reduce (+) (0.0) b def main(a: []f64): f64 = let b = map (f ((5,6.0))) a let c = map (g (zip ([1.0,2.0,3.0]) ([4.0,5.0,6.0]))) a in reduce (+) (0.0) b + reduce (+) (0.0) c futhark-0.25.27/tests/curry2.fut000066400000000000000000000002761475065116200164570ustar00rootroot00000000000000-- Curry a simple function. -- == -- input { -- [8,5,4,3,2,1] -- } -- output { -- [9,6,5,4,3,2] -- } def add(x: i32) (y: i32): i32 = x + y def main(a: []i32): []i32 = map (add(1)) a futhark-0.25.27/tests/deadCodeElimTest1.fut000066400000000000000000000004601475065116200204440ustar00rootroot00000000000000-- == -- input { -- 10i64 -- } -- output { -- -1i64 -- } def neg(x: i64): i64 = -x def main(a: i64): i64 = let b = a + 100 let x = iota(a) let c = b + 200 let z = 3*2 - 6 --let y = map(op ~, x) in let y = map neg x let d = c + 300 in if(false) then d + y[1] else y[1] futhark-0.25.27/tests/deadCodeElimTest2.fut000066400000000000000000000004601475065116200204450ustar00rootroot00000000000000-- == -- input { -- 10i64 -- } -- output { -- -1i64 -- } def neg(x: i64): i64 = -x def main(a: i64): i64 = let b = a + 100 let x = iota(a) let c = b + 200 let z = 3*2 - 6 --let y = map(op ~, x) in let y = map neg x let d = c + 300 in if(false) then d + y[1] else y[1] futhark-0.25.27/tests/dependence-analysis/000077500000000000000000000000001475065116200204175ustar00rootroot00000000000000futhark-0.25.27/tests/dependence-analysis/ad_acc.fut000066400000000000000000000212571475065116200223400ustar00rootroot00000000000000-- See issue 1989. -- == -- structure { UpdateAcc 3 } def gather1D 't [m] (arr1D: [m]t) (inds: [m]i32) : *[m]t = map (\ind -> arr1D[ind] ) inds def gather2D 't [m][d] (arr2D: [m][d]t) (inds: [m]i32) : *[m][d]t = map (\ind -> map (\j -> arr2D[ind,j]) (iota d) ) inds def sumSqrsSeq [d] (xs: [d]f32) (ys: [d]f32) : f32 = loop (res) = (0.0f32) for (x,y) in (zip xs ys) do let z = x-y in res + z*z def log2 x = (loop (y,c) = (x,0i32) while y > 1i32 do (y >> 1, c+1)).1 def isLeaf (h: i32) (node_index: i32) = node_index >= ((1 << (h+1)) - 1) def findLeaf [q][d] (median_dims: [q]i32) (median_vals: [q]f32) (height: i32) (query: [d]f32) = let leaf = loop (node_index) = (0) while !(isLeaf height node_index) do if query[median_dims[node_index]] <= median_vals[node_index] then (node_index+1)*2-1 else (node_index+1)*2 in leaf - i32.i64 q def traverseOnce [q] [d] (radius: f32) (height: i32) (kd_tree: [q](i32,f32,i32)) (query: [d]f32) (last_leaf: i32, stack: i32, dist: f32) : (i32, i32, f32) = let (median_dims, median_vals, clanc_eqdim) = unzip3 kd_tree let last_leaf = last_leaf + i32.i64 q let no_leaf = 2*q + 1 let getPackedInd (stk: i32) (ind: i32) : bool = let b = stk & (1<> (ind+1)) << (ind+1) let mid = if v then (1 << ind) else 0 in ( (fst | snd) | mid ) let getLevel (node_idx: i32) : i32 = log2 (node_idx+1) let getAncSameDimContrib (q_m_i: f32) (node_stack: i32) (node: i32) : f32 = (loop (idx, res) = (node, 0.0f32) while (idx >= 0) do let anc = clanc_eqdim[idx] in if anc == (-1i32) then (-1i32, 0.0f32) else let anc_lev = getLevel anc let is_anc_visited = getPackedInd node_stack anc_lev in if !is_anc_visited then (anc, res) else (-1i32, median_vals[anc] - q_m_i) ).1 let (parent_rec, stack, count, dist, rec_node) = loop (node_index, stack, count, dist, rec_node) = (last_leaf, stack, height, dist, -1) for _i2 < height+1 do if (node_index != 0) && (rec_node < 0) then let parent = (node_index-1) / 2 let scnd_visited = getPackedInd stack count --stack[count] let q_m_d = query[median_dims[parent]] let cur_med_dst = median_vals[parent] - q_m_d let cur_med_sqr = cur_med_dst * cur_med_dst let prv_med_dst = getAncSameDimContrib q_m_d stack parent let prv_med_sqr = prv_med_dst * prv_med_dst let dist_minu = f32.abs(dist - cur_med_sqr + prv_med_sqr) let dist_plus = f32.abs(dist - prv_med_sqr + cur_med_sqr) in if scnd_visited then -- continue backing-up towards the root (parent, stack, count-1, dist_minu, -1) else -- the node_index is actually the `first` child of parent, let to_visit = dist_plus <= radius in if !to_visit then (parent, stack, count-1, dist, -1) else -- update the stack let fst_node = node_index let snd_node = if (fst_node % 2) == 0 then fst_node-1 else fst_node+1 let stack = setPackedInd stack count true in (parent, stack, count, dist_plus, snd_node) else (node_index, stack, count, dist, rec_node) let (new_leaf, new_stack, _) = if parent_rec == 0 && rec_node == -1 then -- we are done, we are at the root node (i32.i64 no_leaf, stack, 0) else -- now traverse downwards by computing `first` loop (node_index, stack, count) = (rec_node, stack, count) for _i3 < height+1 do if isLeaf height node_index then (node_index, stack, count) else let count = count+1 let stack = setPackedInd stack count false let node_index = if query[median_dims[node_index]] <= median_vals[node_index] then (node_index+1)*2-1 else (node_index+1)*2 in (node_index, stack, count) in (new_leaf-i32.i64 q, new_stack, dist) def sortQueriesByLeavesRadix [n] (leaves: [n]i32) : ([n]i32, [n]i32) = (leaves, map i32.i64 (iota n)) def bruteForce [m][d] (radius: f32) (query: [d]f32) (query_w: f32) (leaf_refs : [m][d]f32) (leaf_ws : [m]f32) : f32 = map2(\ref i -> let dist = sumSqrsSeq query ref in if dist <= radius then query_w * leaf_ws[i] else 0.0f32 ) leaf_refs (iota m) |> reduce (+) 0.0f32 def iterationSorted [q][n][d][num_leaves][ppl] (radius: f32) (h: i32) (kd_tree: [q](i32,f32,i32)) (leaves: [num_leaves][ppl][d]f32) (ws: [num_leaves][ppl]f32) (queries: [n][d]f32) (query_ws:[n]f32) (qleaves: [n]i32) (stacks: [n]i32) (dists: [n]f32) (query_inds: [n]i32) (res: f32) : ([n]i32, [n]i32, [n]f32, [n]i32, f32) = let queries_sorted = gather2D queries query_inds let query_ws_sorted= gather1D query_ws query_inds let new_res = map3 (\ query query_w leaf_ind -> if leaf_ind >= i32.i64 num_leaves then 0.0f32 else bruteForce radius query query_w (leaves[leaf_ind]) (ws[leaf_ind]) ) queries_sorted query_ws_sorted qleaves |> reduce (+) 0.0f32 |> opaque let (new_leaves, new_stacks, new_dists) = unzip3 <| map4 (\ query leaf_ind stack dist -> if leaf_ind >= i32.i64 num_leaves then (leaf_ind, stack, dist) else traverseOnce radius h kd_tree query (leaf_ind, stack, dist) ) queries_sorted qleaves stacks dists |> opaque let (qleaves', sort_inds) = sortQueriesByLeavesRadix new_leaves let stacks' = gather1D new_stacks sort_inds let dists' = gather1D new_dists sort_inds let query_inds' = gather1D query_inds sort_inds in (qleaves', stacks', dists', query_inds', res + new_res) def propagate [m1][m][q][d][n] (radius: f32) (ref_pts: [m][d]f32) (indir: [m]i32) (kd_tree: [q](i32,f32,i32)) (queries: [n][d]f32) (query_ws:[n]f32, ref_ws_orig: [m1]f32) : f32 = let kd_weights = map i64.i32 indir |> map (\ind -> if ind >= m1 then 1.0f32 else ref_ws_orig[ind]) let (median_dims, median_vals, _) = unzip3 kd_tree let num_nodes = q -- trace q let num_leaves = num_nodes + 1 let h = (log2 (i32.i64 num_leaves)) - 1 let ppl = m / num_leaves let leaves = unflatten (sized (num_leaves*ppl) ref_pts) let kd_ws_sort = unflatten (sized (num_leaves*ppl) kd_weights) let query_leaves = map (findLeaf median_dims median_vals h) queries let (qleaves, query_inds) = sortQueriesByLeavesRadix query_leaves let dists = replicate n 0.0f32 let stacks = replicate n 0i32 let res_ws = 0f32 let (_qleaves', _stacks', _dists', _query_inds', res_ws') = loop (qleaves : [n]i32, stacks : [n]i32, dists : [n]f32, query_inds : [n]i32, res_ws : f32) for _i < 8 do iterationSorted radius h kd_tree leaves kd_ws_sort queries query_ws qleaves stacks dists query_inds res_ws in res_ws' def rev_prop [m1][m][q][d][n] (radius: f32) (ref_pts: [m][d]f32) (indir: [m]i32) (kd_tree: [q](i32,f32,i32)) (queries: [n][d]f32) (query_ws:[n]f32, ref_ws_orig: [m1]f32) : (f32, ([n]f32, [m1]f32)) = let f = propagate radius ref_pts indir kd_tree queries in vjp2 f (query_ws, ref_ws_orig) 1.0f32 def main [d][n][m][m'][q] (sq_radius: f32) (queries: [n][d]f32) (query_ws: [n]f32) (ref_ws: [m]f32) (refs_pts : [m'][d]f32) (indir: [m']i32) (median_dims : [q]i32) (median_vals : [q]f32) (clanc_eqdim : [q]i32) = let (res, (query_ws_adj, ref_ws_adj)) = rev_prop sq_radius refs_pts indir (zip3 median_dims median_vals clanc_eqdim) queries (query_ws, ref_ws) in (res, query_ws_adj, ref_ws_adj) futhark-0.25.27/tests/dependence-analysis/hist0.fut000066400000000000000000000005651475065116200221740ustar00rootroot00000000000000-- == -- structure { Screma/Hist/BinOp 1 } -- The two reduce_by_index get fused into a single histogram operation. def main [m][n] (A: [m]([n]i32, [n]i32)) = let r = loop A for _i < n do map (\(a, b) -> (reduce_by_index (replicate n 0) (+) 0 (map i64.i32 a) a, reduce_by_index (replicate n 0) (+) 0 (map i64.i32 a) b)) A in map (.0) r futhark-0.25.27/tests/dependence-analysis/hist1.fut000066400000000000000000000005551475065116200221740ustar00rootroot00000000000000-- == -- structure { Screma/Hist 1 } -- The two reduce_by_index produce two separate histogram operations. def main [m][n] (A: [m]([n]i32, [n]i32)) = let r = loop A for _i < n do map (\(a, b) -> (reduce_by_index (replicate n 0) (+) 0 (map i64.i32 a) a, reduce_by_index (replicate n 0) (+) 0 (map i64.i32 b) b)) A in map (.0) r futhark-0.25.27/tests/dependence-analysis/jvp0.fut000066400000000000000000000002421475065116200220140ustar00rootroot00000000000000-- == -- structure { BinOp 2 } def main (A: [](i32,i32)) (n: i64) = let r = loop A for _i < n do jvp (map (\(a,b) -> (a*a,b*b))) A A in map (.0) r futhark-0.25.27/tests/dependence-analysis/map0.fut000066400000000000000000000002631475065116200217750ustar00rootroot00000000000000-- == -- structure { BinOp 1 } def main (A: [](i32,i32)) (n: i64) = let r = loop A for i < n do map (\(a,b) -> if i == 0 then (a,b) else (a+1,b+1)) A in map (.0) r futhark-0.25.27/tests/dependence-analysis/reduce0.fut000066400000000000000000000004331475065116200224660ustar00rootroot00000000000000-- == -- structure { BinOp 2 } def plus (a,b) (x,y): (i32,i32) = (a+x,b+y) def main (A: [](i32,i32)) (n: i64) = let r = loop r' = (0,0) for i < n do reduce (\(a,b) (x,y) -> if i == 0 then (a,b) else (a+x,b+y)) (0,0) (map (plus r') A) in r.0 futhark-0.25.27/tests/dependence-analysis/scan0.fut000066400000000000000000000003001475065116200221340ustar00rootroot00000000000000-- == -- structure { BinOp 1 } def main (A: [](i32,i32)) (n: i64) = let r = loop A for i < n do scan (\(a,b) (x,y) -> if i == 0 then (a,b) else (a+x,b+y)) (0,0) A in map (.0) r futhark-0.25.27/tests/dependence-analysis/scan1.fut000066400000000000000000000006361475065116200221510ustar00rootroot00000000000000-- A simple streamSeq; does not exercise the Stream case in opDependencies, -- but dead code is still removed by the Screma case. -- == -- structure { Stream/BinOp 2 } -- structure { Screma/BinOp 2 } def plus (a,b) (x,y): (i32,i32) = (a+x,b+y) def main (xs: [](i32,i32)) (n: i64) = let r = loop xs for i < n do map (\(x,y) -> if i == 0 then (x,y) else (x+1,y+2)) (scan plus (0,0) xs) in map (.0) r futhark-0.25.27/tests/dependence-analysis/scatter0.fut000066400000000000000000000002331475065116200226620ustar00rootroot00000000000000-- == -- structure { Replicate 1 } def main [n] (A: *[n](i32,i32)) = let r = loop A for _i < n do scatter A (iota n) (copy A) in map (.0) r futhark-0.25.27/tests/dependence-analysis/vjp0.fut000066400000000000000000000002421475065116200220140ustar00rootroot00000000000000-- == -- structure { BinOp 2 } def main (A: [](i32,i32)) (n: i64) = let r = loop A for _i < n do vjp (map (\(a,b) -> (a*a,b*b))) A A in map (.0) r futhark-0.25.27/tests/distribution/000077500000000000000000000000001475065116200172235ustar00rootroot00000000000000futhark-0.25.27/tests/distribution/branch0.fut000066400000000000000000000003061475065116200212570ustar00rootroot00000000000000-- Interchange map-invariant branches to exploit all parallelism. -- == -- structure gpu { Kernel/If 0 } def main (b: bool) (xs: []i32) (ys: []i32) = map (\x -> if b then map (+x) ys else ys) xs futhark-0.25.27/tests/distribution/distribution0.fut000066400000000000000000000014041475065116200225410ustar00rootroot00000000000000-- Expected distributed/sequentialised structure: -- -- map -- map -- sequential... -- -- == -- -- structure gpu { SegMap 1 Loop 2 } def fftmp (num_paths: i64) (md_c: [][]f64) (zi: []f64): [num_paths]f64 = #[incremental_flattening(only_outer)] map (\(j: i64): f64 -> let x = map2 (*) (take(j+1) zi) (take (j+1) md_c[j]) in reduce (+) (0.0) x ) (iota(num_paths) ) def correlateDeltas [n] (num_paths: i64) (md_c: [n][]f64) (zds: [][]f64): [n][num_paths]f64 = #[incremental_flattening(only_inner)] map (fftmp num_paths md_c) zds def main (num_paths: i64) (md_c: [][]f64) (bb_mat: [][][]f64): [][][]f64 = #[incremental_flattening(only_inner)] map (\bb_arr -> correlateDeltas num_paths md_c bb_arr) bb_mat futhark-0.25.27/tests/distribution/distribution1.fut000066400000000000000000000021161475065116200225430ustar00rootroot00000000000000-- Expected distributed/sequentialised structure: -- -- map -- map -- map -- -- map -- map -- map -- map -- map -- scan -- -- == -- structure gpu { SegMap 1 } def combineVs [n] (n_row: [n]f64, vol_row: [n]f64, dr_row: [n]f64): [n]f64 = map2 (+) dr_row (map2 (*) n_row vol_row) def mkPrices [num_und] [num_dates] (md_starts: [num_und]f64, md_vols: [num_dates][num_und]f64, md_drifts: [num_dates][num_und]f64, noises: [num_dates][num_und]f64): [num_dates][num_und]f64 = let e_rows = map (\(x: []f64) -> map f64.exp x ) (map combineVs (zip3 noises (md_vols) (md_drifts))) in scan (\(x: []f64) (y: []f64) -> map2 (*) x y) md_starts e_rows --[#num_dates, num_paths] def main(md_vols: [][]f64, md_drifts: [][]f64, md_starts: []f64, noises_mat: [][][]f64): [][][]f64 = #[incremental_flattening(only_inner)] map (\(noises: [][]f64) -> #[incremental_flattening(only_inner)] mkPrices(md_starts, md_vols, md_drifts, noises)) ( noises_mat) futhark-0.25.27/tests/distribution/distribution10.fut000066400000000000000000000012201475065116200226160ustar00rootroot00000000000000-- Once failed in kernel extraction. The problem was that the map and -- reduce are fused together into a redomap with a map-out array. -- This was not handled correctly when it was turned into a -- group-level stream. -- -- == -- structure gpu { SegMap 2 SegMap/Loop 1 } def indexOfMax8 ((x,i): (u8,i32)) ((y,j): (u8,i32)): (u8,i32) = if x < y then (y,j) else (x,i) def max8 (max_v: u8) (v: u8): u8 = if max_v < v then v else max_v def main [h][w] (frame : [h][w]i32) : [h][w]u8 = map (\row: [w]u8 -> let rs = map u8.i32 row let m = #[sequential] reduce max8 0u8 rs let rs' = map (max8 m) rs in rs') frame futhark-0.25.27/tests/distribution/distribution11.fut000066400000000000000000000006241475065116200226260ustar00rootroot00000000000000-- This one screwed up multi-versioning at one point. The problem was -- that loop interchange produced identity maps. def main [n][m] (xss: *[n][m]i32) = map (\(xs: []i32) -> let ys = copy xs let (xs, _) = loop (zs: [m]i32, ys: [m]i32) = (xs, ys) for i < n do let xs' = scatter (copy ys) (iota m) (rotate 1 zs) in (xs', zs) in xs) xss futhark-0.25.27/tests/distribution/distribution12.fut000066400000000000000000000002321475065116200226220ustar00rootroot00000000000000-- A triply nested map should not cause any multi-versioning. -- == -- structure gpu { SegMap 1 } def main (xsss: [][][]i32) = map (map (map (+1))) xsss futhark-0.25.27/tests/distribution/distribution2.fut000066400000000000000000000027111475065116200225450ustar00rootroot00000000000000-- A combination of distribution0.fut and distribution1.fut. AKA the -- blackScholes computation from GenericPricing. -- -- == -- structure gpu { -- SegMap 6 -- Loop 10 -- } def fftmp (num_paths: i64) (md_c: [][]f64) (zi: []f64): [num_paths]f64 = map (\(j: i64): f64 -> let x = map2 (*) (take (j+1) zi) (take (j+1) md_c[j]) in reduce (+) (0.0) x ) (iota num_paths) def correlateDeltas [n] (num_paths: i64) (md_c: [][]f64) (zds: [n][]f64): [n][num_paths]f64 = map (fftmp num_paths md_c) zds def combineVs [n] (n_row: [n]f64, vol_row: [n]f64, dr_row: [n]f64): [n]f64 = map2 (+) dr_row (map2 (*) n_row vol_row) def mkPrices [num_und][num_dates] (md_starts: [num_und]f64, md_vols: [num_dates][num_und]f64, md_drifts: [num_dates][num_und]f64, noises: [num_dates][num_und]f64): [num_dates][num_und]f64 = let e_rows = map (\(x: []f64) -> map f64.exp x) ( map combineVs (zip3 noises (md_vols) (md_drifts)) ) in scan (\(x: []f64) (y: []f64) -> map2 (*) x y) ( md_starts) (e_rows ) --[num_dates, num_paths] def main(num_paths: i64) (md_c: [][]f64) (md_vols: [][]f64) (md_drifts: [][]f64) (md_starts: []f64) (bb_mat: [][][]f64): [][][]f64 = map (\(bb_row: [][]f64) -> let noises = correlateDeltas num_paths md_c bb_row in mkPrices(md_starts, md_vols, md_drifts, noises)) bb_mat futhark-0.25.27/tests/distribution/distribution3.fut000066400000000000000000000007201475065116200225440ustar00rootroot00000000000000-- == -- compiled random input { [10][16][16]i32 } auto output -- compiled random input { [10][8][32]i32 } auto output -- structure gpu { SegScan 4 } def main [k][n][m] (a: [k][n][m]i32): [][][]i32 = map (\(a_row: [][]i32): [m][n]i32 -> let b = map (\(a_row_row: []i32) -> scan (+) 0 (a_row_row) ) (a_row) in map (\(b_col: []i32) -> scan (+) 0 (b_col)) (transpose b) ) a futhark-0.25.27/tests/distribution/distribution3.fut.tuning000066400000000000000000000000261475065116200240460ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/distribution/distribution4.fut000066400000000000000000000005171475065116200225510ustar00rootroot00000000000000-- Expected distributed structure: -- -- map -- map -- map -- map -- -- == -- structure gpu { SegMap 5 } def main [n][an] [bn] (a: [n][an]i32, b: [n][bn]i32): ([][]i32,[][]i32) = unzip(map2 (\(a_row: []i32) (b_row: []i32): ([an]i32,[bn]i32) -> (map (+1) a_row, map (\x -> x-1) b_row)) a b) futhark-0.25.27/tests/distribution/distribution5.fut000066400000000000000000000013161475065116200225500ustar00rootroot00000000000000-- Expected distributed structure: -- -- map -- map -- map -- map -- map -- concat -- map -- map -- reduce (which becomes a segmented reduction) -- -- == -- structure gpu { -- SegMap 2 SegRed 1 -- } def main [k][n][an][bn] (a: [n][an][k]i32) (b: [n][bn]i32): ([][]i32,[][]i32) = #[incremental_flattening(only_inner)] unzip( #[incremental_flattening(only_inner)] map2 (\(a_row: [][]i32) (b_row: []i32): ([bn]i32,[an]i32) -> #[incremental_flattening(only_inner)] (map (\x -> x-1) (b_row), map (\(a_row_row: []i32): i32 -> let x = map (+1) (a_row_row) in reduce (+) 0 (concat x x) ) a_row)) a b) futhark-0.25.27/tests/distribution/distribution6.fut000066400000000000000000000003001475065116200225410ustar00rootroot00000000000000-- == -- structure gpu { SegMap 1 } -- def main(outer_loop_count: i64, a: []i64): [][]i64 = map (\(i: i64) -> let x = 10 * i in map (*x) a) (iota(outer_loop_count)) futhark-0.25.27/tests/distribution/distribution8.fut000066400000000000000000000023301475065116200225500ustar00rootroot00000000000000-- Like distribution2.fut, but with an outer sequential loop. Does -- not compute anything meaningful. -- -- map -- map -- map -- -- map -- map -- map -- map -- map -- scan -- -- == -- structure gpu { Loop/SegMap 1 Loop 2 } def combineVs [n] (n_row: [n]f64, vol_row: [n]f64, dr_row: [n]f64): [n]f64 = map2 (+) dr_row (map2 (*) n_row vol_row) def mkPrices [num_und][num_dates] (md_starts: [num_und]f64, md_vols: [num_dates][num_und]f64, md_drifts: [num_dates][num_und]f64, noises: [num_dates][num_und]f64): [num_dates][num_und]f64 = let e_rows = map (\(x: []f64) -> map f64.exp x ) (map combineVs (zip3 noises (md_vols) (md_drifts))) in scan (\(x: []f64) (y: []f64) -> map2 (*) x y) md_starts e_rows def main(n: i32, md_vols: [][]f64, md_drifts: [][]f64, md_starts: []f64, noises_mat: [][][]f64): [][][]f64 = loop (noises_mat) for i < n do #[incremental_flattening(only_inner)] map (\(noises: [][]f64) -> #[incremental_flattening(only_inner)] mkPrices(md_starts, md_vols, md_drifts, noises)) ( noises_mat) futhark-0.25.27/tests/distribution/distribution9.fut000066400000000000000000000006211475065116200225520ustar00rootroot00000000000000-- Test that we sequentialise the distributed kernels. Currently done -- by never having parallel constructs inside branches. If we ever -- start doing something clever with branches, this test may have to -- be revised. -- -- == -- structure gpu { If/Kernel 0 } def main(a: [][]i32): [][]i32 = map (\a_r -> if a_r[0] > 0 then map (*2) (a_r) else map (*3) (a_r) ) a futhark-0.25.27/tests/distribution/gather0.fut000066400000000000000000000003101475065116200212670ustar00rootroot00000000000000-- Sometimes it is quite important that the parallelism in an array -- indexing result is exploited. -- == -- structure gpu { SegMap 2 } def main (is: []i32) (xss: [][]f32) = map (\i -> xss[i]) is futhark-0.25.27/tests/distribution/icfp16-example.fut000066400000000000000000000030411475065116200224620ustar00rootroot00000000000000-- This is the program used as a demonstration example in our paper -- for ICFP 2016. -- -- == -- input { -- [[1,2,3],[3,2,1],[4,5,6]] -- } -- output { -- [[[1i32, 2i32, 3i32], -- [2i32, 3i32, 4i32], -- [5i32, 6i32, 7i32]], -- [[7i32, 6i32, 5i32], -- [4i32, 3i32, 2i32], -- [3i32, 2i32, 1i32]], -- [[14i32, 15i32, 16i32], -- [24i32, 25i32, 26i32], -- [39i32, 40i32, 41i32]]] -- [[92i32, 142i32, 276i32], -- [276i32, 142i32, 92i32], -- [662i32, 1090i32, 1728i32]] -- } -- structure gpu { -- Loop/SegMap 1 -- SegMap 2 -- SegRed 1 -- } def main [n][m] (pss: [n][m]i32): ([n][m][m]i32, [n][m]i32) = let (asss, bss) = #[incremental_flattening(only_inner)] unzip(map (\(ps: []i32): ([m][m]i32, [m]i32) -> #[incremental_flattening(only_inner)] let ass = map (\(p: i32): [m]i32 -> let cs = scan (+) 0 (0..1.. let d = reduce (+) 0 as let e = d + b let b' = 2 * e in b') ( zip ass bs) in bs' in (ass, bs')) pss) in (asss, bss) futhark-0.25.27/tests/distribution/inplace0.fut000066400000000000000000000005431475065116200214400ustar00rootroot00000000000000-- Distribution should not choke on a map that consumes its input. def main [m][n] (a: [m][n]i32) (is: [n]i32) (js: [n]i32): [][]i32 = map (\(a_r: []i32) -> let double = map (*2) (a_r) let triple = map (*3) (a_r) in loop a_r = copy a_r for i < n do let a_r[i] = double[is[i]] * triple[js[i]] in a_r ) a futhark-0.25.27/tests/distribution/inplace1.fut000066400000000000000000000004041475065116200214350ustar00rootroot00000000000000-- Ensure that in-place updates with invariant indexes/values are -- distributed sensibly. -- == -- input { [[1,2], [3,4]] } -- output { [[0,2], [0,4]] } -- structure { Replicate 0 } def main (xss: *[][]i32) = map (\(xs: []i32) -> copy xs with [0] = 0) xss futhark-0.25.27/tests/distribution/inplace2.fut000066400000000000000000000004711475065116200214420ustar00rootroot00000000000000-- Ensure that in-place updates with variant indexes/values are -- distributed sensibly. -- == -- input { [[1,2], [3,4]] [0,1] [42,1337] } -- output { [[42,2], [3,1337]] } -- structure { Replicate 0 } def main (xss: *[][]i32) (is: []i32) (vs: []i32) = map3 (\(xs: []i32) i v -> copy xs with [i] = v) xss is vs futhark-0.25.27/tests/distribution/inplace3.fut000066400000000000000000000004311475065116200214370ustar00rootroot00000000000000-- Good distribution of an in-place update of a slice. Should not -- produce a sequential Update statement. -- == -- random input { [2][12]i64 } auto output -- structure gpu { SegMap/Update 0 } def main [n][m] (xss: *[n][m]i64) = map (\xs -> copy xs with [0:10] = iota 10) xss futhark-0.25.27/tests/distribution/inplace4.fut000066400000000000000000000004731475065116200214460ustar00rootroot00000000000000-- Distributing an in-place update of slice with a bounds check. -- == -- input { [[1,2,3],[4,5,6]] [0i64,1i64] [42,1337] } -- output { [[42,1337,3],[4,42,1337]] } -- structure gpu { SegMap/Update 0 } def main [n][m] (xss: *[n][m]i32) (is: [n]i64) (ys: [2]i32) = map2 (\xs i -> copy xs with [i:i+2] = ys) xss is futhark-0.25.27/tests/distribution/inplace5.fut000066400000000000000000000004111475065116200214370ustar00rootroot00000000000000-- Distributed in-place update where slice is not final dimension. -- == -- random input { 1i64 [2][12][2]i64 } auto output -- structure gpu { SegMap/Update 0 } def main [n][m] (l: i64) (xsss: *[n][m][2]i64) = map (\xss -> copy xss with [0:10,l] = iota 10) xsss futhark-0.25.27/tests/distribution/inplace6.fut000066400000000000000000000004411475065116200214430ustar00rootroot00000000000000-- Distributed in-place update where slice is final dimension but there are more indexes. -- == -- random input { 1i64 [2][2][12]i64 } auto output -- structure gpu { SegMap/Update 0 } def main [n][m] (l: i64) (xsss: *[n][2][m]i64) = map (\xss -> copy xss with [l, 0:10] = iota 10) xsss futhark-0.25.27/tests/distribution/inplace7.fut000066400000000000000000000003221475065116200214420ustar00rootroot00000000000000-- == -- structure gpu { SegMap 1 } def main iss = map (\is -> let acc = replicate 3 0 let acc[is[0]] = 1 let acc[is[1]] = 2 let acc[is[2]] = 3 in acc) iss futhark-0.25.27/tests/distribution/irregular0.fut000066400000000000000000000005411475065116200220170ustar00rootroot00000000000000-- Very irregular parallelism. Should not distribute, but just -- sequentialise body. -- -- == -- input { [1,2,3,4,5,6,7,8,9] } -- output { [1, 3, 6, 10, 15, 21, 28, 36, 45] } -- structure gpu { -- SegMap 1 -- } def main(a: []i32): []i32 = #[incremental_flattening(only_inner)] map (\(i: i32): i32 -> reduce (+) 0 (map (+1) (0.. loop(acc) for i < m do map2 (+) acc (a_r[i]) ) a futhark-0.25.27/tests/distribution/loop1.fut000066400000000000000000000016561475065116200210050ustar00rootroot00000000000000-- Like loop0.fut, but the sequential loop also carries a scalar -- variable that is not mapped. -- -- == -- -- input { -- [[[1,7],[9,4],[8,6]], -- [[1,0],[6,4],[1,6]]] -- } -- output { -- [[18, 17], [8, 10]] -- [8, 8] -- } -- -- structure gpu { Map/Loop 0 } def main [n][m][k] (a: [n][m][k]i32): ([n][k]i32,[n]i32) = let acc = replicate k 0 let accnum = 1 in unzip(map (\a_r -> loop((acc,accnum)) for i < m do (map2 (+) acc (a_r[i]), accnum + accnum) ) a) -- Example of what we want - this is dead code. def main_distributed [n][m][k] (a: [n][m][k]i32): ([n][k]i32,[n]i32) = let acc_expanded = replicate n (replicate k 0) let accnum_expanded = replicate n 1 in loop((acc_expanded,accnum_expanded)) for i < m do unzip(map3 (\acc accnum a_r -> (map2 (+) acc (a_r[i]), accnum * accnum) ) (acc_expanded) (accnum_expanded) a) futhark-0.25.27/tests/distribution/loop2.fut000066400000000000000000000007451475065116200210040ustar00rootroot00000000000000-- More tricky variant of loop0.fut where expanding the initial merge -- parameter values is not so simple. -- -- == -- -- input { -- [[[1,7],[9,4],[8,6]], -- [[1,0],[6,4],[1,6]]] -- } -- output { -- [[19, 24], -- [9, 10]] -- } -- -- structure gpu { Map/Loop 0 } def main [n][m][k] (a: [n][m][k]i32): [n][k]i32 = map (\(a_r: [m][k]i32): [k]i32 -> let acc = a_r[0] in #[sequential] loop(acc) for i < m do map2 (+) acc (a_r[i]) ) a futhark-0.25.27/tests/distribution/loop3.fut000066400000000000000000000004111475065116200207730ustar00rootroot00000000000000-- Simplified variant of loop2.fut with lower-rank arrays. -- -- == -- -- structure gpu { Map/Loop 0 } def main [n][k] (m: i32, a: [n][k]i32): [n][k]i32 = map (\a_r -> let acc = a_r in loop(acc) for i < m do map2 (+) acc (a_r) ) a futhark-0.25.27/tests/distribution/loop4.fut000066400000000000000000000005231475065116200210000ustar00rootroot00000000000000-- Distribution with maps consuming their input. -- -- == -- -- structure gpu { Map/Loop 0 } def main [n][k] (m: i32) (a: [n][k]i32): [n][k]i32 = map (\a_r -> let a_r_copy = copy(a_r) in loop acc = a_r_copy for i < m do let acc' = copy(map2 (+) acc (a_r)) let acc'[0] = 0 in acc' ) a futhark-0.25.27/tests/distribution/loop5.fut000066400000000000000000000007301475065116200210010ustar00rootroot00000000000000-- More distribution with maps consuming their input. -- -- == -- -- structure gpu { Map/Loop 0 } def main [n][m][k] (a: *[n][m][k]i32): [n][m][k]i32 = map (\(a_r: [m][k]i32): [m][k]i32 -> loop(a_r) for i < m do map (\(a_r_r: [k]i32): *[k]i32 -> loop a_r_r = copy a_r_r for i < k-2 do let a_r_r[i+2] = a_r_r[i+2] + a_r_r[i] - a_r_r[i+1] in a_r_r) ( a_r) ) a futhark-0.25.27/tests/distribution/loop6.fut000066400000000000000000000006171475065116200210060ustar00rootroot00000000000000-- Interchange of a loop where some parts are dead after the loop. -- == -- structure gpu { /SegMap 0 /Loop 1 /Loop/SegMap 1 } def main [m] [n] (xss: *[m][n]i64) = #[incremental_flattening(only_inner)] map (\xs -> (loop (xs,out) = (xs, replicate n 0f32) for i < n do (let xs = map (+1) xs let out = map2 (+) (map f32.i64 xs) out in (xs, out))).1 ) xss futhark-0.25.27/tests/distribution/loop7.fut000066400000000000000000000004761475065116200210120ustar00rootroot00000000000000-- Must realise that the 'take (i+1)' is invariant to the 'map' after -- interchange. -- == -- structure gpu { Loop/SegRed 1 } def main [n] (xs: [n]i32) = #[incremental_flattening(only_inner)] map (\x -> loop acc = x for i < n-1 do #[unsafe] acc + i32.sum (take (i+1) xs)) xs futhark-0.25.27/tests/distribution/map-duplicate.fut000066400000000000000000000002611475065116200224670ustar00rootroot00000000000000-- A map with duplicate outputs should work. -- == -- structure gpu { SegMap 0 Replicate 2 } def main (n: i64) (m: i64) = map (\i -> (replicate m i, replicate m i)) (iota n) futhark-0.25.27/tests/distribution/map-replicate.fut000066400000000000000000000005171475065116200224710ustar00rootroot00000000000000-- Test that a map containing a (variant) replicate becomes a fully -- parallel kernel, with no sequential replicate. -- -- == -- input { [1,2,3] 2i64 } -- output { [[1,1], [2,2], [3,3]] } -- structure gpu { SegMap 0 Replicate 1 } def main [n] (xs: [n]i32) (m: i64): [n][m]i32 = map (\(x: i32): [m]i32 -> replicate m x) xs futhark-0.25.27/tests/distribution/redomap0.fut000066400000000000000000000002631475065116200214530ustar00rootroot00000000000000-- Distribute a redomap inside of a map. -- == -- structure gpu { SegRed 1 } def main(a: [][]i32): []i32 = map (\(a_r: []i32): i32 -> reduce (+) 0 (map (+1) (a_r))) a futhark-0.25.27/tests/distribution/scatter0.fut000066400000000000000000000003601475065116200214670ustar00rootroot00000000000000-- A mapped scatter should be parallelised. -- input { [[1,2,3],[4,5,6]] [2,0] [42,1337] } -- output { [[1337, 2, 42], [1337, 5, 42]] } def main (xss: *[][]i32) (is: []i64) (vs: []i32) = map (\(xs: []i32) -> scatter (copy xs) is vs) xss futhark-0.25.27/tests/distribution/scatter1.fut000066400000000000000000000003751475065116200214760ustar00rootroot00000000000000-- Scattering where elements are themselves arrays - see #2035. -- == -- input { [[1,2,3],[4,5,6]] [1i64,0i64,-1i64] [[9,8,7],[6,5,4],[9,8,7]] } -- output { [[6,5,4],[9,8,7]] } entry main (xss: *[][]i32) (is: []i64) (ys: [][]i32) = scatter xss is ys futhark-0.25.27/tests/distribution/scatter2.fut000066400000000000000000000005171475065116200214750ustar00rootroot00000000000000-- Scattering where elements are themselves arrays - see #2035. This -- one also has a map part. -- == -- input { [[1,2,3],[4,5,6]] [1i64,0i64,-1i64] [[9,8,7],[6,5,4],[4,5,6]] } -- output { [[8,7,6],[11,10,9]] } -- structure gpu { SegMap 2 } entry main (xss: *[][]i32) (is: []i64) (ys: [][]i32) = scatter xss is (map (map (+2)) ys) futhark-0.25.27/tests/distribution/scatter3.fut000066400000000000000000000005001475065116200214660ustar00rootroot00000000000000-- From #2089. -- == -- input { [[2i64,4i64],[1i64,3i64]] [[[true,false,true],[false,false,true]],[[false,true,true],[false,true,false]]] } -- output { [[[true, true, true], [true, true, true]], [[true, true, true], [false, true, true]]] } entry main = map2 (\is vs -> scatter (replicate 2 (replicate 3 true)) is vs) futhark-0.25.27/tests/distribution/segconcat0.fut000066400000000000000000000005361475065116200217750ustar00rootroot00000000000000-- Nested concatenation just becomes a concatenation along an inner dimension. -- == -- input { [[1,2,3],[4,5,6]] [[3,2,1],[6,5,4]] } -- output { [[1,2,3,3,2,1], -- [4,5,6,6,5,4]] } -- structure gpu { Kernel 0 } def main [n][m] (xss: [][n]i32) (yss: [][m]i32) = let k = n + m in map (\(xs, ys) -> concat xs ys :> [k]i32) (zip xss yss) futhark-0.25.27/tests/distribution/segconcat1.fut000066400000000000000000000005651475065116200220000ustar00rootroot00000000000000-- Nested concatenation with more arrays. -- == -- input { [[1,2],[3,4],[5,6]] [[1,2],[3,4],[5,6]] [[1,2],[3,4],[5,6]] } -- output { [[1,2,1,2,1,2], [3,4,3,4,3,4], [5,6,5,6,5,6]] } -- structure gpu { Kernel 0 } def main [a][b][c] (xss: [][a]i32) (yss: [][b]i32) (zss: [][c]i32) = let n = a + b + c in map3 (\xs ys zs -> concat xs (concat ys zs) :> [n]i32) xss yss zss futhark-0.25.27/tests/distribution/segconcat2.fut000066400000000000000000000004031475065116200217700ustar00rootroot00000000000000-- Nested concatenation with an invariant part. -- == -- input { [[1,2],[3,4]] } -- output { [[1, 2, 2, 2], [3, 4, 2, 2]]} -- structure gpu { Kernel 0 } def main [n] (xss: [n][]i32) = let m = n + 2 in map (\xs -> concat xs (replicate 2 2) :> [m]i32) xss futhark-0.25.27/tests/distribution/segreduce0.fut000066400000000000000000000012501475065116200217670ustar00rootroot00000000000000-- Contrived segmented reduction where the results are used in a -- different order than they are produced. -- -- == -- input { [[1,2],[3,4]] [[5.0f32,6.0f32],[7.0f32,8.0f32]] } -- output { [11.0f32, 15.0f32] [3i32, 7i32] } def main [n][m] (ass: [n][m]i32) (bss: [n][m]f32): ([]f32, []i32) = unzip(map2 (\(as: []i32) (bs: []f32): (f32, i32) -> let (asum, bsum) = reduce (\(x: (i32, f32)) (y: (i32, f32)): (i32, f32) -> let (x_a, x_b) = x let (y_a, y_b) = y in (x_a + y_a, x_b + y_b)) (0, 0f32) (zip as bs) in (bsum, asum)) ass bss) futhark-0.25.27/tests/distribution/segreduce1.fut000066400000000000000000000002641475065116200217740ustar00rootroot00000000000000-- Multi-versioning of nested segmented reductions. -- == -- structure gpu { SegRed 7 } def main = map (map (map (map i32.product >-> i32.sum) >-> i32.minimum) >-> i32.maximum) futhark-0.25.27/tests/distribution/segscan0.fut000066400000000000000000000011641475065116200214500ustar00rootroot00000000000000-- Contrived segmented scan where the results are used in a -- different order than they are produced. -- -- == -- input { [[1,2],[3,4]] [[5.0f32,6.0f32],[7.0f32,8.0f32]] } -- output { -- [[5.0f32, 11.0f32], -- [7.0f32, 15.0f32]] -- [[1i32, 3i32], -- [3i32, 7i32]] -- } def main [n][m] (ass: [n][m]i32) (bss: [n][m]f32): ([][]f32, [][]i32) = unzip(map2 (\(as: []i32) (bs: []f32): ([m]f32, [m]i32) -> let (asum, bsum) = unzip(scan (\(x_a,x_b) (y_a,y_b) -> (x_a + y_a, x_b + y_b)) (0, 0f32) (zip as bs)) in (bsum, asum)) ass bss) futhark-0.25.27/tests/doublebuffer.fut000066400000000000000000000003151475065116200176670ustar00rootroot00000000000000-- A simple program that needs double buffering to be compiled to -- OpenCL code. def main (n: i32) (xss: [][]i32) = map (\xs -> #[unsafe] #[sequential] iterate_while (\xs -> xs[0] < 10) (map (+1)) xs) futhark-0.25.27/tests/entrycopy0.fut000066400000000000000000000002621475065116200173400ustar00rootroot00000000000000-- Entry point result should not be copied. -- == -- structure gpu-mem { Replicate 1 } -- structure seq-mem { Replicate 1 } def main b n = if b then iota n else replicate n 0 futhark-0.25.27/tests/entrycopy1.fut000066400000000000000000000006531475065116200173450ustar00rootroot00000000000000-- Entry point result should not be copied. -- -- ...for the CPU backend, it is currently copied because of a -- conservative assumption when inserting allocations, that we do not -- optimise away properly later. That's not terribly important right -- now, but should be fixed some day. -- == -- structure gpu-mem { SegMap 1 Manifest 0 } -- structure seq-mem { Manifest 1 } def main A = flatten A |> map (+2i32) |> unflatten futhark-0.25.27/tests/entryexpr.fut000066400000000000000000000001501475065116200172600ustar00rootroot00000000000000-- == -- input { [1,2] [3] } -- output { [1,2,3] } def main [n] (xs: [n*2]i32) (ys: [n]i32) = xs ++ ys futhark-0.25.27/tests/entryval.fut000066400000000000000000000001531475065116200170670ustar00rootroot00000000000000-- It is OK for an entry point to not be a function. -- == -- input {} output {3.14} def main: f64 = 3.14 futhark-0.25.27/tests/enums/000077500000000000000000000000001475065116200156335ustar00rootroot00000000000000futhark-0.25.27/tests/enums/enum0.fut000066400000000000000000000003561475065116200174030ustar00rootroot00000000000000-- Basic enum types and matches. -- == -- input { } -- output { 5 } type animal = #dog | #cat | #mouse | #bird def main : i32 = match #mouse : animal case #dog -> 6 case #bird -> 9 case #mouse -> 5 case #cat -> 0 futhark-0.25.27/tests/enums/enum1.fut000066400000000000000000000003501475065116200173760ustar00rootroot00000000000000-- Matches on nested tuples 1. -- == -- input { } -- output { 3 } def main : i32 = match ((3,(1,10)), 2) case (_, 3) -> 1 case ((4,_), 2) -> 2 case ((3,(_, 10)), _) -> 3 case (_, _) -> 4 futhark-0.25.27/tests/enums/enum10.fut000066400000000000000000000001771475065116200174650ustar00rootroot00000000000000-- Match on a function. -- == -- input { } -- output { 1 } def main : i32 = match (\x -> x + 1) case y -> 1 futhark-0.25.27/tests/enums/enum11.fut000066400000000000000000000004001475065116200174530ustar00rootroot00000000000000-- Enum equality. -- == -- input { } -- output { 2 } type foobar = #foo | #bar def main : i32 = if (#foo : foobar) == #bar then 1 else if (#bar : foobar) == #bar then 2 else 3 futhark-0.25.27/tests/enums/enum12.fut000066400000000000000000000010501475065116200174560ustar00rootroot00000000000000-- Composition of functions on enums. -- == -- input { } -- output { 3 } type animal = #dog | #cat | #mouse | #bird type planet = #mercury | #venus | #earth | #mars def compose 'a 'b 'c (f : a -> b) (g : b -> c) : a -> c = \x -> g (f x) def f (x : animal) : planet = match x case #dog -> #mercury case #cat -> #venus case #mouse -> #earth case #bird -> #mars def g (x : planet) : i32 = match x case #mercury -> 1 case #venus -> 2 case #earth -> 3 case #mars -> 4 def main : i32 = compose f g #mouse futhark-0.25.27/tests/enums/enum13.fut000066400000000000000000000003361475065116200174650ustar00rootroot00000000000000-- Comparison of enums in records. -- == -- input { } -- output { [true, false] } type rec = {f1 : #foo | #bar, f2 : #vim | #emacs} def main = let (r : rec) = {f1 = #foo, f2 = #emacs} in [r.f1 == #foo, r.f1 == #bar] futhark-0.25.27/tests/enums/enum14.fut000066400000000000000000000002111475065116200174560ustar00rootroot00000000000000-- Constructor order shouldn't matter. -- == type foobar = #foo | #bar type barfoo = #bar | #foo def main (x : foobar) = #bar : barfoo futhark-0.25.27/tests/enums/enum15.fut000066400000000000000000000003021475065116200174600ustar00rootroot00000000000000-- Comparison of enums of different types in records. -- == -- error: type rec = {f1 : #foo | #bar, f2 : #vim | #emacs} def main = let (r : rec) = {f1 = #foo, f2 = #emacs} in r.f1 == r.f2 futhark-0.25.27/tests/enums/enum16.fut000066400000000000000000000006451475065116200174730ustar00rootroot00000000000000-- Enum in-place updates. -- == -- input { } -- output { [2, 2, 1, 1] } def swap_inplace (n : i64) : *[]#foo | #bar = let x = replicate n #foo ++ replicate n #bar in loop x for i < 2*n do x with [i] = match x[i] case #foo -> #bar case #bar -> #foo def f (x : #foo | #bar) : i32 = match x case #foo -> 1 case #bar -> 2 def main : []i32 = map f (swap_inplace 2) futhark-0.25.27/tests/enums/enum17.fut000066400000000000000000000004431475065116200174700ustar00rootroot00000000000000-- Enum swap constructors in array. -- == -- input { } -- output { [2, 1] } def f (x : #foo | #bar) : #foo | #bar = match x case #foo -> #bar case #bar -> #foo def g (x : #foo | #bar) : i32 = match x case #foo -> 1 case #bar -> 2 def main = map g (map f [#foo, #bar]) futhark-0.25.27/tests/enums/enum18.fut000066400000000000000000000001601475065116200174650ustar00rootroot00000000000000-- Enum identity function with uniqueness types. -- == def id_unique (x : *[]#foo | #bar) : *[]#foo | #bar = x futhark-0.25.27/tests/enums/enum19.fut000066400000000000000000000001111475065116200174620ustar00rootroot00000000000000-- Matches must have at least one case. -- == -- error: def x = match 5 futhark-0.25.27/tests/enums/enum2.fut000066400000000000000000000002621475065116200174010ustar00rootroot00000000000000-- Matches on nested tuples 2. -- == -- input { } -- output { 6 } def main : i32 = match ((1,2), 3) case ((5,2), 3) -> 5 case ((1,2), 3) -> 6 case _ -> 7 futhark-0.25.27/tests/enums/enum20.fut000066400000000000000000000003711475065116200174620ustar00rootroot00000000000000-- Non-exhaustive pattern match. -- == -- error: type planet = #mercury | #venus | #earth | #mars def main : i32 = match #mars : planet case #mercury -> 1 case #venus -> 2 case #earth -> 3 futhark-0.25.27/tests/enums/enum21.fut000066400000000000000000000005101475065116200174560ustar00rootroot00000000000000-- Non-exhaustive pattern match 2. -- == -- error: type planet = #mercury | #venus | #earth | #mars def main : i32 = match (1, #mars : planet, 5) case (1, #mercury, 5) -> 1 case (1, #venus, 5) -> 2 case (1, #earth, 5) -> 3 case (1, #earth, 6) -> 4 futhark-0.25.27/tests/enums/enum22.fut000066400000000000000000000003661475065116200174700ustar00rootroot00000000000000-- Non-exhaustive pattern match 3. -- == -- error: type planet = #mercury | #venus | #earth | #mars def f (x : planet) : i32 = 1 + match x case #mercury -> 1 case #venus -> 2 case #earth -> 3 def main : i32 = f #mars futhark-0.25.27/tests/enums/enum23.fut000066400000000000000000000006051475065116200174650ustar00rootroot00000000000000-- Matches on records. -- == -- input { } -- output { 12 } type foobar = {foo : i32, bar: i32} def main : i32 = match ({foo = 1, bar = 2} : foobar) case {foo = 3, bar = 4} -> 9 case {foo = 5, bar = 6} -> 10 case {foo = 7, bar = 8} -> 11 case {foo = 1, bar = 2} -> 12 case _ -> 12 futhark-0.25.27/tests/enums/enum24.fut000066400000000000000000000005251475065116200174670ustar00rootroot00000000000000-- Matches on records 2. -- == -- input { } -- output { 2 } type foobar = {foo : i32, bar: i32} def main : i32 = match ({foo = 1, bar = 2} : foobar) case {foo = 3, bar = 4} -> 9 case {foo = 5, bar = 6} -> 10 case {foo = 7, bar = 8} -> 11 case {foo = _, bar = x} -> x futhark-0.25.27/tests/enums/enum25.fut000066400000000000000000000007051475065116200174700ustar00rootroot00000000000000-- Matches on records 3. -- == -- input { } -- output { 3 } type foopdoop = #foop | #doop type editors = #emacs | #vim type foobar = {foo : foopdoop, bar: editors} def main : i32 = match ({foo = #foop, bar = #vim} : foobar) case {foo = #doop, bar = #vim} -> 1 case {foo = #foop, bar = #emacs} -> 2 case {foo = _, bar = #vim} -> 3 case {foo = #doop, bar=#emacs} -> 4 futhark-0.25.27/tests/enums/enum26.fut000066400000000000000000000006011475065116200174640ustar00rootroot00000000000000-- Enums in simple modules. -- == -- input { } -- output { 3 } module enum_module = { type foobar = #foo | #bar def f (x : foobar) : i32 = match x case #foo -> 1 case #bar -> 2 def foo = #foo : foobar } def main : i32 = match (enum_module.f (#foo : enum_module.foobar)) case 1 -> 3 case 2 -> 4 case _ -> 5 futhark-0.25.27/tests/enums/enum27.fut000066400000000000000000000007271475065116200174760ustar00rootroot00000000000000-- Enums in module types. -- == -- input { } -- output { 3 } module type foobar_mod = { type foobar val f : foobar -> i32 val foo : foobar } module enum_module : foobar_mod = { type foobar = #foo | #bar def f (x : foobar) : i32 = match x case #foo -> 1 case #bar -> 2 def foo = #foo : foobar } def main : i32 = match (enum_module.f enum_module.foo) case 1 -> 3 case 2 -> 4 case _ -> 5 futhark-0.25.27/tests/enums/enum28.fut000066400000000000000000000011151475065116200174670ustar00rootroot00000000000000-- Enums in parametric modules. -- == -- input { } -- output { 4 } module type foobar_mod = { type foobar val f : foobar -> i32 val foo : foobar val bar : foobar } module sum (M: foobar_mod) = { def sum (a: []M.foobar): i32 = reduce (+) 0 (map M.f a) } module enum_module : foobar_mod = { type foobar = #foo | #bar def f (x : foobar) : i32 = match x case #foo -> 1 case #bar -> 2 def foo = #foo : foobar def bar = #bar : foobar } module sum_enum = sum enum_module def main : i32 = sum_enum.sum [enum_module.foo, enum_module.bar, enum_module.foo] futhark-0.25.27/tests/enums/enum29.fut000066400000000000000000000005151475065116200174730ustar00rootroot00000000000000-- Local ambiguous enum. -- == -- input { } -- output { [2,2] } def f (x : #foo | #bar) : [](#foo | #bar) = let id 't (x : t) : t = x in match id x case #foo -> [#foo, #foo] case #bar -> [#bar, #bar] def g (x : #foo | #bar) : i32 = match x case #foo -> 1 case #bar -> 2 def main : []i32 = map g (f #bar) futhark-0.25.27/tests/enums/enum3.fut000066400000000000000000000005771475065116200174130ustar00rootroot00000000000000-- Matches on nested tuples 3. -- == -- input { } -- output { 3 } type animal = #dog | #cat | #mouse | #bird def main : i32 = match ((1, #dog : animal), 12, (#cat : animal, #mouse : animal)) case ((6, #dog), 12, (#cat , #mouse)) -> 1 case (_, 13, _) -> 2 case ((1, #dog), 12, (#cat, #mouse)) -> 3 case _ -> 4 futhark-0.25.27/tests/enums/enum30.fut000066400000000000000000000036741475065116200174740ustar00rootroot00000000000000-- 255 constructors allowed. -- == type big_enum = #c0 | #c1 | #c2 | #c3 | #c4 | #c5 | #c6 | #c7 | #c8 | #c9 | #c10 | #c11 | #c12 | #c13 | #c14 | #c15 | #c16 | #c17 | #c18 | #c19 | #c20 | #c21 | #c22 | #c23 | #c24 | #c25 | #c26 | #c27 | #c28 | #c29 | #c30 | #c31 | #c32 | #c33 | #c34 | #c35 | #c36 | #c37 | #c38 | #c39 | #c40 | #c41 | #c42 | #c43 | #c44 | #c45 | #c46 | #c47 | #c48 | #c49 | #c50 | #c51 | #c52 | #c53 | #c54 | #c55 | #c56 | #c57 | #c58 | #c59 | #c60 | #c61 | #c62 | #c63 | #c64 | #c65 | #c66 | #c67 | #c68 | #c69 | #c70 | #c71 | #c72 | #c73 | #c74 | #c75 | #c76 | #c77 | #c78 | #c79 | #c80 | #c81 | #c82 | #c83 | #c84 | #c85 | #c86 | #c87 | #c88 | #c89 | #c90 | #c91 | #c92 | #c93 | #c94 | #c95 | #c96 | #c97 | #c98 | #c99 | #c100 | #c101 | #c102 | #c103 | #c104 | #c105 | #c106 | #c107 | #c108 | #c109 | #c110 | #c111 | #c112 | #c113 | #c114 | #c115 | #c116 | #c117 | #c118 | #c119 | #c120 | #c121 | #c122 | #c123 | #c124 | #c125 | #c126 | #c127 | #c128 | #c129 | #c130 | #c131 | #c132 | #c133 | #c134 | #c135 | #c136 | #c137 | #c138 | #c139 | #c140 | #c141 | #c142 | #c143 | #c144 | #c145 | #c146 | #c147 | #c148 | #c149 | #c150 | #c151 | #c152 | #c153 | #c154 | #c155 | #c156 | #c157 | #c158 | #c159 | #c160 | #c161 | #c162 | #c163 | #c164 | #c165 | #c166 | #c167 | #c168 | #c169 | #c170 | #c171 | #c172 | #c173 | #c174 | #c175 | #c176 | #c177 | #c178 | #c179 | #c180 | #c181 | #c182 | #c183 | #c184 | #c185 | #c186 | #c187 | #c188 | #c189 | #c190 | #c191 | #c192 | #c193 | #c194 | #c195 | #c196 | #c197 | #c198 | #c199 | #c200 | #c201 | #c202 | #c203 | #c204 | #c205 | #c206 | #c207 | #c208 | #c209 | #c210 | #c211 | #c212 | #c213 | #c214 | #c215 | #c216 | #c217 | #c218 | #c219 | #c220 | #c221 | #c222 | #c223 | #c224 | #c225 | #c226 | #c227 | #c228 | #c229 | #c230 | #c231 | #c232 | #c233 | #c234 | #c235 | #c236 | #c237 | #c238 | #c239 | #c240 | #c241 | #c242 | #c243 | #c244 | #c245 | #c246 | #c247 | #c248 | #c249 | #c250 | #c251 | #c252 | #c253 | #c254 futhark-0.25.27/tests/enums/enum31.fut000066400000000000000000000003301475065116200174570ustar00rootroot00000000000000-- Missing pattern warning 1. -- == -- error: Unmatched type planet = #mercury | #venus | #earth | #mars def g : i32 = match (#venus : planet) case #mercury -> 1 case #venus -> 2 case #earth -> 3 futhark-0.25.27/tests/enums/enum32.fut000066400000000000000000000004111475065116200174600ustar00rootroot00000000000000-- Missing pattern warning 2. -- == -- error: Unmatched type planet = #mercury | #venus | #earth | #mars type foobar = #foo | #bar type rec = {f1 : foobar, f2: planet} def g : i32 = match {f1 = #bar, f2 = #earth} : rec case {f1 = #bar, f2 = #venus} -> 1 futhark-0.25.27/tests/enums/enum33.fut000066400000000000000000000002151475065116200174630ustar00rootroot00000000000000-- Missing pattern warning 3. -- == -- error: Unmatched type foobar = #foo | #bar def f : i32 = match #foo case (#foo : foobar) -> 1 futhark-0.25.27/tests/enums/enum34.fut000066400000000000000000000003441475065116200174670ustar00rootroot00000000000000-- Missing pattern warning 4; intended behaviour is to print the warning without -- superfluous parentheses. -- == -- error: Unmatched type foobar = #foo | #bar def f : i32 = match #foo : foobar case (((((#foo))))) -> 1 futhark-0.25.27/tests/enums/enum35.fut000066400000000000000000000002041475065116200174630ustar00rootroot00000000000000-- Missing pattern warning 5 (integers). -- == -- error: Unmatched def f : i32 = match (1 : i32) case 1 -> 1 case 2 -> 2 futhark-0.25.27/tests/enums/enum36.fut000066400000000000000000000001761475065116200174740ustar00rootroot00000000000000-- Missing pattern warning 6 (floats). -- == -- error: def f : f32 = match (1.5 : f32) case 1.1 -> 1 case 2 -> 2 futhark-0.25.27/tests/enums/enum37.fut000066400000000000000000000002241475065116200174670ustar00rootroot00000000000000-- Missing pattern warning 7 (floats). -- == -- error: def f : i32 = match {foo = (3.6 : f32), bar = (1 : i32)} case {foo = 3, bar = y} -> y futhark-0.25.27/tests/enums/enum38.fut000066400000000000000000000002051475065116200174670ustar00rootroot00000000000000-- Missing pattern warning 8 (bool). -- == -- error: Unmatched def f : bool = match (true, false) case (false, true) -> false futhark-0.25.27/tests/enums/enum39.fut000066400000000000000000000003621475065116200174740ustar00rootroot00000000000000-- Missing pattern warning 9. -- == -- error: Unmatched type foobar = #foo | #bar type rec = {f1 : foobar, f2 : f32} def f : bool = match (true, 10 : i32, {f1 = #foo, f2 = 1.2} : rec) case (true, 10, {f1 = #foo, f2 = 1.2}) -> true futhark-0.25.27/tests/enums/enum4.fut000066400000000000000000000004221475065116200174010ustar00rootroot00000000000000-- Test matches on nested tuples 4. -- == -- input { } -- output { 3 } def main : i32 = match (4, (5,6)) case (_, (_,10)) -> 1 case (_, (_,7)) -> 2 case (_, (_,6)) -> 3 case _ -> 4 futhark-0.25.27/tests/enums/enum40.fut000066400000000000000000000004031475065116200174600ustar00rootroot00000000000000-- Missing pattern warning 10. -- (Checks that warnings are still triggered with ambiguous types) -- == -- error: type foobar = #foo | #bar def f : bool = match (true, 10, {f1 = #foo : foobar, f2 = 1.2}) case (true, 10, {f1 = #foo, f2 = 1.2}) -> true futhark-0.25.27/tests/enums/enum41.fut000066400000000000000000000037321475065116200174710ustar00rootroot00000000000000-- 257 constructors not allowed. -- == -- error: type big_enum = #c0 | #c1 | #c2 | #c3 | #c4 | #c5 | #c6 | #c7 | #c8 | #c9 | #c10 | #c11 | #c12 | #c13 | #c14 | #c15 | #c16 | #c17 | #c18 | #c19 | #c20 | #c21 | #c22 | #c23 | #c24 | #c25 | #c26 | #c27 | #c28 | #c29 | #c30 | #c31 | #c32 | #c33 | #c34 | #c35 | #c36 | #c37 | #c38 | #c39 | #c40 | #c41 | #c42 | #c43 | #c44 | #c45 | #c46 | #c47 | #c48 | #c49 | #c50 | #c51 | #c52 | #c53 | #c54 | #c55 | #c56 | #c57 | #c58 | #c59 | #c60 | #c61 | #c62 | #c63 | #c64 | #c65 | #c66 | #c67 | #c68 | #c69 | #c70 | #c71 | #c72 | #c73 | #c74 | #c75 | #c76 | #c77 | #c78 | #c79 | #c80 | #c81 | #c82 | #c83 | #c84 | #c85 | #c86 | #c87 | #c88 | #c89 | #c90 | #c91 | #c92 | #c93 | #c94 | #c95 | #c96 | #c97 | #c98 | #c99 | #c100 | #c101 | #c102 | #c103 | #c104 | #c105 | #c106 | #c107 | #c108 | #c109 | #c110 | #c111 | #c112 | #c113 | #c114 | #c115 | #c116 | #c117 | #c118 | #c119 | #c120 | #c121 | #c122 | #c123 | #c124 | #c125 | #c126 | #c127 | #c128 | #c129 | #c130 | #c131 | #c132 | #c133 | #c134 | #c135 | #c136 | #c137 | #c138 | #c139 | #c140 | #c141 | #c142 | #c143 | #c144 | #c145 | #c146 | #c147 | #c148 | #c149 | #c150 | #c151 | #c152 | #c153 | #c154 | #c155 | #c156 | #c157 | #c158 | #c159 | #c160 | #c161 | #c162 | #c163 | #c164 | #c165 | #c166 | #c167 | #c168 | #c169 | #c170 | #c171 | #c172 | #c173 | #c174 | #c175 | #c176 | #c177 | #c178 | #c179 | #c180 | #c181 | #c182 | #c183 | #c184 | #c185 | #c186 | #c187 | #c188 | #c189 | #c190 | #c191 | #c192 | #c193 | #c194 | #c195 | #c196 | #c197 | #c198 | #c199 | #c200 | #c201 | #c202 | #c203 | #c204 | #c205 | #c206 | #c207 | #c208 | #c209 | #c210 | #c211 | #c212 | #c213 | #c214 | #c215 | #c216 | #c217 | #c218 | #c219 | #c220 | #c221 | #c222 | #c223 | #c224 | #c225 | #c226 | #c227 | #c228 | #c229 | #c230 | #c231 | #c232 | #c233 | #c234 | #c235 | #c236 | #c237 | #c238 | #c239 | #c240 | #c241 | #c242 | #c243 | #c244 | #c245 | #c246 | #c247 | #c248 | #c249 | #c250 | #c251 | #c252 | #c253 | #c254 | #c255 | #c256 futhark-0.25.27/tests/enums/enum42.fut000066400000000000000000000001771475065116200174720ustar00rootroot00000000000000-- Ambiguous enums should yield an error. -- == -- error: def f : bool = match #foo case #foo -> true case #bar -> true futhark-0.25.27/tests/enums/enum43.fut000066400000000000000000000003721475065116200174700ustar00rootroot00000000000000-- Matching type error 1. -- == -- error: type planet = #mercury | #venus | #earth | #mars def x = match (#mercury : planet) case #mercury -> 3 case #venus -> 1 case #earth -> true case #mars -> false futhark-0.25.27/tests/enums/enum44.fut000066400000000000000000000003501475065116200174650ustar00rootroot00000000000000-- Matching type error 2. -- == -- error: type planet = #mercury | #venus | #earth | #mars def x = match 2 case #mercury -> 3 case #venus -> 1 case #earth -> true case #mars -> false futhark-0.25.27/tests/enums/enum45.fut000066400000000000000000000005031475065116200174660ustar00rootroot00000000000000-- It is allowed to consume the same array in different branches of a -- pattern match (similar to 'if'-expressions). -- == -- input { [1,2,3] true } output { [0,2,3] } -- input { [1,2,3] false } output { [1,1,3] } def main (xs: *[]i32) (b: bool) = match b case true -> xs with [0] = 0 case false -> xs with [1] = 1 futhark-0.25.27/tests/enums/enum46.fut000066400000000000000000000002521475065116200174700ustar00rootroot00000000000000-- The scrutinee of a 'match' expression is fully evaluated before the branches. -- == def main (xs: *[]i32) = match xs[0] case 0 -> xs with [0] = 0 case _ -> [0] futhark-0.25.27/tests/enums/enum47.fut000066400000000000000000000002221475065116200174660ustar00rootroot00000000000000-- Do not evaluate branches unnecessarily... -- == -- input { 0 } output { 0 } def main (x: i32) = match x case 0 -> 0 case _ -> 2/x futhark-0.25.27/tests/enums/enum5.fut000066400000000000000000000002611475065116200174030ustar00rootroot00000000000000-- Enums as function arguments. -- == -- input { } -- output { 2 } def f (x : #foo | #bar) : i32 = match x case #foo -> 1 case #bar -> 2 def main : i32 = f #bar futhark-0.25.27/tests/enums/enum6.fut000066400000000000000000000003171475065116200174060ustar00rootroot00000000000000-- Enums as a type parameter. -- == -- input { } -- output { 2 } def id 'a (x : a) : a = x def f (x : #foo | #bar) : i32 = match x case #foo -> 1 case #bar -> 2 def main : i32 = f (id #bar) futhark-0.25.27/tests/enums/enum7.fut000066400000000000000000000002351475065116200174060ustar00rootroot00000000000000-- Invalid return type with overlapping constructor. -- == -- error: def g (x : #foo | #bar) : #foo = match x case #foo -> #foo case #bar -> #bar futhark-0.25.27/tests/enums/enum8.fut000066400000000000000000000001361475065116200174070ustar00rootroot00000000000000-- Invalid constructor format. -- == -- error: Unexpected token: 'bar' type foo = #foo | bar futhark-0.25.27/tests/enums/enum9.fut000066400000000000000000000004361475065116200174130ustar00rootroot00000000000000-- Arrays of enums. -- == -- input { } -- output { [1, 2, 3, 4] } type animal = #dog | #cat | #mouse | #bird def f (x : animal) : i32 = match x case #dog -> 1 case #cat -> 2 case #mouse -> 3 case #bird -> 4 def main : []i32 = map f [#dog, #cat, #mouse, #bird] futhark-0.25.27/tests/enums/issue663.fut000066400000000000000000000001631475065116200177420ustar00rootroot00000000000000-- Issue 663. x shouldn't need a type ascription. -- == def main: (bool, #l | #r) = let x = #l in (x == x, x) futhark-0.25.27/tests/eqarrays0.fut000066400000000000000000000007231475065116200171350ustar00rootroot00000000000000-- Equality checking of arrays. -- -- More subtle than it looks, as you also have to compare the dimensions. -- == -- input { empty([0]i32) empty([0]i32) } -- output { true } -- input { empty([0]i32) [1] } -- output { false } -- input { [1] empty([0]i32) } -- output { false } -- input { [1,2] [1,2] } -- output { true } -- input { [1,2] [3,4] } -- output { false } -- warning: deprecated def main [n][m] (xs: [n]i32) (ys: [m]i32) = n == m && xs == (ys :> [n]i32) futhark-0.25.27/tests/eqarrays1.fut000066400000000000000000000003711475065116200171350ustar00rootroot00000000000000-- Equality checking of arrays of two dimensions. -- == -- input { [[1,2],[3,4]] [[1,2],[3,4]] } -- output { true } -- input { [[1,2],[3,4]] [[1,2],[3,5]] } -- output { false } -- warning: deprecated def main (xs: [][]i32) (ys: [][]i32) = xs == ys futhark-0.25.27/tests/errors/000077500000000000000000000000001475065116200160205ustar00rootroot00000000000000futhark-0.25.27/tests/errors/duplicate-let-size.fut000066400000000000000000000001201475065116200222350ustar00rootroot00000000000000-- == -- error: also bound def f (x: []i32) = let [n] (n: [n]i32) = x in n futhark-0.25.27/tests/errors/duplicate-params.fut000066400000000000000000000001701475065116200217710ustar00rootroot00000000000000-- Using the same parameter name twice is forbidden. -- -- == -- error: also bound def main (x: i32) (x: i32): i32 = x futhark-0.25.27/tests/errors/duplicate-size-params.fut000066400000000000000000000002321475065116200227400ustar00rootroot00000000000000-- Using the same parameter name twice is forbidden, even when one use -- is as a size parameter. -- -- == -- error: also bound def f [m] (m:[m]i64) = m futhark-0.25.27/tests/errors/duplicate-type-params.fut000066400000000000000000000001601475065116200227470ustar00rootroot00000000000000-- Using the same type parameter name twice is forbidden. -- -- == -- error: also bound def f 'a 'a (x: a) = x futhark-0.25.27/tests/errors/duplicate-vars.fut000066400000000000000000000002411475065116200214600ustar00rootroot00000000000000-- Using the same name twice in a single pattern is forbidden. -- -- == -- error: also bound def main (x: i32): (i32,i32) = let (y,y) = (x-1, x+1) in (y,y) futhark-0.25.27/tests/errors/missing-in.fut000066400000000000000000000002201475065116200206070ustar00rootroot00000000000000-- == -- error: Unexpected end of file def main (x:i32): f32 = let f_ = map f (1.. x % 3 == 0 || x % 5 == 0) ( iota(bound))) futhark-0.25.27/tests/euler/euler2.fut000066400000000000000000000013241475065116200175360ustar00rootroot00000000000000-- By considering the terms in the Fibonacci sequence whose values do -- not exceed four million, find the sum of the even-valued terms. -- -- == -- input { 4000000 } -- output { 4613732 } -- One approach: sequentially construct an array containing the -- Fibonacci numbers, then filter and sum. If we knew the number of -- Fibonacci numbers we needed to generate (or even just an upper -- bound), we could do it in parallel. -- -- Our approach: simple sequential counting loop. def main(bound: i32): i32 = let (sum, _, _) = loop (sum, fib0, fib1) = (0, 1, 1) while fib1 < bound do let newsum = if fib1 % 2 == 0 then sum + fib1 else sum in (newsum, fib1, fib0 + fib1) in sum futhark-0.25.27/tests/existential-ifs/000077500000000000000000000000001475065116200176145ustar00rootroot00000000000000futhark-0.25.27/tests/existential-ifs/iota.fut000066400000000000000000000003601475065116200212670ustar00rootroot00000000000000-- == -- input { true 20i64 } -- output { [11i64, 12i64, 13i64, 14i64, 15i64, 16i64, 17i64, 18i64, 19i64] } -- -- input { false 20i64 } -- output { empty([0]i64) } def main (b: bool) (n: i64) = if b then filter (>10) (iota n) else [] futhark-0.25.27/tests/existential-ifs/ixfun-antiunif-1.fut000066400000000000000000000010501475065116200234320ustar00rootroot00000000000000-- A simple test for index-function anti-unification across an if-then-else -- == -- input { [-1.0f32, 3.0f32, 5.0f32, 7.0f32, 9.0f32, 11.0f32, 13.0f32, 15.0f32, 17.0f32, 19.0f32, 21.0f32, 23.0f32, 25.0f32]} -- output { [21.0f32, 23.0f32, 25.0f32] } -- -- input { [ 1.0f32, 3.0f32, 5.0f32, 7.0f32, 9.0f32, 11.0f32, 13.0f32, 15.0f32, 17.0f32, 19.0f32, 21.0f32, 23.0f32, 25.0f32]} -- output { [ 2.0f32, 6.0f32, 10.0f32] } def main [n] (arr: [n]f32) = let x = if(arr[0] < 0.0) then arr[10:n] else map (*2.0f32) arr[0:n-10] in x futhark-0.25.27/tests/existential-ifs/ixfun-antiunif-2.fut000066400000000000000000000012241475065116200234360ustar00rootroot00000000000000-- Another simple test for index-function anti-unification across an if-then-else -- This one returns the same memory block, only the offset is existentialized. -- == -- input { [-1.0f32, 3.0f32, 5.0f32, 7.0f32, 9.0f32, 11.0f32, 13.0f32, 15.0f32, 17.0f32, 19.0f32, 21.0f32, 23.0f32, 25.0f32]} -- output { [17.0f32, 19.0f32, 21.0f32, 23.0f32, 25.0f32] } -- -- input { [ 1.0f32, 3.0f32, 5.0f32, 7.0f32, 9.0f32, 11.0f32, 13.0f32, 15.0f32, 17.0f32, 19.0f32, 21.0f32, 23.0f32, 25.0f32]} -- output { [11.0f32, 13.0f32, 15.0f32, 17.0f32, 19.0f32, 21.0f32, 23.0f32, 25.0f32] } def main [n] (arr: [n]f32) = if (arr[0] < 0.0) then arr[2+n/2:n] else arr[2+n/4:n] futhark-0.25.27/tests/existential-ifs/ixfun-antiunif-5.fut000066400000000000000000000005611475065116200234440ustar00rootroot00000000000000-- Another simple test for index-function anti-unification across an if-then-else -- This one returns the same memory block, only the offset is existentialized. -- == -- random input { [30000]i32 } -- auto output -- structure gpu-mem { Manifest 0 } def main (a: []i32) = let xs = if a[0] > 0 then a[10:30:2] else a[5:20:3] in reduce (+) 0 xs futhark-0.25.27/tests/existential-ifs/loop-antiunif.fut000066400000000000000000000006001475065116200231140ustar00rootroot00000000000000-- Another simple test for index-function anti-unification across an if-then-else -- This one returns the same memory block, only the offset is existentialized. -- == -- input { [5, 3, 2, 1, 5] } -- output { 8 } -- -- input { [1, 2, 3, 4, 5, 6, 7] } -- output { 22 } def main [n] (arr: *[n]i32): i32 = let xs = loop arr for _i < n / 2 do arr[1:] in reduce (+) 0 xs futhark-0.25.27/tests/existential-ifs/merge_sort.fut000066400000000000000000000044231475065116200225050ustar00rootroot00000000000000-- | Bitonic merge sort. -- -- Runs in *O(n log²(n))* work and *O(log²(n))* span. Internally pads -- the array to the next power of two, so a poor fit for some array -- sizes. local def log2 (n: i64) : i64 = let r = 0 let (r, _) = loop (r,n) while 1 < n do let n = n / 2 let r = r + 1 in (r,n) in r local def ensure_pow_2 [n] 't ((<=): t -> t -> bool) (xs: [n]t): (*[]t, i64) = if n == 0 then (copy xs, 0) else let d = log2 n in if n == 2**d then (copy xs, d) else let largest = reduce (\x y -> if x <= y then y else x) xs[0] xs in (concat xs (replicate (2**(d+1) - n) largest), d+1) local def kernel_par [n] 't ((<=): t -> t -> bool) (a: *[n]t) (p: i64) (q: i64) : *[n]t = let d = 1 << (p-q) in map (\i -> let a_i = a[i] let up1 = ((i >> p) & 2) == 0 in if (i & d) == 0 then let a_iord = a[i | d] in if a_iord <= a_i == up1 then a_iord else a_i else let a_ixord = a[i ^ d] in if a_i <= a_ixord == up1 then a_ixord else a_i) (iota n) -- | Sort an array in increasing order. def merge_sort [n] 't ((<=): t -> t -> bool) (xs: [n]t): *[n]t = -- We need to pad the array so that its size is a power of 2. We do -- this by first finding the largest element in the input, and then -- using that for the padding. Then we know that the padding will -- all be at the end, so we can easily cut it off. let (xs, d) = ensure_pow_2 (<=) xs in (loop xs for i < d do loop xs for j < i+1 do kernel_par (<=) xs i j)[:n] -- | Like `merge_sort`, but sort based on key function. def merge_sort_by_key [n] 't 'k (key: t -> k) ((<=): k -> k -> bool) (xs: [n]t): [n]t = zip (map key xs) (iota n) |> merge_sort (\(x, _) (y, _) -> x <= y) |> map (\(_, i) -> xs[i]) -- == -- entry: sort_i32 -- input { empty([0]i32) } -- output { empty([0]i32) } -- input { [5,4,3,2,1] } -- output { [1,2,3,4,5] } -- input { [5,4,3,3,2,1] } -- output { [1,2,3,3,4,5] } entry sort_i32 (xs: []i32) = merge_sort (i32.<=) xs -- == -- entry: map_sort_i32 -- input { [[5,4,3,2,1,0],[5,4,3,3,2,1]] } -- output { [[0,1,2,3,4,5],[1,2,3,3,4,5]] } entry map_sort_i32 (xss: [][]i32) = #[unsafe] map (merge_sort (i32.<=)) xss futhark-0.25.27/tests/existential-ifs/merge_sort_minimized.fut000066400000000000000000000001761475065116200245530ustar00rootroot00000000000000entry ensure_pow_2 [n] (xs: [n]i64): []i64 = if n == 2 then xs else let largest = xs[0] in iota largest futhark-0.25.27/tests/existential-ifs/no-ext.fut000066400000000000000000000002411475065116200215430ustar00rootroot00000000000000-- == -- structure gpu-mem { Manifest 1 } entry ensure_pow_2 [n] [m] (xs: [n][m]i32): [][]i32 = if n == 2 then xs[0:m/2, 0:n/2] else xs[0:n/2,0:m/2] futhark-0.25.27/tests/existential-ifs/partition.fut000066400000000000000000000003211475065116200223410ustar00rootroot00000000000000-- == -- input { [1, 1, 1, 1, 1] } -- output { [0i64, 1i64, 2i64, 3i64, 4i64] empty([0]i64) } def main [n] (cost: *[n]i32) = if opaque(true) then partition (\_ -> (opaque true)) (iota n) else ([], []) futhark-0.25.27/tests/existential-ifs/two-exts.fut000066400000000000000000000002541475065116200221270ustar00rootroot00000000000000def main [n] (xs: [n]i64): [][]i64 = if n == 2 then map (\_ -> xs) (iota n) else let largest = xs[0] in map (\_ -> iota largest) (iota (largest - 1)) futhark-0.25.27/tests/existential-ifs/two-returns.fut000066400000000000000000000004111475065116200226410ustar00rootroot00000000000000def main [n] (xs: [n]i64): ([][]i64, [][]i64) = if n == 2 then (map (\_ -> xs) (iota n), map (\_ -> xs) (iota xs[0])) else let largest = xs[0] in (map (\_ -> iota largest) (iota (largest - 1)), map (\_ -> iota largest) (iota xs[1])) futhark-0.25.27/tests/existential-loop/000077500000000000000000000000001475065116200200045ustar00rootroot00000000000000futhark-0.25.27/tests/existential-loop/slice.fut000066400000000000000000000004431475065116200216240ustar00rootroot00000000000000-- A simple test for index-function generalization across a for loop -- == -- input { [0, 1000, 42, 1001, 50000] } -- output { 52043i32 } -- structure gpu-mem { Manifest 0 } def main [n] (a: [n]i32): i32 = let b = loop xs = a[1:] for i < n / 2 - 2 do xs[i:] in reduce (+) 0 b futhark-0.25.27/tests/f16kernel.fut000066400000000000000000000002611475065116200170200ustar00rootroot00000000000000-- Can we correctly have a free variable of type f16 inside a parallel -- construct? -- == -- input { [1f16,2f16] 3f16} auto output def main (xs: []f16) (y: f16) = map (+y) xs futhark-0.25.27/tests/fibfun.fut000066400000000000000000000011201475065116200164670ustar00rootroot00000000000000-- == -- input { -- 10 -- } -- output { -- [ 0 , 1 , 1 , 2 , 3 , 5 , 8 , 13 , 21 , 34 ] -- } def computefibs [n] (arr: *[n]i32): *[n]i32 = let arr[0] = 0 let arr[1] = 1 in loop (arr) for i < n-2 do let x = arr[i] let y = arr[i+1] let arr[i+2] = x + y in arr def fibs(arr: []i32, n: i32): *[][]i32 = map (\_ -> computefibs(copy(arr))) (0..1.. iota(2*x)) xs -- let arr's = map (\x arr -> reshape( (x,2), arr) $ zip xs arrs -- let res = map(\arr' -> reduce(op(+), 0, arr')) arr's -- == -- input { -- [ 1i64, 2i64, 3i64, 4i64] -- } -- output { -- [1i64, 6i64, 15i64, 28i64] -- } def main (xs: []i64): []i64 = map (\(x: i64) -> let arr = #[unsafe] 0..<(2 * x) let arr' = #[unsafe] unflatten arr in reduce (+) 0 (arr'[0]) + reduce (+) 0 (arr'[1]) ) xs futhark-0.25.27/tests/flattening/HighlyNestedMap.fut000066400000000000000000000016241475065116200224070ustar00rootroot00000000000000-- == -- input { -- [ [ [ [1,2,3], [4,5,6] ] -- , [ [6,7,8], [9,10,11] ] -- ] -- , [ [ [3,2,1], [4,5,6] ] -- , [ [8,7,6], [11,10,9] ] -- ] -- ] -- [ [ [ [4,5,6] , [1,2,3] ] -- , [ [9,10,11], [6,7,8] ] -- ] -- , [ [ [4,5,6] , [3,2,1] ] -- , [ [11,10,9], [8,7,6] ] -- ] -- ] -- } -- output { -- [[[[5, 7, 9], -- [5, 7, 9]], -- [[15, 17, 19], -- [15, 17, 19]]], -- [[[7, 7, 7], -- [7, 7, 7]], -- [[19, 17, 15], -- [19, 17, 15]]]] -- } def add1 [n] (xs: [n]i32, ys: [n]i32): [n]i32 = map2 (+) xs ys def add2 [n][m] (xs: [n][m]i32, ys: [n][m]i32): [n][m]i32 = map add1 (zip xs ys) def add3 [n][m][l] (xs: [n][m][l]i32, ys: [n][m][l]i32): [n][m][l]i32 = map add2 (zip xs ys) def add4 (xs: [][][][]i32, ys: [][][][]i32): [][][][]i32 = map add3 (zip xs ys) def main (a: [][][][]i32) (b: [][][][]i32): [][][][]i32 = add4(a,b) futhark-0.25.27/tests/flattening/IntmRes1.fut000066400000000000000000000007221475065116200210220ustar00rootroot00000000000000-- == -- input { -- [ [1,2,3], [4,5,6] -- , [6,7,8], [9,10,11] -- ] -- [1,2,3,4] -- 5 -- } -- output { -- [[7, 8, 9], -- [16, 17, 18], -- [24, 25, 26], -- [33, 34, 35]] -- } def addToRow [n] (xs: [n]i32, y: i32): [n]i32 = map (\(x: i32): i32 -> x+y) xs def main (xss: [][]i32) (cs: []i32) (y: i32): [][]i32 = map (\(xs: []i32, c: i32) -> let y' = y * c + c let zs = addToRow(xs,y') in zs ) (zip xss cs) futhark-0.25.27/tests/flattening/IntmRes2.fut000066400000000000000000000012311475065116200210170ustar00rootroot00000000000000-- == -- input { -- [ [ [1,2,3], [4,5,6] ] -- , [ [6,7,8], [9,10,11] ] -- , [ [3,2,1], [4,5,6] ] -- , [ [8,7,6], [11,10,9] ] -- ] -- [1,2,3,4] -- 5 -- } -- output { -- [[[7, 8, 9], -- [10, 11, 12]], -- [[18, 19, 20], -- [21, 22, 23]], -- [[21, 20, 19], -- [22, 23, 24]], -- [[32, 31, 30], -- [35, 34, 33]]] -- } def addToRow [n] (xs: [n]i32, y: i32): [n]i32 = map (\(x: i32): i32 -> x+y) xs def main (xsss: [][][]i32) (cs: []i32) (y: i32): [][][]i32 = map (\(xss: [][]i32, c: i32) -> let y' = y * c + c in map (\(xs: []i32) -> addToRow(xs,y') ) xss ) (zip xsss cs) futhark-0.25.27/tests/flattening/IntmRes3.fut000066400000000000000000000014471475065116200210310ustar00rootroot00000000000000-- == -- input { -- [ [ [ [1,2,3], [4,5,6] ] -- ] -- , [ [ [6,7,8], [9,10,11] ] -- ] -- , [ [ [3,2,1], [4,5,6] ] -- ] -- , [ [ [8,7,6], [11,10,9] ] -- ] -- ] -- [1,2,3,4] -- 5 -- } -- output { -- [[[[7, 8, 9], -- [10, 11, 12]]], -- [[[18, 19, 20], -- [21, 22, 23]]], -- [[[21, 20, 19], -- [22, 23, 24]]], -- [[[32, 31, 30], -- [35, 34, 33]]]] -- } def addToRow [n] (xs: [n]i32, y: i32): [n]i32 = map (\(x: i32): i32 -> x+y) xs def main (xssss: [][][][]i32) (cs: []i32) (y: i32): [][][][]i32 = map (\(xsss: [][][]i32, c: i32) -> let y' = y * c + c in map (\(xss: [][]i32) -> map (\(xs: []i32) -> addToRow(xs,y') ) xss ) xsss ) (zip xssss cs) futhark-0.25.27/tests/flattening/LoopInv1.fut000066400000000000000000000007151475065116200210310ustar00rootroot00000000000000-- == -- input { -- [ [1,2,3], [4,5,6] -- , [6,7,8], [9,10,11] -- , [3,2,1], [4,5,6] -- , [8,7,6], [11,10,9] -- ] -- [1,2,3] -- } -- output { -- [[2, 4, 6], -- [5, 7, 9], -- [7, 9, 11], -- [10, 12, 14], -- [4, 4, 4], -- [5, 7, 9], -- [9, 9, 9], -- [12, 12, 12]] -- } def addRows [n] (xs: [n]i32, ys: [n]i32): [n]i32 = map2 (+) xs ys def main (xss: [][]i32) (ys: []i32): [][]i32 = map (\(xs: []i32) -> addRows(xs,ys)) xss futhark-0.25.27/tests/flattening/LoopInv2.fut000066400000000000000000000010351475065116200210260ustar00rootroot00000000000000-- == -- input { -- [ [ [1,2,3], [4,5,6] ] -- , [ [6,7,8], [9,10,11] ] -- , [ [3,2,1], [4,5,6] ] -- , [ [8,7,6], [11,10,9] ] -- ] -- [1,2,3] -- } -- output { -- [[[2, 4, 6], -- [5, 7, 9]], -- [[7, 9, 11], -- [10, 12, 14]], -- [[4, 4, 4], -- [5, 7, 9]], -- [[9, 9, 9], -- [12, 12, 12]]] -- } def addRows [n] (xs: [n]i32, ys: [n]i32): [n]i32 = map2 (+) xs ys def main (xsss: [][][]i32) (ys: []i32): [][][]i32 = map (\(xss: [][]i32) -> map (\(xs: []i32) -> addRows(xs,ys)) xss ) xsss futhark-0.25.27/tests/flattening/LoopInv3.fut000066400000000000000000000013041475065116200210260ustar00rootroot00000000000000-- == -- input { -- [ [ [ [1,2,3], [4,5,6] ] -- ] -- , [ [ [6,7,8], [9,10,11] ] -- ] -- , [ [ [3,2,1], [4,5,6] ] -- ] -- , [ [ [8,7,6], [11,10,9] ] -- ] -- ] -- [1,2,3] -- } -- output { -- [[[[2, 4, 6], -- [5, 7, 9]]], -- [[[7, 9, 11], -- [10, 12, 14]]], -- [[[4, 4, 4], -- [5, 7, 9]]], -- [[[9, 9, 9], -- [12, 12, 12]]]] -- } def addRows [n] (xs: [n]i32, ys: [n]i32): [n]i32 = map2 (+) xs ys def main (xssss: [][][][]i32) (ys: []i32): [][][][]i32 = map (\(xsss: [][][]i32) -> map (\(xss: [][]i32) -> map (\(xs: []i32) -> addRows(xs,ys) ) xss ) xsss ) xssss futhark-0.25.27/tests/flattening/LoopInvReshape.fut000066400000000000000000000011061475065116200222530ustar00rootroot00000000000000-- This example presents difficulty for me right now, but also has a -- large potential for improvement later on. -- -- we could turn it into: -- -- let []i32 bettermain ([]i32 xs, [#n]i32 ys, [#n]i32 zs, [#n]i32 is, [#n]i32 js) = -- map (\i32 (i32 y, i32 z, i32 i, i32 j) -> -- xs[i*z + j] -- , zip(ys,zs,is,js)) def main [n][m] (xs: [m]i32, ys: [n]i64, zs: [n]i64, is: [n]i32, js: [n]i32): []i32 = map (\(y: i64, z: i64, i: i32, j: i32): i32 -> #[unsafe] let tmp = unflatten (xs :> [y*z]i32) in tmp[i,j] ) (zip4 ys zs is js) futhark-0.25.27/tests/flattening/Map-IotaMapReduce.fut000066400000000000000000000004171475065116200225560ustar00rootroot00000000000000-- == -- input { -- [2,3,4] -- [8,3,2] -- } -- output { -- [8,9,12] -- } def main [n] (xs: [n]i32) (ys: [n]i32): []i32 = map (\(x: i32, y: i32): i32 -> let tmp1 = 0.. map (\(x: i32): i32 -> let tmp1 = map i32.i64(iota(i64.i32 x)) let tmp2 = map (*y) tmp1 in reduce (+) 0 tmp2 ) xs ) (zip xss ys ) futhark-0.25.27/tests/flattening/MapIotaReduce.fut000066400000000000000000000003001475065116200220320ustar00rootroot00000000000000-- == -- input { -- [1,2,3,4] -- } -- output { -- [0, 1, 3, 6] -- } def main (xs: []i32): []i32 = map (\(x: i32): i32 -> let tmp = 0.. reduce (+) 0 xs ) xss futhark-0.25.27/tests/flattening/VectorAddition.fut000066400000000000000000000002211475065116200222700ustar00rootroot00000000000000-- == -- input { -- [1,2,3,4] -- [5,6,7,8] -- } -- output { -- [6,8,10,12] -- } def main (xs: []i32) (ys: []i32): []i32 = map2 (+) xs ys futhark-0.25.27/tests/flattening/flattening-pipeline000077500000000000000000000000521475065116200225200ustar00rootroot00000000000000#!/bin/sh futhark -s --flattening -i "$1" futhark-0.25.27/tests/flattening/flattening-test000077500000000000000000000002601475065116200216730ustar00rootroot00000000000000#!/bin/sh HERE=$(dirname "$0") if [ $# -lt 1 ]; then FILES="$HERE/"*.fut else FILES=$* fi futhark-test --only-interpret --interpreter="$HERE/flattening-pipeline" $FILES futhark-0.25.27/tests/flattening/map-nested-free.fut000066400000000000000000000004151475065116200223330ustar00rootroot00000000000000-- == -- input { [5i64,7i64] [5i64,7i64] } -- output { [5i64, 7i64] } def main = map2 (\n x -> #[unsafe] let A = #[opaque] replicate n x let B = #[opaque] map (\i -> A[i%x]) (iota n) in B[0]) futhark-0.25.27/tests/flattening/redomap1.fut000066400000000000000000000006251475065116200210720ustar00rootroot00000000000000-- == -- input { -- [[1,2,3],[1,2,3]] -- [[3,2,1],[6,7,8]] -- } -- output { -- [12, 27] -- } def main [m][n] (xss: [m][n]i32) (yss: [m][n]i32): [m]i32 = let final_res = map (\(xs: [n]i32, ys: [n]i32): i32 -> let tmp = map (\(x: i32, y: i32): i32 -> x+y ) (zip xs ys) in reduce (+) 0 tmp ) (zip xss yss) in final_res futhark-0.25.27/tests/flattening/redomap2.fut000066400000000000000000000003311475065116200210650ustar00rootroot00000000000000-- == -- input { -- [1,2,3] -- [6,7,8] -- } -- output { -- 27 -- } def main [n] (xs: [n]i32) (ys: [n]i32): i32 = let tmp = map (\(x: i32, y: i32): i32 -> x+y ) (zip xs ys) in reduce (+) 0 tmp futhark-0.25.27/tests/float_32_64.fut000066400000000000000000000003761475065116200171540ustar00rootroot00000000000000-- Test that float32s and float64s can both be used in a program. -- -- This program does not really test their semantics, but mostly that -- the parser permits them. -- -- == -- input { 3.14f64 } output { 3.0f32 } def main(x: f64): f32 = r32(t64(x)) futhark-0.25.27/tests/floatunderscores.fut000066400000000000000000000002121475065116200206010ustar00rootroot00000000000000-- Floats can contain underscores -- == -- input { 123_.456f32 } -- output { 100000.123456f32 } def main (_: f32) = 100_000.123_456f32 futhark-0.25.27/tests/fourier.fut000066400000000000000000000035631475065116200167060ustar00rootroot00000000000000-- A slow O(n**2) Fourier transform (SFT). -- -- Based on EWD807, but with a correction. When computing 'x', -- Dijkstra specifies a constant '2' in the exponents, while most -- formulations (like Wikipedia and Numpy) use -2. I have gone with -- the latter. This affects the sign of the imaginary part of the -- result. -- -- == -- input { [0f32,1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32 ] } -- output { -- [45f32, -5f32, -5f32, -5f32, -- -5f32, -5f32, -5f32, -5f32, -- -5f32, -5f32] -- [0.0f32, 15.388417f32, 6.881909f32, 3.632712f32, -- 1.624598f32, -5.510910e-15f32, -1.624598f32, -- -3.632712f32, -6.881909f32, -15.388417f32] -- } def pi: f32 = f32.acos 0.0 * 2.0 type complex = (f32, f32) def complexAdd ((a, b): complex) ((c, d): complex): complex = (a+c, b+d) def complexMult ((a,b): complex) ((c,d): complex): complex = (a*c - b*d, a*d + b*c) def toComplex (a: f32): complex = (a, 0f32) def complexExp ((a,b): complex): complex = complexMult (toComplex (f32.exp a)) (f32.cos b, f32.sin b) def toPolar ((a,b): complex): (f32, f32) = (f32.sqrt (a*a + b*b), f32.atan (b/a)) def fromPolar (r: f32, angle: f32): complex = (r * f32.cos angle, r * f32.sin angle) def complexPow (c: complex) (n: i32): complex = let (r, angle) = toPolar c let (r', angle') = (r ** f32.i32 n, f32.i32 n * angle) in fromPolar (r', angle') def f [n] (a: [n]f32) (j: i32): complex = let x = complexExp (complexMult (-2.0,0.0) (complexMult (toComplex pi) (complexMult (0.0, 1.0) (toComplex (1.0/f32.i64 n))))) in reduce complexAdd (0.0, 0.0) (map2 complexMult (map toComplex a) (map (complexPow x) (map (j*) (map i32.i64 (iota n))))) def sft [n] (a: [n]f32): [n]complex = map (f a) (map i32.i64 (iota n)) def main [n] (a: [n]f32): ([n]f32, [n]f32) = unzip (sft a) futhark-0.25.27/tests/funcall-error0.fut000066400000000000000000000001071475065116200200550ustar00rootroot00000000000000-- == -- error: Cannot apply expression as function def main = true 3 futhark-0.25.27/tests/funcall-error1.fut000066400000000000000000000002551475065116200200620ustar00rootroot00000000000000-- Test that functions accept only the right number of arguments. -- == -- error: Cannot apply "f" def f(x: i32) (y: f64): f64 = f64.i32 (x) + y def main: f64 = f 2 2.0 3 futhark-0.25.27/tests/fusion/000077500000000000000000000000001475065116200160075ustar00rootroot00000000000000futhark-0.25.27/tests/fusion/Vers2.0/000077500000000000000000000000001475065116200171465ustar00rootroot00000000000000futhark-0.25.27/tests/fusion/Vers2.0/bugCalib.fut000066400000000000000000000004211475065116200213730ustar00rootroot00000000000000-- == -- input { -- [1.0, 2.0, 3.0, 4.0, 5.0] -- } -- output { -- [2.0, 3.0, 4.0, 5.0, 0.0] -- } def main [m] (result: [m]f64 ): []f64 = -- 0 <= i < m AND 0 <= j < n tabulate m (\j -> if j < m-1 then result[j+1] else 0.0) futhark-0.25.27/tests/fusion/Vers2.0/hindrReshape0.fut000066400000000000000000000005441475065116200223650ustar00rootroot00000000000000-- == -- input { -- [0, 1, 2, 3, 4, 5, 6, 7, 8] -- } -- output { -- [1, 2, 3, 4, 5, 6, 7, 8, 9] -- [[2, 4, 6], [8, 10, 12], [14, 16, 18]] -- } def main (orig: [3*3]i32): ([]i32,[][]i32) = let a = map (+1) orig let b = unflatten a let c = map (\(row: []i32) -> map (\(x: i32): i32 -> x*2) row ) b in (a,c) futhark-0.25.27/tests/fusion/Vers2.0/mapomap0.fut000066400000000000000000000005001475065116200213730ustar00rootroot00000000000000-- == -- input { -- [1.0,-4.0,-2.4] -- } -- output { -- [6.0, 1.0, 2.6] -- [2.0, -3.0, -1.4] -- [3.0, -7.0, -3.8] -- } -- structure { -- Screma 1 -- } -- def main(arr: []f64): ([]f64,[]f64,[]f64) = let x = map (+ 1.0) arr let y = map2 (+) x arr let r = map (+ 5.0) arr in (r,x,y) futhark-0.25.27/tests/fusion/Vers2.0/mapomap1.fut000066400000000000000000000007531475065116200214060ustar00rootroot00000000000000-- == -- input { -- [1.0,-4.0,-2.4] -- } -- output { -- [4.0, -16.0, -9.6] -- [2.0, -3.0, -1.4 ] -- [3.0, -2.0, -0.3999999999999999] -- [4.0, -6.0, -2.8 ] -- [9.0, -6.0, -1.1999999999999997] -- } -- structure { -- Screma 1 -- } -- def main(arr: []f64): ([]f64,[]f64,[]f64,[]f64,[]f64) = let xy = map (\(a: f64): (f64,f64) -> (a+1.0,a+2.0)) arr let (x,y) = unzip(xy) let z = map (*2.0) x let w = map (*3.0) y let r = map (*4.0) arr in (r,x,y,z,w) futhark-0.25.27/tests/fusion/Vers2.0/mapored0.fut000066400000000000000000000003631475065116200213770ustar00rootroot00000000000000-- == -- input { -- [1.0,-4.0,-2.4] -- } -- output { -- -5.4 -- [2.0,-3.0,-1.4] -- } -- structure { -- Screma 1 -- } -- def main(arr: []f64): (f64,[]f64) = let r = reduce (+) (0.0) arr let x = map (+1.0) arr in (r,x) futhark-0.25.27/tests/fusion/Vers2.0/mapored1.fut000066400000000000000000000006571475065116200214060ustar00rootroot00000000000000-- == -- input { -- [1.0,-4.0,-2.4] -- } -- output { -- 3.6 -- [12.0, -3.0, 1.8000000000000003] -- [40.0, 15.0, 23.0] -- [0.7, -2.8, -1.68] -- } -- structure { -- Screma 1 -- } -- def main(arr: []f64): (f64,[]f64,[]f64,[]f64) = let a = map (+3.0) arr let b = map (+7.0) arr let s = reduce (+) (0.0) a let x1 = map (*3.0) a let x2 = map (*5.0) b let x3 = map (*0.7) arr in (s,x1,x2,x3) futhark-0.25.27/tests/fusion/Vers2.0/maposcan0.fut000066400000000000000000000004141475065116200215460ustar00rootroot00000000000000-- == -- input { -- [1.0,-4.0,-2.4] -- } -- output { -- [1.0,-3.0,-5.4] -- [2.0,-6.0,-10.8] -- } -- structure { -- /Stream 1 -- /Screma 0 -- } def main(arr: []f64): ([]f64,[]f64) = let sa = scan (+) (0.0) arr let b = map (*2.0) sa in (sa, b) futhark-0.25.27/tests/fusion/Vers2.0/maposcanomaposcan.fut000066400000000000000000000006741475065116200233770ustar00rootroot00000000000000-- == -- input { -- [1.0,-4.0,-2.4] -- } -- output { -- 129.6 -- [1.0,-3.0, -5.4] -- [2.0,-6.0,-10.8] -- [2.0,-4.0,-14.8] -- [7.0, 1.0, -9.8] -- } -- structure { -- /Stream 1 -- /Screma 0 -- } -- def main(arr: []f64): (f64,[]f64,[]f64,[]f64,[]f64) = let sa = scan (+) (0.0) arr let b = map (*2.0) sa let sb = scan (+) (0.0) b let c = map (+5.0) sb let r = reduce (*) (1.0) b in (r, sa, b, sb, c) futhark-0.25.27/tests/fusion/Vers2.0/rangeIndVar.fut000066400000000000000000000003461475065116200220710ustar00rootroot00000000000000-- == -- input { -- [1,2,2,3,33,4,5,6,7,8,9,0] -- } -- output { -- 33 -- } def main [m] (arr: [m]i32): i32 = let k = loop k = 0 for i < m-1 do if i % 3 == 0 then k + 1 else k in arr[k] futhark-0.25.27/tests/fusion/Vers2.0/redomap0.fut000066400000000000000000000005141475065116200213750ustar00rootroot00000000000000-- == -- input { -- [1.0f32,-4.0f32,-2.4f32] -- } -- output { -- -5.4f32 -- [2.0f32, -3.0f32, -1.4f32] -- [3.0f32, -7.0f32, -3.8f32] -- } -- structure { -- Screma 1 -- } -- def main(arr: []f32): (f32,[]f32,[]f32) = let x = map (+1.0) arr let y = map2 (+) x arr let r = reduce (+) (0.0) arr in (r,x,y) futhark-0.25.27/tests/fusion/Vers2.0/redomap1.fut000066400000000000000000000007631475065116200214040ustar00rootroot00000000000000-- == -- input { -- [1.0f32,-4.0f32,-2.4f32] -- } -- output { -- -5.4f32 -- [2.0f32, -3.0f32, -1.4f32] -- [3.0f32, -2.0f32, -0.4f32] -- [4.0f32, -6.0f32, -2.8f32] -- [9.0f32, -6.0f32, -1.2f32] -- } -- structure { -- Screma 1 -- } -- def main(arr: []f32): (f32,[]f32,[]f32,[]f32,[]f32) = let xy = map (\(a: f32): (f32,f32) -> (a+1.0,a+2.0)) arr let (x,y) = unzip(xy) let z = map (*2.0) x let w = map (*3.0) y let r = reduce (+) (0.0) arr in (r,x,y,z,w) futhark-0.25.27/tests/fusion/Vers2.0/redoredomapomap0.fut000066400000000000000000000010271475065116200231240ustar00rootroot00000000000000-- == -- input { -- [1.0,-4.0,-2.4] -- } -- output { -- -5.4 -- [2.0,-3.0,-1.4] -- 8.4 -- [4.0,-6.0,-2.8] -- 3.0 -- [0.0,1.0,2.0] -- } -- structure { -- /Screma 2 -- } -- def mul2(x: []f64) (i: i32): f64 = x[i]*2.0 def main [n] (arr: [n]f64): (f64,[]f64,f64,[]f64,f64,[]f64) = let r1 = reduce (+) (0.0) arr let x = map (+1.0) arr let r2 = reduce (*) (1.0) x let y = map (mul2(x)) (map i32.i64 (iota(n))) let z = map f64.i64 (iota(n)) let r3 = reduce (+) (0.0) z in (r1,x,r2,y,r3,z) futhark-0.25.27/tests/fusion/WithAccs/000077500000000000000000000000001475065116200175145ustar00rootroot00000000000000futhark-0.25.27/tests/fusion/WithAccs/ker2-radix.fut000066400000000000000000000063741475065116200222160ustar00rootroot00000000000000def mapIntra f as = #[incremental_flattening(only_intra)] #[seq_factor(22)] map f as def map3Intra f as bs cs = #[incremental_flattening(only_intra)] #[seq_factor(22)] map3 f as bs cs -- def mapIntra f as = map f as -- def map3Intra f as bs cs = map3 f as bs cs let partition2 't [n] (dummy: t) (cond: t -> bool) (X: [n]t) : (i64, *[n]t) = let cs = map cond X let tfs= map (\ f->if f then 1i64 else 0i64) cs let isT= scan (+) 0 tfs let i = isT[n-1] let ffs= map (\f->if f then 0 else 1) cs let isF= map (+i) <| scan (+) 0 ffs let inds=map (\(c,iT,iF) -> if c then iT-1 else iF-1 ) (zip3 cs isT isF) let tmp = replicate n dummy in (i, scatter tmp inds X) -- let main [n] (arr: [n]i32) : (i64,*[n]i32) = -- partition2 0i32 (\(x:i32) -> (x & 1i32) == 0i32) arr def getBits (bit_beg: u32) (num_bits: u32) (x: u32) : u32 = let mask = (1 << num_bits) - 1 in (x >> bit_beg) & mask def isBitUnset1 (bit_num: u32) (x: u32) : u32 = let shft = x >> bit_num in 1 - (shft & 1) def isBitUnset (bit_num: u32) (x: u32) : bool = let shft = x >> bit_num in 0 == (shft & 1) def ker1Blk [n] (bit_beg: u32) (lgH: u32) (xs: [n]u32) : [i64.u32 (1u32 << lgH)]u16 = let histo_len = i64.u32 (1u32 << lgH) let inds = map (getBits bit_beg lgH) xs |> map i64.u32 in hist (+) 0u32 histo_len inds (replicate n 1u32) |> map u16.u32 def ker2Blk [n] (bit_beg: u32) (lgH: u32) (histo_loc: [i64.u32 (1u32 << lgH)]u16) (histo_glb: [i64.u32 (1u32 << lgH)]i64) (xs: [n]u32) : (*[n]u32, [n]i64) = let xs' = loop (xs) = (copy xs) for i < i32.u32 lgH do (partition2 0u32 (isBitUnset (bit_beg + u32.i32 i)) xs).1 let histo_scn = tabulate (i64.u32 (1u32 << lgH)) (\j -> if j==0 then 0u16 else histo_loc[j-1]) |> scan (+) 0u16 let histo = map3 (\ a b c -> a - i64.u16 (b + c)) histo_glb histo_loc histo_scn let f x i = let bin = getBits bit_beg lgH x in i + (histo[i32.u32 bin]) let inds = map2 f xs' (iota n) in (xs', inds) -- Simple test for fusing scatter-flatten with the preceding -- map nest that produces its indices and values -- == -- entry: radixIt -- input { 3i64 8i64 0u32 -- [ 12u32,11u32,10u32, 9u32, 4u32, 3u32, 2u32, 1u32 -- , 8u32, 7u32, 6u32, 5u32,12u32,11u32,10u32, 9u32 -- , 4u32, 3u32, 2u32, 1u32, 8u32, 7u32, 6u32, 5u32 -- ] -- } -- -- output { [ 1u32, 1u32, 2u32, 2u32, 3u32, 3u32, 4u32, 4u32 -- , 5u32, 5u32, 6u32, 6u32, 7u32, 7u32, 8u32, 8u32 -- , 9u32, 9u32, 10u32, 10u32, 11u32, 11u32, 12u32, 12u32 -- ] -- } entry radixIt (m: i64) (bq: i64) (bit_beg: u32) -- (lgH: u32) (xs: *[m*bq]u32) : *[m*bq]u32 = let lgH = 8u32 let hist16 = mapIntra (ker1Blk bit_beg lgH) (unflatten xs) let hist64 = transpose hist16 |> flatten |> map i64.u16 |> scan (+) 0i64 let hist64T = unflatten hist64 |> transpose let (xs', inds') = unzip <| map3Intra (ker2Blk bit_beg lgH) hist16 hist64T (unflatten xs) in scatter (replicate (m*bq) 0u32) (flatten inds') (flatten xs') futhark-0.25.27/tests/fusion/WithAccs/map-flat-scat-0.fut000066400000000000000000000014771475065116200230330ustar00rootroot00000000000000-- Simple test for fusing scatter-flatten with the preceding -- map nest that produces its indices and values -- == -- entry: main -- input { 3i64 4i64 [[12u32,11u32,10u32,9u32], [8u32,7u32,6u32,5u32], [4u32,3u32,2u32,1u32]] } -- output { [0u32, 1u32, 10u32, 45u32, 26u32, 62u32, 180u32, 81u32, 160u32, 405u32, 166u32, 302u32] } let i64sqrt x = f64.i64 x |> f64.sqrt |> i64.f64 entry main (m: i64) (b: i64) (xs: *[m][b]u32) : *[m*b]u32 = let inds = map (map2 (\ i x -> i64.u32 x * i64.u32 x + i ) (iota b)) xs let vals = map (map (\ x -> 5*x*x)) xs let inds' = flatten inds let vals' = flatten vals let inds'' = map2 (\i x -> i64sqrt (x - (i % b)) ) (iota (m*b)) inds' let vals'' = map2 (\x i -> x / u32.i64 (i % 3 + 1)) vals' (iota (m*b)) in scatter (replicate (m*b) 0u32) inds'' vals'' futhark-0.25.27/tests/fusion/WithAccs/map-flat-scat-1.fut000066400000000000000000000014771475065116200230340ustar00rootroot00000000000000-- Simple test for fusing scatter-flatten with the preceding -- map nest that produces its indices and values -- == -- entry: main -- input { 3i64 4i64 [[12u32,11u32,10u32,9u32], [8u32,7u32,6u32,5u32], [4u32,3u32,2u32,1u32]] } -- output { [0u32, 1u32, 10u32, 45u32, 26u32, 62u32, 180u32, 81u32, 160u32, 405u32, 166u32, 302u32] } let i64sqrt x = f64.i64 x |> f64.sqrt |> i64.f64 entry main (m: i64) (b: i64) (xs: *[m][b]u32) : *[m*b]u32 = let inds = map (map2 (\ i x -> i64.u32 x * i64.u32 x + i ) (iota b)) xs let vals = map (map (\ x -> 5*x*x)) xs let inds' = flatten inds let vals' = flatten vals let inds'' = map2 (\i x -> i64sqrt (x - (i % b)) ) (iota (m*b)) inds' let vals'' = map2 (\x i -> x / u32.i64 (i % 3 + 1)) vals' (iota (m*b)) in scatter (replicate (m*b) 0u32) inds'' vals'' futhark-0.25.27/tests/fusion/WithAccs/map-flat-scat-2.fut000066400000000000000000000013441475065116200230260ustar00rootroot00000000000000-- Simple test for fusing scatter-flatten with the preceding -- map nest that produces its indices and values -- == -- entry: main -- input { [[12i32,11i32,10i32,9i32], [8i32,7i32,6i32,5i32], [4i32,3i32,2i32,1i32]] } -- -- output { [12.0f32, 14.0f32, 18.0f32, 24.0f32, 32.0f32, 42.0f32, 54.0f32, 68.0f32, 84.0f32, 102.0f32, 122.0f32, 144.0f32] } entry main [m] [b] (ass : [m][b]i32) = let indvals = map (map2 (\ i a -> (i64.i32 a + i, f32.i32 (a*a))) (iota b) ) ass let (finds, fvals) = flatten indvals |> unzip let tmp = replicate (m*b) 0 let finds' = map2 (\ ind i -> ind - (1 + i % b) ) finds (iota (m*b)) let fvals' = map2 (\ vla i -> vla + f32.i64 i ) fvals (iota (m*b)) let res = scatter tmp finds' fvals' in res futhark-0.25.27/tests/fusion/WithAccs/map-flat-scat-3.fut000066400000000000000000000016001475065116200230220ustar00rootroot00000000000000-- Simple test for fusing scatter-flatten with the preceding -- map nest that produces its indices and values -- == -- entry: main -- input { 3i64 1i64 4i64 [[[12u32,11u32,10u32,9u32]], [[8u32,7u32,6u32,5u32]], [[4u32,3u32,2u32,1u32]]] } -- output { [0u32, 1u32, 10u32, 45u32, 26u32, 62u32, 180u32, 81u32, 160u32, 405u32, 166u32, 302u32] } let i64sqrt x = f64.i64 x |> f64.sqrt |> i64.f64 entry main (m: i64) (n:i64) (b: i64) (xss: *[m][n][b]u32) : *[m*n*b]u32 = let inds = map (map (map2 (\ i x -> i64.u32 x * i64.u32 x + i ) (iota b))) xss let vals = map (map (map (\ x -> 5*x*x))) xss let inds' = flatten inds |> flatten let vals' = flatten vals |> flatten let inds'' = map2 (\i x -> i64sqrt (x - (i % b)) ) (iota (m*n*b)) inds' let vals'' = map2 (\x i -> x / u32.i64 (i % 3 + 1)) vals' (iota (m*n*b)) in scatter (replicate (m*n*b) 0u32) inds'' vals'' futhark-0.25.27/tests/fusion/WithAccs/map-wacc-1.fut000066400000000000000000000015571475065116200220720ustar00rootroot00000000000000-- Simple test for fusing a map soac within a following withacc -- == -- entry: main -- -- input { 3i64 4i64 -- [[12f32,11f32,10f32,9f32], [8f32,7f32,6f32,5f32], [4f32,3f32,2f32,1f32]] -- [7f32, 13f32, 17f32] -- [120f32,110f32,100f32,90f32,80f32,70f32,60f32,50f32,40f32,30f32,20f32,111f32] -- } -- output { [168.0f32, 208.0f32, 136.0f32, 154.0f32, 182.0f32, 102.0f32, 140.0f32, 156.0f32, 68.0f32, 126.0f32, 130.0f32, 34.0f32] } import "../../accs/intrinsics" def update (T:i64) (R:i64) (Oacc: *acc ([T*R]f32)) (tid: i64, a: f32, Qk: [R]f32) : *acc ([T*R]f32) = loop Oacc for i < R do let elm = Qk[i] * a let ind = i*T + tid in write Oacc ind elm entry main (T: i64) (R: i64) (Q: [T][R]f32) (A: [T]f32) (O: *[T*R]f32) : *[T*R]f32 = let A' = map (*2.0) A in let z3 = zip3 (iota T) A' Q in scatter_stream O (update T R) z3 futhark-0.25.27/tests/fusion/WithAccs/map-wacc-map-wacc.fut000066400000000000000000000050541475065116200234160ustar00rootroot00000000000000-- Simple test for vertical fusion of: map-withacc-map-withacc -- == -- entry: main -- -- input { 2i64 2i64 2i64 -- [ [ [12f32,11f32,10f32,9f32,11f32] -- , [8f32, 7f32, 6f32, 5f32, 6f32] -- , [4f32, 3f32, 2f32, 1f32, 4f32] -- , [12f32, 7f32, 3f32,5f32,11f32] -- ] -- , [ [12f32, 7f32, 1f32,5f32,11f32] -- , [ 5f32, 8f32, 7f32, 1f32,2f32] -- , [12f32,11f32, 5f32,9f32,15f32] -- , [6f32, 3f32, 7f32, 1f32, 4f32] -- ] -- ] -- [ 120f32,110f32,100f32,90f32,110f32 -- , 80f32, 70f32, 60f32, 50f32, 60f32 -- , 40f32, 30f32, 20f32, 10f32, 40f32 -- , 120f32, 70f32, 30f32,50f32,110f32 -- , 120f32, 70f32, 10f32,50f32,110f32 -- , 50f32, 80f32, 70f32, 10f32,20f32 -- , 120f32,110f32, 50f32,90f32,150f32 -- , 60f32, 30f32, 70f32, 10f32, 40f32 -- ] -- [7f32, 13f32, 17f32, 11f32] -- } -- output { [ 50.399998f32, 29.399998f32, 4.2f32, 21.0f32, 110.0f32 -- , 22.5f32, 36.0f32, 31.5f32, 4.5f32, 60.0f32 -- , 55.199997f32, 50.6f32, 23.0f32, 41.399998f32, 40.0f32 -- , 26.57143f32, 13.285715f32, 31.000002f32, 4.4285717f32, 110.0f32 -- , 120.0f32, 70.0f32, 10.0f32, 50.0f32, 110.0f32 -- , 50.0f32, 80.0f32, 70.0f32, 10.0f32, 20.0f32 -- , 120.0f32, 110.0f32, 50.0f32, 90.0f32, 150.0f32 -- , 60.0f32, 30.0f32, 70.0f32, 10.0f32, 40.0f32 -- ] -- } -- input { 3i64 4i64 [[12u32,11u32,10u32,9u32], [8u32,7u32,6u32,5u32], [4u32,3u32,2u32,1u32]] } -- output { [0u32, 1u32, 10u32, 45u32, 26u32, 62u32, 180u32, 81u32, 160u32, 405u32, 166u32, 302u32] } import "../../accs/intrinsics" -- Q = E * T*T * (R*R+1) def accUpdateO [Q] (T: i64) (R: i64) (k: i64) (Ok: [T*T][R*R+1]f32) (O: *[Q]f32) (A: [T*T]f32) : *[Q]f32 = let inner = R*R+1 let glb_offset = k * (T*T*inner) let f (Oacc: *acc ([Q]f32)) (tid,a) = let offset = glb_offset + tid*inner in loop Oacc for i < R do loop Oacc for j < R do let elm = Ok[tid][i*R+j] * a let ind = (offset + i*R) + j in write Oacc ind elm in scatter_stream O f (zip (iota (T*T)) A) def main (E: i64) (T: i64) (R: i64) (Ok: [E][T*T][R*R+1]f32) (O: *[E*(T*T)*(R*R+1)]f32) (A: [T*T]f32) : *[E*(T*T)*(R*R+1)]f32 = let e = f32.i64 E let (_,O) = loop (A,O) for k < E-1 do let A' = map (\x -> (x+1.0+e)/(x+e-1.0) ) A let O = accUpdateO T R k Ok[k] O A' let A''= map (\x -> (x+2.0+e)/(x+e-2.0) ) A' let O = accUpdateO T R k Ok[k+1] O A'' in (A'',O) in Ofuthark-0.25.27/tests/fusion/WithAccs/wacc-map-1.fut000066400000000000000000000017171475065116200220700ustar00rootroot00000000000000-- Simple test for fusing a map soac within a following withacc -- == -- entry: main -- -- input { 3i64 4i64 -- [[12f32,11f32,10f32,9f32], [8f32,7f32,6f32,5f32], [4f32,3f32,2f32,1f32]] -- [7f32, 13f32, 17f32] -- [120f32,110f32,100f32,90f32,80f32,70f32,60f32,50f32,40f32,30f32,20f32,111f32] -- } -- output { [108.0f32, 120.0f32, 76.0f32, 99.0f32, 105.0f32, 57.0f32, 90.0f32, 90.0f32, 38.0f32, 81.0f32, 75.0f32, 19.0f32] -- [11.0f32, 17.0f32, 21.0f32] -- } import "../../accs/intrinsics" def update (T:i64) (R:i64) (Oacc: *acc ([T*R]f32)) (tid: i64, a: f32, Qk: [R]f32) : *acc ([T*R]f32) = loop Oacc for i < R do let elm = Qk[i] * a let ind = i*T + tid in write Oacc ind elm entry main (T: i64) (R: i64) (Q: [T][R]f32) (A: [T]f32) (O: *[T*R]f32) : (*[T*R]f32, *[T]f32) = let A' = map (+2.0) A let z3 = zip3 (iota T) A' Q let r' = scatter_stream O (update T R) z3 let A''= map (+2.0) A' in (r', A'') futhark-0.25.27/tests/fusion/consumption0.fut000066400000000000000000000005711475065116200211700ustar00rootroot00000000000000-- After fusion, consumes a0s. See issue #224. -- -- == -- structure { /Screma 1 } def main [m][b] (d: i32, a0s: [m][b][b]f32): *[m][b][b]f32 = let a1s = map (\(x: [][]f32): [b][b]f32 -> transpose x) a0s in map (\(a1: [][]f32): *[b][b]f32 -> map (\(row: []f32): *[b]f32 -> copy row with [d] = 0f32 ) a1 ) a1s futhark-0.25.27/tests/fusion/consumption1.fut000066400000000000000000000005321475065116200211660ustar00rootroot00000000000000-- After fusion, consumes a free variable. Fixed with copy(). -- -- == -- structure { /Screma 1 } def main [n][m] (as: [n]i32, bs: [m]bool): [m][n]i32 = let css = map (\(b: bool): [n]i32 -> if b then map (+1) as else as) bs let dss = map (\(cs: [n]i32): [n]i32 -> copy cs with [0] = 42) css in dss futhark-0.25.27/tests/fusion/consumption2.fut000066400000000000000000000007131475065116200211700ustar00rootroot00000000000000-- After fusion, consumes a free variable. Fixed with copy(). -- -- == -- structure { /Screma 1 } def main [n][m] (as: [n]i32, bs: [m]bool): [n]i32 = let css = map (\(b: bool): [n]i32 -> if b then map i32.i64 (iota n) else as) bs let dss = map (\(cs: []i32): [n]i32 -> copy cs with [0] = 42) css in reduce (\(ds0: []i32) (ds1: []i32): [n]i32 -> map2 (+) ds0 ds1) ( replicate n 0) dss futhark-0.25.27/tests/fusion/consumption3.fut000066400000000000000000000002301475065116200211630ustar00rootroot00000000000000-- Not fusible. -- == -- structure { Screma 2 } def main (xs: *[]i32) = let ys = map (+2) xs let xs[0] = 2 let zs = map (*xs[1]) ys in (xs,zs) futhark-0.25.27/tests/fusion/filter-filter1.fut000066400000000000000000000006331475065116200213620ustar00rootroot00000000000000-- == -- input { -- [1,2,3,4,5,6,7,8,9] -- [10,20,30,40,50,60,70,80,90] -- } -- output { -- [20, 30, 40, 60, 80, 90] -- } def div2(x: i32): bool = x % 2 == 0 def div3(x: i32): bool = x % 3 == 0 def main(a: []i32) (b: []i32): []i32 = let (c1,c2) = unzip(filter (\(x: i32, y: i32): bool -> div2(x) || div3(y)) ( zip a b)) in filter div2 c2 futhark-0.25.27/tests/fusion/fuse-across-reshape-transpose.fut000066400000000000000000000004211475065116200244170ustar00rootroot00000000000000-- == -- input { -- } -- output { -- [[2, 8, 14], [4, 10, 16], [6, 12, 18]] -- } -- structure { /Screma 1 } def main: [][]i32 = let a = map (+1) (map i32.i64 (iota(3*3))) let b = unflatten a let c = transpose b in map (\(row: []i32) -> map (*2) row) c futhark-0.25.27/tests/fusion/fuse-across-reshape1.fut000066400000000000000000000004261475065116200224710ustar00rootroot00000000000000-- == -- input { -- } -- output { -- [[2, 4, 6], [8, 10, 12], [14, 16, 18]] -- } -- structure { -- /Screma 1 -- } def main: [][]i32 = let a = map (+1) (map i32.i64 (iota(3*3))) let b = unflatten a in map (\(row: []i32) -> map (\(x: i32): i32 -> x*2) row) b futhark-0.25.27/tests/fusion/fuse-across-reshape2.fut000066400000000000000000000004561475065116200224750ustar00rootroot00000000000000-- == -- input { -- } -- output { -- [[0, 9, 18], [27, 36, 45], [54, 63, 72]] -- } def main: [][]i32 = let a = map (\i -> replicate 9 (i32.i64 i)) (iota (3*3)) let b = unflatten_3d (flatten a) in map (\(row: [][]i32) -> map (\(x: []i32): i32 -> reduce (+) 0 x) row) b futhark-0.25.27/tests/fusion/fuse-across-reshape3.fut000066400000000000000000000004061475065116200224710ustar00rootroot00000000000000-- structure { Map 3 Map/Map/Map 1 Map/Map/Scan 1 } def main(n: i64, m: i64, k: i64): [][][]f32 = map (\(ar: [][]f32): [m][n]f32 -> map (\(arr: []f32): [n]f32 -> scan (+) 0f32 arr) ar) ( unflatten_3d (map f32.i64 (iota(k*m*n)))) futhark-0.25.27/tests/fusion/fuse-across-transpose1.fut000066400000000000000000000002661475065116200230620ustar00rootroot00000000000000-- == -- structure { Screma 2 } def main [n] (a: [][n]i32): [][]i32 = let b = map (\x1: [n]i32 -> map (+1) x1) a let c = map (\z1: [n]i32 -> map (*3) z1) (transpose b) in c futhark-0.25.27/tests/fusion/fuse-across-transpose2.fut000066400000000000000000000006531475065116200230630ustar00rootroot00000000000000-- == -- structure { /Screma 1 } def main (a_1: [][]i32, a_2: [][]i32): [][]i32 = let a = map2 (\(a_1r: []i32) (a_2r: []i32) -> zip a_1r a_2r) ( a_1) (a_2) let b = map (\(row: [](i32,i32)) -> map (\(x: i32, y: i32): (i32,i32) -> (x+y,x-y)) row) a let c = map (\(row: [](i32,i32)) -> map (\(x,y) -> x + y) row) (transpose b) in c futhark-0.25.27/tests/fusion/fuse-across-transpose3.fut000066400000000000000000000004321475065116200230570ustar00rootroot00000000000000-- == -- structure { Screma 2 } def main [n][m] (a: [n][m]i32): i32 = let b = map (\z1: [m]i32 -> map (*3) z1) a let ravgs = map (\r: i32 -> reduce (+) 0 r / i32.i64 n) (transpose b) let res = reduce (+) 0 ravgs in res futhark-0.25.27/tests/fusion/fuse-across-transpose4.fut000066400000000000000000000006221475065116200230610ustar00rootroot00000000000000-- == -- input { [[1,2,3],[4,5,6]] [[7,8,9],[1,2,3]] } -- output { [[10, 7], [12, 9], [14, 11]] } -- structure { /Screma 1 } def main [n][m] (a: [n][m]i32) (b: [n][m]i32): [][]i32 = let a2 = map (\r: [n]i32 -> map (+1) r) (transpose a) let b2 = map (\r: [n]i32 -> map (+1) r) (transpose b) let c = map (\(rx,ry): [n]i32 -> map2 (+) rx ry) (zip a2 b2) in c futhark-0.25.27/tests/fusion/fuse-across-transpose5.fut000066400000000000000000000007311475065116200230630ustar00rootroot00000000000000-- == -- input { -- [[1,2,3],[4,5,6],[7,8,9]] -- } -- output { -- [[0, 1, 2], [0, 2, 4], [0, 3, 6]] -- } def main [n][m] (a: [n][m]i32): [][]i32 = let foo = replicate m (map i32.i64 (iota n)) let bar = replicate m (map i32.i64 (iota n)) let b = replicate n (map i32.i64 (iota m)) let c = map (\(xs: []i32, ys: []i32,zs: []i32) -> map (\(x: i32, y: i32, z: i32): i32 -> x+y*z) (zip3 xs ys zs)) (zip3 foo bar (transpose b)) in c futhark-0.25.27/tests/fusion/fuse-across-transpose6.fut000066400000000000000000000065711475065116200230740ustar00rootroot00000000000000-- Inspired by the blackScholes function in OptionPricing. This -- program once malfunctioned because rearrange-pulling did not -- properly update the lambda indices. -- -- == -- -- input { -- [ -- [ 1.0000000f32, 0.6000000f32, 0.8000000f32 ], -- [ 0.6000000f32, 0.8000000f32, 0.1500000f32 ], -- [ 0.8000000f32, 0.1500000f32, 0.5809475f32 ] -- ] -- [ -- [ 0.1900000f32, 0.1900000f32, 0.1500000f32 ], -- [ 0.1900000f32, 0.1900000f32, 0.1500000f32 ], -- [ 0.1900000f32, 0.1900000f32, 0.1500000f32 ], -- [ 0.1900000f32, 0.1900000f32, 0.1500000f32 ], -- [ 0.1900000f32, 0.1900000f32, 0.1500000f32 ] -- ] -- [ -- [ -0.0283491736871803f32, 0.0178771081725381f32, 0.0043096808044729f32 ], -- [ -0.0183841413744211f32, -0.0044530897672834f32, 0.0024263805987983f32 ], -- [ -0.0172686581005089f32, 0.0125638544546015f32, 0.0094452810918001f32 ], -- [ -0.0144179417871814f32, 0.0157411263968213f32, 0.0125315353728014f32 ], -- [ -0.0121497422218761f32, 0.0182904634062437f32, 0.0151125070556484f32 ] -- ] -- [ [ 2.2372928847280580f32, 1.0960951589853829f32, 0.7075902730592357f32, 0.8166828043492210f32, 0.7075902730592357f32 ], -- [ 0.0000000000000000f32, 0.5998905309250137f32, 0.4993160054719562f32, 0.6669708029197080f32, 0.5006839945280438f32 ], -- [ 0.0000000000000000f32, 0.4001094690749863f32, 0.5006839945280438f32, 0.3330291970802919f32, 0.4993160054719562f32 ] -- ] -- } -- output { -- [[1.4869640253447387f32, 1.3138063004156617f32, 1.313617559344596f32], -- [1.7978839917383833f32, 1.6235475749754527f32, 1.5763413379252744f32], -- [2.021386995094619f32, 1.9227160475136045f32, 1.8300219186412707f32], -- [2.326897180853236f32, 2.372540501622419f32, 2.135902194722964f32], -- [2.6295904397262726f32, 2.8264486916069096f32, 2.4935049184863627f32]] -- } -- structure { /Screma 1 /Screma/Screma 1 /Screma/Screma/Screma 1 } def correlateDeltas [num_und][num_dates] (md_c: [num_und][num_und]f32, zds: [num_dates][num_und]f32): [num_dates][num_und]f32 = map (\(zi: [num_und]f32): [num_und]f32 -> map (\j: f32 -> let j' = j + 1 let x = map2 (*) (take j' zi) (take j' md_c[j]) in reduce (+) (0.0) x ) (iota(num_und) ) ) zds def combineVs [num_und] (n_row: [num_und]f32, vol_row: [num_und]f32, dr_row: [num_und]f32 ): [num_und]f32 = map2 (+) dr_row (map2 (*) n_row vol_row) def mkPrices [num_dates][num_und] (md_vols: [num_dates][num_und]f32, md_drifts: [num_dates][num_und]f32, noises: [num_dates][num_und]f32): [num_dates][num_und]f32 = let c_rows = map combineVs (zip3 noises (md_vols) (md_drifts) ) let e_rows = map (\(x: []f32): [num_und]f32 -> map f32.exp x ) (c_rows ) in scan (\x y -> map2 (*) x y) (replicate num_und 1.0) (e_rows ) -- Formerly blackScholes. def main [num_dates][num_und] (md_c: [num_und][num_und]f32) (md_vols: [num_dates][num_und]f32) (md_drifts: [num_dates][num_und]f32) (bb_arr: [num_und][num_dates]f32): [num_dates][num_und]f32 = -- I don't want to import the entire Brownian bridge, so we just -- transpose bb_arr. let bb_row = transpose bb_arr let noises = correlateDeltas(md_c, bb_row) in mkPrices(md_vols, md_drifts, noises) futhark-0.25.27/tests/fusion/fuse-across-transpose7.fut000066400000000000000000000007261475065116200230710ustar00rootroot00000000000000-- Careful not to fuse excessively. -- == -- structure { /Screma 2 } def matmul_diag_full [n][m] (ds: [n]f64) (A: [n][m]f64): [n][m]f64 = map2 (\d as -> map (*d) as) ds A def matmul_full_diag [n][m] (A: [n][m]f64) (ds: [m]f64): [n][m]f64 = transpose (map2 (\d as -> map (*d) as) ds (transpose A)) def main (k0: f64) (D: []f64) (W: [][]f64) = let X = map (\d -> f64.exp (f64.neg d * k0 * 1)) D let temp = X `matmul_diag_full` W `matmul_full_diag` X in temp futhark-0.25.27/tests/fusion/fuseComplex1.fut000066400000000000000000000020741475065116200211050ustar00rootroot00000000000000-- == -- input { [1.0, 2.0] [[1.0, 2.0], [-4.0, 1.5]] } -- output { 8.0 -1.0 0.0 -1.0 5.5 } -- structure { Screma 1 } def f1(p: (f64,f64) ): (f64,f64,f64) = let (a1, a2) = p in (a1 * a2, a1 + a2, a1 - a2) def f2(p: (f64,f64) ): (f64,f64) = let (a1, a2) = p in (a1 - a2, a1 + 2.0*a2) def g (p: (f64,f64,f64,f64) ): (f64,f64) = let (a1,a2,a3,a4) = p in (a1 * a2 - a3 * a4, a3 + a4 + a2 - a1) --let f64 myop ( (f64,f64,f64,f64,f64) p ) = -- let {a1,a2,a3,a4,a5} = p in a1+a2+a3+a4+a5 def myop (p: (f64,f64,f64,f64,f64)) (q: (f64,f64,f64,f64,f64)): (f64,f64,f64,f64,f64) = let (a1,a2,a3,a4,a5) = p let (b1,b2,b3,b4,b5) = q in (a1+b1,a2+b2,a3+b3,a4+b4,a5+b5) --let f64 def main(x1: []f64) (x2: [][]f64): (f64,f64,f64,f64,f64) = let (y1, y2, y3) = unzip3( map f1 (zip x1 (x2[1] ) ) ) let (z1, z2) = unzip2( map f2 (zip y1 y2 ) ) let (q1, q2) = unzip2( map g (zip4 y3 z1 y2 y3 ) ) in -- let res = map ( myop, zip(q1,q2,z2,y1,y3) ) in -- res[3] reduce myop (0.0,0.0,0.0,0.0,0.0) (zip5 q1 q2 z2 y1 y3 ) futhark-0.25.27/tests/fusion/fuseEasy1.fut000066400000000000000000000003761475065116200204020ustar00rootroot00000000000000-- == -- input { -- [1.0,-4.0,-2.4] -- } -- output { -- 36.000000 -- } def f(a: f64 ): f64 = a + 3.0 def g(a: f64 ): f64 = a * 3.0 def main(arr: []f64): f64 = let x = map f arr let y = map g x let z = map g y in z[0] futhark-0.25.27/tests/fusion/fuseEasy2.fut000066400000000000000000000005001475065116200203700ustar00rootroot00000000000000-- == -- input { -- [1.0,-4.0,-2.4] -- } -- output { -- 16.000000 -- } def f(a: f64 ): f64 = a + 3.0 def g(a: f64 ): f64 = a * 3.0 def h(a1: f64, a2: f64, a3: f64): f64 = a1 * a2 + a3 def main(arr: []f64): f64 = let x = map f arr let y = map g arr let z = map h (zip3 x y x) in z[0] futhark-0.25.27/tests/fusion/fuseEasy3.fut000066400000000000000000000010321475065116200203720ustar00rootroot00000000000000-- == -- input { [1.0, 2.0, -4.0, 1.5] } -- output { [13.0, 22.0, -2.0, 17.25] } def f(a: f64 ): f64 = a + 3.0 def g(a: f64 ): f64 = a * 3.0 def h1(a1: f64, a2: f64, a3: f64): f64 = a1 * a2 + a3 --let f64 h2((f64,f64) a23) = let (a2,a3) = a23 in a2 * a3 def h2(a1: f64) (a23: (f64,f64)): f64 = let (a2,a3) = a23 in a2 * a3 - a1 def main(arr: []f64): []f64 = let x = map f arr let y = map g arr in if arr[0] < 0.0 then map h1 (zip3 x y x) --else map(h2(1.0), zip(y,x)) else map (h2(y[0])) (zip x x) futhark-0.25.27/tests/fusion/fuseEasy4.fut000066400000000000000000000005031475065116200203750ustar00rootroot00000000000000-- == -- input { [4.0, 2.0, -4.0, 1.5] } -- output { 17.0 } def f(a: f64, b: f64): f64 = a + 3.0 def g(a: f64, b: f64): f64 = a * 3.0 def main (arr: []f64): f64 = let n = i64.f64 arr[0] let x = replicate n 2.0 let y = map f (zip x (arr :> [n]f64)) let z = map g (zip (arr :> [n]f64) x) in y[0] + z[0] futhark-0.25.27/tests/fusion/fuseFilter1.fut000066400000000000000000000003241475065116200207170ustar00rootroot00000000000000-- == -- input { -- [3,5,-2,3,4,-30] -- [-4,10,1,-8,2,4] -- } -- output { -- [1, 4] -- } def main (a: []i32) (b: []i32): []i32 = let (c,d) = unzip(filter (\(x,y) -> x+y < 0) (zip a b)) in filter (0<) d futhark-0.25.27/tests/fusion/fuseFilter2.fut000066400000000000000000000002541475065116200207220ustar00rootroot00000000000000-- == -- input { [1,2,3,4] [5,6,7,8] } -- output { 26 } def main(a: []i32) (b: []i32): i32 = let (a2,b2) = unzip(filter (\(x,y) -> x < y) (zip a b)) in reduce (+) 0 b2 futhark-0.25.27/tests/fusion/fusion1.fut000066400000000000000000000005441475065116200201160ustar00rootroot00000000000000-- == -- input { -- [1.0,2.0,3.0,4.0] -- } -- output { -- 65.000000 -- } def f(a: f64 ): f64 = a + 3.0 def g(a: f64 ): f64 = a * 3.0 def h(a: f64, b: f64): f64 = a * b - (a + b) def main(arr: []f64): f64 = let b = map f arr --let arr[1] = 3.33 in let x = map f b let y = map g b let z = map h (zip x y) in z[0] futhark-0.25.27/tests/fusion/fusion2.fut000066400000000000000000000005441475065116200201170ustar00rootroot00000000000000-- == -- input { -- [1.0,2.0,3.0,4.0] -- } -- output { -- 73.000000 -- } def f(a: f64): f64 = a + 3.0 def g(a: f64): f64 = a * 3.0 def h(x: f64) (y: (f64,f64)): f64 = let (a,b) = y in a * b - (a + b) + x def main(arr: []f64): f64 = let b = map f arr let x = map f b let y = map g b let z = map (h(x[1])) (zip x y) in z[0] --+ y[0] futhark-0.25.27/tests/fusion/fusion3.fut000066400000000000000000000010331475065116200201120ustar00rootroot00000000000000-- == -- input { -- [-2.0,3.0,9.0] -- } -- output { -- 19.0 -- } -- structure { Screma 1 } def f(a: f64 ): f64 = a + 3.0 def g(a: f64 ): f64 = a * 3.0 def h(x: f64, y: (f64,f64)): f64 = let (a,b) = y in a * b - (a + b) + x def opp(x: f64) (a: f64) (b: f64): f64 = x*(a+b) def main(arr: []f64): f64 = let arr2 = replicate 5 arr let y = map (\(x: []f64): f64 -> let a = map f x let b = reduce (opp(1.0)) (0.0) a in b ) arr2 in y[0] futhark-0.25.27/tests/fusion/fusion4.fut000066400000000000000000000003751475065116200201230ustar00rootroot00000000000000-- Test that filter can be fused into reduce. -- == -- input { -- [9,-3,5,2] -- } -- output { -- 6 -- } def divisibleBy(x: i32) (y: i32): bool = y % x == 0 def main(a: []i32): i32 = let threes = filter (divisibleBy 3) a in reduce (+) 0 threes futhark-0.25.27/tests/fusion/fusion5.fut000066400000000000000000000040241475065116200201170ustar00rootroot00000000000000-- Once failed in fusion. Derived from tail2futhark output. -- == -- input { [1, 2, -4, 1] [[1, 2], [-4, 1]] } -- output { -- [[true, false, false, false, false, false, false, false, false, false, false, -- false, false, false, false, false, false, false, false, false, false, false, -- false, false, false, false, false, false, false, false], -- [false, false, false, false, false, false, false, false, false, false, false, -- false, false, false, false, false, false, false, false, false, false, false, -- false, false, false, false, false, false, false, false], -- [true, false, false, false, false, false, false, false, false, false, false, -- false, false, false, false, false, false, false, false, false, false, false, -- false, false, false, false, false, false, false, false]] -- } -- structure { /Screma 3 /Screma/Screma 1 } def main(t_v1: []i32) (t_v3: [][]i32): [][]bool = let n = 3 let t_v6 = map (\(x: i32): i32 -> (x + 1)) (map i32.i64 (iota(n))) let t_v12 = map (\(x: i32): i32 -> (x + 1)) (map i32.i64 (iota(30))) let t_v18 = transpose (replicate 30 t_v6) let t_v19 = replicate n t_v12 let t_v27 = map (\(x: []i32,y: []i32) -> map2 (^) x y) ( zip (t_v18) ( map (\(x: []i32) -> map (<<1) x) (t_v18))) let t_v33 = map (\(x: []i32) -> map (\(t_v32: i32): bool -> ((0 != t_v32))) x) ( map (\(x: []i32,y: []i32) -> map2 (&) x y) ( zip (t_v27) ( map (\(x: []i32) -> map (\(t_v29: i32): i32 -> (1 >> t_v29)) x) ( map (\(x: []i32) -> map (\(t_v28: i32): i32 -> (t_v28 - 1)) x) ( t_v19))))) in t_v33 futhark-0.25.27/tests/fusion/fusion6.fut000066400000000000000000000007361475065116200201260ustar00rootroot00000000000000-- == -- structure { Screma 1 Scatter 1 } def main [n] (xs: [n]i32): [n]i32 = let num x = x&1 let pairwise op (a1,b1) (a2,b2) = (a1 `op` a2, b1 `op` b2) let bins = xs |> map num let flags = bins |> map (\x -> if x == 0 then (1,0) else (0,1)) let offsets = scan (pairwise (+)) (0,0) flags let f bin (a,b) = match bin case 0 -> a-1 case _ -> (last offsets).0+b-1 let is = map2 f bins offsets in scatter (copy xs) is xs futhark-0.25.27/tests/fusion/fusion7.fut000066400000000000000000000002161475065116200201200ustar00rootroot00000000000000-- == -- structure { Screma 1 } let main (is: []i32) (xs: []i32) = let foo = (map (+1) xs)[0] let bar = (map (+2) xs)[0] in (foo, bar) futhark-0.25.27/tests/fusion/horizontal0.fut000066400000000000000000000002331475065116200207760ustar00rootroot00000000000000-- No fusion of this, because the arrays are independent. -- == -- structure { Screma 2 } def main [n] (x: [n]i32) (y: [n]i32) = (map (+1) x, map (+2) y) futhark-0.25.27/tests/fusion/iswim1.fut000066400000000000000000000005411475065116200177400ustar00rootroot00000000000000-- == -- input { -- [[1,2,3],[4,5,6],[7,8,9]] -- } -- output { -- [[3, 4, 5], [7, 9, 11], [14, 17, 20]] -- } -- structure { /Screma 1 } def main(input: [][3]i32): [][]i32 = let x = scan (\(a: []i32) (b: []i32): [3]i32 -> map2 (+) a b) ( replicate 3 0) input in map (\(r: []i32): [3]i32 -> map (+2) r) x futhark-0.25.27/tests/fusion/iswim2.fut000066400000000000000000000014141475065116200177410ustar00rootroot00000000000000-- == -- input { -- [[1,2,3],[4,5,6],[7,8,9]] -- [[4,5,6],[7,8,9],[1,2,3]] -- } -- output { -- [[5, 7, 9], [16, 20, 24], [24, 30, 36]] -- } -- structure { /Screma 1 } def main(input1: [][]i32) (input2: [][]i32): [][]i32 = let input = map(\(r1: []i32, r2: []i32) -> zip r1 r2) (zip input1 input2) let x = scan(\(a: [](i32,i32)) (b: [](i32,i32)) -> let (a1, a2) = unzip(a) let (b1, b2) = unzip(b) in map(\(quad: (i32,i32,i32,i32)): (i32,i32) -> let (a1x,a2x,b1x,b2x) = quad in (a1x+b1x,a2x+b2x)) (zip4 a1 a2 b1 b2)) (zip (replicate 3 0) (replicate 3 0)) input in map (\(r: [](i32,i32)) -> map (\(x,y) -> x+y) r) x futhark-0.25.27/tests/fusion/iswim3.fut000066400000000000000000000027211475065116200177440ustar00rootroot00000000000000-- This test exposed a bug in map-nest creation. The program involves -- ISWIM with apparently more complex shapes than the other ISWIM -- tests. The bug happened whilst pulling a transpose before the -- producer. -- --== -- -- structure { Map 1 Redomap 1 Scanomap 1 } def correlateDeltas [num_und] [num_dates] (md_c: [num_und][num_und]f64, zds: [num_dates][num_und]f64 ): [num_dates][num_und]f64 = map (\(zi: [num_und]f64): [num_und]f64 -> map (\(j: i32): f64 -> let x = map2 (*) zi (md_c[j] ) in reduce (+) (0.0) x ) (map i32.i64 (iota(num_und))) ) zds def blackScholes [num_und][num_dates] (md_c:[num_und][num_und]f64, md_vols: [num_dates][num_und]f64, md_drifts: [num_dates][num_und]f64, md_starts: [num_und]f64, bb_arr: [num_dates][num_und]f64 ): [num_dates][num_und]f64 = let noises = correlateDeltas(md_c, bb_arr) in scan (\(x: []f64) (y: []f64) -> map2 (*) x y ) (md_starts) noises def main [num_und][num_dates] (md_cs: [num_und][num_und]f64, md_vols: [num_dates][num_und]f64, md_drifts: [num_dates][num_und]f64, md_sts: [num_und]f64, bb_row: [num_dates][num_und]f64 ): [][]f64 = let bd_row = blackScholes(md_cs, md_vols, md_drifts, md_sts, bb_row) in bd_row futhark-0.25.27/tests/fusion/map-scan1.fut000066400000000000000000000003031475065116200203030ustar00rootroot00000000000000-- == -- input { -- [1,2,3,4,5,6,7] -- } -- output { -- [3, 7, 12, 18, 25, 33, 42] -- } -- structure { -- Screma 1 -- } def main(a: []i32): []i32 = let b = scan (+) 0 (map (+2) a) in b futhark-0.25.27/tests/fusion/map-scan2.fut000066400000000000000000000003661475065116200203150ustar00rootroot00000000000000-- == -- input { -- [1,2,3,4,5,6,7] -- } -- output { -- [-1, -1, 0, 2, 5, 9, 14] -- } -- structure { -- Screma 1 -- } def main(a: []i32): []i32 = let (_,b) = unzip(map (\(x: i32): (i32,i32) -> (x+2,x-2)) a) let c = scan (+) 0 b in c futhark-0.25.27/tests/fusion/map-scan3.fut000066400000000000000000000014461475065116200203160ustar00rootroot00000000000000-- Mapping with a scanomap - segmented scanomap, but one that uses a -- distinct per-segment value in the fold part. -- -- The program is somewhat contrived to support large amount of work -- with only small input data sets. -- -- == -- input { 3i64 3i64 } -- output { 488i32 } -- input { 10i64 1000i64 } -- output { 1986778316i32 } -- compiled input { 10i64 10000i64 } -- output { -1772567048i32 } -- compiled input { 10000i64 10i64 } -- output { 1666665i32 } -- compiled input { 100000i64 10i64 } -- output { 16511385i32 } -- -- structure { -- /Screma/Stream 1 -- /Screma 1 -- } def main(n: i64) (m: i64): i32 = let factors = map (^123) (iota n) let res = map (\factor -> reduce (+) 0 (scan (+) 0 (map i32.i64 (map (*factor) (iota m))))) factors in res[n-2] futhark-0.25.27/tests/fusion/map-scatter-map.fut000066400000000000000000000030751475065116200215270ustar00rootroot00000000000000-- == -- compiled input { -- [3i64, -1i64, 1i64, 5i64, 2i64, -1i64, 7i64, 6i64] -- [7.0f32, 8.0f32, 9.0f32, 10.0f32, 12.0f32, 15.0f32, 18.0f32, 11.0f32] -- [1.0f32, 3.0f32, 2.0f32, 5.0f32, 4.0f32, 7.0f32, 9.0f32, 8.0f32] -- [5.0f32, 9.0f32, 22.0f32, 33.0f32, 27.0f32, 22.0f32, 17.0f32, 8.0f32] -- } -- output { -- [231735.0f32, 69984.0f32, 518400.0f32, 3626700.0f32, 528768.0f32, 2310000.0f32, 7840800.0f32, 1.28304e7f32] -- [811930.0f32, 93312.0f32, 1166400.0f32, 3626700.0f32, 886464.0f32, 2475000.0f32, 7840800.0f32, 8820900.0f32] -- } def main [n] (is: [n]i64) (vs: [n]f32) (xs: [n]f32) (ys_bar: *[n]f32) = let map_res_1 = map2 (*) xs vs let zip_copy = copy map_res_1 let map_res_2 = map2 (*) vs zip_copy let scatter_res_1 = scatter map_res_1 is map_res_2 let (map_adjs_1, map_adjs_2) = unzip <| map3 (\ x y lam_adj -> (y * lam_adj, x * lam_adj) ) map_res_2 scatter_res_1 ys_bar let scatter_res_adj_gather = map (\ is_elem -> if is_elem >= 0 && is_elem < n then map_adjs_2[is_elem] else 0.0f32 ) is let map_res_adj_1 = map2 (+) map_adjs_1 scatter_res_adj_gather let map_res_bar = scatter map_adjs_2 is (replicate n 0.0f32) let (map_adjs_3, map_adjs_4) = unzip <| map3 (\ x y lam_adj -> (y * lam_adj, x * lam_adj) ) vs zip_copy map_res_adj_1 let map_res_adj_2 = map2 (+) map_res_bar map_adjs_4 let (map_adjs_5, map_adjs_6) = unzip <| map3 (\ x y lam_adj -> (y * lam_adj, x * lam_adj) ) xs vs map_res_adj_2 let x_adj = map2 (+) map_adjs_3 map_adjs_6 in (x_adj, map_adjs_5) futhark-0.25.27/tests/fusion/noFusion1.fut000066400000000000000000000003451475065116200204120ustar00rootroot00000000000000-- == def f(a: f64): f64 = a + 3.0 def g(a: f64) (b: f64): f64 = a * b def main(arr: []f64): f64 = let n = t64(arr[0]) let x = map f arr in let arr = loop (arr) for i < n do map (g(arr[i])) x in arr[0] futhark-0.25.27/tests/fusion/noFusion2.fut000066400000000000000000000003001475065116200204020ustar00rootroot00000000000000-- == def f(a: f64 ): f64 = a + 3.0 def g(a: f64 ): f64 = a * 3.0 def main(arr: []f64): f64 = let x = map f arr let y = map f x let z = map g x in y[0] + z[0] futhark-0.25.27/tests/fusion/noFusion3.fut000066400000000000000000000004011475065116200204050ustar00rootroot00000000000000-- == -- structure { Screma 3 } def f(a: f64): f64 = a + 3.0 def g(a: []f64) (b: f64): f64 = a[0] * b def h(a: f64) (b: f64): f64 = a * b def main(arr: []f64): f64 = let x = map f arr let y = map (g(x)) x let z = map (h(y[0])) y in z[0] futhark-0.25.27/tests/fusion/noFusion4.fut000066400000000000000000000002171475065116200204130ustar00rootroot00000000000000-- == def main(arr: *[]f64): f64 = let x = map (+1.0) arr let arr[1] = 3.33 let y = map (*2.0) x in y[0] + arr[1] futhark-0.25.27/tests/fusion/red-red-fusion.fut000066400000000000000000000002031475065116200213450ustar00rootroot00000000000000-- Horizontal fusion of reductions. -- == -- structure { Screma 1 } def main (xs: []i32) = (i32.sum xs, f32.sum (map f32.i32 xs)) futhark-0.25.27/tests/fusion/scanomap-scanomap1.fut000066400000000000000000000005011475065116200222040ustar00rootroot00000000000000-- == -- input { -- [1,2,3,4,5,6,7] -- } -- output { -- [2, 3, 4, 5, 6, 7, 8] -- [2, 5, 9, 14, 20, 27, 35] -- [2, 6, 24, 120, 720, 5040, 40320] -- } -- structure { -- Screma 1 -- } def main(inp: []i32): ([]i32, []i32, []i32) = let a = map (+1) inp let b = scan (+) 0 a let c = scan (*) 1 a in (a, b, c) futhark-0.25.27/tests/fusion/scanomap-scanomap2.fut000066400000000000000000000005741475065116200222170ustar00rootroot00000000000000-- == -- input { -- [1,2,3,4,5,6,7] -- } -- output { -- [2, 3, 4, 5, 6, 7, 8] -- [2, 5, 9, 14, 20, 27, 35] -- [3, 4, 5, 6, 7, 8, 9] -- [3, 12, 60, 360,2520,20160,181440] -- } -- structure { -- Screma 1 -- } def main(inp: []i32): ([]i32, []i32, []i32, []i32) = let a = map (+1) inp let b = scan (+) 0 a let c = map (+1) a let d = scan (*) 1 c in (a, b, c, d) futhark-0.25.27/tests/fusion/scanreduce0.fut000066400000000000000000000003261475065116200207240ustar00rootroot00000000000000-- Horizontal fusion of scan and reduce. -- == -- input { [1,2,3,4] } output { [1, 3, 6, 10] 24 } -- structure { Screma 1 } -- structure gpu { SegScan 1 } def main (xs: []i32) = (scan (+) 0 xs, reduce (*) 1 xs) futhark-0.25.27/tests/fusion/scanreduce1.fut000066400000000000000000000005071475065116200207260ustar00rootroot00000000000000-- Complicated horizontal fusion between reductions and scans. -- == -- input { [1,2,-3,4,0,4] } -- output { 8i32 [false, false, false, false, false, false] 0i32 } -- structure { Screma 1 } -- structure gpu { SegRed 1 SegScan 1 } def main (xs: []i32) = (reduce (+) 0 xs, scan (&&) true (map (<0) xs), reduce (*) 0 xs) futhark-0.25.27/tests/fusion/scatter-not-fuse.fut000066400000000000000000000014071475065116200217340ustar00rootroot00000000000000-- == -- compiled input { -- [3i64, -1i64, 1i64, 5i64, 2i64, -1i64, 7i64, 6i64] -- [7.0f32, 8.0f32, 9.0f32, 10.0f32, 12.0f32, 15.0f32, 18.0f32, 11.0f32] -- } -- output { -- [50.0f32, 0.0f32, 40.0f32, 75.0f32, 45.0f32, 0.0f32, 55.0f32, 90.0f32] -- [21.0f32, 0.0f32, 0.0f32, 0.0f32, 36.0f32, 0.0f32, 0.0f32, 0.0f32] -- } def main [n] (is : [n]i64) (ys_bar: *[n]f32) = let scatter_res_adj_gather = map (\ is_elem -> if is_elem >= 0 && is_elem < n then ys_bar[is_elem] else 0 ) is let zeros = replicate n 0.0f32 let map_res_bar = scatter ys_bar is zeros let map_adjs_1 = map (\ lam_adj -> 5.0f32 * lam_adj ) scatter_res_adj_gather let map_adjs_2 = map (\ lam_adj -> 3.0f32 * lam_adj ) map_res_bar in (map_adjs_1, map_adjs_2) futhark-0.25.27/tests/fusion/scatter-o-maps.fut000066400000000000000000000012011475065116200213600ustar00rootroot00000000000000-- == -- compiled input { -- [3i64, -1i64, 1i64, -10i64] -- [2.0f32, 3.0f32, 3.0f32, 4.0f32] -- [7.0f32, 8.0f32, 9.0f32, 10.0f32, 12.0f32, 15.0f32, 18.0f32] -- } -- output { -- [7.0f32, 8.0f32, 9.0f32, 10.0f32, -3.0f32, 15.0f32, 3.0f32] -- [0.0f32, 0.0f32, 0.0f32, 0.0f32, -9.0f32, 0.0f32, 9.0f32] -- } -- -- structure { -- /Scatter 1 -- } def main [n][m] (is: [n]i64) (vs: [n]f32) (xs: *[m]f32) = let is' = map (+5) is let vs' = map2 (\ i v -> v * f32.i64 i) is vs let res1 = scatter xs is' vs' let vs'' = map2 (\ i v -> (v * v) / f32.i64 i) is vs' let res2 = scatter (replicate m 0.0f32) is' vs'' in (res1, res2) futhark-0.25.27/tests/fusion/slicemap0.fut000066400000000000000000000003171475065116200204050ustar00rootroot00000000000000-- The structure test is a bit iffy, as we cannot express a constraint -- on the ordering. -- == -- input { [1,2,3] } output { [3,5] } -- structure gpu { Index 1 } def main (xs: []i32) = (map (+2) xs)[::2] futhark-0.25.27/tests/fusion/slicemap1.fut000066400000000000000000000001301475065116200203770ustar00rootroot00000000000000-- == -- structure { Screma 1 } def main (xs: []i32) = map (+3) ((map (+2) xs)[::2]) futhark-0.25.27/tests/fusion/slicemap2.fut000066400000000000000000000003371475065116200204110ustar00rootroot00000000000000-- The structure test is a bit iffy, as we cannot express a constraint -- on the ordering. -- == -- input { [[1,2,3]] } output { [[3, 4, 5]] } -- structure gpu { Index 1 } def main (xs: [][]i32) = (map (map (+2)) xs)[::2] futhark-0.25.27/tests/fusion/slicemap3.fut000066400000000000000000000002611475065116200204060ustar00rootroot00000000000000-- == -- input { [[1,2,3],[4,5,6],[7,8,9]] } output { [[9, 15], [27, 33]] } -- structure gpu { Index 1 } def main (xs: [][]i32) = map (map (*3)) ((map (map (+2)) xs)[::2,::2]) futhark-0.25.27/tests/fusion/slicemap4.fut000066400000000000000000000001361475065116200204100ustar00rootroot00000000000000-- == -- structure { Screma 1 } def main (xs: []i32) = xs |> map (+2) |> reverse |> map (*3) futhark-0.25.27/tests/fusion/tabulate0.fut000066400000000000000000000004261475065116200204120ustar00rootroot00000000000000-- Indexing with iota elements is turned into proper mapping, thus -- permitting fusion. -- == -- input { [1,2,3] } output { [1,4,9] } -- structure { Screma 1 } def main [n] (xs: [n]i32) = let ys = map (\i -> #[unsafe] xs[i]) (iota n) in map (\i -> ys[i] * xs[i]) (iota n) futhark-0.25.27/tests/fusion/tabulate1.fut000066400000000000000000000005051475065116200204110ustar00rootroot00000000000000-- When turning a map-iota into a proper map, the array being indexed -- does not have to be of the same size as the map. -- == -- input { 3i64 [1,2,3] } output { [1,4,9] } -- structure { Screma 1 } def main [k] (n: i64) (xs: [k]i32) = let ys = map (\i -> #[unsafe] xs[i]) (iota n) in map (\i -> ys[i] * xs[i]) (iota n) futhark-0.25.27/tests/fusion/tabulate2.fut000066400000000000000000000003571475065116200204170ustar00rootroot00000000000000-- Indexing with iota elements, but the index is an inner dimension. -- == -- input { 1i64 [[1,2,3],[4,5,6]] } output { [8,10,12] } -- structure { Iota 0 } def main [n] (j: i64) (xs: [][n]i32) = map (\i -> #[unsafe] xs[j,i]*2) (iota n) futhark-0.25.27/tests/futlib_tests/000077500000000000000000000000001475065116200172135ustar00rootroot00000000000000futhark-0.25.27/tests/futlib_tests/array.fut000066400000000000000000000053651475065116200210620ustar00rootroot00000000000000-- Tests of various array functions from the basis library. -- == -- entry: test_length -- input { empty([0]i32) } output { 0i64 } -- input { [1,2,3] } output { 3i64 } entry test_length (x: []i32) = length x -- == -- entry: test_null -- input { empty([0]i32) } output { true } -- input { [1,2,3] } output { false } entry test_null (x: []i32) = null x -- == -- entry: test_head -- input { empty([0]bool) } error: Error -- input { [true,false] } output { true } entry test_head (x: []bool) = head x -- == -- entry: test_tail -- input { empty([0]bool) } error: Error -- input { [true] } output { empty([0]bool) } -- input { [true,false] } output { [false] } entry test_tail (x: []bool) = tail x -- == -- entry: test_init -- input { empty([0]bool) } error: Error -- input { [true] } output { empty([0]bool) } -- input { [true,false] } output { [true] } entry test_init (x: []bool) = init x -- == -- entry: test_last -- input { empty([0]bool) } error: Error -- input { [true] } output { true } -- input { [true,false] } output { false } entry test_last (x: []bool) = last x -- == -- entry: test_take -- input { 0 empty([0]bool) } output { empty([0]bool) } -- input { 1 empty([0]bool) } error: Error -- input { 0 [true] } output { empty([0]bool) } -- input { 1 [true] } output { [true] } -- input { 1 [true,false] } output { [true] } -- input { 2 [true,false,true] } output { [true,false] } entry test_take (i: i32) (x: []bool) = take (i64.i32 i) x -- == -- entry: test_drop -- input { 0 empty([0]bool) } output { empty([0]bool) } -- input { 1 empty([0]bool) } error: Error -- input { 0 [true] } output { [true] } -- input { 1 [true] } output { empty([0]bool) } -- input { 1 [true,false] } output { [false] } -- input { 2 [true,false,true] } output { [true] } entry test_drop (i: i32) (x: []bool) = drop (i64.i32 i) x -- == -- entry: test_reverse -- input { [[1,2],[3,4],[5,6]] } output { [[5,6],[3,4],[1,2]] } entry test_reverse (x: [][]i32) = reverse x -- == -- entry: test_or -- input { [true, true] } -- output { true } -- input { [true, false] } -- output { true } -- input { [false, false] } -- output { false } -- input { empty([0]bool) } -- output { false } entry test_or (xs: []bool) = or xs -- == -- entry: test_and -- input { [true, true] } -- output { true } -- input { [true, false] } -- output { false } -- input { [false, false] } -- output { false } -- input { empty([0]bool) } -- output { true } entry test_and (xs: []bool) = and xs -- == -- entry: test_flatten -- input { [[1,2],[3,4]] } output { [1,2,3,4] } entry test_flatten (xs: [][]i32) = flatten xs -- == -- entry: test_foldl -- input { 10i64 } output { -45i64 } entry test_foldl n = foldl (-) 0 (iota n) -- == -- entry: test_foldr -- input { 10i64 } output { -5i64 } entry test_foldr n = foldr (-) 0 (iota n) futhark-0.25.27/tests/futlib_tests/math.fut000066400000000000000000000056161475065116200206740ustar00rootroot00000000000000 -- == -- entry: test_f64_ceil -- input { 3.4 } output { 4.0 } -- input { 3.8 } output { 4.0 } -- input { 4.0 } output { 4.0 } -- input { 0.0 } output { 0.0 } -- input { -3.9 } output { -3.0 } -- input { -3.1 } output { -3.0 } -- input { -4.0 } output { -4.0 } entry test_f64_ceil (x: f64) = f64.ceil x -- == -- entry: test_f64_floor -- input { 3.4 } output { 3.0 } -- input { 3.8 } output { 3.0 } -- input { 4.0 } output { 4.0 } -- input { 0.0 } output { 0.0 } -- input { -3.9 } output { -4.0 } -- input { -3.1 } output { -4.0 } -- input { -4.0 } output { -4.0 } entry test_f64_floor (x: f64) = f64.floor x -- == -- entry: test_f64_trunc -- input { 3.4 } output { 3.0 } -- input { 3.8 } output { 3.0 } -- input { 4.0 } output { 4.0 } -- input { 0.0 } output { 0.0 } -- input { -3.9 } output { -3.0 } -- input { -3.1 } output { -3.0 } -- input { -4.0 } output { -4.0 } entry test_f64_trunc (x: f64) = f64.trunc x -- == -- entry: test_f64_round -- input { 0.0 } output { 0.0 } -- input { 99.0 } output { 99.0 } -- input { -5.0 } output { -5.0 } -- input { 1.1 } output { 1.0 } -- input { -1.1 } output { -1.0 } -- input { 1.9 } output { 2.0 } -- input { -1.9 } output { -2.0 } -- input { 2.5 } output { 2.0 } -- input { -2.5 } output { -2.0 } -- input { 1000001.4999 } output { 1000001.0 } -- input { -1000001.4999 } output { -1000001.0 } entry test_f64_round (x: f64) = f64.round x -- == -- entry: test_f32_ceil -- input { 3.4f32 } output { 4.0f32 } -- input { 3.8f32 } output { 4.0f32 } -- input { 4.0f32 } output { 4.0f32 } -- input { 0.0f32 } output { 0.0f32 } -- input { -3.9f32 } output { -3.0f32 } -- input { -3.1f32 } output { -3.0f32 } -- input { -4.0f32 } output { -4.0f32 } entry test_f32_ceil (x: f32) = f32.ceil x -- == -- entry: test_f32_floor -- input { 3.4f32 } output { 3.0f32 } -- input { 3.8f32 } output { 3.0f32 } -- input { 4.0f32 } output { 4.0f32 } -- input { 0.0f32 } output { 0.0f32 } -- input { -3.9f32 } output { -4.0f32 } -- input { -3.1f32 } output { -4.0f32 } -- input { -4.0f32 } output { -4.0f32 } entry test_f32_floor (x: f32) = f32.floor x -- == -- entry: test_f32_trunc -- input { 3.4f32 } output { 3.0f32 } -- input { 3.8f32 } output { 3.0f32 } -- input { 4.0f32 } output { 4.0f32 } -- input { 0.0f32 } output { 0.0f32 } -- input { -3.9f32 } output { -3.0f32 } -- input { -3.1f32 } output { -3.0f32 } -- input { -4.0f32 } output { -4.0f32 } entry test_f32_trunc (x: f32) = f32.trunc x -- == -- entry: test_f32_round -- input { 0.0f32 } output { 0.0f32 } -- input { 99.0f32 } output { 99.0f32 } -- input { -5.0f32 } output { -5.0f32 } -- input { 1.1f32 } output { 1.0f32 } -- input { -1.1f32 } output { -1.0f32 } -- input { 1.9f32 } output { 2.0f32 } -- input { -1.9f32 } output { -2.0f32 } -- input { 2.5f32 } output { 2.0f32 } -- input { -2.5f32 } output { -2.0f32 } -- input { 1001.4999f32 } output { 1001.0f32 } -- input { -1001.4999f32 } output { -1001.0f32 } entry test_f32_round (x: f32) = f32.round x futhark-0.25.27/tests/gauss_jordan.fut000066400000000000000000000022121475065116200177000ustar00rootroot00000000000000-- Matrix inversion using Gauss-Jordan elimination without pivoting. -- -- Taken from https://www.cs.cmu.edu/~scandal/nesl/alg-numerical.html#inverse -- -- == -- input { [[1.0f32, 2.0f32, 1.0f32], [2.0f32, 1.0f32, 1.0f32], [1.0f32, 1.0f32, 2.0f32]] } -- output { [[-0.25f32, 0.75f32, -0.25f32], [0.75f32, -0.25f32, -0.25f32], [-0.25f32, -0.25f32, 0.75f32]] } def Gauss_Jordan [n][m] (A: [n][m]f32): [n][m]f32 = (loop A for i < n do let irow = A[0] let Ap = A[1:n] let v1 = irow[i] let irow = map (/v1) irow let Ap = map (\jrow -> let scale = jrow[i] in map2 (\x y -> y - scale * x) irow jrow) Ap in Ap ++ [irow]) :> [n][m]f32 def matrix_inverse [n] (A: [n][n]f32): [n][n]f32 = -- Pad the matrix with the identity matrix. let n2 = n + n let on_row row i = let padding = replicate n 0.0 let padding[i] = 1f32 in concat row padding :> [n2]f32 let Ap = map2 on_row A (iota n) let Ap' = Gauss_Jordan Ap -- Drop the identity matrix at the front. in Ap'[0:n,n:n*2] :> [n][n]f32 def main [n] (A: [n][n]f32): [n][n]f32 = matrix_inverse A futhark-0.25.27/tests/globalsize0.fut000066400000000000000000000001231475065116200174330ustar00rootroot00000000000000-- #1920 def n = 3i64 def bar f = f { xs = replicate n 0f32 } def main = bar id futhark-0.25.27/tests/gregorian.fut000066400000000000000000000033271475065116200172060ustar00rootroot00000000000000-- Date computations. Some complex scalar expressions and a branch. -- Once messed up the simplifier. def mod(x: i32, y: i32): i32 = x - (x/y)*y def hours_in_dayI: i32 = 24 def minutes_in_dayI: i32 = hours_in_dayI * 60 def minutes_to_noonI: i32 = (hours_in_dayI / 2) * 60 def minutes_in_day: f64 = 24.0*60.0 def date_of_gregorian(date: (i32,i32,i32,i32,i32)): i32 = let (year, month, day, hour, mins) = date let ym = if(month == 1 || month == 2) then ( 1461 * ( year + 4800 - 1 ) ) / 4 + ( 367 * ( month + 10 ) ) / 12 - ( 3 * ( ( year + 4900 - 1 ) / 100 ) ) / 4 else ( 1461 * ( year + 4800 ) ) / 4 + ( 367 * ( month - 2 ) ) / 12 - ( 3 * ( ( year + 4900 ) / 100 ) ) / 4 let tmp = ym + day - 32075 - 2444238 in tmp * minutes_in_dayI + hour * 60 + mins def gregorian_of_date (minutes_since_epoch: i32 ): (i32,i32,i32,i32,i32) = let jul = minutes_since_epoch / minutes_in_dayI let l = jul + 68569 + 2444238 let n = ( 4 * l ) / 146097 let l = l - ( 146097 * n + 3 ) / 4 let i = ( 4000 * ( l + 1 ) ) / 1461001 let l = l - ( 1461 * i ) / 4 + 31 let j = ( 80 * l ) / 2447 let d = l - ( 2447 * j ) / 80 let l = j / 11 let m = j + 2 - ( 12 * l ) let y = 100 * ( n - 49 ) + i + l --let daytime = minutes_since_epoch mod minutes_in_day in let daytime = mod( minutes_since_epoch, minutes_in_dayI ) in if ( daytime == minutes_to_noonI ) --then [#year = y; month = m; day = d; hour = 12; minute = 0] then (y, m, d, 12, 0) --else [#year = y; month = m; day = d; hour = daytime / 60; minute = daytime mod 60] else (y, m, d, daytime / 60, mod(daytime, 60) ) def main(x: i32): i32 = date_of_gregorian(gregorian_of_date(x)) futhark-0.25.27/tests/guysteele_sequential.fut000066400000000000000000000014631475065116200214700ustar00rootroot00000000000000-- This program is a crude implementation of the sequential -- implementation from Guy Steele's talk "Four Solutions to a Trivial -- Problem" https://www.youtube.com/watch?v=ftcIcn8AmSY -- -- It is probably not the nicest way to do this in Futhark, but it -- found a bug in fusion (related to the 'reverse' function). -- == -- input { [2,6,3,5,2,8,1,4,2,2,5,3,5,7,4,1] } -- output { 35 } def min(x: i32) (y: i32): i32 = if x < y then x else y def max(x: i32) (y: i32): i32 = if x < y then y else x def reverse [n] (a: [n]i32): [n]i32 = map (\(i: i64): i32 -> a[n-i-1]) (iota(n)) def main(a: []i32): i32 = let highestToTheLeft = scan max 0 a let highestToTheRight = reverse(scan max 0 (reverse(a))) let waterLevels = map2 min highestToTheLeft highestToTheRight in reduce (+) 0 (map2 (-) waterLevels a) futhark-0.25.27/tests/hexfloats.fut000066400000000000000000000006071475065116200172240ustar00rootroot00000000000000-- Futhark supports hexadecimal float literals -- == -- input {} -- output { -- [31.875f64, 31.875f64, 17.996094f64, 3.984375f64, -17.996094f64] -- [31.875f32, 17.996094f32, 3.984375f32, -17.996094f32, 0.9375f32] -- } def main: ([]f64, []f32) = ([0xf.fp1, 0xf.fp1f64, 0x11.ffp0f64, 0xf.fp-2f64, -0x11.ffp0_0f64], [0xf.fp1f32, 0x11.ffp0f32, 0xf.fp-2f32, -0x11.ffp0f32, 0x0.f0p0f32]) futhark-0.25.27/tests/higher-order-functions/000077500000000000000000000000001475065116200210715ustar00rootroot00000000000000futhark-0.25.27/tests/higher-order-functions/alias0.fut000066400000000000000000000002241475065116200227600ustar00rootroot00000000000000-- Yet another case of aliasing that can result in incorrect code -- generation. def main (w: i64) (h: i64) = ([1,2,3] :> [w*h]i32) |> unflatten futhark-0.25.27/tests/higher-order-functions/alias1.fut000066400000000000000000000002351475065116200227630ustar00rootroot00000000000000-- Yet another case of aliasing that can result in incorrect code -- generation. def main [m][n] (xi_0: [m][n]f32) (xi_1: [m][n]f32) = map2 zip xi_0 xi_1 futhark-0.25.27/tests/higher-order-functions/alias2.fut000066400000000000000000000002701475065116200227630ustar00rootroot00000000000000def main [h][w][n] (ether: [h][w]f32) (is: [n]i64): [][]f32 = let ether_flat = copy (flatten ether) let vs = map (\i -> ether_flat[i]) is in unflatten (scatter ether_flat is vs) futhark-0.25.27/tests/higher-order-functions/alias3.fut000066400000000000000000000003231475065116200227630ustar00rootroot00000000000000type pair = (f32,i32) def main [h][w][n] (ether: [h][w]pair) (is: [n]i64): [h][w]pair = let ether_flat = copy (flatten ether) let vs = map (\i -> ether_flat[i]) is in unflatten (scatter ether_flat is vs) futhark-0.25.27/tests/higher-order-functions/array-fun0.fut000066400000000000000000000002461475065116200235770ustar00rootroot00000000000000-- We cannot have an array literal that contains function variables. -- == -- error: functional def f (x:i32) : i32 = x+x def g (x:i32) : i32 = x+1 def arr = [f, g] futhark-0.25.27/tests/higher-order-functions/array-fun1.fut000066400000000000000000000002421475065116200235740ustar00rootroot00000000000000-- We cannot have a parameter with an array of functions. -- == -- error: Cannot .* array with elements of lifted type .* -> .* def f (arr : [](i32->i32)) = arr futhark-0.25.27/tests/higher-order-functions/array-fun2.fut000066400000000000000000000003231475065116200235750ustar00rootroot00000000000000-- We cannot map a function that returns a function over an array, -- since that would result in an array of functions. -- == -- error: functional def main (xs : []i32) = map (\(x:i32) -> \(y:i32) -> x+y) xs futhark-0.25.27/tests/higher-order-functions/array-lambda0.fut000066400000000000000000000002401475065116200242210ustar00rootroot00000000000000-- We cannot have an array containing a literal lambda-expression. -- == -- error: functional def main : i32 = let _ = [\(x:i32) -> x+1] in 42 futhark-0.25.27/tests/higher-order-functions/array-lambda1.fut000066400000000000000000000002671475065116200242330ustar00rootroot00000000000000-- We cannot have an array containing literal lambda-expressions. -- == -- error: functional def main () : i32 = let _ = [\(x:i32) -> x+1, \(x:i32) -> x+x] in 42 futhark-0.25.27/tests/higher-order-functions/binop0.fut000066400000000000000000000005271475065116200230040ustar00rootroot00000000000000-- Test of a higher-order infix operator that takes two functions as -- arguments and returns a function as result. -- == -- input { 7 12 } output { 8 24 } def (***) '^a '^b '^a' '^b' (f: a -> a') (g: b -> b') : (a,b) -> (a',b') = \(x: a, y: b) -> (f x, g y) def main (x: i32) (y: i32) = ((\(x:i32) -> x+1) *** (\(y:i32) -> y+y)) (x, y) futhark-0.25.27/tests/higher-order-functions/binop1.fut000066400000000000000000000003671475065116200230070ustar00rootroot00000000000000-- Test of an infix operator that takes arguments of order 0, but -- returns a function. -- == -- input { 7 5 } output { 35 } def (**) (x:i32) (y:i32) = \(f:i32->i32->i32) -> f x y def main (x:i32) (y:i32) = (x ** y) (\(a:i32) (b:i32) -> a*b) futhark-0.25.27/tests/higher-order-functions/binop2.fut000066400000000000000000000004321475065116200230010ustar00rootroot00000000000000-- Test of an infix operator that takes one zero order argument and -- one functional argument, and returns a function. -- == -- input { 7 5 } output { 19 } def (**) (x:i32) (f:i32->i32) : i32 -> i32 = \(y:i32) -> f x + y def main (x:i32) (y:i32) = (x ** (\(z:i32) -> z+z)) y futhark-0.25.27/tests/higher-order-functions/conditional-function0.fut000066400000000000000000000003131475065116200260140ustar00rootroot00000000000000-- We cannot return a function from a conditional. -- == -- error: returned from branch def f (x:i32) : i32 = x+x def g (x:i32) : i32 = x+1 def main (b : bool) (n : i32) : i32 = (if b then f else g) n futhark-0.25.27/tests/higher-order-functions/function-argument0.fut000066400000000000000000000005151475065116200253370ustar00rootroot00000000000000-- Simple monomorphic higher-order function that takes a function as argument. -- == -- input { 3 } output { 5 12 } -- input { 16 } output {18 64 } def twice (f : i32 -> i32) (x : i32) : i32 = f (f x) def double (x : i32) : i32 = x+x def add1 (x : i32) : i32 = x+1 def main (x : i32) : (i32, i32) = (twice add1 x, twice double x) futhark-0.25.27/tests/higher-order-functions/function-argument1.fut000066400000000000000000000005461475065116200253440ustar00rootroot00000000000000-- Simple polymorphic higher-order function that takes a function as argument. -- == -- input { 11 true } output { 44 true } -- input { 7 false } output { 28 false } def twice 'a (f : a -> a) (x : a) : a = f (f x) def double (x : i32) : i32 = x+x def not (b : bool) : bool = ! b def main (x : i32) (b : bool) : (i32, bool) = (twice double x, twice not b) futhark-0.25.27/tests/higher-order-functions/function-argument2.fut000066400000000000000000000004631475065116200253430ustar00rootroot00000000000000-- Polymorphic higher-order function with an argument lambda that -- closes over a local variable. -- == -- input { 5 true } output { 7 } -- input { 5 false } output { 9 } def twice 'a (f : a -> a) (x : a) : a = f (f x) def main (x : i32) (b : bool) : i32 = twice (\(y:i32) -> if b then y+1 else y+2) x futhark-0.25.27/tests/higher-order-functions/function-composition.fut000066400000000000000000000007231475065116200260010ustar00rootroot00000000000000-- Standard polymorphic function composition. -- == -- input { 5 } output { true [6,6,6] [false,false,false] } -- input { 2 } output { false [3,3,3] [true,true,true] } def compose 'a 'b 'c (f : b -> c) (g : a -> b) (x : a) : c = f (g x) def add1 (x : i32) : i32 = x+1 def isEven (x : i32) : bool = x % 2 == 0 def replicate3 'a (x : a) : [3]a = [x, x, x] def main (x : i32) = (compose isEven add1 x, compose replicate3 add1 x, compose replicate3 isEven x ) futhark-0.25.27/tests/higher-order-functions/function-result0.fut000066400000000000000000000004221475065116200250300ustar00rootroot00000000000000-- Appling a first-order function that returns a lambda which closes -- over a local variable. -- == -- input { 3 4 } output { 9 } -- input { 10 12 } output { 24 } def main (x : i32) (y : i32) = let f (x:i32) = let a = 2 in \(y:i32) -> x+y+a in f x y futhark-0.25.27/tests/higher-order-functions/higher-order-entry-point0.fut000066400000000000000000000002771475065116200265440ustar00rootroot00000000000000-- Entry point functions are not allowed to take functions as arguments. -- == -- error: Entry point functions may not be higher-order def main (x : i32) (f : i32 -> i32, n : i32) = f x + n futhark-0.25.27/tests/higher-order-functions/higher-order-entry-point1.fut000066400000000000000000000001741475065116200265410ustar00rootroot00000000000000-- Curried entry point. -- == -- input { 2 2 } output { 4 } def plus (x: i32) (y: i32) = x + y def main (x: i32) = plus x futhark-0.25.27/tests/higher-order-functions/higher-order-entry-point2.fut000066400000000000000000000002351475065116200265400ustar00rootroot00000000000000-- A first-order entry point need not be syntactically first-order. -- == -- input { 2 2 } output { 4 } def plus (x: i32) (y: i32) = x + y def main = plus futhark-0.25.27/tests/higher-order-functions/higher-order0.fut000066400000000000000000000001661475065116200242530ustar00rootroot00000000000000-- id id id ... -- == -- input { 378 } output { 378 } def id '^a (x : a) : a = x def main (x : i32) = id id id id x futhark-0.25.27/tests/higher-order-functions/higher-order1.fut000066400000000000000000000003421475065116200242500ustar00rootroot00000000000000-- Just because a type parameter *may* be function, it does not *have* -- to be a function. -- == -- input { 2 } output { [2] } def id '^a (x: a) = x def array 't (f: t -> i32) (t: t) = [f t] def main (x: i32) = array id x futhark-0.25.27/tests/higher-order-functions/issue1798.fut000066400000000000000000000004061475065116200232720ustar00rootroot00000000000000def splits [n] 'a (p: a -> bool) (s: [n]a) = let m = n+1 in (\(_, i, k) -> #[unsafe] s[i:i+k], map (\i -> (replicate m (),i,i+1)) (indices s)) def main (s: []u8) = let (get,fs) = splits (=='-') s let on_f (_, i, k) = length s[i:i+k] in map on_f fs futhark-0.25.27/tests/higher-order-functions/issue493.fut000066400000000000000000000005001475065116200231740ustar00rootroot00000000000000-- It should be possible for a partially applied function to refer to -- a first-order (dynamic) function in its definition. -- == -- input { 3i64 [[1,2],[3,4]] } -- output { [[[1,2],[3,4]],[[1,2],[3,4]],[[1,2],[3,4]]] } def apply 'a '^b (f: a -> b) (x: a) = f x def main (n: i64) (d: [][]i32) = apply (replicate n) d futhark-0.25.27/tests/higher-order-functions/localfunction0.fut000066400000000000000000000002221475065116200245250ustar00rootroot00000000000000-- The defunctionaliser once messed up local closures. def main (n: i64) = let scale (x: i64) (y: i64) = (x+y) / n in map (scale 1) (iota n) futhark-0.25.27/tests/higher-order-functions/loops0.fut000066400000000000000000000003021475065116200230200ustar00rootroot00000000000000-- The merge parameter in a loop cannot have function type. -- == -- error: used as loop variable def id 'a (x : a) : a = x def main (n : i32) = loop f = id for i < n do \(y:i32) -> f y futhark-0.25.27/tests/higher-order-functions/match-function0.fut000066400000000000000000000003251475065116200246100ustar00rootroot00000000000000-- We cannot return a function from a pattern match. -- == -- error: returned from pattern match def f (x:i32) : i32 = x+x def g (x:i32) : i32 = x+1 def main (b : bool) (n : i32) : i32 = (match b case _ -> f) n futhark-0.25.27/tests/higher-order-functions/nested-closures0.fut000066400000000000000000000006131475065116200250100ustar00rootroot00000000000000-- A local function with a free variable that itself is a function that closes -- over its own local variables, i.e., a closure inside a closure environment. -- == -- input { 12 } output { 17 } def main (x : i32) = let f = let b = 2 in let g = let a = 1 in let h = \(x:i32) -> x+a in \(z:i32) -> h b + z in \(y:i32) -> g y + b in f x futhark-0.25.27/tests/higher-order-functions/partial-application0.fut000066400000000000000000000003401475065116200256230ustar00rootroot00000000000000-- Basic partial application. -- == -- input { 3 7 2 } output { 12 } -- input { -2 5 1 } output { 4 } def f (x : i32) (y : i32) (z : i32) : i32 = x + y + z def main (x : i32) (y : i32) (z : i32) = let g = f x y in g z futhark-0.25.27/tests/higher-order-functions/partial-application1.fut000066400000000000000000000003101475065116200256210ustar00rootroot00000000000000-- Binding a function to a local variable before applying it. -- == -- input { 3 } output { 6 } -- input { 11 } output { 22 } def f (x : i32) : i32 = x + x def main (x : i32) = let g = f in g x futhark-0.25.27/tests/higher-order-functions/records0.fut000066400000000000000000000006271475065116200233370ustar00rootroot00000000000000-- Storing a function in a record and applying it. -- == -- input { 3 7 2 } output { 71 } -- input { 5 9 11 } output { 102 } def add (x:i32) (y:i32) : i32 = x + y def main (x : i32) (y : i32) (z : i32) = let n = 1 let t = (\(z:i32) -> n+z, 10) let r = { a = 42 , f = add , f1 = add n , g = \(z:i32) -> z+z+n } in t.0 z + t.1 + r.f r.a x + r.f1 y + r.g z futhark-0.25.27/tests/higher-order-functions/records1.fut000066400000000000000000000004511475065116200233330ustar00rootroot00000000000000-- Implicit record field referring to a local first-order function. -- == -- input { 5 10 12 } output { 37 } -- input { 11 3 9 } output { 26 } def main (k : i32) (m : i32) (n : i32) = let r = (let a = m let f (x:i32) : i32 = x+a in { a, n = k, f }) in r.f n + r.a + r.n futhark-0.25.27/tests/higher-order-functions/shape-params0.fut000066400000000000000000000007421475065116200242550ustar00rootroot00000000000000-- Defunctionalization should not leave the shape parameter on lifted -- functions whose parameters do not refer to the shape, and it should -- preserve (or introduce new shape parameters) on lifted functions -- when necessary. -- == -- input { [2,3,5,1] [6,5,2,6] } output { [8,8,7,7] } def map2 'a 'b 'c [m] (f: a -> b -> c) (xs: [m]a) (ys: [m]b): [m]c = map (\(x,y) -> f x y) (zip xs ys) def add (x: i32) (y: i32) = x + y def main (xs: []i32) (ys: []i32) = map2 add xs ys futhark-0.25.27/tests/higher-order-functions/shape-params1.fut000066400000000000000000000002521475065116200242520ustar00rootroot00000000000000-- We can close over shape parameters. -- == -- input { [5,8,9] 5i64 } output { 8i64 } def f [n] (_: [n]i32) = \(y:i64) -> y+n def main (xs: []i32) (x: i64) = f xs x futhark-0.25.27/tests/higher-order-functions/shape-params2.fut000066400000000000000000000007561475065116200242640ustar00rootroot00000000000000-- A higher-order function with a shape parameter that contains a -- local dynamic function which is used as a first class value and -- which refers to the outer shape parameter in its parameter type -- and in its body. -- == -- input { [2,3,5,1] [6,5,2,6] } output { [8,8,7,7] 4i64 } def map2 [n] (f: i32 -> i32 -> i32) (xs: [n]i32) = let g (ys: [n]i32) = (map (\(x,y) -> f x y) (zip xs ys), n) in g def add (x: i32) (y: i32) = x + y def main (xs: []i32) (ys: []i32) = map2 add xs ys futhark-0.25.27/tests/higher-order-functions/shape-params3.fut000066400000000000000000000004221475065116200242530ustar00rootroot00000000000000-- A higher-order function that uses the shape parameter as a value term. -- == -- input { [12,17,8,23] } output { [13,18,9,24] 4i64 } def map_length [n] (f: i32 -> i32) (xs: [n]i32) : ([n]i32, i64) = (map f xs, n) def main (xs: []i32) = map_length (\(x:i32) -> x+1) xs futhark-0.25.27/tests/higher-order-functions/shape-params4.fut000066400000000000000000000001001475065116200242450ustar00rootroot00000000000000type^ f = (n: i64) -> [n]i32 def main: f = \n -> replicate n 0 futhark-0.25.27/tests/higher-order-functions/shape-params5.fut000066400000000000000000000003701475065116200242570ustar00rootroot00000000000000type^ nn 'u = { f : u } def connect '^u (a: nn u) (b: nn u) : nn (u, u) = { f = (a.f, a.f) } def nn1 : nn ((n: i64) -> [n]i32 -> [n]i32) = { f = \n (xs: [n]i32) -> xs } def foo = connect nn1 nn1 def main [n] (xs: [n]i32) = foo.f.0 n xs futhark-0.25.27/tests/higher-order-functions/shape-params6.fut000066400000000000000000000002611475065116200242570ustar00rootroot00000000000000-- == -- input { [1,2] [3,4,5] } -- output { 5i64 } def f [n][m] (f: [n+m]i32 -> i64) (a: [n]i32) (b: [m]i32) = f (a ++ b) def g n m (_: [n+m]i32) = n+m def main = f (g 2 3) futhark-0.25.27/tests/higher-order-functions/shape-params7.fut000066400000000000000000000003461475065116200242640ustar00rootroot00000000000000-- == -- input { [1,2,3] } -- output { [1,2,3] } def inc (x: i64) = x + 1 def tail [n] 't (A: [inc n]t) = A[1:] :> [n]t def cons [n] 't (x: t) (A: [n]t): [inc n]t = [x] ++ A :> [inc n]t def main (xs: []i32) = tail (cons 2 xs) futhark-0.25.27/tests/higher-order-functions/soac0.fut000066400000000000000000000005121475065116200226140ustar00rootroot00000000000000-- We can use function composition in the functional argument of the map SOAC. -- == -- input { [2,4,5,1,7,5] } output { [5,9,11,3,15,11 ] } def compose 'a 'b 'c (f : b -> c) (g : a -> b) (x : a) : c = f (g x) def add1 (x : i32) : i32 = x+1 def double (x : i32) : i32 = x+x def main (xs : []i32) = map (compose add1 double) xs futhark-0.25.27/tests/higher-order-functions/soac1.fut000066400000000000000000000004771475065116200226270ustar00rootroot00000000000000-- We can use function composition in the functional argument of the reduce SOAC. -- == -- input { [8,7,12,9] } output { 36 } def compose '^a '^b '^c (f : b -> c) (g : a -> b) (x : a) : c = f (g x) def add (x : i32) (y:i32) : i32 = x+y def id '^a (x : a) : a = x def main (xs : []i32) = reduce (compose add id) 0 xs futhark-0.25.27/tests/higher-order-functions/uniqueness0.fut000066400000000000000000000002371475065116200240720ustar00rootroot00000000000000-- == -- error: consumption def update (xs: *[]i32) (i: i32) (y: i32) = xs with [i] = y def main (QUUX: *[]i32)= let f = update QUUX in (f 0 0, f 0 0) futhark-0.25.27/tests/higher-order-functions/uniqueness1.fut000066400000000000000000000002231475065116200240660ustar00rootroot00000000000000-- == -- error: consumption def update (xs: *[]i32) (i: i32) (y: i32) = xs with [i] = y def main (arr: *[]i32) = let f = update arr in arr futhark-0.25.27/tests/higher-order-functions/uniqueness10.fut000066400000000000000000000003511475065116200241500ustar00rootroot00000000000000-- Defunctionaliser generated wrong uniqueness for this one at one -- point. def main [n] (xs: [n]i32) (ys: [n]i32) (is: []i32) = let op (xs' : [n]i32, ys') i = (if i == 2 then xs else xs', ys') in foldl op (xs, ys) is futhark-0.25.27/tests/higher-order-functions/uniqueness11.fut000066400000000000000000000000661475065116200241540ustar00rootroot00000000000000def main = map (const (iota 3 with [0] = 1)) (iota 3) futhark-0.25.27/tests/higher-order-functions/uniqueness2.fut000066400000000000000000000002551475065116200240740ustar00rootroot00000000000000-- A lambda whose free variable has been consumed. -- == -- error: "QUUX".*consumed def main(y: i32, QUUX: *[]i32) = let f = \x -> x + QUUX[0] let QUUX[1] = 2 in f y futhark-0.25.27/tests/higher-order-functions/uniqueness3.fut000066400000000000000000000003071475065116200240730ustar00rootroot00000000000000-- A partially applied function whose closure has been consumed. -- == -- error: "QUUX".*consumed def const x _ = x def main(y: i32, QUUX: *[]i32) = let f = const QUUX let QUUX[1] = 2 in f y futhark-0.25.27/tests/higher-order-functions/uniqueness4.fut000066400000000000000000000003121475065116200240700ustar00rootroot00000000000000-- A partially applied function whose closure has been consumed. -- == -- error: "QUUX".*consumed def const x _ = x[0] def main(y: i32, QUUX: *[]i32) = let f = const QUUX let QUUX[1] = 2 in f y futhark-0.25.27/tests/higher-order-functions/uniqueness5.fut000066400000000000000000000003471475065116200241010ustar00rootroot00000000000000-- A consuming function must not be passed as a higher-order argument! -- == -- error: consumption def zero (xs: *[]i32) (i: i32) = xs with [i] = 0 def apply f x = f x def main (arr: *[]i32)= let f = zero arr in apply f 0 futhark-0.25.27/tests/higher-order-functions/uniqueness6.fut000066400000000000000000000004151475065116200240760ustar00rootroot00000000000000-- Nope, this one is also not OK (although it would be possible to -- change the type system so that it would be). -- == -- error: consumption def zero (xs: *[]i32) (i: i32) = xs with [i] = 0 def apply f x = f x def main (arr: *[]i32)= let f = zero arr in f 0 futhark-0.25.27/tests/higher-order-functions/uniqueness7.fut000066400000000000000000000002661475065116200241030ustar00rootroot00000000000000-- == -- input { 1 } output { [0,1,1] } def zero (xs: *[]i32) (i: i32) = xs with [i] = 0 def uniq (x: i32): *[]i32 = [x,x,x] def main (x: i32)= let f = zero (uniq x) in f 0 futhark-0.25.27/tests/higher-order-functions/uniqueness8.fut000066400000000000000000000005551475065116200241050ustar00rootroot00000000000000-- Do not let the uniqueness of a returned array affect the inferred -- uniqueness of a function parameter. def (>->) '^a '^b '^c (f: a -> b) (g: b -> c) (x: a): c = g (f x) def tabmap [n] 'a 'b (f: i32 -> a -> b) (xs: [n]a): *[1]b = [f 0 xs[0]] def main [n][m] (arr: [n][m]f32): [][]f32 = let f (i: i32) (j: i32) (x: f32) = x in tabmap (f >-> tabmap) arr futhark-0.25.27/tests/higher-order-functions/uniqueness9.fut000066400000000000000000000003401475065116200240760ustar00rootroot00000000000000-- This requires care to maintain the right uniqueness attributes. def singleton (f: i32 -> []i32): ([]i32, *[]i32) = let xs = f 1 in (xs, [1]) def main (xs: []i32) = singleton (\y -> if y >= 0 then [xs[y]] else xs) futhark-0.25.27/tests/higher-order-functions/value-type-function0.fut000066400000000000000000000006501475065116200256100ustar00rootroot00000000000000-- Since we distinguish between value types and lifted types, it -- should not be possible to instantiate a polymorphic funtion, that -- uses its (value type) polymorphic arguments in the branches of a -- conditional, with a function type. -- == -- error: functional def cond 'a (b : bool) (x : a) (y : a) : a = if b then x else y def main (b : bool) : i32 = let f = cond b (\(x:i32) -> x+x) (\(x:i32) -> x) in f 42 futhark-0.25.27/tests/higher-order-functions/value-type-function1.fut000066400000000000000000000003601475065116200256070ustar00rootroot00000000000000-- We should not be able to instantiate a value type parameter of a -- polymorphic function with a function type. -- == -- error: functional def mkArray 'a (x : a) : []a = [x] def main (x : i32) = let _ = mkArray (\(x:i32) -> x) in x futhark-0.25.27/tests/higher-order-functions/value-type-function2.fut000066400000000000000000000004221475065116200256070ustar00rootroot00000000000000-- A lifted type parameter cannot be used as the type of the branches -- of a conditional. -- == -- error: def cond '^a (b : bool) (x : a) (y : a) : a = if b then x else y def main (b : bool) (y : i32) : i32 = let f = cond b (\(x:i32) -> x+x) (\(x:i32) -> x) in f y futhark-0.25.27/tests/higher-order-functions/value-type-function3.fut000066400000000000000000000004311475065116200256100ustar00rootroot00000000000000-- We can not even use a lifted type parameter as the result type of a -- condition even if it's not actually instantiated with a function type. -- == -- error: def cond '^a (b : bool) (x : a) (y : a) : a = if b then x else y def main (b : bool) (x : i32) : i32 = cond b x 0 futhark-0.25.27/tests/higher-order-functions/value-type-function4.fut000066400000000000000000000004671475065116200256220ustar00rootroot00000000000000-- A value type parameter of a polymorphic function cannot be instantiated to a -- function type by passing the function as an argument to another function. -- == -- error: functional def app (f : (i32 -> i32) -> (i32 -> i32)) : i32 = f (\(x:i32) -> x+x) 42 def id 'a (x:a) : a = x def main : i32 = app id futhark-0.25.27/tests/higher-order-functions/value-type-function5.fut000066400000000000000000000006351475065116200256200ustar00rootroot00000000000000-- Lifted type parameters allows for the definition of general polymorphic -- function composition. Without them, we are limited in which functions can be -- composed. -- == -- error: functional def compose 'a 'b 'c (f : b -> c) (g : a -> b) : a -> c = \(x : a) -> f (g x) def add (x : i32) (y : i32) : i32 = x+y def double (x : i32) : i32 = x+x def main (x : i32) (y : i32) : i32 = compose add double 3 5 futhark-0.25.27/tests/hist/000077500000000000000000000000001475065116200154535ustar00rootroot00000000000000futhark-0.25.27/tests/hist/64and.fut000066400000000000000000000010221475065116200171020ustar00rootroot00000000000000-- Test with i64.&. -- == -- -- input { -- 5i64 -- [0, 1, 2, 3, 4] -- [1, 1, 1, 1, 1] -- } -- output { -- [1i64, 1i64, 1i64, 1i64, 1i64] -- } -- -- input { -- 5i64 -- [0, 0, 0, 0, 0] -- [6, 1, 4, 5, -1] -- } -- output { -- [0i64, -1i64, -1i64, -1i64, -1i64] -- } -- -- input { -- 5i64 -- [1, 2, 1, 4, 5] -- [1, 1, 4, 4, 4] -- } -- output { -- [-1i64, 0i64, 1i64, -1i64, 4i64] -- } def main [m] (n: i64) (is: [m]i32) (image: [m]i32) : [n]i64 = hist (i64.&) (-1) n (map i64.i32 is) (map i64.i32 image) futhark-0.25.27/tests/hist/64max.fut000066400000000000000000000013201475065116200171260ustar00rootroot00000000000000-- Test with i64.max/u64.max. -- == -- -- input { -- 5i64 -- [0, 1, 2, 3, 4] -- [1, 1, 1, 1, 1] -- } -- output { -- [1, 1, 1, 1, 1] -- [1, 1, 1, 1, 1] -- } -- -- input { -- 5i64 -- [0, 0, 0, 0, 0] -- [6, 1, 4, 5, -1] -- } -- output { -- [6, 0, 0, 0, 0] -- [-1, 0, 0, 0, 0] -- } -- -- input { -- 5i64 -- [1, 2, 1, 4, 5] -- [1, 1, 4, 4, 4] -- } -- output { -- [0, 4, 1, 0, 4] -- [0, 4, 1, 0, 4] -- } def main [m] (n: i64) (is: [m]i32) (image: [m]i32) : ([n]i32, [n]i32) = (reduce_by_index (replicate n 0) i32.max i32.lowest (map i64.i32 is) image, map i32.u32 (reduce_by_index (replicate n 0) u32.max u32.lowest (map i64.i32 is) (map u32.i32 image))) futhark-0.25.27/tests/hist/64xor.fut000066400000000000000000000010121475065116200171470ustar00rootroot00000000000000-- Test with i64.^. -- == -- -- input { -- 5i64 -- [0, 1, 2, 3, 4] -- [1, 1, 1, 1, 1] -- } -- output { -- [1i64, 1i64, 1i64, 1i64, 1i64] -- } -- -- input { -- 5i64 -- [0, 0, 0, 0, 0] -- [6, 1, 4, 5, -1] -- } -- output { -- [-7i64, 0i64, 0i64, 0i64, 0i64] -- } -- -- input { -- 5i64 -- [1, 2, 1, 4, 5] -- [1, 1, 4, 4, 4] -- } -- output { -- [0i64, 5i64, 1i64, 0i64, 4i64] -- } def main [m] (n: i64) (is: [m]i32) (image: [m]i32) : [n]i64 = hist (i64.^) 0 n (map i64.i32 is) (map i64.i32 image) futhark-0.25.27/tests/hist/and.fut000066400000000000000000000007271475065116200167430ustar00rootroot00000000000000-- Test with i32.&. -- == -- -- input { -- 5i64 -- [0, 1, 2, 3, 4] -- [1, 1, 1, 1, 1] -- } -- output { -- [1, 1, 1, 1, 1] -- } -- -- input { -- 5i64 -- [0, 0, 0, 0, 0] -- [6, 1, 4, 5, -1] -- } -- output { -- [0, -1, -1, -1, -1] -- } -- -- input { -- 5i64 -- [1, 2, 1, 4, 5] -- [1, 1, 4, 4, 4] -- } -- output { -- [-1, 0, 1, -1, 4] -- } def main [m] (n: i64) (is: [m]i32) (image: [m]i32) : [n]i32 = hist (i32.&) (-1) n (map i64.i32 is) image futhark-0.25.27/tests/hist/array.fut000066400000000000000000000005131475065116200173100ustar00rootroot00000000000000-- Test reduce_by_index on array of arrays -- == -- input { [[1,2,3],[4,5,6]] [0i64,0i64,2i64,1i64] } -- output { [[1, 4, 7], [4, 6, 8]] } def main [m][n][k] (xs : *[n][m]i32) (image : *[k]i64) : *[n][m]i32 = reduce_by_index xs (\x y -> map2 (+) x y) (replicate m 0) image (replicate k (map i32.i64 (iota m))) futhark-0.25.27/tests/hist/array_novec.fut000066400000000000000000000007111475065116200205020ustar00rootroot00000000000000-- Test reduce_by_index on array of arrays, where the operator is not -- recognisably vectorised. -- == -- input { [[1,2,3],[4,5,6]] [0i64,0i64,2i64,1i64] } -- output { [[1, 4, 7], [4, 6, 8]] } def main [m][n][k] (xs : *[n][m]i32) (image : *[k]i64) : *[n][m]i32 = reduce_by_index xs (\x y -> loop acc = copy x for i in iota m do acc with [i] = acc[i] + y[i]) (replicate m 0) image (replicate k (map i32.i64 (iota m))) futhark-0.25.27/tests/hist/equiv.fut000066400000000000000000000014321475065116200173240ustar00rootroot00000000000000-- Test -- == -- input { -- [[1, 2, 3], [1, 2, 3], [1, 2, 3]] -- [1, 1, 1] -- } -- output { -- [[1i32, 2i32, 3i32], [4i32, 8i32, 12i32], [1i32, 2i32, 3i32]] -- [[1i32, 2i32, 3i32], [4i32, 8i32, 12i32], [1i32, 2i32, 3i32]] -- } def hist_equiv [n][k] (xs : [n][3]i32) (image : [k]i32) : [n][3]i32 = let inds = image let vals = replicate k [1,2,3] let vals' = transpose vals let xs' = transpose xs let res = map2 (\row x -> reduce_by_index (copy x) (+) 0 (map i64.i32 inds) row) vals' xs' in transpose res def main [n][k] (xs : [n][3]i32) (image : [k]i32) = -- : *[n][3]i32 = let res1 = reduce_by_index (copy xs) (\x y -> map2 (+) x y) [0,0,0] (map i64.i32 image) (replicate k [1,2,3]) let res2 = hist_equiv (copy xs) image in (res1, res2) futhark-0.25.27/tests/hist/f16.fut000066400000000000000000000017011475065116200165660ustar00rootroot00000000000000-- Can we do operations on f16s, even though these are not natively supported? -- == -- -- input { -- [0f16, 0f16, 0f16, 0f16, 0f16] -- [1i16, 1i16, 1i16, 1i16, 1i16] -- [1f16, 1f16, 1f16, 1f16, 1f16] -- } -- output { -- [0f16, 5f16, 0f16, 0f16, 0f16] -- } -- -- input { -- [0f16, 0f16, 0f16, 0f16, 0f16] -- [1i16, 1i16, 4i16, 4i16, 4i16] -- [0.1f16, 0.1f16, 0.4f16, 0.4f16, 0.4f16] -- } -- output { -- [0f16, 0.2f16, 0f16, 0f16, 1.2f16] -- } -- -- input { -- [1f16, 2f16, 3f16, 4f16, 5f16] -- [1i16, 1i16, 4i16, 4i16, 4i16] -- [1f16, 1f16, 4f16, 4f16, 4f16] -- } -- output { -- [1f16, 4f16, 3f16, 4f16, 17f16] -- } -- -- input { -- [1f16, f16.nan, 3f16, 4f16, 5f16] -- [1i16, 1i16, 4i16, 4i16, 4i16] -- [1f16, 1f16, 4f16, 4f16, 4f16] -- } -- output { -- [1f16, f16.nan, 3f16, 4f16, 17f16] -- } def main [m][n] (hist : *[n]f16) (is: [m]i16) (image : [m]f16) : [n]f16 = reduce_by_index hist (+) 0f16 (map i64.i16 is) image futhark-0.25.27/tests/hist/f32.fut000066400000000000000000000016051475065116200165670ustar00rootroot00000000000000-- Can we do operations on f32s, even though these are not natively supported? -- == -- -- input { -- [0f32, 0f32, 0f32, 0f32, 0f32] -- [1, 1, 1, 1, 1] -- [1f32, 1f32, 1f32, 1f32, 1f32] -- } -- output { -- [0f32, 5f32, 0f32, 0f32, 0f32] -- } -- -- input { -- [0f32, 0f32, 0f32, 0f32, 0f32] -- [1, 1, 4, 4, 4] -- [0.1f32, 0.1f32, 0.4f32, 0.4f32, 0.4f32] -- } -- output { -- [0f32, 0.2f32, 0f32, 0f32, 1.2f32] -- } -- -- input { -- [1f32, 2f32, 3f32, 4f32, 5f32] -- [1, 1, 4, 4, 4] -- [1f32, 1f32, 4f32, 4f32, 4f32] -- } -- output { -- [1f32, 4f32, 3f32, 4f32, 17f32] -- } -- -- input { -- [1f32, f32.nan, 3f32, 4f32, 5f32] -- [1, 1, 4, 4, 4] -- [1f32, 1f32, 4f32, 4f32, 4f32] -- } -- output { -- [1f32, f32.nan, 3f32, 4f32, 17f32] -- } def main [m][n] (hist : *[n]f32) (is: [m]i32) (image : [m]f32) : [n]f32 = reduce_by_index hist (+) 0f32 (map i64.i32 is) image futhark-0.25.27/tests/hist/f64.fut000066400000000000000000000016171475065116200165770ustar00rootroot00000000000000-- Can we do operations on f64s, even though these are not always -- natively supported? -- == -- -- input { -- [0f64, 0f64, 0f64, 0f64, 0f64] -- [1, 1, 1, 1, 1] -- [1f64, 1f64, 1f64, 1f64, 1f64] -- } -- output { -- [0f64, 5f64, 0f64, 0f64, 0f64] -- } -- -- input { -- [0f64, 0f64, 0f64, 0f64, 0f64] -- [1, 1, 4, 4, 4] -- [0.1f64, 0.1f64, 0.4f64, 0.4f64, 0.4f64] -- } -- output { -- [0f64, 0.2f64, 0f64, 0f64, 1.2f64] -- } -- -- input { -- [1f64, 2f64, 3f64, 4f64, 5f64] -- [1, 1, 4, 4, 4] -- [1f64, 1f64, 4f64, 4f64, 4f64] -- } -- output { -- [1f64, 4f64, 3f64, 4f64, 17f64] -- } -- -- input { -- [1f64, f64.nan, 3f64, 4f64, 5f64] -- [1, 1, 4, 4, 4] -- [1f64, 1f64, 4f64, 4f64, 4f64] -- } -- output { -- [1f64, f64.nan, 3f64, 4f64, 17f64] -- } def main [m][n] (hist : *[n]f64) (is: [m]i32) (image : [m]f64) : [n]f64 = reduce_by_index hist (+) 0f64 (map i64.i32 is) image futhark-0.25.27/tests/hist/fusion.fut000066400000000000000000000002511475065116200174740ustar00rootroot00000000000000-- -- == -- structure { Screma 0 Hist 1 } def main [m][n] (hist : *[n]i32, image : [m]i32) : [n]i32 = reduce_by_index hist (+) 0 (map i64.i32 image) (map (+2) image) futhark-0.25.27/tests/hist/hist2d.fut000066400000000000000000000023371475065116200173750ustar00rootroot00000000000000-- == -- input { [[1,2,3],[4,5,6],[7,8,9]] [1i64, 1i64] [1i64, -1i64] [42, 1337] } -- output { [[1i32, 2i32, 3i32], [4i32, 47i32, 6i32], [7i32, 8i32, 9i32]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [-1i64] [-1i64] [1337] } -- output { [[1i32, 2i32, 3i32], [4i32, 5i32, 6i32], [7i32, 8i32, 9i32]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [3i64] [0i64] [1337] } -- output { [[1i32, 2i32, 3i32], [4i32, 5i32, 6i32], [7i32, 8i32, 9i32]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [0i64] [3i64] [1337] } -- output { [[1i32, 2i32, 3i32], [4i32, 5i32, 6i32], [7i32, 8i32, 9i32]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [-1i64] [0i64] [1337] } -- output { [[1i32, 2i32, 3i32], [4i32, 5i32, 6i32], [7i32, 8i32, 9i32]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [0i64] [-1i64] [1337] } -- output { [[1i32, 2i32, 3i32], [4i32, 5i32, 6i32], [7i32, 8i32, 9i32]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [0i64] [0i64] [1337] } -- output { [[1338i32, 2i32, 3i32], [4i32, 5i32, 6i32], [7i32, 8i32, 9i32]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [3i64] [3i64] [1337] } -- output { [[1i32, 2i32, 3i32], [4i32, 5i32, 6i32], [7i32, 8i32, 9i32]] } def main [n][m][l] (xss: *[n][m]i32) (is: [l]i64) (js: [l]i64) (vs: [l]i32): [n][m]i32 = reduce_by_index_2d xss (+) 0 (zip is js) vs futhark-0.25.27/tests/hist/hist3d.fut000066400000000000000000000005411475065116200173710ustar00rootroot00000000000000-- == -- input { [[[1,2,3],[4,5,6],[7,8,9]],[[10,20,30],[40,50,60],[70,80,90]]] } -- output { -- [[[1i32, 2i32, 3i32], [4i32, 5i32, 6i32], [7i32, 8i32, 9i32]], -- [[10i32, 20i32, 30i32], [40i32, 1387i32, 60i32], [70i32, 80i32, 90i32]]] -- } def main [n][m][o] (xss: *[n][m][o]i32) = reduce_by_index_3d xss (+) 0 [(1, 1, 1), (1,-1, 1)] [1337, 0] futhark-0.25.27/tests/hist/horizontal-fusion.fut000066400000000000000000000005401475065116200216640ustar00rootroot00000000000000-- -- == -- input { 2i64 [0, 1, 1] } output { [2, 6] [0f32, 0f32] } -- structure { Screma 0 Hist 1 } def main [m] (n: i64) (image : [m]i32) : ([n]i32, []f32) = let as = replicate n 0 let bs = replicate n 0 in (reduce_by_index as (+) 0 (map i64.i32 image) (map (+2) image), reduce_by_index bs (*) 1 (map i64.i32 image) (map f32.i32 image)) futhark-0.25.27/tests/hist/large.fut000066400000000000000000000006171475065116200172710ustar00rootroot00000000000000-- Some tests to try out very large/sparse histograms. -- == -- tags { no_python no_wasm } -- compiled input { 10000000i64 1000i64 } output { 499500i32 } -- compiled input { 100000000i64 10000i64 } output { 49995000i32 } -- compiled input { 100000000i64 1000000i64 } output { 1783293664i32 } def main (n: i64) (m: i64) = hist (+) 0 n (map (%n) (iota m)) (map i32.i64 (iota m)) |> i32.sum futhark-0.25.27/tests/hist/large2d.fut000066400000000000000000000007051475065116200175150ustar00rootroot00000000000000-- Some tests to try out very large/sparse 2D histograms. -- == -- tags { no_python no_wasm } -- compiled input { 100i64 100i64 1000000i64 } auto output -- compiled input { 1000i64 1000i64 1000000i64 } auto output def main (n: i64) (m: i64) (k: i64) = reduce_by_index_2d (replicate m (replicate n 0)) (+) 0 (zip (map (%n) (iota k)) (map (%m) (iota k))) (map i32.i64 (iota k)) |> flatten |> i32.sum futhark-0.25.27/tests/hist/max.fut000066400000000000000000000013201475065116200167540ustar00rootroot00000000000000-- Test with i32.max/u32.max. -- == -- -- input { -- 5i64 -- [0, 1, 2, 3, 4] -- [1, 1, 1, 1, 1] -- } -- output { -- [1, 1, 1, 1, 1] -- [1, 1, 1, 1, 1] -- } -- -- input { -- 5i64 -- [0, 0, 0, 0, 0] -- [6, 1, 4, 5, -1] -- } -- output { -- [6, 0, 0, 0, 0] -- [-1, 0, 0, 0, 0] -- } -- -- input { -- 5i64 -- [1, 2, 1, 4, 5] -- [1, 1, 4, 4, 4] -- } -- output { -- [0, 4, 1, 0, 4] -- [0, 4, 1, 0, 4] -- } def main [m] (n: i64) (is: [m]i32) (image: [m]i32) : ([n]i32, [n]i32) = (reduce_by_index (replicate n 0) i32.max i32.lowest (map i64.i32 is) image, map i32.u32 (reduce_by_index (replicate n 0) u32.max u32.lowest (map i64.i32 is) (map u32.i32 image))) futhark-0.25.27/tests/hist/min.fut000066400000000000000000000013231475065116200167550ustar00rootroot00000000000000-- Test with i32.min/u32.min. -- == -- -- input { -- 5i64 -- [0, 1, 2, 3, 4] -- [1, -1, 1, 1, 1] -- } -- output { -- [0, -1, 0, 0, 0] -- [0, 0, 0, 0, 0] -- } -- -- input { -- 5i64 -- [0, 0, 0, 0, 0] -- [6, 1, 4, 5, -1] -- } -- output { -- [-1, 0, 0, 0, 0] -- [0, 0, 0, 0, 0] -- } -- -- input { -- 5i64 -- [1, 2, 1, 4, 5] -- [1, 1, 4, 4, 4] -- } -- output { -- [0, 0, 0, 0, 0] -- [0, 0, 0, 0, 0] -- } def main [m] (n: i64) (is: [m]i32) (image: [m]i32) : ([n]i32, [n]i32) = (reduce_by_index (replicate n 0) i32.min i32.highest (map i64.i32 is) image, map i32.u32 (reduce_by_index (replicate n 0) u32.min u32.highest (map i64.i32 is) (map u32.i32 image))) futhark-0.25.27/tests/hist/or.fut000066400000000000000000000007361475065116200166210ustar00rootroot00000000000000-- Test with i32.|. -- == -- -- input { -- 5i64 -- [0, 1, 2, 3, 4] -- [1, 1, 1, 1, 1] -- } -- output { -- [1, 1, 1, 1, 1] -- } -- -- input { -- 5i64 -- [0, 0, 0, 0, 0] -- [6, 1, 4, 5, -1] -- } -- output { -- [-1, 0, 0, 0, 0] -- } -- -- input { -- 5i64 -- [1, 2, 1, 4, 5] -- [1, 1, 4, 4, 4] -- } -- output { -- [0i32, 5i32, 1i32, 0i32, 4i32] -- } def main [m] (n: i64) (is: [m]i32) (image: [m]i32) : [n]i32 = hist (i32.|) 0 n (map i64.i32 is) image futhark-0.25.27/tests/hist/segmented.fut000066400000000000000000000005371475065116200201530ustar00rootroot00000000000000-- == -- input { 10i64 [[1,2,3],[2,3,4],[3,4,5]] } -- output { -- [[0i32, 1i32, 1i32, 1i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32], -- [0i32, 0i32, 1i32, 1i32, 1i32, 0i32, 0i32, 0i32, 0i32, 0i32], -- [0i32, 0i32, 0i32, 1i32, 1i32, 1i32, 0i32, 0i32, 0i32, 0i32]] -- } def main (m: i64) = map (\xs -> hist (+) 0 m (map i64.i32 xs) (map (const 1) xs)) futhark-0.25.27/tests/hist/segmented_2d.fut000066400000000000000000000014161475065116200205350ustar00rootroot00000000000000-- == -- input { 4i64 5i64 [[0,0,0],[0,3,0],[3,4,0]] [[0,0,1],[0,3,4],[1,2,3]] } -- output { -- [[[2i32, 1i32, 0i32, 0i32, 0i32], -- [0i32, 0i32, 0i32, 0i32, 0i32], -- [0i32, 0i32, 0i32, 0i32, 0i32], -- [0i32, 0i32, 0i32, 0i32, 0i32]], -- [[1i32, 0i32, 0i32, 0i32, 1i32], -- [0i32, 0i32, 0i32, 0i32, 0i32], -- [0i32, 0i32, 0i32, 0i32, 0i32], -- [0i32, 0i32, 0i32, 1i32, 0i32]], -- [[0i32, 0i32, 0i32, 1i32, 0i32], -- [0i32, 0i32, 0i32, 0i32, 0i32], -- [0i32, 0i32, 0i32, 0i32, 0i32], -- [0i32, 1i32, 0i32, 0i32, 0i32]]] -- } def main (n: i64) (m: i64) = map2 (\xs ys -> reduce_by_index_2d (replicate n (replicate m 0)) (+) 0 (zip (map i64.i32 xs) (map i64.i32 ys)) (map (const 1) xs)) futhark-0.25.27/tests/hist/segmented_arr.fut000066400000000000000000000005201475065116200210070ustar00rootroot00000000000000-- == -- input { 4i64 [[0,1],[1,2],[2,3]] } -- output { -- [[[1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0]], -- [[0, 0, 0], [1, 1, 1], [1, 1, 1], [0, 0, 0]], -- [[0, 0, 0], [0, 0, 0], [1, 1, 1], [1, 1, 1]]] -- } def main (m: i64) = map (\xs -> hist (map2 (+)) (replicate 3 0) m (map i64.i32 xs) (map (const (replicate 3 1)) xs)) futhark-0.25.27/tests/hist/simple.fut000066400000000000000000000013731475065116200174700ustar00rootroot00000000000000-- Test genred in simple cases with addition operator -- == -- -- input { -- [0, 0, 0, 0, 0] -- [1, 1, 1, 1, 1] -- } -- output { -- [0, 5, 0, 0, 0] -- } -- -- input { -- [0, 0, 0, 0, 0] -- [1, 1, 4, 4, 4] -- } -- output { -- [0, 2, 0, 0, 12] -- } -- -- input { -- [1, 2, 3, 4, 5] -- [1, 1, 4, 4, 4] -- } -- output { -- [1, 4, 3, 4, 17] -- } -- -- input { -- [0, 0, 0, 0, 0] -- [10000000] -- } -- output { -- [0, 0, 0, 0, 0] -- } -- -- input { -- [0, 0, 0, 0, 0] -- empty([0]i32) -- } -- output { -- [0, 0, 0, 0, 0] -- } -- -- input { -- empty([0]i32) -- empty([0]i32) -- } -- output { -- empty([0]i32) -- } def main [m][n] (hist : *[n]i32) (image : [m]i32) : [n]i32 = reduce_by_index hist (+) 0 (map i64.i32 image) image futhark-0.25.27/tests/hist/tuple.fut000066400000000000000000000006231475065116200173250ustar00rootroot00000000000000-- Test reduce_by_index on array of tuples -- == def bucket_function (x : i32) : (i64, (i32, i32)) = (i64.i32 x, (1, 2)) def operator ((x0, y0) : (i32, i32)) ((x1, y1) : (i32, i32)) : (i32, i32) = (x0 + x1, y0 + y1) def main [m][n] (xs : *[m](i32, i32)) (image : [n]i32) : ([m]i32, [m]i32) = let (is, vs) = unzip (map bucket_function image) in unzip (reduce_by_index xs operator (1,1) is vs) futhark-0.25.27/tests/hist/tuple_partial.fut000066400000000000000000000012451475065116200210420ustar00rootroot00000000000000-- Test reduce_by_index on array of tuples where part of the tuple is not -- recomputed. -- == -- input { -- 5i64 -- [1, 3, 1] -- [4, 1, 3] -- [5, 6, 7] -- } -- output { -- [-1, 3, -1, 1, -1] -- [-1, 7, -1, 6, -1] -- } -- This is 'min', but with auxiliary information carried along with the result. def operator ((x0, y0): (i32, i32)) ((x1, y1): (i32, i32)): (i32, i32) = if x0 != -1 && (x1 == -1 || x0 < x1) then (x0, y0) else (x1, y1) def main [n] (m: i64) (is: [n]i32) (vs0: [n]i32) (vs1: [n]i32): ([m]i32, [m]i32) = let ne = (-1, -1) let dest = replicate m ne let vs = zip vs0 vs1 in unzip (reduce_by_index dest operator ne (map i64.i32 is) vs) futhark-0.25.27/tests/hist/xor.fut000066400000000000000000000007361475065116200170110ustar00rootroot00000000000000-- Test with i32.^. -- == -- -- input { -- 5i64 -- [0, 1, 2, 3, 4] -- [1, 1, 1, 1, 1] -- } -- output { -- [1, 1, 1, 1, 1] -- } -- -- input { -- 5i64 -- [0, 0, 0, 0, 0] -- [6, 1, 4, 5, -1] -- } -- output { -- [-7, 0, 0, 0, 0] -- } -- -- input { -- 5i64 -- [1, 2, 1, 4, 5] -- [1, 1, 4, 4, 4] -- } -- output { -- [0i32, 5i32, 1i32, 0i32, 4i32] -- } def main [m] (n: i64) (is: [m]i32) (image: [m]i32) : [n]i32 = hist (i32.^) 0 n (map i64.i32 is) image futhark-0.25.27/tests/hoist-consume.fut000066400000000000000000000006611475065116200200240ustar00rootroot00000000000000-- This test can fail if the (consuming) calls to fib2 are lifted in an -- erroneous way. -- == -- input { -- 10i64 -- } -- output { -- [42, 42, 42, 42, 42, 42, 42, 42, 42, 42] -- } def fib2(a: *[]i32, i: i32, n: i32): *[]i32 = a def fib(a: *[]i32, i: i32, n: i32): *[]i32 = if i == n then a else if i < 2 then fib2(a,i+1,n) else fib2(a,i+1,n) def main(n: i64): []i32 = fib(replicate n 42,0,i32.i64 n) futhark-0.25.27/tests/hoist-map.fut000066400000000000000000000031771475065116200171350ustar00rootroot00000000000000-- If we don't hoist very aggressively here, this will not run -- correctly with the GPU backends. -- == -- compiled random input { [100][64]i32 } auto output local def log2 (n: i64) : i64 = let r = 0 let (r, _) = loop (r,n) while 1 < n do let n = n / 2 let r = r + 1 in (r,n) in r local def ensure_pow_2 [n] 't ((<=): t -> t -> bool) (xs: [n]t): (*[]t, i64) = if n == 0 then (copy xs, 0) else let d = log2 n in if n == 2**d then (copy xs, d) else let largest = reduce (\x y -> if x <= y then y else x) xs[0] xs in (concat xs (replicate (2**(d+1) - n) largest), d+1) local def kernel_par [n] 't ((<=): t -> t -> bool) (a: *[n]t) (p: i64) (q: i64) : *[n]t = let d = 1 << (p-q) in map (\i -> let a_i = a[i] let up1 = ((i >> p) & 2) == 0 in if (i & d) == 0 then let a_iord = a[i | d] in if a_iord <= a_i == up1 then a_iord else a_i else let a_ixord = a[i ^ d] in if a_i <= a_ixord == up1 then a_ixord else a_i) (iota n) -- | Sort an array in increasing order. def merge_sort [n] 't ((<=): t -> t -> bool) (xs: [n]t): *[n]t = -- We need to pad the array so that its size is a power of 2. We do -- this by first finding the largest element in the input, and then -- using that for the padding. Then we know that the padding will -- all be at the end, so we can easily cut it off. let (xs, d) = ensure_pow_2 (<=) xs in (loop xs for i < d do loop xs for j < i+1 do kernel_par (<=) xs i j)[:n] def main = map (merge_sort (i32.<=)) futhark-0.25.27/tests/hoist-unsafe.fut000066400000000000000000000002641475065116200176330ustar00rootroot00000000000000-- Test that we do not hoist dangerous things out of loops. -- == -- input { empty([0]i32) 2 } output { 2 } def main [n] (a: [n]i32) (m: i32) = loop x=m for i < n do x+a[2] futhark-0.25.27/tests/hoist-unsafe2.fut000066400000000000000000000004741475065116200177200ustar00rootroot00000000000000-- Test that we *do* hoist a potentially unsafe (but loop-invariant) -- expression out of a loop. -- == -- input { 4i64 [1i64,2i64,3i64] } output { 6i64 } -- input { 0i64 empty([0]i64) } output { 0i64 } -- structure { /Loop/BinOp 2 } def main [n] (a: i64) (xs: [n]i64) = loop acc = 0 for x in xs do acc + x*(a/n) futhark-0.25.27/tests/holes/000077500000000000000000000000001475065116200156165ustar00rootroot00000000000000futhark-0.25.27/tests/holes/hof0.fut000066400000000000000000000001061475065116200171670ustar00rootroot00000000000000-- == -- input {0i64} error: hole def main (x: i64) : bool = ??? x x futhark-0.25.27/tests/holes/hof1.fut000066400000000000000000000001111475065116200171640ustar00rootroot00000000000000-- == -- input {0i64} error: hole def main (x: i64) : bool = ??? (+1) x futhark-0.25.27/tests/holes/hof2.fut000066400000000000000000000002111475065116200171660ustar00rootroot00000000000000-- Based on #1738. -- == -- input { 0f32 } error: hole.*hof2.fut:7 def f a b : f32 = a + b entry main y : f32 = (let x = y in f x) ??? futhark-0.25.27/tests/holes/hof3.fut000066400000000000000000000001031475065116200171670ustar00rootroot00000000000000-- From #1885 def f : (x: i64) -> []i64 = ??? entry main x = f x futhark-0.25.27/tests/holes/loop0.fut000066400000000000000000000006201475065116200173650ustar00rootroot00000000000000def main [j] (jets: [j]u8) = let h = 20 let world = replicate h (replicate 7 0) let skipped = 0 let top = i64.i64 (length world) let jet_num = 0 let rock_num = 0 let target = 10 let (_, skipped, top, _, _) = loop (jet_num, skipped, top, world, rock_num) while rock_num < target do if ??? then ??? else (jet_num, skipped, top, world, rock_num) in skipped + top futhark-0.25.27/tests/holes/loop1.fut000066400000000000000000000002561475065116200173730ustar00rootroot00000000000000-- Loop optimisation should not eliminate holes. -- == -- input { 0i32 } error: hole def holey (x: i32) : i32 = ??? def main (x: i32) = loop acc = holey x for i < x do acc futhark-0.25.27/tests/holes/loop2.fut000066400000000000000000000003731475065116200173740ustar00rootroot00000000000000-- Loop optimisation should not eliminate holes, but also not -- propagate them needlessly. -- == -- input { 0i32 } output { 0i32 } -- input { 1i32 } error: hole def holey (x: i32) : i32 = ??? def main (x: i32) = loop acc = x for i < x do holey acc futhark-0.25.27/tests/holes/simple0.fut000066400000000000000000000001071475065116200177050ustar00rootroot00000000000000-- == -- input {0i32} error: hole def main (x: i32) : i32 = ??? : i32 futhark-0.25.27/tests/holes/simple1.fut000066400000000000000000000001051475065116200177040ustar00rootroot00000000000000-- == -- input {0i64} error: hole def main (x: i64) : *[x]i32 = ??? futhark-0.25.27/tests/holes/simple2.fut000066400000000000000000000001671475065116200177150ustar00rootroot00000000000000-- == -- input {true} output { [1,2,3] } -- input {false} error: hole def main (b: bool) = if b then [1,2,3] else ??? futhark-0.25.27/tests/holes/simple3.fut000066400000000000000000000001361475065116200177120ustar00rootroot00000000000000-- == -- error: causality def main (b: bool) (A: []i32) = if b then filter (>0) A else ??? futhark-0.25.27/tests/holes/simple4.fut000066400000000000000000000001411475065116200177070ustar00rootroot00000000000000-- == -- error: Ambiguous size.*instantiated size parameter of "f32.sum" def main = f32.sum ??? futhark-0.25.27/tests/holes/simple5.fut000066400000000000000000000001031475065116200177060ustar00rootroot00000000000000-- == -- error: size-polymorphic def main (b: bool) : []i32 = ??? futhark-0.25.27/tests/holes/simple6.fut000066400000000000000000000000641475065116200177150ustar00rootroot00000000000000def f (x: f32) : f32 = ??? def main (bs:f32) = f bs futhark-0.25.27/tests/if0.fut000066400000000000000000000001061475065116200156770ustar00rootroot00000000000000-- == -- error: bool def main (x: i32) = if x then true else false futhark-0.25.27/tests/if1.fut000066400000000000000000000001021475065116200156740ustar00rootroot00000000000000-- == -- error: i32.*i16 def main b = if b then 2i32 else 2i16 futhark-0.25.27/tests/implicit_method.fut000066400000000000000000000102251475065116200203760ustar00rootroot00000000000000-- This test has been distilled from CalibVolDiff and exposed a bug in -- the memory expander. -- -- == -- input { -- [[1.0f32, 1.5f32, 2.5f32], [3.0f32, 6.5f32, 0.5f32]] -- [[0.10f32, 0.15f32, 0.25f32], [0.30f32, 0.65f32, 0.05f32]] -- [[0.1f32, 1.75f32], [1.0f32, 17.5f32]] -- [[0.01f32, 1.705f32], [0.1f32, 17.05f32]] -- [[0.02f32, 0.05f32], [0.04f32, 0.07f32]] -- 0.1f32 -- 30i64 -- } -- output { [[[-1.350561f32, 0.615297f32], [-0.225855f32, 0.103073f32]], -- [[-1.776825f32, 0.812598f32], [-0.230401f32, 0.105177f32]], -- [[-2.596339f32, 1.191926f32], [-0.235133f32, 0.107367f32]], -- [[-4.819269f32, 2.220859f32], [-0.240064f32, 0.109650f32]], -- [[-33.523365f32, 15.507273f32], [-0.245206f32, 0.112030f32]], -- [[6.763457f32, -3.140509f32], [-0.250574f32, 0.114514f32]], -- [[3.071727f32, -1.431704f32], [-0.256181f32, 0.117110f32]], -- [[1.987047f32, -0.929638f32], [-0.262046f32, 0.119824f32]], -- [[1.468468f32, -0.689606f32], [-0.268185f32, 0.122666f32]], -- [[1.164528f32, -0.548925f32], [-0.274619f32, 0.125644f32]], -- [[0.964817f32, -0.456489f32], [-0.281369f32, 0.128768f32]], -- [[0.823568f32, -0.391114f32], [-0.288459f32, 0.132050f32]], -- [[0.718389f32, -0.342435f32], [-0.295916f32, 0.135502f32]], -- [[0.637028f32, -0.304780f32], [-0.303769f32, 0.139136f32]], -- [[0.572216f32, -0.274786f32], [-0.312050f32, 0.142969f32]], -- [[0.519370f32, -0.250331f32], [-0.320795f32, 0.147017f32]], -- [[0.475458f32, -0.230010f32], [-0.330045f32, 0.151299f32]], -- [[0.438389f32, -0.212858f32], [-0.339844f32, 0.155834f32]], -- [[0.406681f32, -0.198186f32], [-0.350242f32, 0.160647f32]], -- [[0.379247f32, -0.185493f32], [-0.361297f32, 0.165764f32]], -- [[0.355280f32, -0.174405f32], [-0.373073f32, 0.171215f32]], -- [[0.334161f32, -0.164635f32], [-0.385643f32, 0.177033f32]], -- [[0.315410f32, -0.155961f32], [-0.399088f32, 0.183257f32]], -- [[0.298649f32, -0.148209f32], [-0.413506f32, 0.189930f32]], -- [[0.283580f32, -0.141239f32], [-0.429004f32, 0.197103f32]], -- [[0.269957f32, -0.134939f32], [-0.445709f32, 0.204836f32]], -- [[0.257582f32, -0.129217f32], [-0.463768f32, 0.213195f32]], -- [[0.246292f32, -0.123996f32], [-0.483353f32, 0.222260f32]], -- [[0.235948f32, -0.119214f32], [-0.504664f32, 0.232124f32]], -- [[0.226438f32, -0.114818f32], [-0.527942f32, 0.242899f32]]] -- } def tridagSeq [n][m] (a: [n]f32) (b: *[m]f32) (c: [m]f32) (y: *[m]f32 ): *[m]f32 = let (y,b) = loop ((y, b)) for i < n-1 do let i = i + 1 let beta = a[i] / b[i-1] let b[i] = b[i] - beta*c[i-1] let y[i] = y[i] - beta*y[i-1] in (y, b) let y[n-1] = y[n-1]/b[n-1] in loop (y) for j < n - 1 do let i = n - 2 - j let y[i] = (y[i] - c[i]*y[i+1]) / b[i] in y def implicitMethod [n][m] (myD: [m][3]f32, myDD: [m][3]f32, myMu: [n][m]f32, myVar: [n][m]f32, u: [n][m]f32) (dtInv: f32): *[n][m]f32 = map (\(tup: ([]f32,[]f32,[]f32) ) -> let (mu_row,var_row,u_row) = tup let (a,b,c) = unzip3 (map (\(tup: (f32,f32,[]f32,[]f32)): (f32,f32,f32) -> let (mu, var, d, dd) = tup in ( 0.0 - 0.5*(mu*d[0] + 0.5*var*dd[0]) , dtInv - 0.5*(mu*d[1] + 0.5*var*dd[1]) , 0.0 - 0.5*(mu*d[2] + 0.5*var*dd[2]))) (zip4 (mu_row) (var_row) myD myDD)) in tridagSeq a (copy b) c (copy u_row)) (zip3 myMu myVar u) def main [m][n] (myD: [m][3]f32) (myDD: [m][3]f32) (myMu: [n][m]f32) (myVar: [n][m]f32) (u: *[n][m]f32) (dtInv: f32) (num_samples: i64): *[num_samples][n][m]f32 = map (implicitMethod(myD,myDD,myMu,myVar,u)) ( map (*dtInv) (map (/f32.i64(num_samples)) (map f32.i64 (map (+1) (iota(num_samples)))))) futhark-0.25.27/tests/in-place-distribute.fut000066400000000000000000000016261475065116200210750ustar00rootroot00000000000000-- Extraction from generic pricer. Uses shape declarations in ways -- that were at one point problematic. -- -- == -- input { -- [1.0, 4.0, 7.0, 10.0, 13.0] -- } -- output { -- [[1.000000, 1.000000, 1.000000, 1.000000, 1.000000], [4.000000, -- 16.000000, 256.000000, 65536.000000, 4294967296.000000], [7.000000, -- 49.000000, 2401.000000, 5764801.000000, 33232930569601.000000], -- [10.000000, 100.000000, 10000.000000, 100000000.000000, -- 10000000000000000.000000], [13.000000, 169.000000, 28561.000000, -- 815730721.000000, 665416609183179904.000000]] -- } def seqloop (num_dates: i64) (gauss: f64): [num_dates]f64 = let bbrow = replicate num_dates 0.0f64 let bbrow[ 0 ] = gauss in loop (bbrow) for i in map (+1) (iota (num_dates-1)) do let bbrow[i] = bbrow[i-1] * bbrow[i-1] in bbrow def main [num_dates] (gausses: [num_dates]f64): [][]f64 = map (seqloop(num_dates)) gausses futhark-0.25.27/tests/include_basic.fut000066400000000000000000000003061475065116200200070ustar00rootroot00000000000000-- This test shows how to import a file and use its function. -- == -- input { -- 7 -- } -- output { -- 29 -- } import "include_basic_includee" def main(s: i32): i32 = importe_function(s) + 1 futhark-0.25.27/tests/include_basic_includee.fut000066400000000000000000000000521475065116200216550ustar00rootroot00000000000000def importe_function(t: i32): i32 = t * 4 futhark-0.25.27/tests/include_many.fut000066400000000000000000000004501475065116200176720ustar00rootroot00000000000000-- This test shows how to include many file and use their functions. -- == -- input { -- 2 -- } -- output { -- -5 -- } import "include_many_includee0" import "include_many_includee1" def main(s: i32): i32 = includee0_function(s) + includee1_function(s) * includee0_includee_function(s) futhark-0.25.27/tests/include_many_includee0.fut000066400000000000000000000001721475065116200216230ustar00rootroot00000000000000open (import "include_many_includee0_includee") def includee0_function(x: i32): i32 = includee0_includee_function(x * 3) futhark-0.25.27/tests/include_many_includee0_includee.fut000066400000000000000000000000651475065116200234740ustar00rootroot00000000000000def includee0_includee_function(x: i32): i32 = x - 3 futhark-0.25.27/tests/include_many_includee1.fut000066400000000000000000000000541475065116200216230ustar00rootroot00000000000000def includee1_function(x: i32): i32 = x + 6 futhark-0.25.27/tests/index0.fut000066400000000000000000000002271475065116200164140ustar00rootroot00000000000000-- Test simple indexing of an array. -- == -- input { -- [4,3,2,1,0] -- 1 -- } -- output { -- 3 -- } def main (a: []i32) (i: i32): i32 = a[i] futhark-0.25.27/tests/index1.fut000066400000000000000000000002551475065116200164160ustar00rootroot00000000000000-- Test simple indexing of an array. -- == -- input { -- [[4,3],[3,2],[2,1],[1,0]] -- 1 -- } -- output { -- [3,2] -- } def main (a: [][]i32) (i: i32): []i32 = a[i] futhark-0.25.27/tests/index10.fut000066400000000000000000000004441475065116200164760ustar00rootroot00000000000000-- Complex indexing into reshape, replicate and iota should be simplified away. -- == -- input { 2i64 } output { 1i64 } -- input { 10i64 } output { 3i64 } -- structure { Iota 0 Replicate 0 Reshape 0 } def main(x: i64) = let a = iota x let b = replicate x a let c = flatten b in c[3] futhark-0.25.27/tests/index11.fut000066400000000000000000000004051475065116200164740ustar00rootroot00000000000000-- Index projection! -- == -- input { [1,2,3] [[true,false],[false,true]] } -- output { 1 [2,3] [true,false] [[false,true]] } def newhead = (.[0]) def newtail = (.[1:]) def main (xs: []i32) (ys: [][]bool) = (newhead xs, newtail xs, newhead ys, newtail ys) futhark-0.25.27/tests/index12.fut000066400000000000000000000002501475065116200164730ustar00rootroot00000000000000-- Simplifying away a slice of a rotate is not so simple. -- == -- input { [1,2,3] } -- output { [3,1] } def main (xs: []i32) = let ys = rotate (-1) xs in ys[0:2] futhark-0.25.27/tests/index13.fut000066400000000000000000000001041475065116200164720ustar00rootroot00000000000000def f (xs: []([]i32,[]i32)) = xs[0] def main xs ys = f (zip xs ys) futhark-0.25.27/tests/index14.fut000066400000000000000000000001711475065116200164770ustar00rootroot00000000000000def f (A: []([](i32,i32), [](i32,i32))) = A[0].1 entry main (A: *[]([](i32,i32), [](i32,i32))) = f A with [0] = (0,0) futhark-0.25.27/tests/index2.fut000066400000000000000000000003411475065116200164130ustar00rootroot00000000000000-- Test indexing of an array of tuples. -- == -- input { -- [1, 2, 3] -- [1.0, 2.0, 3.0] -- 1 -- } -- output { -- 2 -- 2.000000 -- } def main (a: []i32) (b: []f64) (i: i32): (i32,f64) = let c = zip a b in c[i] futhark-0.25.27/tests/index4.fut000066400000000000000000000003521475065116200164170ustar00rootroot00000000000000-- Test indexing of a high-dimension array! -- == -- input { -- [[[1,2,3], [4,5,6], [7,8,9]], [[2,1,3], [4,6,5], [8,7,9]]] -- 1 -- 1 -- } -- output { -- [4,6,5] -- } def main (a: [][][]i32) (i: i32) (j: i32): []i32 = a[i,j] futhark-0.25.27/tests/index5.fut000066400000000000000000000006051475065116200164210ustar00rootroot00000000000000-- See if we can access an array with a stride. -- -- == -- input { [0,1,2,3,4,5,6,7,8,9] 4 9 2 } output { [4,6,8] } -- input { [0,1,2,3,4,5,6,7,8,9] 9 4 -2 } output { [9,7,5] } -- input { [0,1,2,3,4,5,6,7,8,9] 9 -10 -2 } error: out of bounds -- input { [0,1,2,3,4,5,6,7] 7 9 2 } output { [7] } def main (as: []i32) (i: i32) (j: i32) (s: i32): []i32 = as[i64.i32 i:i64.i32 j:i64.i32 s] futhark-0.25.27/tests/index6.fut000066400000000000000000000003011475065116200164130ustar00rootroot00000000000000-- Test simple indexing of an array with a type that is not 32 bits. -- == -- input { -- [4i8,3i8,2i8,1i8,0i8] -- 1 -- } -- output { -- 3i8 -- } def main (a: []i8) (i: i32): i8 = a[i] futhark-0.25.27/tests/index7.fut000066400000000000000000000003471475065116200164260ustar00rootroot00000000000000-- Test simple indexing of a 2D array where the element type is not 32 bits. -- == -- input { -- [[4i8,3i8],[3i8,2i8],[2i8,1i8],[1i8,0i8]] -- 1 -- } -- output { -- [3i8,2i8] -- } def main (a: [][]i8) (i: i32): []i8 = a[i] futhark-0.25.27/tests/index8.fut000066400000000000000000000002621475065116200164230ustar00rootroot00000000000000-- Indexing an array literal with a constant should remove the -- indexing. -- -- == -- structure { Index 0 Assert 0 } def main(xs: []i32): []i32 = let xss = [xs] in xss[0] futhark-0.25.27/tests/index9.fut000066400000000000000000000003021475065116200164170ustar00rootroot00000000000000-- Slicing a replicate should work. -- -- == -- input { 3i64 [1,2] } output { [[1,2],[1,2]] } def main [b] (m: i64) (diag: [b]i32): [][]i32 = let top_per = replicate m diag in top_per[1:m] futhark-0.25.27/tests/inlineTest1.fut000066400000000000000000000004721475065116200174260ustar00rootroot00000000000000-- == -- input { -- 42 -- 1337 -- } -- output { -- 24730855 -- } def fun1(a: i32, b: i32): i32 = a + b def fun2(a: i32, b: i32): i32 = fun1(a,b) * (a+b) def fun3(a: i32, b: i32): i32 = fun2(a,b) + a + b def main (n: i32) (m: i32): i32 = fun1(n,m) + fun2(n+n,m+m) + fun3(3*n,3*m) + fun2(2,n) + fun3(n,3) futhark-0.25.27/tests/inplace-replicate.fut000066400000000000000000000002751475065116200206110ustar00rootroot00000000000000-- == -- input { [1,2,3,4] 2i64 42 } output { [1i32, 2i32, 42i32, 4i32] } -- structure { Replicate 0 Assert 1 } def main (xs: *[]i32) (i: i64) (v: i32) = xs with [i:i+1] = replicate 1 v futhark-0.25.27/tests/inplace0.fut000066400000000000000000000004551475065116200167230ustar00rootroot00000000000000-- Test lowering of an in-place update. -- == -- input { -- 3i64 -- 1 -- 2 -- 42 -- } -- output { -- [[0,0,0], [0,0,0], [0,42,0]] -- } def main (n: i64) (i: i32) (j: i32) (x: i32): [][]i32 = let a = replicate n (replicate n 0) let b = replicate n 0 let b[i] = x let a[j] = b in a futhark-0.25.27/tests/inplace1.fut000066400000000000000000000004011475065116200167130ustar00rootroot00000000000000-- Test an in-place update of an argument to main() -- == -- input { -- [[1],[2],[3],[4],[5]] -- 2 -- 42 -- } -- output { -- [[1],[2],[42],[4],[5]] -- } def main [n][m] (a: *[m][n]i32) (i: i32) (v: i32): [][]i32 = let a[i] = replicate n v in a futhark-0.25.27/tests/inplace2.fut000066400000000000000000000010201475065116200167120ustar00rootroot00000000000000-- In-place update with a slice. -- -- == -- input { [1,2,3,4,5] [8,9] 2i64 } -- output { [1,2,8,9,5] } -- input { [1,2,3,4,5] [5,6,7,8,9] 0i64 } -- output { [5,6,7,8,9] } -- input { [1,2,3,4,5] empty([0]i32) 0i64 } -- output { [1,2,3,4,5] } -- input { [1,2,3,4,5] empty([0]i32) 1i64 } -- output { [1,2,3,4,5] } -- input { [1,2,3,4,5] empty([0]i32) 5i64 } -- output { [1,2,3,4,5] } -- input { [1,2,3,4,5] [1,2,3] -1i64 } -- error: Error def main [n][m] (as: *[n]i32) (bs: [m]i32) (i: i64): []i32 = let as[i:i+m] = bs in as futhark-0.25.27/tests/inplace3.fut000066400000000000000000000003521475065116200167220ustar00rootroot00000000000000-- In-place update without 'let'. -- == -- input { -- [[1],[2],[3],[4],[5]] -- 2 -- 42 -- } -- output { -- [[1],[2],[42],[4],[5]] -- } def main [k][n] (a: *[k][n]i32) (i: i32) (v: i32): [][]i32 = a with [i] = replicate n v futhark-0.25.27/tests/inplace4.fut000066400000000000000000000003501475065116200167210ustar00rootroot00000000000000-- In-place update of a slice. -- == -- input { [0,1,2,3,4] [42,42] } -- output { [42,42,2,3,4] } -- input { [42,42] [0,1,2,3,4] } -- error: inplace4.fut:9 def main [n][m] (xs: *[n]i32) (ys: [m]i32) : [n]i32 = xs with [0:m] = ys futhark-0.25.27/tests/inplace5.fut000066400000000000000000000005261475065116200167270ustar00rootroot00000000000000-- In-place update of the middle of an array. -- == -- input { [0u8,1u8,2u8,3u8,4u8] 1i64 3i64 } -- output { [1u8, 2u8, 3u8, 128u8, 1u8, 2u8, 3u8, 0u8] } def main (bs: []u8) i k = let k = i64.min 8 k let one_bit = [0x80u8, 1u8, 2u8, 3u8] let block = replicate 8 0u8 let block[0:k] = bs[i:i+k] let block[k:k+4] = one_bit in block futhark-0.25.27/tests/inplace6.fut000066400000000000000000000004061475065116200167250ustar00rootroot00000000000000-- Simplify indexing of an in-place update where we just wrote to that -- index. -- == -- input { [1,2,3] 0 42 } -- output { 42 } -- input { [1,2,3] 6 42 } -- error: out of bounds -- structure { Update 0 } def main (xs: *[]i32) i v = let xs[i] = v in xs[i] futhark-0.25.27/tests/int.fut000066400000000000000000000005441475065116200160210ustar00rootroot00000000000000-- Test integer semantics - overflow and the like. -- -- This relies on integers being 32 bit and signed, and shifts doing -- sign extension. -- -- == -- input { -- 2147483647 -- -2147483648 -- } -- output { -- 2147483647 -- -2147483648 -- 2147483647 -- -1073741824 -- } def main(a: i32) (b: i32): (i32, i32, i32, i32) = (a, a+1, b-1, b>>1) futhark-0.25.27/tests/intragroup/000077500000000000000000000000001475065116200166765ustar00rootroot00000000000000futhark-0.25.27/tests/intragroup/big0.fut000066400000000000000000000003761475065116200202450ustar00rootroot00000000000000-- This needs more than 2**31 threads, but its input isn't that big. -- == -- tags { no_python } -- compiled random input { [10000000]f32 } auto output def main (xs: []f32) = map (\x -> iota 256 |> map f32.i64 |> map (+x) |> scan (+) 0 |> f32.sum) xs futhark-0.25.27/tests/intragroup/complex-screma.fut000066400000000000000000000005331475065116200223360ustar00rootroot00000000000000-- Parallelise even a complicated screma with horisontally fused scan -- and reduce. -- == -- structure gpu { /SegMap/SegScan 1 /SegMap/SegRed 1 } entry main [n] [m] (a: [m][n]f32) = #[incremental_flattening(only_intra)] map (\ row -> let row_scanned = scan (+) 0 row in (reduce (+) 0 row, row_scanned) ) a |> unzip futhark-0.25.27/tests/intragroup/expansion0.fut000066400000000000000000000005371475065116200215070ustar00rootroot00000000000000-- This uses way too much memory if memory expansion is done by the -- number of threads, not the number of groups. -- == -- random input { [30]bool [30][2]f32 } -- compiled random input { [30000]bool [30000][256]f32 } def main bs xss = #[incremental_flattening(only_intra)] map2 (\b xs -> if b then xs else iterate 10 (scan (+) 0f32) xs) bs xss futhark-0.25.27/tests/intragroup/if0.fut000066400000000000000000000005051475065116200200740ustar00rootroot00000000000000-- == -- compiled random input { [1023]bool [1023][256]i32 } auto output -- structure gpu { -- /If/True/SegMap/If/True/SegMap 1 -- /If/True/SegMap/If/False/SegMap 1 -- } def main (bs: []bool) (xss: [][]i32) = map2 (\b xs -> if b then map (+2) xs else map (+3) (scan (+) 0 xs)) bs xss futhark-0.25.27/tests/intragroup/inplace0.fut000066400000000000000000000006071475065116200211140ustar00rootroot00000000000000-- In-place updates at the group level are tricky, because we have to -- be sure only one thread writes. -- == -- random input { [1][256]f32 } auto output -- structure gpu { SegMap/SegScan 2 SegMap/Update 1 } def main (xss: [][]f32) = #[incremental_flattening(only_intra)] map (\xs -> let ys = scan (+) 0 xs let ys[0] = ys[0] + 1 in scan (+) 0 ys) xss futhark-0.25.27/tests/intragroup/reduce0.fut000066400000000000000000000003551475065116200207500ustar00rootroot00000000000000-- Simple intra-group reduction. -- == -- random input { [1][256]i32 } auto output -- random input { [10][256]i32 } auto output -- structure gpu { SegMap/SegRed 1 } def main xs = #[incremental_flattening(only_intra)] map i32.sum xs futhark-0.25.27/tests/intragroup/reduce1.fut000066400000000000000000000004711475065116200207500ustar00rootroot00000000000000-- Multiple intra-group reductions. -- == -- random input { [1][256]i32 } auto output -- compiled random input { [100][256]i32 } auto output -- structure gpu { SegMap/SegRed 1 SegMap/SegRed/SegBinOp 2 } def main xs = #[incremental_flattening(only_intra)] unzip (map (\xs -> (i32.sum xs, i32.product xs)) xs) futhark-0.25.27/tests/intragroup/reduce1.fut.tuning000066400000000000000000000000261475065116200222470ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/intragroup/reduce2.fut000066400000000000000000000003731475065116200207520ustar00rootroot00000000000000-- Map-reduce inside group. -- == -- random input { [1][256]i32 } auto output -- random input { [100][256]i32 } auto output -- structure gpu { SegMap/SegRed 1 } def main xs = #[incremental_flattening(only_intra)] map (map i32.abs >-> i32.sum) xs futhark-0.25.27/tests/intragroup/reduce2.fut.tuning000066400000000000000000000000261475065116200222500ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/intragroup/reduce3.fut000066400000000000000000000004131475065116200207460ustar00rootroot00000000000000-- Segmented intra-group reduction. -- == -- random input { [1][16][16]i32 } auto output -- random input { [10][16][16]i32 } auto output -- structure gpu { SegMap/SegRed 1 } def main xsss = #[incremental_flattening(only_intra)] map (map (i32.sum >-> (*2))) xsss futhark-0.25.27/tests/intragroup/reduce4.fut000066400000000000000000000006701475065116200207540ustar00rootroot00000000000000-- Intra-group reduction with array accumulator. -- == -- random input { [2][2][128]i32 } auto output -- structure gpu { SegMap/SegRed 1 } def badsum [n][m] (xss: [n][m]i32): []i32 = reduce_comm(\(xs: []i32) ys -> loop zs = replicate m 0 for i < m do let zs[i] = xs[i] + ys[i] in zs) (replicate m 0) xss def main xs = #[incremental_flattening(only_intra)] map badsum xs futhark-0.25.27/tests/intragroup/reduce_by_index0.fut000066400000000000000000000004121475065116200226230ustar00rootroot00000000000000-- == -- random input { 10i64 [1][256]i64 } auto output -- compiled random input { 10i64 [100][256]i64 } auto output def histogram k is = hist (+) 0 k (map (%k) is) (map (const 1i32) is) def main k is = #[incremental_flattening(only_intra)] map (histogram k) is futhark-0.25.27/tests/intragroup/reduce_by_index1.fut000066400000000000000000000004251475065116200226300ustar00rootroot00000000000000-- == -- random input { 10i64 [1][1][256]i64 } auto output -- compiled random input { 10i64 [10][1][256]i64 } auto output def histogram k is = hist (+) 0 k (map (%k) is) (map (const 1i32) is) def main k is = #[incremental_flattening(only_intra)] map (map (histogram k)) is futhark-0.25.27/tests/intragroup/reduce_by_index2.fut000066400000000000000000000005751475065116200226370ustar00rootroot00000000000000-- Nastier operator that requires locking. (If we ever get 64-bit -- float atomics, then maybe add another test.) -- == -- random input { 10i64 [1][256]i64 } auto output -- compiled random input { 10i64 [100][256]i64 } auto output def histogram k is = hist (+) 0 k (map (%k) is) (map (const 1f64) is) def main k is = #[incremental_flattening(only_intra)] map (histogram k) is futhark-0.25.27/tests/intragroup/replicate0.fut000066400000000000000000000004401475065116200214440ustar00rootroot00000000000000-- Written in a contrived way to make the replicate actually show up. -- == -- compiled random input { [1][256]i32 } auto output -- compiled random input { [100][256]i32 } auto output -- compiled random input { [100][512]i32 } auto output def main = map i32.sum >-> map (replicate 2000) futhark-0.25.27/tests/intragroup/replicate0.fut.tuning000066400000000000000000000000261475065116200227470ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/intragroup/replicate1.fut000066400000000000000000000003711475065116200214500ustar00rootroot00000000000000-- Replication of array. -- == -- compiled random input { [1][256]i32 } auto output -- compiled random input { [100][256]i32 } auto output -- compiled random input { [100][512]i32 } auto output def main = map (scan (+) 0i32) >-> map (replicate 20) futhark-0.25.27/tests/intragroup/replicate1.fut.tuning000066400000000000000000000000261475065116200227500ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/intragroup/scan0.fut000066400000000000000000000003611475065116200204220ustar00rootroot00000000000000-- Simple intra-group scan. -- == -- random input { [1][256]i32 } auto output -- random input { [10][256]i32 } auto output -- structure gpu { SegMap/SegScan 1 } def main xs = #[incremental_flattening(only_intra)] map (scan (+) 0i32) xs futhark-0.25.27/tests/intragroup/scan0.fut.tuning000066400000000000000000000000261475065116200217230ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/intragroup/scan1.fut000066400000000000000000000004461475065116200204270ustar00rootroot00000000000000-- Multiple intra-group scans. -- == -- random input { [1][256]i32 } auto output -- compiled random input { [100][256]i32 } auto output -- structure gpu { SegMap/SegScan 1 } def main xss = #[incremental_flattening(only_intra)] unzip (map (\xs -> (scan (+) 0i32 xs, scan (*) 1i32 xs)) xss) futhark-0.25.27/tests/intragroup/scan1.fut.tuning000066400000000000000000000000261475065116200217240ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/intragroup/scan2.fut000066400000000000000000000003751475065116200204310ustar00rootroot00000000000000-- Map-scan inside group. -- == -- random input { [1][256]i32 } auto output -- random input { [100][256]i32 } auto output -- structure gpu { SegMap/SegScan 1 } def main xs = #[incremental_flattening(only_intra)] map (map i32.abs >-> scan (+) 0) xs futhark-0.25.27/tests/intragroup/scan2.fut.tuning000066400000000000000000000000261475065116200217250ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/intragroup/scan3.fut000066400000000000000000000003331475065116200204240ustar00rootroot00000000000000-- Intra-group scan horizontally fused with map. -- == -- random input { [256][256]f32 } auto output def main xss = #[incremental_flattening(only_intra)] map (\xs -> (scan (+) 0f32 xs, map (+2) xs)) xss |> unzip futhark-0.25.27/tests/intragroup/segreduce0.fut000066400000000000000000000005311475065116200214430ustar00rootroot00000000000000-- Simple intra-group reduction. -- == -- random input { [1][1][256]i32 } auto output -- compiled random input { [10][16][16]i32 } auto output -- compiled random input { [10][256][1]i32 } auto output -- structure gpu { SegMap/SegRed 1 SegMap 1 SegRed 1 } def main xss = #[incremental_flattening(only_intra)] map (\xs -> map i32.sum xs) xss futhark-0.25.27/tests/intragroup/segreduce0.fut.tuning000066400000000000000000000000261475065116200227450ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/intragroup/segreduce1.fut000066400000000000000000000006331475065116200214470ustar00rootroot00000000000000-- Multiple intra-group reductions. -- == -- random input { [1][1][256]i32 } auto output -- compiled random input { [10][16][16]i32 } auto output -- compiled random input { [10][256][1]i32 } auto output -- structure gpu { /SegMap/SegRed 1 SegMap 1 SegRed 1 } def main xsss = #[incremental_flattening(only_intra)] xsss |> map (\xss ->map (\xs -> (i32.sum xs, i32.product xs)) xss) |> map unzip |> unzip futhark-0.25.27/tests/intragroup/segreduce1.fut.tuning000066400000000000000000000000261475065116200227460ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/intragroup/segreduce2.fut000066400000000000000000000005531475065116200214510ustar00rootroot00000000000000-- Map-reduce inside group. -- == -- random input { [1][1][256]i32 } auto output -- compiled random input { [10][16][16]i32 } auto output -- compiled random input { [10][256][1]i32 } auto output -- structure gpu { /SegMap/SegRed 1 SegMap 1 SegRed 1 } def main xsss = #[incremental_flattening(only_intra)] map (\xss -> map (map i32.abs >-> i32.sum) xss) xsss futhark-0.25.27/tests/intragroup/segreduce2.fut.tuning000066400000000000000000000000261475065116200227470ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/intragroup/segscan0.fut000066400000000000000000000005361475065116200211250ustar00rootroot00000000000000-- Simple intra-group scan. -- == -- random input { [1][1][256]i32 } auto output -- compiled random input { [10][16][16]i32 } auto output -- compiled random input { [10][256][1]i32 } auto output -- structure gpu { SegMap/SegScan 1 SegMap 1 SegScan 1 } def main xss = #[incremental_flattening(only_intra)] map (\xs -> map (scan (+) 0i32) xs) xss futhark-0.25.27/tests/intragroup/segscan1.fut000066400000000000000000000006421475065116200211240ustar00rootroot00000000000000-- Multiple intra-group scans. -- == -- random input { [1][1][256]i32 } auto output -- compiled random input { [10][16][16]i32 } auto output -- compiled random input { [10][256][1]i32 } auto output -- structure gpu { /SegMap/SegScan 1 SegMap 1 SegScan 1 } def main xsss = #[incremental_flattening(only_intra)] xsss |> map (\xss ->map (\xs -> (scan (i32.+) 0 xs, scan (i32.*) 1 xs)) xss) |> map unzip |> unzip futhark-0.25.27/tests/intragroup/segscan2.fut000066400000000000000000000005561475065116200211310ustar00rootroot00000000000000-- Map-scan inside group. -- == -- random input { [1][1][256]i32 } auto output -- compiled random input { [10][16][16]i32 } auto output -- compiled random input { [10][256][1]i32 } auto output -- structure gpu { /SegMap/SegScan 1 SegMap 1 SegScan 1 } def main xsss = #[incremental_flattening(only_intra)] map (\xss -> map (map i32.abs >-> scan (+) 0) xss) xsss futhark-0.25.27/tests/intragroup/stencil_1d.fut000066400000000000000000000003761475065116200214510ustar00rootroot00000000000000-- Intragroup relaxation. -- == -- compiled no_python random input { 100 [100][256]f32 } auto output def relax (xs: []f32) = map2 (+) xs (map2 (+) (rotate (-1) xs) (rotate 1 xs)) def main (steps: i32) (xss: [][]f32) = map (iterate steps relax) xss futhark-0.25.27/tests/intragroup/stencil_1d.fut.tuning000066400000000000000000000000261475065116200227440ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/intragroup/stencil_2d.fut000066400000000000000000000032761475065116200214540ustar00rootroot00000000000000-- Intragroup Game of Life! -- == -- random no_python compiled input { 100 [100][16][16]bool } auto output -- random no_python compiled input { 1000 [100][16][16]bool } auto output -- random no_python compiled input { 3000 [100][16][16]bool } auto output def bint: bool -> i32 = i32.bool def all_neighbours [n][m] (world: [n][m]bool): [n][m]i32 = let ns = map (rotate (-1)) world let ss = map (rotate 1) world let ws = rotate (-1) world let es = rotate 1 world let nws = map (rotate (-1)) ws let nes = map (rotate (-1)) es let sws = map (rotate 1) ws let ses = map (rotate 1) es in map3 (\(nws_r, ns_r, nes_r) (ws_r, world_r, es_r) (sws_r, ss_r, ses_r) -> map3 (\(nw,n,ne) (w,_,e) (sw,s,se) -> bint nw + bint n + bint ne + bint w + bint e + bint sw + bint s + bint se) (zip3 nws_r ns_r nes_r) (zip3 ws_r world_r es_r) (zip3 sws_r ss_r ses_r)) (zip3 nws ns nes) (zip3 ws world es) (zip3 sws ss ses) def iteration [n][m] (board: [n][m]bool): [n][m]bool = let lives = all_neighbours(board) in map2 (\(lives_r: []i32) (board_r: []bool) -> map2 (\(neighbors: i32) (alive: bool): bool -> if neighbors < 2 then false else if neighbors == 3 then true else if alive && neighbors < 4 then true else false) lives_r board_r) lives board def life (iterations: i32) (board: [][]bool) = loop board for _i < iterations do iteration board def main (iterations: i32) (board: [][][]bool) = map (life iterations) board futhark-0.25.27/tests/intragroup/stencil_2d.fut.tuning000066400000000000000000000000261475065116200227450ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/intragroup/toomuch.fut000066400000000000000000000012001475065116200210650ustar00rootroot00000000000000-- Intra-group for this one probably exceeds local memory -- availability, so should not be picked, even if it looks like a good -- idea! -- == -- compiled random input { [128][256]f64 [128][256]f64 [128][256]f64 [128][256]f64 } auto output type vec = (f64, f64, f64, f64) def vecadd (a: vec) (b: vec) = (a.0 + b.0, a.1 + b.1, a.2 + b.2, a.3 + b.3) def psum = scan vecadd (0,0,0,0) def main (xss: [][]f64) (yss: [][]f64) (zss: [][]f64) (vss: [][]f64) = #[incremental_flattening(no_outer)] map (psum >-> psum >-> psum >-> psum >-> psum >-> psum >-> psum >-> psum >-> psum) (map4 zip4 xss yss zss vss) |> map unzip4 |> unzip4 futhark-0.25.27/tests/intragroup/toomuch.fut.tuning000066400000000000000000000000261475065116200223750ustar00rootroot00000000000000default_threshold=200 futhark-0.25.27/tests/intunderscores.fut000066400000000000000000000002341475065116200202720ustar00rootroot00000000000000-- Integers can contain underscores -- == -- input { 123_456 } -- output { 101000i32 } def main (x: i32) = let x = 100_000i32 in x + i32.i16(1_000i16) futhark-0.25.27/tests/iota0.fut000066400000000000000000000002351475065116200162400ustar00rootroot00000000000000-- Does iota work at all? -- == -- input { 0i64 } -- output { empty([0]i64) } -- input { 2i64 } -- output { [0i64,1i64] } def main(n: i64): []i64 = iota(n) futhark-0.25.27/tests/ipl-bug.fut000066400000000000000000000003341475065116200165630ustar00rootroot00000000000000-- A contrived program that once made in-place lowering fail. def main [n] (xs: *[n][n]i32, ys0: [n]i32, i: i32): ([n][n]i32, i32) = let ys = map (+ 1) ys0 let zs = map (+ 1) xs[i] let xs[i] = ys in (xs, zs[i]) futhark-0.25.27/tests/issue1025.fut000066400000000000000000000053141475065116200166670ustar00rootroot00000000000000-- Derived from futracer module f32racer = { type t = f32 type point2D = {x: t, y: t} type point3D = {x: t, y: t, z: t} type angles = {x: t, y: t, z: t} } module i32racer = { type t = i32 type point2D = {x: t, y: t} type point3D = {x: t, y: t, z: t} type angles = {x: t, y: t, z: t} } type triangle = (f32racer.point3D, f32racer.point3D, f32racer.point3D) type point_projected = {x: i32, y: i32, z: f32} type point = i32racer.point2D type triangle_projected = (point_projected, point_projected, point_projected) type point_barycentric = (i32, i32racer.point3D, f32racer.point3D) type rectangle = (i32racer.point2D, i32racer.point2D) def barycentric_coordinates ({x, y}: i32racer.point2D) (triangle: triangle_projected) = let ({x=xp0, y=yp0, z=_}, {x=xp1, y=yp1, z=_}, {x=xp2, y=yp2, z=_}) = triangle in ((yp1 - yp2) * (x - xp2) + (xp2 - xp1) * (y - yp2)) def main [tn] (triangles_projected: [tn]triangle_projected) (w: i32) (h: i32) rects ((n_rects_x, n_rects_y): (i32, i32)) = let each_pixel [rtpn] (rect_triangles_projected: [rtpn]triangle_projected) (pixel_index: i32) = let p = {x=pixel_index / h, y=pixel_index % h} let each_triangle (t: triangle_projected) = barycentric_coordinates p t != 0 in all each_triangle rect_triangles_projected let rect_in_rect (({x=x0a, y=y0a}, {x=x1a, y=y1a}): rectangle) (({x=x0b, y=y0b}, {x=x1b, y=y1b}): rectangle): bool = ! (x1a <= x0b || x0a >= x1b || y1a <= y0b || y0a >= y1b) let bounding_box (({x=x0, y=y0, z=_}, {x=x1, y=y1, z=_}, {x=x2, y=y2, z=_}): triangle_projected): rectangle = ({x=i32.min (i32.min x0 x1) x2, y=i32.min (i32.min y0 y1) y2}, {x=i32.max (i32.max x0 x1) x2, y=i32.max (i32.max y0 y1) y2}) let triangle_in_rect (rect: rectangle) (tri: triangle_projected): bool = let rect1 = bounding_box tri in rect_in_rect rect1 rect || rect_in_rect rect rect1 let each_rect [bn] (rect: rectangle) (pixel_indices: [bn]i32) = let rect_triangles_projected = filter (triangle_in_rect rect) triangles_projected in map (each_pixel rect_triangles_projected) pixel_indices let rect_pixel_indices (totallen: i64) (({x=x0, y=y0}, {x=x1, y=y1}): rectangle) = let (xlen, ylen) = (i64.i32 (x1 - x0), i64.i32 (y1 - y0)) let xs = map (+ x0) (map i32.i64 (iota xlen)) let ys = map (+ y0) (map i32.i64 (iota ylen)) in flatten (map (\x -> map (\y -> x * h + y) ys) xs) :> [totallen]i32 let x_size = w / n_rects_x + i32.bool (w % n_rects_x > 0) let y_size = h / n_rects_y + i32.bool (h % n_rects_y > 0) let pixel_indicess = map (rect_pixel_indices (i64.i32 (x_size * y_size))) rects let pixelss = map2 each_rect rects pixel_indicess in pixelss futhark-0.25.27/tests/issue1043.fut000066400000000000000000000001701475065116200166620ustar00rootroot00000000000000-- == -- error: aliased to "xs" def f 'a 'b (f: a -> b) (xs: a) = f xs def main (xs: []i32) : *[]i32 = (`f`xs) id futhark-0.25.27/tests/issue1053.fut000066400000000000000000000106611475065116200166710ustar00rootroot00000000000000-- == -- error: trail_map type model_params = { -- environment parameters pct_pop: f32, --diffusion_ker_size: i32, decay: f32, -- agent parameters sensor_angle: f32, sensor_offset: f32, -- sensor_width: i32, rot_angle: f32, step_size: f32, deposit_amount: f32 -- angle_jitter: f32 } type agent = { loc: (f32,f32), ang: f32 } type env [grid_h][grid_w][n_agents] = { model_params: model_params, trail_map: [grid_h][grid_w]f32, agent_list: [n_agents]agent } def bounded (max: f32) (x: f32) : f32 = if x >= 0 && x < max then x else (x + max) f32.% max def loc2grid (grid_size: i64) (real_loc: f32) : i64 = let gs_f = f32.i64 grid_size in if real_loc >= 0 && real_loc < gs_f then i64.f32 real_loc else i64.f32 (bounded gs_f real_loc) def read_sensor [xn] [yn] (p: model_params) (trail_map: [yn][xn]f32) (x: f32, y: f32) (ang: f32) : f32 = let sx = f32.cos ang * p.sensor_offset + x |> loc2grid xn let sy = f32.sin ang * p.sensor_offset + y |> loc2grid yn in trail_map[sy,sx] def move_step (p: model_params) ({loc=(x: f32, y: f32), ang: f32} : agent) : agent = let x_ = x + p.step_size * f32.cos ang let y_ = y + p.step_size * f32.sin ang in {loc=(x_, y_), ang} def step_agent (p: model_params) (trail_map: [][]f32) ({loc,ang}: agent) : (agent, (i64, i64)) = let sl = read_sensor p trail_map loc (ang + p.sensor_angle) let sf = read_sensor p trail_map loc ang let sr = read_sensor p trail_map loc (ang - p.sensor_angle) let stepped = if sf >= sr && sf >= sl then move_step p {loc,ang} else (if sr >= sl then move_step p {loc, ang=ang - p.rot_angle} else move_step p {loc, ang=ang + p.rot_angle}) in (stepped, (i64.f32 loc.0, i64.f32 loc.1)) def step_agents [h][w][a] ({model_params, trail_map, agent_list}: env[h][w][a]) : env[h][w][a] = let (stepped, deposits) = unzip (map (step_agent model_params trail_map) agent_list) let flat_deposits = map (\(x,y) -> y*w+x) deposits let deposited = reduce_by_index (flatten trail_map) (+) 0 flat_deposits (replicate a model_params.deposit_amount) in {model_params, trail_map=unflatten deposited, agent_list=stepped} def disperse_cell [h][w] (p: model_params) (trail_map: [h][w]f32) (x: i64) (y: i64) : f32 = let neighbors = map (\(dx,dy) -> trail_map[(y+dy+h) i32.% h, (x+dx+w) i32.% w] ) [(-1, 1), ( 0, 1), ( 1, 1), (-1, 0), ( 1, 0), (-1,-1), ( 0,-1), ( 1,-1)] let sum = trail_map[x,y] + reduce (+) 0 neighbors in p.decay * sum / 9 def disperse_trail [h][w][a] ({model_params, trail_map, agent_list}: env[h][w][a]) : env[h][w][a] = {model_params, agent_list, trail_map=tabulate_2d h w (disperse_cell model_params trail_map)} def simulation_step [h][w][a] (e: env[h][w][a]) : env[h][w][a] = e |> step_agents |> disperse_trail def to_deg (rad: f32): i32 = 180 * rad / f32.pi |> f32.round |> i64.f32 def to_rad (deg: i64): f32 = f32.i64 deg * f32.pi / 180 def build_test_env [h][w][a] (trail_map: [h][w]f32) (agent_xs: [a]f32) (agent_ys: [a]f32) (agent_angs: [a]i64) : env[h][w][a] = let model_params = { pct_pop=0 , decay=0.5 , sensor_angle=to_rad 45 , sensor_offset=2 , rot_angle=to_rad 45 , step_size=1 , deposit_amount=9 } let agent_list = map3 (\x y ang -> {loc=(x,y), ang=to_rad ang}) agent_xs agent_ys agent_angs in {model_params, agent_list, trail_map} entry test_single_step_trail [h][w] (trail_map: [h][w]f32) (x: f32) (y: f32) (ang: i64) : [h][w]f32 = let e = simulation_step (build_test_env trail_map [x] [y] [ang]) in e.trail_map futhark-0.25.27/tests/issue1054.fut000066400000000000000000000010021475065116200166570ustar00rootroot00000000000000-- == -- input { [0x2b28ab09u32, 0x7eaef7cfu32, 0x15d2154fu32, 0x16a6883cu32] } -- auto output def blk_transpose (block: [4]u32) : [4]u32 = #[sequential] map (\i -> let offset = u32.i64 (3-i)<<3 in (((block[0] >> offset) & 0xFF) << 24) | (((block[1] >> offset) & 0xFF) << 16) | (((block[2] >> offset) & 0xFF) << 8) | ((block[3] >> offset) & 0xFF)) (iota 4) def main (key: [4]u32) : [11][4]u32 = map blk_transpose (loop w = [key] for _i < 10i32 do (w ++ [key])) :> [11][4]u32 futhark-0.25.27/tests/issue1068.fut000066400000000000000000000032121475065116200166710ustar00rootroot00000000000000type complex = {r:f32, i:f32} def complex r i : complex = {r, i} def real r = complex r 0 def imag i = complex 0 i def zero = complex 0 0 def addC (a:complex) (b:complex) : complex = {r=a.r+b.r, i=a.i+b.i} def subC (a:complex) (b:complex) : complex = {r=a.r-b.r, i=a.i-b.i} def mulC (a:complex) (b:complex) : complex = {r=a.r*b.r-a.i*b.i, i=a.r*b.i+a.i*b.r} def divC (a:complex) (b:complex) : complex = let d = b.r*b.r+b.i*b.i let r = (a.r*b.r+a.i*b.i)/d let i = (a.i*b.r-a.r*b.i)/d in {r, i} def pi:f32 = 3.141592653589793 def gfft [n] (inverse: bool) (xs:[n]complex) : [n]complex = let dir = 1 - 2*i64.bool inverse let (n', iter) = iterate_while (( (a << 1, b+1)) (1, 0) let iteration [l] ((xs:[l]complex), m, e, theta0) = let modc = (1 << e) - 1 let xs' = tabulate l (\i -> let i = i32.i64 i let q = i & modc let p'= i >> e let p = p'>> 1 let a = xs[q + (p << e)] let b = xs[q + (p + m << e)] let theta = theta0 * f32.i32 p in if bool.i32 (p' & 1) then mulC (complex (f32.cos theta) (-f32.sin theta)) (subC a b) else addC a b ) in (xs', m >> 1, e + 1, theta0 * 2) in (iterate iter iteration (xs, i32.i64 (n>>1), 0, pi*f32.from_fraction (dir*2) n) |> (.0)) def gfft3 [m][n][k] inverse (A:[m][n][k]complex) = tabulate_2d n k (\i j -> gfft inverse A[:,i,j]) def main testData = gfft3 false (map (map (map real)) testData) futhark-0.25.27/tests/issue1069.fut000066400000000000000000000003521475065116200166740ustar00rootroot00000000000000-- The 'unsafe' is just to make the code a little neater. It does not -- affect the bug. def main xs = #[unsafe] tabulate_2d 10 10 (\a i -> let x = xs[a+i] in loop xs = xs for j < 10 do map (+x) xs) futhark-0.25.27/tests/issue1073.fut000066400000000000000000000021411475065116200166650ustar00rootroot00000000000000def dot [n] (x: [n]bool) (y: [n]bool):i64 = i64.sum (map2 (\x y -> i64.bool <| x && y) x y) local def choose (n:i64) (k:i64):f32 = f32.product (map (\i -> f32.i64 (n+1-i)/f32.i64 i) (1...k)) def kc [m] (x: [m]bool) (y: [m]bool) (d:i64):f32 = f32.sum (map (choose (dot x y)) (1...d)) def norm [m] (k: [m]bool -> [m]bool -> f32) (x: [m]bool) (y: [m]bool) :f32 = let xy = k x y let xx = k x x let yy = k y y in xy / (f32.sqrt xx * f32.sqrt yy) def kcn x y d = norm (\x y -> kc x y d) x y type centroid [n][m] = {d: i64, w: [n]f32, trx: [n][m]bool, try: [n]bool} def predict [n][m][k] (c: centroid[n][m]) (xs: [k][m]bool): [k]f32 = map (\x -> f32.sum (map2 (\w x' -> w * kcn x x' c.d) c.w c.trx)) xs entry lto [n][m] (c:centroid[n][m]) = let zero i x = tabulate n (\j -> if j == i then 0 else x[j]) let cmod i c = {d=c.d, w=zero i c.w, trx=c.trx, try=c.try} let score i j = if c.try[i] || c.try[i] == c.try[j] then -1 else let c' = cmod i (cmod j c) let ys = predict c' [c.trx[i], c.trx[j]] in if ys[0] < ys[1] then 1 else if ys[0] == ys[1] then 0.5 else 0 in tabulate_2d n n score futhark-0.25.27/tests/issue1074.fut000066400000000000000000000027661475065116200167030ustar00rootroot00000000000000def dot [n] (x: [n]bool) (y: [n]bool):i64 = i64.sum (map2 (\x y -> i64.bool <| x && y) x y) local def choose (n:i64) (k:i64):f32 = f32.product (map (\i -> f32.i64 (n+1-i)/f32.i64 i) (1...k)) def kc [m] (x: [m]bool) (y: [m]bool) (d:i64):f32 = f32.sum (map (choose (dot x y)) (1...d)) def norm [m] (k: [m]bool -> [m]bool -> f32) (x: [m]bool) (y: [m]bool) :f32 = let xy = k x y let xx = k x x let yy = k y y in xy / (f32.sqrt xx * f32.sqrt yy) def kcn x y d = norm (\x y -> kc x y d) x y type centroid [n][m] = {d: i64, w: [n]f32, trx: [n][m]bool, try: [n]bool} def train [n][m] (x:[n][m]bool) (y:[n]bool) (d:i64): centroid [n][m] = let zeros = replicate n 0 let w = reduce (map2 (+)) zeros (map2 (\x' y -> (map (\x'' -> (2 * f32.bool y - 1) * kcn x' x'' d) x)) x y) in {d=d,w=w,trx=x,try=y} def predict [n][m][k] (c: centroid[n][m]) (xs: [k][m]bool): [k]f32 = map (\x -> f32.sum (map2 (\w x' -> w * kcn x x' c.d) c.w c.trx)) xs def lto [n][m] (c:centroid[n][m]) = let mean x = f32.sum x / f32.i64 (length x) let zero i x = tabulate n (\j -> if j == i then 0 else x[j]) let cmod i c = {d=c.d, w=zero i c.w, trx=c.trx, try=c.try} let score i j = if c.try[i] || c.try[i] == c.try[j] then -1 else let c' = cmod i (cmod j c) let ys = predict c' [c.trx[i], c.trx[j]] in if ys[0] < ys[1] then 1 else if ys[0] == ys[1] then 0.5 else 0 in tabulate_2d n n score |> flatten |> filter (>=0) |> mean entry test [n][m] (x:[n][m]bool) (y:[n]bool) = map (\d -> lto (train x y d)) (1...10) futhark-0.25.27/tests/issue1080.fut000066400000000000000000000031421475065116200166650ustar00rootroot00000000000000type complex = {r:f32, i:f32} def complex r i : complex = {r, i} def real r = complex r 0 def conjC (a:complex) : complex = {r = a.r, i = -a.i} def addC (a:complex) (b:complex) : complex = {r=a.r+b.r, i=a.i+b.i} def subC (a:complex) (b:complex) : complex = {r=a.r-b.r, i=a.i-b.i} def mulC (a:complex) (b:complex) : complex = {r=a.r*b.r-a.i*b.i, i=a.r*b.i+a.i*b.r} def pi:f32 = 3.141592653589793 def gfft [n] (inverse: bool) (xs:[n]complex) : [n]complex = let logN = assert (i64.popc n == 1) (i64.ctz n) let startTheta = pi * f32.from_fraction (2 - (i64.bool inverse << 2)) n let ms = n >> 1 let iteration [l] ((xs:[l]complex), e, theta0) = let modc = (1 << e) - 1 let xs' = tabulate l (\i -> let q = i & modc let p'= i >> e let p = p'>> 1 let ai = q + (p << e) let bi = ai + ms let a = xs[ai] let b = xs[bi] let theta = theta0 * f32.i64 p in if bool.i64 (p' & 1) then mulC (complex (f32.cos theta) (-f32.sin theta)) (subC a b) else addC a b ) in (xs', e + 1, theta0 * 2) in (iterate logN iteration (xs, 0, startTheta)).0 def gfft3 [m][n][k] inverse (A:[m][n][k]complex) = let A' = tabulate_2d n k (\i j -> gfft inverse A[:,i,j]) let A'' = tabulate_2d k m (\i j -> gfft inverse A'[:,i,j]) let A''' = tabulate_2d m n (\i j -> gfft inverse A''[:,i,j]) in A''' def ifft3 [m][n][k] (x:[m][n][k]complex) = let f = real (f32.from_fraction 1 (m*n*k)) in gfft3 true x |> map (map (map (mulC f))) def main = map ifft3 futhark-0.25.27/tests/issue1100.fut000066400000000000000000000002641475065116200166600ustar00rootroot00000000000000-- == -- input { [[1.0f32, 2.0f32], [3.0f32, 4.0f32]] } -- output { [[1.0f32, 2.0f32], [3.0f32, 4.0f32]] } def main xss = let scale (f: f32) = map (map (*f)) in scale 1.0 xss futhark-0.25.27/tests/issue1112.fut000066400000000000000000000044351475065116200166670ustar00rootroot00000000000000type triad 't = (t, t, t) def triadMap 'a 'b (f:a->b) (A:triad a) : triad b = (f A.0, f A.1, f A.2) def triadMap2 'a 'b 'c (f:a->b->c) (A:triad a) (B:triad b): triad c = (f A.0 B.0, f A.1 B.1, f A.2 B.2) def triadFold 'a (f:a->a->a) (A:triad a) : a = f A.0 <| f A.1 A.2 type v3 = triad f32 type m33 = triad v3 type quaternion = {r:f32, v:v3} def v3sum (v:v3) : f32 = triadFold (+) v def v3add (a:v3) (b:v3) : v3 = triadMap2 (+) a b def v3mul (a:v3) (b:v3) : v3 = triadMap2 (*) a b def v3dot (a:v3) (b:v3) : f32 = v3mul a b |> v3sum def gauss_jordan [m] [n] (A:[m][n]f32) = loop A for i < i64.min m n do let icol = map (\row -> row[i]) A let (j,_) = map f32.abs icol |> zip (iota m) |> drop i |> reduce_comm (\(i,a)(j,b) -> if a < b then (j,b) else if b < a then (i,a) else if i < j then (i,a) else (j,b)) (0,0) let f = (1-A[i,i]) / A[j,i] let irow = map2 (f32.fma f) A[j] A[i] in map (\j -> if j == i then irow else let f = f32.neg A[j,i] in map2 (f32.fma f) irow A[j]) (iota m) def hStack [m][n][l] (A:[m][n]f32) (B:[m][l]f32) = map2 concat A B def vStack [m][n][l] (A:[m][n]f32) (B:[l][n]f32) = concat A B def solveAB [m][n] (A:[m][m]f32) (B:[m][n]f32) : [m][n]f32 = let AB = hStack A B |> gauss_jordan in AB[0:m, m:(m+n)] :> [m][n]f32 def solveAb [m] (A:[m][m]f32) (b:[m]f32) = unflatten (b :> [m*1]f32) |> solveAB A |> flatten def main u_bs (points':[]v3) (forces:[](v3,v3)) = let C (x,y,z) = [ ( 1, 0, 0) , ( 0, 1, 0) , ( 0, 0, 1) , ( 0, -z, y) , ( z, 0, -x) , (-y, x, 0) ] let CC C = map (\a -> (\b -> map (v3dot a) b) C) C let singleParticle u_b (f, t) = let f_ext = [f.0,f.1,f.2,t.0,t.1,t.2] let C_ps = map C points' let CC_ps = map CC C_ps |> transpose |> map transpose |> map (map f32.sum) let f_flow = map2 (\C u -> map (v3dot u) C ) C_ps u_b |> map f32.sum let f_tot = map2 (+) f_flow f_ext let u = solveAb CC_ps f_tot in u in map2 singleParticle u_bs forces futhark-0.25.27/tests/issue1139.fut000066400000000000000000000025621475065116200166770ustar00rootroot00000000000000type LandCounts = [5]i8 type ColorRequirement = {masks: LandCounts, count: i8} type ProbTable = [16][9][4][18][18]f32 type ColorRequirements = #one_requirement i8 ColorRequirement | #two_requirements i8 ColorRequirement ColorRequirement def sum_with_mask (counts: LandCounts) (masks: LandCounts) = reduce (+) 0 (map2 (&) counts masks) def get_casting_probability_1 (lands: LandCounts) (cmc: i8) (requirement: ColorRequirement) (prob_to_cast: ProbTable) = prob_to_cast[cmc, requirement.count, 0, sum_with_mask lands requirement.masks, 0] def get_casting_probability_2 (lands: LandCounts) (cmc: i8) (req1: ColorRequirement) (req2: ColorRequirement) (prob_to_cast: ProbTable) = prob_to_cast[cmc, req1.count, req2.count, sum_with_mask lands req1.masks, sum_with_mask lands req2.masks] def get_casting_probability (lands: LandCounts) (requirements: ColorRequirements) (prob_to_cast: ProbTable) = match requirements case #one_requirement cmc req1 -> get_casting_probability_1 lands cmc req1 prob_to_cast case #two_requirements cmc req1 req2 -> get_casting_probability_2 lands cmc req1 req2 prob_to_cast entry get_score [picked_count] (lands: LandCounts) (prob_to_cast: ProbTable) (picked_color_requirements: [picked_count]ColorRequirements) = map (\req -> get_casting_probability lands req prob_to_cast) picked_color_requirements futhark-0.25.27/tests/issue1142.fut000066400000000000000000000003101475065116200166560ustar00rootroot00000000000000def main [n][m] (xsss: *[2][n][m]i32) = #[unsafe] let xss = xsss[0] let ys = loop acc = replicate m 0 for i < m do let acc[i] = xsss[0,0,i]+1 in acc in xss with [0] = ys futhark-0.25.27/tests/issue1143.fut000066400000000000000000000026671475065116200167000ustar00rootroot00000000000000def dotprod [n] (xs: [n]f32) (ys: [n]f32): f32 = reduce (+) 0.0 (map2 (*) xs ys) def house [d] (x: [d]f32): ([d]f32, f32) = let dot = dotprod x x let dot' = dot - x[0]**2 + x[0]**2 let beta = if dot' != 0 then 2.0/dot' else 0 in (x, beta) def matmul [n][p][m] (xss: [n][p]f32) (yss: [p][m]f32): [n][m]f32 = map (\xs -> map (dotprod xs) (transpose yss)) xss def outer [n][m] (xs: [n]f32) (ys: [m]f32): [n][m]f32 = matmul (map (\x -> [x]) xs) [ys] def matsub [m][n] (xss: [m][n]f32) (yss: [m][n]f32): *[m][n]f32 = map2 (\xs ys -> map2 (-) xs ys) xss yss def matadd [m][n] (xss: [m][n]f32) (yss: [m][n]f32): [m][n]f32 = map2 (\xs ys -> map2 (+) xs ys) xss yss def matmul_scalar [m][n] (xss: [m][n]f32) (k: f32): *[m][n]f32 = map (map (*k)) xss def block_householder [m][n] (A: [m][n]f32) (r: i64): ([][]f32, [][]f32) = #[unsafe] let Q = replicate m (replicate m 0) let (Q,A) = loop (Q,A) = (Q, copy A) for k in 0..<(n/r) do let s = k * r let V = replicate m (replicate r 0f32) let Bs = replicate r 0f32 let (A) = loop (A) for j in 0.. block_householder arr r) arrs futhark-0.25.27/tests/issue1153.fut000066400000000000000000000001211475065116200166600ustar00rootroot00000000000000-- == -- error: Default def example a b = a + b def example2 a = example a 1.0 futhark-0.25.27/tests/issue1155.fut000066400000000000000000000002251475065116200166670ustar00rootroot00000000000000-- == -- error: Default module type Addable = { type t val add: t -> t -> t } module Add_f32:Addable = { type t = f32 def add a b = a + b } futhark-0.25.27/tests/issue1168.fut000066400000000000000000000307151475065116200167020ustar00rootroot00000000000000type triad 't = (t, t, t) def triadMap 'a 'b (f:a->b) (A:triad a) : triad b = (f A.0, f A.1, f A.2) def triadMap2 'a 'b 'c (f:a->b->c) (A:triad a) (B:triad b): triad c = (f A.0 B.0, f A.1 B.1, f A.2 B.2) def triadZip 'a 'b (A: triad a) (B: triad b) : triad (a,b) = triadMap2 (\a b -> (a, b)) A B def triadUnzip 'a 'b (A:triad (a,b)) : (triad a, triad b) = (triadMap (.0) A, triadMap (.1) A) def triadFold 'a (f:a->a->a) (A:triad a) : a = f A.0 <| f A.1 A.2 def triadShiftR 'a (A:triad a) : triad a = (A.2, A.0, A.1) def triadShiftL 'a (A:triad a) : triad a = (A.1, A.2, A.0) def triplet a = (a, a, a) type v3 = triad f32 type m33 = triad v3 type quaternion = {r:f32, v:v3} def v3sum (v:v3) : f32 = triadFold (+) v def v3add (a:v3) (b:v3) : v3 = triadMap2 (+) a b def v3sub (a:v3) (b:v3) : v3 = triadMap2 (-) a b def v3mul (a:v3) (b:v3) : v3 = triadMap2 (*) a b def cross (a:v3) (b:v3) : v3 = (a.1*b.2-a.2*b.1, a.2*b.0-a.0*b.2, a.0*b.1-a.1*b.0) def v3dot (a:v3) (b:v3) : f32 = v3mul a b |> v3sum def scaleV3 (f:f32) = triadMap (*f) def v3negate = triadMap f32.neg def v3outer a b = triadMap (\f -> scaleV3 f b) a def sumV3s = reduce_comm v3add (triplet 0) def m33map2 f (A:m33) (B:m33) : m33 = triadMap2 (triadMap2 f) A B def m33add = m33map2 (+) def m33transpose (m:m33) = ( (m.0.0, m.1.0, m.2.0) , (m.0.1, m.1.1, m.2.1) , (m.0.2, m.1.2, m.2.2) ) def mvMult (m:m33) (v:v3) : v3 = triadMap (v3dot v) m def sumM33s = reduce_comm m33add (triplet <| triplet 0) def m33fromQuaternion (q:quaternion) : m33 = let a = q.r let b = q.v.0 let c = q.v.1 let d = q.v.2 let aa = a*a let ab = a*b let ac = a*c let ad = a*d let bb = b*b let bc = b*c let bd = b*d let cc = c*c let cd = c*d let dd = d*d in ( (aa+bb-cc-dd, 2*(bc-ad), 2*(bd+ac)) , (2*(bc+ad), aa-bb+cc-dd, 2*(cd-ab)) , (2*(bd-ac), 2*(cd+ab), aa-bb-cc+dd) ) def dot a b = map2 (*) a b |> f32.sum def outer [m][n] (as:[m]f32) (bs:[n]f32) = map (\a -> map (*a) bs) as def matVecMul [m][n] (A:[m][n]f32) (b:[n]f32) : [m]f32 = map (dot b) A def hash (x:u32) : u32 = let x = ((x >> 16) ^ x) * 0x45d9f3b let x = ((x >> 16) ^ x) * 0x45d9f3b let x = ((x >> 16) ^ x) in x -- xoshiro128** type PRNG = {state: (u32,u32,u32,u32)} def rotl (x:u32) (k:u32) = x << k | x >> 32-k def next (g:PRNG) : (u32, PRNG) = let (a,b,c,d) = g.state let res = (rotl (b * 5) 7) * 9 let t = b << 9 let c = c ^ a let d = d ^ b let b = b ^ c let a = a ^ d let c = c ^ t let d = rotl d 11 in (res, {state = (a,b,c,d)}) def split n (g:PRNG) = let (a, b, c, d) = g.state let (r, g') = next g let splitG i = let i' = u32.i64 i let r' = hash (r^i') let f a = hash (r'^a) in {state = (f a, f b, f c, f d)} in (tabulate n splitG, g') def newGen (seed:i32) : PRNG = let seed' = u32.i32 seed let h0 = hash seed' let h1 = hash (seed'+1) let h2 = hash (seed'+2) let h3 = hash (seed'+3) let a = h0 ^ h1 let b = h1 ^ h2 let c = h2 ^ h3 let d = h3 ^ h0 in {state = (a,b,c,d)} def genI32 g = let (r, g) = next g in (i32.u32 r, g) def genF32 g = let (i, g) = genI32 g in (f32.from_fraction (i64.i32 i) (i64.i32 i32.highest), g) def randomArray 'a (f:(PRNG -> (a, PRNG))) n g : ([]a, PRNG) = split n g |> (\(a, g) -> (map f a |> map (.0), g)) def sample [m] 'a (array: [m]a) (n: i64) (g: PRNG) : ([]a, PRNG) = let (indices, g') = randomArray genI32 n g let samples = map (\i -> array[i32.abs (i % i32.i64 m)]) indices in (samples, g') def normalPair g : ((f32, f32), PRNG) = let iteration (_,_,_,g) = let (x, g) = genF32 g let (y, g) = genF32 g let s = x*x + y*y in (s,x,y,g) let start = iteration (0:f32, 0:f32, 0:f32, g) let (s,x,y,g) = iterate_until (\(s,_,_,_) -> 0 x*x+y*y+z*z < 1) iteration start type^ network 'p 'i 'o = { parameter: p , propagation: p -> i -> (o, o -> (i, p)) , scale: f32 -> p -> p , add: p -> p -> p , zero: p , sum: (k:i64) -> [k]p -> p } def eval 'p 'i 'o (net:network p i o) (input:i) = (net.propagation net.parameter input).0 def gradient 'p 'i 'o (errF: o -> o -> o) (net:network p i o) (input:i, ref:o) = let (o, f) = net.propagation net.parameter input in (.1) <| f <| errF ref o def gradientDescentStep [n] 'p 'i 'o (lr:f32) (errF: o -> o -> o) (net: network p i o) (samples:[n](i,o)) : network p i o = let g = map (gradient errF net) samples |> net.sum n in { parameter = net.add net.parameter (net.scale (-lr/f32.i64 n) g) , propagation = net.propagation , scale = net.scale , add = net.add , zero = net.zero , sum = net.sum } def gradientDescent [n] 'p 'i 'o (lr: f32) (mf: f32) (batchSize: i64) (steps: i32) (errF: o -> o -> o) (net: network p i o) (samples:[n](i,o)) (momentum: p) (gen:PRNG) : (network p i o, p, PRNG) = let gradient p (input, ref) = let (o, f) = net.propagation p input in (.1) <| f <| errF ref o let iteration (p:p, m:p, gen: PRNG) : (p, p, PRNG) = let (batch, gen) = sample samples batchSize gen let grad = map (gradient p) batch |> net.sum batchSize let m = net.add (net.scale mf m) (net.scale (1-mf) grad) in (net.add p (net.scale (-lr/f32.i64 n) m), m, gen) let (parameter, momentum, gen) = iterate steps iteration (net.parameter, momentum, gen) in ( { parameter , propagation = net.propagation , scale = net.scale , add = net.add , zero = net.zero , sum = net.sum } , momentum , gen ) -- This is where the magic happens, the magic of function composition def chain 'p1 'p2 'i 'm 'o (a:network p1 i m) (b:network p2 m o) : network (p1, p2) i o = let parameter = (a.parameter, b.parameter) let propagation (ap, bp) i = let (m, af) = a.propagation ap i let (o, bf) = b.propagation bp m let f o' = let (m', bg) = bf o' let (i', ag) = af m' in (i', (ag, bg)) in (o, f) let scale f (ap, bp) = (a.scale f ap, b.scale f bp) let add (ap, bp) (ap', bp') = (a.add ap ap', b.add bp bp') let zero = (a.zero, b.zero) let sum k = unzip >-> (\(as, bs) -> (a.sum k as, b.sum k bs)) in {parameter, propagation, scale, add, zero, sum} def (<>) = chain def stateless 'i 'o (propagation': i -> (o, o -> i)) : network () i o = let parameter = () let propagation _ i = propagation' i |> (\(a,b) -> (a, b >-> (\i -> (i, ())))) let scale _ _ = () let add _ _ = () let zero = () let sum _ _ = () in {parameter, propagation, scale, add, zero, sum} --evaluates pairs of values in the same network def pairNetwork 'p 'i 'o (net:network p i o) : network p (i, i) (o, o) = let pairMap f (a, b) = (f a, f b) let parameter = net.parameter let propagation param i = let ((o1, f1),(o2, f2)) = pairMap (net.propagation param) i let f (o1, o2) = (f1 o1, f2 o2) |> (\((i1, g1), (i2, g2)) -> ((i1,i2), net.add g1 g2)) in ((o1, o2), f) let scale = net.scale let add = net.add let zero = net.zero let sum = net.sum in {parameter, propagation, scale, add, zero, sum} def linear [m][n] (weights:[m][n]f32) : network ([m][n]f32) ([n]f32) ([m]f32) = let parameter = weights let propagation ws i = let forward = matVecMul ws i let backward o = let i' = matVecMul (transpose ws) o let g = outer o i in (i', g) in (forward, backward) let scale f = map (map (*f)) let add = map2 (map2 (+)) let zero = replicate m (replicate n 0) let sum k (ps:[k][m][n]f32) = ps |> transpose |> map transpose |> map (map f32.sum) in {parameter, propagation, scale, add, zero, sum} def sum = let propagation [n] (is:[n]f32) = let forward = f32.sum is let backward = replicate n in (forward, backward) in stateless propagation def sumInv n = let propagation (i:f32) = let forward = replicate n i let backward os = f32.sum os in (forward, backward) in stateless propagation def genMap [n] 'p 'i 'o (ps:[n]p) (f: p -> i -> o) (df: p -> i -> o -> (i, p)) (scale': f32 -> p -> p) (add': p -> p -> p) (zero': p) : network ([n]p) ([n]i) ([n]o) = let parameter = ps let propagation ps is = (map2 f ps is, map3 df ps is >-> unzip) let scale f = map (scale' f) let add = map2 add' let zero = replicate n zero' let sum k (ps:[k][n]p) = ps |> transpose |> map (reduce_comm add' zero') in {parameter, propagation, scale, add, zero, sum} def statelessMap 'i 'o (n:i64) (f: i -> o) (df: i -> o -> i) = genMap (replicate n ()) (\_ -> f) (\_ i o -> (df i o, ())) (\_ _ -> ()) (\_ _ -> ()) () def bias [n] (biases:[n]f32) = genMap biases (+) (\_ _ o -> (o, o)) (*) (+) 0 def scale [n] (factors:[n]f32) = genMap factors (*) (\f i o -> (f*o, i*o)) (*) (+) 0 def smoothInvQuadMap [n] (cs:[n]f32) : network ([n]f32) ([n]v3) ([n]f32) = let f c v = f32.exp(-c*c * v3dot v v) let parameter = cs let propagation cs is = let forward = map2 f cs is let backward os = let ss = map2 (\f c -> -2*c*f) forward cs let is'= map4 (\c s o i -> scaleV3 (c*s*o) i) cs ss os is let gs = map3 (\s o i -> s*o*v3dot i i) ss os is in (is', gs) in (forward, backward) let scale f = map (*f) let add = map2 (+) let zero = replicate n 0 let sum k (ps:[k][n]f32) = ps |> transpose |> map f32.sum in {parameter, propagation, scale, add, zero, sum} def particlePairs [n] (pairs:[n](v3, v3)) : network ([n](v3,v3)) (v3, m33) ([n]v3) = let parameter = pairs let propagation ps (p, R) = let forward = map (\(u, v) -> p `v3add` mvMult R v `v3sub` u) ps let backward os = let i = let pg = sumV3s os let Rg = map2 (\(_, v) o -> v3outer o v) ps os |> sumM33s -- check outer!!! in (pg, Rg) let g = let R' = m33transpose R in map (\o -> (v3negate o, mvMult R' o)) os in (i, g) in (forward, backward) let scale f = map (\(u, v) -> (scaleV3 f u, scaleV3 f v)) let add = map2 (\(u, v) (u', v') -> (v3add u u', v3add v v')) let zero = replicate n ((0,0,0),(0,0,0)) let sum k (ps:[k][n](v3, v3)) = ps |> transpose |> map (reduce_comm (\(u,v) (u',v') -> (v3add u u',v3add v v')) ((0,0,0),(0,0,0))) in {parameter, propagation, scale, add, zero, sum} def atanMap n = statelessMap n (f32.atan) (\x e -> e/(x*x+1)) def testNet [m] (fs1:[m]f32) = sumInv m <> scale fs1 <> sum type^ interactionNet [n] 'p = {net: network p (v3, m33) f32, pairs: p -> [n](v3,v3)} type networkParameter [m][n] = ([m](v3, v3),((((((([m]f32),[n][m]f32),[n]()), [n][n]f32), [n]()),[n]f32),())) def interactionNet [m][n] (pairs:[m](v3,v3)) (cs:[m]f32) (ws1:[n][m]f32) (ws2:[n][n]f32) (fs:[n]f32) = let net = particlePairs pairs <> ( smoothInvQuadMap cs <> linear ws1 <> atanMap n <> linear ws2 <> atanMap n <> scale fs <> sum ) let pairs (x, _) = x in {net, pairs} type coordinate = (v3, quaternion) type sample = (coordinate, f32) type^ iNet [m][n] = interactionNet [m] (networkParameter [m][n]) def fromParameter [m][n] (pairs, cs, ws1, ws2, fs) (parameter:networkParameter [m][n]) : iNet [m][n] = let inet = interactionNet pairs cs ws1 ws2 fs let net = inet.net let net'= { parameter , add=net.add , scale=net.scale , sum=net.sum , propagation=net.propagation , zero=net.zero } in {net=net', pairs=inet.pairs} def main [m][n] lr mf batchSize steps stuff (netParameter:networkParameter[m][n]) (momentum:networkParameter[m][n]) (samples:[]sample) (gen:PRNG) = let inet = fromParameter stuff netParameter let samples' = map (\((p, o), r) -> ((p, m33fromQuaternion o), r)) samples let (_, momentum', _) = gradientDescent lr mf batchSize steps (\a b -> b - a) inet.net samples' momentum gen in momentum' futhark-0.25.27/tests/issue1173.fut000066400000000000000000000002741475065116200166730ustar00rootroot00000000000000-- == -- input { 10i64 } -- output { 130i32 } def main m = let f [n] m' v: ([m']i32, (os: [n]i32) -> i32) = (replicate m' (v+i32.i64 n), i32.sum) let (x, g) = f m 3 in g x futhark-0.25.27/tests/issue1174.fut000066400000000000000000000003271475065116200166730ustar00rootroot00000000000000-- == -- input { 0 } -- output { 1i64 2i64 } def delaylength [x] (arr: [x]i32) (y: i64) = length arr def main x = let f = delaylength [x] let g = delaylength [x,x] let (f', g') = id (f, g) in (f' 1, g' 2) futhark-0.25.27/tests/issue1177.fut000066400000000000000000000004661475065116200167020ustar00rootroot00000000000000-- We messed up because this loop is existential at *two* levels (the -- outer dimension and the nested dimension). Although the mess was -- in the let-binding, and due to how we constructed patterns for -- statements. def main n = loop acc = [([1], 1)] for i < n do replicate i (replicate (n-i) 1, i) futhark-0.25.27/tests/issue1192.fut000066400000000000000000000003461475065116200166740ustar00rootroot00000000000000-- == -- input { [1f32,2f32,3f32] } -- output { [[1.0f32, 2.0f32, 3.0f32], [3.0f32, 1.0f32, 2.0f32], [2.0f32, 3.0f32, 1.0f32]] } def main [n] (Irow1: [n]f32) = let In: [n][n]f32 = map (\i -> rotate (-i) Irow1) (iota n) in In futhark-0.25.27/tests/issue1194.fut000066400000000000000000000001701475065116200166710ustar00rootroot00000000000000-- == -- input { 8i64 } output { 511i64 } def main (n: i64) = loop i = 0 for d in n..n-1...0 do i + (1 << n - d) futhark-0.25.27/tests/issue1203.fut000066400000000000000000000004171475065116200166640ustar00rootroot00000000000000def main [n] (xss: [][n]i64) (x: i64) = map (\xs -> #[sequential] loop (xs, x) for _i < n do let res = opaque (scan (*) 0 xs) let tmp = map2 (-) res xs let x' = reduce (+) x tmp in (res, x')) xss futhark-0.25.27/tests/issue1209.fut000066400000000000000000000002741475065116200166730ustar00rootroot00000000000000type ObjectGeom = #Wall f64 | #Block ([3]f64) type Object = #PassiveObject ObjectGeom | #Light ([3]f64) def main (_: i32) : Object = #PassiveObject (#Block [ 1.0, -1.6, 1.2]) futhark-0.25.27/tests/issue1213.fut000066400000000000000000000003061475065116200166620ustar00rootroot00000000000000def foo x = #[sequential] let n = 2i64 let ys = replicate n 0 let ys[0] = x let bar = all (== 0i64) ys let baz = all (\i -> ys[i] == 0) (0.. sized (w-1) (row[:seam_idx] ++ row[seam_idx + 1:])) image minimum_seam def helper [h][w] (n: i64) (image: [h][w]u32) : [h][w]u32 = loop image = copy image for i < n do let w' = n - i in image with [:, :n-i-1] = remove_minimum_seam <| copy image[:, :w'] def main [m][h][w] (n: i64) (images: *[m][h][w]u32): [m][h][]u32 = let w' = w - n let res = #[incremental_flattening(only_intra)] map (helper n) images in res[:,:,:w'] futhark-0.25.27/tests/issue1225.fut000066400000000000000000000004321475065116200166650ustar00rootroot00000000000000def main [h][w] (pic: [h][w][3]u8) idxArr : [][][3]u8 = map2 (\(row: [][3]u8) i -> if i == 0 then sized (w-1) row[1:] else if i == w-1 then sized (w-1) row[:w-1] else sized (w-1) (concat (row[0:i]) (row[i+1:]))) pic idxArr futhark-0.25.27/tests/issue1231.fut000066400000000000000000000001321475065116200166570ustar00rootroot00000000000000-- == -- error: importe_function import "include_basic" def main x = importe_function x futhark-0.25.27/tests/issue1232.fut000066400000000000000000000004601475065116200166640ustar00rootroot00000000000000type option 'a = #some a | #none def bind 'a 'b (m: option a) (f: a -> option b): option b = match m case #none -> #none case #some a -> f a entry foo (n: i32): bool = match bind (if n == 0 then #some () else #none) (\() -> #some true) case #some res -> res case #none -> true futhark-0.25.27/tests/issue1237.fut000066400000000000000000000076401475065116200167000ustar00rootroot00000000000000-- Something about interchange and certificates. -- -- Adapted from nw.fut. def B0: i64 = 64 def fInd (B: i64) (y:i32) (x:i32): i32 = y*(i32.i64 B+1) + x def max3 (x:i32, y:i32, z:i32) = if x < y then if y < z then z else y else if x < z then z else x def mkVal [l2][l] (B: i64) (y:i32) (x:i32) (pen:i32) (inp_l:[l2]i32) (ref_l:[l][l]i32) : i32 = #[unsafe] max3( ( (inp_l[fInd B (y-1) (x-1)])) + ( ref_l[y-1, x-1]) , ( (inp_l[fInd B y (x-1)])) - pen , ( (inp_l[fInd B (y-1) x])) - pen ) def intraBlockPar [len] (B: i64) (penalty: i32) (inputsets: [len*len]i32) (reference2: [len][len]i32) (b_y: i64) (b_x: i64) : [B][B]i32 = let ref_l = reference2[b_y * B + 1: b_y * B + 1 + B, b_x * B + 1: b_x * B + 1 + B] :> [B][B]i32 let inputsets' = unflatten inputsets let inp_l' = (copy inputsets'[b_y * B : b_y * B + B + 1, b_x * B : b_x * B + B + 1]) :> *[B+1][B+1]i32 -- inp_l is the working memory let inp_l = replicate ((B+1)*(B+1)) 0i32 |> unflatten -- Initialize inp_l with the already processed the column to the left of this -- block let inp_l[0:B+1, 0] = inputsets'[b_y * B : b_y * B + B + 1, b_x * B] -- Initialize inp_l with the already processed the row to above this block let inp_l[0, 1:B+1] = inputsets'[b_y * B, b_x * B + 1 : b_x * B + B + 1] let inp_l = assert (inp_l' == inp_l) (flatten inp_l) -- Process the second half (anti-diagonally) of the block let inp_l = loop inp_l for m < B-1 do let m = B - 2 - m let (inds, vals) = unzip ( -- tabulate over the m'th anti-diagonal after the middle tabulate B (\tx -> ( if tx > m then (-1, 0) else let ind_x = i32.i64 (tx + B - m) let ind_y = i32.i64 (B - tx) let v = mkVal B ind_y ind_x penalty inp_l ref_l in (i64.i32 (fInd B ind_y ind_x), v) ) )) in scatter inp_l inds vals let inp_l2 = unflatten inp_l in inp_l2[1:B+1,1:B+1] :> [B][B]i32 def updateBlocks [q][lensq] (B: i64) (len: i32) (blk: i64) (mk_b_y: (i32 -> i32)) (mk_b_x: (i32 -> i32)) (block_inp: [q][B][B]i32) (inputsets: *[lensq]i32) = let (inds, vals) = unzip ( tabulate (blk*B*B) (\gid -> let B2 = i32.i64 (B*B) let gid = i32.i64 gid let (bx, lid2) = (gid / B2, gid % B2) let (ty, tx) = (lid2 / i32.i64 B, lid2 % i32.i64 B) let b_y = mk_b_y bx let b_x = mk_b_x bx let v = #[unsafe] block_inp[bx, ty, tx] let ind = (i32.i64 B*b_y + 1 + ty) * len + (i32.i64 B*b_x + tx + 1) in (i64.i32 ind, v))) in scatter inputsets inds vals def main [lensq] (penalty : i32) (inputsets : *[lensq]i32) (reference : *[lensq]i32) : *[lensq]i32 = let len = i64.f32 (f32.sqrt (f32.i64 lensq)) let inputsets = inputsets :> [len*len]i32 let reference = reference :> [len*len]i32 let worksize = len - 1 let B = i64.min worksize B0 let B = assert (worksize % B == 0) B let block_width = trace <| worksize / B let reference2 = unflatten reference let inputsets = loop inputsets for blk < block_width do let blk = blk + 1 let block_inp = tabulate blk (\b_x -> let b_y = blk-1-b_x in intraBlockPar B penalty inputsets reference2 b_y b_x ) let mkBY bx = i32.i64 (blk - 1) - bx let mkBX bx = bx in updateBlocks B (i32.i64 len) blk mkBY mkBX block_inp inputsets in inputsets :> [lensq]i32 futhark-0.25.27/tests/issue1239.fut000066400000000000000000000003511475065116200166720ustar00rootroot00000000000000def n = 2i64 def grid (i: i64): [n][n]i64 = let grid = unflatten (0..<(n * n)) in if i == 0 then unflatten (scatter (flatten grid) (0..> 18u64) ^ oldstate) >> 27u64) let rot = u32.u64 (oldstate >> 59u64) in ({state, inc}, (xorshifted >> rot) | (xorshifted << ((-rot) & 31u32))) def rng_from_seed (xs: []i32) = let initseq = 0xda3e39cb94b95bdbu64 -- Should expose this somehow. let state = 0u64 let inc = (initseq << 1u64) | 1u64 let {state, inc} = (rand {state, inc}).0 let state = loop state for x in xs do state + u64.i32 x in (rand {state, inc}).0 def dummy_rng (): rng = rng_from_seed [0] type tup = (i64, i64) def foo [n] (grid: *[n]tup): *[n]tup = grid def bar [n] (grid: *[n]tup): *[n]tup = grid with [0] = (1, 1) def foo_bar [n] (grid_foo: *[n]tup) (grid_bar: *[n]tup): (*[n]tup, *[n]tup) = (foo grid_foo, bar grid_bar) def create_tup (_: rng): tup = (0, 0) def dummy_grid (n: i64): [n]tup = replicate n ((create_tup (dummy_rng ()))) entry foo_bar_bar [n] (grid0: *[n]i64) (grid1: *[n]i64): ([n]i64, [n]i64) = unzip ((foo_bar (zip grid0 grid1) (dummy_grid n)).1) futhark-0.25.27/tests/issue1242.fut000066400000000000000000000001751475065116200166700ustar00rootroot00000000000000def hof [n] (f: i64 -> i64) (irf: [n]f32) (b: bool) = n def main [n][m] (irf: [n]f32) (bs: [m]bool) = map (hof id irf) bs futhark-0.25.27/tests/issue1243.fut000066400000000000000000000007741475065116200166760ustar00rootroot00000000000000def DistMatrix [m] (b : [m]f32) : [m][m]f32 = let initial = replicate m b let outside = (replicate (m+1) f32.inf) with [0] = 0 in (loop (D, column) = (initial, outside) for i < m do let next_row = loop cs = replicate (m+1) f32.inf for j in 1...m do cs with [j] = f32.minimum [cs[j-1], column[j]] in (D with [i] = (next_row[1:] :> [m]f32), next_row)) |> \(x,_) -> x def main [d] [n] (s : [d][n]f32) as = map (\a -> #[sequential] DistMatrix s[a]) as futhark-0.25.27/tests/issue1250.fut000066400000000000000000000005611475065116200166660ustar00rootroot00000000000000-- == -- input { } module type mt = { type t val to_i64 : t -> i64 } module i8mt = { type t = i8 def to_i64 = i8.to_i64 } module type a = { module b: mt module c: mt } module a_impl = { module b = i8mt module c = i8mt } module use_a (d: a) = { def b_to_i64 (b: d.b.t) = d.b.to_i64 b } module f = use_a a_impl def main = f.b_to_i64 10 futhark-0.25.27/tests/issue1262.fut000066400000000000000000000007261475065116200166740ustar00rootroot00000000000000def reduce_by_index_stream [k] 'a 'b (dest: *[k]a) (f: *[k]a -> b -> *[k]a) : *[k]a = dest def pairInteraction [n] 'a 'b (ne: b) (add: b-> b -> b) (potential: a -> a -> b) (coordinates: [n]a) = let interaction [k] (acc: *[k]b) (i: i64, j: i64) : *([k]b) = let cI = coordinates[i] let cJ = coordinates[j] let v = potential cI cJ in acc in reduce_by_index_stream (replicate n ne) interaction futhark-0.25.27/tests/issue1268.fut000066400000000000000000000003631475065116200166770ustar00rootroot00000000000000type~ csc_mat = { col_offsets: []i64 , row_idxs: []i32 } def low (d: csc_mat) (j: i64): i64 = 0 entry foo (m: csc_mat): csc_mat = let n = length m.col_offsets - 1 let m' = copy m let lows' = map (\j -> low m' j) (iota n) in m' futhark-0.25.27/tests/issue1284.fut000066400000000000000000000006261475065116200166770ustar00rootroot00000000000000def one_scatter (n: i64) (m: i64) : [n][m]i32 = let res = tabulate_2d n m (\i j -> 0) in scatter_2d res [(0, 0)] [1] entry foo = one_scatter 2 2 def another_scatter [n] (inp: *[n][n]i32): *[n][n]i32 = scatter_2d inp [(0, 1)] [2] entry bar = another_scatter (one_scatter 2 2) -- == -- entry: foo -- input {} output { [[1, 0], [0, 0]] } -- == -- entry: bar -- input {} output { [[1, 2], [0, 0]] } futhark-0.25.27/tests/issue1291.fut000066400000000000000000000006041475065116200166710ustar00rootroot00000000000000def main [n] (is : [n]i64) (ys_bar: *[n]f32) = let scatter_res_adj_gather = map(\ is_elem -> ys_bar[is_elem] ) is let zeros = replicate n 0.0f32 let map_res_bar = scatter ys_bar is zeros let map_adjs_1 = map (\ lam_adj -> 5.0f32 * lam_adj ) scatter_res_adj_gather let map_adjs_2 = map (\ lam_adj -> 3.0f32 * lam_adj ) map_res_bar in (map_adjs_1, map_adjs_2) futhark-0.25.27/tests/issue1292.fut000066400000000000000000000004161475065116200166730ustar00rootroot00000000000000-- == -- input { -- [1i64,2i64,3i64,4i64,5i64] -- [1.0,2.0,3.0,4.0,5.0] -- [1.0,2.0,3.0,4.0,5.0] -- } -- auto output def main [n] (is: [n]i64) (vs: [n]f64) (xs: [n]f64) : [n]f64 = let xs' = map2 (*) vs xs let vs' = map2 (*) vs xs' let ys = scatter xs' is vs' in ys futhark-0.25.27/tests/issue1294.fut000066400000000000000000000002041475065116200166700ustar00rootroot00000000000000-- == -- error: array element def apply 'a '~b (f: i64 -> a -> b) (x: a) = [f 0 x, f 1 x] def main (x: i32) = apply replicate x futhark-0.25.27/tests/issue1296.fut000066400000000000000000000002541475065116200166770ustar00rootroot00000000000000-- == -- structure { Concat 0 Replicate 1 } def main(n: i64) = let xs = replicate n 0 let ys = replicate n 1 let xs' = map (+ 1) xs let zs = concat xs' ys in zs futhark-0.25.27/tests/issue1302.fut000066400000000000000000000000771475065116200166660ustar00rootroot00000000000000def arr = replicate 10 true def main x = copy arr with [0] = x futhark-0.25.27/tests/issue1310.fut000066400000000000000000000011631475065116200166620ustar00rootroot00000000000000-- This is a fragile test, in that it depends on subtle interchange to -- even produce the code that sequentialisation chokes on. def dotprod [n] (xs: [n]f64) (ys: [n]f64) = f64.sum (map2 (*) xs ys) def identity (n: i64): [n][n]f64 = tabulate_2d n n (\i j ->if j == i then 1 else 0) def back_substitution [n] (U: [n][n]f64) (y: [n]f64): [n]f64 = let x = replicate n 0 in loop x for j in 0.. map i32.bool |> scan (+) 0 |> map (+ -1) let is = data |> map i64.i32 let f = scatter (replicate n (-1i32)) is d let g = f |> map (!= -1) |> reduce (&&) true let new_data = scatter (copy data) is d in (g, new_data) def resolve_vars [n] (a: [n]bool) = let valid = reduce (&&) true a let data = map i32.bool a in (valid, data) entry oef [n] (a: [n]bool) = let (b, data) = resolve_vars a let (c, data) = resolve_fns a data in (b && c, data) futhark-0.25.27/tests/issue1325.fut000066400000000000000000000006351475065116200166730ustar00rootroot00000000000000-- == -- random input { [100]bool [100]i64 [10][10]i32 [10][10]i32 } auto output -- compiled random input { [1000]bool [1000]i64 [10][10]i32 [10][10]i32 } auto output -- structure gpu-mem { If/False/Replicate 0 If/True/Replicate 0 } let main [n][k] (bs: []bool) (is: []i64) (xs: [n][k]i32) (ys: [n][k]i32) = #[unsafe] map2 (\b i -> let j = i%n in if b then xs[j] else ys[j]) bs is futhark-0.25.27/tests/issue1326.fut000066400000000000000000000000621475065116200166660ustar00rootroot00000000000000def f n = iota (n + 1) entry a = f 10 entry b = a futhark-0.25.27/tests/issue1328.fut000066400000000000000000000012751475065116200166770ustar00rootroot00000000000000-- == -- compiled random input { [200][10]f32 [200][10]f32 [10]f32 } auto output def main [n1][n2][m] (X1: [n1][m]f32) (X2: [n2][m]f32) (Y: [m]f32): [][]f32 = let res = map ( \x1 -> ( map (\x2 -> #[sequential] let Y0 = Y let Y1 = map2 (\x y -> y + 3 * x) x1 Y let a = reduce (+) 0 x2 let Y2 = map2 (\x y -> y + a * x) x2 Y1 in [ f32.i64 0, f32.i64 1, ( Y2[0]),( Y2[1]),( Y2[2]), ( Y2[3]),( Y2[4]),( Y2[5]), ( Y2[6]),( Y2[7]),( Y2[8]), ( Y2[9]) ] ) X2 ) ) X1 in (flatten res) futhark-0.25.27/tests/issue1332.fut000066400000000000000000000000631475065116200166640ustar00rootroot00000000000000def main (n: i64) (xs: []i32) = (.[::n - 1]) xs futhark-0.25.27/tests/issue1333.fut000066400000000000000000000003411475065116200166640ustar00rootroot00000000000000-- == -- input {true} output { 1i64 } module type mt = { type arr [n] val mk : bool -> arr [] } module m : mt = { type arr [n] = [n]bool def mk b = [b] } def main b = let f [n] (_: m.arr [n]) = n in f (m.mk b) futhark-0.25.27/tests/issue1345.fut000066400000000000000000000000331475065116200166650ustar00rootroot00000000000000entry sqrt32 (x: bool) = x futhark-0.25.27/tests/issue1350.fut000066400000000000000000000032431475065116200166670ustar00rootroot00000000000000def matmul A B = map (\a -> map (\b -> f64.sum (map2 (*) a b)) (transpose B)) A def identity n = tabulate_2d n n (\i j -> f64.bool(i == j)) def relatives_to_absolutes [n] (relatives: [][4][4]f64) (parents: [n]i64) : [n][4][4]f64 = loop absolutes = replicate n (identity 4) for (relative, parent, i) in zip3 relatives parents (iota n) do if parent == -1 then absolutes with [i] = relative else absolutes with [i] = copy (absolutes[parent] `matmul` relative) def euler_angles_to_rotation_matrix (xzy: [3]f64) : [4][4]f64 = let tx = xzy[0] let ty = xzy[2] let tz = xzy[1] let costx = f64.cos(tx) let sintx = f64.sin(tx) let costy = f64.cos(ty) let sinty = f64.sin(ty) let costz = f64.cos(tz) let sintz = f64.sin(tz) in [[costy * costz, -costx * sintz + sintx * sinty * costz, sintx * sintz + costx * sinty * costz, 0], [costy * sintz, costx * costz + sintx * sinty * sintz, -sintx * costz + costx * sinty * sintz, 0], [-sinty, sintx * costy, costx * costy, 0], [0, 0, 0, 1]] def get_posed_relatives (num_bones: i64) (base_relatives: [][][]f64) (pose_params: [][3]f64) = let offset = 3 let f i = matmul base_relatives[i] (euler_angles_to_rotation_matrix pose_params[i+offset]) in tabulate num_bones f entry get_skinned_vertex_positions (num_bones: i64) (base_relatives: [][][]f64) (parents: []i64) (pose_params: [][3]f64) = let relatives = get_posed_relatives num_bones base_relatives pose_params in relatives_to_absolutes relatives parents futhark-0.25.27/tests/issue1358.fut000066400000000000000000000001711475065116200166740ustar00rootroot00000000000000-- == def main [n] b (xs: *[n]i32) = let vals = map (+2) (if b then reverse xs else xs) in scatter xs (iota n) vals futhark-0.25.27/tests/issue1366.fut000066400000000000000000000002601475065116200166720ustar00rootroot00000000000000#[noinline] def update [n] (xs: [n]i32): [n]i32 = map (+ 1) xs def main [n][m] (xss: *[m][n]i32) = #[unsafe] loop xss for i < m do copy xss with [i] = update xss[i] futhark-0.25.27/tests/issue1384.fut000066400000000000000000000026211475065116200166750ustar00rootroot00000000000000def loess_proc [n_m] (q: i64) (m_fun: i64 -> i64) (max_dist: [n_m]f32) : ([n_m]f32, [n_m]f32) = (max_dist, max_dist) def loess_l [m] [n] [n_m] (xx_l: [m][n]f32) (yy_l: [m][n]f32) (q: i64) (m_fun: i64 -> i64) (ww_l: [m][n]f32) (l_idx_l: [m][n_m]i64) (max_dist_l: [m][n_m]f32) (n_nn_l: [m]i64) : ([m][n_m]f32, [m][n_m]f32) = let loess_l_fun_fun (q_pad: i64) = map3 (\xx yy max_dist -> loess_proc q m_fun max_dist) xx_l yy_l max_dist_l |> unzip in if q < 12 then loess_l_fun_fun 11 else if q < 32 then loess_l_fun_fun 31 else if q < 64 then loess_l_fun_fun 63 else loess_l_fun_fun 4095 entry main [m] [n] [n_m] (xx_l: [m][n]f32) (yy_l: [m][n]f32) (ww_l: [m][n]f32) (l_idx_l: [m][n_m]i64) (max_dist_l: [m][n_m]f32) (n_nn_l: [m]i64) (q: i64) (jump: i64) : ([m][n_m]f32, [m][n_m]f32) = let m_fun (x: i64): i64 = 2 + i64.min (x * jump) (n - 1) in loess_l xx_l yy_l q m_fun ww_l l_idx_l max_dist_l n_nn_l futhark-0.25.27/tests/issue1424.fut000066400000000000000000000053401475065116200166710ustar00rootroot00000000000000-- == -- tags { no_opencl no_cuda no_hip no_pyopencl } type index = {x: i64, y: i64, z: i64} def E :f64 = 1 def Emin :f64 = 1e-6 def indexIsInside (nelx :i64, nely :i64, nelz :i64) (idx :index) :bool = (idx.x >= 0 && idx.y >= 0 && idx.z >= 0 && idx.x < nelx && idx.y < nely && idx.z < nelz) def isOnBoundary(nodeIndex :index, nx :i64, ny :i64, nz :i64) :bool = (nodeIndex.x == 0) type nodalWeights = (f64, f64) -- utility methods for doing averages for the element edge, surface, and center def sumEdge (a :f64, b :f64) = 0.5*(a+b) def sumSurf (a :f64, b :f64, c :f64, d :f64) = 0.25*(a+b+c+d) def sumCent (w :nodalWeights) = 0.125*(w.0+w.1) def prolongateCell (cellIndex :index, w :nodalWeights) = let i = cellIndex let c = sumCent w let sxp = sumSurf (w.1,w.1,w.1,w.1) -- Surface X positive let sxn = sumSurf (w.0,w.1,w.1,w.1) let syp = sumSurf (w.1,w.0,w.1,w.1) let syn = sumSurf (w.1,w.1,w.1,w.1) let szp = sumSurf (w.1,w.1,w.1,w.1) let szn = sumSurf (w.0,w.1,w.1,w.1) in [({x=(2*i.x+0),y=(2*i.y+1),z=(2*i.z+0)}, (w.0, sumEdge(w.0,w.1))), ({x=(2*i.x+1),y=(2*i.y+1),z=(2*i.z+0)}, (sumEdge(w.1,w.0), w.1))] def generateLoad (o :i64) (w :nodalWeights) :[24]f64 = scatter (replicate 24 0) [0+o,3+o,6+o,9+o,12+o,15+o,18+o,21+o] [w.0, w.1, w.1, w.1, w.0, w.0, w.0, w.1] def applyBoundaryConditionsToWeightsInverse (elementIndex :index, w :nodalWeights) (nx :i64, ny :i64, nz :i64) :nodalWeights = let ei = elementIndex let setIfNotIndex (v :f64) (nodeIndex :index) :f64 = if (isOnBoundary (nodeIndex,nx,ny,nz)) then (v/8) else 0 in (setIfNotIndex w.0 {x=ei.x+0,y=ei.y+1,z=ei.z+0}, setIfNotIndex w.1 {x=ei.x+1,y=ei.y+1,z=ei.z+0}) def getFineValue [nelx][nely][nelz] (x :[nelx][nely][nelz]f32) (o :i64) (cellIndex :index, w :nodalWeights) :[24]f64 = let w_onBoundary = applyBoundaryConditionsToWeightsInverse (cellIndex, w) ((nelx+1),(nely+1),(nelz+1)) let loadVectorOnBoundary = generateLoad o w_onBoundary in loadVectorOnBoundary def restrictCell (vals :[8][24]f64) :[24]f64 = vals |> transpose |> map (\x -> reduce (+) 0 x) def getDiagonalCellContribution [nelx][nely][nelz] (l :u8) (x :[nelx][nely][nelz]f32) (cellIndex, w) = let fineCells = loop vals = [(cellIndex,w)] for i < (i64.u8 l) do vals |> map prolongateCell |> flatten let fineValuesX = map (getFineValue x 0) fineCells let coarseValues = loop vx = fineValuesX for i < (i64.u8 l) do let ii = (i64.u8 l) - i - 1 let xx = (vx :> [(8**ii)*8][24]f64) |> unflatten |> map restrictCell in xx let coarseX = flatten coarseValues in sized 24 coarseX entry getNodeDiagonalValues [nelx][nely][nelz] (l :u8) (x :[nelx][nely][nelz]f32) input = map (getDiagonalCellContribution l x) input futhark-0.25.27/tests/issue1424_tiny.fut000066400000000000000000000004211475065116200177270ustar00rootroot00000000000000-- Small program related to #1424. -- == -- input { 2 [1,2,3] [42] } auto output -- structure gpu-mem {SegMap 1} def main n (xs: []i32) (unit: [1]i32) = #[sequential_inner] map (\x -> let arr = replicate 1 x in (loop arr for i < n do unit)) xs futhark-0.25.27/tests/issue1435.fut000066400000000000000000000033171475065116200166750ustar00rootroot00000000000000-- == def segmented_scan [n] 't (op: t -> t -> t) (ne: t) (flags: [n]bool) (as: [n]t): [n]t = (unzip (scan (\(x_flag,x) (y_flag,y) -> (x_flag || y_flag, if y_flag then y else x `op` y)) (false, ne) (zip flags as))).1 def replicated_iota [n] (reps:[n]i64) : []i64 = let s1 = scan (+) 0 reps let s2 = map2 (\i x -> if i==0 then 0 else x) (iota n) (rotate (-1) s1) let tmp = reduce_by_index (replicate (reduce (+) 0 reps) 0) i64.max 0 s2 (iota n) let flags = map (>0) tmp in segmented_scan (+) 0 flags tmp def segmented_iota [n] (flags:[n]bool) : [n]i64 = let iotas = segmented_scan (+) 0 flags (replicate n 1) in map (\x -> x-1) iotas def expand 'a 'b (sz: a -> i64) (get: a -> i64 -> b) (arr:[]a) : []b = let szs = map sz arr let idxs = replicated_iota szs let iotas = segmented_iota (map2 (!=) idxs (rotate (-1) idxs)) in map2 (\i j -> get arr[i] j) idxs iotas def sub xs (i:i64) = xs[i] def flatMap 'a 'b [m] (n: i64) (f: a -> [n]b) (xs: [m]a): *[]b = flatten(map f xs) def sized 't n (xs: []t) = xs :> [n]t entry listmults2 = let xss = [[1, 2, 3], [2, 3, 4]] in let yss = [[4, 5, 6], [5, 6, 7]] in map (\(xs, ys, x, y) -> x * y) (expand (\(xs, ys, x) -> length ys) (\(xs, ys, x) -> \y -> (xs, ys, x, sub ys y)) (expand (\(xs, ys) -> length xs) (\(xs, ys) -> \x -> (xs, ys, sub xs x)) (flatMap (length yss) (\xs -> sized (length yss) (map (\ys -> (xs, ys)) yss)) xss))) futhark-0.25.27/tests/issue1455.fut000066400000000000000000000002001475065116200166630ustar00rootroot00000000000000entry test [n] (xs: *[n]i32) : (*[]i32, *[]i32) = let a = [0] let cp = loop a for i < n do copy xs[i:i+1] in (cp, xs) futhark-0.25.27/tests/issue1457.fut000066400000000000000000000002621475065116200166750ustar00rootroot00000000000000def main [n][m] (xs: *[n][m]f64) = #[unsafe] let (a, b) = unzip (tabulate m (\i -> (f64.sqrt (f64.i64 i), f64.cos (f64.i64 i)))) let xs[0] = a let xs[1] = b in (a, xs) futhark-0.25.27/tests/issue1462.fut000066400000000000000000000006441475065116200166750ustar00rootroot00000000000000def step [n] (buf: [n]f32) (r: i64): []f32 = let total = reduce (+) 0f32 buf in buf ++ [total * (f32.i64 r)] def run [n] (t: i64) (T: i64) (buf: [n]f32) (r: i64): []f32 = (loop (buf) for i < t do step buf r) :> [T]f32 def runs (t: i64): [][]f32 = let start = [1f32, 1f32] in map (run t (t + 2) start) (iota 5) -- == -- tags { no_opencl no_cuda no_hip no_pyopencl } -- input { } def main = runs 10 futhark-0.25.27/tests/issue1476.fut000066400000000000000000000012471475065116200167020ustar00rootroot00000000000000-- == -- tags { no_opencl no_cuda no_hip no_pyopencl } let Lmx [nlat] (m:i64) (n:i64) (amm:f32) (cx:[nlat]f32) (x:[nlat]f32) = let X = replicate n 0 let m' = f32.i64 m let Sx p = map2 (*) p x |> reduce (+) 0f32 let p0 = map (\cx -> amm*(1 - cx*cx)**(m'/2)*(-1)**m') cx let p1 = map2 (\cx p0 -> cx*p0) cx p0 let p2 n p1 p0 = map3 (\cx p1 p0 -> cx*p1) cx p1 p0 let (X, pn, _) = loop (X,p1,p0) for i < (n-m-1) do let pi = p2 (m'+2+f32.i64 i) p1 p0 let X[m+2+i] = Sx pi in (X, pi, p1) in X let main (lmax:i64) amm cx gr = #[unsafe] let f x = tabulate lmax (\m -> Lmx m lmax amm[m] cx x) in vjp f cx gr futhark-0.25.27/tests/issue1478.fut000066400000000000000000000001351475065116200166770ustar00rootroot00000000000000-- == -- error: Entry points module mmain = { entry f x = x + 1 } entry f x = mmain.f x futhark-0.25.27/tests/issue1481.fut000066400000000000000000000124031475065116200166720ustar00rootroot00000000000000module type field = { module R: real type t val zero : t val +: t -> t -> t val *: R.t -> t -> t -- dummy function to generate new non-zero values of for t val tab3: i64 -> i64 -> i64 -> t } module mk_scalar_field (R: real) = { module R = R type t = R.t def zero = R.i64 0 def (+) (x:t) (y:t): t = R.(x + y) def (*) (a:R.t) (x:t): t = R.(a * x) def tab3 i j k: t = R.((i64 i)+(i64 j)+(i64 k)) } module mk_lt (F: field) = { module R = F.R def rm1 = R.i64 (-1) def r0 = R.i64 0 def r1 = R.i64 1 def r2 = R.i64 2 def r3 = R.i64 3 def r4 = R.i64 4 def len_q (N:i64): i64 = (N + 1) * (N + 2) // 2 def gen_ml (N:i64): [](i64,i64) = loop ml = [(0,0)] for i < ((len_q N)-1) do let (m,l) = ml[i] let nl = if N == l then m + 1 else l + 1 let nm = if N == l then m + 1 else m in ml ++ [(nm,nl)] def all_amm (N:i64): []R.t = iota N |> map (\i -> R.i64(i + 1)) |> map (\k -> R.((r2*k+r1)/(r2*k))) |> ([r1]++) |> scan (R.*) r1 |> map (\el -> R.(sqrt(el/(r4*pi)))) def amn (m:R.t) (n:R.t): R.t = R.(sqrt((r4*n*n - r1)/(n*n - m*m))) def bmn (m:R.t) (n:R.t): R.t = let l = R.((r2*n + r1)/(r2*n - r3)) let r = R.(((n - r1)*(n - r1) - m*m)/(n*n - m*m)) in R.(rm1*sqrt(l * r)) def lat_grid (nlat:i64): []R.t = iota nlat |> map R.i64 |> map (\x -> R.(cos (x / (i64 nlat) * pi))) -- m F.zero) let m' = R.i64 m let Sx p = map2 (F.*) p x |> reduce (F.+) F.zero -- P^m_m let p0 = map (\cx -> R.(amm*(r1 - cx*cx)**(m'/r2)*(rm1)**m')) cx let X[m] = Sx p0 -- P^m_(m + 1) let p1 = map2 (\cx p0 -> R.((amn m' (m' + r1))*cx*p0)) cx p0 let X[m + 1] = Sx p1 -- P^m_n -> P^m_n+1 -> P^m_n+2 let p2 n p1 p0 = map3 (\cx p1 p0 -> R.((amn m' n)*cx*p1 + (bmn m' n)*p0)) cx p1 p0 -- P^m_n let (X, pn, _) = match (n-m) case 0 -> (X, p0, p0) case 1 -> (X, p1, p0) case _ -> loop (X,p1,p0) for i < (n - m - 1) do let pi = p2 R.(m'+r2+i64 i) p1 p0 let X[m+2+i] = Sx pi in (X, pi, p1) in X -- n==m and m F.zero) let m' = R.i64 m let Sx p = map2 (F.*) p x |> reduce (F.+) F.zero -- P^m_m let p0 = map (\cx -> R.(amm*(r1 - cx*cx)**(m'/r2)*(rm1)**m')) cx let X[m] = Sx p0 in if (n-m)==0 then X else Lmx' m np1 amm cx x def iLmX' [nlat] (m:i64) (np1:i64) (amm:R.t) (cx:[nlat]R.t) (X:[np1]F.t): [nlat]F.t = let n = np1 - 1 let x = tabulate nlat (\i -> F.zero) let m' = R.i64 m -- at each m we do x += X[m]P^_n let SX m x p = map2 (\xi pi -> xi F.+ (pi F.* X[m])) x p -- P^m_m let p0:[nlat]R.t = map (\cx -> R.(amm*(r1 - cx*cx)**(m'/r2)*(rm1)**m')) cx let x[:] = SX m x p0 -- P^m_(m + 1) let p1 = map2 (\cx p0 -> R.((amn m' (m' + r1))*cx*p0)) cx p0 let x[:] = SX m x p1 -- P^m_n -> P^m_n+1 -> P^m_n+2 let p2 n p1 p0 = map3 (\cx p1 p0 -> R.((amn m' n)*cx*p1 + (bmn m' n)*p0)) cx p1 p0 -- P^m_n let (x, pn, _) = match (n-m) case 0 -> (x, p0, p0) case 1 -> (x, p1, p0) case _ -> loop (x,p1,p0) for i < (n - m - 1) do let pi = p2 R.(m'+r2+i64 i) p1 p0 let x[:] = SX (m+2+i) x pi in (x, pi, p1) in x def iLmX [nlat] (m:i64) (np1:i64) (amm:R.t) (cx:[nlat]R.t) (X:[np1]F.t): [nlat]F.t = let n = np1 - 1 let x = tabulate nlat (\i -> F.zero) let m' = R.i64 m -- at each m we do x += X[m]P^_n let SX m x p = map2 (\xi pi -> xi F.+ (pi F.* X[m])) x p -- P^m_m let p0:[nlat]R.t = map (\cx -> R.(amm*(r1 - cx*cx)**(m'/r2)*(rm1)**m')) cx let x[:] = SX m x p0 in if (n-m)==0 then x else iLmX' m np1 amm cx X def lt [np1][nlon][nlat] (amm:[np1]R.t) (cx:[nlat]R.t) (x:[nlon][nlat]F.t): [np1][np1]F.t = map2 (\m x -> Lmx m np1 amm[m] cx x) (iota np1) x[:np1] :> [np1][np1]F.t def ilt [np1][nlon][nlat] (amm:[np1]R.t) (cx:[nlat]R.t) (X:[np1][np1]F.t): [nlon][nlat]F.t = let out = tabulate_2d nlon nlat (\_ _ -> F.zero) let out[:np1] = map2 (\m x -> iLmX m np1 amm[m] cx x) (iota np1) X in out :> [nlon][nlat]F.t def bench (nxfm:i64) (lmax:i64) (nlat:i64) (nlon:i64): [nxfm][nlon][nlat]F.t = -- lmax > nlat let amm = all_amm lmax let x = tabulate_3d nxfm nlon nlat F.tab3 let cx = lat_grid nlat let X = map (lt amm cx) x let x' = map (ilt amm cx) X in x' } module lts = mk_lt (mk_scalar_field f32) -- == -- compiled input { 1i64 20i64 128i64 256i64 } -- compiled input { 8i64 20i64 128i64 256i64 } entry main (nxfm:i64) (lmax:i64) (nlat:i64) (nlon:i64) = lts.bench nxfm lmax nlat nlon futhark-0.25.27/tests/issue1492.fut000066400000000000000000000001401475065116200166670ustar00rootroot00000000000000-- == -- error: Unexpected "def" def f x = let x' = x + 2 def g x = let x' = x * 2 in x futhark-0.25.27/tests/issue1498.fut000066400000000000000000000033131475065116200167020ustar00rootroot00000000000000let pi = 3.141592653589793f32 let mpr_node [N][T] (i: i64) (n: i64) (dt: f32) (nstep: i64) (i0: i64) (r: [N][T]f32) (V: [N][T]f32) (weights: [N][N]f32) (idelays: [N][N]i64) (G: f32) (I: f32) (Delta: f32) (eta: f32) (tau: f32) (J: f32) = let dr r V = 1/tau * ( Delta / (pi * tau) + 2 * V * r) let dV r V r_c = 1/tau * ( V**2 - pi**2 * tau**2 * r**2 + eta + J * tau * r + I + r_c) let r_bound r = if r >= 0f32 then r else 0f32 let r_c = iota N |> map (\m -> weights[n,m] * r[m,i - idelays[n,m] - 1]) |> reduce (+) 0f32 let r_c = r_c * G let dr_0 = dr r[n,i-1] V[n,i-1] let dV_0 = dV r[n,i-1] V[n,i-1] r_c let r_int = r[n,i-1] + dt*dr_0 let V_int = V[n,i-1] + dt*dV_0 let r_int = r_bound r_int in (r_int, V_int) let mpr_integrate_seq [N] [T] (dt: f32) (nstep: i64) (i0: i64) (r: *[N][T]f32) (V: *[N][T]f32) (weights: [N][N]f32) (idelays: [N][N]i64) (G: f32) (I: f32) (Delta: f32) (eta: f32) (tau: f32) (J: f32): (*[N][T]f32, *[N][T]f32) = loop (r, V) for i_ < nstep do let n = 0 let i = i_ + i0 let (rn, Vn) = mpr_node i n dt nstep i0 r V weights idelays G I Delta eta tau J let r[n,i] = rn let V[n,i] = Vn in (r, V) let sweep [N] [T] (g: i64) (dt: f32) (nstep: i64) (i0: i64) (r: [N][T]f32) (V: [N][T]f32) (weights: [N][N]f32) (idelays: [N][N]i64) = let Gs = tabulate g (\i -> 0.0 + (f32.i64 i) * 0.1) let do_one_seq G = mpr_integrate_seq dt nstep i0 (copy r) (copy V) weights idelays G 0.0f32 0.7f32 (-4.6f32) 1.0f32 14.5f32 in map do_one_seq Gs let main (ng: i64) (nh: i64) (nt: i64) (nn: i64) idelays weights r V = let dt = 0.01f32 let f dt = sweep ng dt nt nh r V weights idelays let x = f dt in vjp f dt x futhark-0.25.27/tests/issue1499.fut000066400000000000000000000002301475065116200166760ustar00rootroot00000000000000def f (xs: *[][]bool) = #[unsafe] let a = xs[0] let b = copy a let xs[0,1] = true in (b[0], xs[0,0]) def main A = map (\xs -> f (copy xs)) A futhark-0.25.27/tests/issue1500.fut000066400000000000000000000011571475065116200166660ustar00rootroot00000000000000def secant_method (f:f32 -> f32) (a:f32) (b:f32) (tol:f32) = (loop (x_1, x_2, f_1, f_2) = (a, b, f a, f b) while (f32.abs (x_1-x_2) > tol) do let x_0 = x_1 - f_1 * (x_1 - x_2) / (f_1 - f_2) let f_0 = f x_0 in (x_0, x_1, f_0, f_1)).0 def minimum_D (stress:f32 -> f32) = let f D = stress D in secant_method f 0.01 0.1 def running (turning:bool) (R:f32) (v:f32) (a:f32) = let n = 20000i64 let zs = tabulate n (\i -> f32.from_fraction i n) let maximum_stress D = D+2 let D = minimum_D maximum_stress in (zs, D) def turning = running true 3 2 0 def main = turning.1 futhark-0.25.27/tests/issue1501.fut000066400000000000000000000007271475065116200166710ustar00rootroot00000000000000def effective_stress (d:f32) (_F_x, _F_y, F_z) = 0 def running (turning:bool) (R:f32) (v:f32) (a:f32) = let n = 20000 let zs = tabulate n (\i -> f32.from_fraction i n) let (Fs, Ms) = unzip (map (\_ -> ((0,0,0),(0,0,0))) zs) in (zs, unzip3 Fs, unzip3 Ms) def straight = running false 0 0 0 def zs_str = straight.0 def uczip3 (a,b,c) = zip3 a b c def stress Fs Ms = map2 effective_stress zs_str (uczip3 Fs) entry stress_str = stress (straight.1) (straight.2) futhark-0.25.27/tests/issue1505.fut000066400000000000000000000002211475065116200166620ustar00rootroot00000000000000def main n b = let xs = replicate n 0 let xs' = if b then loop xs = copy xs for i < 10 do xs with [i] = 0 else xs in if b then xs else xs' futhark-0.25.27/tests/issue1510.fut000066400000000000000000000006361475065116200166700ustar00rootroot00000000000000module type MyModuleType = { val dummy: i32 } module MyModuleOps (thisMod: MyModuleType) = { def test = (copy (1...10i32)) } module MyModule: MyModuleType = { def dummy = 0i32 } module MyModule_Ops = MyModuleOps MyModule module MyModule2: MyModuleType = { def dummy = 1i32 } module MyModule2_Ops = MyModuleOps MyModule2 entry testfn = if true then MyModule_Ops.test else MyModule2_Ops.test futhark-0.25.27/tests/issue1512.fut000066400000000000000000000003641475065116200166700ustar00rootroot00000000000000module type Size = { val n: i64 } module SizeOps (size: Size) = { def test (arr: []i64) = arr :> [size.n]i64 } module mySize = { def n = 4i64 } : Size module mySizeOps = SizeOps mySize entry main (arr: []i64) = mySizeOps.test arr futhark-0.25.27/tests/issue1523.fut000066400000000000000000000046461475065116200167010ustar00rootroot00000000000000def dotprod [n] (a: [n]f32) (b: [n]f32): f32 = map2 (*) a b |> reduce (+) 0 def lud_diagonal [b] (a: [b][b]f32): *[b][b]f32 = let mat = copy a in #[unsafe] loop (mat: *[b][b]f32) for i < b-1 do let col = map (\j -> if j > i then (mat[j,i] - (dotprod mat[j,:i] mat[:i,i])) / mat[i,i] else mat[j,i]) (iota b) let mat[:,i] = col let row = map (\j -> if j > i then mat[i+1, j] - (dotprod mat[:i+1, j] mat[i+1, :i+1]) else mat[i+1, j]) (iota b) let mat[i+1] = row in mat def lud_perimeter_upper [m][b] (diag: [b][b]f32) (a0s: [m][b][b]f32): *[m][b][b]f32 = let a1s = map (\ (x: [b][b]f32): [b][b]f32 -> transpose(x)) a0s in let a2s = map (\a1 -> map (\row0 -> -- Upper #[unsafe] loop row = copy row0 for i < b do let sum = loop sum=0.0f32 for k < i do sum + diag[i,k] * row[k] let row[i] = row[i] - sum in row ) a1 ) a1s in map transpose a2s def lud_perimeter_lower [b][m] (diag: [b][b]f32) (mat: [m][b][b]f32): *[m][b][b]f32 = map (\blk -> map (\row0 -> -- Lower #[unsafe] loop row = copy row0 for j < b do let sum = loop sum=0.0f32 for k < j do sum + diag[k,j] * row[k] let row[j] = (row[j] - sum) / diag[j,j] in row ) blk ) mat def block_size: i64 = 32 def main [num_blocks][n] (matb: *[num_blocks][num_blocks][n][n]f32) b = #[unsafe] let matb = loop matb for step < (n / b) - 1 do let diag = lud_diagonal matb[step,step] in let row_slice = matb[step,step+1:num_blocks] let top_per_irreg = lud_perimeter_upper diag row_slice let col_slice = matb[step+1:num_blocks,step] let lft_per_irreg = lud_perimeter_lower diag col_slice let inner_slice = matb[step+1:num_blocks,step+1:num_blocks] let matb[step, step] = diag let matb[step, step+1:num_blocks] = top_per_irreg in matb in matb futhark-0.25.27/tests/issue1524.fut000066400000000000000000000015451475065116200166750ustar00rootroot00000000000000-- == -- compiled random input { [2][2][32][32]f32 [32][32]f32 } auto output let dotprod [n] (a: [n]f32) (b: [n]f32): f32 = map2 (*) a b |> reduce (+) 0 let lud_perimeter_lower [b][m] (diag: [b][b]f32) (mat: [m][b][b]f32): *[m][b][b]f32 = map (\blk -> #[incremental_flattening(only_intra)] map (\row0 -> -- Lower #[unsafe] loop row = copy row0 for j < b do let sum = loop sum=0.0f32 for k < j do sum + diag[k,j] * row[k] let row[j] = (row[j] - sum) / diag[j,j] in row ) blk ) mat let main [num_blocks][b] (matb: *[num_blocks][num_blocks][b][b]f32) (diag: [b][b]f32) = let step = 0 let col_slice = matb[step+1:num_blocks,step] let lft_per_irreg = lud_perimeter_lower diag col_slice in lft_per_irreg futhark-0.25.27/tests/issue1525.fut000066400000000000000000000001231475065116200166650ustar00rootroot00000000000000def main s = [0,1] |> map (\c -> (c...(c+1)) |> map(\k -> s[k]) |> reduce (+) 0) futhark-0.25.27/tests/issue1531.fut000066400000000000000000000002131475065116200166620ustar00rootroot00000000000000type~ t0 = ?[n][m].([n]i64, [m]i64) type~ t1 = (t0, t0) def main: t1 = let a = (iota 1, iota 2) let b = (iota 3, iota 4) in (a, b) futhark-0.25.27/tests/issue1533.fut000066400000000000000000000044241475065116200166740ustar00rootroot00000000000000-- In OpenCL backend: subtle bug caused by interchange not respecting -- permutations. -- In multicore backend: invalid hoisting after double buffering. -- == -- compiled random input { [20][20][20]f64 } auto output type complex = {r:f64, i:f64} def complex r i : complex = {r, i} def real r = complex r 0 def imag i = complex 0 i def zero = complex 0 0 def conjC (a:complex) : complex = {r = a.r, i = -a.i} def addC (a:complex) (b:complex) : complex = {r=a.r+b.r, i=a.i+b.i} def subC (a:complex) (b:complex) : complex = {r=a.r-b.r, i=a.i-b.i} def mulC (a:complex) (b:complex) : complex = {r=a.r*b.r-a.i*b.i, i=a.r*b.i+a.i*b.r} type triad 't = (t, t, t) def triadMap 'a 'b (f:a->b) (A:triad a) : triad b = (f A.0, f A.1, f A.2) def triadMap2 'a 'b 'c (f:a->b->c) (A:triad a) (B:triad b): triad c = (f A.0 B.0, f A.1 B.1, f A.2 B.2) def triadFold 'a (f:a->a->a) (A:triad a) : a = f A.0 <| f A.1 A.2 type v3 = triad f64 def v3sum (v:v3) : f64 = triadFold (+) v def v3add (a:v3) (b:v3) : v3 = triadMap2 (+) a b def v3sub (a:v3) (b:v3) : v3 = triadMap2 (-) a b def v3mul (a:v3) (b:v3) : v3 = triadMap2 (*) a b def v3dot (a:v3) (b:v3) : f64 = v3mul a b |> v3sum def scaleV3 (f:f64) = triadMap (*f) def v3abs a = f64.sqrt (v3dot a a) def fromReal : (f64 -> complex) = real def fromReal1d = map fromReal def fromReal2d = map fromReal1d def fromReal3d = map fromReal2d def toReal : (complex -> f64) = (.r) def toReal1d = map toReal def toReal2d = map toReal1d def toReal3d = map toReal2d def gfft [n] (xs:[n]complex) : [n]complex = let startTheta = f64.pi * f64.from_fraction 2 n let ms = n >> 1 let iteration ((xs:[n]complex), e, theta0) = let modc = (1 << e) - 1 let xs' = tabulate n (\i -> let q = i & modc let p'= i >> e let p = p'>> 1 let ai = q + (p << e) let bi = ai + ms let a = xs[ai] let b = xs[bi] let theta = theta0 * f64.i64 p in mulC (complex (f64.cos theta) (-f64.sin theta)) (subC a b)) in (xs', e + 1, theta0 * 2) in (iterate 2 iteration (xs, 0, startTheta)).0 def fft3 [m][n][k] (A:[m][n][k]complex) = #[unsafe] #[incremental_flattening(only_inner)] tabulate_2d n k (\i j -> gfft A[:,i,j]) entry main grid = fromReal3d grid |> fft3 |> toReal3d futhark-0.25.27/tests/issue1535.fut000066400000000000000000000002261475065116200166720ustar00rootroot00000000000000type sumType = #some ([0]i32) | #none entry main = (\(x: sumType) -> match x case (#some y) -> id y case _ -> []) (#none: sumType) futhark-0.25.27/tests/issue1537.fut000066400000000000000000000004761475065116200167030ustar00rootroot00000000000000module type scenario = { val numbers: () -> []i64 val n_numbers: i64 } module scenario: scenario = { def numbers (): []i64 = [] def n_numbers: i64 = loop i = 0 while length (numbers ()) != 0 do i + 1 } entry main = loop s = scenario.numbers () for _i < scenario.n_numbers - 1 do s ++ [1] futhark-0.25.27/tests/issue1545.fut000066400000000000000000000004621475065116200166750ustar00rootroot00000000000000-- == -- input { [[1,2,3],[4,5,6]] } -- output { [[11i32, 0i32, 0i32], [14i32, 0i32, 0i32]] } let main [n] (xss: [][n]i32) = #[sequential_inner] map (\xs -> let xs' = loop xs = copy xs for _i < 10 do map (+1) xs let ys = replicate n 0 let ys[0] = xs'[0] in ys) xss futhark-0.25.27/tests/issue1546.fut000066400000000000000000000001671475065116200167000ustar00rootroot00000000000000let main (b: bool) (i: i64) (xs: []i64) = #[unsafe] if b then (iota xs[i+1],xs[0::2]) else (xs[0::2],iota xs[i+1]) futhark-0.25.27/tests/issue1552.fut000066400000000000000000000000721475065116200166700ustar00rootroot00000000000000type T 'a = i64 let f 'a (x: T a): T a = x let main = f futhark-0.25.27/tests/issue1553.fut000066400000000000000000000021101475065116200166640ustar00rootroot00000000000000let argmax arr = reduce_comm (\(a,i) (b,j) -> if a < b then (b,j) else if b < a then (a,i) else if j < i then (b,j) else (a,i)) (0f32, 0) (zip arr (indices arr)) let gaussian_elimination [n] [m] (A: [m][n]f32): [m][n]f32 = loop A for i < i64.min m n do let value j x = if j >= i then f32.abs x else -f32.inf let j = A[:,i] |> map2 value (indices A) |> argmax |> (.1) let f = (1-A[i,i]) / A[j,i] let irow = map2 (f32.fma f) A[j] A[i] in tabulate m (\j -> let f = A[j,i] * -1 in map2 (\x y -> if j == i then x else f32.fma f x y) irow A[j]) let createIdentity n = tabulate n (\i -> replicate n 0f32 with [i] = 1f32) let matrix_inverse [n] (A: [n][n]f32): [n][n]f32 = let I = createIdentity n let AI = map2 concat A I let A1I= gaussian_elimination AI let A1 = A1I[:,n:] :> [n][n]f32 in A1 let main [k] [n] (As: [k][n][n]f32) :[k][n][n]f32 = map (\A -> matrix_inverse A) As futhark-0.25.27/tests/issue1557.fut000066400000000000000000000001221475065116200166710ustar00rootroot00000000000000-- == -- error: Cannot unify def get (cond: bool): #a = if cond then #a else 0 futhark-0.25.27/tests/issue1559.fut000066400000000000000000000004371475065116200167040ustar00rootroot00000000000000-- == -- random input { [100][10]i32 } auto output let opt [n][m] (arr: [m][n]i32) = reduce (map2 (\(x,i) (y,j) -> if x < y then (y, j) else (x, i)) ) (replicate n (i32.lowest, 0)) (map2 (\r i -> zip r (replicate n i)) arr (iota m)) let main arr = let (vs, is) = unzip (opt arr) in is futhark-0.25.27/tests/issue1569.fut000066400000000000000000000042661475065116200167110ustar00rootroot00000000000000let dotprod [n] (a: [n]f32) (b: [n]f32): f32 = map2 (*) a b |> reduce (+) 0 let lud_diagonal [b] (a: [b][b]f32): *[b][b]f32 = let mat = copy a in loop (mat: *[b][b]f32) for i < b-1 do let col = map (\j -> if j > i then #[unsafe] (mat[j,i] - (dotprod mat[j,:i] mat[:i,i])) / mat[i,i] else mat[j,i]) (iota b) let mat[:,i] = col let row = map (\j -> if j > i then mat[i+1, j] - (dotprod mat[:i+1, j] mat[i+1, :i+1]) else mat[i+1, j]) (iota b) let mat[i+1] = row in mat let lud_perimeter_upper [m] (diag: [16][16]f32, a0s: [m][16][16]f32): *[m][16][16]f32 = let a1s = map (\ (x: [16][16]f32): [16][16]f32 -> transpose(x)) a0s in let a2s = map (\a1: [16][16]f32 -> map (\row0: [16]f32 -> -- Upper loop row = copy row0 for i < 16 do let sum = (loop sum=0.0f32 for k < i do sum + diag[i,k] * row[k]) let row[i] = row[i] - sum in row ) a1 ) a1s in map (\x: [16][16]f32 -> transpose(x)) a2s let lud_perimeter_lower [m] (diag: [16][16]f32, mat: [m][16][16]f32): *[m][16][16]f32 = map (\blk: [16][16]f32 -> map (\ (row0: [16]f32): *[16]f32 -> -- Lower loop row = copy row0 for j < 16 do let sum = loop sum=0.0f32 for k < j do sum + diag[k,j] * row[k] let row[j] = (row[j] - sum) / diag[j,j] in row ) blk ) mat let main [num_blocks] (matb: *[num_blocks][num_blocks][16][16]f32): [num_blocks][num_blocks][16][16]f32 = let step = 0 let diag = lud_diagonal(matb[step,step]) in let matb[step,step] = diag let row_slice = matb[step,step+1:num_blocks] let top_per_irreg = lud_perimeter_upper(diag, row_slice) let matb[step, step+1:num_blocks] = top_per_irreg let col_slice = matb[step+1:num_blocks,step] let lft_per_irreg = lud_perimeter_lower(diag, col_slice) let matb[step+1:num_blocks, step] = lft_per_irreg in matb futhark-0.25.27/tests/issue1572.fut000066400000000000000000000002211475065116200166660ustar00rootroot00000000000000let main xss = map (\(xs: []i32) -> loop xs = zip (copy xs) (copy xs) for i < 10 do xs with [0] = (xs[0].1 + 1,2)) xss futhark-0.25.27/tests/issue1579.fut000066400000000000000000000171441475065116200167110ustar00rootroot00000000000000-- Issue was memory expansion did not handle an allocation inside -- nested SegOps in a Group SegOp. And _this_ was because the size -- was not hoisted all the way out, because there was an If -- surrounding the inner SegThread SegOps. And _this_ was because we -- lost the information that the size was actually used as a size. -- Tiling produced such a SegOp. In most cases that version will then -- be discarded, but the attributes ensure that only one version is -- produced. module loess_m = { module T = f64 type t = T.t let filterPadWithKeys [n] 't (p : (t -> bool)) (dummy : t) (arr : [n]t) : ([n]t, [n]i64, i64) = let tfs = map (\a -> if p a then 1i64 else 0i64) arr let isT = scan (+) 0i64 tfs let i = last isT let inds= map2 (\a iT -> if p a then iT - 1 else -1i64) arr isT let rs = scatter (replicate n dummy) inds arr let ks = scatter (replicate n (-1i64)) inds (iota n) in (rs, ks, i) -------------------------------------------------------------------------------- -- Main LOESS procedure - outer parallel version, with extra work -- -------------------------------------------------------------------------------- let loess_outer [n] [n_m] (xx: [n]i64) (yy: [n]t) (q: i64) (ww: [n]t) (l_idx: [n_m]i64) (lambda: [n_m]t) (n_nn: i64) = let q_slice 'a (arr: [n]a) (l_idx_i: i64) (v: a) (add: a -> a -> a) (zero: a): [q]a = #[unsafe] tabulate q (\j -> if j >= n_nn then zero else add arr[l_idx_i + j] v) -- need the duplicate to prevent manifestation let q_slice' 'a (arr: [n]a) (l_idx_i: i64) (v: a) (add: a -> a -> a) (zero: a): [q]a = #[unsafe] tabulate q (\j -> if j >= n_nn then zero else add arr[l_idx_i + j] v) in -- [n_m] #[sequential_inner] map2 (\l_idx_i lambda_i -> ----------------------------------- -- REDOMAP 1 ----------------------------------- #[unsafe] let xx_slice = q_slice xx l_idx_i 1 (+) 0 let ww_slice = q_slice ww l_idx_i 0 (+) 0 let (w, xw, x2w) = map2 (\xx_j ww_j -> -- let x_j = (xx_j - m_fun i) |> T.i64 let x_j = xx_j |> T.i64 -- tricube let r = T.abs x_j let tmp1 = r / lambda_i let tmp2 = 1.0 - tmp1 * tmp1 * tmp1 let tmp3 = tmp2 * tmp2 * tmp2 -- scale by user-defined weights let w_j = tmp3 * ww_j let xw_j = x_j * w_j let x2w_j = x_j * xw_j in (w_j, xw_j, x2w_j) ) xx_slice ww_slice |> unzip3 -- then, compute fit and slope based on polynomial degree let a = T.sum w + T.epsilon let b = T.sum xw + T.epsilon let c = T.sum x2w + T.epsilon -- degree 1 let det1 = 1 / (a * c - b * b) let a11 = c * det1 let b11 = -b * det1 ----------------------------------- -- REDOMAP 2 ----------------------------------- let w' = map2 (\xx_j ww_j -> let x_j = xx_j |> T.i64 -- tricube let r = T.abs x_j let tmp1 = r / lambda_i let tmp2 = 1.0 - tmp1 * tmp1 * tmp1 let tmp3 = tmp2 * tmp2 * tmp2 -- scale by user-defined weights in tmp3 * ww_j ) xx_slice ww_slice -- then, compute fit and slope based on polynomial degree let yy_slice = q_slice' yy l_idx_i 0 (+) 0 in map3 (\w_j yy_j xw_j -> (w_j * a11 + xw_j * b11) * yy_j) w' yy_slice xw |> T.sum ) l_idx lambda let loess_outer_l [m] [n] [n_m] (xx_l: [m][n]i64) (yy_l: [m][n]t) (q: i64) (ww_l: [m][n]t) (l_idx_l: [m][n_m]i64) (lambda_l: [m][n_m]t) (n_nn_l: [m]i64) = #[incremental_flattening(no_intra)] map5 (\xx yy ww l_idx (lambda, n_nn) -> loess_outer xx yy q ww l_idx lambda n_nn ) xx_l yy_l ww_l l_idx_l (zip lambda_l n_nn_l) let l_indexes [N] (nn_idx: [N]i64) (m_fun: i64 -> i64) (n_m: i64) (q: i64) (n_nn: i64): [n_m]i64 = -- [n_m] tabulate n_m (\i -> let x = m_fun i -- use binary search to find the nearest idx let (init_idx, _) = -- O(log N) loop (low, high) = (0i64, N - 1) while low <= high do let mid = (low + high) / 2 let mid_id = nn_idx[mid] let mid_idx = if mid_id < 0 then i64.highest else mid_id in if mid_idx >= x then (low, mid - 1) else (mid + 1, high) let (idx, _, _) = -- find the neighbor interval, starting at init_idx loop (l_idx, r_idx, span) = (init_idx, init_idx, 1) while span < q do -- O(q) let l_cand = i64.max (l_idx - 1) 0 let r_cand = i64.min (r_idx + 1) (n_nn - 1) let l_dist = i64.abs (nn_idx[l_cand] - x) let r_dist = i64.abs (nn_idx[r_cand] - x) in if l_cand == l_idx then (l_idx, r_idx, q) -- leftmost found, return else if l_dist < r_dist || r_cand == r_idx then (l_cand, r_idx, span + 1) -- expand to the left else (l_idx, r_cand, span + 1) -- expand to the right let res_idx = i64.max (i64.min (n_nn - q) idx) 0 in res_idx ) let find_lambda [n_m] (y_idx: []i64) (l_idx: [n_m]i64) (m_fun: i64 -> i64) (q: i64) (n_nn: i64) : [n_m]t= map2 (\l i -> let mv = m_fun i let q' = i64.min q n_nn let r = l + q' - 1 let md_i = i64.max (i64.abs (y_idx[l] - mv)) (i64.abs (y_idx[r] - mv)) |> T.i64 in md_i + T.max (((T.i64 q) - (T.i64 n_nn)) / 2) 0 ) l_idx (iota n_m) let loess_params [N] (q: i64) (m_fun: i64 -> i64) (n_m: i64) (y_idx: [N]i64) (n_nn: i64) : ([n_m]i64, [n_m]t) = let y_idx_p1 = (y_idx |> map (+1)) let q3 = i64.min q N -- [n_m] let l_idx = l_indexes y_idx_p1 (m_fun >-> (+1)) n_m q3 n_nn let lambda = find_lambda y_idx l_idx m_fun q n_nn in (l_idx, lambda) } entry main [m] [n] (Y: [m][n]f64) (q: i64) (jump: i64) = -- set up parameters for the low-pass filter smoothing let n_m = if jump == 1 then n else n / jump + 1 let m_fun (x: i64): i64 = i64.min (x * jump) (n - 1) -- filter nans and pad non-nan indices let (nn_y_l, nn_idx_l, n_nn_l) = map (loess_m.filterPadWithKeys (\i -> !(f64.isnan i)) 0) Y |> unzip3 -- calculate invariant arrays for the low-pass filter smoothing let (l_idx_l, lambda_l) = map2 (\nn_idx n_nn -> loess_m.loess_params q m_fun n_m nn_idx n_nn ) nn_idx_l n_nn_l |> unzip let weights_l = replicate (m * n) 1f64 |> unflatten in loess_m.loess_outer_l nn_idx_l nn_y_l q weights_l l_idx_l lambda_l n_nn_l futhark-0.25.27/tests/issue1593.fut000066400000000000000000000021661475065116200167030ustar00rootroot00000000000000-- == -- structure gpu { Replicate 0 } let lud_perimeter_upper [m][b] (diag: [b][b]f32, a0s: [m][b][b]f32): *[m][b][b]f32 = let a1s = map (\ (x: [b][b]f32): [b][b]f32 -> transpose(x)) a0s in let a2s = map (\a1: [b][b]f32 -> map (\row0: [b]f32 -> -- Upper loop row = copy row0 for i < b do let sum = (loop sum=0.0f32 for k < i do sum + diag[i,k] * row[k]) let row[i] = row[i] - sum in row ) a1 ) a1s in map (\x: [b][b]f32 -> transpose(x)) a2s let main [num_blocks] (matb: *[num_blocks][num_blocks][32][32]f32): *[num_blocks][num_blocks][32][32]f32 = #[unsafe] let matb = loop(matb) for step < num_blocks - 1 do -- 1. compute the current diagonal block let diag = matb[step,step] -- 2. compute the top perimeter let row_slice = matb[step,step+1:num_blocks] let top_per_irreg = lud_perimeter_upper(diag, row_slice) -- 5. update matrix in place let matb[step, step+1:num_blocks] = top_per_irreg in matb in matb futhark-0.25.27/tests/issue1599.fut000066400000000000000000000000541475065116200167030ustar00rootroot00000000000000-- == -- error: Occurs let bad a f = f a f futhark-0.25.27/tests/issue1609.fut000066400000000000000000000010761475065116200167000ustar00rootroot00000000000000-- == -- input { [0i64,0i64] [[1f32,2f32],[3f32,4f32]] [[0f32,0f32]] } -- output { [[3f32,4f32]] [[1i64,1i64]] } let argmax (x: f32, i: i64) (y: f32, j: i64) = if x == y then (x, i64.max i j) else if x > y then (x,i) else (y,j) let main [n][m][k] (is: [n]i64) (vs: [n][k]f32) (dst: [m][k]f32) = let dst_cpy = copy dst let res = reduce_by_index (map2 zip dst_cpy (replicate m (replicate k (-1)))) (map2 argmax) (replicate k (f32.lowest, -1)) is (map2 zip vs (map (replicate k) (iota n))) in unzip (map unzip res) futhark-0.25.27/tests/issue1610.fut000066400000000000000000000153611475065116200166720ustar00rootroot00000000000000-- Specific number of SegMaps in this test is not so important - we -- just shouldn't get rid of the versions. -- == -- structure gpu-mem { If/True/SegMap 3 If/False/If/True/SegMap 2 If/False/If/False/If/True/SegMap 1 } module type rand ={ type rng val init : i32 -> rng val rand : rng -> (rng, i32) val split : rng -> (rng,rng) val split_n : (n: i64) -> rng -> [n]rng } -- random module taken from Futharks webpage module lcg : rand = { type rng = u32 def addConst : u32 = 1103515245 def multConst : u32 = 12345 def modConst : u32 = 1<<31 def rand rng = let rng' = (addConst * rng + multConst) % modConst in (rng', i32.u32 rng') def init (x: i32) : u32 = let x = u32.i32 x let x =((x >> 16) ^ x) * 0x45d9f3b let x =((x >> 16) ^ x) * 0x45d9f3b let x =((x >> 16) ^ x) in x def split (rng: rng) = (init (i32.u32 (rand rng).0), init (i32.u32 rng)) def split_n n rng = tabulate n (\i -> init (i32.u32 rng ^ i32.i64 i)) } -- This function swaps the two edges that produce the lowest cost let swap [m] (i : i32) (j : i32) (tour : [m]i32) : [m]i32 = let minI = i+1 in map i32.i64 (iota m) |> map(\ind -> if ind < minI || ind > j then tour[ind] else tour[j - (ind - minI)] ) def rand_nonZero (rng: lcg.rng) (bound: i32) = let (rng,x) = lcg.rand rng in if (x % bound) > 0 then (rng, x % bound) else (rng, (x % bound) + 1) let mkRandomTour (offset:i64) (cities:i32) : []i32 = let rng = lcg.init (i32.i64 offset) --let randIndArr = map (\i -> -- if i == 0 then rand_i32 rng cities -- else rand_i32 (i-1).0 cities ) iota cities let initTour = map (\i -> if i == (i64.i32 cities) then 0 else i) (iota ((i64.i32 cities)+1)) |> map i32.i64 let intialI = rand_nonZero rng (cities) let intialJ = rand_nonZero intialI.0 (cities) let randomSwaps = rand_nonZero intialJ.0 100 let rs = loop (intialI,intialJ,initTour) for i < randomSwaps.1 do let intI = rand_nonZero intialI.0 (cities) let intJ = rand_nonZero intI.0 (cities) let swappedTour = swap intialI.1 intialJ.1 initTour in (intI, intJ, swappedTour) in rs.2 -- mkFlagArray is taken from PMPH lecture notes p. 48 let mkFlagArray 't [m] (aoa_shp: [m]i32) (zero: t) (aoa_val: [m]t) : []t = let shp_rot = map (\i -> if i == 0 then 0 else aoa_shp[i-1] ) (iota m) let shp_scn = scan (+) 0 shp_rot let aoa_len = shp_scn[m-1]+ aoa_shp[m-1] |> i64.i32 let shp_ind = map2 (\shp ind -> if shp == 0 then -1 else ind ) aoa_shp shp_scn in scatter (replicate aoa_len zero) (map i64.i32 shp_ind) aoa_val -- segmented_scan is taken from PMPH Futhark code let segmented_scan [n] 't (op: t -> t -> t) (ne: t) (flags: [n]bool) (arr: [n]t) : [n]t = let (_, res) = unzip <| scan (\(x_flag,x) (y_flag,y) -> let fl = x_flag || y_flag let vl = if y_flag then y else op x y in (fl, vl) ) (false, ne) (zip flags arr) in res -- Comparator function used in forLoops in the reduce part let changeComparator (t1 : (i32, i32, i32)) (t2: (i32, i32, i32)) : (i32, i32, i32) = if t1.0 < t2.0 then t1 else if t1.0 == t2.0 then if t1.1 < t2.1 then t1 else if t1.1 == t2.1 then if t1.2 < t2.2 then t1 else t2 else t2 else t2 -- finds the best cost of two input costs let costComparator (cost1: i32) (cost2 :( i32)) : i32 = if cost1 < cost2 then cost1 else cost2 -- findMinChange is the parallel implementation of the two for loops -- in the 2-opt move algorithm let findMinChange [m] [n] [x] [y] (dist : [m]i32) (tour : [n]i32) (Iarr : [x]i32) (Jarr : [y]i32) (cities : i32) (totIter : i64) : (i32, i32, i32) = let changeArr = map (\ind -> let i = Iarr[ind] let iCity = tour[i] let iCityp1 = tour[i+1] let j = Jarr[ind] + i + 2 let jCity = tour[j] let jCityp1 = tour[j+1] in ((dist[iCity * cities + jCity] + dist[iCityp1 * cities + jCityp1] - (dist[iCity * cities + iCityp1] + dist[jCity * cities + jCityp1])), i, j) ) (iota totIter) in reduce changeComparator (2147483647, -1, -1) changeArr -- 2-opt algorithm let twoOptAlg [m] [n] [x] [y] (distM : [m]i32) (tour : [n]i32) (Iarr : [x]i32) (Jarr : [y]i32) (cities : i32) (totIter : i64) : []i32 = let twoOpt xs = let minChange = findMinChange distM xs Iarr Jarr cities totIter in if minChange.0 < 0 then (swap minChange.1 minChange.2 xs, minChange.0) else (xs, minChange.0) let rs = loop (xs, cond) = (tour, -1) while cond < 0 do twoOpt xs in rs.0 -- Compute cost of tour let cost [n] [m] (tour : [n]i32) (distM : [m]i32) : i32 = map (\i -> distM[tour[i] * i32.i64(n-1) + tour[i+1]] ) (iota (n-1)) |> reduce (+) 0 -- Generate flagArray let flagArrayGen (cities : i32) : ([]i32)= let len = i64.i32 (cities-2) let aoa_val = replicate len 1i32 let temp = map (+1) (iota len) |> map i32.i64 let aoa_shp = map (\i -> temp[len - i - 1] ) (iota len) in mkFlagArray aoa_shp 0i32 aoa_val --[m] let main [m] (cities : i32) (numRestarts : i64) (distM : [m]i32) : i32 = --let cities = 5 let totIter = ((cities-1)*(cities-2))/2 |> i64.i32 --let initTour = (iota (i64.i32 cities+1)) |> --map(\i -> (i+1)*10) --let oldCost = cost tour distM --let distM = [0,4,6,8,3, -- 4,0,4,5,2, -- 6,4,0,2,3, -- 8,5,2,0,4, -- 3,2,3,4,0] let flagArr = flagArrayGen cities let Iarr = scan (+) 0i32 flagArr |> map (\x -> x-1) let Jarr = segmented_scan (+) 0i32 (map bool.i32 (flagArr :> [totIter]i32)) (replicate totIter 1i32) |> map (\x -> x-1) let allCosts = map(\ind -> let tour = mkRandomTour ((ind+1)*13) cities let minTour = twoOptAlg distM tour Iarr Jarr cities totIter in cost minTour distM )(iota numRestarts) in reduce costComparator 2147483647 allCosts --in mkRandomTour 125 cities futhark-0.25.27/tests/issue1615.fut000066400000000000000000000001541475065116200166710ustar00rootroot00000000000000-- == -- structure { Update 1 } def main (A: *[3]i32) : *[3]i32 = let x = (id A)[2] in A with [1] = x futhark-0.25.27/tests/issue1627.fut000066400000000000000000000052361475065116200167020ustar00rootroot00000000000000type isoCoordinates = {xi: f64, eta: f64, zeta: f64} type index = {x: i64, y: i64, z: i64} def pt :f64 = 1f64/f64.sqrt(3f64) def quadpoints :[8]isoCoordinates = [{xi=pt,eta=pt,zeta=pt},{xi=(-pt),eta=pt,zeta=pt}, {xi=pt,eta=(-pt),zeta=pt},{xi=(-pt),eta=(-pt),zeta=pt}, {xi=pt,eta=pt,zeta=(-pt)},{xi=(-pt),eta=pt,zeta=(-pt)}, {xi=pt,eta=(-pt),zeta=(-pt)},{xi=(-pt),eta=(-pt),zeta=(-pt)}] def matmul_f64 [n][m][p] (A: [n][m]f64) (B: [m][p]f64) :[n][p]f64 = map (\A_row -> map (\B_col -> f64.sum (map2 (*) A_row B_col)) (transpose B)) A def Cmat :[6][6]f64 = [[0.35e2 / 0.26e2,0.15e2 / 0.26e2,0.15e2 / 0.26e2,0,0,0],[0.15e2 / 0.26e2,0.35e2 / 0.26e2,0.15e2 / 0.26e2,0,0,0],[0.15e2 / 0.26e2,0.15e2 / 0.26e2,0.35e2 / 0.26e2,0,0,0],[0,0,0,0.5e1 / 0.13e2,0,0],[0,0,0,0,0.5e1 / 0.13e2,0],[0,0,0,0,0,0.5e1 / 0.13e2]] def getB0 (iso :isoCoordinates) :[6][24]f64 = let t1 = 1 - iso.zeta let t2 = 1 + iso.zeta let t3 = 0.1e1 / 0.4e1 - iso.eta / 4 let t4 = t3 * t1 let t5 = t3 * t2 let t6 = 0.1e1 / 0.4e1 + iso.eta / 4 let t7 = t6 * t1 let t8 = t6 * t2 let t9 = 1 - iso.xi let t10 = 1 + iso.xi let t11 = t9 / 4 let t12 = t11 * t2 let t13 = t10 / 4 let t14 = t13 * t1 let t2 = t13 * t2 let t1 = t11 * t1 let t11 = t3 * t10 let t10 = t6 * t10 let t3 = t3 * t9 let t6 = t6 * t9 in [[-t4,0,0,t4,0,0,t7,0,0,-t7,0,0,-t5,0,0,t5,0,0,t8,0,0,-t8,0,0], [0,-t1,0,0,-t14,0,0,t14,0,0,t1,0,0,-t12,0,0,-t2,0,0,t2,0,0,t12,0], [0,0,-t3,0,0,-t11,0,0,-t10,0,0,-t6,0,0,t3,0,0,t11,0,0,t10,0,0,t6], [-t1,-t4,0,-t14,t4,0,t14,t7,0,t1,-t7,0,-t12,-t5,0,-t2,t5,0,t2,t8,0,t12,-t8,0], [0,-t3,-t1,0,-t11,-t14,0,-t10,t14,0,-t6,t1,0,t3,-t12,0,t11,-t2,0,t10,t2,0,t6,t12], [-t3,0,-t4,-t11,0,t4,-t10,0,t7,-t6,0,-t7,t3,0,-t5,t11,0,t5,t10,0,t8,t6,0,-t8]] def getQuadraturePointStiffnessMatrix (youngsModule :f64) (iso :isoCoordinates) :[24][24]f64 = let B0 = getB0 iso let C = map (map (*youngsModule)) (copy Cmat) let inter_geom = matmul_f64 (transpose B0) C in matmul_f64 inter_geom B0 def assembleElementNonlinearStiffnessMatrix (youngsModule :f64) :[24][24]f64 = map (getQuadraturePointStiffnessMatrix youngsModule) quadpoints |> transpose |> map transpose |> map (map f64.sum) def getElementStiffnessDiagonal [nelx][nely][nelz] (x :[nelx][nely][nelz]f32) (elementIndex :index) = let xloc = f64.f32 (x[elementIndex.x,elementIndex.y,elementIndex.z]) let kt = assembleElementNonlinearStiffnessMatrix xloc in kt[8,8] entry main [nelx][nely][nelz] (x :[nelx][nely][nelz]f32) = tabulate_3d nelx nely nelz (\i j k -> getElementStiffnessDiagonal x {x=i,y=j,z=k}) -- == -- entry: main -- input { [[[1f32, 1f32]]] } auto output futhark-0.25.27/tests/issue1628.fut000066400000000000000000000001361475065116200166750ustar00rootroot00000000000000def main b (xs: *[2][3]i64) = let v = if b then iota 3 else copy xs[1] in xs with [0] = v futhark-0.25.27/tests/issue1631.fut000066400000000000000000000002221475065116200166630ustar00rootroot00000000000000-- Problem was that this loop was mistakenly turned into a stream. let main = let go _ = loop _ = [] for _ in [0,1] do [0i32] in (go 0, go 1) futhark-0.25.27/tests/issue1653.fut000066400000000000000000000030311475065116200166700ustar00rootroot00000000000000module type mat = { type t type~ mat [n][m] val eye : (n:i64) -> (m:i64) -> mat[n][m] val dense [n][m] : mat[n][m] -> [n][m]t } module type sparse = { type t type~ csr [n][m] type~ csc [n][m] module csr : { include mat with t = t with mat [n][m] = csr[n][m] } module csc : { include mat with t = t with mat [n][m] = csc[n][m] } } module sparse (T : numeric) : sparse with t = T.t = { type t = T.t module csr = { type t = t type~ mat [n][m] = ?[nnz]. {dummy_m : [m](), row_off : [n]i64, col_idx : [nnz]i64, vals : [nnz]t} def eye (n:i64) (m:i64) : mat[n][m] = let e = i64.min n m let one = T.i64 1 let row_off = (map (+1) (iota e) ++ replicate (i64.max 0 (n-e)) e) :> [n]i64 in {dummy_m = replicate m (), row_off=row_off, col_idx=iota e, vals=replicate e one } def dense [n][m] (_csr: mat[n][m]) : [n][m]t = let arr : *[n][m]t = tabulate_2d n m (\ _ _ -> T.i64 0) in arr } module csc = { type t = t def eye (n:i64) (m:i64) : csr.mat[m][n] = csr.eye m n def dense [n][m] (mat: csr.mat[n][m]) : [m][n]t = csr.dense mat |> transpose type~ mat[n][m] = csr.mat[m][n] } type~ csr[n][m] = csr.mat[n][m] type~ csc[n][m] = csc.mat[n][m] } module spa = sparse i32 module csr = spa.csr def main (n:i64) (m:i64) : *[n][m]i32 = csr.eye n m |> csr.dense futhark-0.25.27/tests/issue1669.fut000066400000000000000000000047611475065116200167120ustar00rootroot00000000000000def map_4d 'a 'x [n][m][l][k] (f: a -> x) (as: [n][m][l][k]a): [n][m][l][k]x = map (map (map (map f))) as def map2_4d 'a 'b 'x [n][m][l][k] (f: a -> b -> x) (as: [n][m][l][k]a) (bs: [n][m][l][k]b): [n][m][l][k]x = map2 (map2 (map2 (map2 f))) as bs -- Bookkeeping def uoffsets: ([27]i64,[27]i64,[27]i64) = ([-1i64, -1i64, -1i64, -1i64, -1i64, -1i64, -1i64, -1i64, -1i64, 0i64, 0i64, 0i64, 0i64, 0i64, 0i64, 0i64, 0i64, 0i64, 1i64, 1i64, 1i64, 1i64, 1i64, 1i64, 1i64, 1i64, 1i64], [-1i64, -1i64, -1i64, 0i64, 0i64, 0i64, 1i64, 1i64, 1i64, -1i64, -1i64, -1i64, 0i64, 0i64, 0i64, 1i64, 1i64, 1i64, -1i64, -1i64, -1i64, 0i64, 0i64, 0i64, 1i64, 1i64, 1i64], [-1i64, 0i64, 1i64, -1i64, 0i64, 1i64, -1i64, 0i64, 1i64, -1i64, 0i64, 1i64, -1i64, 0i64, 1i64, -1i64, 0i64, 1i64, -1i64, 0i64, 1i64, -1i64, 0i64, 1i64, -1i64, 0i64, 1i64]) type index = {x: i64, y: i64, z: i64} def indexIsInside (nelx :i64, nely :i64, nelz :i64) (idx :index) :bool = (idx.x >= 0 && idx.y >= 0 && idx.z >= 0 && idx.x < nelx && idx.y < nely && idx.z < nelz) def getLocalU [nx][ny][nz] (u :[nx][ny][nz][3]f64) (nodeIndex :index) :*[27*3]f64 = let (nodeOffsetsX,nodeOffsetsY,nodeOffsetsZ) = uoffsets let ni = nodeIndex in (map3 (\i j k -> if (indexIsInside (nx,ny,nz) {x=ni.x+i,y=ni.y+j,z=ni.z+k}) then #[unsafe] (u[ni.x+i,ni.y+j,ni.z+k]) else [0,0,0]) nodeOffsetsX nodeOffsetsY nodeOffsetsZ) |> flatten -- SOR SWEEP FILE def omega :f64 = 0.6 def sorNodeAssembled (mat :[3][27*3]f64) (uStencil :*[27*3]f64) (f :[3]f64) :*[3]f64 = -- extract value of own node, and zero let ux_old = copy uStencil[39] let uy_old = copy uStencil[40] let uz_old = copy uStencil[41] let uStencil = uStencil with [39] = 0 let uStencil = uStencil with [40] = 0 let uStencil = uStencil with [41] = 0 let S = map (\row -> (f64.sum (map2 (*) uStencil row))) mat let M = mat[:3,39:42] let rx = M[0,1]*uy_old + M[0,2]*uz_old let ux_new = (1/M[0,0]) * (f[0]-S[0]-rx) let ry = M[1,0]*ux_new + M[1,2]*uz_old let uy_new = (1/M[1,1]) * (f[1]-S[1]-ry) let rz = M[2,0]*ux_new + M[2,1]*uy_new let uz_new = (1/M[2,2]) * (f[2]-S[2]-rz) let uold = [ux_old, uy_old, uz_old] let unew = [ux_new, uy_new, uz_new] in map2 (\un uo -> omega*un + (1-omega)*uo) unew uold entry sorSweepAssembled [nx][ny][nz] (mat :[nx][ny][nz][3][27*3]f64) (f :[nx][ny][nz][3]f64) (u :[nx][ny][nz][3]f64) = tabulate_3d nx ny nz (\i j k -> let uloc = getLocalU u {x=i,y=j,z=k} in #[unsafe] sorNodeAssembled mat[i,j,k] uloc f[i,j,k]) futhark-0.25.27/tests/issue1685.fut000066400000000000000000000007121475065116200167000ustar00rootroot00000000000000def max(x: i32, y: i32): i32 = if x > y then x else y def min(x: i32, y:i32): i32 = if x < y then x else y def mapOp (x: i32): (i32,i32,i32) = ( max(x,0), max(x,0), x) def redOp(y: (i32,i32,i32)) (z: (i32,i32,i32)): (i32,i32,i32) = let (x0, m0,s0) = y let (x1, _, s1) = z let s2 = s0+s1 in ( max(x0,max(x1,s2-m0)) , min(s2,m0) , s2) entry main(xs: []i32): i32 = let (x, _, _) = reduce redOp (0,0,0) (map mapOp xs) in x futhark-0.25.27/tests/issue1700.fut000066400000000000000000000001561475065116200166660ustar00rootroot00000000000000-- == -- input { empty([0][2]i32) } -- output { empty([0]i32) } def main (xs: [][]i32) = (transpose xs)[0] futhark-0.25.27/tests/issue1711.fut000066400000000000000000000001711475065116200166650ustar00rootroot00000000000000def test = let x = [1, 2, 3, 4, 5] in let y = x |> map (\x -> x + x) with [0] = 10 in y |> i64.sum |> (+ 10) futhark-0.25.27/tests/issue1739.fut000066400000000000000000000005551475065116200167050ustar00rootroot00000000000000entry main [n] [m] (xs: [3][3]f32) (as: *[n][m]f32) = let f i = let v = filter (> 0) as[i] ++ as[i] in map (\x -> map (\y -> #[sequential] map (* f32.sum y * f32.sum x) v |> f32.sum) xs) xs |> map (map (+(f32.i64 i))) let bs = #[incremental_flattening(no_intra)] tabulate n f in bs futhark-0.25.27/tests/issue1741.fut000066400000000000000000000007571475065116200167020ustar00rootroot00000000000000-- == -- error: type Op.t = i64 type L0 = i64 type L1 = [1]i64 module type Const = { type t val x: t } module L0: Const = { type t = L0 def x : t = 0 } module L1: Const = { type t = L1 def x : t = [1i64] } module type Op = (X: Const) -> Const module type Container = { module Op: Op } module L0Op: Container = { module Op (X: Const with t = L0): Const = { type t = L0 def x = X.x + 1 } } module L2 = L0Op.Op(L1) entry main = L2.x futhark-0.25.27/tests/issue1744.fut000066400000000000000000000004461475065116200167000ustar00rootroot00000000000000-- == -- tags { no_opencl no_cuda no_hip no_pyopencl } entry main (xs: [3][3]f32) : [3]f32 = map (\x -> let ss = map (map2 (*) x) xs let h = 0 let (_,h) = loop (ss, h) for _ in iota 3 do (tail ss, f32.sum (flatten ss)) in h) xs futhark-0.25.27/tests/issue1749.fut000066400000000000000000000002251475065116200167000ustar00rootroot00000000000000-- == -- error: Consuming.*"empty" module Test = { let empty = [] } let consume (arr: *[]i64) = arr entry main (_: bool) = consume Test.empty futhark-0.25.27/tests/issue1753.fut000066400000000000000000000001241475065116200166710ustar00rootroot00000000000000def main (xss: [][]i32) vs = map (\xs -> map (\v -> copy xs with [0] = v) vs) xss futhark-0.25.27/tests/issue1755.fut000066400000000000000000000037761475065116200167130ustar00rootroot00000000000000-- == -- input{} auto output type turn = bool module game = { type position = {score:i32, seed:i64} type move = #move i32 | #nomove let pos_size = 2i64 let move_size = 2i64 def legal_moves (position: position) turn : []move = let minmax = if turn then 1 else -1 let move_num = 2 + position.seed%2 let nums = map (\r -> (r+1,i32.i64 (r%10))) (map (^position.seed) (iota move_num)) in map (\(_,x) -> #move (minmax * x)) nums def make_move (position: position) (move: move) : position = match move case #nomove -> position case #move move -> {score = position.score + move, seed = position.seed + 100 + i64.i32 move} def neutral_position : position = {score = 0, seed = 0} def add_position (a: position) (b: position) : position = {score = a.score + b.score, seed = a.seed + b.seed} } def bf_ab (position: game.position) (turn: turn) (depth: i64) = let (poss, _t, _shapes, _SS) = loop (poss, t, shapes, SS) = ([position], turn, [], []) for _ in 0.. length (game.legal_moves p t)) poss let Sscan = scan (+) 0 S let total_moves = last Sscan let B = rotate (-1) Sscan with [0] = 0 let scatter_poss pos offset = #[sequential] let moves = game.legal_moves pos t let results = map (game.make_move pos) moves let is = map (+offset) <| indices results let empty = replicate total_moves game.neutral_position in scatter empty is results let chunks = map2 scatter_poss poss B let ne = replicate total_moves game.neutral_position let new_poss = reduce (map2 game.add_position) ne chunks in (new_poss, not t, S ++ shapes, [length S] ++ SS) in (map (.score) poss, map (.seed) poss) def brute_force (position: game.position) (turn: turn) (depth: i64) = bf_ab position turn depth def bf_eval (depth: i64) (turn: turn) (position: game.position) = brute_force position turn depth def main = bf_eval 2 true {score=0,seed=306} futhark-0.25.27/tests/issue1757.fut000066400000000000000000000001561475065116200167020ustar00rootroot00000000000000-- == -- input { -5 } error: y > 0 def main (x: i32) = let y = x + 2 let z = assert (y>0) (x+2) in y+z futhark-0.25.27/tests/issue1758.fut000066400000000000000000000002161475065116200167000ustar00rootroot00000000000000-- == -- input { 10i64 } -- error: issue1758.fut:8 let main (i: i64): [i]i64 = let a = iota i let b = replicate i 0 in a with [:4] = b futhark-0.25.27/tests/issue1774.fut000066400000000000000000000005241475065116200167000ustar00rootroot00000000000000-- == -- compiled random input { [3]f32 [100]f32 [10]f32 } auto output -- structure mc-mem { Alloc 3 } let f [n] [m] [l] (xs: [n]f32) (ys: [m]f32) (zs: *[l]f32) : [n][m]f32 = map (\x -> map (\y -> f32.sum (map (*y) (map (*x) zs))) ys) xs entry main [n] [m] [l] : [n]f32 -> [m]f32 -> *[l]f32 -> [n][m]f32 = f futhark-0.25.27/tests/issue1780.fut000066400000000000000000000003221475065116200166710ustar00rootroot00000000000000def (>->>) '^a '^b '^c (f: a -> b) (g: b -> c) (x: a): c = g (f x) def (<-<<) '^a '^b '^c (g: b -> c) (f: a -> b) (x: a): c = g (f x) def compose2 = ((>->) (<-<) <-<) (>->) entry main = compose2 (+) (* 2i32) futhark-0.25.27/tests/issue1783.fut000066400000000000000000000003211475065116200166730ustar00rootroot00000000000000-- == -- error: cannot match type surface = #asphere {curvature: f64} | #sphere {curvature: f64} entry sag (surf: surface) : f64 = match surf case #asphere -> 1 case #sphere -> 2 futhark-0.25.27/tests/issue1787.fut000066400000000000000000000002451475065116200167040ustar00rootroot00000000000000-- == -- error: found to be functional entry main: i32 -> i32 -> i32 = ((true, (.0)), (false, (.1))) |> (\p -> if p.0.0 then p.0 else p.1) |> (.1) |> curry futhark-0.25.27/tests/issue1791.fut000066400000000000000000000002441475065116200166760ustar00rootroot00000000000000-- == -- structure gpu { SegRed 1 } entry main [n] [d] (as: [n][d]f32) : [3]f32 = let bs = replicate n 0 let f i = f32.sum as[i] + 1 + bs[i] in tabulate 3 f futhark-0.25.27/tests/issue1794.fut000066400000000000000000000027561475065116200167130ustar00rootroot00000000000000let sgmscan 't [n] (op: t->t->t) (ne: t) (flg : [n]i64) (arr : [n]t) : [n]t = let flgs_vals = scan ( \ (f1, x1) (f2,x2) -> let f = f1 | f2 in if f2 != 0 then (f, x2) else (f, op x1 x2) ) (0,ne) (zip flg arr) let (_, vals) = unzip flgs_vals in vals let mkFlagArray 't [m] (aoa_shp: [m]i64) (zero: t) (aoa_val: [m]t ) : []t = let shp_rot = map (\i->if i==0 then 0 else aoa_shp[i-1]) (iota m) let shp_scn = scan (+) 0 shp_rot let aoa_len = shp_scn[m-1]+aoa_shp[m-1] let shp_ind = map2 (\shp ind -> if shp==0 then -1 else ind) aoa_shp shp_scn in scatter (replicate aoa_len zero) shp_ind aoa_val let partition2 [n] 't (conds: [n]bool) (dummy: t) (arr: [n]t) : (i64, [n]t) = let tflgs = map (\ c -> if c then 1 else 0) conds let fflgs = map (\ b -> 1 - b) tflgs let indsT = scan (+) 0 tflgs let tmp = scan (+) 0 fflgs let lst = if n > 0 then indsT[n-1] else -1 let indsF = map (+lst) tmp let inds = map3 (\ c indT indF -> if c then indT-1 else indF-1) conds indsT indsF let fltarr= scatter (replicate n dummy) inds arr in (lst, fltarr) def main [m] (bs: [m]bool) (S1_xss: [m]i64) = let (spl, iinds) = partition2 bs 0 (iota m) let F = mkFlagArray S1_xss 0 (map (+1) iinds) let II1_xss = sgmscan (+) 0 F F |> map (\x->x-1) let mask_xss = map (\sgmind -> bs[sgmind]) II1_xss in (spl, iinds,mask_xss) futhark-0.25.27/tests/issue1806.fut000066400000000000000000000001241475065116200166700ustar00rootroot00000000000000type record = {ctx: i64, foo': i64} entry main(r: record): record = r with ctx = 1 futhark-0.25.27/tests/issue1808.fut000066400000000000000000000002201475065116200166670ustar00rootroot00000000000000-- == -- input { [1,2,3] } -- output { [1,2,3,1,2,3] } def main (xs: []i32) = let [m] ys : [m]i32 = xs ++ xs in map (\i -> ys[i]) (iota m) futhark-0.25.27/tests/issue1816.fut000066400000000000000000000000771475065116200167000ustar00rootroot00000000000000type rec3 = {x: [2]i32, y: i32} let mod3 (r: *rec3): *rec3 = r futhark-0.25.27/tests/issue1824.fut000066400000000000000000000000741475065116200166740ustar00rootroot00000000000000def x !! y = x && y def x += y = x + y def x =+ y = x + y futhark-0.25.27/tests/issue1837.fut000066400000000000000000000005141475065116200166770ustar00rootroot00000000000000-- This crashed the compiler because the size slice produced by -- ExpandAllocations had a consumption inside of it. -- == -- structure gpu-mem { SegMap 2 SegRed 1 } entry main [n] (xs: [n]i64) = tabulate n (\_ -> let xs = scatter (copy xs) xs xs let xs = xs with [opaque 0] = opaque n in spread xs[0] 0 xs xs :> [n]i64) futhark-0.25.27/tests/issue1838.fut000066400000000000000000000004151475065116200167000ustar00rootroot00000000000000-- == -- structure gpu-mem { SegMap 5 SegRed 2 } entry main [n] (xs: [n]i64) = tabulate n (\_ -> let xs = loop xs = copy xs for i < 3 do scatter (copy xs) (scan (+) 0 xs) xs let xs = xs with [opaque 0] = n in spread xs[0] 0 xs xs :> [n]i64) futhark-0.25.27/tests/issue1841.fut000066400000000000000000000002441475065116200166720ustar00rootroot00000000000000-- == -- entry: f -- input { 1f32 } output { 1f32 } -- == -- entry: f' -- input { 1f32 } output { 2f32 } entry f (x: f32) : f32 = x entry f' (x: f32) : f32 = x+1 futhark-0.25.27/tests/issue1843.fut000077500000000000000000000014451475065116200167030ustar00rootroot00000000000000let N_coating_coefs = i64.i32 8 let N_coating_specs = i64.i32 26 type coating = #AntiReflective | #Mirror | #Absorbing | #PhaseGradient_IdealFocus {lam: f64, f: f64} | #PhaseGradient_RotSym_Dispersive {lam_d: f64, phi0: [N_coating_coefs]f64, GD: [N_coating_coefs]f64, GDD: [N_coating_coefs]f64} def parse_coatings (coating_list_array: [][N_coating_specs]f64): []coating = let parse (specs: [N_coating_specs]f64) : coating = let enum = i64.f64 specs[0] in match enum case 0 -> -- AntiReflective #AntiReflective case _ -> #AntiReflective in map parse coating_list_array entry RayTrace (coating_list_array: [][N_coating_specs]f64) : []coating = let coatings = parse_coatings coating_list_array in coatings futhark-0.25.27/tests/issue1847.fut000066400000000000000000000022321475065116200166770ustar00rootroot00000000000000type comp = f64 type Layer [nx][ny] = #l_inhomogeneous {epsilon: [nx][ny]comp, mu: [nx][ny]comp, thickness: f64} | #l_homogeneous {epsilon: comp, mu: comp, thickness: f64} type ConvLayer [NM] = #inhomogeneous {Kz: [NM][NM]comp} | #homogeneous {Kz: [NM][NM]comp} let makeConvLayer [nx][ny] (NM: i64) (layer : Layer[nx][ny]) : ConvLayer[NM] = match layer case #l_inhomogeneous {epsilon, mu, thickness} -> #inhomogeneous { Kz = tabulate_2d NM NM (\i j->0.0) } case #l_homogeneous {epsilon, mu, thickness} -> --BUG HERE: let Kz = tabulate_2d NM NM (\i j->0.0) in #homogeneous { Kz = Kz } --BUT OK THIS WAY: -- #homogeneous { Kz = tabulate_2d NM NM (\i j->0.0) } --API demo/examples entry main (x: i64) = --freespace problem setup let nx = 5 --x pixel count let ny = 5 --y pixel count let N = 5 --modes in x let M = 5 --modes in y let NM = N * M --total number of modes let layer = (#l_homogeneous {epsilon = 1.0, mu = 1.0, thickness = f64.inf} : Layer[nx][ny]) let c = makeConvLayer NM layer --run it in 1 futhark-0.25.27/tests/issue1853.fut000066400000000000000000000002651475065116200167000ustar00rootroot00000000000000-- Test parser-related quirks. type Thing = #this [2]f64 | #that [3]f64 type pt [n] 't = [n]t type foo = pt [2]i32 type bar = pt [2] i32 type pt2 't = t type baz = pt2 ([2]i32) futhark-0.25.27/tests/issue1855.fut000066400000000000000000000003561475065116200167030ustar00rootroot00000000000000-- Pattern matching on an incompletely known sum type should still -- respect aliases. type t [n] = #foo [n](i32,i32) | #bar def main (x: t []) = (\y -> match y case #foo arr -> take 2 arr case _ -> [(0,1),(1,2)]) x futhark-0.25.27/tests/issue1863.fut000066400000000000000000000001741475065116200167000ustar00rootroot00000000000000-- == -- input { empty([1][0]i32) } -- output { [0i32] } def main (foo: [1][0]i32): [1]i32 = map2 (\_ _ -> 0) [0] foo futhark-0.25.27/tests/issue1874.fut000066400000000000000000000007101475065116200166760ustar00rootroot00000000000000type~ state = [][]f32 entry init (m: i64): state = unflatten (replicate (m * m) 1f32) def step' [m] (cells: *[m][m]f32): *[m][m]f32 = let step_cell (cell: f32) ((y, x): (i64, i64)) = let x = if y > 0 then cells[y - 1, x] else 0 in cell * x in map2 (map2 step_cell) cells (tabulate_2d m m (\y x -> (y, x))) entry step (cells: state): state = step' (copy cells) -- == -- entry: step -- script input { init 10i64 } futhark-0.25.27/tests/issue1895.fut000066400000000000000000000025731475065116200167120ustar00rootroot00000000000000-- Setup from my program. module real = f32 type real = real.t type vec3 = {x:f32,y:f32,z:f32} def dot (a: vec3) (b: vec3) = (a.x*b.x + a.y*b.y + a.z*b.z) def vecsub (a: vec3) (b: vec3) = {x= a.x-b.x, y= a.y-b.y, z= a.y-b.z} def cross ({x=ax,y=ay,z=az}: vec3) ({x=bx,y=by,z=bz}: vec3): vec3 = ({x=ay*bz-az*by, y=az*bx-ax*bz, z=ax*by-ay*bx}) def quadrance v = dot v v def norm = quadrance >-> f32.sqrt def to_vec3 (xs: [3]real) = {x=xs[0], y=xs[1], z=xs[2]} def map_2d f = map (map f) def map_3d f = map (map_2d f) def grad 'a (f: a -> real) (primal: a) = vjp f primal (real.i64 1) let jacfwd [n][m] (f: [n][m]vec3 -> real) (x: [n][m]vec3): [n][m]vec3 = let v3 xs = unflatten_3d xs |> map_2d to_vec3 let tangent i = (replicate (n*m*3) 0 with [i] = 1) |> v3 in tabulate (n*m*3) (\i -> jvp f x (tangent i)) |> v3 -- The function. def fun atom_coords : real = let xs = flatten atom_coords -- "Random" reads are necessary (in the original program this reads an -- input of unknown pairs of indices, here the indexing is just nonsense): let dists = map2 (\i j -> norm (xs[i] `vecsub` xs[j])) (indices atom_coords) (indices atom_coords |> reverse) in f32.sum dists -- == -- compiled random input { [1][32][20][3]f32 } -- auto output entry main coords = map (grad fun) (map_3d to_vec3 coords) |> map_3d (\v -> [v.x,v.y,v.z]) futhark-0.25.27/tests/issue1903.fut000066400000000000000000000003711475065116200166720ustar00rootroot00000000000000-- == -- error: anonymous-nonconstructive module m2 : { type^ t val x : bool -> t val f : t -> i64 } = { type^ t = []bool -> bool def x b = \(_: [10]bool) : bool -> b def f [n] (_: [n]bool -> bool) = n } entry main2 = m2.f (m2.x true) futhark-0.25.27/tests/issue1926.fut000066400000000000000000000003621475065116200166770ustar00rootroot00000000000000-- == -- error: cannot match value type found = #found i32 | #not_found def main = let o = map (\x -> if (x > 3) then (#found x) else (#not_found)) [0, 1, 2, 3, 4] let u = match o case #found x -> x case #not_found -> -1 in u futhark-0.25.27/tests/issue1935.fut000066400000000000000000000002201475065116200166700ustar00rootroot00000000000000def get_name (i: i64) = match i case 0 -> "some name" case _ -> "" entry main = loop i = 0 while length (get_name i) != 0 do i + 1 futhark-0.25.27/tests/issue1936.fut000066400000000000000000000001511475065116200166740ustar00rootroot00000000000000-- == -- tags { no_wasm } -- compiled random input { f32 } def main (x: f32) = (x, (), replicate 10 ()) futhark-0.25.27/tests/issue1937.fut000066400000000000000000000001001475065116200166670ustar00rootroot00000000000000def iiota [n] : [n]i64 = 0..1.. loop dBins = replicate numBins2 0 for j < numBins do if dot > threshold then #[unsafe] let dBins[numBins+1] = dBins[numBins+1] + 1 in dBins else #[unsafe] let dBins[numBins] = dBins[numBins] + 1 in dBins ) points futhark-0.25.27/tests/issue1943.fut000066400000000000000000000001721475065116200166750ustar00rootroot00000000000000def windows k s = map (\i -> take k (drop i s)) (take (length s - k) (indices s)) entry main (s: []i32) = windows 14 s futhark-0.25.27/tests/issue1947.fut000066400000000000000000000003521475065116200167010ustar00rootroot00000000000000def consume (a: *[]i64) = a with [0] = 0 entry test (a: *[2]i64) (_: i64) n = let b = map id a :> [n]i64 let a_consumed = consume a let final = consume (copy (filter (\x -> x > 2) a_consumed)) in (all (\x -> x == 0) b, final) futhark-0.25.27/tests/issue1949.fut000066400000000000000000000001461475065116200167040ustar00rootroot00000000000000def fn (arr: *[](i32,i32)) = if true then opaque arr else opaque arr entry test = (fn [(0, 0)])[0] futhark-0.25.27/tests/issue1952.fut000066400000000000000000000001171475065116200166740ustar00rootroot00000000000000entry main (x: i64) = let is = flatten [[x], [x]] in is[999999999] futhark-0.25.27/tests/issue1978.fut000066400000000000000000000001421475065116200167020ustar00rootroot00000000000000entry main (gridDim: (i64,i64)) = tabulate_2d gridDim.0 gridDim.1 (\i j -> (i, j)) |> flatten futhark-0.25.27/tests/issue1984.fut000066400000000000000000000002371475065116200167040ustar00rootroot00000000000000-- #19 -- == -- input { true [1,2,3] } output { [1,2,3] } -- input { false [1,2,3] } output { [3,2,1] } def main b (xs: []i32) = if b then xs else reverse xs futhark-0.25.27/tests/issue1998.fut000066400000000000000000000025731475065116200167160ustar00rootroot00000000000000-- | file: error.fut module type bitset = { type bitset[n] val nbs : i64 val empty : (n : i64) -> bitset[(n - 1) / nbs + 1] val complement [n] : bitset[(n - 1) / nbs + 1] -> bitset[(n - 1) / nbs + 1] val size [n] : bitset[(n - 1) / nbs + 1] -> i64 } module mk_bitset (I: integral) : bitset = { def nbs = i64.i32 I.num_bits type bitset [n] = [n]I.t def zero : I.t = I.i64 0 def empty (n : i64) : bitset[(n - 1) / nbs + 1] = replicate ((n - 1) / nbs + 1) zero def set_front_bits_zero [n] (s : bitset[(n - 1) / nbs + 1]) : bitset[(n - 1) / nbs + 1] = let l = (n - 1) / nbs + 1 let start = 1 + (n - 1) % nbs let to_keep = I.i64 (i64.not (i64.not 0 << start)) in if l == 0 then s else copy s with [l - 1] = s[l - 1] I.& to_keep def complement [n] (s : bitset[(n - 1) / nbs + 1]) : bitset[(n - 1) / nbs + 1] = map I.not s |> set_front_bits_zero def size [n] (s : bitset[(n - 1) / nbs + 1]) : i64 = map (i64.i32 <-< I.popc) s |> i64.sum } module bitset_u8 = mk_bitset u8 -- == -- entry: test_complement -- input { 0u8 } output { 0i64 } -- input { 1u8 } output { 1i64 } -- input { 2u8 } output { 2i64 } -- input { 8u8 } output { 8i64 } entry test_complement (c : u8) : i64 = let c' = i64.u8 c let empty_set = bitset_u8.empty c' let full_set = bitset_u8.complement empty_set let result = bitset_u8.size full_set in result futhark-0.25.27/tests/issue2000.fut000066400000000000000000000001161475065116200166540ustar00rootroot00000000000000entry main (x: i32) (y: i32): bool = match x case x' -> (\z -> z == x') y futhark-0.25.27/tests/issue2011.fut000066400000000000000000000001641475065116200166610ustar00rootroot00000000000000-- == -- input {} -- output { 4u64 } module mod: {module x: integral} = { module x = u64 } def main = u64.i64 4 futhark-0.25.27/tests/issue2015.fut000066400000000000000000000003451475065116200166660ustar00rootroot00000000000000def main [n][m] (xss: *[n][m]i32) = map (\(xs: [m]i32) -> let xs = loop (zs: [m]i32) = xs for i < n do let xs' = scatter (copy xs) (iota m) (rotate 1 zs) in xs' in xs) xss futhark-0.25.27/tests/issue2016.fut000066400000000000000000000005341475065116200166670ustar00rootroot00000000000000-- == -- error: non-constructively module trouble : { type const 'a 'b val mk_const 'a 'b : a -> b -> const a b } = { type const 'a 'b = a def mk_const x _ = x } entry f [n][m] (_: trouble.const ([n]i64) ([m]i64)) = m entry g (x: i64) = let [n][m] (_: trouble.const ([n]i64) ([m]i64)) = trouble.mk_const (iota (x+1)) (iota x) in m futhark-0.25.27/tests/issue2017.fut000066400000000000000000000000511475065116200166620ustar00rootroot00000000000000-- == -- error: Refutable def a - 1 = 2 futhark-0.25.27/tests/issue2018.fut000066400000000000000000000001611475065116200166650ustar00rootroot00000000000000def main (i: i64) (j: i64) (xss: *[][]i32) = let xs = xss[i] let xss[j] = copy (opaque (opaque xs)) in xss futhark-0.25.27/tests/issue2021.fut000066400000000000000000000001051475065116200166550ustar00rootroot00000000000000def arr = replicate 10 true entry main : [](bool,bool) = zip arr arr futhark-0.25.27/tests/issue2038.fut000066400000000000000000000010551475065116200166720ustar00rootroot00000000000000entry all_work_indices (n:i64) (m:i64) = let block i prog prog' = (i, prog, i64.max 0 (i64.min (n-i) (if i == 0 then n else prog'))) let size (_, a, b) = b-a >> 1 let (iter, _) = loop (iter, progress) = (0, replicate m 0) while not (all id (tabulate m (\i -> progress[i] >= n-i-1))) do let blockrow = map3 block (iota m) (progress) (rotate (-1) progress) let sizes = map size blockrow in (iter + 1i32, map2 (+) progress sizes) in iter futhark-0.25.27/tests/issue2040.fut000066400000000000000000000003731475065116200166650ustar00rootroot00000000000000def fun [k] (n:i64) (op:[k]f32 -> *[k]f32) (_:[k]f32) = let process (Q:*[n+1][k]f32) i = let q = op Q[i] in Q with [i+1] = q let Q = replicate (n+1) (replicate k 0) let Q = process Q 0 in Q entry test n k = fun n copy (replicate k 0) futhark-0.25.27/tests/issue2048.fut000066400000000000000000000002271475065116200166730ustar00rootroot00000000000000-- == -- input { [1,2,3] } -- output { 3i64 } def f [n] (xs: [n]i32) = let [m] (ys: [m]i32) = filter (>0) xs in ys entry main xs = length (f xs) futhark-0.25.27/tests/issue2053.fut000066400000000000000000000002671475065116200166730ustar00rootroot00000000000000-- == -- input { [1i64,2i64,3i64] } -- output { 6i64 } def f [n] (reps:[n]i64) : [reduce (+) 0 reps]i64 = iota (reduce (+) 0 reps) entry main arr = length (f (map (\x -> x) arr)) futhark-0.25.27/tests/issue2058.fut000066400000000000000000000002621475065116200166730ustar00rootroot00000000000000-- We neglected to mark the target arrays as consumed while -- simplifying the body. entry problem [n] (arr: *[n]i64) : [n]i64 = reduce_by_index arr (+) 0 (iota n) (copy arr) futhark-0.25.27/tests/issue2073.fut000066400000000000000000000004641475065116200166740ustar00rootroot00000000000000-- == -- error: "value" is not in scope def f [n] (dmax: i64) (depth: [n]i64) (value: [n]i32) (parent: [n]i64) : []i32 = loop value for d in dmax..dmax-1...1 do reduce_by_index (copy value) (+) 0i32 (map (\i -> if depth[i] == d then parent[i] else -1) (iota (length value))) value futhark-0.25.27/tests/issue2092.fut000066400000000000000000000005431475065116200166730ustar00rootroot00000000000000-- == -- input { 2i64 3i64 } -- output { [[[0i64, 0i64], [1i64, 1i64]], [[1i64, 1i64], [1i64, 1i64]], [[2i64, 2i64], [1i64, 1i64]]] } -- input { 10i64 256i64 } auto output entry main k n = #[incremental_flattening(only_intra)] tabulate n (\i -> let A = replicate k i in tabulate k (\j -> if j % 2 == 1 then replicate k j else A)) futhark-0.25.27/tests/issue2096.fut000066400000000000000000000004271475065116200167000ustar00rootroot00000000000000entry main (x: i32): (i32, bool) = let when pred action orig = if pred then action orig else orig let action = when true (\(x, _) -> (x, true)) in if true then action (x, false) else action (x, false) futhark-0.25.27/tests/issue2099.fut000066400000000000000000000002311475065116200166740ustar00rootroot00000000000000-- == -- input {} -- output { [100, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 11i64 } def numbers = concat [100] (0..<10i32) entry main = (numbers, length numbers) futhark-0.25.27/tests/issue2100.fut000066400000000000000000000003021475065116200166520ustar00rootroot00000000000000entry main (n: i64): () = let arr = reduce_by_index (replicate n ()) (\() () -> ()) () (replicate n 0) (replicate n ()) in arr[0] futhark-0.25.27/tests/issue2103.fut000066400000000000000000000001531475065116200166610ustar00rootroot00000000000000-- == -- error: cannot match def main = let [n] (A: [n]i64, B: [n]i64) = (iota 1, iota 2) in zip A B futhark-0.25.27/tests/issue2104.fut000066400000000000000000000012251475065116200166630ustar00rootroot00000000000000type v3= (f32,f32,f32) type v4= (f32,f32,f32,f32) type C4N3V3 = (v4,v3,v3) let serializeC4N3V3 (((r,g,b,a),(nx,ny,nz),(vx,vy,vz)):C4N3V3) = [r,g,b,a,nx,ny,nz,vx,vy,vz] let triangleMeshC4N3V3 [m][n] (vs:[m][n]C4N3V3) : []f32 = let triangle v0 v1 v2 = serializeC4N3V3 v0 ++ serializeC4N3V3 v1 ++ serializeC4N3V3 v2 let quad v0 v1 v2 v3 = triangle v0 v1 v2 ++ triangle v2 v3 v0 in tabulate_2d (m-1) (n-1) (\i j -> quad vs[i,j] vs[i+1,j] vs[i+1,j+1] vs[i,j+1]) |> flatten |> flatten entry particleSystemMesh [m][n][k] (vs:[m][n]C4N3V3) (coords:[k]v3) : []f32 = map (\_ -> vs) coords |> map triangleMeshC4N3V3 |> flatten futhark-0.25.27/tests/issue2106.fut000066400000000000000000000003751475065116200166720ustar00rootroot00000000000000-- == -- input { 2i64 } -- output { [0i64, 0i64] } def f'' 'a (n: i64) (f: a -> i64) (a: a): [n]i64 = replicate n (f a) def f (s: {n: i64}): [s.n]i64 = let f' 'a (f: a -> i64) (a: a): [s.n]i64 = f'' s.n f a in f' id 0 entry main (n: i64) = f {n} futhark-0.25.27/tests/issue2113.fut000066400000000000000000000001051475065116200166570ustar00rootroot00000000000000-- == -- error: "a" module type abc = { type x 'a val y : x a } futhark-0.25.27/tests/issue2114.fut000066400000000000000000000001071475065116200166620ustar00rootroot00000000000000-- == -- error: "t" module type A = { module R : { type t = t } } futhark-0.25.27/tests/issue2124.fut000066400000000000000000000002171475065116200166650ustar00rootroot00000000000000-- Ignore suffixes when computing differences when possible. def main b : [10]i64 = if b then iota 10 : [10]i64 else iota 10i64 : [10]i64 futhark-0.25.27/tests/issue2125.fut000066400000000000000000000010421475065116200166630ustar00rootroot00000000000000-- Two things are necessary here: -- -- 1) not generating unnecessary versions. -- -- 2) simplifying away the slices. -- -- == -- structure gpu { /If/True/SegMap 1 /If/False/SegRed 1 } entry example_tc5 [A][B][I][J] [Q] (xsss: [Q][A][I]f32) (ysss: [B][Q][J]f32) : [I][B][J][A]f32 = #[unsafe] map (\i -> -- dim 0 map (\b -> -- dim 1 map (\j -> -- dim 2 map (\a -> -- dim 3 map2 (*) xsss[:, a, i] ysss[b, :, j] |> f32.sum ) (iota A) ) (iota J) ) (iota B) ) (iota I) futhark-0.25.27/tests/issue2136.fut000066400000000000000000000002511475065116200166660ustar00rootroot00000000000000def test (a: i64) (b: i64) (f: [a*b]f32 -> [a*b]f32) (g: [a*b]f32 -> [a*b]f32) = g (f (replicate (a*b) 0)) entry main a b = let h = test a b in h reverse reverse futhark-0.25.27/tests/issue2184.fut000066400000000000000000000006521475065116200166760ustar00rootroot00000000000000-- Somewhat exotic case related to entry points that are also used as -- ordinary functions. entry calculate_objective [d] (xParam: [3][4*d]f64) (yParam: [3][d]f64) : f64 = 0 entry calculate_jacobian [d] (mainParams: [3][4*d]f64) (yParam: [3][d]f64) = vjp (\(x, y) -> calculate_objective x y) (mainParams, yParam) 1 futhark-0.25.27/tests/issue2193.fut000066400000000000000000000001531475065116200166720ustar00rootroot00000000000000def takefrom 't (xs: []t) (i: i64) : [i+1]t = take (i+1) xs entry main n (xs: []i32) = n |> takefrom xs futhark-0.25.27/tests/issue2197.fut000066400000000000000000000004121475065116200166740ustar00rootroot00000000000000def index_of_first p xs = loop i = 0 while i < length xs && !p xs[i] do i + 1 def span p xs = let i = index_of_first p xs in (take i xs, drop i xs) entry part1 [l] (ls: [][l]i32) = let blank (l: [l]i32) = null l in span blank ls |> \(x, y) -> (id x, tail y) futhark-0.25.27/tests/issue2209.fut000066400000000000000000000050351475065116200166740ustar00rootroot00000000000000-- Fancy size-dependent programming via the module system; the -- interpreter had a far too naive idea of how size expressions were -- handled. -- == -- entry: test_adam -- input { [42f32,43f32] } module type vspace = { module Scalar: real type scalar = Scalar.t type vector val zero : vector val scale : scalar -> vector -> vector val dot : vector -> vector -> scalar val + : vector -> vector -> vector val neg : vector -> vector val to_array : vector -> []scalar } module array_vec (T: real) (Size: {val n : i64}) = { module Scalar = T type scalar = Scalar.t type vector = [Size.n]scalar def zero : vector = rep (T.i32 0) def scale (s: scalar) (v: vector) : vector = map (s T.*) v def dot (v: vector) (u: vector) : scalar = T.sum (map2 (T.*) v u) def (+) (v: vector) (u: vector) : vector = map2 (T.+) v u def neg (v: vector) : vector = map T.neg v def from_array (x: [Size.n]scalar) : vector = x def to_array (x: vector) : [Size.n]scalar = x } module adam (V: vspace) = { local module S = V.Scalar local type a = V.scalar local type v = V.vector type state = { step: i32 , mw: v , vw: a , w: v } def result (x: state) : v = x.w def initial_state (x: v) : state = {step = 0, mw = V.zero, vw = S.i32 0, w = x} type params = { beta1: a , beta2: a , eta: a , epsilon: -- learning rate a } def def_params : params = { beta1 = S.f32 0.9 , beta2 = S.f32 0.999 , eta = S.f32 1e-4 , epsilon = S.f32 1e-5 } def adam (p: params) (obj: v -> a) (state: state) : state = let grad: v = vjp obj state.w (S.i32 1) let mw': v = V.(p.beta1 `scale` state.mw + (S.(i32 1 - p.beta1) `scale` grad)) let vw': a = S.(p.beta2 * state.vw + (i32 1 - p.beta2) * (grad `V.dot` grad)) let pow (x: a) (n: i32) = S.(exp (i32 n * log x)) let mw_hat: v = S.(i32 1 - p.beta1 `pow` state.step) `V.scale` mw' let vw_hat: a = S.((i32 1 - p.beta2 `pow` state.step) * vw') let s: a = S.(p.eta / (sqrt vw_hat + p.epsilon)) let w: v = V.(state.w + neg (s `scale` mw_hat)) in {mw = mw', vw = vw', step = state.step + 1, w} } module Size = {def n : i64 = 2} module V = array_vec f32 Size type v = [Size.n]f32 module Adam = adam V def func (v: v) : f32 = let vv = v :> [2]f32 let x: f32 = vv[0] let y: f32 = vv[1] in f32.(2 * x * x + (y - 2) * (y - 2) + 3) entry test_adam (x0: [2]f32) = iterate 5 (Adam.adam Adam.def_params func) (Adam.initial_state (x0 :> [Size.n]f32)) |> \{mw, step, vw, w} -> (mw, step, vw, w) futhark-0.25.27/tests/issue2216.fut000066400000000000000000000007071475065116200166730ustar00rootroot00000000000000-- == -- input { [0.0, 0.0, 0.0] } -- output { [[0.0, 0.0, 0.0], -- [0.0, 2.0, 0.0], -- [0.0, 0.0, 2.0]] } def identity_mat n = tabulate_2d n n (\i j -> f64.bool (i == j)) def Jacobi [n] (f: [n]f64 -> [n]f64) (x: [n]f64) : [n][n]f64 = map (\i -> jvp f x i) (identity_mat n) def Hessian [n] (f: [n]f64 -> f64) (x: [n]f64) : [n][n]f64 = Jacobi (\x -> vjp f x 1) x entry main (x: [3]f64) = Hessian (\x -> x[1] ** 2 + x[2] ** 2) x futhark-0.25.27/tests/issue245.fut000066400000000000000000000014761475065116200166170ustar00rootroot00000000000000-- This was an issue with a simplification rule for -- rearrange-split-rearrange chains that sometimes occur in -- tail2futhark output. -- -- == -- input { 2i64 3i64 } -- output { [[1i32, 2i32], [4i32, 5i32]] } def take_arrint (l: i64) (x: [][]i32): [][]i32 = let v1 = take l x in v1 def reshape_int (l: i64) (x: []i32): []i32 = let roundUp = ((l + (length x - 1)) / length x) in let extend = flatten (replicate (roundUp) (x)) in let v1 = take l extend in v1 entry main (x: i64) (y: i64): [][]i32 = let t_v1 = unflatten (reshape_int ((x * y)) (map (\x -> (i32.i64 x + 1)) (iota (6)))) in let t_v2 = transpose (t_v1) in let t_v3 = take_arrint (x) (t_v2) in let t_v4 = transpose (t_v3) in t_v4 futhark-0.25.27/tests/issue246.fut000066400000000000000000000022721475065116200166130ustar00rootroot00000000000000-- We assigned overly complex (and wrong) index functions to splits. -- -- == -- input { 3i64 4i64 } -- output { [1i64, 2i64, 5i64, 6i64, 9i64, 10i64] } def dim_2 't [d0] [d1] (i: i64) (x: [d0][d1]t): i64 = if (i == 1) then d1 else d0 def take_arrint [k] (l: i64) (x: [][k]i64): [][]i64 = if (0 <= l) then if (l <= length x) then let v1 = take (l) (x) in v1 else concat (x) (replicate ((i64.abs (l) - length x)) (replicate (dim_2 1 x) (0) :> [k]i64)) else if (0 <= (l + length x)) then let v2 = drop ((l + length x)) (x) in v2 else concat (replicate ((i64.abs (l) - length x)) (replicate (dim_2 1 x) (0) :> [k]i64)) (x) def reshape_int (l: i64) (x: []i64): []i64 = let roundUp = ((l + (length x - 1)) / length x) in let extend = flatten (replicate (roundUp) (x)) in let v1 = take (l) (extend) in v1 entry main (n: i64) (m: i64): []i64 = let t_v1 = unflatten (reshape_int ((n * m)) ((map (\(x: i64): i64 -> (x + 1)) (iota (12))))) in let t_v2 = transpose (t_v1) in let t_v3 = take_arrint (2) (t_v2) in let t_v4 = transpose (t_v3) in flatten (t_v4) futhark-0.25.27/tests/issue248.fut000066400000000000000000000030121475065116200166060ustar00rootroot00000000000000-- This seems to fail in some places. Generated by tail2futhark (prettified a little bit). -- -- == -- input { [67,67,67,65,65,67,66,65,65,68,65,66,67,65,65,67,68,67,65,67,68,68,66,67,68,68,67,67,67,66,65,68,67,66,67,67,67,65,67,67,67,66,67,67,66,65,67,67] } -- output { true } def eqb (x: bool) (y: bool): bool = (! ((x || y)) || (x && y)) def reshape_int (l: i64) (x: []i32): []i32 = let roundUp = ((l + (length x - 1)) / length x) in let extend = flatten (replicate (roundUp) (x)) in let v1 = take (l) (extend) in v1 entry main (nucleotides: []i32): bool = let t_v2 = unflatten (reshape_int (8*6) nucleotides) in let t_v8 = transpose (map transpose (unflatten_3d (reshape_int (8*6*4) (map i32.u8 "ABCD")))) in let t_v9 = unflatten_3d (reshape_int (4*8*6) (flatten t_v2)) in let t_v12 = let x = t_v8 in let y = t_v9 in map2 (\(x: [][]i32) (y: [][]i32) -> map2 (\(x: []i32) (y: []i32) -> map2 (==) (x) (y)) (x) (y)) (x) (y) in let t_v18 = map (\(x: [][]bool) -> map (\(x: []bool): bool -> reduce (||) (false) (x)) (x)) (t_v12) in let t_v21 = (map (\(x: []bool): bool -> reduce (&&) (true) (x)) (transpose (t_v18))) in let t_v26 = reduce (&&) (true) (let x = t_v21 in let y = [false, false, false, true, false, true, false, false] in map2 eqb (x) (y)) in t_v26 futhark-0.25.27/tests/issue304.fut000066400000000000000000000004421475065116200166030ustar00rootroot00000000000000-- Too aggressive hoisting/distribution can lead to a compile error -- here. -- == -- input { [[1,2],[3,4]] } output { [[1,2],[3,4]] } -- input { [[1,2,3],[3,4,5]] } error: out of bounds entry main [m][n] (xss : [m][n]i32): [n][m]i32 = map (\j -> map (\i -> xss[j,i]) (iota m)) (iota n) futhark-0.25.27/tests/issue354.fut000066400000000000000000000010051475065116200166040ustar00rootroot00000000000000-- == -- structure { Screma 6 } def linerp2D (image: [][]f32) (p: [2]i32): f32 = #[unsafe] let a = p[0] let b = p[1] in image[a,b] def f [n] (rotSlice: [n][n]f32): [n][n]f32 = let positions1D = iota n let positions2D = map (\x -> map (\y -> [i32.i64 x,i32.i64 y]) positions1D) positions1D in map (\row -> map (linerp2D rotSlice) row) positions2D def main [s][n] (proj: [s][n]f32): [s][n][n]f32 = let rotatedVol = map (\row -> map (\col -> replicate n col) row) proj in map (\x -> f x) rotatedVol futhark-0.25.27/tests/issue367.fut000066400000000000000000000002051475065116200166110ustar00rootroot00000000000000def main(n: i64) = let a = replicate n (replicate n 1) in map (\(xs: []i32, i) -> copy xs with [0] = i32.i64 i) (zip a (iota n)) futhark-0.25.27/tests/issue390.fut000066400000000000000000000003221475065116200166050ustar00rootroot00000000000000-- An error in the handling of reshape in the internaliser. The -- source-language array is of lesser rank than the corresponding -- core-language array(s). entry main n m (a: [n*m]([]i32,i32)) = unflatten a futhark-0.25.27/tests/issue392.fut000066400000000000000000000011111475065116200166040ustar00rootroot00000000000000def dotprod [n] (xs: [n]f64) (ys: [n]f64): f64 = reduce (+) 0.0 (map2 (*) xs ys) def matvecmul [n] [m] (xss: [n][m]f64) (ys: [m]f64) = map (dotprod ys) xss def cost_derivative [n] (output_activations:[n]f64) (y:[n]f64) : [n]f64 = map2 (-) output_activations y def outer_prod [m][n] (a:[m]f64) (b:[n]f64) : *[m][n]f64 = map (\x -> map (\y -> x * y) b) a def main [i] [j] [k] (w3:[k][j]f64) (x:[i]f64,y:[k]f64) (z2: []f64) (z3: [k]f64) = let delta3 = map2 (*) (cost_derivative z3 y) z3 let nabla_b3 = delta3 let nabla_w3 = outer_prod delta3 z2 in (nabla_b3,nabla_w3) futhark-0.25.27/tests/issue393.fut000066400000000000000000000007641475065116200166220ustar00rootroot00000000000000-- == -- structure { /Screma 1 /Screma/Screma 2 } def dotprod [n] (xs: [n]f64) (ys: [n]f64): f64 = reduce (+) 0.0 (map2 (*) xs ys) def matvecmul [n] [m] (xss: [n][m]f64) (ys: [m]f64) = map (dotprod ys) xss def outer_prod [m][n] (a:[m]f64) (b:[n]f64): [m][n]f64 = map (\x -> map (\y -> x * y) b) a def main [i] [j] [k] (b2: [j]f64) (b3: [k]f64) (w3: [k][j]f64) (x:[i]f64) = let delta2 = map2 (*) (matvecmul (transpose w3) b3) b2 let nabla_w2 = outer_prod delta2 x in (delta2,nabla_w2) futhark-0.25.27/tests/issue396.fut000066400000000000000000000003401475065116200166130ustar00rootroot00000000000000-- The problem is that a simplified index function is used. def main (b: bool) (xs: []i32) = map (\(x: i32) -> if b then (copy (transpose [[x,x],[x,x]])) with [0,0] = 7i32 else [[1,1],[1,1]]) xs futhark-0.25.27/tests/issue397.fut000066400000000000000000000004571475065116200166250ustar00rootroot00000000000000-- == -- input {0} error: def predict (a:[10]f64) : i64 = let (m,i) = reduce (\(a,i) (b,j) -> if a > b then (a,i) else (b,j)) (a[9],9) (zip (a[:8]) (iota 9 :> [8]i64)) in i def main (x: i32) : i64 = predict [0.2,0.3,0.1,0.5,0.6,0.2,0.3,0.1,0.7,0.1] futhark-0.25.27/tests/issue400.fut000066400000000000000000000003011475065116200165720ustar00rootroot00000000000000-- Consumption of loops with more certain patterns was not tracked -- correctly. def main (n: i64) (x: i32) = loop a = replicate n x for i < 10 do (loop (a) for j < i do a with [j] = 1) futhark-0.25.27/tests/issue403.fut000066400000000000000000000002231475065116200166000ustar00rootroot00000000000000-- == -- input { [1,2] 0 } output { true } -- input { [1,2] 1 } output { false } def main (xs: *[]i32) (i: i32) = let xs[0] = 0 in xs[i] == 0 futhark-0.25.27/tests/issue407.fut000066400000000000000000000006111475065116200166050ustar00rootroot00000000000000module edge_handling (mapper: {}) = { def handle (g: i32): f32 = let base (): f32 = f32.i32 g in base () } module edge_handling_project_top = edge_handling {} module edge_handling_project_bottom = edge_handling {} def main (x: i32) = let _unused = edge_handling_project_top.handle 0 let project_bottom () = edge_handling_project_bottom.handle x in project_bottom () futhark-0.25.27/tests/issue408.fut000066400000000000000000000002711475065116200166100ustar00rootroot00000000000000-- Bug in closed-form simplification. -- == -- input { true } output { true } -- input { false } output { false } -- structure { Reduce 0 } def main (x: bool) = reduce (&&) true [x] futhark-0.25.27/tests/issue410.fut000066400000000000000000000032111475065116200165760ustar00rootroot00000000000000def sgmScanSum [n] (vals:[n]i32) (flags:[n]bool) : [n]i32 = let pairs = scan ( \(v1,f1) (v2,f2) -> let f = f1 || f2 let v = if f2 then v2 else v1+v2 in (v,f) ) (0,false) (zip vals flags) let (res,_) = unzip pairs in res def sgmIota [n] (flags:[n]bool) : [n]i32 = let iotas = sgmScanSum (replicate n 1) flags in map (\x -> x-1) iotas type point = (i32,i32) type line = (point,point) def main [h][w][n] (grid:*[h][w]i32) (lines:[n]line) (nn: i64) (idxs: []i32) = #[unsafe] let iotan = iota n let nums = map (\i -> iotan[i]) idxs let flags = map (\i -> i != 0 && nums[i] != nums[i-1]) (map i32.i64 (iota nn)) let (ps1,ps2) = unzip lines let (xs1,ys1) = unzip ps1 let (xs2,ys2) = unzip ps2 let xs1 = map (\i -> xs1[i]) idxs let ys1 = map (\i -> ys1[i]) idxs let xs2 = map (\i -> xs2[i]) idxs let ys2 = map (\i -> ys2[i]) idxs let dirxs = map2 (\x1 x2 -> if x2 > x1 then 1 else if x1 > x2 then -1 else 0) xs1 xs2 let slops = map4 (\x1 y1 x2 y2 -> if x2 == x1 then if y2 > y1 then f32.i32(1) else f32.i32(-1) else f32.i32(y2-y1) / f32.abs(f32.i32(x2-x1))) xs1 ys1 xs2 ys2 let iotas = sgmIota flags let xs = map3 (\x1 dirx i -> x1+dirx*i) xs1 dirxs iotas let ys = map3 (\y1 slop i -> y1+i32.f32(slop*f32.i32(i))) ys1 slops iotas let is = map2 (\x y -> w*i64.i32 y+i64.i32 x) xs ys let flatgrid = flatten grid in scatter (copy flatgrid) is (replicate nn 1) futhark-0.25.27/tests/issue413.fut000066400000000000000000000001531475065116200166030ustar00rootroot00000000000000module vec: { type vec [x] } = { type vec [x] = [x]i32 } def main [n] ((x: vec.vec[n]): vec.vec[n]) = 0 futhark-0.25.27/tests/issue419.fut000066400000000000000000000103121475065116200166070ustar00rootroot00000000000000-- -- == -- compiled input { [1, 3, 1, 1, 1, 53, 2, 2, 1, 7, 8, 1, -- 2, 1, 1, 41, 2, 2, 4, 1, 1, 37, 1, 1, 1, 2, -- 2, 40, 1, 1, 1, 63, 1, 9, 2, 1, 2, 1, 3, 35 ] -- } -- output { [5i32, 15i32, 21i32, 27i32, 31i32, 39i32, 10i32, 33i32, 0i32, 1i32, 2i32, 3i32, -- 4i32, 6i32, 7i32, 8i32, 9i32, 11i32, 12i32, 13i32, 14i32, 16i32, 17i32, 18i32, -- 19i32, 20i32, 22i32, 23i32, 24i32, 25i32, 26i32, 28i32, 29i32, 30i32, 32i32, -- 34i32, 35i32, 36i32, 37i32, 38i32] } def sgmPrefSum [n] (flags: [n]i32) (data: [n]i32) : [n]i32 = (unzip (scan (\(x_flag,x) (y_flag,y) -> let flag = x_flag | y_flag in if y_flag != 0 then (flag, y) else (flag, x + y)) (0, 0) (zip flags data))).1 def bin_packing_ffh [q] (w: i32) (all_perm : *[q]i32) (all_data0 : [q]i32) = let all_data = scatter (replicate q 0) (map i64.i32 all_perm) all_data0 let len = q let cur_shape = replicate 0 0 let goOn = true let count = 0 let (_,all_perm,_,_, _,_) = loop ((len,all_perm,all_data,cur_shape, goOn,count)) while goOn && count < 100 do let data = take (len) all_data let perm = take (len) all_perm -- 1. initial attempt by first fit heuristic let scan_data = scan (+) 0 data let ini_sgms = map (/w) scan_data let num_sgms = (last ini_sgms) + 1 -- OK let flags = map (\i -> if i == 0 then 1 else if ini_sgms[i-1] == ini_sgms[i] then 0 else 1 ) (map i32.i64 (iota len)) let ones = replicate len 1 let tmp = sgmPrefSum flags ones let (inds1,inds2,vals) = unzip3 ( map (\ i -> if (i == i32.i64 len-1) || (flags[i+1] == 1) -- end of segment then (i+1-tmp[i], ini_sgms[i], tmp[i]) else (-1,-1,0) ) (map i32.i64 (iota len)) ) let flags = scatter (replicate len 0) (map i64.i32 inds1) vals let shapes= scatter (replicate (i64.i32 num_sgms) 0) (map i64.i32 inds2) vals -- 2. try validate: whatever does not fit move it as a first segment let scan_data = sgmPrefSum flags data let ini_sgms = scan (+) 0 (map (\x -> if x > 0 then 1 else 0) flags) -- map (/w) scan_data let moves = map (\ i -> let sgm_len = flags[i] in if sgm_len > 0 then if scan_data[i+sgm_len-1] > w then 1 -- this start of segment should be moved else 0 else 0 ) (map i32.i64 (iota len)) let num_moves = reduce (+) 0 moves in -- if true -- then (num_moves, flags, all_data, concat shapes cur_shape, false) if num_moves == 0 then (num_moves, all_perm, all_data, concat shapes cur_shape, false, count) else -- reorder perm, data, and shape arrays let scan_moves = scan (+) 0 moves let (inds_s, lens, inds_v) = unzip3 ( map (\ i -> let offset = scan_moves[i] let (ind_s, ll) = if i > 0 && flags[i] == 0 && moves[i-1] > 0 -- new start of segment then (ini_sgms[i-1]-1, flags[i-1]-1) else (-1, 0) let ind_v = if moves[i] == 0 then (num_moves-offset+i) else offset-1 -- ??? in (ind_s, ll, ind_v) ) (iota len) ) let shapes' = scatter shapes inds_s lens let cur_shape= concat shapes' cur_shape let all_perm = scatter (copy all_perm) inds_v perm let all_data = scatter (copy all_data) inds_v data -- in (num_moves, all_perm, inds_v, cur_shape, false) in (num_moves, all_perm, all_data, cur_shape, true, count+1) in all_perm def main [arr_len] (arr : [arr_len]i32) = bin_packing_ffh 10 (map i32.i64 (iota arr_len)) arr futhark-0.25.27/tests/issue426.fut000066400000000000000000000001371475065116200166110ustar00rootroot00000000000000-- == -- input { [1,2] [2,1] } -- output { true } def main (xs: []i32) (ys: []i32) = xs != ys futhark-0.25.27/tests/issue431.fut000066400000000000000000000003011475065116200165760ustar00rootroot00000000000000-- Applying tiling inside of fused scanomaps. -- == -- structure gpu { /If/True/SegScan 1 /If/False/SegScan 1 } def main (xs: []i32) = scan (+) 0 (map (\x -> reduce (+) 0 (map (+x) xs)) xs) futhark-0.25.27/tests/issue433.fut000066400000000000000000000001311475065116200166010ustar00rootroot00000000000000-- The bug here was related to the replicate. def main = replicate 0 ([] : [](i32,i32)) futhark-0.25.27/tests/issue436.fut000066400000000000000000000003211475065116200166050ustar00rootroot00000000000000-- Fusion would sometimes eat certificates on reshapes. -- == -- input { 1i64 [1] } -- output { [4] } -- input { 2i64 [1] } -- error: def main [m] (n: i64) (xs: [m]i32) = map (+2) (map (+1) (xs :> [n]i32)) futhark-0.25.27/tests/issue437.fut000066400000000000000000000003541475065116200166140ustar00rootroot00000000000000-- Tragic problem with index functions. -- == -- input { true 1i64 2i64 [1,2,3] } output { [1,2] } -- input { false 1i64 2i64 [1,2,3] } output { [1] } def main (b: bool) (n: i64) (m: i64) (xs: []i32) = if b then xs[0:m] else xs[0:n] futhark-0.25.27/tests/issue455.fut000066400000000000000000000003111475065116200166050ustar00rootroot00000000000000def main (data: *[]i32) : []i32 = let old_data = copy data let (data, _) = loop (data, old_data) for i in [1,2,3] do let new_data = old_data with [0] = 1 in (new_data, data) in data futhark-0.25.27/tests/issue456.fut000066400000000000000000000012631475065116200166150ustar00rootroot00000000000000-- This program exposed a flaw in the kernel extractor, which was -- unable to handle identity mappings. These rarely occur normally, -- because the simplifier will have removed them, but they sometimes -- occur after loop interchange. -- == -- structure gpu { SegMap 1 } def main [n] (datas: *[][n]i32) (is: []i64) = #[incremental_flattening(only_inner)] map (\(data: [n]i32, old_data: [n]i32) -> let (data, _) = loop (data: *[n]i32, old_data: *[n]i32) = (copy data, copy old_data) for i in [1,2,3] do let new_data = scatter old_data is (replicate n data[0]) in (new_data : *[n]i32, data : *[n]i32) in data) (zip datas (copy datas)) futhark-0.25.27/tests/issue460.fut000066400000000000000000000003241475065116200166050ustar00rootroot00000000000000-- == -- input { [1,2] [3,4] 2 } -- output { [1, 2] } -- input { [1,2] [3,4] 7 } -- output { [3, 4] } def main (xs: []i32) (ys: []i32) (n: i32) = map (\(x,y) -> (loop (x,y) for _i < n do (y,x)).0) (zip xs ys) futhark-0.25.27/tests/issue473.fut000066400000000000000000000002061475065116200166100ustar00rootroot00000000000000-- Projecting an array index should be permitted. -- == -- input { 0 } -- output { 0 } def main (x: i32) = let a = [(x,x)] in a[0].0 futhark-0.25.27/tests/issue481.fut000066400000000000000000000002411475065116200166060ustar00rootroot00000000000000-- == -- input { [[[1,2], [3,4]], [[2,1], [4,3]]] } -- output { [[[1,3], [2,4]], [[2,4], [1,3]]] } def main (xsss: [][][]i32): *[][][]i32 = map transpose xsss futhark-0.25.27/tests/issue483.fut000066400000000000000000000003321475065116200166110ustar00rootroot00000000000000-- == -- input { 0i64 32i64 empty([0]i32) } -- output { empty([32][0]i32) } -- input { 32i64 0i64 empty([0]i32) } -- output { empty([0][32]i32) } entry main (n: i64) (m: i64) (xs: [n*m]i32) = transpose (unflatten xs) futhark-0.25.27/tests/issue485.fut000066400000000000000000000006711475065116200166210ustar00rootroot00000000000000-- Avoid fusing the map into the scatter, because the map is reading -- from the same array that the scatter is consuming. The -- complication here is that the scatter is actually writing to an -- *alias* of the array the map is reading from. def main (n: i64) (m: i32) = let xs = iota n let ys = xs : *[n]i64 -- now ys aliases xs let vs = map (\i -> xs[(i+2)%n]) (iota n) -- read from xss in scatter ys (iota n) vs -- consume xs futhark-0.25.27/tests/issue506.fut000066400000000000000000000005471475065116200166150ustar00rootroot00000000000000-- Issue with a generated variable name that matched the name of a -- function. This program does not compute anything interesting. -- == def map2 [n] 'a 'b 'x (f: a -> b -> x) (as: [n]a) (bs: [n]b): []x = map (\(a, b) -> f a b) (zip as bs) def main (n: i64) = let on_row (row: i64) (i: i64) = replicate row i let a = iota n in map (on_row a[0]) a futhark-0.25.27/tests/issue512.fut000066400000000000000000000002531475065116200166040ustar00rootroot00000000000000-- == -- input { [1i64,2i64,3i64] } output { 4i64 } def apply 'a (f: a -> a) (x: a) = f x def f [n] (xs: [n]i64) (x: i64) = n + x def main (xs: []i64) = apply (f xs) 1 futhark-0.25.27/tests/issue514.fut000066400000000000000000000001121475065116200166000ustar00rootroot00000000000000-- == -- error: issue514.fut:4:26-36 def main = (2.0 + 3.0) / (2 + 3i32) futhark-0.25.27/tests/issue522.fut000066400000000000000000000000701475065116200166020ustar00rootroot00000000000000def main (b: bool) (xs: []i32) = if b then xs else [] futhark-0.25.27/tests/issue525.fut000066400000000000000000000003571475065116200166150ustar00rootroot00000000000000-- Test that a unique array is properly considered non-unique inside a -- lambda body (mostly so that the type annotation is correct. def main [n][m] (x: i32, a: *[n][m]i32) = let b = transpose a in map1 (\x -> b[m - x - 1]) (iota m) futhark-0.25.27/tests/issue526.fut000066400000000000000000000005241475065116200166120ustar00rootroot00000000000000-- Issue with aliasing analysis that failed to produce enough values. -- Crashed in fusion, but that wasn't where the problem actually was. def main [n][m] (t_v1: [n][m]i32) = let t_v4 = (loop t_v2 = t_v1 for _i < 100 do transpose t_v2) :> [n][m]i32 let y = map1 (\x -> t_v1[x]) (iota n) let t_v9 = map2 (map2 (==)) t_v4 y in t_v9 futhark-0.25.27/tests/issue527.fut000066400000000000000000000001751475065116200166150ustar00rootroot00000000000000-- == -- input { 2 } output { 2 } def id 'a (x: a) : a = x def main (x: i32) = let r = { id } in r.id x futhark-0.25.27/tests/issue531.fut000066400000000000000000000003601475065116200166040ustar00rootroot00000000000000-- Defunctionalisation should also look at dimensions used in type -- ascriptions. def f (ys: []i32) = map (+1) ys def main [n] (xs: [n]i32) = let g (h : []i32 -> []i32) = copy (h xs) let g' = \x -> g (\y -> (x y : [n]i32)) in g' f futhark-0.25.27/tests/issue538.fut000066400000000000000000000003301475065116200166100ustar00rootroot00000000000000-- For some reason, having a module type of the same name as a module -- causes the 'n' name to disappear. -- == module sobol_dir = { def n: i32 = 1 } module type sobol_dir = {} module A: {val n: i32} = sobol_dir futhark-0.25.27/tests/issue541.fut000066400000000000000000000001211475065116200166000ustar00rootroot00000000000000def f = \x -> let h y = y in h x entry main1 (x: i32) = f x + f x futhark-0.25.27/tests/issue544.fut000066400000000000000000000002051475065116200166060ustar00rootroot00000000000000-- This used to produce an unnecessarily unique return type on a -- lifted function. def main = ((\x -> x) <-< (\x -> x)) [1,2,3] futhark-0.25.27/tests/issue545.fut000066400000000000000000000002761475065116200166170ustar00rootroot00000000000000-- == -- error: consumption def update (xs: *[]i32) (x: i32) : *[]i32 = xs with [0] = x def apply (f: i32->[]i32) (x: i32) : []i32 = f x def main (xs: *[]i32) = apply (update xs) 2 futhark-0.25.27/tests/issue553.fut000066400000000000000000000004461475065116200166150ustar00rootroot00000000000000-- Can we handle really large tuples correctly? -- == -- input { 1 2 3 4 5 6 7 8 9 10 } -- output { 1 2 3 } def main (x1: i32) (x2: i32) (x3: i32) (x4: i32) (x5: i32) (x6: i32) (x7: i32) (x8: i32) (x9: i32) (x10: i32) = let t = (x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) in (t.0, t.1, t.2) futhark-0.25.27/tests/issue558.fut000066400000000000000000000003351475065116200166170ustar00rootroot00000000000000-- This file is *intentionally* written with DOS linebreaks (\r\n). -- Don't change it to Unix linebreaks (\n)! -- == -- input { [1,2,3] 4 } -- error: Index \[4\] out of bounds def main (a: []i32) (i: i32): i32 = a[i] futhark-0.25.27/tests/issue560.fut000066400000000000000000000002521475065116200166060ustar00rootroot00000000000000-- == -- input { [[[1,2], [3,4]],[[5,6],[7,8]]] } -- output { [[[1i32, 2i32], [5i32, 6i32]], [[3i32, 4i32], [7i32, 8i32]]] } def main (matb: [][][]i32) = transpose matb futhark-0.25.27/tests/issue561.fut000066400000000000000000000013401475065116200166060ustar00rootroot00000000000000-- == -- structure { Scatter 1 Screma 1 } def main [n_indices] (scan_num_edges: [n_indices]i64, write_inds: [n_indices]i64, active_starts: [n_indices]i32) = let flat_len = scan_num_edges[n_indices-1] let (tmp1, tmp2, tmp3) = (replicate flat_len false, replicate flat_len 0i32, replicate flat_len 1i32) let active_flags = scatter tmp1 write_inds (replicate n_indices true) let track_nodes_tmp= scatter tmp2 write_inds (map i32.i64 (iota n_indices)) let track_index_tmp= scatter tmp3 write_inds active_starts in scan (\(x,a,b) (y,c,d) -> (x || y, a+c,b+d)) (false,0,0) (zip3 active_flags track_nodes_tmp track_index_tmp) futhark-0.25.27/tests/issue567.fut000066400000000000000000000001371475065116200166170ustar00rootroot00000000000000-- Infinite loops should not crash the compiler. def main (x: i32) = loop x while true do x+1 futhark-0.25.27/tests/issue572.fut000066400000000000000000000004061475065116200166120ustar00rootroot00000000000000-- The issue was applying tiling inside a loop that does not run the -- same number of iterations for each thread in a workgroup. -- == -- input { [1,2,3,4,5] } -- output { 150 } def main (xs: []i32) = reduce (+) 0 (map (\x -> reduce (+) 0 (map (+x) xs)) xs) futhark-0.25.27/tests/issue573.fut000066400000000000000000000000701475065116200166100ustar00rootroot00000000000000def main (a: *[8]f32) : *[8]f32 = a with [:] = copy a futhark-0.25.27/tests/issue582.fut000066400000000000000000000005421475065116200166140ustar00rootroot00000000000000-- Turned out to be a bug in the fuse-across-transpose optimisation. -- == -- input { [[[1f32], [2f32]], [[3f32], [4f32]]] } -- output { [[[4.0f32, 8.0f32]], [[12.0f32, 16.0f32]]] } type two 't = [2]t def main (a: [][][]f32) : [][][]f32 = let a = map (map (map (*2))) a let a = map (map (map (*2))) <| transpose <| map transpose a in transpose a futhark-0.25.27/tests/issue589.fut000066400000000000000000000005561475065116200166300ustar00rootroot00000000000000-- The problem is that we currently do not alias xs_i to xs, but they -- really do alias once we replace the polymorphic type t with []i32. -- Required extra conservatism in the type checker. -- == -- error: in-place def swap 't (i: i32) (j: i32) (xs: *[]t) = let xs_i = xs[i] let xs[i] = xs[j] let xs[j] = xs_i in xs def main (xs: *[][]i32) = swap 0 1 xs futhark-0.25.27/tests/issue593.fut000066400000000000000000000002661475065116200166210ustar00rootroot00000000000000-- The usual problems with identity mapping. def main (xss: [][]i32) (ys: []i32) = let (as, bs) = unzip2 (map2 (\xs y -> (y, i32.sum (map (+y) xs))) xss ys) in (i32.sum as, bs) futhark-0.25.27/tests/issue596.fut000066400000000000000000000001661475065116200166230ustar00rootroot00000000000000-- == -- error: Consuming.*"xs" def consume (xs: *[]i32) = xs def main (xss: [][]i32) = map (\xs -> consume xs) xss futhark-0.25.27/tests/issue604.fut000066400000000000000000000003551475065116200166110ustar00rootroot00000000000000-- == -- input { [[0], [1]] } -- output { [[1], [0]] } def swap [n] 't (i: i32) (j: i32) (xs: *[n]t) = let xs_i = copy xs[i] let xs_j = copy xs[j] let xs[i] = xs_j let xs[j] = xs_i in xs def main (xs: *[][]i32) = swap 0 1 xs futhark-0.25.27/tests/issue605.fut000066400000000000000000000001151475065116200166040ustar00rootroot00000000000000def main (xs: *[][]i32) = let xs_1 = copy xs[1] let xs[0] = xs_1 in xs futhark-0.25.27/tests/issue608.fut000066400000000000000000000002101475065116200166030ustar00rootroot00000000000000-- == -- input { empty([0][3]i32) } output { empty([0][3]i32) } -- compiled input { [[1,2]] } error: def main (xs: [][3]i32) = xs futhark-0.25.27/tests/issue622.fut000066400000000000000000000003061475065116200166050ustar00rootroot00000000000000-- The kernel extractor tried to distribute "irregular rotations", or -- whatever you want to call them. def main [n][m][k] (xsss: [n][m][k]i32) = map (map2 (\r xs -> rotate r xs) (iota m)) xsss futhark-0.25.27/tests/issue624.fut000066400000000000000000000004041475065116200166060ustar00rootroot00000000000000-- The problem was incorrect type substitution in the monomorphiser -- which removed a uniqueness attribute. module type m = { type^ t val r: *t -> *t } module m: m = { type^ t = []f32 def r (t: *t): *t = t } entry r (t: *m.t): *m.t = m.r t futhark-0.25.27/tests/issue626.fut000066400000000000000000000003411475065116200166100ustar00rootroot00000000000000-- == -- error: consuming module type m1 = { type t val get: *t -> f32 -> f32 } module m1 : m1 = { type t = [1]f32 def get (t: *t) (v: f32): f32 = t[0] } entry read (t: *[1]m1.t): []f32 = map2 m1.get t [0] futhark-0.25.27/tests/issue627.fut000066400000000000000000000002521475065116200166120ustar00rootroot00000000000000-- == -- error: val f module type m = { type t val f: t -> () } module m: m = { type t = [1]f32 def f (_t: *t): () = () } entry f (t: m.t): () = m.f t futhark-0.25.27/tests/issue643.fut000066400000000000000000000001341475065116200166070ustar00rootroot00000000000000-- == -- input { empty([0][0]i32) } -- output { 0i64 } def main [n][m] (xs: [n][m]i32) = m futhark-0.25.27/tests/issue649.fut000066400000000000000000000002441475065116200166170ustar00rootroot00000000000000-- The problem was that invalid size parameters were passed to the -- 'length' function after internalisation. def main = loop xs = [0] while length xs == 0 do xs futhark-0.25.27/tests/issue656.fut000066400000000000000000000012221475065116200166120ustar00rootroot00000000000000-- This program fuses into (among other things) a streamSeq, and we -- had a bug where the result of the first-order-transformed streamSeq -- might alias some other array (specifically 'xs'), and the result -- then went on to be consumed in the 'scatter's. def main [n] (xs:[n]i32) (is:[n]i32) = let bits1 = map (&1) xs let bits0 = map (1-) bits1 let idxs0 = map2 (*) bits0 (scan (+) 0 bits0) let idxs1 = scan (+) 0 bits1 let offs = reduce (+) 0 bits0 let idxs1 = map2 (*) bits1 (map (+offs) idxs1) let idxs = map (\x->x-1) (map2 (+) idxs0 idxs1) in (scatter (copy xs) (map i64.i32 idxs) xs, scatter (copy is) (map i64.i32 idxs) is) futhark-0.25.27/tests/issue661.fut000066400000000000000000000001311475065116200166040ustar00rootroot00000000000000def f (b: f32): (i32, []f32) = (0, [b]) def main (i: i32) = (\i -> f (f32.i32 i)) i futhark-0.25.27/tests/issue667.fut000066400000000000000000000005161475065116200166210ustar00rootroot00000000000000type point = (f32, f32) def f (p1: point) (p2: point) = let points = [p1,p2] let isingrid = [(f32.abs(p1.1) <= 1), (f32.abs(p2.1) <= 2)] let (truepoints, _) = unzip( filter (\(_, x) -> x) (zip points isingrid)) in truepoints[0] entry main (p1s: []point) (p2s: []point) = map2 (\p1 p2 -> f p1 p2) p1s p2s futhark-0.25.27/tests/issue672.fut000066400000000000000000000001241475065116200166100ustar00rootroot00000000000000def main (xs: *[]i32) = let x = xs[0] let xs[0] = xs[1] let xs[0] = x in xs futhark-0.25.27/tests/issue679.fut000066400000000000000000000000771475065116200166260ustar00rootroot00000000000000def main (xs: [](i32,i32)) = let y = xs[0] with 1 = 0 in y futhark-0.25.27/tests/issue680.fut000066400000000000000000000001011475065116200166020ustar00rootroot00000000000000def main (xs: *[]i32) = (xs with [1] = xs[0]) with [1] = xs[1] futhark-0.25.27/tests/issue681.fut000066400000000000000000000002161475065116200166120ustar00rootroot00000000000000type dir = #up | #down def move (x: i32) (d: dir) = match d case #down -> x+1 case #up -> x-1 def main x = (move x #up, move x #down) futhark-0.25.27/tests/issue682.fut000066400000000000000000000001371475065116200166150ustar00rootroot00000000000000def main (i: i32) = let xs = [0] let a = xs[0] let xs[i] = a let xs[i] = xs[0] in xs futhark-0.25.27/tests/issue706.fut000066400000000000000000000001601475065116200166060ustar00rootroot00000000000000-- == -- input { [true, false] } -- output { [1f32, f32.nan] } def main = map (\x -> if x then 1 else f32.nan) futhark-0.25.27/tests/issue708.fut000066400000000000000000000010061475065116200166100ustar00rootroot00000000000000-- The internaliser logic for flattening out multidimensional array -- literals was not reconstructing the original dimensions properly. def insert [n] 't (np1: i64) (x: t) (a: [n]t) (i: i64): [np1]t = let (b,c) = (take i a, drop i a) in b ++ [x] ++ c :> [np1]t def list_insertions [n] 't (np1: i64) (x: t) (a: [n]t): [n][np1]t = map (insert np1 x a) (iota n) def main [n] (a: [n][3]u8): [][n][3]u8 = (loop p = [[head a]] for i in (1...n-1) do flatten (map (list_insertions (n+1) a[i]) p)) :> [][n][3]u8 futhark-0.25.27/tests/issue709.fut000066400000000000000000000005341475065116200166160ustar00rootroot00000000000000-- == -- input { 0 } output { [[[0]]] } def insert [n] 't (np1: i64) (x: t) (a: [n]t) (i: i64): [np1]t = let (b,c) = (take i a, drop i a) in b ++ [x] ++ c :> [np1]t def list_insertions [n] 't (np1: i64) (x: t) (a: [n]t): [np1][np1]t = map (insert np1 x a) (0...(length a)) :> [np1][np1]t def main (x: i32) = map (list_insertions 1 x) [[]] futhark-0.25.27/tests/issue712.fut000066400000000000000000000001131475065116200166010ustar00rootroot00000000000000def main (x: i32) (y: i32) = let t = (x,y) let f g = g t.0 in f (+2) futhark-0.25.27/tests/issue715.fut000066400000000000000000000004401475065116200166070ustar00rootroot00000000000000-- == -- input { true } output { [42,0] [0,0] } -- input { false } output { [42,1] [1,1] } def main (b: bool): ([]i32, []i32) = let (xs, ys) = if b then (replicate 2 0, replicate 2 0) else (replicate 2 1, replicate 2 1) let xs[0] = 42 in (xs, ys) futhark-0.25.27/tests/issue728.fut000066400000000000000000000017421475065116200166210ustar00rootroot00000000000000def expand_with_flags [n] 'a 'b (b: b) (sz: a -> i32) (get: a -> i32 -> b) (arr:[n]a) : ([]b, []bool) = ([], []) type^ csr 't = {row_off: []i32, col_idx: []i32, vals: []t} def expand_reduce 'a 'b [n] (sz: a -> i32) (get: a -> i32 -> b) (f: b -> b -> b) (ne:b) (arr:[n]a) : *[]b = let (vals, flags) = expand_with_flags ne sz get arr in [] def smvm ({row_off,col_idx,vals} : csr i32) (v:[]i32) = let rows = map (\i -> (i, row_off[i], row_off[i+1]-row_off[i])) (iota(length row_off - 1)) let sz r = r.2 let get r i = vals[r.1+i] * v[col_idx[r.1+i]] in expand_reduce sz get (+) 0 rows def m_csr : csr i32 = {row_off=[0,3,5,8,9,11], -- size 6 col_idx=[0,1,3,1,2,1,2,3,3,3,4], -- size 11 vals=[1,2,11,3,4,5,6,7,8,9,10]} -- size 11 def v : []i32 = [3,1,2,6,5] def main (_ : i32) : []i32 = smvm m_csr (copy v) futhark-0.25.27/tests/issue743.fut000066400000000000000000000003601475065116200166110ustar00rootroot00000000000000-- Spurious size annotations maintained by defunctionaliser. -- == def get xs (i: i64) = xs[i] def test (xs: []i64) (l: i64): [l]i64 = let get_at xs indices = map (get xs) indices in get_at xs (iota l) def main = test (iota 4) 2 futhark-0.25.27/tests/issue750.fut000066400000000000000000000014721475065116200166140ustar00rootroot00000000000000def flatten_to [n][m] 't (k: i64) (xs: [n][m]t): [k]t = flatten xs :> [k]t def main [n] (as: [100]i32) (bs: [100]i32) (is: [4]i32) (xsss : [][n][]f32) = let m = 9 * n in #[unsafe] map(\xss -> let (ysss, zsss) = unzip <| map(\xs -> let foo = reduce (\i j -> if xs[i] < xs[j] then i else j) 0 is in (replicate 12 (replicate 12 foo), replicate 12 (replicate 12 xs[0]))) xss let vss = map2 (\a b -> map (\zss -> zss[a:a+3, b:b+3] |> flatten_to 9) zsss |> flatten_to m) (map i64.i32 as) (map i64.i32 bs) in (ysss, vss)) xsss futhark-0.25.27/tests/issue762.fut000066400000000000000000000010101475065116200166030ustar00rootroot00000000000000def main [N] [D] [K] [triD] (x: [N][D]f64) (means: [K][D]f64) (qs: [K][D]f64) (ls: [K][triD]f64) = let xs = map (\x' -> unzip3 (tabulate K (\k -> (map2 (-) x' means[k], qs[k], ls[k])))) x let a = map (.0) xs let b = reduce (map2 (map2 (+))) (map (map (const 0)) means) (map (.1) xs) let c = reduce (map2 (map2 (+))) (map (map (const 0)) ls) (map (.2) xs) in (a, map (map (+2)) b, c) futhark-0.25.27/tests/issue763.fut000066400000000000000000000021621475065116200166150ustar00rootroot00000000000000type vector = (f64,f64) def add(v1: vector)(v2: vector): vector = let(a,b) = v1 let(c,d) = v2 in (a+c, b+d) def mult(f: f64)(v: vector): vector = let (a,b) = v in (f*a, f*b) def dotprod(v1: vector, v2: vector): f64 = let(a,b) = v1 let(c,d) = v2 in a*c + b*d def square(v: vector): f64 = dotprod(v,v) def init_matrix 't (nx: i64)(ny: i64)(x: t): [nx][ny]t = map( \(_) -> map( \(_):t -> x ) (0.. map( \(y) -> replicate 9 (f_eq(rho[x,y], u[x,y], g, tau)) ) (0.. map (\xs -> loop z = i for _p < 10 do i32.sum (map (+z) xs)) xss) is futhark-0.25.27/tests/issue774.fut000066400000000000000000000005451475065116200166220ustar00rootroot00000000000000-- In-place lowering should be careful about updates where the new -- value potentially aliases the old one. type t = [8]u32 def pack [n] (zero: t) (xs: [n]bool): t = loop ret = zero for i in 0.. replicate 2 (x >= 4) |> pack zero) (iota 10) futhark-0.25.27/tests/issue780.fut000066400000000000000000000001141475065116200166070ustar00rootroot00000000000000-- == -- input { 2 3 } output { 6 } def main (x: i32) (y: i32) = ((*) x) y futhark-0.25.27/tests/issue782.fut000066400000000000000000000007111475065116200166140ustar00rootroot00000000000000module type bar = { type bar 'f } module bar_f32 = { type bar 'f = f32 } module type foo = { type foo_in val foo: foo_in -> f32 } module foo_f32 = { type foo_in = f32 def foo (x: foo_in): f32 = x } type some_type = i8 module wrapper (bar: bar) (foo: foo with foo_in = bar.bar some_type) = { def baz (x: bar.bar some_type): f32 = foo.foo x } module wrapped = wrapper bar_f32 foo_f32 def main (s: f32): f32 = wrapped.baz s futhark-0.25.27/tests/issue793.fut000066400000000000000000000062441475065116200166250ustar00rootroot00000000000000-- Hit a bug in the tiling logic for splitting the loop prelude. -- types type Sphere = {pos: [3]f32, radius: f32, color: [4]u8} type Intersection = {t: f32, index: i64, prim: u8} -- constants def DROP_OFF = 100f32 -- ray intersection primitive cases def P_NONE = 0:u8 def P_SPHERE = 1:u8 def P_LIGHT = 2:u8 def P_POLYGON = 3:u8 -- render functions: def dot [n] (a: [n]f32) (b: [n]f32): f32 = reduce (+) 0 (map2 (*) a b) def sphereIntersect (rayO: [3]f32) (rayD: [3]f32) (s: Sphere): f32 = let d = map2 (-) s.pos rayO let b = dot d rayD let c = (dot d d) - s.radius * s.radius let disc = b * b - c in if (disc < 0) then DROP_OFF else let t = b - f32.sqrt disc in if (0 < t) then t else DROP_OFF -- render function def render [nspheres] [nlights] (dim: [2]i64) (spheres: [nspheres]Sphere) (lights: [nlights]Sphere) : [][4]u8 = -- return a color for each pixel let pixIndices = iota (dim[0] * dim[1]) in map (\i -> -- for each pixel let coord = [i %% dim[0], i // dim[0]] let rayD: [3]f32 = [f32.i64 dim[0], f32.i64 (coord[0] - dim[0] / 2), f32.i64 (dim[1] / 2 - coord[1])] let rayO: [3]f32 = [0, 0, 0] -- sphere intersections let sInts: []Intersection = map3 (\t index prim -> {t, index, prim}) -- using instead of zip to create a record (map (\sphere -> sphereIntersect rayO rayD sphere ) spheres) (iota nspheres) (replicate nspheres P_SPHERE) -- light intersections let lInts: []Intersection = map3 (\t index prim -> {t, index, prim}) (map (\light -> sphereIntersect rayO rayD light ) lights) (iota nlights) (replicate nlights P_LIGHT) -- closest intersection and corresponding primitive index let min: Intersection = reduce (\min x-> if x.t < min.t then x else min ) {t = DROP_OFF, index = 0i64, prim = P_NONE} (concat sInts lInts) -- return color in if (min.prim == P_SPHERE) then (spheres[min.index].color) else if (min.prim == P_LIGHT) then (lights[min.index].color) else [0:u8, 0:u8, 0:u8, 0:u8] ) pixIndices -- entry point def main [s] (width: i64) (height: i64) -- spheres and lights (numS: i64) (numL: i64) (sPositions: [s][3]f32) (sRadii: [s]f32) (sColors: [s][4]u8) -- return pixel color : [][4]u8 = -- combine data for render function let totalS = numS + numL let k = totalS - numS let spheres = map3 (\p r c -> {pos = p, radius = r, color = c}) sPositions[0 : numS] sRadii[0 : numS] sColors[0 : numS] let lights = map3 (\p r c -> {pos = p, radius = r, color = c}) (sPositions[numS : totalS] :> [k][3]f32) (sRadii[numS : totalS] :> [k]f32) (sColors[numS : totalS] :> [k][4]u8) in render [width, height] spheres lights futhark-0.25.27/tests/issue795.fut000066400000000000000000000010511475065116200166160ustar00rootroot00000000000000def main (r_sigma: f32) (I_tiled: [][][]f32) = let nz' = i64.f32 (1/r_sigma + 0.5) let bin v = i64.f32 (v/r_sigma + 0.5) let intensity cell = reduce_by_index (replicate nz' 0) (+) 0 (cell |> map bin) (map ((*256) >-> i64.f32) cell) |> map (f32.i64 >-> (/256)) let count cell = reduce_by_index (replicate nz' 0) (+) 0 (cell |> map bin) (map (const 1) cell) in map2 (map2 zip) (map (map intensity) I_tiled) (map (map count) I_tiled) futhark-0.25.27/tests/issue798.fut000066400000000000000000000001551475065116200166250ustar00rootroot00000000000000-- Variables should not be in scope in their own type. module m = { type t = i32 } def main (m: m.t) = m futhark-0.25.27/tests/issue811.fut000066400000000000000000000003261475065116200166070ustar00rootroot00000000000000-- == -- random input { [2][3][4]f32 } auto output def foo [n][m] (A: [n][m]f32): [n][m]f32 = (loop A for _i < n do let irow = A[0] let Ap = A[1:n] in concat Ap [irow]) :> [n][m]f32 def main = map foo futhark-0.25.27/tests/issue812.fut000066400000000000000000000004061475065116200166070ustar00rootroot00000000000000def foo [n] (m: i64) (A: [n][n]i32) = let on_row row i = let padding = replicate n 0 let padding[i] = 10 in concat row padding :> [m]i32 in map2 on_row A (iota n) def main [n] (As: [][n][n]i32) = map (foo (n*2)) As futhark-0.25.27/tests/issue813.fut000066400000000000000000000002701475065116200166070ustar00rootroot00000000000000-- == -- input { 4 } output { [3] } def ilog2 (x: i32) = 31 - i32.clz x def main (n: i32) = let m = ilog2 n let id = 1<<(m-1) let indexes = id-1..id*2-1...n-1 in indexes[1:] futhark-0.25.27/tests/issue814.fut000066400000000000000000000000511475065116200166050ustar00rootroot00000000000000def main (n: i64) = map ((-) n) (iota n) futhark-0.25.27/tests/issue815.fut000066400000000000000000000000571475065116200166140ustar00rootroot00000000000000def main : (i32) = loop _ = 0 for i < 1 do i futhark-0.25.27/tests/issue816.fut000066400000000000000000000011001475065116200166030ustar00rootroot00000000000000-- == -- input { [[[1f32, 2f32, 3f32], [4f32, 5f32, 6f32], [7f32, 8f32, 9f32]]] } -- output { [[[1.0f32, 2.0f32, 3.0f32], [1.0f32, 2.0f32, 3.0f32], [2.0f32, 4.0f32, 6.0f32]]] } def main [m][b] (peri_batch_mat: [m][b][b]f32) = map (\peri_mat -> let mat = copy peri_mat in loop mat for im1 < (b-1) do #[unsafe] let i = im1 + 1 let row_sum = loop row_sum = replicate b 0 for j < i do map2 (+) row_sum mat[j] let mat[i] = row_sum in mat ) peri_batch_mat futhark-0.25.27/tests/issue825.fut000066400000000000000000000001131475065116200166060ustar00rootroot00000000000000-- == -- entry: a -- random input { 0u8 } auto output entry a (i: u8) = 0 futhark-0.25.27/tests/issue826.fut000066400000000000000000000002601475065116200166120ustar00rootroot00000000000000def main (xss: [][]i32) = map (\xs -> let sum = i32.sum xs let xs' = copy xs let xs'[0] = sum let xs'[1] = sum in xs') xss futhark-0.25.27/tests/issue829.fut000066400000000000000000000001111475065116200166100ustar00rootroot00000000000000def main (xs: *[1][1]i32) (ys: [1]i32) : *[1][1]i32 = xs with [0] = ys futhark-0.25.27/tests/issue844.fut000066400000000000000000000004231475065116200166130ustar00rootroot00000000000000-- == -- error: Consuming.*"xs" module type mt = { type~ t val mk : i32 -> *t val f : *t -> *t } module m : mt = { type~ t = []i32 def mk (x: i32) = [x] def f (xs: *[]i32) = xs with [0] = xs[0] + 1 } def main (x: i32) = let f = \xs -> m.f xs in f (m.mk x) futhark-0.25.27/tests/issue845.fut000066400000000000000000000003721475065116200166170ustar00rootroot00000000000000module type mt = { type~ t val mk : i32 -> *t val f : *t -> *t } module m : mt = { type~ t = []i32 def mk (x: i32) = [x] def f (xs: *[]i32) = xs with [0] = xs[0] + 1 } def main (x: i32) = let f = \(xs: *m.t) -> m.f xs in f (m.mk x) futhark-0.25.27/tests/issue847.fut000066400000000000000000000002441475065116200166170ustar00rootroot00000000000000-- Tiling bug. def main (acc: []i64) (c: i64) (n:i64) = let is = map (+c) (iota n) let fs = map (\i -> reduce (+) 0 (map (+(i+c)) acc)) (iota n) in (fs, is) futhark-0.25.27/tests/issue848.fut000066400000000000000000000020541475065116200166210ustar00rootroot00000000000000type vector = {x:f32, y:f32, z:f32} type triangle = (vector, vector, vector) entry generate_terrain [depth] [width] (points: [depth][width]vector) = let n = width - 1 let n2 = n * 2 let triangles = let m = depth - 1 in map3 (\row0 row1 i -> let (row0', row1') = if i % 2 == 0 then (row0, row1) else (row1, row0) in flatten (map4 (\c0 c0t c1 c1t -> let tri0 = (c0, c0t, c1) let tri1 = (c1, c1t, c0t) in [tri0, tri1]) (row0'[:width-1] :> [n]vector) (row0'[1:] :> [n]vector) (row1'[:width-1] :> [n]vector) (row1'[1:] :> [n]vector)) :> [n2]triangle) (points[:depth-1] :> [m][width]vector) (points[1:] :> [m][width]vector) ((0.. [m]i64) in triangles futhark-0.25.27/tests/issue869.fut000066400000000000000000000006551475065116200166310ustar00rootroot00000000000000-- The "fix" for this in the internaliser was actually a workaround -- for a type checker bug (#1565). -- == -- error: Initial loop values do not have expected type. def matmult [n][m][p] (x: [n][m]f32) (y: [m][p]f32) : [n][p]f32 = map (\xr -> map (\yc -> reduce (+) 0 (map2 (*) xr yc) ) (transpose y) ) x def main [n][m] (x: [n][m]f32) : [][]f32 = loop x for i < m-1 do matmult x[1:,:] x[:,1:] futhark-0.25.27/tests/issue872.fut000066400000000000000000000002201475065116200166070ustar00rootroot00000000000000type t = #foo | #bar def main (x:i32) = match #foo : t case #foo -> let xs = filter (>x) [1,2,3] in length xs case #bar -> 0 futhark-0.25.27/tests/issue873.fut000066400000000000000000000001661475065116200166210ustar00rootroot00000000000000type^ myType = #myVal (i32 -> i32) def main = match ((\m -> #myVal (\_ -> m)) 0 : myType) case #myVal m -> m 1 futhark-0.25.27/tests/issue879.fut000066400000000000000000000002551475065116200166260ustar00rootroot00000000000000-- == -- error: Consuming.*"s" def f (xs: [10]i32) : [10]i32 = xs def main (s: [10]i32) : *[10]i32 = let s = f s let s = loop s for _i < 10 do f s in s with [0] = 0 futhark-0.25.27/tests/issue880.fut000066400000000000000000000003401475065116200166110ustar00rootroot00000000000000-- == -- error: Consuming.*"xs" type t = {xs: [10]i32} def f ({xs}: t) : t = {xs = xs} def g (s: t) = let s = f s let s = loop s = f s for _i < 10 do f s in s def main xs = let {xs} = g {xs} in xs with [0] = 0 futhark-0.25.27/tests/issue895.fut000066400000000000000000000001201475065116200166130ustar00rootroot00000000000000entry a = let scan' op ne as = scan op ne as in scan' (+) 0 [] entry b = a futhark-0.25.27/tests/issue921.fut000066400000000000000000000005261475065116200166130ustar00rootroot00000000000000-- == -- structure gpu { SegMap 2 } def main (b1: bool) (b2: bool) (xs: [3]i32) (ys: [3]i32) = map (\x -> if b1 then map (\y -> if b2 then (map (+(x+y)) ys, xs) else (xs, ys)) ys else replicate 3 (ys, xs)) xs |> map unzip |> unzip futhark-0.25.27/tests/issue931.fut000066400000000000000000000001441475065116200166100ustar00rootroot00000000000000type~ g2 = #g2 ([]i32) | #nog def foo2 (x: g2): g2 = match x case #nog -> #g2 [] case _ -> x futhark-0.25.27/tests/issue941.fut000066400000000000000000000004441475065116200166140ustar00rootroot00000000000000type sometype 't = #someval t def geni32 (maxsize : i64) : sometype i64 = #someval maxsize def genarr 'elm (genelm: i64 -> sometype elm) (ownsize : i64) : sometype ([ownsize](sometype elm)) = #someval (tabulate ownsize genelm) def main = genarr geni32 1 futhark-0.25.27/tests/issue942.fut000066400000000000000000000003771475065116200166220ustar00rootroot00000000000000-- == -- input {} output { [0i64] } type sometype 't = #someval t def f (size : i64) (_ : i32) : sometype ([size]i64) = #someval (iota size) def apply '^a '^b (f: a -> b) (x: a) = f x def main : [1]i64 = match apply (f 1) 0 case #someval x -> x futhark-0.25.27/tests/issue992.fut000066400000000000000000000001401475065116200166130ustar00rootroot00000000000000-- == -- input { [1,2,3] } output { [1,2,3] } def main [n] (xs: [n]i32) = reverse (reverse xs) futhark-0.25.27/tests/issue995.fut000066400000000000000000000003611475065116200166230ustar00rootroot00000000000000def render (color_fun : i64 -> i32) (h : i64) (w: i64) : []i32 = tabulate h (\i -> color_fun i) def get [n] (arr: [n][n]i32) (i : i64) : i32 = arr[i,i] def main [n] mode (arr: [n][n]i32) = if mode then [] else render (get arr) n n futhark-0.25.27/tests/linear_solve.fut000066400000000000000000000016641475065116200177150ustar00rootroot00000000000000-- Solving a linear system using Gauss-Jordan elimination without pivoting. -- -- Taken from https://www.cs.cmu.edu/~scandal/nesl/alg-numerical.html#solve -- -- == -- input { [[1.0f32, 2.0f32, 1.0f32], [2.0f32, 1.0f32, 1.0f32], [1.0f32, 1.0f32, 2.0f32]] -- [1.0f32, 2.0f32, 3.0f32] } -- output { [0.5f32, -0.5f32, 1.5f32] } def Gauss_Jordan [n][m] (A: [n][m]f32): [n][m]f32 = (loop (A) for i < n do let irow = A[0] let Ap = A[1:n] let v1 = irow[i] let irow = map (/v1) irow let Ap = map (\jrow -> let scale = jrow[i] in map2 (\x y -> y - scale * x) irow jrow) Ap in Ap ++ [irow]) :> [n][m]f32 def linear_solve [n][m] (A: [n][m]f32) (b: [n]f32): [n]f32 = -- Pad the matrix with b. let Ap = map2 concat A (transpose [b]) let Ap' = Gauss_Jordan Ap -- Extract last column. in Ap'[0:n,m] def main [n][m] (A: [n][m]f32) (b: [n]f32): [n]f32 = linear_solve A b futhark-0.25.27/tests/localfunction0.fut000066400000000000000000000002421475065116200201420ustar00rootroot00000000000000-- A simple locally defined function. -- == -- input { [1,2,3] } output { [3,4,5] } def main [n] (a: [n]i32) = let add_two (x: i32) = x + 2 in map add_two a futhark-0.25.27/tests/localfunction1.fut000066400000000000000000000003071475065116200201450ustar00rootroot00000000000000-- A simple locally defined function. This one has free variables. -- == -- input { 3 [1,2,3] } output { [4,5,6] } def main [n] (x: i32) (a: [n]i32) = let add_x (y: i32) = x + y in map add_x a futhark-0.25.27/tests/localfunction10.fut000066400000000000000000000003671475065116200202330ustar00rootroot00000000000000-- Polymorphic local function with size annotations - that's easy to -- mess up! -- == -- input { [[1f32]] } output { [[1f32]] } def main (kr: [][]f32): [][]f32 = let f 'a [r][c] (arr: [r][c]a): [r][c]a = flatten arr |> unflatten in f kr futhark-0.25.27/tests/localfunction11.fut000066400000000000000000000002451475065116200202270ustar00rootroot00000000000000-- Local functions should not affect aliasing. def main (ops: []i32) (exs: []i32) = let correct op ex = op == ex in ops |> filter (\op -> all (correct op) exs) futhark-0.25.27/tests/localfunction12.fut000066400000000000000000000001741475065116200202310ustar00rootroot00000000000000-- Local function used in operator section. def main (x: i32) (y: i32) (z: i32) = let add x y = x + y + z in x `add` y futhark-0.25.27/tests/localfunction2.fut000066400000000000000000000004021475065116200201420ustar00rootroot00000000000000-- The same name for a local function in two places should not cause -- trouble. -- == -- input { 3 } output { 6 0 } def f1 (x: i32) = let g (y: i32) = x + y in g x def f2 (x: i32) = let g (y: i32) = x - y in g x def main(x: i32) = (f1 x, f2 x) futhark-0.25.27/tests/localfunction3.fut000066400000000000000000000002601475065116200201450ustar00rootroot00000000000000-- Local functions can be shadowed. -- == -- input { 3 } output { 10 } def main(x: i32) = let f (y: i32) = y + 2 let x = f x let f (y: i32) = y * 2 let x = f x in x futhark-0.25.27/tests/localfunction4.fut000066400000000000000000000005371475065116200201550ustar00rootroot00000000000000-- A local function whose closure refers to an array whose size is -- *not* used inside the local function. -- == -- input { 2i64 0 } output { 1i64 } def main(n: i64) (x: i32) = let a = map (1+) (iota n) let f (i: i32) = #[unsafe] a[i] -- 'unsafe' to prevent an assertion -- that uses the array length. in f x futhark-0.25.27/tests/localfunction5.fut000066400000000000000000000004131475065116200201470ustar00rootroot00000000000000-- Shape-bound variables used inside a local function, but where the -- array itself is not used. def f(n: i64) = replicate n 0 def main [n] (lower_bounds: [n]f64) = let rs = f n let init_i [n] (rs: [n]i32) = map (\j -> lower_bounds[j]) (iota n) in init_i rs futhark-0.25.27/tests/localfunction6.fut000066400000000000000000000001151475065116200201470ustar00rootroot00000000000000def main(n: i32) = let f (i: i32) = (loop (i) while i < n do i+1) in f 2 futhark-0.25.27/tests/localfunction7.fut000066400000000000000000000002621475065116200201530ustar00rootroot00000000000000-- A local function can refer to a global variable. -- == -- input { 1 } output { 5 } def two = 2 def main (x: i32) = let add_two (y: i32) = y + two in add_two (add_two x) futhark-0.25.27/tests/localfunction8.fut000066400000000000000000000003621475065116200201550ustar00rootroot00000000000000-- It is OK for a local function to alias a local array. -- == -- input { true } output { [1,2,3] } -- input { false } output { empty([0]i32) } def main b = let global: []i32 = [1,2,3] let f (b: bool) = if b then global else [] in f b futhark-0.25.27/tests/localfunction9.fut000066400000000000000000000003071475065116200201550ustar00rootroot00000000000000-- Test some subtle things about aliasing when abstract types are involved. def divergence 'real (op: real -> real -> real) (c0: real) = let next (c, i) = ((c `op` c), i + 1) in next (c0, 0i32) futhark-0.25.27/tests/loops/000077500000000000000000000000001475065116200156405ustar00rootroot00000000000000futhark-0.25.27/tests/loops/for-in0.fut000066400000000000000000000001751475065116200176350ustar00rootroot00000000000000-- Basic for-in loop. -- == -- input { [1,2,3,4,5] } -- output { 15 } def main(xs: []i32) = loop a=0 for x in xs do a + x futhark-0.25.27/tests/loops/for-in1.fut000066400000000000000000000003001475065116200176240ustar00rootroot00000000000000-- For-in loop where iota should be optimised away. -- == -- input { 5i64 } -- output { 4i64 } -- structure { Iota 0 } def main(n: i64) = let xs = iota n in loop a=0 for x in xs do a ^ x futhark-0.25.27/tests/loops/for-in2.fut000066400000000000000000000003271475065116200176360ustar00rootroot00000000000000-- For-in loop where replicate should be optimised away. -- == -- input { 5i64 } -- output { 99i64 } -- structure { Replicate 0 } def main(n: i64) = let xs = replicate n n in loop a=0 for x in xs do (a<<1) ^ x futhark-0.25.27/tests/loops/for-in3.fut000066400000000000000000000003441475065116200176360ustar00rootroot00000000000000-- For-in loop where map and iota should be optimised away. -- == -- input { 5i64 } -- output { 2i64 } -- structure { Iota 0 Map 0 } def main(n: i64) = let xs = map (2*) (map (1+) (iota n)) in loop a=0 for x in xs do a ^ x futhark-0.25.27/tests/loops/for-in4.fut000066400000000000000000000002151475065116200176340ustar00rootroot00000000000000-- For-in over 2D array. -- == -- input { [[1],[2],[3]] } -- output { 6 } def main (xss: [][]i32) = loop a = 0 for xs in xss do a + xs[0] futhark-0.25.27/tests/loops/loop-error0.fut000066400000000000000000000006201475065116200205360ustar00rootroot00000000000000-- == -- error: aliases previously returned value def main(): ([]f64,[][]f64) = let e_rows = [] let arr = copy(e_rows) let acc = copy([1.0]) in loop (acc, arr) = (acc, arr) for i < length arr do let arr[i] = let y = arr[i] let x = acc in [2.0] -- Error, because 'arr[i]' and 'arr' are aliased, yet the latter -- is consumed. in (arr[i], arr) futhark-0.25.27/tests/loops/loop0.fut000066400000000000000000000002561475065116200174140ustar00rootroot00000000000000-- Simplest interesteing loop - factorial function. -- == -- input { -- 10 -- } -- output { -- 3628800 -- } def main(n: i32): i32 = loop x = 1 for i < n do x * (i + 1) futhark-0.25.27/tests/loops/loop1.fut000066400000000000000000000004411475065116200174110ustar00rootroot00000000000000-- == -- input { -- } -- output { -- [1, 5, 9] -- } def main: []i32 = let arr = [(0,1), (2,3), (4,5)] let n = length arr let outarr = replicate n (0,0) in let outarr = loop outarr = outarr for i < n do let outarr[i] = arr[i] in outarr in map (\(x,y) -> x + y) outarr futhark-0.25.27/tests/loops/loop10.fut000066400000000000000000000010111475065116200174630ustar00rootroot00000000000000-- This program exposed an obscure problem in the type-checking of -- loops with existential sizes (before inlining). def step [n] (cost: *[n]i32, updating_graph_mask : *[n]bool) : (*[n]i32, *[n]bool) = (cost, updating_graph_mask) def main (n: i32, cost: *[]i32, updating_graph_mask: *[]bool) = loop (cost, updating_graph_mask) = (cost, updating_graph_mask) while updating_graph_mask[0] do let (cost', updating_graph_mask') = step(cost, updating_graph_mask) in (cost', updating_graph_mask') futhark-0.25.27/tests/loops/loop11.fut000066400000000000000000000004641475065116200174770ustar00rootroot00000000000000-- Tests a loop that simplification once messed up royally. -- == -- input { 50 1.1 } -- output { 14 } def p(c: f64): bool = c < 4.0 def f(x: f64, y: f64): f64 = x * y def main (depth: i32) (a: f64): i32 = let (c,i) = loop (c, i) = (a, 0) while i < depth && p(c) do (f(a, c), i + 1) in i futhark-0.25.27/tests/loops/loop12.fut000066400000000000000000000006601475065116200174760ustar00rootroot00000000000000-- This loop is interesting because only the final size is actually -- used. The simplifier can easily mess this up. Contrived? Yes, -- but code generators sometimes do this. -- -- == -- input { 0 [1] } output { 1i64 } -- input { 1 [1] } output { 2i64 } -- input { 2 [1] } output { 4i64 } -- input { 3 [1] } output { 8i64 } def main (n: i32) (as: []i32): i64 = let as = loop (as) for _i < n do concat as as in length as futhark-0.25.27/tests/loops/loop13.fut000066400000000000000000000004431475065116200174760ustar00rootroot00000000000000-- Loops can use any signed integral type def main (x: i8): i64 = let x = loop x = x for i < 0x80i8 do x + i let x = loop x = i16.i8 x for i < 0x80i16 do x + i let x = loop x = i32.i16 x for i < 0x80i32 do x + i let x = loop x = i64.i32 x for i < 0x80i64 do x + i in x futhark-0.25.27/tests/loops/loop14.fut000066400000000000000000000002651475065116200175010ustar00rootroot00000000000000-- We should be able to handle a for-loop with negative bound. -- == -- input { 1 } output { 1 } -- input { -1 } output { 0 } def main (n: i32) = loop x = 0i32 for _i < n do x + 1 futhark-0.25.27/tests/loops/loop15.fut000066400000000000000000000003051475065116200174750ustar00rootroot00000000000000-- Simple case; simplify away the loops. -- == -- input { 10 2 } output { 2 } -- structure { Loop 0 } def main (n: i32) (a: i32) = loop x = a for _i < n do loop _y = x for _j < n do a futhark-0.25.27/tests/loops/loop16.fut000066400000000000000000000004651475065116200175050ustar00rootroot00000000000000-- Complex case; simplify away the loops. -- == -- input { 10 2i64 [1,2,3] } -- output { [1,2] } -- structure { Loop 0 } def main (n: i32) (a: i64) (arr: []i32) = #[unsafe] -- Just to make the IR cleaner. loop x = take a arr for _i < n do loop _y = take (length x) arr for _j < n do take a arr futhark-0.25.27/tests/loops/loop18.fut000066400000000000000000000002631475065116200175030ustar00rootroot00000000000000-- Constant-forming a loop with a non-default type. -- == -- input { 10i16 } -- output { 100i16 } -- structure { Loop 0 } def main (x: i16) = loop acc = 0 for i < x do acc + x futhark-0.25.27/tests/loops/loop2.fut000066400000000000000000000002611475065116200174120ustar00rootroot00000000000000-- A loop that doesn't involve in-place updates or even arrays. -- == -- input { -- 42 -- } -- output { -- 861 -- } def main(n: i32): i32 = loop x = 0 for i < n do x + i futhark-0.25.27/tests/loops/loop3.fut000066400000000000000000000004271475065116200174170ustar00rootroot00000000000000-- == -- input { -- 42i64 -- } -- output { -- 820i64 -- } def main(n: i64): i64 = let a = iota(1) in let a = loop a for i < n do let b = replicate n 0 in -- Error if hoisted outside loop. loop c=b for j < i do let c[0] = c[0] + j in c in a[0] futhark-0.25.27/tests/loops/loop4.fut000066400000000000000000000005301475065116200174130ustar00rootroot00000000000000-- Nasty loop whose size cannot be predicted in advance. -- == -- input { -- [1,2,3] -- 4 -- } -- output { -- [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3] -- } def main (xs: []i32) (n: i32): []i32 = loop (xs) for i < n do concat xs xs futhark-0.25.27/tests/loops/loop5.fut000066400000000000000000000003661475065116200174230ustar00rootroot00000000000000-- == -- input { -- } -- output { -- [0i64, 1i64, 3i64, 6i64, 10i64, 15i64, 21i64, 28i64, 36i64, 45i64] -- } def main: []i64 = let n = 10 let x = iota(n) in loop (x) for i < n-1 do let x[i+1] = x[i+1] + x[i] in x futhark-0.25.27/tests/loops/loop6.fut000066400000000000000000000011561475065116200174220ustar00rootroot00000000000000-- This loop nest was derived from an LU factorisation program, and -- exposed a bug in a simplification rule. It does not compute -- anything interesting. -- -- Specifically, the bug was in the detection of loop-invariant -- variables - an array might be considered loop-invariant, even -- though some of its existential parameters (specifically shape -- arguments) are not considered loop-invariant (due to missing copy -- propagation). -- == def main [n] (a: *[n][]f64, u: *[][]f64): ([][]f64, [][]f64) = loop ((a,u)) for k < n do let u[k,k] = a[k,k] in let a = loop (a) for i < n-k do a in (a,u) futhark-0.25.27/tests/loops/loop7.fut000066400000000000000000000007171475065116200174250ustar00rootroot00000000000000-- This loop is intended to trigger a bug in the in-place-lowering -- optimisation. It requires proper maintaining of the loop result -- ordering. -- == def main(n: i64, i: i32, x: f64): [][]f64 = let res = replicate n (replicate n 0.0) let (u, uu) = (replicate n 0.0, replicate n 0.0) in let (u,x) = loop ((u, x)) for i < n-1 do let y = x + 1.0 let u[i] = u[i] * y in (u, y) let res[i] = u in res futhark-0.25.27/tests/loops/loop8.fut000066400000000000000000000003551475065116200174240ustar00rootroot00000000000000-- Test that we do not clobber loop merge parameters while they may -- still be used. -- == -- input { 1 2 0 3 } -- output { 2 1 3 } def main(x: i32) (y: i32) (z: i32) (n: i32): (i32, i32, i32) = loop ((x,y,z)) for i < n do (y,x,z+1) futhark-0.25.27/tests/loops/loop9.fut000066400000000000000000000002211475065116200174150ustar00rootroot00000000000000-- Test that we can remove a single-iteration loop. -- == -- structure { Loop 0 } def main(x: i32, y: i32): i32 = loop (x) for i < 1 do x + y futhark-0.25.27/tests/loops/pow2reduce.fut000066400000000000000000000007121475065116200204370ustar00rootroot00000000000000-- Tree reduction that only works on input that is a power of two. -- == -- input { [1,2,3,4] } -- output { 10 } def step [k] (xs: [2**k]i32) : [2**(k-1)]i32 = tabulate (2**(k-1)) (\i -> xs[i*2] + xs[i*2+1]) def sum [k] (xs: [2**k]i32) : i32 = head (loop xs for i in reverse (iota k) do step (xs :> [2**(i+1)]i32)) def ilog2 (n: i64) : i64 = i64.i32 (63 - i64.clz n) def main [n] (xs: [n]i32) = let k = ilog2 n in sum (xs :> [2**k]i32) futhark-0.25.27/tests/loops/while-loop0.fut000066400000000000000000000002511475065116200205150ustar00rootroot00000000000000-- Simplest test of while loops. -- == -- input { -- 1 -- 9 -- } -- output { -- 16 -- } def main (x: i32) (bound: i32): i32 = loop (x) while x < bound do x * 2 futhark-0.25.27/tests/loops/while-loop1.fut000066400000000000000000000004351475065116200205220ustar00rootroot00000000000000-- Test a while loop that has an array merge variable and checks it in -- its condition. -- == -- input { -- [1,2,3,4,5,6] -- 3 -- 10 -- } -- output { -- [7, 8, 9, 10, 11, 12] -- } def main (a: []i32) (i: i32) (bound: i32): []i32 = loop (a) while a[i] < bound do map (+1) a futhark-0.25.27/tests/loops/while-loop2.fut000066400000000000000000000006451475065116200205260ustar00rootroot00000000000000-- While-loop with a condition that consumes something that it has allocated itself. -- == -- input { -- [5i64,4i64,2i64,8i64,1i64,9i64,9i64] -- 4i64 -- } -- output { -- [5i64, 4i64, 2i64, 8i64, 6i64, 9i64, 9i64] -- } def pointlessly_consume(x: i64, a: *[]i64): bool = x < reduce (+) 0 a def main (a: *[]i64) (i: i64): []i64 = loop (a) while pointlessly_consume(a[i], iota(i)) do let a[i] = a[i] + 1 in a futhark-0.25.27/tests/loops/while-loop3.fut000066400000000000000000000002031475065116200205150ustar00rootroot00000000000000-- == -- input { 1 } output { 11 false } def main (x: i32) = loop (x, continue) = (x, true) while continue do (x+1, x < 10) futhark-0.25.27/tests/lss.fut000066400000000000000000000024101475065116200160220ustar00rootroot00000000000000-- Parallel longest satisfying segment -- -- Written as a function that is parameterisered over the satisfaction -- property. Cannot handle empty input arrays. -- -- == -- input { [1, -2, -2, 3, 4, -1, 5, -6, 1] } -- output { 4 } -- input { [5, 4, 3, 2, 1] } -- output { 1 } -- input { [1, 2, 3, 4, 5] } -- output { 5 } -- The two relations must describe a transitive relation. def lss [n] 't (pred1: t -> bool) (pred2: t -> t -> bool) (xs: [n]t) = let max = i32.max let redOp (lssx, lisx, lcsx, tlx, firstx, lastx) (lssy, lisy, lcsy, tly, firsty, lasty) = let connect = pred2 lastx firsty || tlx == 0 || tly == 0 let newlss = if connect then max (lcsx + lisy) (max lssx lssy) else max lssx lssy let newlis = if lisx == tlx && connect then lisx + lisy else lisx let newlcs = if lcsy == tly && connect then lcsy + lcsx else lcsy let first = if tlx == 0 then firsty else firstx let last = if tly == 0 then lastx else lasty in (newlss, newlis, newlcs, tlx+tly, first, last) let mapOp x = let xmatch = if pred1 x then 1 else 0 in (xmatch, xmatch, xmatch, 1, x, x) in (reduce redOp (0,0,0,0,xs[0],xs[0]) (map mapOp xs)).0 def main (xs: []i32): i32 = lss (\_ -> true) (\(x: i32) y -> x <= y) xs futhark-0.25.27/tests/lu-factorisation.fut000066400000000000000000000015521475065116200205120ustar00rootroot00000000000000-- Compute LU-factorisation of matrix. -- == -- input { -- [[4.0,3.0],[6.0,3.0]] -- } -- output { -- [[1.000000, 0.000000], -- [1.500000, 1.000000]] -- [[4.000000, 3.000000], -- [0.000000, -1.500000]] -- } def lu_inplace [n] (a: *[n][]f64): (*[][]f64, *[][]f64) = let (_,l,u) = loop (a,l,u) = (a, replicate n (replicate n 0.0), replicate n (replicate n 0.0)) for k < n do let u[k,k] = a[k,k] in let (l,u) = loop (l,u) for i < n-k do let l[i+k,k] = a[i+k,k]/u[k,k] let u[k,i+k] = a[k,i+k] in (l,u) let a = loop a for i < n-k do loop a for j < n-k do let a[i+k,j+k] = a[i+k,j+k] - l[i+k,k] * u[k,j+k] in a in (a,l,u) in (l,u) def main(a: [][]f64): ([][]f64, [][]f64) = lu_inplace(copy(a)) futhark-0.25.27/tests/man/000077500000000000000000000000001475065116200152575ustar00rootroot00000000000000futhark-0.25.27/tests/man/bench/000077500000000000000000000000001475065116200163365ustar00rootroot00000000000000futhark-0.25.27/tests/man/bench/example1.fut000066400000000000000000000004561475065116200205770ustar00rootroot00000000000000-- How quickly can we reduce arrays? -- -- == -- nobench input { 0i64 } -- output { 0i64 } -- input { 100i64 } -- output { 4950i64 } -- compiled input { 10000i64 } -- output { 49995000i64 } -- compiled input { 1000000i64 } -- output { 499999500000i64 } def main(n: i64): i64 = reduce (+) 0 (iota n) futhark-0.25.27/tests/man/test/000077500000000000000000000000001475065116200162365ustar00rootroot00000000000000futhark-0.25.27/tests/man/test/example1.fut000066400000000000000000000003171475065116200204730ustar00rootroot00000000000000-- Test simple indexing of an array. -- == -- tags { firsttag secondtag } -- input { [4,3,2,1] 1i64 } -- output { 3 } -- input { [4,3,2,1] 5i64 } -- error: Error* def main (a: []i32) (i: i64): i32 = a[i] futhark-0.25.27/tests/man/test/example2.fut000066400000000000000000000004141475065116200204720ustar00rootroot00000000000000def add (x: i32) (y: i32): i32 = x + y -- Test the add1 function. -- == -- entry: add1 -- input { 1 } output { 2 } entry add1 (x: i32): i32 = add x 1 -- Test the sub1 function. -- == -- entry: sub1 -- input { 1 } output { 0 } entry sub1 (x: i32): i32 = add x (-1) futhark-0.25.27/tests/man/test/example3.fut000066400000000000000000000002331475065116200204720ustar00rootroot00000000000000-- == -- random input { [100]i32 [100]i32 } auto output -- random input { [1000]i32 [1000]i32 } auto output def main xs ys = i32.product (map2 (*) xs ys) futhark-0.25.27/tests/manifest.fut000066400000000000000000000003211475065116200170260ustar00rootroot00000000000000-- Test that manifest is not optimised away. -- == -- input { [[1,2,3], [4,5,6], [7,8,9]] } output { 8 } -- structure { Manifest 1 } entry main (xs: [][]i32) = let ys = manifest (transpose xs) in ys[1,2] futhark-0.25.27/tests/manylet.fut000066400000000000000000000004751475065116200167030ustar00rootroot00000000000000-- Parser test. 'in' is optional except at the end of a chain of -- let-bindings. def main [n] (a: *[n]i32, x: i32): [n]i32 = let y = x + 2 let z = y + 3 + x let (a,_) = loop ((a,z)) for i < n do let tmp = z * z let a[i] = tmp let x = [a[i]-1] let b = scatter a [i] x in (b, tmp+2) in a futhark-0.25.27/tests/map_tridag_par.fut000066400000000000000000000123501475065116200201760ustar00rootroot00000000000000-- A map of a parallel tridag. This is intended to test -- parallelisation of the inner scans and maps. The real test for -- this is LocVolCalib. -- -- == -- compiled input { 1000i64 256i64 } -- -- output { [0.010000f32, 0.790000f32, 2.660000f32, -- 21474836.000000f32, 21474836.000000f32, 21474836.000000f32, -- 21474836.000000f32, 21474836.000000f32, 21474836.000000f32, -- 5625167.000000f32] } -- -- no_python compiled input { 100i64 2560i64 } -- -- output { [0.000000f32, 0.120000f32, 0.260000f32, 0.430000f32, -- 0.620000f32, 0.840000f32, 1.110000f32, 1.440000f32, 1.840000f32, -- 2.360000f32] } -- -- no_python compiled input { 10i64 25600i64 } -- -- output { [0.000000f32, 0.110000f32, 0.250000f32, 0.410000f32, -- 0.590000f32, 0.800000f32, 1.040000f32, 1.340000f32, 1.710000f32, -- 2.170000f32] } def tridagPar [n] (a: [n]f32, b: []f32, c: []f32, y: []f32 ): *[n]f32 = ---------------------------------------------------- -- Recurrence 1: b[i] = b[i] - a[i]*c[i-1]/b[i-1] -- -- solved by scan with 2x2 matrix mult operator -- ---------------------------------------------------- let b0 = b[0] let mats = map (\(i: i32): (f32,f32,f32,f32) -> if 0 < i then (b[i], 0.0-a[i]*c[i-1], 1.0, 0.0) else (1.0, 0.0, 0.0, 1.0) ) (map i32.i64 (iota n)) let scmt = scan (\(a: (f32,f32,f32,f32)) (b: (f32,f32,f32,f32)): (f32,f32,f32,f32) -> let (a0,a1,a2,a3) = a let (b0,b1,b2,b3) = b let value = 1.0/(a0*b0) in ( (b0*a0 + b1*a2)*value, (b0*a1 + b1*a3)*value, (b2*a0 + b3*a2)*value, (b2*a1 + b3*a3)*value ) ) (1.0, 0.0, 0.0, 1.0) mats let b = map (\(tup: (f32,f32,f32,f32)): f32 -> let (t0,t1,t2,t3) = tup in (t0*b0 + t1) / (t2*b0 + t3) ) scmt ------------------------------------------------------ -- Recurrence 2: y[i] = y[i] - (a[i]/b[i-1])*y[i-1] -- -- solved by scan with linear func comp operator -- ------------------------------------------------------ let y0 = y[0] let lfuns= map (\(i: i32): (f32,f32) -> if 0 < i then (y[i], 0.0-a[i]/b[i-1]) else (0.0, 1.0 ) ) (map i32.i64 (iota n)) let cfuns= scan (\(a: (f32,f32)) (b: (f32,f32)): (f32,f32) -> let (a0,a1) = a let (b0,b1) = b in ( b0 + b1*a0, a1*b1 ) ) (0.0, 1.0) lfuns let y = map (\(tup: (f32,f32)): f32 -> let (a,b) = tup in a + b*y0 ) cfuns ------------------------------------------------------ -- Recurrence 3: backward recurrence solved via -- -- scan with linear func comp operator -- ------------------------------------------------------ let yn = y[n-1]/b[n-1] let lfuns= map (\(k: i32): (f32,f32) -> let i = i32.i64 n-k-1 in if 0 < k then (y[i]/b[i], 0.0-c[i]/b[i]) else (0.0, 1.0 ) ) (map i32.i64 (iota n)) let cfuns= scan (\(a: (f32,f32)) (b: (f32,f32)): (f32,f32) -> let (a0,a1) = a let (b0,b1) = b in (b0 + b1*a0, a1*b1) ) (0.0, 1.0) lfuns let y = map (\(tup: (f32,f32)): f32 -> let (a,b) = tup in a + b*yn ) cfuns let y = map (\i: f32 -> y[n-i-1]) (iota n) in y def map_tridag_par [inner][outer] (myD: [inner][3]f32, myDD: [inner][3]f32, myMu: [outer][inner]f32, myVar: [outer][inner]f32, u: [outer][inner]f32, dtInv: f32 ): *[][]f32 = map3 (\mu_row var_row u_row -> let (a,b,c) = unzip3 (map4 (\mu var d dd: (f32,f32,f32) -> ( 0.0 - 0.5*(mu*d[0] + 0.5*var*dd[0]) , dtInv - 0.5*(mu*d[1] + 0.5*var*dd[1]) , 0.0 - 0.5*(mu*d[2] + 0.5*var*dd[2]) ) ) mu_row var_row myD myDD) in tridagPar( a, b, c, u_row ) ) myMu myVar u -- To avoid floating-point jitter. def trunc2dec (x: f32) = f32.abs (f32.i32 (i32.f32 (x*100.0))/100.0) def main (outer: i64) (inner: i64) = let myD = replicate inner [0.10, 0.20, 0.30] let myDD = replicate inner [0.20, 0.30, 0.40] let scale (s: i64) (x: i64) = f32.i64 (s+x) / f32.i64 inner let scale_row (s: i64) (i: i64) (row: [inner]i64) = map (scale (s+i)) row let myMu = map2 (scale_row 1) (iota outer) (replicate outer (iota inner)) let myVar = map2 (scale_row 2) (iota outer) (replicate outer (iota inner)) let u = map2 (scale_row 3) (iota outer) (replicate outer (iota inner)) let dtInv = 0.8874528f32 let res = map_tridag_par (myD, myDD, myMu, myVar, u, dtInv) in map (\i -> trunc2dec (res[i*(outer/10), i*(inner/10)])) (iota 10) futhark-0.25.27/tests/mapconcat.fut000066400000000000000000000007521475065116200171750ustar00rootroot00000000000000-- Mapping a concatenation is turned into a single concat. -- == -- input { [[1,2],[4,5],[7,8]] [[3],[6],[9]] [[3],[6],[9]] } -- output { [[1,2,3,3],[4,5,6,6],[7,8,9,9]] } -- -- input { [[1,2],[4,5],[7,8]] [[3,2,1],[6,5,4],[9,8,7]] [[0],[3],[6]] } -- output { [[1,2,3,2,1,0],[4,5,6,5,4,3],[7,8,9,8,7,6]] } -- structure { Map 0 Concat 1 } def main [a][b][c] (xss: [][a]i32) (yss: [][b]i32) (zss: [][c]i32) = let n = a + b + c in map3 (\xs ys zs -> xs ++ ys ++ zs :> [n]i32) xss yss zss futhark-0.25.27/tests/mapiota.fut000066400000000000000000000001551475065116200166570ustar00rootroot00000000000000-- iota cannot be mapped. -- == -- error: type containing anonymous sizes def main(ns: []i64) = map iota ns futhark-0.25.27/tests/mapmatmultfun.fut000066400000000000000000000017231475065116200201210ustar00rootroot00000000000000-- Mapping matrix multiplication, written in a functional style. -- This is primarily a test of tiling. -- == -- input { -- [[ [1,2], [3,4] ], -- [ [5,6], [7,8] ]] -- [[ [1,2], [3,4] ], -- [ [5,6], [7,8] ]] -- } -- output { -- [[[7i32, 10i32], -- [15i32, 22i32]], -- [[67i32, 78i32], -- [91i32, 106i32]]] -- } -- compiled random input { [1][16][16]i32 [1][16][16]i32 } auto output -- compiled random input { [2][16][32]i32 [2][32][16]i32 } auto output -- compiled random input { [3][32][16]i32 [3][16][32]i32 } auto output -- compiled random input { [4][128][17]i32 [4][17][128]i32 } auto output -- structure { /Screma 1 /Screma/Screma 1 /Screma/Screma/Screma 1 /Screma/Screma/Screma/Screma 1 } def matmult [n][m][p] (x: [n][m]i32) (y: [m][p]i32): [n][p]i32 = map (\xr -> map (\yc -> reduce (+) 0 (map2 (*) xr yc)) (transpose y)) x def main [k][n][m][p] (xs: [k][n][m]i32) (ys: [k][m][p]i32): [k][n][p]i32 = map2 matmult xs ys futhark-0.25.27/tests/mapreplicate.fut000066400000000000000000000002371475065116200176740ustar00rootroot00000000000000-- replicate can be mapped. -- == -- input { 2i64 [true,false] } output { [[true,true],[false,false]] } def main (n: i64) (xs: []bool) = map (replicate n) xs futhark-0.25.27/tests/mapslice.fut000066400000000000000000000004251475065116200170220ustar00rootroot00000000000000-- == -- input { 2i64 [1,2,3,4,5,6,7,8,9] } -- output { [[1i32, 2i32, 3i32], [3i32, 4i32, 5i32]] } -- structure gpu { SegMap 1 } def main (n: i64) (xs: []i32) = tabulate n (\i -> let ys = #[unsafe] xs[i:i+3] :> [3]i32 in map (+i32.i64 i) ys) futhark-0.25.27/tests/matmultimp.fut000066400000000000000000000011011475065116200174060ustar00rootroot00000000000000-- Matrix multiplication written imperatively. Very slow when using -- GPU backends. -- -- == -- input { -- [ [1,2], [3,4] ] -- [ [5,6], [7,8] ] -- } -- output { -- [ [ 19 , 22 ] , [ 43 , 50 ] ] -- } def matmult [m][o][n] (a: [m][o]i32, b: [o][n]i32): [m][n]i32 = let res = replicate m (replicate n 0) in loop res for i < m do loop res for j < n do let partsum = loop partsum = 0 for k < o do partsum + a[i,k] * b[k,j] let res[i,j] = partsum in res def main (x: [][]i32) (y: [][]i32): [][]i32 = matmult(x, y) futhark-0.25.27/tests/matmultrepa.fut000066400000000000000000000022151475065116200175570ustar00rootroot00000000000000-- Matrix multiplication written in a Repa-like style. -- == -- input { -- [ [1,2], [3,4] ] -- [ [5,6], [7,8] ] -- } -- output { -- [ [ 19 , 22 ] , [ 43 , 50 ] ] -- } -- compiled random input { [16][16]i32 [16][16]i32 } auto output -- compiled random input { [16][32]i32 [32][16]i32 } auto output -- compiled random input { [32][16]i32 [16][32]i32 } auto output -- compiled random input { [31][32]i32 [32][31]i32 } auto output -- compiled random input { [128][17]i32 [17][128]i32 } auto output -- structure { /Screma 1 /Screma/Screma 1 Screma/Screma/Screma 1 } def redplus1(a: []i32): i32 = reduce (+) 0 a def redplus2 [n][m] (a: [n][m]i32): [n]i32 = map redplus1 a def mul1 [m] (a: [m]i32, b: [m]i32): [m]i32 = map2 (*) a b def mul2 [n][m] (a: [n][m]i32, b: [n][m]i32): [n][m]i32 = map mul1 (zip a b) def replin [m] (n: i64) (a: [m]i32): [n][m]i32 = replicate n a def matmultFun [n][m] (a: [n][m]i32, b: [m][n]i32 ): [n][n]i32 = let br = replicate n (transpose b) let ar = map (replin n) a let abr = map mul2 (zip ar br) in map redplus2 abr def main [n][m] (x: [n][m]i32) (y: [m][n]i32): [n][n]i32 = matmultFun(x, y) futhark-0.25.27/tests/memory-block-merging/000077500000000000000000000000001475065116200205325ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/000077500000000000000000000000001475065116200226415ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/chain/000077500000000000000000000000001475065116200237235ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/chain/blk-chain2.fut000066400000000000000000000021661475065116200263620ustar00rootroot00000000000000-- Example of chaining coalescing inside a block -- == -- input { 0 -- 0 -- 0 -- [1, 2, 3, 4] -- [[1, 2, 4, 5], [3, 4, 5 ,6]] -- [ [ [1, 2, 3, 4], [3, 4, 5, 6] ] -- , [ [5, 6, 7, 8], [7, 8, 9, 9] ] -- , [ [9, 9, 9, 9], [8, 8, 8, 8] ] -- , [ [7, 7, 7, 7], [6, 6, 6, 6] ] -- ] -- } -- output { [ [ [1, 2, 3, 4], [3, 4, 5, 6] ] -- , [ [5, 6, 7, 8], [7, 8, 9, 9] ] -- , [ [9, 9, 9, 9], [8, 8, 8, 8] ] -- , [ [1, 2, 3, 4], [4, 5, 6, 7] ] -- ] -- } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 3 } -- Needs two-dimensional overlap checking let main [m] [n] (i1: i32) (i2: i32) (k: i32) (a: [n]i32) (v: [m][n]i32) (z: *[n][m][n]i32) : *[n][m][n]i32 = let u = map (\x -> map (+1) x) v let b = map (+i1) a let ind1 = z[k,i1,i2] - b[0] let u[ind1] = b -- This should not coalesce let c = map (+i2) a let u[i1+i2] = c -- Coalescing -- let z[i1+i2+k] = u let z[k + i32.i64 m + 1] = u -- Coalescing in z futhark-0.25.27/tests/memory-block-merging/coalescing/chain/blk-chain3.fut000066400000000000000000000015401475065116200263560ustar00rootroot00000000000000-- Example of chaining coalescing inside a block -- == -- input { 0 -- 0 -- 0 -- [1, 2] -- [[1, 2], [3, 4], [5, 6]] -- [ [ [1, 2], [3, 4], [5, 6] ] -- , [ [5, 6], [7, 8], [9, 0] ] ] -- } -- output { [ [ [2, 3], [1, 2], [1, 2] ] -- , [ [5, 6], [7, 8], [9, 0] ] -- ] -- } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 1 } -- This is blk-chain2.fut with the alternative ind1 let main [m] [n] (i1: i32) (i2: i32) (k: i32) (a: [n]i32) (v: [m][n]i32) (z: *[n][m][n]i32) : *[n][m][n]i32 = let u = map (\x -> map (+1) x) v let b = map (+i1) a let ind1 = i1 + 1 let u[ind1] = b -- Coalescing. let c = map (+i2) a let u[i1+2] = c -- Coalescing let z[i1+i2+k] = u in z futhark-0.25.27/tests/memory-block-merging/coalescing/chain/blk-chain4.fut000066400000000000000000000013111475065116200263530ustar00rootroot00000000000000-- Test coalescing with 'Index' index functions where the first slice dimension -- index is fixed ('i' in the code below). This needs to be handled separately -- from 'DimSlice'. -- == -- input { [0, 3, 5] } -- output { [[ 0, 0, 0], -- [ 6, 6, 6], -- [10, 10, 10], -- [10, 10, 10], -- [13, 13, 13], -- [15, 15, 15]] -- } -- structure seq-mem { Alloc 3 } -- structure gpu-mem { Alloc 5 } -- The GPU pipeline has additional allocations for the two 'replicate' -- expressions. let main [n] (a: [n]i32): [][n]i32 = let x = map (\i -> replicate n (i + 10)) a |> opaque let a2 = map (\i -> replicate n (2 * i)) a let y = concat a2 x in y futhark-0.25.27/tests/memory-block-merging/coalescing/chain/chain0.fut000066400000000000000000000011311475065116200256010ustar00rootroot00000000000000-- Memory block merging with a chain of two copies (the second copy is -- technically a concat, but mostly acts as a copy). Requires allocation -- hoisting. -- == -- input { [7, 0, 7] } -- output { [7, 0, 7, 8, 1, 8] } -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } let main [n] (ns: [n]i32): []i32 = -- Will initially be set to use the memory of t1. Will end up using the -- memory of t2 through t1. let t0 = map (+ 1) ns -- Will use the second part of the memory of t2. let t1 = copy t0 -- Will be the only remaining memory block. let t2 = concat ns t1 in t2 futhark-0.25.27/tests/memory-block-merging/coalescing/chain/chain1.fut000066400000000000000000000015651475065116200256150ustar00rootroot00000000000000-- Here is a chain of two in-place coalescings. The compiler needs to keep -- track of both slices to properly coalesce everything. -- == -- input { [2, 4] -- 0 -- 1 -- } -- output { [[[2, 2], -- [3, 5]], -- [[1, 1], -- [1, 1]]] -- } -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } let main [n] (ns: [n]i32) (i: i32) (j: i32): [n][n][n]i32 = -- Create arrays into which other arrays can be coalesced. Two allocations, -- but both will use the same allocation after coalescing. let wsss = replicate n (replicate n (replicate n 1)) let vss = replicate n (replicate n 2) -- Create the "base" array. let xs = map (+ 1) ns -- xs can be coalesced into vss[j]. let vss[j] = xs -- vss (and thereby xs) can be coalesced into wsss[i]. let wsss[i] = vss in wsss futhark-0.25.27/tests/memory-block-merging/coalescing/chain/chain2.fut000066400000000000000000000025611475065116200256130ustar00rootroot00000000000000-- Here is a chain of potential coalescings, but only the first one can work, so -- we can only remove one allocation. -- == -- input { [[[1, 1], [1, 1]], [[1, 1], [1, 1]]] -- [2, 4] -- 0 -- 1 -- } -- output { [[[2, 2], [3, 5]], [[2, 2], [2, 2]]] -- [[[5, 5], [5, 5]], [[5, 5], [5, 5]]] -- } -- structure seq-mem { Alloc 5 } -- structure gpu-mem { Alloc 5 } let main [n] (wsss0: [n][n][n]i32) (ns: [n]i32) (i: i32) (j: i32): ([n][n][n]i32, [n][n][n]i32) = -- Create arrays into which other arrays can be coalesced. let wsss = map (\wss -> map (\ws -> map (+ 1) ws) wss) wsss0 let vss = replicate n (replicate n 2) -- Create the "base" array. let xs = map (+ 1) ns -- Use wsss after the creation of vss. To ensure that use_wsss does not get -- hoisted (who knows?), also use a value from the previous array in the -- expression. let k = xs[0] let use_wsss = map (\wss -> map (\ws -> map (+ k) ws) wss) wsss -- xs can be coalesced into vss[j]. let vss[j] = xs -- vss (and xs) cannot be coalesced into wsss[i], since wsss is used (in -- use_wsss) after the creation of vss and before this expression. If we -- performed this coalescing, the use_wsss expression would use some values -- from vss instead of only values from wsss. let wsss[i] = vss in (wsss, use_wsss) futhark-0.25.27/tests/memory-block-merging/coalescing/chain/chain3.fut000066400000000000000000000024301475065116200256070ustar00rootroot00000000000000-- Here is a chain of potential coalescings, but only the second one can work, so -- we can only remove one allocation. -- == -- input { [[2, 2], [2, 2]] -- [2, 4] -- 0 -- 1 -- } -- output { [[[4, 4], [3, 5]], [[1, 1], [1, 1]]] -- [[7, 7], [7, 7]] -- } -- structure seq-mem { Alloc 3 } -- structure gpu-mem { Alloc 4 } let main [n] (vss0: [n][n]i32) (ns: [n]i32) (i: i32) (j: i32) : ([n][n][n]i32, [n][n]i32) = -- Create arrays into which other arrays can be coalesced. let wsss = replicate n (replicate n (replicate n 1)) let vss = map (\vs -> map (+ 2) vs) vss0 -- Create the "base" array. let xs = map (+ 1) ns -- Use vss after the creation of xs. To ensure that use_vss does not get -- hoisted (who knows?), also use a value from the previous array in the -- expression. let k = xs[0] let use_vss = map (\vs -> map (+ k) vs) vss -- xs cannot be coalesced into vss[j], since vss is used (in use_vss) after -- the creation of xs and before this expression. If we performed this -- coalescing, the use_vss expression would use some values from xs instead of -- only values from vss. let vss[j] = xs -- vss can be coalesced into wsss[i]. let wsss[i] = vss in (wsss, use_vss) futhark-0.25.27/tests/memory-block-merging/coalescing/chain/chain4.fut000066400000000000000000000055461475065116200256230ustar00rootroot00000000000000-- A more complex example. -- -- This is an example where too aggressive allocation hoisting is a bad thing. -- If implemented without the current limiter, this would happen: To enable an -- eventual coalescing into vss, both the allocation for the memory of vss, -- *and* the vss array creation statement itself, are hoisted upwards as much as -- possible. This hinders the later coalescing into wsss, and was never useful -- to begin with, since there can be no coalescing into vss regardless. -- -- To fix this in a nicer way, we could perform allocation hoisting while we do -- coalescing instead of before we do it, which would help at least for this -- program. Maybe? -- == -- input { [[[1, 1], [1, 1]], [[1, 1], [1, 1]]] -- [2, 4] -- 0 -- 1 -- } -- output { [[[2, 2], [3, 5]], [[2, 2], [2, 2]]] -- [[[5, 5], [5, 5]], [[5, 5], [5, 5]]] -- } -- structure seq-mem { Alloc 5 } -- structure gpu-mem { Alloc 5 } let main [n] (wsss0: [n][n][n]i32) (ns: [n]i32) (i: i32) (j: i32): ([n][n][n]i32, [n][n][n]i32) = -- Create an array into which other arrays can be coalesced. let wsss = map (\wss -> map (\ws -> map (+ 1) ws) wss) wsss0 -- Create the "base" array. let xs = map (+ 1) ns -- Use wsss after the creation of xs, but *before* the creation of vss. We -- want to consider the case where either xs can be coalesced into vss[j], -- *or* vss can be coalesced into wsss[i], but that both are not possible, so -- that the compiler will have to make a choice. -- -- If xs is coalesced into vss[j], vss will then alias xs, and so wsss will be -- used after an array creation in the memory of vss -- this will not happen -- if xs is not coalesced, since then the first array creation in the memory -- of vss will occur *after* use_wsss. let k = xs[0] let use_wsss = map (\wss -> map (\ws -> map (+ k) ws) wss) wsss -- Create another coalescing-enabling array. let vss = replicate n (replicate n 2) -- xs cannot be coalesced into vss[j], since vss is used (in the previous -- statement) after the creation of xs. This is where we see that we cannot -- end up with the ambigious case that we were trying to provoke: Since we -- want the coalescing into wsss[i] to fail only if the coalescing into vss[j] -- succeeds (and to otherwise succeed), we have to create wsss between xs and -- vss, in which case vss will be created (and thus used) after xs, which -- means that xs cannot even be coalesced into vss[j] in the first place. let vss[j] = xs -- vss can be coalesced into wsss[i]. let wsss[i] = vss in (wsss, use_wsss) -- For safety condition 3: This shows that it is always okay (FIXME: proper -- proof) to optimistically coalesce two memory blocks through a top-down -- traversal, since that will never restrict later coalescings. Try proof by -- contradiction? futhark-0.25.27/tests/memory-block-merging/coalescing/chain/chain5.fut000066400000000000000000000013141475065116200256110ustar00rootroot00000000000000-- Memory block merging with a chain that uses all three coalescing-enabled -- constructs. -- == -- input { [7, 0, 7] } -- output { [[0, 0, 0, 0, 0, 0], [7, 0, 7, 8, 1, 8]] } -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } let main [n] (ns: [n]i32): [][]i32 = -- Will be the only remaining memory block. let t3 = replicate 2 (replicate (n+n) 0) -- Will initially be set to use the memory of t1. Will end up using the -- memory of t3 through t2 through t1. let t0 = map (+ 1) ns -- Will use the second part of index 1 of the memory of t3 through the memory -- of t2. let t1 = copy t0 -- Will use index 1 of the memory of t3. let t2 = concat ns t1 let t3[1] = t2 in t3 futhark-0.25.27/tests/memory-block-merging/coalescing/concat/000077500000000000000000000000001475065116200241105ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/concat/iotas.fut000066400000000000000000000003651475065116200257530ustar00rootroot00000000000000-- == -- entry: concat_iotas -- input { 2i64 4i64 } -- output { [ 0i64, 1i64, 0i64, 1i64, 2i64, 3i64 ] } -- structure gpu-mem { Alloc 1 } -- structure seq-mem { Alloc 1 } entry concat_iotas (i: i64) (j: i64): []i64 = concat (iota i) (iota j) futhark-0.25.27/tests/memory-block-merging/coalescing/concat/neg0.fut000066400000000000000000000010731475065116200254620ustar00rootroot00000000000000-- Should not short circuit because t3 is created after the others -- == -- input { [5, 15] -- 0 -- } -- output { [[6, 16, 10, 30, 1, 5], -- [0, 0, 0, 0, 0, 0]] -- } -- structure seq-mem { Alloc 4 } -- structure gpu-mem { Alloc 4 } let main [n] (ns: [n]i32) (i: i32): [][]i32 = let t0 = map (+ 1) ns -- Will use the memory of t3. let t1 = map (* 2) ns -- Will use the memory of t3. let t2 = map (/ 3) ns -- Will use the memory of t3. let t3 = unflatten (replicate (n * (n+n+n)) 0) let t3[i] = concat (concat t0 t1) t2 in t3 futhark-0.25.27/tests/memory-block-merging/coalescing/concat/pos0.fut000066400000000000000000000010211475065116200255030ustar00rootroot00000000000000-- Memory block merging with a concat of multiple arrays. Requires allocation -- hoisting of the memory block for 't3'. -- == -- input { [5, 15] } -- output { [6, 16, 10, 30, 1, 5] } -- structure seq-mem { Alloc 3 } -- structure gpu-mem { Alloc 1 } let main (ns: []i32): []i32 = let t0 = map (+ 1) ns -- Will use the memory of t4. let t1 = map (* 2) ns -- Will use the memory of t4. let t2 = map (/ 3) ns -- Will use the memory of t4. let t3 = concat t0 t1 -- Will use the memory of t4. let t4 = concat t3 t2 in t4 futhark-0.25.27/tests/memory-block-merging/coalescing/concat/pos1.fut000066400000000000000000000012221475065116200255070ustar00rootroot00000000000000-- Memory block merging with a concat of multiple arrays into a multidimensional -- array. Requires allocation hoisting of the memory block for 't3'. -- == -- input { [5, 15] -- 0i64 -- } -- output { [[6, 16, 10, 30, 1, 5], -- [0, 0, 0, 0, 0, 0]] -- } -- structure seq-mem { Alloc 3 } -- structure gpu-mem { Alloc 1 } let main [n] (ns: [n]i32) (i: i64): [][]i32 = let t3 = unflatten (replicate (n * (n+n+n)) 0) let t0 = map (+ 1) ns -- Will use the memory of t3. let t1 = map (* 2) ns -- Will use the memory of t3. let t2 = map (/ 3) ns -- Will use the memory of t3. let t3[i] = concat (concat t0 t1) t2 in t3 futhark-0.25.27/tests/memory-block-merging/coalescing/concat/pos2.fut000066400000000000000000000016021475065116200255120ustar00rootroot00000000000000-- Memory block merging with a concat of two arrays of different sources into a -- multidimensional array. -- == -- input { [ [1i32, 1i32, 1i32, 1i32] -- , [1i32, 1i32, 1i32, 1i32] -- ] -- [3, 7] -- [8, 9] -- } -- output { -- [ [4i32, 8i32, 9i32, 10i32] -- , [1i32, 1i32, 1i32, 1i32] -- ] -- } -- structure seq-mem { Alloc 0 } -- structure gpu-mem { Alloc 0 } let main [n][m][k] (y: *[n][m+k]i32) (a: [m]i32) (b: [k]i32): *[n][m+k]i32 = let a1 = map (+1) a -- Will use the memory of z, and thereby y[0]. let b1 = map (+1) b -- Will use the memory of z, and thereby y[0]. let z = concat a1 b1 -- Will use the memory of y[0]. -- There will be inserted a safety reshape here. let y[0] = z -- y is not allocated in this body, so there will be no -- allocations left after the optimisation. in y futhark-0.25.27/tests/memory-block-merging/coalescing/concat/pos3.fut000066400000000000000000000005251475065116200255160ustar00rootroot00000000000000-- Memory block merging with a concat of two arrays, but in the different order -- than they were created. -- == -- input { [5, 15] } -- output { [10, 30, 6, 16] } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 1 } let main (ns: []i32): []i32 = let t0 = map (+ 1) ns let t1 = map (* 2) ns let t2 = concat t1 t0 in t2 futhark-0.25.27/tests/memory-block-merging/coalescing/concat/same.fut000066400000000000000000000010421475065116200255520ustar00rootroot00000000000000-- Memory block merging with a concat of the same array twice. This should -- coalesce without problems. However the compiler cannot remove both copy -- instructions from the final imperative intermediate representation, since -- either the first part or the second part will need to be inserted by a copy -- from the other part. -- == -- input { [13, 4] } -- output { [130, 40, 130, 40] } -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } let main (xs: []i32): []i32 = let ys = map (* 10) xs let zs = concat ys ys in zs futhark-0.25.27/tests/memory-block-merging/coalescing/copy/000077500000000000000000000000001475065116200236135ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/copy/neg0.fut000066400000000000000000000023401475065116200251630ustar00rootroot00000000000000-- Memory block merging with a copy into a multidimensional array. Very similar -- to pos3.fut, but this should not coalesce in the CPU pipeline (see body -- comment). -- -- NOTE: Due to a regression in the fusion engine, the seq-mem code is not -- actually being fused, which means that it can be short-circuited: -- https://github.com/diku-dk/futhark/issues/1733 -- -- When that bug has been fixed, there should be two seq-mem allocations instead -- of one. -- == -- input { 1i64 -- [6, 0, 7] -- [[-1, 0, 1], -- [-1, 0, 1], -- [-1, 0, 1]] -- } -- output { [[0, 1, 2], -- [7, 1, 8], -- [0, 1, 2]] -- } -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } let main [n] (i: i64) (ns: [n]i32) (mss: [n][n]i32): [n][n]i32 = -- For the CPU pipeline, t1 and t0 can be fused into a single outer map. This -- makes it impossible to coalesce, since mem_t1 is used after the creation of -- t0 through its use in the same map body as t0. -- -- The fusion does not happen in the GPU pipeline, so in that case it is the -- same as pos3.fut, meaning it gets a coalescing. let t1 = map (map (+ 1)) mss let k = 1 let t0 = map (+ k) ns let t1[i] = t0 in t1 futhark-0.25.27/tests/memory-block-merging/coalescing/copy/neg1.fut000066400000000000000000000011641475065116200251670ustar00rootroot00000000000000-- Memory block merging with a copy into a multidimensional array given as a -- function parameter. -- == -- input { [[0, 1, 2], -- [3, 4, 5], -- [6, 7, 8]] -- 1i64 -- [7, 1, 5] -- } -- output { [[0, 1, 2], -- [14, 10, 13], -- [6, 7, 8]] -- } -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } let main [n] (t1: *[n][n]i32) (i: i64) (ns: [n]i32): [n][n]i32 = let t0 = map3 (\x y z -> x + y + z) t1[i] ns (rotate 1 t1[i]) -- Will use the memory of t1[i]. -- This is the basis array in which everything will be put. let t1[i] = t0 in t1 futhark-0.25.27/tests/memory-block-merging/coalescing/copy/pos0.fut000066400000000000000000000005231475065116200252140ustar00rootroot00000000000000-- Memory block merging with a copy. Requires allocation hoisting of the memory -- block for 't1'. -- == -- input { [7, 0, 7] } -- output { [8, 1, 8] } -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 0 } let main [n] (ns: *[n]i32): *[n]i32 = let t0 = map (+ 1) ns -- Will use the memory of t1. let t1 = copy t0 in t1 futhark-0.25.27/tests/memory-block-merging/coalescing/copy/pos1.fut000066400000000000000000000011751475065116200252210ustar00rootroot00000000000000-- Memory block merging with a copy into a multidimensional array. Requires -- allocation hoisting of the memory block for 't1'. -- == -- input { 1i64 -- [7i64, 0i64, 7i64] -- } -- output { [[0i64, 1i64, 2i64], -- [8i64, 1i64, 8i64], -- [0i64, 1i64, 2i64]] -- } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 2 } let main [n] (i: i64) (ns: [n]i64): [n][n]i64 = let t1 = replicate n (iota n) let t0 = map (+ 1) ns -- Will use the memory of t1[i]. -- This is the basis array in which everything will be put. Its creation uses -- two allocations. let t1[i] = t0 in t1 futhark-0.25.27/tests/memory-block-merging/coalescing/copy/pos2.fut000066400000000000000000000011101475065116200252070ustar00rootroot00000000000000-- Memory block merging with a copy into a multidimensional array given as a -- function parameter. -- == -- input { [[0, 1, 2], -- [0, 1, 2], -- [0, 1, 2]] -- 1i64 -- [7, 0, 7] -- } -- output { [[0, 1, 2], -- [8, 1, 8], -- [0, 1, 2]] -- } -- structure seq-mem { Alloc 0 } -- structure gpu-mem { Alloc 0 } let main [n] (t1: *[n][n]i32) (i: i64) (ns: [n]i32): [n][n]i32 = let t0 = map (+ 1) ns -- Will use the memory of t1[i]. -- This is the basis array in which everything will be put. let t1[i] = t0 in t1 futhark-0.25.27/tests/memory-block-merging/coalescing/copy/pos3.fut000066400000000000000000000011471475065116200252220ustar00rootroot00000000000000-- Memory block merging with a copy into a multidimensional array. -- == -- input { 1i64 -- [6, 0, 7] -- [[-1, 0, 1], -- [-1, 0, 1], -- [-1, 0, 1]] -- } -- output { [[0, 1, 2], -- [7, 1, 8], -- [0, 1, 2]] -- } -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 2 } let main [n] (i: i64) (ns: [n]i32) (mss: [n][n]i32): [n][n]i32 = -- This is the basis array in which everything will be put. let t1 = map (\ms -> map (+ 1) ms) mss let k = t1[0, 1] let t0 = map (+ k) ns -- Will use the memory of t1[i]. let t1[i] = t0 in t1 futhark-0.25.27/tests/memory-block-merging/coalescing/copy/pos4.fut000066400000000000000000000011161475065116200252170ustar00rootroot00000000000000-- Memory block merging with a copy into a multidimensional array given as a -- function parameter. -- == -- input { [[0, 1, 2], -- [0, 1, 2], -- [0, 1, 2]] -- 1i64 -- [7, 0, 7] -- } -- output { [[0, 1, 2], -- [7, 1, 9], -- [0, 1, 2]] -- } -- structure seq-mem { Alloc 0 } -- structure gpu-mem { Alloc 0 } let main [n] (t1: *[n][n]i32) (i: i64) (ns: [n]i32): [n][n]i32 = let t0 = map2 (+) t1[i] ns -- Will use the memory of t1[i]. -- This is the basis array in which everything will be put. let t1[i] = t0 in t1 futhark-0.25.27/tests/memory-block-merging/coalescing/copy/pos5.fut000066400000000000000000000021521475065116200252210ustar00rootroot00000000000000-- Memory block merging with a copy into a multidimensional array given as a -- function parameter. -- == -- input { [[[0, 1, 2], -- [0, 1, 2], -- [0, 1, 2]], -- [[1, 2, 3], -- [4, 5, 6], -- [7, 8, 9]], -- [[0, 0, 0], -- [0, 0, 0], -- [0, 0, 0]]] -- 1i64 -- [7, 0, 7] -- } -- output { [[[0, 1, 2], -- [0, 1, 2], -- [0, 1, 2]], -- [[77, 77, 77], -- [0, 0, 0], -- [77, 77, 77]], -- [[0, 0, 0], -- [0, 0, 0], -- [0, 0, 0]]] -- } -- compiled random input { [256][256][256]i32 1i64 [256]i32 } -- auto output -- structure seq-mem { /Alloc 1 } -- structure gpu-mem { /Alloc 0 } let main [n] (t1: *[n][n][n]i32) (i: i64) (xs: [n]i32): *[n][n][n]i32 = #[incremental_flattening(only_intra)] let t0 = map (\x -> loop res = replicate n x for j < 10 do map (+ x) res) xs -- Will use the memory of t1[i]. -- This is the basis array in which everything will be put. let t1[i] = t0 in t1 futhark-0.25.27/tests/memory-block-merging/coalescing/cosmin-tests/000077500000000000000000000000001475065116200252715ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/cosmin-tests/test-suc-alias.fut000066400000000000000000000005551475065116200306540ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 4 } -- structure seq-mem { Alloc 1 } let main [n] (ind: i64) (ass: [n][n]i64) = let np1 = n+1 let yss = replicate np1 (replicate np1 2) let as = map (reduce (+) 0) ass let bs = opaque as let cs = as[2:] let ds = bs[:n-2] let r1 = reduce (+) 0 cs let r2 = reduce (+) 0 ds let yss[ind, 1:] = bs in (r1,r2,yss) futhark-0.25.27/tests/memory-block-merging/coalescing/cosmin-tests/test-suc-concat.fut000066400000000000000000000004341475065116200310260ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 1 } -- structure seq-mem { Alloc 3 } let main [n] (ind: i64) (as: [n]i64) = let tmp1 = map (*2) as |> opaque let tmp2 = map (*3) as |> opaque let tmp3 = map (*4) as |> opaque let tmp = concat tmp2 tmp3 let res = concat tmp1 tmp in res futhark-0.25.27/tests/memory-block-merging/coalescing/cosmin-tests/test-suc-if-1.fut000066400000000000000000000006661475065116200303220ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 1 } -- structure seq-mem { Alloc 1 } let main [n] (ind: i64) (q: f32) (ass: [n][n]f32) (as: [n]f32) = let yss = map (map (+1)) ass -- replicate n (replicate n 2) -- map (map (+1)) ass let bs = if (q > 0) then let b1s = map (+2) as -- replicate n 2 -- map (+2) as in b1s else let b2s = map (+3) as -- replicate n 3 -- map (+3) as in b2s let yss[ind] = bs in yss futhark-0.25.27/tests/memory-block-merging/coalescing/cosmin-tests/test-suc-loop-1.fut000066400000000000000000000004501475065116200306640ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 3 } -- structure seq-mem { Alloc 3 } let main [n] (ind: i64) (ass: [n][n]i64) (as: [n]i64) (inds: [n]i64) = let yss = replicate n (replicate n 33) let bs = iota n let bs = loop bs for i < n do map (*i) bs let yss[ind] = bs in yss futhark-0.25.27/tests/memory-block-merging/coalescing/cosmin-tests/test-suc-loop-2.fut000066400000000000000000000006121475065116200306650ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 6 } -- structure seq-mem { Alloc 4 } let main [n] (ind: i64) (ass: [n][n]i64) (as: [n]i64) = let yss = map (map (+1)) ass |> opaque -- replicate n (replicate n 2 let bs = map (*3) as |> opaque let bs = loop (bs) for i < n do let cs = map (*2) bs let s = reduce (+) 0 cs in map (+s) cs let yss[ind] = bs in yss futhark-0.25.27/tests/memory-block-merging/coalescing/cosmin-tests/test-suc-loop-3.fut000066400000000000000000000006361475065116200306740ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 4 } -- structure seq-mem { Alloc 2 } let main [n] (ind: i64) (ass: [n][n]i64) = let yss = map (map (+1)) ass |> opaque -- replicate n (replicate n 2) let cs = map (reduce (+) 0) yss let xs = replicate n 0i64 |> opaque let xs[0] = cs[0] let xs = loop (xs) for i < n-1 do let xs[i+1] = 3*xs[i] + cs[i]*i in xs let yss[ind] = xs in yss futhark-0.25.27/tests/memory-block-merging/coalescing/hoisting/000077500000000000000000000000001475065116200244655ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/hoisting/alloc-hinder0.fut000066400000000000000000000012671475065116200276340ustar00rootroot00000000000000-- An example of a program with silly limits. -- == -- input { [1i64, 2i64] } -- output { [2i64, 3i64, 0i64, 1i64] } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 2 } let main (ns: []i64): []i64 = let t0 = map (+ 1) ns -- Create an array whose memory block allocation depends on the *value* of t0, -- not the *shape* of t0. This makes it impossible to hoist the alloc up -- before the t0 creation. let annoying = iota t0[0] -- Try to coalesce t0 and annoying into t2. Only annoying can be coalesced. -- t0 cannot be coalesced, since the allocation of the memory of t2 can only -- occur after knowing the value of t0[0]. let t2 = concat t0 annoying in t2 futhark-0.25.27/tests/memory-block-merging/coalescing/hoisting/concat.fut000066400000000000000000000007711475065116200264610ustar00rootroot00000000000000-- Optimally, we would test just the the expresions are hoisted as we expect, -- but currently we just test that the final number of allocations match (it is -- the same for the other hoisting tests). -- == -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } let main (length0: i64, length1: i64): []i32 = let temp0 = replicate length0 1i32 let temp1 = replicate length1 1i32 -- Will be moved up to before temp0. let with_hoistable_mem = concat temp0 temp1 in with_hoistable_mem futhark-0.25.27/tests/memory-block-merging/coalescing/hoisting/copy-in-if.fut000066400000000000000000000007601475065116200271620ustar00rootroot00000000000000-- In this case the nested body needs allocation hoisting to enable one array -- coalescing. -- -- It is perhaps a pretty far-out case. -- == -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 2 } let main (cond: bool, lengths: []i64, index: i64): []i64 = if cond then let lengths' = map (+1) lengths let temp = replicate lengths'[index] 1i64 -- Will be moved up to before temp. let with_hoistable_mem = copy temp in with_hoistable_mem else lengths futhark-0.25.27/tests/memory-block-merging/coalescing/hoisting/copy.fut000066400000000000000000000004121475065116200261540ustar00rootroot00000000000000-- Small copy test. -- == -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } let main (length: i64): [length]i32 = let temp = replicate length 1i32 -- Will be moved up to before temp. let with_hoistable_mem = copy temp in with_hoistable_mem futhark-0.25.27/tests/memory-block-merging/coalescing/hoisting/dependent.fut000066400000000000000000000005661475065116200271620ustar00rootroot00000000000000-- A small chain of coalescings. -- == -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } let main (length: i64): []i32 = let temp = replicate length 1i32 -- Will be moved up to before temp. let with_hoistable_mem0 = copy temp -- Will be moved up to before temp. let with_hoistable_mem1 = concat temp with_hoistable_mem0 in with_hoistable_mem1 futhark-0.25.27/tests/memory-block-merging/coalescing/if/000077500000000000000000000000001475065116200232375ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/if/both-inside.fut000066400000000000000000000014561475065116200261720ustar00rootroot00000000000000-- An if expression where both branch arrays are defined inside the 'if'. This -- should be okay as long as the usual safety conditions are kept, since 'ys0' -- and 'ys1' exist independently of each other. -- == -- input { [[1i64, 4i64], [9i64, 16i64]] -- false -- 1i64 -- } -- output { [[1i64, 4i64], [1i64, 2i64]] -- } -- structure seq-mem { Alloc 0 } -- structure gpu-mem { Alloc 0 } let main [n] (xs: *[n][n]i64) (cond: bool) (i: i64): [n][n]i64 = -- Both branches will use the memory of ys, which will use the memory of -- xs[i]. let ys = if cond then let ys0 = iota n in ys0 else let ys1 = map (+ 1) (iota n) in ys1 -- xs is not allocated in this body, so we end up with zero allocations. let xs[i] = ys in xs futhark-0.25.27/tests/memory-block-merging/coalescing/if/both-outside.fut000066400000000000000000000011361475065116200263660ustar00rootroot00000000000000-- An if expression where both branch arrays are defined outside the 'if'. This -- should *not* be okay, since 'ys0' and 'ys1' sharing the same memory means -- that one of them gets overwritten. -- == -- input { [[1i64, 4i64], [9i64, 16i64]] -- false -- 1i64 -- } -- output { [[1i64, 4i64], [1i64, 2i64]] -- } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 2 } let main [n] (xs: *[n][n]i64) (cond: bool) (i: i64): [n][n]i64 = let ys0 = iota n let ys1 = map (+ 1) (iota n) let ys = if cond then ys0 else ys1 let xs[i] = ys in xs futhark-0.25.27/tests/memory-block-merging/coalescing/if/if-neg-2.fut000066400000000000000000000012561475065116200252670ustar00rootroot00000000000000-- Negative Example of If-Coalescing. -- == -- input { [[1,2], [3,4]] -- [1,2] -- [3,4] -- } -- output { -- [ [1i32, 2i32], [2i32, 3i32] ] -- 3i32 -- } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 2 } -- There should be no coalescing here because `x` is -- used during the lifetime of `r`, which also prevents -- coalescing of the `z` in `x`! let main [n] (x: *[n][n]i32) (a: [n]i32) (b: [n]i32): (*[n][n]i32, i32) = let (z,s) = if (x[0,0]) > 0 then let r = map (+1) a let q = x[x[0,0],0] in (r, q) else (map (*2) b, 2) let x[n/2] = z in (x,s) futhark-0.25.27/tests/memory-block-merging/coalescing/if/if-neg-3-pos.fut000066400000000000000000000016721475065116200260710ustar00rootroot00000000000000-- Positive variant of if-neg-3.fut: b[0] is used in the then-branch, but before -- y0 is created, so it is okay to coalesce both b and y0 into x[0] (through the -- existential memory of y). -- -- This also shows that it might make sense to do the memory block merging pass -- in a different representation: This program calculates the exact same thing -- as if-neg-3.fut, and yet this one is able to do two more coalescings just by -- moving a statement. -- -- However, existentials in ifs are not supported yet. -- == -- input { true -- [[9, 9], [9, 9]] -- [1, 4] -- } -- output { [[2, 5], [9, 9]] -- 2 -- } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 2 } let main [n] (cond: bool) (x: *[n][n]i32) (a: [n]i32): (*[n][n]i32, i32) = let b = map (+ 1) a let (y, r) = if cond then let k = b[0] let y0 = map (+ 1) a in (y0, k) else (b, 0) let x[0] = y in (x, r) futhark-0.25.27/tests/memory-block-merging/coalescing/if/if-neg-3.fut000066400000000000000000000017301475065116200252650ustar00rootroot00000000000000-- The compiler will try to coalesce y into x[0]. Since y is an if-expression, -- the compiler will then try to coalesce y0 and b into y (the branch results). -- At first this looks okay, since only one of them (b) is created outside the -- if, so making y0 and b use the same memory should not conflict (like in -- one-inside-one-outside.fut). However, b is used *after* the creation of y0 -- (in b[0]), so if they are set to use the same memory block, the b[0] -- expression will actually be y0[0], which is wrong. -- -- This needs to fail at any coalescing. -- == -- input { true -- [[9, 9], [9, 9]] -- [1, 4] -- } -- output { [[2, 5], [9, 9]] -- 2 -- } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 2 } let main [n] (cond: bool) (x: *[n][n]i32) (a: [n]i32): (*[n][n]i32, i32) = let b = map (+ 1) a let (y, r) = if cond then let y0 = map (+ 1) a in (y0, b[0]) else (b, 0) let x[0] = y in (x, r) futhark-0.25.27/tests/memory-block-merging/coalescing/if/if-neg-4.fut000066400000000000000000000013201475065116200252610ustar00rootroot00000000000000-- Same as if-neg-3.fut, but with an extra nested if. This should not produce -- any coalescings either. -- == -- input { true true -- [9, 9] [9, 9] -- [[0, 0], [0, 0]] -- } -- output { [[1, 2], [0, 0]] -- 10 10 -- } -- structure seq-mem { Alloc 3 } -- structure gpu-mem { Alloc 3 } let main [n] (cond0: bool) (cond1: bool) (y0: [n]i32) (z0: [n]i32) (x: *[n][n]i32): (*[n][n]i32, i32, i32) = let y = map (+ 1) y0 let z = map (+ 1) z0 let (a, b, c) = if cond0 then if cond1 then let y1 = iota n |> map i32.i64 |> map (+ 1) in (y1, y[0], z[0]) else (y, 0, 0) else (z, 0, 0) let x[0] = a in (x, b, c) futhark-0.25.27/tests/memory-block-merging/coalescing/if/if-nonexist.fut000066400000000000000000000025011475065116200262200ustar00rootroot00000000000000-- Tricky if-coalescing -- == -- input { [ [ [1,2], [3,4] ] -- , [ [5,6], [7,8] ] -- ] -- [1,2] -- } -- output { -- [ [ [1i32, 2i32], [3i32, 4i32] ] -- , [ [0i32, 0i32], [2i32, 4i32] ] -- ] -- } -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } -- Number of coalescing is 1, but corresponds to 4 coalescing -- operations on the same memory block, i.e., -- (i) `y[1] = z2`, at the very bottom -- (ii) `z2 = z0` where `z0` is the result of the then branch, -- (iii) `z2 = z1` where `z1` is the result of the else branch, -- (iv) and finaly the creation of `z` (transitive closure -- added by `z2` and `z1`). -- Basically, since the memory block of the if-result is not -- existensial then we can track the creation of `z` outside -- the branches. Note that `z[0] = x2` and `z[1] = x2` are not -- coalesced. let main [n] (y: *[n][n][n]i32) (a : [n]i32): *[n][n][n]i32 = let z = replicate n (replicate n 0) let x2 = map (*2) a -- The sole allocation. This could be stored in either -- z[0] or z[1], but both might need it, so we do not -- merge memory. let z2 = if (n > 3) then let z[0] = x2 in z else let z[1] = x2 in z let y[1] = z2 in y futhark-0.25.27/tests/memory-block-merging/coalescing/if/one-inside-one-outside.fut000066400000000000000000000013111475065116200302360ustar00rootroot00000000000000-- An if expression where one branch array is defined outside the 'if', and one -- is defined inside the 'if'. This should be okay as long as the usual safety -- conditions are kept, since 'ys0' and 'ys1' can use the same memory block -- without 'ys0' being overwritten (it seems). However, we cannot handle this yet. -- == -- input { [[1i64, 4i64], [9i64, 16i64]] -- false -- 1i64 -- } -- output { [[1i64, 4i64], [1i64, 2i64]] -- } -- structure gpu-mem { Alloc 2 } let main [n] (xs: *[n][n]i64) (cond: bool) (i: i64): [n][n]i64 = let ys0 = iota n let ys = if cond then ys0 else let ys1 = map (+ 1) (iota n) in ys1 let xs[i] = ys in xs futhark-0.25.27/tests/memory-block-merging/coalescing/inplace-updates/000077500000000000000000000000001475065116200257175ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/inplace-updates/fail.fut000066400000000000000000000005651475065116200273600ustar00rootroot00000000000000-- == -- entry: fail -- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64, 8i64, 9i64] } -- output { [0i64, 2i64, 3i64, 0i64, 4i64, 5i64, 6i64, 7i64, 8i64, 9i64] } -- structure gpu-mem { Alloc 1 } -- structure seq-mem { Alloc 1 } entry fail [n] (xs: *[n]i64): *[n]i64 = let b = replicate 4 0 -- let b[1:3] = xs[6:8] let b[1:3] = xs[2:4] in xs with [:4] = b futhark-0.25.27/tests/memory-block-merging/coalescing/inplace-updates/success.fut000066400000000000000000000005401475065116200301060ustar00rootroot00000000000000-- == -- entry: success -- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64, 8i64, 9i64] } -- output { [0i64, 6i64, 7i64, 0i64, 4i64, 5i64, 6i64, 7i64, 8i64, 9i64] } -- structure gpu-mem { Alloc 0 } -- structure seq-mem { Alloc 0 } entry success [n] (xs: *[n]i64): *[n]i64 = let b = replicate 4 0 let b[1:3] = xs[6:8] in xs with [:4] = b futhark-0.25.27/tests/memory-block-merging/coalescing/iota-one-row.fut000066400000000000000000000006161475065116200257040ustar00rootroot00000000000000-- == -- entry: iota_one_row -- input { 1i64 [ [10i64, 11i64, 12i64, 13i64, 14i64], -- [15i64, 16i64, 17i64, 18i64, 19i64] ] } -- output { [ [10i64, 11i64, 12i64, 13i64, 14i64], -- [0i64, 1i64, 2i64, 3i64, 4i64] ] } -- structure gpu-mem { Alloc 0 } -- structure seq-mem { Alloc 0 } entry iota_one_row [n][m] (i: i64) (xs: *[n][m]i64): *[n][m]i64 = xs with [i] = iota m futhark-0.25.27/tests/memory-block-merging/coalescing/issue-1789.fut000066400000000000000000000003031475065116200251130ustar00rootroot00000000000000-- == -- input { [1,2,3,4,5,6,7,8,9] } -- output { [4,5,6,7,8,9,7,8,9] } -- structure gpu-mem { Alloc 1 } entry main [n] (xs: *[n]i32) : [n]i32 = let i = n/3 in xs with [:2*i] = copy xs[i:] futhark-0.25.27/tests/memory-block-merging/coalescing/issue-1927.fut000066400000000000000000000002641475065116200251130ustar00rootroot00000000000000-- == -- input { 3i64 } -- output { [0i64, 0i64, 1i64, 2i64, 4i64] } -- structure gpu-mem { Alloc 2 } def main k = let src = iota k let dst = iota 5 in dst with [1:4] = src futhark-0.25.27/tests/memory-block-merging/coalescing/issue-1930.fut000066400000000000000000000002011475065116200250740ustar00rootroot00000000000000-- == -- input { 5000i64 [1i64] } -- error: def main [n] (k: i64) (dst: *[n]i64) = let src = iota k in dst with [1:4] = src futhark-0.25.27/tests/memory-block-merging/coalescing/issue-2010.fut000066400000000000000000000014211475065116200250670ustar00rootroot00000000000000-- partition2.fut -- == -- input { [1i32,2i32,3i32,4i32,5i32,6i32,7i32] -- } -- output { -- 3i64 -- [2i32,4i32,6i32,1i32,3i32,5i32,7i32] -- } let partition2 't [n] (dummy: t) (cond: t -> bool) (X: [n]t) : (i64, *[n]t) = let cs = map cond X let tfs= map (\ f->if f then 1i64 else 0i64) cs let isT= scan (+) 0 tfs let i = isT[n-1] let ffs= map (\f->if f then 0 else 1) cs let isF= map (+i) <| scan (+) 0 ffs let inds=map (\(c,iT,iF) -> if c then iT-1 else iF-1 ) (zip3 cs isT isF) let tmp = replicate n dummy in (i, scatter tmp inds X) let main [n] (arr: *[n]i32) : (i64,*[n]i32) = partition2 0i32 (\(x:i32) -> (x & 1i32) == 0i32) arr futhark-0.25.27/tests/memory-block-merging/coalescing/loop/000077500000000000000000000000001475065116200236125ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/loop/loop-ip.fut000066400000000000000000000010361475065116200257110ustar00rootroot00000000000000-- Very Simple Example of Loop Coalescing. -- == -- input { [ [1,2], [3,4] ] -- [1,2] -- } -- output { -- [ [1i32, 9i32], [1i32, 3i32] ] -- } -- structure seq-mem { Alloc 0 } -- structure gpu-mem { Alloc 0 } -- Code below should result in 1 mem-block coalescing, -- corresponding to 4 coalesced variables. let main [n] [m] (y: *[n][m]i32) (a: [m]i32): *[n][m]i32 = let y[0,1] = 9 let a0 = copy a let a1 = loop a1 = a0 for i < m do let a1[i] = i32.i64 i + a1[i] in a1 let y[n/2] = a1 in y futhark-0.25.27/tests/memory-block-merging/coalescing/loop/replicate-in-loop.fut000066400000000000000000000022541475065116200276600ustar00rootroot00000000000000-- A replicate in a loop, with a subsequent loop after the replicate. The -- coalescing transformation must make sure *not* to coalesce the loop into the -- return value. -- -- This problem originated from the decision to allow allocations inside loops. -- == -- input { [3, 6] -- 2i64 -- } -- output { [5, 8] -- } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 2 } let main [n] (a0: [n]i32) (n_iter: i64): []i32 = let a2 = loop a = a0 for _i < n_iter do -- If we coalesce a2 into a3, we end up coalescing the actual memory that -- the existential memory of a2 points to: the memory of b0. But that -- memory is allocated inside the loop, and it needs to stay that way to -- ensure that a loop iteration can read from the memory of the previous -- iteration and write to the memory of the current iteration. If a -- coalescing occurs, both iterations will use the same globally-created -- memory, and the replicate will write over everything written by the -- previous iteration. let b0 = replicate n 0i32 let a' = loop b = b0 for j < n do let b[j] = a[j] + 1 in b in a' let a3 = copy a2 in a3 futhark-0.25.27/tests/memory-block-merging/coalescing/lud/000077500000000000000000000000001475065116200234255ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/lud/lud.fut000066400000000000000000000100071475065116200247270ustar00rootroot00000000000000-- Parallel blocked LU-decomposition. -- -- == -- structure gpu-mem { Alloc 20 } -- structure seq-mem { Alloc 15 } def block_size: i64 = 32 def dotprod [n] (a: [n]f32) (b: [n]f32): f32 = map2 (*) a b |> reduce (+) 0 def lud_diagonal [b] (a: [b][b]f32): *[b][b]f32 = #[incremental_flattening(only_intra)] map (\mat -> let mat = copy mat in #[unsafe] loop (mat: *[b][b]f32) for i < b-1 do let col = map (\j -> if j > i then (mat[j,i] - (dotprod mat[j,:i] mat[:i,i])) / mat[i,i] else mat[j,i]) (iota b) let mat[:,i] = col let row = map (\j -> if j > i then mat[i+1, j] - (dotprod mat[:i+1, j] mat[i+1, :i+1]) else mat[i+1, j]) (iota b) let mat[i+1] = row in mat ) (unflatten (a :> [opaque 1*b][b]f32)) |> head def lud_perimeter_upper [m][b] (diag: [b][b]f32) (a0s: [m][b][b]f32): *[m][b][b]f32 = let a1s = map (\ (x: [b][b]f32): [b][b]f32 -> transpose(x)) a0s in let a2s = map (\a1 -> map (\row0 -> -- Upper #[unsafe] loop row = copy row0 for i < b do let sum = loop sum=0.0f32 for k < i do sum + diag[i,k] * row[k] let row[i] = row[i] - sum in row ) a1 ) a1s in map transpose a2s def lud_perimeter_lower [b][m] (diag: [b][b]f32) (mat: [m][b][b]f32): *[m][b][b]f32 = map (\blk -> map (\row0 -> -- Lower #[unsafe] loop row = copy row0 for j < b do let sum = loop sum=0.0f32 for k < j do sum + diag[k,j] * row[k] let row[j] = (row[j] - sum) / diag[j,j] in row ) blk ) mat def lud_internal [m][b] (top_per: [m][b][b]f32) (lft_per: [m][b][b]f32) (mat_slice: [m][m][b][b]f32): *[m][m][b][b]f32 = let top_slice = map transpose top_per in #[incremental_flattening(only_inner)] map2 (\mat_arr lft -> #[incremental_flattening(only_inner)] map2 (\mat_blk top -> #[incremental_flattening(only_inner)] map2 (\mat_row lft_row -> #[sequential_inner] map2 (\mat_el top_row -> let prods = map2 (*) lft_row top_row let sum = f32.sum prods in mat_el - sum ) mat_row top ) mat_blk lft ) mat_arr top_slice ) mat_slice lft_per def main [num_blocks] (matb: *[num_blocks][num_blocks][32][32]f32): *[num_blocks][num_blocks][32][32]f32 = #[unsafe] let matb = loop matb for step < num_blocks - 1 do -- 1. compute the current diagonal block let diag = lud_diagonal matb[step,step] in -- 2. compute the top perimeter let row_slice = matb[step,step+1:num_blocks] let top_per_irreg = lud_perimeter_upper diag row_slice -- 3. compute the left perimeter and update matrix let col_slice = matb[step+1:num_blocks,step] let lft_per_irreg = lud_perimeter_lower diag col_slice |> opaque -- 4. compute the internal blocks let inner_slice = matb[step+1:num_blocks,step+1:num_blocks] let internal = lud_internal top_per_irreg lft_per_irreg inner_slice -- 5. update matrix in place let matb[step, step] = diag let matb[step, step+1:num_blocks] = top_per_irreg let matb[step+1:num_blocks, step] = lft_per_irreg let matb[step+1:num_blocks, step+1:num_blocks] = internal in matb let last_step = num_blocks - 1 in let matb[last_step,last_step] = lud_diagonal matb[last_step, last_step] in matb futhark-0.25.27/tests/memory-block-merging/coalescing/lud/lud_internal1-16.fut000066400000000000000000000020571475065116200271360ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 2 } -- structure seq-mem { Alloc 3 } let lud_internal [m] (top_per: [m][16][16]f32) (lft_per: [m][16][16]f32) (mat_slice: [m][m][16][16]f32): *[m][m][16][16]f32 = let top_slice = map transpose top_per in #[incremental_flattening(only_intra)] map2 (\mat_arr lft -> map2 (\mat_blk top -> map2 (\mat_row lft_row -> map2 (\mat_el top_row -> let prods = map2 (*) lft_row top_row let sum = f32.sum prods in mat_el - sum ) mat_row top ) mat_blk lft ) mat_arr top_slice ) mat_slice lft_per let main [num_blocks] (matb: *[num_blocks][num_blocks][16][16]f32) = let top_per_irreg = matb[0,1:num_blocks] let col_slice = matb[1:num_blocks,0] let inner_slice = matb[1:num_blocks,1:num_blocks] let internal = lud_internal top_per_irreg col_slice inner_slice let matb[1:, 1:] = internal in matb futhark-0.25.27/tests/memory-block-merging/coalescing/lud/lud_internal1.fut000066400000000000000000000020531475065116200267060ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 2 } -- structure seq-mem { Alloc 1 } let lud_internal [m][b] (top_per: [m][b][b]f32) (lft_per: [m][b][b]f32) (mat_slice: [m][m][b][b]f32): *[m][m][b][b]f32 = let top_slice = map transpose top_per in #[incremental_flattening(only_intra)] map2 (\mat_arr lft -> map2 (\mat_blk top -> map2 (\mat_row lft_row -> map2 (\mat_el top_row -> let prods = map2 (*) lft_row top_row let sum = f32.sum prods in mat_el - sum ) mat_row top ) mat_blk lft ) mat_arr top_slice ) mat_slice lft_per let main [num_blocks][b] (matb: *[num_blocks][num_blocks][b][b]f32) = let top_per_irreg = matb[0,1:num_blocks] let col_slice = matb[1:num_blocks,0] let inner_slice = matb[1:num_blocks,1:num_blocks] let internal = lud_internal top_per_irreg col_slice inner_slice let matb[1:, 1:] = internal in matb futhark-0.25.27/tests/memory-block-merging/coalescing/lud/lud_internal2.fut000066400000000000000000000020631475065116200267100ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 7 } -- structure seq-mem { Alloc 1 } let lud_internal [m][b] (top_per: [m][b][b]f32) (lft_per: [m][b][b]f32) (mat_slice: [m][m][b][b]f32): *[m][m][b][b]f32 = let top_slice = map transpose top_per in map2 (\mat_arr lft -> map2 (\mat_blk top -> map2 (\mat_row lft_row -> map2 (\mat_el top_row -> #[sequential] let prods = map2 (*) lft_row top_row let sum = f32.sum prods in mat_el - sum ) mat_row top ) mat_blk lft ) mat_arr top_slice ) mat_slice lft_per let main [num_blocks][b] (matb: *[num_blocks][num_blocks][b][b]f32) = let top_per_irreg = matb[0,1:num_blocks] let col_slice = matb[1:num_blocks,0] let inner_slice = matb[1:num_blocks,1:num_blocks] let internal = lud_internal top_per_irreg col_slice inner_slice let matb[1:, 1:] = internal in matb futhark-0.25.27/tests/memory-block-merging/coalescing/lud/lud_internal3.fut000066400000000000000000000020031475065116200267030ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 8 } -- structure seq-mem { Alloc 1 } let lud_internal [m][b] (top_per: [m][b][b]f32) (lft_per: [m][b][b]f32) (mat_slice: [m][m][b][b]f32): *[m][m][b][b]f32 = let top_slice = map transpose top_per in map2 (\mat_arr lft -> map2 (\mat_blk top -> map2 (\mat_row lft_row -> map2 (\mat_el top_row -> let prods = map2 (*) lft_row top_row let sum = f32.sum prods in mat_el - sum ) mat_row top ) mat_blk lft ) mat_arr top_slice ) mat_slice lft_per let main [num_blocks][b] (matb: *[num_blocks][num_blocks][b][b]f32) = let top_per_irreg = matb[0,1:num_blocks] let col_slice = matb[1:num_blocks,0] let inner_slice = matb[1:num_blocks,1:num_blocks] let internal = lud_internal top_per_irreg col_slice inner_slice let matb[1:, 1:] = internal in matb futhark-0.25.27/tests/memory-block-merging/coalescing/lud/lud_internal4.fut000066400000000000000000000022531475065116200267130ustar00rootroot00000000000000-- == -- compiled random input { [16][16][8][8]f32 } -- auto output -- structure gpu-mem { Alloc 2 } let lud_internal [b][m] (top_per: [m][b][b]f32) (lft_per: [m][b][b]f32) (mat_slice: [m][m][b][b]f32): *[m][m][b][b]f32 = let top_slice = map transpose top_per in #[incremental_flattening(no_outer)] map2 (\mat_arr lft -> #[incremental_flattening(no_outer)] map2 (\mat_blk top -> #[incremental_flattening(only_intra)] map2 (\mat_row lft_row -> map2 (\mat_el top_row -> let prods = map2 (*) lft_row top_row let sum = f32.sum prods in mat_el - sum ) mat_row top ) mat_blk lft ) mat_arr top_slice ) mat_slice lft_per let main [b][num_blocks] (matb: *[num_blocks][num_blocks][b][b]f32) = let top_per_irreg = matb[0,1:num_blocks] let col_slice = matb[1:num_blocks,0] let inner_slice = matb[1:num_blocks,1:num_blocks] let internal = lud_internal top_per_irreg col_slice inner_slice let matb[1:, 1:] = internal in matb futhark-0.25.27/tests/memory-block-merging/coalescing/lud/lud_internal5.fut000066400000000000000000000021611475065116200267120ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 3 } let lud_internal [m] (top_per: [m][16][16]f32) (lft_per: [m][16][16]f32) (mat_slice: [m][m][16][16]f32): *[m][m][16][16]f32 = let top_slice = map transpose top_per in #[incremental_flattening(no_outer)] map2 (\mat_arr lft -> #[incremental_flattening(no_outer)] map2 (\mat_blk top -> #[incremental_flattening(only_intra)] map2 (\mat_row lft_row -> map2 (\mat_el top_row -> let prods = map2 (*) lft_row top_row let sum = f32.sum prods in mat_el - sum ) mat_row top ) mat_blk lft ) mat_arr top_slice ) mat_slice lft_per let main [num_blocks] (matb: *[num_blocks][num_blocks][16][16]f32) = let top_per_irreg = matb[0,1:num_blocks] let col_slice = matb[1:num_blocks,0] let inner_slice = matb[1:num_blocks,1:num_blocks] let internal = lud_internal top_per_irreg col_slice inner_slice let matb[1:, 1:] = internal in matb futhark-0.25.27/tests/memory-block-merging/coalescing/lud/lud_internal6.fut000066400000000000000000000017671475065116200267260ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 2 } let lud_internal [m][b] (top_per: [m][b][b]f32) (lft_per: [m][b][b]f32) (mat_slice: [m][m][b][b]f32): *[m][m][b][b]f32 = let top_slice = map transpose top_per in #[sequential_inner] map2 (\mat_arr lft -> map2 (\mat_blk top -> map2 (\mat_row lft_row -> map2 (\mat_el top_row -> let prods = map2 (*) lft_row top_row let sum = f32.sum prods in mat_el - sum ) mat_row top ) mat_blk lft ) mat_arr top_slice ) mat_slice lft_per let main [num_blocks][b] (matb: *[num_blocks][num_blocks][b][b]f32) = let top_per_irreg = matb[0,1:num_blocks] let col_slice = matb[1:num_blocks,0] let inner_slice = matb[1:num_blocks,1:num_blocks] let internal = lud_internal top_per_irreg col_slice inner_slice let matb[1:, 1:] = internal in matb futhark-0.25.27/tests/memory-block-merging/coalescing/map/000077500000000000000000000000001475065116200234165ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/map/map.fut000066400000000000000000000002501475065116200247100ustar00rootroot00000000000000-- == -- input { [1,2,3] } -- output { [2,3,4] } -- structure gpu-mem { Alloc 0 } -- structure seq-mem { Alloc 1 } let main [n] (xs: *[n]i32): *[n]i32 = map (+1) xs futhark-0.25.27/tests/memory-block-merging/coalescing/map/map0.fut000066400000000000000000000004171475065116200247750ustar00rootroot00000000000000-- == -- input { [[1,2,3], [4,5,6], [7,8,9]] 1i64 } -- output { [[1,2,3], [5,6,7], [7,8,9]] } -- structure gpu-mem { Alloc 0 } -- structure seq-mem { Alloc 0 } let main [n] (xss: *[n][n]i32) (i: i64) = let xs = map (+1) xss[i] let xss' = xss with [i] = xs in xss' futhark-0.25.27/tests/memory-block-merging/coalescing/map/map1.fut000066400000000000000000000004331475065116200247740ustar00rootroot00000000000000-- == -- input { [[1,2,3], [4,5,6], [7,8,9]] 1i64 } -- output { [[1,2,3], [11,13,15], [7,8,9]] } -- structure gpu-mem { Alloc 1 } -- structure seq-mem { Alloc 0 } let main [n] (xss: *[n][n]i32) (i: i64) = let xs = map2 (+) xss[i] xss[i+1] let xss' = xss with [i] = xs in xss' futhark-0.25.27/tests/memory-block-merging/coalescing/map/map10.fut000066400000000000000000000003411475065116200250520ustar00rootroot00000000000000-- == -- input { [0, 1, 2, 3, 4] } -- output { [0f32, 1f32, 2f32, 3f32, 4f32] } -- structure gpu-mem { Alloc 0 } -- structure mc-mem { Alloc 0 } -- structure seq-mem { Alloc 1 } def main [n] (xs: *[n]i32) = map f32.i32 xs futhark-0.25.27/tests/memory-block-merging/coalescing/map/map11.fut000066400000000000000000000002461475065116200250570ustar00rootroot00000000000000-- == -- input { [1,2,3,4] } auto output -- structure gpu { SegMap 2 } def main (xs: []i32) = loop xs for _i < 10 do map (+2) (opaque (map (+1) (reverse xs))) futhark-0.25.27/tests/memory-block-merging/coalescing/map/map12.fut000066400000000000000000000004221475065116200250540ustar00rootroot00000000000000-- == -- input { [[1,2,3,4]] } -- output { [[11, 7, 4, 2]] } -- structure gpu-mem { Alloc 2 } def main [n][m] (xs: [n][m]i32) = #[incremental_flattening(only_intra)] map (\row -> let ys = scan (+) 0 row let zs = map (+1) ys in reverse zs) xs futhark-0.25.27/tests/memory-block-merging/coalescing/map/map13.fut000066400000000000000000000002051475065116200250540ustar00rootroot00000000000000-- == -- random input { [2000]i32 } -- auto output -- structure gpu-mem { Alloc 1 } def main (xs: *[]i32) = take 1000 (map (+1) xs) futhark-0.25.27/tests/memory-block-merging/coalescing/map/map14.fut000066400000000000000000000006041475065116200250600ustar00rootroot00000000000000-- This is testing that duplicate names for memory blocks in different -- functions don't cause trouble. -- == -- entry: foo bar -- input { [1,2,3] [2,3,4] } -- output { [3,5,7] } -- structure gpu-mem { Alloc 0 } -- structure mc-mem { Alloc 0 } -- structure seq-mem { Alloc 2 } entry foo (xs: *[]i32) (ys: []i32) = map2 (+) xs ys entry bar (xs: *[]i32) (ys: []i32) = map2 (+) xs ys futhark-0.25.27/tests/memory-block-merging/coalescing/map/map15.fut000066400000000000000000000003161475065116200250610ustar00rootroot00000000000000-- == -- input { [1,2,3] } -- output { [2.0f32, 3.0f32, 4.0f32] } -- structure gpu-mem { Alloc 0 } let main [n] (xs: *[n]i32): *[n]f32 = let xs = opaque (map (+1) xs) let ys = map (f32.i32) xs in ys futhark-0.25.27/tests/memory-block-merging/coalescing/map/map3.fut000066400000000000000000000007061475065116200250010ustar00rootroot00000000000000-- == -- input { [[[1,2,3], [4,5,6], [7,8,9]], -- [[0,0,0], [1,1,1], [2,2,2]], -- [[3,3,3], [4,4,4], [5,5,5]]] -- 1i64 } -- output { [[[1,2,3], [4,5,6], [7,8,9]], -- [[1,1,1], [2,2,2], [3,3,3]], -- [[3,3,3], [4,4,4], [5,5,5]]] } -- structure gpu-mem { Alloc 1 } -- structure seq-mem { Alloc 0 } let main [n] (xsss: *[n][n][n]i32) (i: i64) = let xss = map (map (+ 1)) xsss[i] let xsss[i] = xss in xsss futhark-0.25.27/tests/memory-block-merging/coalescing/map/map4.fut000066400000000000000000000010471475065116200250010ustar00rootroot00000000000000-- Even though there is an opportunity to short-circuit a memory allocation on -- the GPU, we shouldn't because it will hurt coalesced access. -- == -- input { [7i64, 8i64, 9i64] } -- output { [[0i64, 8i64, 18i64], -- [1i64, 9i64, 19i64], -- [2i64, 10i64, 20i64]] } -- compiled random input { [1024]i64 } -- auto output -- structure gpu-mem { Alloc 2 } -- structure seq-mem { Alloc 1 } let main [n] (xs: [n]i64): [n][n]i64 = map (\j -> loop xs' = copy xs for i < n do let xs'[i] = xs'[i] * i + j in xs' ) (iota n) futhark-0.25.27/tests/memory-block-merging/coalescing/map/map5.fut000066400000000000000000000005211475065116200247760ustar00rootroot00000000000000-- == -- compiled random input { [1024][1024]i64 } -- auto output -- structure gpu-mem { Alloc 3 } -- structure seq-mem { Alloc 1 } let main [m][n] (xss: [m][n]i64): [n][m]i64 = let xss' = transpose xss in map (\xs -> loop xs' = copy xs for i < m do let xs'[i] = xs'[i] * i in xs' ) xss' futhark-0.25.27/tests/memory-block-merging/coalescing/map/map6.fut000066400000000000000000000002471475065116200250040ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 1 } -- structure seq-mem { Alloc 2 } let main [n] (xs: [n]i64) = let ys = map (+ 1) xs let zs = map (* 2) xs in concat ys zs futhark-0.25.27/tests/memory-block-merging/coalescing/map/map7.fut000066400000000000000000000004301475065116200247770ustar00rootroot00000000000000-- == -- input { [[1,2,3], [4,5,6], [7,8,9]] 1i64 } -- output { [[1,2,3], [5,7,9], [7,8,9]] } -- structure gpu-mem { Alloc 0 } -- structure seq-mem { Alloc 0 } let main [n] (xss: *[n][n]i32) (i: i64) = let xs = map2 (+) xss[i] xss[i-1] let xss' = xss with [i] = xs in xss' futhark-0.25.27/tests/memory-block-merging/coalescing/map/map8.fut000066400000000000000000000003211475065116200247770ustar00rootroot00000000000000-- == -- input { [1,2,3] [2,3,4] } -- output { [3,5,7] } -- structure gpu-mem { Alloc 0 } -- structure mc-mem { Alloc 0 } -- structure seq-mem { Alloc 1 } def main (xs: *[]i32) (ys: []i32) = map2 (+) xs ys futhark-0.25.27/tests/memory-block-merging/coalescing/map/map9.fut000066400000000000000000000003061475065116200250030ustar00rootroot00000000000000-- == -- input { [1,2,3] } -- output { [2,3,4] } -- structure gpu-mem { Alloc 1 } -- structure mc-mem { Alloc 1 } -- structure seq-mem { Alloc 1 } let main [n] (xs: [n]i32): [n]i32 = map (+1) xs futhark-0.25.27/tests/memory-block-merging/coalescing/misc/000077500000000000000000000000001475065116200235745ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/misc/choice0.fut000066400000000000000000000020211475065116200256210ustar00rootroot00000000000000-- An example of a program where there is a coalescing choice. -- == -- input { [1i64, 2i64] } -- output { [2i64, 3i64, 0i64, 1i64] } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 2 } let main [n] (ns: [n]i64): []i64 = let t0 = map (+ 1) ns -- Create an array whose memory block allocation depends on the *value* of t0, -- not the *shape* of t0. This makes it impossible to hoist the alloc up -- before the t0 creation. let annoying = iota t0[0] -- Either coalesce t0 into t1... let t1 = copy t0 -- ... or coalesce t1 into t2. Both will not work: -- -- + If t0 is coalesced into t1, the allocation of t2 needs to be hoisted way -- up to before the creation of t0. This is not possible due to the -- allocation size depending on the size of annoying. -- -- + Else, the allocation of t2 just needs to be hoisted up to before the -- creation of t1, which is doable. -- -- Either will work on their own. annoying can always be coalesced. let t2 = concat t1 annoying in t2 futhark-0.25.27/tests/memory-block-merging/coalescing/misc/two-dim-ker.fut000066400000000000000000000016541475065116200264610ustar00rootroot00000000000000-- Test2 Memory-Block Merging -- -- For the CPU pipeline there is no coalescing to do in this program. The -- compiler makes sure there is only a single alloc before we even get to memory -- block merging. -- -- == -- input { [ [ [0i64, 1i64], [2i64, 3i64] ], [ [4i64, 5i64], [6i64, 7i64] ] ] } -- output { [[[0i64, 9i64], [0i64, 13i64]]]} -- compiled random input { [128][128][128]i64 } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 3 } let main [n] (xsss: [n][n][n]i64): [][n][n]i64 = let asss = drop 1 xsss in map (\ass -> map (\as -> let r = loop r = 0 for i < n do let r = r + as[i] in r in loop bs = iota n for j < n do let bs[j] = bs[j]*r in bs ) ass ) asss futhark-0.25.27/tests/memory-block-merging/coalescing/safety-condition-1/000077500000000000000000000000001475065116200262565ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/safety-condition-1/neg0.fut000066400000000000000000000015331475065116200276310ustar00rootroot00000000000000-- Negative test. We cannot fulfill safety condition 1, since 'ys' is used -- after the coalescing-enabling line (in 'zs'). -- -- However, the new fusion engine fuses the two maps (ys and zs), meaning that -- it actually is short-circuited on the seq-mem backend. -- == -- input { 3i64 -- [0i64, 1i64, 2i64, 3i64] -- } -- output { [[0i64, 1i64, 2i64, 3i64], -- [4i64, 5i64, 6i64, 7i64], -- [8i64, 9i64, 10i64, 11i64], -- [1i64, 2i64, 3i64, 4i64]] -- [2i64, 3i64, 4i64, 5i64] -- } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 2 } let main [n] (i: i64) (ys0: [n]i64): ([n][n]i64, [n]i64) = let xs = tabulate_2d n n (\i j -> i*n + j) |> opaque let ys = map (+ 1) ys0 let xs[i] = ys let zs = map (+ 1) ys -- This could also be a short-circuit point in SeqMem in (xs, zs) futhark-0.25.27/tests/memory-block-merging/coalescing/safety-condition-2/000077500000000000000000000000001475065116200262575ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/safety-condition-2/neg0.fut000066400000000000000000000010621475065116200276270ustar00rootroot00000000000000-- Negative test. We cannot fulfill safety condition 2, since 'ys' is allocated -- before the function body is run, so 'xs', which is created in the body, can -- never be allocated before 'ys'. -- == -- input { 0i64 -- [10i64, 20i64, 30i64] -- } -- output { [[10i64, 20i64, 30i64], -- [3i64, 4i64, 5i64], -- [6i64, 7i64, 8i64]] -- } -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } let main [n] (i: i64) (ys: [n]i64): [n][n]i64 = let xs = tabulate_2d n n (\i j -> i * n + j) let xs[i] = ys in xs futhark-0.25.27/tests/memory-block-merging/coalescing/safety-condition-2/pos0.fut000066400000000000000000000010371475065116200276610ustar00rootroot00000000000000-- Positive test. We can fulfill safety condition 2, since the memory for 'xs' -- is allocated before 'ys' is created. -- == -- input { 2 -- [[1, 1, 1], -- [1, 1, 1], -- [1, 1, 1]] -- [5, 7, 9] -- } -- output { [[1, 1, 1], -- [1, 1, 1], -- [10, 14, 18]] -- } -- structure seq-mem { Alloc 0 } -- structure gpu-mem { Alloc 0 } let main [n] (i: i32) (xs: *[n][n]i32) (ys0: [n]i32): [n][n]i32 = let ys = map (* 2) ys0 -- Will use the memory of xs[i]. let xs[i] = ys in xs futhark-0.25.27/tests/memory-block-merging/coalescing/safety-condition-3/000077500000000000000000000000001475065116200262605ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/safety-condition-3/neg0.fut000066400000000000000000000012721475065116200276330ustar00rootroot00000000000000-- Negative test. 'xs' is used while 'ys' is live, so we cannot merge their -- memory blocks, since 'zs' would then map over the contents of 'ys' instead of -- the original contents of 'xs[i]'. -- == -- input { [[2, 2], -- [2, 2]] -- [3, 4] -- 1i64 -- } -- output { [[2, 2], -- [4, 5]] -- 6 -- } -- structure seq-mem { Alloc 2 } -- structure gpu-mem { Alloc 3 } let main [n] (xs: *[n][n]i32) (ys0: [n]i32) (i: i64): ([n][n]i32, i32) = let ys = map (+ 1) ys0 let zs = map (+ ys[0]) xs[i] -- Cannot be hoisted to exist before 'ys', which -- would have solved the problem. let xs[i] = ys in (xs, zs[i]) futhark-0.25.27/tests/memory-block-merging/coalescing/safety-condition-3/pos0.fut000066400000000000000000000007741475065116200276710ustar00rootroot00000000000000-- Positive test. 'xs' is not used while 'ys' is live, so we can merge their -- memory blocks. -- == -- input { [[2, 2], -- [2, 2]] -- [3, 4] -- 1 -- } -- output { [[2, 2], -- [4, 5]] -- 6 -- } -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } let main [n] (xs: *[n][n]i32) (ys0: [n]i32) (i: i32): ([n][n]i32, i32) = let ys = map (+ 1) ys0 -- Will use the memory of xs[i]. let zs = map (+ 1) ys let xs[i] = ys in (xs, zs[i]) futhark-0.25.27/tests/memory-block-merging/coalescing/safety-condition-4/000077500000000000000000000000001475065116200262615ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/safety-condition-4/neg0.fut000066400000000000000000000007731475065116200276410ustar00rootroot00000000000000-- Negative test. 'ys' aliases 'zs', so it already occupies another memory -- block than "its own". -- == -- input { [[2, 2], -- [2, 2]] -- [[3, 4], -- [10, 20]] -- 0 -- 0 -- } -- output { [[9, 12], -- [2, 2]] -- } -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } let main [n] (xs: *[n][n]i32) (zs0: [n][n]i32) (i: i32) (j: i32): [n][n]i32 = let zs = map (\z -> map (* 3) z) zs0 let ys = zs[i] let xs[j] = ys in xs futhark-0.25.27/tests/memory-block-merging/coalescing/safety-condition-5/000077500000000000000000000000001475065116200262625ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/safety-condition-5/neg0.fut000066400000000000000000000006551475065116200276410ustar00rootroot00000000000000-- t0 cannot be coalesced into t1, since the index function includes i1, which -- depends on the result of t0. -- == -- input { [0, 1] -- [[5, 5], [5, 5]] -- 0 -- } -- output { [[5, 5], [1, 2]] -- } -- structure seq-mem { Alloc 1 } -- structure gpu-mem { Alloc 1 } let main [n] (ns: [n]i32) (t1: *[n][n]i32) (i0: i32): [][]i32 = let t0 = map (+ 1) ns let i1 = t0[i0] let t1[i1] = t0 in t1 futhark-0.25.27/tests/memory-block-merging/coalescing/test-impossible-asserts/000077500000000000000000000000001475065116200274465ustar00rootroot00000000000000iota-inlined-perfbug.fut000066400000000000000000000003741475065116200341170ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/coalescing/test-impossible-asserts-- == -- structure gpu-mem { Alloc 3 } -- structure seq-mem { Alloc 2 } let main [n] (xs: *[n][n]i64) (i: i64): (i64, [n][n]i64) = let a = iota n let b = if i > 5 then a[i:n] else a[0:n-i] let s = reduce (+) 0i64 b let xs[i] = a in (s, xs) futhark-0.25.27/tests/memory-block-merging/coalescing/test-impossible-asserts/test1.fut000066400000000000000000000010501475065116200312220ustar00rootroot00000000000000-- == -- input { [[0i64, 1i64, 2i64], [3i64, 4i64, 5i64], [7i64, 8i64, 9i64]] -- [42i64, 1337i64, 0i64] -- 1i64 -- } -- output { 596444i64 -- [[0i64, 1i64, 2i64], -- [588i64, 595856i64, 0i64], -- [7i64, 8i64, 9i64]] -- } -- structure gpu-mem { Alloc 2 } -- structure seq-mem { Alloc 1 } let main [n] (xs: *[n][n]i64) (a0: [n]i64) (i: i64): (i64, [n][n]i64) = let a = map (\e -> e*e/3) a0 let b = if i > 5 then a[i:n] else a[0:n-i] let s = reduce (+) 0i64 b let xs[i] = a in (s, xs) futhark-0.25.27/tests/memory-block-merging/coalescing/test-impossible-asserts/test2.fut000066400000000000000000000003441475065116200312300ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 1 } -- structure seq-mem { Alloc 1 } let main [n] (xs: *[n][n]i64) (a0: [n]i64) (i: i64): ([n]i64, [n][n]i64) = let a = map (\e -> e*e/3) a0 let a[i] = 33i64 let xs[i] = a in (a, xs) futhark-0.25.27/tests/memory-block-merging/coalescing/test-impossible-asserts/test3.fut000066400000000000000000000004371475065116200312340ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 0 } -- structure seq-mem { Alloc 0 } let main [n] (xsss: *[n][n][n]i64) (ass0: [n][n]i64) (as0: [n]i64) (i: i64): [n][n][n]i64 = let ass = map (map (*5i64)) ass0 |> opaque let as = map (+3i64) as0 let ass[i] = as let xsss[i] = ass in xsss futhark-0.25.27/tests/memory-block-merging/coalescing/test-impossible-asserts/test4.fut000066400000000000000000000003721475065116200312330ustar00rootroot00000000000000-- == -- structure gpu-mem { Alloc 1 } -- structure seq-mem { Alloc 1 } let main [n] (zss: *[n][n]f64) (x: f64) (y: f64) (i: i64) (j: i64): *[n][n]f64 = let ys = replicate n y let xs = replicate n x let zss[i] = xs let zss[j] = ys in zss futhark-0.25.27/tests/memory-block-merging/coalescing/weird.fut000066400000000000000000000010231475065116200244670ustar00rootroot00000000000000-- == -- input { [[1i64,2i64,3i64],[4i64,5i64,6i64],[7i64,8i64,9i64]] 1i64 } -- output { [[1i64,1i64,1i64], [5i64,6i64,7i64], [8i64,9i64,10i64]] } -- structure gpu-mem { Alloc 2 } -- structure seq-mem { Alloc 2 } let main [n] (xss: [n][n]i64) (i: i64) = -- The basis array let xss = map (map (+ 1)) xss -- There's also an allocation here let xs = replicate n i -- This loop and the result inside should be short-circuited let xss = loop xss for j < i do xss with [j] = xs let xss = copy xss in xss futhark-0.25.27/tests/memory-block-merging/misc/000077500000000000000000000000001475065116200214655ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/misc/brown-bridge-simple.fut000066400000000000000000000014641475065116200260620ustar00rootroot00000000000000-- Test inspired from Option Pricing's brownian bridge -- == -- input { [ [1.0, 3.0, 5.0, 7.0, 9.0, 11.0], [0.0, 2.0, 4.0, 6.0, 8.0, 10.0] ] } -- output { [[3.0f64, 2.0f64, 2.0f64, 2.0f64, 2.75f64, -11.25f64], [2.0f64, 2.0f64, 2.0f64, 2.0f64, 2.0f64, -10.0f64]] } def brownian_bridge [num_dates] (gauss: [num_dates]f64): [num_dates]f64 = let bbrow = replicate num_dates 0.0 let bbrow[ num_dates-1 ] = 0.5 * gauss[0] let bbrow = loop (bbrow) for i < num_dates-1 do #[unsafe] let bbrow[i] = bbrow[i+1]*1.5 + gauss[i+1] in bbrow let bbrow = loop (bbrow) for ii < num_dates-1 do let i = num_dates - (ii+1) let bbrow[i] = bbrow[i] - bbrow[i-1] in bbrow in bbrow def main [m] [num_dates] (gausses: [m][num_dates]f64) : [m][num_dates]f64 = map brownian_bridge gausses futhark-0.25.27/tests/memory-block-merging/misc/ixfun-loop.fut000066400000000000000000000004501475065116200243040ustar00rootroot00000000000000-- A simple test for index-function generalization across a for loop -- == -- input { [0i64, 1000i64, 42i64, 1001i64, 50000i64] } -- output { 1249975000i64 } def main [n] (a: [n]i64): i64 = let b = loop b = iota(10) for i < n do let m = a[i] in iota(m) in reduce (+) 0 b futhark-0.25.27/tests/memory-block-merging/reuse/000077500000000000000000000000001475065116200216555ustar00rootroot00000000000000futhark-0.25.27/tests/memory-block-merging/reuse/map-reduce-map.fut000066400000000000000000000006101475065116200251670ustar00rootroot00000000000000-- == -- input { [[ 0, 1, 2, 3], [ 4, 5, 6, 7]] } -- output { [[10, 11, 12, 13], [30, 31, 32, 33]] } -- random input { [100][10]i32 } -- auto output -- structure gpu-mem { Alloc 2 } def main (xss: [][]i32) = #[incremental_flattening(only_intra)] map (\xs -> let as = map (+1) xs |> opaque let a = reduce (+) 0 as let bs = map (+a) xs in bs) xss futhark-0.25.27/tests/migration/000077500000000000000000000000001475065116200164755ustar00rootroot00000000000000futhark-0.25.27/tests/migration/array0.fut000066400000000000000000000003521475065116200204130ustar00rootroot00000000000000-- Array literals are migrated if they contain free scalar variables. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/ArrayLit 1 -- /ArrayLit 0 -- } def main (i: i64) (v: i32) : i32 = let xs = [0, 1, 2, 3, v] in xs[i%5] futhark-0.25.27/tests/migration/array1.fut000066400000000000000000000003061475065116200204130ustar00rootroot00000000000000-- Array literals are migrated if they contain free scalar variables. -- == -- structure gpu { -- GPUBody 0 -- ArrayLit 1 -- } def main (i: i64) : i32 = let xs = [0, 1, 2, 3, 4] in xs[i%5] futhark-0.25.27/tests/migration/array2.fut000066400000000000000000000004251475065116200204160ustar00rootroot00000000000000-- Array literals are migrated if they contain free scalar variables. -- -- Arrays with non-primitive rows are not be migrated. -- == -- structure gpu { -- GPUBody 0 -- ArrayLit 3 -- } def main (i: i64) (v: [2]i32) : i32 = let xs = [[0, 1], [2, 3], v] in xs[i%3, i%2] futhark-0.25.27/tests/migration/array3.fut000066400000000000000000000004531475065116200204200ustar00rootroot00000000000000-- Array literals are migrated if they contain free scalar variables. -- -- Arrays with non-primitive rows are not be migrated. This is to avoid turning -- a parallel device copy into a sequential operation. -- == -- structure gpu { -- GPUBody 0 -- } def main [n] (A: [n]i32) : [1][n]i32 = [A] futhark-0.25.27/tests/migration/array4.fut000066400000000000000000000003631475065116200204210ustar00rootroot00000000000000-- Array literals are migrated if they contain free scalar variables. -- -- [n] is equivalent to gpu { n }. -- == -- structure gpu { -- /GPUBody 1 -- ArrayLit 0 -- Replicate 0 -- /Index 0 -- } def main (x: [1]i32) : *[1]i32 = [x[0]]futhark-0.25.27/tests/migration/assert0.fut000066400000000000000000000003741475065116200206020ustar00rootroot00000000000000-- Assertions are considered free to migrate but are not migrated needlessly. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/Index 1 -- /GPUBody/CmpOp 1 -- /GPUBody/Assert 1 -- } def main (arr: [1]i32) : i32 = assert (arr[0] == 42) 1337 futhark-0.25.27/tests/migration/assert1.fut000066400000000000000000000004671475065116200206060ustar00rootroot00000000000000-- Assertions are considered free to migrate but are not migrated needlessly. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/Index 1 -- /GPUBody/CmpOp 2 -- /GPUBody/Assert 2 -- } def main (arr: [1]i32) : i32 = let v = arr[0] let x = assert (v != 7) 1007 let y = assert (v != 21) 330 in x+y futhark-0.25.27/tests/migration/assert2.fut000066400000000000000000000002731475065116200206020ustar00rootroot00000000000000-- Assertions are considered free to migrate but are not migrated needlessly. -- == -- structure gpu { -- GPUBody 0 -- /Assert 1 -- } def main (fail: bool) : i32 = assert fail 800 futhark-0.25.27/tests/migration/assert3.fut000066400000000000000000000002621475065116200206010ustar00rootroot00000000000000-- Assertions are considered free to migrate but are not migrated needlessly. -- == -- structure gpu { -- GPUBody 0 -- /Assert 1 -- } def main (arr: []i64) : i64 = arr[0] futhark-0.25.27/tests/migration/assert4.fut000066400000000000000000000004731475065116200206060ustar00rootroot00000000000000-- Assertions are considered free to migrate but are not migrated needlessly. -- -- This test is a variation of the 'index0.fut' test. -- == -- structure gpu { -- /Assert 1 -- /GPUBody 1 -- /GPUBody/Assert 1 -- /GPUBody/Index 2 -- /Index 1 -- } def main (arr: []i64) : i64 = let i = arr[0] in arr[i] futhark-0.25.27/tests/migration/blocking0_hostonly.fut000066400000000000000000000010071475065116200230220ustar00rootroot00000000000000-- Host-only operations block the migration of whole statements. -- == -- structure gpu { -- GPUBody 0 -- } #[noinline] def hostonly 'a (x: a) : a = -- This function can only be run on host. let arr = opaque [x] in arr[0] entry case_if (A: [5]i64) : i64 = if A[0] == 0 then hostonly 42 else A[1] entry case_while (A: [5]i64) : i64 = loop x = A[0] while x < 1000 do x * (A[hostonly x % 5] + 1) entry case_for (A: [5]i64) : i64 = loop x = 0 for i < A[0] do x * (A[hostonly x % 5] + 1) futhark-0.25.27/tests/migration/blocking1_array.fut000066400000000000000000000012041475065116200222610ustar00rootroot00000000000000-- Arrays with non-primitive rows block the migration of whole statements. -- This is to avoid turning a parallel device copy into a sequential operation. -- == -- structure gpu { -- /If/True/ArrayLit 1 -- /Loop/ArrayLit 2 -- } entry case_if (A: [5]i64) (x: i64) : i64 = if A[0] == 0 then let B = [A, opaque A] in #[unsafe] (opaque B)[x%2, 2] else A[1] entry case_while (A: [5]i64) : i64 = loop x = A[0] while x < 1000 do let B = [A, opaque A] in #[unsafe] (opaque B)[x%2, 2] entry case_for (A: [5]i64) : i64 = loop x = 0 for i < A[0] do let B = [A, opaque A] in #[unsafe] (opaque B)[x%2, 2] futhark-0.25.27/tests/migration/blocking2_concat.fut000066400000000000000000000012021475065116200224110ustar00rootroot00000000000000-- Array concatenation blocks the migration of whole statements. -- This is to avoid turning a parallel device copy into a sequential operation. -- == -- structure gpu { -- /If/True/Concat 1 -- /Loop/Concat 2 -- } entry case_if (A: [5]i64) (x: i64) : i64 = if A[0] == 0 then let B = concat A (opaque A) in #[unsafe] (opaque B)[x%10] else A[1] entry case_while (A: [5]i64) : i64 = loop x = A[0] while x < 1000 do let B = concat A (opaque A) in #[unsafe] (opaque B)[x%10] entry case_for (A: [5]i64) : i64 = loop x = 0 for i < A[0] do let B = concat A (opaque A) in #[unsafe] (opaque B)[x%10] futhark-0.25.27/tests/migration/blocking3_copy.fut000066400000000000000000000005301475065116200221200ustar00rootroot00000000000000-- Array copying blocks the migration of whole statements. -- This is to avoid turning a parallel device copy into a sequential operation. -- == -- structure gpu { -- /If/True/Replicate 1 -- } def main (A: [5]i64) : [1]i64 = if A[0] == 0 then let B = copy (opaque A) in #[unsafe] (opaque B)[0:1] else A[1:2] :> [1]i64 futhark-0.25.27/tests/migration/blocking4_iota.fut000066400000000000000000000011661475065116200221110ustar00rootroot00000000000000-- Iotas blocks the migration of whole statements. -- This is to avoid turning a parallel computation into a sequential operation. -- == -- structure gpu { -- /If/True/Iota 1 -- /Loop/Iota 2 -- } entry case_if (A: [5]i64) (x: i64) : i64 = if A[0] == 0 then let B = iota A[1] in #[unsafe] (opaque B)[x % length B] else A[1] entry case_while (A: [5]i64) (x: i64) : i64 = loop y = A[0] while y < 1000 do let B = iota y in #[unsafe] (opaque B)[x % length B] entry case_for (A: [5]i64) (x: i64) : i64 = loop y = 0 for i < A[0] do let B = iota y in #[unsafe] (opaque B)[x % length B] futhark-0.25.27/tests/migration/blocking5_replicate.fut000066400000000000000000000012431475065116200231220ustar00rootroot00000000000000-- Replicates blocks the migration of whole statements. -- This is to avoid turning a parallel computation into a sequential operation. -- == -- structure gpu { -- /If/True/Replicate 1 -- /Loop/Replicate 2 -- } entry case_if (A: [5]i64) (x: i64) : i64 = if A[0] == 0 then let B = replicate A[1] 1337 in #[unsafe] (opaque B)[x % length B] else A[1] entry case_while (A: [5]i64) (x: i64) : i64 = loop y = A[0] while y < 1000 do let B = replicate y 1337 in #[unsafe] (opaque B)[x % length B] entry case_for (A: [5]i64) (x: i64) : i64 = loop y = 0 for i < A[0] do let B = replicate y 1337 in #[unsafe] (opaque B)[x % length B] futhark-0.25.27/tests/migration/blocking6_update.fut000066400000000000000000000005701475065116200224370ustar00rootroot00000000000000-- Multi-element array slice updates block the migration of whole statements. -- This is to avoid turning a parallel device copy into a sequential operation. -- == -- structure gpu { -- /If/True/Update 1 -- } entry case_if (A: *[5]i64) (x: i64) : i64 = if A[0] == 0 then let B = A with [0:2] = [4, 2] in #[unsafe] (opaque B)[x % length B] else A[1] futhark-0.25.27/tests/migration/blocking7_exception.fut000066400000000000000000000005421475065116200231530ustar00rootroot00000000000000-- Any operation that normally is parallel does not block migration of parent -- statements if it produces an array of just a single element. -- == -- structure gpu { -- GPUBody/If/True/Replicate 1 -- } -- This fails due to a memory allocation error. -- def main (A: [1]i64) : [1]i64 = -- if A[4] == 42 -- then copy (opaque A) -- else A futhark-0.25.27/tests/migration/cse0.fut000066400000000000000000000003361475065116200200510ustar00rootroot00000000000000-- Duplicate scalar migrations are eliminated. -- == -- structure gpu { -- Replicate 1 -- } def main (A: [5]i32) (x: i32) : i32 = let (a, b) = if x == 42 then (0, 0) else (A[0], A[1]) in a+b futhark-0.25.27/tests/migration/cse1.fut000066400000000000000000000005241475065116200200510ustar00rootroot00000000000000-- Duplicate scalar migrations are eliminated. -- -- Both x's are migrated and should be reduced to a single GPUBody. -- == -- structure gpu { -- /GPUBody 1 -- /Loop/If/True/GPUBody 0 -- } def main (A: [5]i32) (x: i32) : i32 = let (_, res) = loop (c, _) = (0, x) for z in A do if c == 3 then (c, x) else (c+1, z) in res futhark-0.25.27/tests/migration/cse2.fut000066400000000000000000000007421475065116200200540ustar00rootroot00000000000000-- Duplicate scalar migrations are eliminated. -- -- Both x's are migrated and should be reduced to a single array. -- They are migrated with the GPUBody that computes 'A[1]+1'. -- If merging occurs before CSE then this test will fail. -- == -- structure gpu { -- /GPUBody 1 -- /Loop/If/True/GPUBody 0 -- } def main (A: *[5]i32) (x: i32) : i32 = let A[0] = A[1]+1 -- let (_, res) = loop (c, _) = (0, x) for z in A do if c == 3 then (c, x) else (c+1, z) in res futhark-0.25.27/tests/migration/flatindex.fut000066400000000000000000000005461475065116200212000ustar00rootroot00000000000000-- Flat indexes may not be migrated on their own as results from GPUBody -- constructs are copied, which would change the asymptotic cost of the -- operation. -- == -- structure gpu { -- GPUBody 0 -- FlatIndex 1 -- } import "intrinsics" -- This fails due to a memory allocation error. def main (A: [5]i64) : [2][2]i64 = flat_index_2d A 0 2 2 2 A[0] futhark-0.25.27/tests/migration/flatupdate.fut000066400000000000000000000006061475065116200213500ustar00rootroot00000000000000-- Flat updates may not be migrated on their own as results from GPUBody -- constructs are copied, which would change the asymptotic cost of the -- operation. -- == -- structure gpu { -- GPUBody 0 -- FlatUpdate 1 -- } import "intrinsics" -- This fails due to a memory allocation error. let v = [[1i64]] def main (A: *[5]i64) : *[5]i64 = let x = A[0] in flat_update_2d A x 1 1 v futhark-0.25.27/tests/migration/fun0_hostonly.fut000066400000000000000000000007321475065116200220260ustar00rootroot00000000000000-- Do not migrate the application of functions that work with arrays. -- These are subject to compiler limitations. -- == -- structure gpu { -- GPUBody 0 -- /Index 3 -- } #[noinline] def fun'' (a: i32) (b: i32) : i32 = let xs = scan (+) 0 (replicate 10 a) in xs[b % 10] #[noinline] def fun' (a: i32) (b: i32) : i32 = fun'' a b #[noinline] def fun (a: i32) (b: i32) : i32 = fun' a b def main (arr: [2]i32) : i32 = let (a, b) = (arr[0], arr[1]) in fun a b futhark-0.25.27/tests/migration/fun1_2-1.fut000066400000000000000000000004561475065116200204520ustar00rootroot00000000000000-- Function calls that do not work with arrays can be migrated. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/Index 2 -- /GPUBody/Apply 1 -- /Index 1 -- } #[noinline] def plus (a: i32) (b: i32) : i32 = a + b def main (arr: [2]i32) : i32 = let (a, b) = (arr[0], arr[1]) in plus a b futhark-0.25.27/tests/migration/fun2_2-1_unused.fut000066400000000000000000000005001475065116200220240ustar00rootroot00000000000000-- Function calls that do not work with arrays can be migrated. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/Index 2 -- /GPUBody/Apply 1 -- /Index 1 -- } #[noinline] def id2 'a (x: a) (y: a) : (a, a) = (x, y) def main (arr: [3]i32) : i32 = let (a, b) = (arr[0], arr[1]) let (_, y) = id2 a b in y futhark-0.25.27/tests/migration/fun3_2-2-1.fut000066400000000000000000000005321475065116200206060ustar00rootroot00000000000000-- Function calls that do not work with arrays can be migrated. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/Index 2 -- /GPUBody/Apply 1 -- /GPUBody/BinOp 1 -- /Index 1 -- } #[noinline] def id2 'a (x: a) (y: a) : (a, a) = (x, y) def main (arr: [3]i32) : i32 = let (a, b) = (arr[0], arr[1]) let (x, y) = id2 a b in x + y futhark-0.25.27/tests/migration/fun4_2-2.fut000066400000000000000000000004101475065116200204440ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- GPUBody 0 -- /Index 2 -- } #[noinline] def id2 'a (x: a) (y: a) : (a, a) = (x, y) def main (arr: [3]i32) : (i32, i32) = let (a, b) = (arr[0], arr[1]) in id2 a b futhark-0.25.27/tests/migration/fun5_2-3.fut000066400000000000000000000004331475065116200204530ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- GPUBody 0 -- /Index 2 -- } #[noinline] def id3 'a (x: a) (y: a) (z: a) : (a, a, a) = (x, y, z) def main (arr: [3]i32) : (i32, i32, i32) = let (a, b) = (arr[0], arr[1]) in id3 a b a futhark-0.25.27/tests/migration/fun6_3-2.fut000066400000000000000000000005141475065116200204540ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/Index 3 -- /GPUBody/Apply 1 -- /Index 2 -- } #[noinline] def pick2 'a (x: a) (y: a) (z: a) : (a, a) = (x, z) def main (arr: [3]i32) : (i32, i32) = let (a, b, c) = (arr[0], arr[1], arr[2]) in pick2 a b c futhark-0.25.27/tests/migration/fun7_3-2.fut000066400000000000000000000004251475065116200204560ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- /GPUBody 0 -- /Index 2 -- } #[noinline] def pick2 'a (x: a) (y: a) (z: a) : (a, a) = (x, y) def main (arr: [3]i32) : (i32, i32) = let (a, b) = (arr[1], arr[2]) in pick2 a b a futhark-0.25.27/tests/migration/fun8_2-1_hostonly.fut000066400000000000000000000011141475065116200224100ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/Index 2 -- /GPUBody/Apply 1 -- /Index 1 -- } #[noinline] def hostonly (x: i32) : i32 = -- This function can only be run on host and thus requires -- its argument to be made available there. if x == 42 then (opaque [x])[0] else 42 #[noinline] def id2 'a (x: a) (y: a) : (a, a) = (x, y) def main (arr: [3]i32) : [3]i32 = let (a, b) = (arr[1], arr[2]) let (x, y) = id2 a b let i = hostonly x in map (\j -> if j == y then j/i else j) arr futhark-0.25.27/tests/migration/hist0.fut000066400000000000000000000003721475065116200202460ustar00rootroot00000000000000-- Neutral elements are made available on host. -- == -- structure gpu { -- /Index 1 -- /SegHist 1 -- } def main (A: *[10]i32) : *[10]i32 = let A = A with [0] = 0 let B = opaque A let ne = B[0] in reduce_by_index B (+) ne [4, 2] [1, 0] futhark-0.25.27/tests/migration/hist1.fut000066400000000000000000000007231475065116200202470ustar00rootroot00000000000000-- Reads can be delayed into kernel bodies and combining operators. -- == -- structure gpu { -- /Index 0 -- /SegHist 1 -- /SegMap 0 -- } def main (A: *[10]i64) : *[10]i64 = let A = A with [0] = 0 let A = A with [1] = 1 let B = opaque A let x = B[0] -- This read can be delayed into op let y = B[1] -- This read can be delayed into the kernel body let op = \a b -> a+b+x let is = [4, 2] let as = map (+y) is in reduce_by_index B op 0 is as futhark-0.25.27/tests/migration/hoisted0.fut000066400000000000000000000005351475065116200207370ustar00rootroot00000000000000-- GPUBody kernels are hoisted out of loops they are invariant of. -- == -- structure gpu { -- GPUBody 2 -- /GPUBody 1 -- /Loop/Loop/GPUBody/BinOp 1 -- } def main [n] (A: *[n]i64) (x: i64) : *[n]i64 = loop A for i < n do -- Storage of x should be hoisted. let sum = loop x for y in #[unsafe] A[i:] do x + y in A with [i] = sum futhark-0.25.27/tests/migration/hoisted1.fut000066400000000000000000000007421475065116200207400ustar00rootroot00000000000000-- GPUBody kernels are hoisted out of loops they are invariant of. -- -- If merging occurs before hoisting then this test will fail. -- == -- structure gpu { -- GPUBody 1 -- /Loop/GPUBody/Index 2 -- /Loop/GPUBody/Loop/BinOp 2 -- } def main [n] (A: *[n]i64) (x: i64) : *[n]i64 = loop A for i < n-1 do let j = #[unsafe] (A[i] + A[i+1]%n) % n -- Storage of x should be hoisted. let sum = loop x for y in #[unsafe] A[j:] do x + y in #[unsafe] A with [i] = sum futhark-0.25.27/tests/migration/if0.fut000066400000000000000000000003611475065116200176730ustar00rootroot00000000000000-- If statements may be migrated as a whole but only if reads are reduced. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/CmpOp 1 -- /GPUBody/If 1 -- /Index 1 -- } def main (A: [5]i64) : i64 = if A[0] == 0 then A[1] else A[2] futhark-0.25.27/tests/migration/if1.fut000066400000000000000000000005641475065116200177010ustar00rootroot00000000000000-- If statements may be migrated as a whole but only if reads are reduced. -- -- Reads can be delayed from any branch, even if only one branch performs a -- read. This reduces the worst-case number of reads. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/CmpOp 1 -- /GPUBody/If 1 -- /Index 1 -- } def main (A: [5]i64) : i64 = if A[0] == 0 then 42 else A[2] futhark-0.25.27/tests/migration/if10.fut000066400000000000000000000014231475065116200177540ustar00rootroot00000000000000-- If statements that return arrays can also be migrated if doing so does not -- introduce any additional array copying. -- -- In this case the outer 'if' can be migrated because all possible return -- values originate from array literals declared within its branches. These -- array literals would all be migrated to a GPUBody anyway, so migrating the -- if statement introduces no additional copying. -- == -- structure gpu { -- ArrayLit 3 -- /GPUBody/If 1 -- /ArrayLit 0 -- /If 0 -- } -- This fails due to a memory allocation error. -- def main (A: [2]i64) (x: i64) : [2]i64 = -- if A[0] == 0 -- then [opaque x, 1] -- else let B = opaque [opaque x, 2, 0] -- in if A[1] == 1 -- then loop _ = B[0:2] for i < x do [i, 3] -- else [opaque x, 4] futhark-0.25.27/tests/migration/if2.fut000066400000000000000000000004571475065116200177030ustar00rootroot00000000000000-- If statements may be migrated as a whole but only if reads are reduced. -- -- Migrating the whole if statement in this case saves no reads as it would -- introduce a read of the return value. -- == -- structure gpu { -- GPUBody 0 -- } def main (A: [5]i64) : i64 = if A[0] == 0 then 42 else 1337 futhark-0.25.27/tests/migration/if3.fut000066400000000000000000000005111475065116200176730ustar00rootroot00000000000000-- If statements may be migrated as a whole but only if reads are reduced. -- -- Migrating the whole if statement in this case saves no reads as it would -- introduce a read of the second tuple value. -- == -- structure gpu { -- GPUBody 0 -- } def main (A: [5]i64) : (i64, i64) = if A[0] == 0 then (7, 42) else (A[2], 1337) futhark-0.25.27/tests/migration/if4.fut000066400000000000000000000006421475065116200177010ustar00rootroot00000000000000-- If statements may be migrated as a whole but only if reads are reduced. -- -- A read is saved by migrating the 'if' as the second tuple value is not used -- on host. -- == -- structure gpu { -- /GPUBody/CmpOp 1 -- /GPUBody/If/False/Index 1 -- /GPUBody/BinOp 1 -- /Index 1 -- } def main (A: *[5]i64) : i64 = let (x, y) = if A[0] == 0 then (A[1], 42) else (A[2], 1337) let A' = map (+y) A in x + A'[4] futhark-0.25.27/tests/migration/if5.fut000066400000000000000000000004011475065116200176730ustar00rootroot00000000000000-- Reads can be delayed out of if statements. -- == -- structure gpu { -- /If 1 -- /If/True/Index 0 -- /If/False/Index 0 -- /GPUBody/BinOp 1 -- /Index 1 -- } def main (A: [5]i64) (c: bool) : i64 = let x = if c then A[1] else A[2] in x + A[3] futhark-0.25.27/tests/migration/if6.fut000066400000000000000000000005011475065116200176750ustar00rootroot00000000000000-- Reads can be delayed out of if statements, even if only one branch performs -- a read. This reduces the worst-case number of reads. -- == -- structure gpu { -- /If 1 -- /If/True/Index 0 -- /GPUBody/BinOp 1 -- /Index 1 -- } def main (A: [5]i64) (c: bool) : i64 = let x = if c then A[1] else 42 in x + A[3] futhark-0.25.27/tests/migration/if7.fut000066400000000000000000000003661475065116200177070ustar00rootroot00000000000000-- Reads can be delayed into if statements. -- == -- structure gpu { -- /Index 0 -- /If/True/GPUBody/BinOp 1 -- } def main (A: *[5]i64) (c: bool) : *[5]i64 = let x = A[0] let A' = if c then A with [0] = x+1 else A in A' with [2] = x futhark-0.25.27/tests/migration/if8.fut000066400000000000000000000003701475065116200177030ustar00rootroot00000000000000-- Reads can be delayed through if statements. -- == -- structure gpu { -- /If/True/GPUBody/BinOp 1 -- /GPUBody/BinOp 2 -- /Index 1 -- } def main (A: [5]i64) (c: bool) : i64 = let x = A[0] let y = if c then x+3 else 42 in x + y + A[1] futhark-0.25.27/tests/migration/if9.fut000066400000000000000000000004661475065116200177120ustar00rootroot00000000000000-- If statements that return arrays can also be migrated if doing so does not -- introduce any additional array copying. -- -- In this case the 'if' cannot be migrated. -- == -- structure gpu { -- GPUBody 0 -- } def main [n] (A: [n]i64) (m: i64) : [m]i64 = if A[0] == 0 then A[:m] else A[1:m+1] :> [m]i64 futhark-0.25.27/tests/migration/index0.fut000066400000000000000000000003771475065116200204130ustar00rootroot00000000000000-- Non-reducible array reads should not be migrated unless to prevent the -- reading of their indices. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/Index 2 -- /Index 1 -- } def main (arr: [3][3]i32) : i32 = let i = arr[1, 0] in arr[0, i] futhark-0.25.27/tests/migration/index1.fut000066400000000000000000000003211475065116200204010ustar00rootroot00000000000000-- Non-reducible array reads should not be migrated unless to prevent the -- reading of their indices. -- == -- structure gpu { -- GPUBody 0 -- /Index 1 -- } def main (arr: [3][3]i32) : i32 = arr[1, 1] futhark-0.25.27/tests/migration/index2.fut000066400000000000000000000004431475065116200204070ustar00rootroot00000000000000-- Slice indexes may not be migrated on their own as results from GPUBody -- constructs are copied, which would change the asymptotic cost of the -- operation. -- == -- structure gpu { -- GPUBody 0 -- Index 2 -- } def main (arr: [3][3]i32) : [3]i32 = let i = arr[1, 0] in arr[:, i] futhark-0.25.27/tests/migration/intrinsics.fut000066400000000000000000000033541475065116200214070ustar00rootroot00000000000000def flat_index_2d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) : [n1][n2]a = intrinsics.flat_index_2d as offset n1 s1 n2 s2 :> [n1][n2]a def flat_update_2d [n][k][l] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (asss: [k][l]a) : *[n]a = intrinsics.flat_update_2d as offset s1 s2 asss def flat_index_3d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) : [n1][n2][n3]a = intrinsics.flat_index_3d as offset n1 s1 n2 s2 n3 s3 :> [n1][n2][n3]a def flat_update_3d [n][k][l][p] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (asss: [k][l][p]a) : *[n]a = intrinsics.flat_update_3d as offset s1 s2 s3 asss def flat_index_4d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) (n4: i64) (s4: i64) : [n1][n2][n3][n4]a = intrinsics.flat_index_4d as offset n1 s1 n2 s2 n3 s3 n4 s4 :> [n1][n2][n3][n4]a def flat_update_4d [n][k][l][p][q] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (s4: i64) (asss: [k][l][p][q]a) : *[n]a = intrinsics.flat_update_4d as offset s1 s2 s3 s4 asss type~ acc 't = intrinsics.acc t def scatter_stream [k] 'a 'b (dest: *[k]a) (f: *acc ([k]a) -> b -> acc ([k]a)) (bs: []b) : *[k]a = intrinsics.scatter_stream dest f bs :> *[k]a def reduce_by_index_stream [k] 'a 'b (dest: *[k]a) (op: a -> a -> a) (ne: a) (f: *acc ([k]a) -> b -> acc ([k]a)) (bs: []b) : *[k]a = intrinsics.hist_stream dest op ne f bs :> *[k]a def write [n] 't (acc : *acc ([n]t)) (i: i64) (v: t) : *acc ([n]t) = intrinsics.acc_write acc i v futhark-0.25.27/tests/migration/iota.fut000066400000000000000000000003651475065116200201550ustar00rootroot00000000000000-- In general iotas are not migrated in order not to turn a parallel operation -- into a sequential one (GPUBody kernels are single-threaded). -- == -- structure gpu { -- GPUBody 0 -- Iota 1 -- } def main (A: [5]i64) : []i64 = iota A[0] futhark-0.25.27/tests/migration/kernels_hostonly.fut000066400000000000000000000005521475065116200226210ustar00rootroot00000000000000-- Parallel kernels cannot be migrated into sequential kernels and thus blocks -- the migration of parent statements. -- == -- structure gpu { -- /Loop/Loop/If 1 -- } def main (A: [10][5]i64): i64 = loop x = 1 for i < A[0, 0] do loop x for B in A do if B[0] != 42 then let sum = reduce (+) 0 B in sum%x + 1 else 42futhark-0.25.27/tests/migration/loop0_inside.fut000066400000000000000000000006111475065116200215770ustar00rootroot00000000000000-- Reads can be delayed and reduced inside loops. -- == -- structure gpu { -- /Loop/GPUBody 1 -- /Loop/GPUBody/Apply 1 -- /Loop/GPUBody/Index 1 -- } def main [n] [m] (A: *[n][m]f32) : *[n][m]f32 = loop A = A for i < n do let B = A[i, :] let B' = map (\x -> x*x) B let len = reduce (+) 0 B' |> f32.sqrt let B' = map (/len) B' let A' = A with [i, :] = B' in A' futhark-0.25.27/tests/migration/loop10_blocked.fut000066400000000000000000000004511475065116200220120ustar00rootroot00000000000000-- Whole loops are not migrated if the asymptotic cost would change. -- -- In this case migrating the whole loop could cause a copy of A. -- == -- structure gpu { -- /Loop 1 -- } def main (A: [5]i64) : [5]i64 = loop (B: [5]i64) = A for i < A[0] do if i%4 != 0 then B else [1, 2, i, 4, 5] futhark-0.25.27/tests/migration/loop11_into.fut000066400000000000000000000004451475065116200213640ustar00rootroot00000000000000-- Reads can be delayed into loops given that the number of reads done by the -- loop remains unchanged. -- == -- structure gpu { -- /Index 0 -- /Loop/GPUBody/BinOp 1 -- } def main (A: [10]i64) : [10]i64 = let x = A[0] in loop A for y in A do let z = x+y in map (+z) A futhark-0.25.27/tests/migration/loop12_into.fut000066400000000000000000000010761475065116200213660ustar00rootroot00000000000000-- Reads can be delayed into loops given that the number of reads done by the -- loop remains unchanged. -- -- In this case the reads are not delayed as the worst case number of reads per -- iteration would increase by one. -- == -- structure gpu { -- /Index 2 -- GPUBody 0 -- } #[noinline] def hostonly (x: i64) : i64 = -- This function can only be run on host. if x == 42 then (opaque [x])[0] else 42 def main (A: [10]i64) : i64 = let x = A[0] let y = A[1] in loop z = 0 for i < 10 do hostonly (if z != 0 then 42 else id (x+z) + y) futhark-0.25.27/tests/migration/loop13_into.fut000066400000000000000000000005251475065116200213650ustar00rootroot00000000000000-- Reads can be delayed into loops given that the number of reads done by the -- loop remains unchanged. -- == -- structure gpu { -- /Index 0 -- /Loop/GPUBody/BinOp 1 -- } def main (A: [10]i64) : [10]i64 = let x = A[0] let (_, A') = loop (x, A) = (x, A) for y in A do let z = x+y in (opaque x, map (+z) A) in A' futhark-0.25.27/tests/migration/loop14_into.fut000066400000000000000000000005161475065116200213660ustar00rootroot00000000000000-- Reads can be delayed into loops given that the number of reads done by the -- loop remains unchanged. -- == -- structure gpu { -- /Index 0 -- /Loop/GPUBody/BinOp 1 -- } def main (A: [10]i64) : [10]i64 = let x = A[0] let (_, A') = loop (x, A) = (x, A) for y in A do let z = x+y in (0, map (+z) A) in A' futhark-0.25.27/tests/migration/loop15_into.fut000066400000000000000000000007631475065116200213730ustar00rootroot00000000000000-- Reads can be delayed into loops given that the number of reads done by the -- loop remains unchanged. -- == -- structure gpu { -- /GPUBody 1 -- /Loop/GPUBody 1 -- } #[noinline] def hostonly (x: i64) : i64 = -- This function can only be run on host. if x == 42 then (opaque [x])[0] else x def main (A: [10]i64) : [10]i64 = let x = A[0] let (_, A') = loop (x, A) = (x, A) for i < 10 do let y = x+1 let z = hostonly y in (A[i], map (+z) A) in A' futhark-0.25.27/tests/migration/loop16_into.fut000066400000000000000000000010721475065116200213660ustar00rootroot00000000000000-- Reads can be delayed into loops given that the number of reads done by the -- loop remains unchanged. -- -- In this case the reads are not delayed as the number of reads per iteration -- would increase by one. -- == -- structure gpu { -- /Index 2 -- GPUBody 0 -- } #[noinline] def hostonly (x: i64) : i64 = -- This function can only be run on host. if x == 42 then (opaque [x])[0] else 42 def main (A: [10]i64) : i64 = let (a, b) = loop (x, y) = (A[0], A[1]) for i < 10 do let z = hostonly (x+y) in (z%22, z*z) in a+b futhark-0.25.27/tests/migration/loop17_through.fut000066400000000000000000000003301475065116200220720ustar00rootroot00000000000000-- Reads can be delayed through loops. -- == -- structure gpu { -- /Index 1 -- /Loop/GPUBody/BinOp 2 -- } def main (A: [10]i64) : i64 = let x = A[0] let y = A[1] in loop z = 0 for i < 10 do (x+z)+y futhark-0.25.27/tests/migration/loop18_through.fut000066400000000000000000000003261475065116200221000ustar00rootroot00000000000000-- Reads can be delayed through loops. -- == -- structure gpu { -- /Index 1 -- /GPUBody/Loop 1 -- } def main (A: [10]i64) : i64 = let x = A[0] let y = A[1] in loop z = 0 while z < 1000 do (x+z)+y futhark-0.25.27/tests/migration/loop19_through.fut000066400000000000000000000005331475065116200221010ustar00rootroot00000000000000-- Reads can be delayed through loops. -- == -- structure gpu { -- /Index 0 -- /Loop/If/True/GPUBody/BinOp 1 -- /Loop/If/False/GPUBody/BinOp 1 -- } def main (A: *[10]i64) : *[10]i64 = let (x, A) = loop (x, A) = (A[0], A) for i < 10 do if i%4 == 0 then (x-1, A) else (x+1, A with [i] = 42) in A with [6] = x futhark-0.25.27/tests/migration/loop1_tonext.fut000066400000000000000000000005271475065116200216540ustar00rootroot00000000000000-- Reads can be reduced from one iteration to the next. -- == -- structure gpu { -- /Loop/Index 0 -- } def main [n] (A: *[n]f32) : *[n]f32 = let (A', _) = loop (A, x) = (A, 0) for i < n do let A' = map (+x) A let y = A'[i] -- is delayed into next iteration let A' = A' with [i*i % n] = 42 in (A', y) in A' futhark-0.25.27/tests/migration/loop20_through.fut000066400000000000000000000004131475065116200220660ustar00rootroot00000000000000-- Reads can be delayed through loops. -- == -- structure gpu { -- /GPUBody 2 -- /Loop/GPUBody/BinOp 3 -- /Index 1 -- } def main (A: [10]i64) : i64 = let (a, b) = loop (x, y) = (A[0], A[1]) for i < 10 do let z = x+y in (z%22, z*z) in a+b futhark-0.25.27/tests/migration/loop2_out.fut000066400000000000000000000003031475065116200211330ustar00rootroot00000000000000-- Reads can be delayed out of loops. -- == -- structure gpu { -- /Loop/Index 0 -- /Loop/GPUBody 1 -- /Index 1 -- } def main [n] (A: [n]f32) : f32 = loop x = 0 for i < n do x + A[i] futhark-0.25.27/tests/migration/loop3_forin.fut000066400000000000000000000004071475065116200214470ustar00rootroot00000000000000-- Reads introduced by a for-in loop can be delayed. -- == -- structure gpu { -- /Loop/Index 0 -- /Loop/GPUBody 1 -- /GPUBody 1 -- /Index 1 -- } def main [n] (A: [n]i64) : i64 = let sum = loop x = 0 for y in A do x + y in sum + A[sum % n] futhark-0.25.27/tests/migration/loop4_wholefor.fut000066400000000000000000000003641475065116200221620ustar00rootroot00000000000000-- Whole for loops can be migrated to avoid reading the bound but only if -- doing so can save a read. -- == -- structure gpu { -- /GPUBody/Loop 1 -- /Index 1 -- } def main (A: [10]i64) : i64 = loop x = 1 for i < A[0] do x * A[x%10] futhark-0.25.27/tests/migration/loop5_wholefor.fut000066400000000000000000000003301475065116200221540ustar00rootroot00000000000000-- Whole for loops can be migrated to avoid reading the bound but only if -- doing so can save a read. -- == -- structure gpu { -- GPUBody 0 -- } def main (A: [10]i64) : i64 = loop x = 1 for i < A[0] do x*x futhark-0.25.27/tests/migration/loop6_wholewhile.fut000066400000000000000000000004041475065116200225010ustar00rootroot00000000000000-- Whole while loops can be migrated to avoid reading the loop condition but -- only if doing so can save a read. -- == -- structure gpu { -- /GPUBody/Loop 1 -- /Index 1 -- } def main (A: [10]i64) : i64 = loop x = A[0] while x < 1000 do x * A[x%10] futhark-0.25.27/tests/migration/loop7_wholewhile.fut000066400000000000000000000003521475065116200225040ustar00rootroot00000000000000-- Whole while loops can be migrated to avoid reading the loop condition but -- only if doing so can save a read. -- == -- structure gpu { -- GPUBody 0 -- } def main (A: [10]i64) : i64 = loop x = A[0] while x < 1000 do x * x futhark-0.25.27/tests/migration/loop8_wholewhile.fut000066400000000000000000000003651475065116200225110ustar00rootroot00000000000000-- Whole while loops can be migrated to avoid reading the loop condition but -- only if doing so can save a read. -- == -- structure gpu { -- /GPUBody/Loop 1 -- } def main [n] (A: [n]i64) : i64 = loop x = 0 while x < 1000 do x + A[x%n] futhark-0.25.27/tests/migration/loop9_wholewhile.fut000066400000000000000000000004151475065116200225060ustar00rootroot00000000000000-- Whole while loops can be migrated to avoid reading the loop condition but -- only if doing so can save a read. -- == -- structure gpu { -- /GPUBody/Loop 1 -- } def main [n] (A: [n]i64) : (i64, i64) = loop (x, y) = (0, 0) while x < 1000 do (x + A[x%n], y+1) futhark-0.25.27/tests/migration/map-reduce.fut000066400000000000000000000005441475065116200212420ustar00rootroot00000000000000-- The array read associated with a reduce can be eliminated, and reads can be -- delayed into map kernels. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/Apply 1 -- /Index 0 -- /SegRed 1 -- /SegMap 1 -- } entry vector_norm [n] (A: [n]f32): [n]f32 = let pow2 = map (\x -> x*x) A let len = reduce (+) 0 pow2 |> f32.sqrt in map (/len) Afuthark-0.25.27/tests/migration/merge0.fut000066400000000000000000000005371475065116200204010ustar00rootroot00000000000000-- Migrated statements are moved into GPUBody statements that are combined. -- -- Can merge adjacent GPUBody statements. -- == -- structure gpu { -- /BinOp 1 -- /GPUBody 1 -- /GPUBody/CmpOp 1 -- /GPUBody/If 1 -- } def main (A: *[1]i32) (a: i32) : *[1]i32 = let b = a*2 let x = A[0] let y = if x == b then b+3 else b in A with [0] = y futhark-0.25.27/tests/migration/merge1.fut000066400000000000000000000005431475065116200203770ustar00rootroot00000000000000-- Migrated statements are moved into GPUBody statements that are combined. -- -- Can merge non-adjacent GPUBody statements. -- == -- structure gpu { -- /BinOp 1 -- /GPUBody 1 -- /GPUBody/CmpOp 1 -- /GPUBody/If 1 -- } def main (A: *[1]i32) (a: i32) : *[1]i32 = let x = A[0] let b = a*2 let y = if x == b then b+3 else b in A with [0] = y futhark-0.25.27/tests/migration/merge2.fut000066400000000000000000000012171475065116200203770ustar00rootroot00000000000000-- Migrated statements are moved into GPUBody statements that are combined. -- -- Can merge multiple non-adjacent GPUBody statements. -- == -- structure gpu { -- GPUBody 1 -- } #[noinline] def hostonly 'a (x: a) : i32 = -- This function can only be run on host and thus requires -- its argument to be made available there. let arr = opaque [42] in arr[0] #[noinline] def id 'a (x: a) : a = x def main (A: *[4]i32) : *[4]i32 = let (a, b) = id (A[0], A[1]) -- gpu let c = hostonly a -- host let (x, y) = id (A[2], A[3]) -- gpu let z = hostonly y -- host let B = A with [0] = b+x -- gpu in map (+c+z) B futhark-0.25.27/tests/migration/merge3.fut000066400000000000000000000013151475065116200203770ustar00rootroot00000000000000-- Migrated statements are moved into GPUBody statements that are combined. -- -- Can merge multiple non-adjacent GPUBody statements. -- == -- structure gpu { -- GPUBody 2 -- } #[noinline] def hostonly 'a (x: a) : i32 = -- This function can only be run on host and thus requires -- its argument to be made available there. let arr = opaque [42] in arr[0] #[noinline] def id 'a (x: a) : a = x def main (A: *[4]i32) : *[4]i32 = let (a1, a2) = id (A[0], A[1]) -- gpu 1 let b = hostonly a1 -- host let c = b*a2 -- gpu 2 let (d1, d2) = id (a2*2, A[2]) -- gpu 1 let e = hostonly d1 -- host let f = e*d2 -- gpu 2 in A with [2] = c+f futhark-0.25.27/tests/migration/merge4.fut000066400000000000000000000011761475065116200204050ustar00rootroot00000000000000-- Migrated statements are moved into GPUBody statements that are combined. -- -- Can merge non-adjacent GPUBody statements but only if the latter does not -- depend upon the prior through a non-GPUBody statement. -- == -- structure gpu { -- GPUBody 2 -- } #[noinline] def hostonly 'a (x: a) : i32 = -- This function can only be run on host and thus requires -- its argument to be made available there. let arr = opaque [42] in arr[0] #[noinline] def id 'a (x: a) : a = x def main (A: *[3]i32) : *[3]i32 = let (x, y) = id (A[0], A[1]) -- gpu 1 let z = hostonly x -- host in A with [0] = y+z -- gpu 2 futhark-0.25.27/tests/migration/merge5.fut000066400000000000000000000011311475065116200203750ustar00rootroot00000000000000-- Migrated statements are moved into GPUBody statements that are combined. -- -- The consumption of dependencies are considered when statements are reordered. -- == -- structure gpu { -- GPUBody 2 -- } def main (A: *[5]i32) (x: i32) : *[3]i32 = let A1 = A[0:3] -- alias #1 of A let A2 = A[2:5] -- alias #2 of A let i = x%3 let y = #[unsafe] A1[i]+1 -- gpu 1, observes A through A1 let B = A2 with [0] = 42 -- consumes A through A2 let z = B[2]+y -- gpu 2, depends on B, cannot merge with gpu 1 let C = B with [1] = z in C :> *[3]i32 futhark-0.25.27/tests/migration/merge6.fut000066400000000000000000000014341475065116200204040ustar00rootroot00000000000000-- Migrated statements are moved into GPUBody statements that are combined. -- -- The consumption of dependencies are considered when statements are reordered. -- == -- structure gpu { -- /GPUBody 2 -- /GPUBody/If/True/Update 1 -- } let one = opaque 1i64 let two = one + 1 def main (A: *[5]i64) (x: i64) : *[1]i64 = let A1 = A[0:3] -- alias #1 of A let A2 = A[2:5] -- alias #2 of A let y = A1[0]+A1[1] -- gpu 1, observes A through A1 let z = if x == 0 -- observes A through A2 then reduce (+) 0 A2 else y let C = if A1[0] == 0 -- gpu 2, consumes A through A1 then let A1' = A1 with [1] = z in #[unsafe] A1'[one:two] else #[unsafe] A1[one:two] in #[unsafe] C :> *[1]i64 futhark-0.25.27/tests/migration/merge7.fut000066400000000000000000000010261475065116200204020ustar00rootroot00000000000000-- Migrated statements are moved into GPUBody statements that are combined. -- -- The consumption of dependencies are considered when statements are merged. -- == -- structure gpu { -- GPUBody/ArrayLit 1 -- GPUBody/If/True/Update 1 -- } let one = opaque 1i64 let two = one + 1 def main (A: *[5]i32) (x: i32) : *[5]i32 = let B = [0, x, 2] let y = #[unsafe] B[one] let C = if A[0] == y then let B' = B with [1] = 3 in #[unsafe] B'[one:two] else #[unsafe] B[one:two] in A with [1:2] = C futhark-0.25.27/tests/migration/reduce0.fut000066400000000000000000000003401475065116200205410ustar00rootroot00000000000000-- Neutral elements are made available on host. -- == -- structure gpu { -- /Index 1 -- /SegRed 1 -- } def main (A: *[10]f32): [1]f32 = let A = A with [0] = 0 let B = opaque A let ne = B[0] in [reduce (+) ne B]futhark-0.25.27/tests/migration/reduce1.fut000066400000000000000000000006501475065116200205460ustar00rootroot00000000000000-- Reads can be delayed into kernel bodies and combining operators. -- == -- structure gpu { -- /Index 0 -- /SegRed 1 -- /SegMap 0 -- } def main (A: *[10]i64): [1]i64 = let A = A with [0] = 0 let A = A with [1] = 0 let B = opaque A let x = B[0] -- This read can be delayed into op let y = B[1] -- This read can be delayed into the kernel body let op = \a b -> a+b+x in [reduce_comm op 0 (map (+y) B)]futhark-0.25.27/tests/migration/reduction0_dup.fut000066400000000000000000000003431475065116200221410ustar00rootroot00000000000000-- Common subexpression elimination should eliminate duplicate array reads -- so no migration reduction is necessary. -- == -- structure gpu { -- GPUBody 0 -- /Index 1 -- } def main (arr: [3]f32) : f32 = arr[0] + arr[0] futhark-0.25.27/tests/migration/reduction10_mixed.fut000066400000000000000000000004321475065116200225370ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- /GPUBody 1 -- /Index 1 -- /BinOp 0 -- } def main (arr: [3]f32) : f32 = let (a, b) = (arr[0], arr[1]) let c = a*b + a let d = -b/2 let e = c / arr[2] let f = 10 * e in f * d futhark-0.25.27/tests/migration/reduction1_2-1.fut000066400000000000000000000003461475065116200216540ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/Index 2 -- /GPUBody/BinOp 1 -- /Index 1 -- /BinOp 0 -- } def main (arr: [3]f32) : f32 = arr[0] + arr[1] futhark-0.25.27/tests/migration/reduction2_3-1.fut000066400000000000000000000003311475065116200216500ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/BinOp 2 -- /Index 1 -- /BinOp 0 -- } def main (arr: [3]f32) : f32 = arr[0] + arr[1] + arr[2] futhark-0.25.27/tests/migration/reduction3_2-1-2.fut000066400000000000000000000004031475065116200220070ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/BinOp 1 -- /Index 1 -- /BinOp 2 -- } def main (arr: [3]f32) : (f32, f32) = let (a, b) = (arr[0], arr[1]) let x = a*b in (x+1, x-1) futhark-0.25.27/tests/migration/reduction4_2-2-2.fut000066400000000000000000000003521475065116200220140ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- GPUBody 0 -- /Index 2 -- } def main (arr: [3]f32) : (f32, f32) = let (a, b) = (arr[0], arr[1]) let (x, y) = (a*b, a+b) in (x*y, x/y) futhark-0.25.27/tests/migration/reduction5_2-3-2.fut000066400000000000000000000003621475065116200220170ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- GPUBody 0 -- /Index 2 -- } def main (arr: [3]f32) : (f32, f32) = let (a, b) = (arr[0], arr[1]) let (x, y, z) = (a*b, a+b, a-b) in (x*y, y*z) futhark-0.25.27/tests/migration/reduction6_2-3-1.fut000066400000000000000000000004131475065116200220140ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/BinOp 5 -- /Index 1 -- /BinOp 0 -- } def main (arr: [3]f32) : f32 = let (a, b) = (arr[0], arr[1]) let (x, y, z) = (a*b, a+b, a-b) in x*y*z futhark-0.25.27/tests/migration/reduction7_3-2-4-2.fut000066400000000000000000000005361475065116200221650ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/Index 3 -- /GPUBody/BinOp 4 -- /Index 2 -- /BinOp 6 -- } def main (arr: [3]f32) : (f32, f32) = let (a, b, c) = (arr[0], arr[1], arr[2]) let (x, y) = (a+b+c, a-b-c) let (d, e, f, g) = (x+y, x-y, x*y, x/y) in (d%e, f+g) futhark-0.25.27/tests/migration/reduction8_3-2-4-1.fut000066400000000000000000000004501475065116200221600ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- /GPUBody 1 -- /Index 1 -- /BinOp 0 -- } def main (arr: [3]f32) : f32 = let (a, b, c) = (arr[0], arr[1], arr[2]) let (x, y) = (a+b+c, a-b-c) let (d, e, f, g) = (x+y, x-y, x*y, x/y) in d*e*f*g futhark-0.25.27/tests/migration/reduction9_deep2-1.fut000066400000000000000000000004151475065116200225170ustar00rootroot00000000000000-- Only migrate array reads when a reduction can be obtained. -- == -- structure gpu { -- /GPUBody 1 -- /Index 1 -- /BinOp 0 -- } def main (arr: [3]f32) : f32 = let a = arr[0] let b = a+2 let c = 4*b let x = arr[1] let y = x+2 let z = 3*y in c/z futhark-0.25.27/tests/migration/replicate0.fut000066400000000000000000000003661475065116200212520ustar00rootroot00000000000000-- Replicated arrays can be migrated if they replicate a value once. -- -- replicate 1 n is equivalent to gpu { n }. -- == -- structure gpu { -- /GPUBody 1 -- Replicate 0 -- /Index 0 -- } def main (A: [1]i32) : *[1]i32 = replicate 1 A[0]futhark-0.25.27/tests/migration/replicate1.fut000066400000000000000000000005651475065116200212540ustar00rootroot00000000000000-- Replicated arrays can be migrated if they replicate a value once. -- -- The replicates are combined into one and migrated with their outermost -- dimension dropped, which is reintroduced by the GPUBody construct. -- == -- structure gpu { -- /GPUBody/Replicate 1 -- /Replicate 0 -- /Index 0 -- } def main (A: [1]i32) : *[1][1]i32 = replicate 1 (replicate 1 A[0])futhark-0.25.27/tests/migration/replicate2.fut000066400000000000000000000006071475065116200212520ustar00rootroot00000000000000-- In general replicates are not migrated in order not to turn a parallel -- operation into a sequential one (GPUBody kernels are single-threaded). -- -- They can however be rewritten to allow the computation of their replicated -- value to be migrated. -- == -- structure gpu { -- /GPUBody/BinOp 1 -- /Replicate 1 -- } def main (A: [1]i32) (n: i64) : *[n]i32 = replicate n (A[0] + 1)futhark-0.25.27/tests/migration/replicate3.fut000066400000000000000000000006321475065116200212510ustar00rootroot00000000000000-- In general replicates are not migrated in order not to turn a parallel -- operation into a sequential one (GPUBody kernels are single-threaded). -- -- They can however be rewritten to allow the computation of their replicated -- value to be migrated. -- == -- structure gpu { -- /GPUBody 1 -- /Replicate 1 -- /Index 0 -- } def main (A: [1]i32) (n: i64) : *[n][1]i32 = replicate n (replicate 1 A[0])futhark-0.25.27/tests/migration/reshape.fut000066400000000000000000000011261475065116200206440ustar00rootroot00000000000000-- Reshapes may not be migrated on their own as results from GPUBody constructs -- are copied, which would change the asymptotic cost of the operation. -- -- In general any scalar used within a type must be made available on host -- before use. -- == -- structure gpu { -- GPUBody 0 -- Reshape 1 -- } #[noinline] def alter [n] (A: [n]i64) : *[]i64 = let l = n%10 in [l+1] ++ A[:l] ++ A #[noinline] def modify [m] (A: [m]i64) (x: i64) : *[m]i64 = map (+x) A def main [n] (A: [n]i64) : *[]i64 = let A' = alter A let m = A'[0] -- must be read let B = A' :> [m]i64 in modify B m futhark-0.25.27/tests/migration/reuse0_index.fut000066400000000000000000000016041475065116200216100ustar00rootroot00000000000000-- A statement that reuses memory can be migrated as part of a parent body -- if none but single elements of the reused memory (updated or aliased) are -- returned or if the memory source is migrated into the same kernel. -- == -- structure gpu { -- /GPUBody/If/True/Index 2 -- /GPUBody/Loop/Index 5 -- } entry case_if (A: *[5]i64) (x: i64) : [1]i64 = if A[0] == 42 then let B = #[unsafe] A[x%3 : x%3+2] in #[unsafe] (opaque B)[0:1] :> [1]i64 else A[0:1] :> [1]i64 entry case_while (A: [5]i64) : [1]i64 = let (_, C) = loop (x, A') = (0, A[0:1]) while A'[0] != x do let B = #[unsafe] A[x%3 : x%3+2] in (x+1, #[unsafe] (opaque B)[0:1] :> [1]i64) in C entry case_for (A: [5]i64) : [1]i64 = let (_, C) = loop (x, _) = (0, A[0:1]) for i < A[0] do let B = #[unsafe] A[x%3 : x%3+2] in (x+1, #[unsafe] (opaque B)[0:1] :> [1]i64) in C futhark-0.25.27/tests/migration/reuse1_update.fut000066400000000000000000000025341475065116200217670ustar00rootroot00000000000000-- A statement that reuses memory can be migrated as part of a parent body -- if none but single elements of the reused memory (updated or aliased) are -- returned or if the memory source is migrated into the same kernel. -- == -- structure gpu { -- /GPUBody/If/True/Update 1 -- /GPUBody/Loop/Update 0 -- } entry case_if (A: *[5]i64) (x: i64) : [1]i64 = if A[0] == 42 then let B = #[unsafe] A with [x%5] = 0 in #[unsafe] (opaque B)[0:1] :> [1]i64 else A[0:1] :> [1]i64 -- Compiler limitations prevent these cases from being validated. -- -- We cannot consume something outside the loops; allocating inside the -- (migrated) loops causes allocation errors; introducing a multi-element -- array as a loop parameter disables the optimization; and updating a -- single-element array loop parameter is replaced with a replicate, which -- also causes an allocation error. -- entry case_while (A: *[5]i64) : [1]i64 = -- let (_, C) = -- loop (x, A') = (0, A[0:1]) while A'[1] != x do -- let B = #[unsafe] [0, 1, x, 3, 4] with [x%5] = 0 -- in (x+1, #[unsafe] (opaque B)[0:1] :> [1]i64) -- in C -- entry case_for (A: *[5]i64) : [1]i64 = -- let (_, C) = -- loop (x, _) = (0, A[0:1]) for i < A[1] do -- let B = #[unsafe] [0, 1, x, 3, 4] with [x%5] = 0 -- in (x+1, #[unsafe] (opaque B)[0:1] :> [1]i64) -- in C futhark-0.25.27/tests/migration/reuse2_flatindex.fut000066400000000000000000000010601475065116200224550ustar00rootroot00000000000000-- A statement that reuses memory can be migrated as part of a parent body -- if none but single elements of the reused memory (updated or aliased) are -- returned or if the memory source is migrated into the same kernel. -- == -- structure gpu { -- /GPUBody/If/True/FlatIndex 1 -- } import "intrinsics" -- This fails due to a memory allocation error. -- entry case_if (A: *[5]i64) (x: i64) : [1]i64 = -- if A[4] == 42 -- then let B = flat_index_2d A 0 2 2 2 1 -- in #[unsafe] (opaque B)[0:1, 0] :> [1]i64 -- else A[0:1] :> [1]i64 futhark-0.25.27/tests/migration/reuse3_flatupdate.fut000066400000000000000000000011541475065116200226350ustar00rootroot00000000000000-- A statement that reuses memory can be migrated as part of a parent body -- if none but single elements of the reused memory (updated or aliased) are -- returned or if the memory source is migrated into the same kernel. -- == -- structure gpu { -- /GPUBody/If/True/FlatUpdate 1 -- } import "intrinsics" let v = [[1i64]] entry case_if (A: *[5]i64) (x: i64) : [1]i64 = if A[4] == 42 then let B = flat_update_2d A 0 1 1 v in #[unsafe] (opaque B)[0:1] :> [1]i64 else A[0:1] :> [1]i64 -- Compiler limitations prevent the loop cases from being validated. -- See 'reuse1_update.fut' for details. futhark-0.25.27/tests/migration/reuse4_scratch.fut000066400000000000000000000013721475065116200221360ustar00rootroot00000000000000-- A statement that reuses memory can be migrated as part of a parent body -- if none but single elements of the reused memory (updated or aliased) are -- returned or if the memory source is migrated into the same kernel. -- == -- structure gpu { -- /GPUBody/If/True/Scratch 1 -- /GPUBody/Loop/Scratch 1 -- } -- These fail due to memory allocation errors. -- entry case_if (A: *[5]i64) : [1]i64 = -- if A[0] == 42 -- then let B = #[sequential] map (+1) A -- in #[unsafe] (opaque B)[0:1] :> [1]i64 -- else A[0:1] :> [1]i64 -- entry case_for (A: *[5]i64) : [1]i64 = -- loop A' = A[0:1] for i < A[1] do -- let B = [0, 1, 2, 3, A'[0]] -- let C = #[sequential] map (+i) B -- let idx = i%5 -- in C[idx:idx+1] :> [1]i64futhark-0.25.27/tests/migration/reuse5_reshape.fut000066400000000000000000000016111475065116200221330ustar00rootroot00000000000000-- A statement that reuses memory can be migrated as part of a parent body -- if none but single elements of the reused memory (updated or aliased) are -- returned or if the memory source is migrated into the same kernel. -- == -- structure gpu { -- /If/True/GPUBody/If/True/Reshape 1 -- /GPUBody/Loop/Reshape 2 -- } -- These programs are artificial. -- Most natural alternatives cannot be validated due to compiler limitations. entry case_if [n] (A: *[n]i64) : [1]i64 = if n > 0 then if #[unsafe] A[0] == 42 then (opaque A) :> [1]i64 else #[unsafe] A[0:1] :> [1]i64 else [42] entry case_while (A: []i64) : [1]i64 = let (_, C) = loop (x, A') = (0, A[0:1]) while A'[0] != x do (x+1, (opaque A) :> [1]i64) in C entry case_for (A: [5]i64) : [1]i64 = let (_, C) = loop (x, _) = (0, A[0:1]) for i < A[0] do (x+1, (opaque A) :> [1]i64) in C futhark-0.25.27/tests/migration/reuse6_rearrange.fut000066400000000000000000000007341475065116200224600ustar00rootroot00000000000000-- A statement that reuses memory can be migrated as part of a parent body -- if none but single elements of the reused memory (updated or aliased) are -- returned or if the memory source is migrated into the same kernel. -- == -- structure gpu { -- /GPUBody/If/True/Rearrange 1 -- } entry case_if (A: [3][2]i64) (x: i64) : [1]i64 = if A[0,0] == 42 then let B = transpose (opaque A) in #[unsafe] (opaque B)[0, 0:1] :> [1]i64 else A[0, 0:1] :> [1]i64 futhark-0.25.27/tests/migration/scalar_ops.fut000066400000000000000000000024351475065116200213470ustar00rootroot00000000000000-- This test verifies that a host-only usage can be found through all major -- scalar expression types, indicating that migration analysis covers all of -- them. -- == -- structure gpu { -- GPUBody 0 -- } #[noinline] def hostonly 'a (x: a) : bool = -- This function can only be run on host and thus requires -- its argument to be made available there. let arr = opaque [true] in arr[0] #[noinline] def join 'a (x: a) (y: a) : a = x #[noinline] def join3 'a (x: a) (y: a) (z: a) : a = x -- 'SubExp' cannot be tested due to elimination by the simplifier. entry opaque (A: [2]i32) : i32 = let (a, b) = (A[0], A[1]) let x = opaque b in if hostonly x then join a b else 0 entry unOp (A: [2]i32) : i32 = let (a, b) = (A[0], A[1]) let x = i32.abs b in if hostonly x then join a b else 0 entry binOp (A: [3]i32) : i32 = let (a, b, c) = (A[0], A[1], A[2]) let x = a + 4 let y = 2 + b in if hostonly x || hostonly y then join3 a b c else 0 entry cmpOp (A: [3]i32) : i32 = let (a, b, c) = (A[0], A[1], A[2]) let x = a == 4 let y = 2 == b in if hostonly x || hostonly y then join3 a b c else 0 entry convOp (A: [2]i32) : i32 = let (a, b) = (A[0], A[1]) let x = i32.to_i64 b in if hostonly x then join a b else 0 futhark-0.25.27/tests/migration/scan0.fut000066400000000000000000000003361475065116200202230ustar00rootroot00000000000000-- Neutral elements are made available on host. -- == -- structure gpu { -- /Index 1 -- /SegScan 1 -- } def main (A: *[10]f32): [10]f32 = let A = A with [0] = 0 let B = opaque A let ne = B[0] in scan (+) ne Bfuthark-0.25.27/tests/migration/scan1.fut000066400000000000000000000006411475065116200202230ustar00rootroot00000000000000-- Reads can be delayed into kernel bodies and combining operators. -- == -- structure gpu { -- /Index 0 -- /SegScan 1 -- /SegMap 0 -- } def main (A: *[10]i64): [10]i64 = let A = A with [0] = 0 let A = A with [1] = 0 let B = opaque A let x = B[0] -- This read can be delayed into op let y = B[1] -- This read can be delayed into the kernel body let op = \a b -> a+b+x in scan op 0 (map (+y) B)futhark-0.25.27/tests/migration/sizevar0_reshape.fut000066400000000000000000000015761475065116200225000ustar00rootroot00000000000000-- Size variables must be made available on host before use and thus block the -- migration of any parent statements. -- == -- structure gpu { -- /If 1 -- /If/True/GPUBody 1 -- } def main (A: [5]i64) : [1]i64 = if A[0] == 0 -- blocked. then let x = if A[1] == 0 -- blocked. then let n = A[2] -- required on host. in if A[3] == 0 -- not blocked. then let B = A :> [n]i64 -- n used as a size variable. in (opaque B)[0:1] else let i = n%5 in A[i:i+1] :> [1]i64 else A[0:1] :> [1]i64 in if A[4] == x[0] -- not blocked. then A[1:2] :> [1]i64 else A[2:3] :> [1]i64 else A[3:4] :> [1]i64 futhark-0.25.27/tests/migration/sizevar1_index.fut000066400000000000000000000016721475065116200221560ustar00rootroot00000000000000-- Size variables must be made available on host before use and thus block the -- migration of any parent statements. -- == -- structure gpu { -- /If 1 -- /If/True/If 1 -- /If/True/If/True/GPUBody/If 1 -- /If/True/GPUBody/If 1 -- } def main (A: [5]i64) : [1]i64 = if A[0] == 0 -- blocked. then let x = if A[1] == 0 -- blocked. then let n = A[2] -- required on host. in if A[3] == 0 -- not blocked. then let B = A[:n] -- n used as a size variable. in opaque B :> [1]i64 else let i = n%5 in A[i:i+1] :> [1]i64 else A[0:1] :> [1]i64 in if A[4] == x[0] -- not blocked. then A[1:2] :> [1]i64 else A[2:3] :> [1]i64 else A[3:4] :> [1]i64 futhark-0.25.27/tests/migration/sunk0.fut000066400000000000000000000006751475065116200202650ustar00rootroot00000000000000-- Array reads are sunk to their deepest possible branch. -- == -- structure gpu { -- GPUBody 1 -- /Index 0 -- /If/True/Index 1 -- /If/True/GPUBody/Index 1 -- /If/True/GPUBody/If/True/Index 1 -- /If/True/GPUBody/If/False/Index 1 -- } def main (A: [5]i32) (x: i32) : i32 = let y = A[0] in if x == 7 then let z = A[1] in if A[2] == 42 then y else z else 14 futhark-0.25.27/tests/migration/sunk1.fut000066400000000000000000000010701475065116200202540ustar00rootroot00000000000000-- Array reads are sunk to their deepest possible branch. -- == -- structure gpu { -- GPUBody 1 -- /Loop/Index 0 -- /Loop/If/True/GPUBody 1 -- /Loop/If/True/Index 1 -- } #[noinline] def hostonly 'a (x: a) : i32 = -- This function can only be run on host and thus requires -- its argument to be made available there. let arr = opaque [7] in arr[0] def main (A: [5](i32, i32)) (x: i32) : (i32, i32) = loop (res, c) = (0, x) for (y, z) in A do if c == 3 then (y+1, hostonly z) -- reads of y and z should be sunk here else (res, c+1) futhark-0.25.27/tests/migration/top-level0.fut000066400000000000000000000005531475065116200212070ustar00rootroot00000000000000-- Migration analysis is also done for top-level constants. -- -- Only constants that are used by functions need to be read to host. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/Index 2 -- /GPUBody/BinOp 1 -- /Index 1 -- /CmpOp 1 -- } let arr = opaque [4i32, 2i32] let a = arr[0] let b = arr[1] let c = a + b def main (x : i32) : bool = c == xfuthark-0.25.27/tests/migration/top-level1.fut000066400000000000000000000004721475065116200212100ustar00rootroot00000000000000-- Migration analysis is also done for top-level constants. -- -- Only constants that are used by functions need to be read to host. -- == -- structure gpu { -- GPUBody 0 -- /Index 2 -- } let arr = opaque [4i32, 2i32] let a = arr[0] let b = arr[1] let c = a + b def main (x : i32) : bool = c == x || b == xfuthark-0.25.27/tests/migration/update0.fut000066400000000000000000000004271475065116200205620ustar00rootroot00000000000000-- Values written by Updates are not required to be available on host. -- == -- structure gpu { -- /GPUBody 1 -- /GPUBody/Index 1 -- /GPUBody/BinOp 1 -- /Index 0 -- /Update 1 -- } def main (A: *[9]f32) : *[9]f32 = let x = A[4] let x' = x / 7 in A with [4] = x' futhark-0.25.27/tests/migration/update1.fut000066400000000000000000000006111475065116200205560ustar00rootroot00000000000000-- Updates may not be migrated on their own as results from GPUBody constructs -- are copied, which would change the asymptotic cost of the operation. -- -- If their index depends on an array read, that read cannot be prevented. -- == -- structure gpu { -- GPUBody 0 -- /Index 1 -- /BinOp 1 -- /Update 1 -- } def main [n] (A: *[n]i64) : *[n]i64 = let i = A[4] in A with [i] = 42 futhark-0.25.27/tests/migration/withacc0.fut000066400000000000000000000005441475065116200207220ustar00rootroot00000000000000-- Neutral elements are made available on host. -- == -- structure gpu { -- /Index 1 -- } import "intrinsics" def f (acc: *acc ([]i32)) i = let acc = write acc i 1 let acc = write acc (i+1) 1 in acc def main (A: *[10]i32) : *[10]i32 = let A = A with [0] = 0 let B = opaque A let ne = B[0] in reduce_by_index_stream B (+) ne f (iota 10) futhark-0.25.27/tests/migration/withacc1.fut000066400000000000000000000007211475065116200207200ustar00rootroot00000000000000-- Reads can be delayed into kernel bodies and combining operators. -- == -- structure gpu { -- /Index 0 -- } import "intrinsics" def main (A: *[10]i32) : *[10]i32 = let A = A with [0] = 0 let A = A with [1] = 1 let B = opaque A let x = B[0] -- This read can be delayed into op let y = B[1] -- This read can be delayed into f let op = \a b -> a+b+x let f = \(acc: *acc ([]i32)) i -> write acc i y in reduce_by_index_stream B op 0 f (iota 10) futhark-0.25.27/tests/modules/000077500000000000000000000000001475065116200161545ustar00rootroot00000000000000futhark-0.25.27/tests/modules/Vec3.fut000066400000000000000000000022571475065116200175020ustar00rootroot00000000000000module Vec3 = { module f32 = { type t = ( f32 , f32 , f32 ) def add(a: t , b: t): t = let (a1, a2, a3) = a let (b1, b2, b3) = b in (a1 + b1, a2 + b2 , a3 + b3) def subtract(a: t , b: t): t = let (a1, a2, a3) = a let (b1, b2, b3) = b in (a1 - b1, a2 - b2 , a3 - b3) def scale(k: f32 , a: t): t = let (a1, a2, a3) = a in (a1 * k, a2 * k , a3 * k) def dot(a: t , b: t): f32 = let (a1, a2, a3) = a let (b1, b2, b3) = b in a1*b1 + a2*b2 + a3*b3 } module Int = { type t = ( i32 , i32 , i32 ) def add(a: t , b: t): t = let (a1, a2, a3) = a let (b1, b2, b3) = b in (a1 + b1, a2 + b2 , a3 + b3) def subtract(a: t , b: t): t = let (a1, a2, a3) = a let (b1, b2, b3) = b in (a1 - b1, a2 - b2 , a3 - b3) def scale(k: i32 , a: t): t = let (a1, a2, a3) = a in (a1 * k, a2 * k , a3 * k) def dot(a: t , b: t): i32 = let (a1, a2, a3) = a let (b1, b2, b3) = b in a1*b1 + a2*b2 + a3*b3 } } futhark-0.25.27/tests/modules/anonymous-signature.fut000066400000000000000000000002621475065116200227230ustar00rootroot00000000000000-- Can we match a module with an unnamed signature? -- == -- input { 5 } output { 7 } module M: {val x: i32} = { def x: i32 = 2 def y: i32 = 3 } def main(x: i32) = M.x + x futhark-0.25.27/tests/modules/ascription-error0.fut000066400000000000000000000004501475065116200222550ustar00rootroot00000000000000-- Abstract types must be abstract. -- == -- error: Function body does not have module type SIG = { type t val inject: i32 -> t val extract: t -> i32 } module Struct: SIG = { type t = i32 def inject (x: i32): i32 = x def extract (x: i32): i32 = x } def main(x: i32): i32 = Struct.inject x futhark-0.25.27/tests/modules/ascription-error1.fut000066400000000000000000000004001475065116200222510ustar00rootroot00000000000000-- We may not access structure members not part of the signature. -- == -- error: Struct.g module type SIG = { val f: i32 -> i32 } module Struct: SIG = { def f (x: i32): i32 = x + 2 def g (x: i32): i32 = x + 3 } def main(x: i32): i32 = Struct.g x futhark-0.25.27/tests/modules/ascription-error2.fut000066400000000000000000000003611475065116200222600ustar00rootroot00000000000000-- Opaque signature ascription must hide equality of type. -- == -- error: type module type S = { type t val a : t val f : t -> i32 } module B : S = { type t = i32 def a:t = 3 def f (a:t):t = a } module C : S = B def main() : i32 = C.f B.a futhark-0.25.27/tests/modules/ascription-error3.fut000066400000000000000000000002551475065116200222630ustar00rootroot00000000000000-- Ascription must respect uniqueness. -- == -- error: \*\[d\]i32 module type mt = { val f [n] : [n]i32 -> ?[d].*[d]i32 } module m = { def f (ns: []i32): []i32 = ns } : mt futhark-0.25.27/tests/modules/ascription-error4.fut000066400000000000000000000004301475065116200222570ustar00rootroot00000000000000-- Check that value mismatches in nested modules use qualified names. -- == -- error: bar.x module type mt = { module foo : { val x : i32 } module bar : { val x : bool } } module m : mt = { module foo = { def x : i32 = 1 } module bar = { def x : i32 = 1 } } futhark-0.25.27/tests/modules/ascription-error5.fut000066400000000000000000000004231475065116200222620ustar00rootroot00000000000000-- Check that type mismatches in nested modules use qualified names. -- == -- error: bar.t module type mt = { module foo : { type t = i32 } module bar : { type t = bool } } module m : mt = { module foo = { type t = i32 } module bar = { type t = i32 } } futhark-0.25.27/tests/modules/ascription-error6.fut000066400000000000000000000002261475065116200222640ustar00rootroot00000000000000-- Uniqueness stuff. -- == -- error: \*\[1\]f32 module M : { val f : [1]f32 -> bool } = { def f (_: *[1]f32) = true } def main (_: []f32) = M.f futhark-0.25.27/tests/modules/ascription-error7.fut000066400000000000000000000001431475065116200222630ustar00rootroot00000000000000-- == -- error: constructive module m : { type sum [n][m] } = { type sum [n][m] = [n+m]bool } futhark-0.25.27/tests/modules/ascription-sizelifted0.fut000066400000000000000000000002271475065116200232700ustar00rootroot00000000000000-- == -- error: "empty" module type dict = { type~ dict val empty : dict } module naive_dict : dict = { type~ dict = []bool def empty = [] } futhark-0.25.27/tests/modules/ascription-sizelifted1.fut000066400000000000000000000002561475065116200232730ustar00rootroot00000000000000-- == -- error: "empties" module type dict = { type~ dict val empties : (dict,dict) } module naive_dict : dict = { type~ dict = []bool def empties = ([], [true]) } futhark-0.25.27/tests/modules/ascription-sizelifted2.fut000066400000000000000000000002221475065116200232650ustar00rootroot00000000000000module type dict = { type~ dict val mk : () -> (dict,dict) } module naive_dict : dict = { type~ dict = []bool def mk () = ([], [true]) } futhark-0.25.27/tests/modules/ascription0.fut000066400000000000000000000006151475065116200211310ustar00rootroot00000000000000-- Basic signature matching without abstract types. -- == -- input { [1,2,3] [4,5,6] } -- output { 6 15 } module type SIG = { type t = (i32, i32) val x: t val f [n]: [n]t -> t } module Struct: SIG = { type t = (i32,i32) def x: (i32, i32) = (2,2) def f (as: []t): t = reduce (\(a,b) (c,d) -> (a+c,b+d)) (0,0) as } def main(xs: []i32) (ys: []i32) = Struct.f (zip xs ys) : Struct.t futhark-0.25.27/tests/modules/ascription1.fut000066400000000000000000000010311475065116200211230ustar00rootroot00000000000000-- Signature matching with a single abstract type. -- == -- input { [1,2,3] [4,5,6] } -- output { 6 15 } module type SIG = { type t val inject: i32 -> i32 -> t val extract: t -> (i32,i32) val f [n]: [n]t -> t } module Struct: SIG = { type t = (i32,i32) def x: (i32, i32) = (2,2) def inject (x: i32) (y: i32): t = (x, y) def extract (v:t): t = v def f (as: []t): t = reduce (\(a,b) (c,d) -> (a+c,b+d)) (0,0) as } def main (xs: []i32) (ys: []i32): (i32,i32) = Struct.extract (Struct.f (map2 Struct.inject xs ys)) futhark-0.25.27/tests/modules/ascription10.fut000066400000000000000000000006541475065116200212150ustar00rootroot00000000000000-- An abstract type that is realised by a nested module. module type number = { type t val i32: i32 -> t } module has_number: number with t = i32 = { type t = i32 def i32 (x: i32) = x } module type optimizable = { module loss: number } module opaque : optimizable = { module loss = has_number } module mt (optable: optimizable) = { module loss = optable.loss } module m = mt opaque def main (x: i32) = m.loss.i32 x futhark-0.25.27/tests/modules/ascription11.fut000066400000000000000000000007541475065116200212170ustar00rootroot00000000000000module type number = { type t val i32: i32 -> t } module has_number: number with t = i32 = { type t = i32 def i32 (x: i32) = x } module type optimizable = { module loss: number } module stochastic_gradient_descent (optable: optimizable) = { module loss = optable.loss } module logistic_regression (dummy: {}) : optimizable = { module loss = has_number } module logreg_m = logistic_regression {} module sgd = stochastic_gradient_descent logreg_m def main (x: i32) = sgd.loss.i32 x futhark-0.25.27/tests/modules/ascription12.fut000066400000000000000000000003201475065116200212050ustar00rootroot00000000000000module type sized = { val len: i64 } module arr (S: sized): { type t = [S.len]i32 } = { type t = [S.len]i32 } module nine = { def len = 9i64 } module arr_nine : { type t = [nine.len]i32 } = arr nine futhark-0.25.27/tests/modules/ascription13.fut000066400000000000000000000003711475065116200212140ustar00rootroot00000000000000module type sparse = { type csr module csr : { type mat = csr } } module sparse (T : {}) : sparse = { module csr = { type mat = bool } type csr = csr.mat } module spa = sparse {} module csr = spa.csr def main (x: csr.mat) = x futhark-0.25.27/tests/modules/ascription14.fut000066400000000000000000000002011475065116200212050ustar00rootroot00000000000000-- Uniqueness stuff. -- == module M : { val f : *[1]f32 -> bool } = { def f (_: [1]f32) = true } def main (_: []f32) = M.f futhark-0.25.27/tests/modules/ascription15.fut000066400000000000000000000002161475065116200212140ustar00rootroot00000000000000module type mt = { type sum [n][m] = ([n]bool, [m]bool, [n+m]bool) } module m : mt = { type sum [n][m] = ([n]bool, [m]bool, [n+m]bool) } futhark-0.25.27/tests/modules/ascription2.fut000066400000000000000000000004021475065116200211250ustar00rootroot00000000000000-- Test ascriptions of signatures with type abbreviations referencing -- abstract types. -- == -- input {} output { 3 } module type T1 = { type t type s = t val a : s } module X : T1 = { type t = i32 type s = i32 def a : s = 3 } -- ok def main : i32 = 3 futhark-0.25.27/tests/modules/ascription3.fut000066400000000000000000000003311475065116200211270ustar00rootroot00000000000000-- Ascription only needs a subtype. -- == -- input { 2 } output { [0,0] } module type S = { val f: i32 -> []i32 } module M: S = { def f(x: i32): *[]i32 = replicate (i64.i32 x) 0 } def main(n: i32): []i32 = M.f n futhark-0.25.27/tests/modules/ascription4.fut000066400000000000000000000003571475065116200211400ustar00rootroot00000000000000-- Ascription can happen anywhere in a module expression. -- == -- input { 2 } output { [0,0] } module type S = { val f: i32 -> []i32 } module M = { def f(x: i32): *[]i32 = replicate (i64.i32 x) 0 }: S def main(n: i32): []i32 = M.f n futhark-0.25.27/tests/modules/ascription5.fut000066400000000000000000000005751475065116200211430ustar00rootroot00000000000000-- Multiple abstract types with the same unqualified name. This will -- not work if the type checker simply flattens out all abstract -- types. -- == -- input { } output { 2 } module type MT = { module A: { type a type b } module B: { type a type b } } module M: MT = { module A = { type a = i32 type b = bool } module B = { type a = bool type b = i32 } } def main = 2 futhark-0.25.27/tests/modules/ascription6.fut000066400000000000000000000003631475065116200211370ustar00rootroot00000000000000-- Multiple levels of ascription. -- == -- input {} output { 4 2 } module outer: { val x: i32 module inner: { val y: i32 } } = { module inner: { val y: i32} = { def y = 2 } def x = inner.y + 2 } def main = (outer.x, outer.inner.y) futhark-0.25.27/tests/modules/ascription7.fut000066400000000000000000000002121475065116200211310ustar00rootroot00000000000000-- Basic/naive use of ascription. -- == -- input {} output { 2 } module m = { def x = 2 } module m': { val x: i32 } = m def main = m'.x futhark-0.25.27/tests/modules/ascription8.fut000066400000000000000000000010601475065116200211340ustar00rootroot00000000000000-- Ascription of a module containing a parametric module whose -- parameter contains an abstract type. -- == -- input {} output {0.0} module type sobol = { module Reduce : (X : { type t val ne : t }) -> { val run : i32 -> X.t } } module Sobol: sobol = { module Reduce (X : { type t val ne : t }) : { val run : i32 -> X.t } = { def run (N:i32) : X.t = copy X.ne } } module R = Sobol.Reduce { type t = f64 def ne = 0f64 } def main : f64 = R.run 100000 futhark-0.25.27/tests/modules/ascription9.fut000066400000000000000000000005451475065116200211440ustar00rootroot00000000000000-- Size annotations should not mess up module ascription. module type vector = { type vector 'a } module mk_kmeans (D: {}) (V: vector) (R: {}): { type point = V.vector f32 val kmeans [n] : (points: [n]point) -> ([n]point, i32) } = { type point = V.vector f32 def kmeans [n] (points: [n]point): ([n]point, i32) = (points, 0) } futhark-0.25.27/tests/modules/calling_nested_module.fut000066400000000000000000000005471475065116200232220ustar00rootroot00000000000000-- == -- input { -- 10 21 -- } -- output { -- 6 -- } type t = i32 module NumLib = { def plus(a: t, b: t): t = a + b module BestNumbers = { def four(): t = 4 def seven(): t = 42 def six(): t = 41 } } def localplus(a: i32, b: i32): i32 = NumLib.plus (a,b) def main (a: i32) (b: i32): i32 = localplus(NumLib.BestNumbers.four() , 2) futhark-0.25.27/tests/modules/duplicate_def.fut000066400000000000000000000011571475065116200214700ustar00rootroot00000000000000-- testing that variable shadowing and chunking -- doesn't allow for duplicate definitions -- == -- error: Duplicate type foo = (i32, f32) module M0 = { type foo = foo -- the type is defined from l. 1 type bar = f32 } module M1 = { type foo = f32 type bar = M0.bar -- type is defined from l.5 module M0 = { type foo = M0.foo -- is defined at l. 5 type bar = (i32, i32, i32) } type foo = f32 -- REDEFINITION OF foo IN Struct M1 type baz = M0.bar -- defined at line 17 } type baz = M1.baz -- is defined at l. 13 def main(a: i32, b: float): baz = (1,2,3) futhark-0.25.27/tests/modules/duplicate_def0.fut000066400000000000000000000004721475065116200215470ustar00rootroot00000000000000-- This test is written to ensure that the same name can be used -- for different declarations in the same local environment, as long as their types does not overlap. -- == -- input { 4 } -- output { 4 } type foo = i32 def foo(a: i32): foo = a + a module Foo = { def one(): i32 = 1 } def main(x: i32): i32 = x futhark-0.25.27/tests/modules/duplicate_def1.fut000066400000000000000000000005111475065116200215420ustar00rootroot00000000000000-- The module opens a new environment, which lets us use names again, which were used -- in a previous scope. -- == -- input { } -- output { 1 2.0 } type foo = i32 module Foo = { def foo(): i32 = 1 module Foo = { type foo = f64 def foo(): foo = 2.0 } } def main: (foo, Foo.Foo.foo) = ( Foo.foo() , Foo.Foo.foo()) futhark-0.25.27/tests/modules/duplicate_error0.fut000066400000000000000000000002571475065116200221430ustar00rootroot00000000000000-- This test fails with a DuplicateDefinition error. -- == -- error: .*Dup.* module Foo = { def foo(): i32 = 1 } def bar(): i32 = 1 def bar(): i32 = 2 def main(): i32 = 0 futhark-0.25.27/tests/modules/duplicate_error1.fut000066400000000000000000000002511475065116200221360ustar00rootroot00000000000000-- This test fails with a DuplicateDefinition error. -- == -- error: .*Dup.* module Foo = { def foo(): foo = 1 } type foo = i32 type foo = float def main(): i32 = 0 futhark-0.25.27/tests/modules/entry.fut000066400000000000000000000002471475065116200200400ustar00rootroot00000000000000-- OK to use module type in an entry point (although perhaps not -- useful). module M = { type t = bool } : { type t } entry main (x: (M.t,M.t)) : (M.t,M.t) = x futhark-0.25.27/tests/modules/fun_call_test.fut000066400000000000000000000016671475065116200215300ustar00rootroot00000000000000-- == -- input { -- } -- output { -- [[2,4,5],[1,5,3],[3,7,1]] -- } def min(a: i32) (b: i32): i32 = if(a num -> num val mult : num -> num -> num val one : num val zero : num } module Int = { type num = i32 def plus (x: i32) (y: i32): i32 = x + y def mult (x: i32) (y: i32): i32 = x * y def one: i32 = 1 def zero: i32 = 0 } module Float32 = { type num = f32 def plus (x: f32) (y: f32): f32 = x + y def mult (x: f32) (y: f32): f32 = x * y def one: f32 = 1f32 def zero: f32 = 0f32 } module DotProd(T: NUMERIC) = { def dotprod [n] (xs: [n]T.num) (ys: [n]T.num): T.num = reduce T.mult T.one (map2 T.plus xs ys) } module IntDotProd = DotProd(Int) module Float32DotProd = DotProd(Float32) def main [n] (xs: [n]i32) (ys: [n]i32) (as: [n]f32) (bs: [n]f32): (i32, f32) = (IntDotProd.dotprod xs ys, Float32DotProd.dotprod as bs) futhark-0.25.27/tests/modules/functor10.fut000066400000000000000000000005101475065116200205110ustar00rootroot00000000000000-- Three nested functors with the same named module type parameter. -- == -- input { true } output { true } module type mt = { type t } module f1(R: mt) = { type t = R.t } module f2(R: mt) = { module L = f1(R) open L } module f3(R: mt) = { open (f2 R) } module m = f3({type t = bool}) def main(x: m.t): m.t = x futhark-0.25.27/tests/modules/functor11.fut000066400000000000000000000010671475065116200205220ustar00rootroot00000000000000-- Complicated nested modules in parametric modules ought to work. -- == -- input { [true,false] } output { [true,false] [0,0] } module type mt = { type cell val init: bool -> cell } module f1(R: mt) = { type cell = R.cell def init [n] (bs: [n]bool): [n]cell = map R.init bs } module f2(R: mt) = { module m = { type cell = (R.cell, i32) def init (b: bool) = (R.init b, 0) } module m' = f1(m) open m' } module m1 = { type cell = bool def init (b: bool) = b } module m2 = f2(m1) def main [n] (bs: [n]bool) = unzip (m2.init bs) futhark-0.25.27/tests/modules/functor12.fut000066400000000000000000000010501475065116200205130ustar00rootroot00000000000000-- Even more Complicated nested modules in parametric modules ought to -- work. -- -- == -- input { [true,false] } output { [true,false] [0,0] } module type mt = { type cell val init: bool -> cell } module f1(R: mt) = { type cell = R.cell def init [n] (bs: [n]bool): [n]cell = map R.init bs } module f2(R: mt) = { open (f1 { type cell = (R.cell, i32) def init (b: bool) = (R.init b, 0) }) } module m1 = { type cell = bool def init (b: bool) = b } module m2 = f2(m1) def main [n] (bs: [n]bool) = unzip (m2.init bs) futhark-0.25.27/tests/modules/functor13.fut000066400000000000000000000007131475065116200205210ustar00rootroot00000000000000-- Complex multiple applications of a parametric module must work. -- -- == -- input { 1 } output { 9 } module type mt = { val f: i32 -> i32 } module pm1(R: mt): {val g: i32->i32} = { def helper(x: i32) = R.f (R.f x) def g(x: i32): i32 = helper x } module pm2(R: mt) = { module tmp = pm1(R) def h(x: i32): i32 = tmp.g (tmp.g x) } module m1 = { def f (x: i32) = x + 1 } module m2 = pm2(m1) module m3 = pm2(m1) def main(x: i32) = m2.h (m3.h x) futhark-0.25.27/tests/modules/functor14.fut000066400000000000000000000013611475065116200205220ustar00rootroot00000000000000-- Deep multiple applications of a parametric module must work. -- -- == -- input { 1 } output { 129 } module type mt = { val f: i32 -> i32 } module pm1(R: mt): {val g1:i32->i32} = { def h(x: i32): i32 = R.f (R.f x) def g1(x: i32): i32 = h x } module pm2(R: mt) = { open (pm1 R) def g2(x: i32): i32 = g1 (g1 x) } module pm3(R: mt) = { open (pm2 R) def g3(x: i32): i32 = g2 (g2 x) } module pm4(R: {val f:i32->i32}) = { open (pm3 R) def g4(x: i32): i32 = g3 (g3 x) } module pm5(R: mt) = { open (pm4 R) def g5(x: i32): i32 = g4 (g4 x) } module pm6(R: mt) = { open (pm5 R) def g6(x: i32): i32 = g5 (g5 x) } module m1 = { def f (x: i32) = x + 1 } module m2 = pm6(m1) module m3 = pm6(m1) def main(x: i32) = m2.g6 (m3.g6 x) futhark-0.25.27/tests/modules/functor15.fut000066400000000000000000000004151475065116200205220ustar00rootroot00000000000000-- Can we nest a parametric module? -- == -- input { 3 } output { 6 } module type MT1 = { val f: i32 -> i32 -> i32 } module M = { module T(P: MT1) = { def g(x: i32) = P.f x x } } module T = M.T({def f (x: i32) (y: i32) = x + y}) def main (x: i32) = T.g x futhark-0.25.27/tests/modules/functor16.fut000066400000000000000000000010421475065116200205200ustar00rootroot00000000000000-- Can we nest a parametric module inside another parametric module? -- This is also a tricky test of shadowing, and we actually got it -- wrong at first. -- -- == -- -- input { 3 } output { 18 } module type MT1 = { val f: i32 -> i32 -> i32 } module type MT2 = { val g: i32 -> i32 } module M = { module T1(P1: MT1) = { module T2(P2: MT1): MT2 = { def g(x: i32) = P2.f (P1.f x x) x } } } module T1a = M.T1({def f (x: i32) (y: i32) = x + y}) module T = T1a.T2({def f (x: i32) (y: i32) = x * y}) def main (x: i32) = T.g x futhark-0.25.27/tests/modules/functor17.fut000066400000000000000000000004431475065116200205250ustar00rootroot00000000000000-- Test for shape declarations inside a parametric module. (We used -- to have a bug here.) -- == -- input { } output { [1.0,2.0] [1] } module PM(P: {type^ r}) = { type t = i32 def f [n] (r: P.r) (a: [n]t) = (r,a) } module PMI = PM {type^ r = []f64} def main = PMI.f [1.0,2.0] [1] futhark-0.25.27/tests/modules/functor18.fut000066400000000000000000000012521475065116200205250ustar00rootroot00000000000000-- A functor whose parameter is itself a functor. -- == -- input { 1 } output { 3 4 7 } module F = \(P_f: (P_f_a: {type a val f: a -> a}) -> {val f: P_f_a.a -> P_f_a.a}) -> \(P_x: {type a val f: a -> a}) -> {type a = P_x.a open (P_f P_x)} module twice = \(twice_P: {type a val f: a -> a}) -> { def f (x: twice_P.a) = twice_P.f (twice_P.f x) } module thrice = \(thrice_P: {type a val f: a -> a}) -> { def f (x: thrice_P.a) = thrice_P.f (thrice_P.f (thrice_P.f x)) } module add_one = {type a = i32 def f(x: i32) = x + 1} module F_2 = F twice add_one module F_3 = F thrice add_one module F_6 = F twice F_3 def main(x: i32) = (F_2.f x, F_3.f x, F_6.f x) futhark-0.25.27/tests/modules/functor19.fut000066400000000000000000000004771475065116200205360ustar00rootroot00000000000000-- Using a curried module twice should work, and notably not result in -- anything being defined twice. -- == -- input {} output {7} module pm (A: {val x: i32}) (B: {val y: i32}) = { def z = A.x + B.y } module cm = pm { def x = 2 } module m1 = cm { def y = 1 } module m2 = cm { def y = 2 } def main = m1.z + m2.z futhark-0.25.27/tests/modules/functor2.fut000066400000000000000000000004631475065116200204410ustar00rootroot00000000000000-- Referring to a parameter-defined type in a functor return signature. -- == -- input { 2 } output { 4 } module F(P:{type t val f:t->t}): {type t = P.t val f2:t->t} = { type t = P.t def f2(x: t): t = P.f (P.f x) } module F' = F({type t = i32 def f (x: i32): i32 = x+1}) def main(x: i32): F'.t = F'.f2 x futhark-0.25.27/tests/modules/functor20.fut000066400000000000000000000010241475065116200205130ustar00rootroot00000000000000-- Parametric module operating on a module with nested modules. -- == -- input { 2 } output { 3 } module type integral = { type t val frob: t -> t } module quux = { type t = i32 def frob (x: i32) = x + 1 } module type has_int = { module int: integral } module mk_has_int (T: integral): has_int with int.t = T.t = { module int = T } module has_quux = mk_has_int quux module frob_int (E: has_int) = { def really_frob (x: E.int.t) = E.int.frob x } module m = frob_int has_quux def main (x: i32) = m.really_frob x futhark-0.25.27/tests/modules/functor21.fut000066400000000000000000000005111475065116200205140ustar00rootroot00000000000000-- Will the abstract type defined in the argument to a parametric -- module also be accessible in the resulting module? module type has_cell = { type cell } module mk_has_cell (V: has_cell) = { type cell = V.cell } module i8_cell = { type cell = i8 } : has_cell module m = mk_has_cell i8_cell def main (x: m.cell) = x futhark-0.25.27/tests/modules/functor22.fut000066400000000000000000000007311475065116200205210ustar00rootroot00000000000000-- Another Parametric module operating on a module with nested -- modules. -- == -- input { 2 } output { 3 } module type to_i32 = { type t val to_i32: t -> i32 } module i32 = { type t = i32 def to_i32 (x: i32) = x } module type engine = { module int: to_i32 val min: int.t } module an_engine = { module int = i32 def min = 1 } module mk_has_y (E: engine) = { def y = E.int.to_i32 E.min } module m1 = mk_has_y an_engine def main (x: i32) = m1.y + x futhark-0.25.27/tests/modules/functor23.fut000066400000000000000000000007341475065116200205250ustar00rootroot00000000000000-- Multiple includes of the same thing is harmless, and should work. -- == -- input { 1 } output { 2 3 } module type has_a = { type a val f : a -> a } module pm(M: has_a): { type a = (M.a, M.a) include has_a with a = (M.a, M.a) include has_a with a = (M.a, M.a) } = { type a = (M.a, M.a) def f (x, y) = (M.f x, M.f (M.f y)) } module M_a: has_a with a = i32 = { type a = i32 def f = (+1) } module M_a_a = pm M_a def main (x: i32) = M_a_a.f (x, x) futhark-0.25.27/tests/modules/functor24.fut000066400000000000000000000003621475065116200205230ustar00rootroot00000000000000module type abs = { type abs } module abs = { type abs = i32 } module fieldtype (P: abs): abs = { type abs = i32 } module big_field (M: abs) = { type t = M.abs } module mod = big_field (fieldtype abs) def main (a: mod.t) = a futhark-0.25.27/tests/modules/functor25.fut000066400000000000000000000004151475065116200205230ustar00rootroot00000000000000module mk_m1 (R: {}) (S: {type s}) = { def f1 (k_m: S.s) = 0i32 } module mk_m2 (S: {type s}) = { module solve = mk_m1 {} S def f2 (k_m: S.s) = solve.f1 k_m } module mk_sm (R: {}) = { type s = {} } module sm = mk_sm {} module m2 = mk_m2 sm def main = m2.f2 futhark-0.25.27/tests/modules/functor26.fut000066400000000000000000000007351475065116200205310ustar00rootroot00000000000000module type rng_engine = { module int: { type t val to_i64 : t -> i64 } val min: int.t } module pcg32: rng_engine with int.t = u32 = { module int = { type t = u32 def to_i64 = u32.to_i64 } : { type t = u32 val to_i64 : t -> i64 } def min = 0u32 } module uniform_int_distribution (E: { module int: { type t val to_i64 : t -> i64 } val min: int.t }) = { def v = E.int.to_i64 E.min } module dist = uniform_int_distribution pcg32 def main = dist.v futhark-0.25.27/tests/modules/functor27.fut000066400000000000000000000004561475065116200205320ustar00rootroot00000000000000-- Ensuring the right references to sizes even through tricky -- indirections. module mk (P: {val n: i64}) : { val mk 't : t -> [P.n]t } = { def mk = replicate P.n } module mk2 (P: {val n: i64}) = { module m = mk P def f (t: bool) = m.mk t } module m = mk2 { def n = 10i64 } def main = m.f futhark-0.25.27/tests/modules/functor28.fut000066400000000000000000000012251475065116200205260ustar00rootroot00000000000000module type newreal = { type t val f32 : f32 -> t } module type newint = { type t val f32 : f32 -> t } module newf32 : newreal with t = f32 = { type t = f32 def f32 = f32.f32 } module newi32 : newint with t = i32 = { type t = i32 def f32 = i32.f32 } module type mixture = { module V : newreal module I : newint } module em (P: mixture) = { module mixture = P } module k_means_mixture (P: mixture) = { module V = P.V module I = P.I } module foo = { module V = newf32 module I = newi32 } module bar = k_means_mixture foo module baz = bar : mixture module k_means_em = em baz def main (x: k_means_em.mixture.V.t) = x futhark-0.25.27/tests/modules/functor29.fut000066400000000000000000000003661475065116200205340ustar00rootroot00000000000000module newf32 = { type t = f32 } module type mixture = { module V : {type t} } module pm (P: mixture) = { module V = P.V } module foo = { module V = newf32 } module k_means_em = pm (pm foo : mixture) def main (x: k_means_em.V.t) = x futhark-0.25.27/tests/modules/functor3.fut000066400000000000000000000005061475065116200204400ustar00rootroot00000000000000-- Parametric module where the argument contains an abstract type. -- == -- input {} output {2} module type colour = { type colour } module rgba_colour: colour = { type colour = i32 } module colourspace(C: colour) = { open C def frob (x: colour): colour = x } module rgba = colourspace(rgba_colour) def main = 2 futhark-0.25.27/tests/modules/functor30.fut000066400000000000000000000003541475065116200205210ustar00rootroot00000000000000-- Based on #1741 -- == -- error: type t = i64 module Op = (\(X: {type t = i64 val x : t}) -> {def x = X.x}) : (X: {type t val x: t}) -> {val x : X.t} module L2 = Op { type t = bool def x : t = true } def main = L2.x futhark-0.25.27/tests/modules/functor4.fut000066400000000000000000000005401475065116200204370ustar00rootroot00000000000000-- Another parametric module where the argument contains an abstract type. -- == -- input {} output {2} module type foo = { type foo val mkfoo: i32 -> foo } module rgba_foo: foo = { type foo = i32 def mkfoo (x: i32) = x } module foospace(C: foo) = { open C def frob (x: foo): foo = x } module rgba = foospace(rgba_foo) def main = 2 futhark-0.25.27/tests/modules/functor5.fut000066400000000000000000000003411475065116200204370ustar00rootroot00000000000000-- Open and functors must work together. -- == -- input {} output {6} module type mt = { val x: i32 } module m1: mt = { def x = 2 } module f(M: mt) = { open M def y = x + 2 } module m2 = f(m1) def main = m2.x + m2.y futhark-0.25.27/tests/modules/functor6.fut000066400000000000000000000004661475065116200204500ustar00rootroot00000000000000-- Applying a parametric module inside another parametric module. -- == -- input { } -- output { 2 } module f1(R: { type cell }) = { type cell = R.cell } module f2(R: { type cell }) = { module L = f1(R) type cell = R.cell } module m1 = { type cell = i32 } module m2 = f2(m1) def main: m2.cell = 2 futhark-0.25.27/tests/modules/functor7.fut000066400000000000000000000005071475065116200204450ustar00rootroot00000000000000-- Make sure a type from an opened module inside a functor is not -- abstract. -- -- == -- input { } -- output { 2 } module f1(R0: { type cell }) = { type cell = R0.cell } module f2(R1: { type cell }) = { module L = f1(R1) open L def id (x: cell) = x } module m2 = f2({type cell = i32}) def main: m2.cell = m2.id 2 futhark-0.25.27/tests/modules/functor8.fut000066400000000000000000000005271475065116200204500ustar00rootroot00000000000000-- Using the same signature in multiple places should not cause -- trouble. -- == -- input { true } output { true } module type rules = { type cell } module f1(R1: rules) = { type cell = R1.cell } module f2(R2: rules) = { module L = f1(R2) open L } module conway = f2({type cell = bool}) def main(x: conway.cell): conway.cell = x futhark-0.25.27/tests/modules/functor9.fut000066400000000000000000000004141475065116200204440ustar00rootroot00000000000000-- Instantiating a parametric module twice should go well. -- == -- input { 2 true } output { 2 true } module f(P: {type t}) = { def id (x: P.t) = x } module m1 = f({type t = i32}) module m2 = f({type t = bool}) def main (x: i32) (y: bool) = (m1.id x, m2.id y) futhark-0.25.27/tests/modules/hof0.fut000066400000000000000000000002301475065116200175230ustar00rootroot00000000000000-- OK because the module defines a higher-order type and the module -- type specifies a lifted type. module m = { type^ t = i32 -> i32 } : { type ^t } futhark-0.25.27/tests/modules/hof1.fut000066400000000000000000000002721475065116200175320ustar00rootroot00000000000000-- Not OK because the module defines a higher-order type but the -- module type specifies a zero-order type. -- == -- error: non-lifted module m = { type^ t = i32 -> i32 } : { type t } futhark-0.25.27/tests/modules/hof2.fut000066400000000000000000000001041475065116200175250ustar00rootroot00000000000000-- OK; perfect match. module m = { type t 'a = a } : { type t 'a } futhark-0.25.27/tests/modules/hof3.fut000066400000000000000000000003301475065116200175270ustar00rootroot00000000000000-- Not OK, because the module type specifies a more liberal type than -- defined by the module. -- == -- error: Module type module m = { def f 'a (x: a) = ([x])[0] } : { val f '^a : a -> a } let main = m.f id 0i32 futhark-0.25.27/tests/modules/hof4.fut000066400000000000000000000001131475065116200175270ustar00rootroot00000000000000-- == -- error: non-lifted module m = { type^ t '^a = a } : { type t 'a } futhark-0.25.27/tests/modules/hof5.fut000066400000000000000000000002761475065116200175420ustar00rootroot00000000000000-- We require that some type is non-functional, but that type refers -- to a lifted abstract type! -- == -- error: non-lifted module m = \(p: {type ^a}) -> ({ type^ t = p.a } : { type t }) futhark-0.25.27/tests/modules/hof6.fut000066400000000000000000000002441475065116200175360ustar00rootroot00000000000000-- Higher-order abstract types may not be array elements! -- == -- error: Cannot create array module m = { type^ t = i32 -> i32 } : { type ^t } def x: []m.t = [] futhark-0.25.27/tests/modules/hof7.fut000066400000000000000000000005441475065116200175420ustar00rootroot00000000000000-- Just one test that actually computes something. -- == -- input { 3 } output { 5 4 } module type has_t = { type ^t val v: t val ap: t -> i32 -> i32 } module m1: has_t = { type^ t = i32 -> i32 def v = (+2) def ap f x = f x } module m2: has_t = { type t = i32 def v = 1 def ap = (+) } def main (x: i32) = (m1.ap m1.v x, m2.ap m2.v x) futhark-0.25.27/tests/modules/hof8.fut000066400000000000000000000002521475065116200175370ustar00rootroot00000000000000-- Lifted abstract types from a module parameter cannot be array -- elements! -- == -- error: might contain function module m = \(p: {type ^a}) -> { def v: []p.a = [] } futhark-0.25.27/tests/modules/hof9.fut000066400000000000000000000003741475065116200175450ustar00rootroot00000000000000-- == -- input { 2 } output { 2 } module type mt = { type^ abs val mk : i32 -> abs val len : abs -> i32 } module m : mt = { type^ abs = bool -> i32 def mk (n: i32) = \_ -> n def len (f: abs) = f true } def main (x: i32) = m.len (m.mk x) futhark-0.25.27/tests/modules/import-qualified.fut000066400000000000000000000001631475065116200221470ustar00rootroot00000000000000-- == -- input { 1 } -- output { 3 } module M = import "importee-qualified" def main(a: i32): i32 = M.whatever 1 futhark-0.25.27/tests/modules/importee-qualified.fut000066400000000000000000000000351475065116200224570ustar00rootroot00000000000000def whatever(x: i32) = x + 2 futhark-0.25.27/tests/modules/index_qual_array.fut000066400000000000000000000001571475065116200222260ustar00rootroot00000000000000-- == -- input { 4 } output { 5 } module M = { def a: []i32 = [1,2,3] } def main(x: i32): i32 = M.a[0] + x futhark-0.25.27/tests/modules/int_mod.fut000066400000000000000000000003771475065116200203340ustar00rootroot00000000000000-- == -- input { -- 10 21 -- } -- output { -- 31 -- } module IntLib = { def plus(a: i32, b: i32): i32 = a + b def numberFour(): i32 = 4 } def localplus(a: i32, b: i32): i32 = IntLib.plus (a,b) def main (a: i32) (b: i32): i32 = localplus(a,b) futhark-0.25.27/tests/modules/intrinsics-error.fut000066400000000000000000000003011475065116200222020ustar00rootroot00000000000000-- You are not allowed to use the intrinsics module in module expressions. -- == -- error: The 'intrinsics' module may not be used in module expressions. module M = intrinsics def main() = 0 futhark-0.25.27/tests/modules/lambda0.fut000066400000000000000000000010451475065116200201740ustar00rootroot00000000000000-- Module-level lambdas, i.e., anonymous functors. -- == -- input { 2 } output { 32 } module type operation = { type t val f: t -> t } module type repeater = (P:operation) -> operation with t = P.t module twice: repeater = \(P: operation) -> { type t = P.t def f (x: P.t) = P.f (P.f x) } module type i32_operation = operation with t = i32 module times_2: i32_operation = { type t = i32 def f (x: i32) = x * 2 } module times_4: i32_operation = twice(times_2) module times_16: i32_operation = twice(times_4) def main (x: i32) = times_16.f x futhark-0.25.27/tests/modules/lambda1.fut000066400000000000000000000013411475065116200201740ustar00rootroot00000000000000-- More fancy use of module-level lambdas. -- == -- input { 9 } output { 3 } module type operation = {type a type b val f: a -> b} module compose = \(F: operation) -> \(G: operation with a = F.b) -> { type a = F.a type b = G.b def f(x: a) = G.f (F.f x) } module i32_to_f64: operation with a = i32 with b = f64 = { type a = i32 type b = f64 def f(x: a) = f64.i32 x } module f64_to_i32: operation with a = f64 with b = i32 = { type a = f64 type b = i32 def f(x: a) = i32.f64 x } module f64_sqrt: operation with a = f64 with b = f64 = { type a = f64 type b = f64 def f(x: a) = f64.sqrt x } module i32_sqrt = compose (compose i32_to_f64 f64_sqrt) f64_to_i32 def main(x: i32) = i32_sqrt.f x futhark-0.25.27/tests/modules/lambda2.fut000066400000000000000000000013151475065116200201760ustar00rootroot00000000000000-- Another clever use of module lambdas, and keyword-like application. -- == -- input { 9 } output { 3.0 } module type operation = {type a type b val f: a -> b} module compose = \(P: {module F: operation module G: operation with a = F.b}): (operation with a = P.F.a with b = P.G.b) -> { type a = P.F.a type b = P.G.b def f(x: a) = P.G.f (P.F.f x) } module i32_to_f64: operation with a = i32 with b = f64 = { type a = i32 type b = f64 def f(x: a) = f64.i32 x } module f64_sqrt: operation with a = f64 with b = f64 = { type a = f64 type b = f64 def f(x: a) = f64.sqrt x } module mysqrt = compose { module F = i32_to_f64 module G = f64_sqrt } def main(x: i32) = mysqrt.f x futhark-0.25.27/tests/modules/liftedness0.fut000066400000000000000000000002461475065116200211160ustar00rootroot00000000000000-- Abstract type must be at most as lifted as in the module type. -- == -- error: vector module type mt = { type~ vector } module m = { type^ vector = []i32 } : mt futhark-0.25.27/tests/modules/liftedness1.fut000066400000000000000000000003211475065116200211110ustar00rootroot00000000000000-- Trying to sneak more constraints through a size-lifted type! -- == -- error: val f module type mt = { type~ t val f : t -> i32 } module m : mt = { type~ t = [][]i32 def f [n] (_: [n][n]i32) = n } futhark-0.25.27/tests/modules/local0.fut000066400000000000000000000002421475065116200200440ustar00rootroot00000000000000-- Test that something defined with local is not accessible outside the module. -- == -- input {} output { 1 } def x = 1 open { local def x = 2 } def main = x futhark-0.25.27/tests/modules/local_open0.fut000066400000000000000000000002101475065116200210600ustar00rootroot00000000000000-- Does local open work at all? -- == -- input { 1 } output { 6 } module m = { def x = 2 def y = 3 } def main(x: i32) = x + m.(x + y) futhark-0.25.27/tests/modules/local_open1.fut000066400000000000000000000002431475065116200210670ustar00rootroot00000000000000-- Local open with a nested module. -- == -- input { 1 } output { 6 } module m0 = { module m1 = { def x = 2 } def x = 3 } def main(x: i32) = x + m0.(x + m1.(x)) futhark-0.25.27/tests/modules/local_open2.fut000066400000000000000000000004111475065116200210650ustar00rootroot00000000000000-- Local open that involves values of an abstract type. -- == module type has_t = { type t val f: i32 -> t } module pm (num: has_t) = { def test (x: i32) = num.(f x) } module m = pm { type t = i32 def f (x: i32) = x } def main (x: i32) = m.test x futhark-0.25.27/tests/modules/local_open3.fut000066400000000000000000000002561475065116200210750ustar00rootroot00000000000000-- Local open with nested modules, defined elsewhere. -- == module has_x = { def x = 1i32 } module has_has_x = { module has_x = has_x } def main = has_has_x.(has_x.x) futhark-0.25.27/tests/modules/local_open4.fut000066400000000000000000000003211475065116200210670ustar00rootroot00000000000000-- Deeper local open! -- == module has_x = { def x = 1i32 } module has_has_x = { module has_x = has_x } module has_has_has_x = { module has_x = has_x } def main = has_has_has_x.(has_has_x.(has_x.x)) futhark-0.25.27/tests/modules/map_with_structure0.fut000066400000000000000000000004271475065116200227070ustar00rootroot00000000000000-- Testing whether it is possible to use a function -- from a module in a curry function (map) -- == -- input { -- [1, 2, 3 ,4, 5, 6, 7, 8, 9, 10] -- } -- output { -- 55 -- } module F = { def plus(a: i32) (b: i32): i32 = a+b } def main(a: []i32): i32 = reduce F.plus 0 a futhark-0.25.27/tests/modules/map_with_structure1.fut000066400000000000000000000004271475065116200227100ustar00rootroot00000000000000-- Testing whether it is possible to use a function -- from a module in a curry function (map) -- == -- input { -- [1, 2, 3 ,4, 5, 6, 7, 8, 9, 10] -- } -- output { -- 55 -- } module F = { def plus(a: i32) (b: i32): i32 = a+b } def main(a: []i32): i32 = reduce F.plus 0 a futhark-0.25.27/tests/modules/module-spec-error0.fut000066400000000000000000000002551475065116200223220ustar00rootroot00000000000000-- A mismatched module spec. -- == -- error: QUUX module type MT = { module M: {val QUUX: i32} } module M1: MT = { module M = { def QUUX2 = 2 } } def main() = M1.M.x futhark-0.25.27/tests/modules/module-spec-error1.fut000066400000000000000000000003261475065116200223220ustar00rootroot00000000000000-- A parametric module does not match a module spec. -- == -- error: parametric module type MT = { module M: {val x: i32} } module M1: MT = { module M(P: {val y:i32}) = { def x = P.y } } def main() = M1.M.x futhark-0.25.27/tests/modules/module-spec-error2.fut000066400000000000000000000003641475065116200223250ustar00rootroot00000000000000-- == -- error: M1.M.t.*M0.M.t module type MT = { module M: {type t val x: t val f: t -> t} } module M0: MT = { module M = { type t = i32 def x = 0 def f (y: t) = y + 1 } } module M1: MT = M0 def main() = M1.M.f (M0.M.x) futhark-0.25.27/tests/modules/module-spec0.fut000066400000000000000000000002641475065116200211730ustar00rootroot00000000000000-- A module spec in a module type. -- == -- input {} output { 2 } module type MT = { module M: {val x: i32} } module M1: MT = { module M = { def x = 2 } } def main = M1.M.x futhark-0.25.27/tests/modules/module-spec1.fut000066400000000000000000000006001475065116200211660ustar00rootroot00000000000000-- A module spec in a module type, used for a parametric module, with -- some shadowing too. -- == -- input { 10 } output { 10 } module PM(P: {type t val x: t module PM: {val f: t -> t}}) = { def iterate(n: i32) = loop x = copy P.x for i < n do P.PM.f x } module M = PM({ type t = i32 def x = 0 module PM = { def f(a: i32) = a + 1 } }) def main(n: i32) = M.iterate n futhark-0.25.27/tests/modules/module-spec2.fut000066400000000000000000000010311475065116200211660ustar00rootroot00000000000000-- Higher-order module spec. -- == -- input { 3 } output { 12 } module type MT1 = { val f: i32 -> i32 -> i32 } module type MT2 = { val g: i32 -> i32 } module type MT3 = { module T: MT1 -> MT2 } module MT3_twice: MT3 = { module T(P: MT1): MT2 = { def g (x: i32) = P.f x x } } module MT1_plus: MT1 = { def f (x: i32) (y: i32) = x + y } module M = { module T(P: MT3) = { module P_T_I = P.T MT1_plus def g(x: i32) = P_T_I.g x } } module MT_I = M.T MT3_twice def main(x: i32) = MT1_plus.f x x + MT_I.g x futhark-0.25.27/tests/modules/module-spec3.fut000066400000000000000000000005611475065116200211760ustar00rootroot00000000000000-- Higher-order module specs with abstract types. -- == -- input { 1 } output { 4 } module type repeater = (P:{type t val f: t -> t}) -> {val g: P.t -> P.t} module twice_(P: {type t val f: t -> t}) = { def g (x: P.t) = P.f (P.f x) } module twice = twice_ : repeater module twice_mult = twice {type t = i32 def f (x: t) = x * 2} def main(x: i32) = twice_mult.g x futhark-0.25.27/tests/modules/open0.fut000066400000000000000000000002141475065116200177120ustar00rootroot00000000000000-- Does the open declaration work at all? -- == -- input { } output { 4 } module M = { def the_value = 4 } open M def main = the_value futhark-0.25.27/tests/modules/open1.fut000066400000000000000000000002241475065116200177140ustar00rootroot00000000000000-- Does open shadow correctly? -- == -- input { } output { 4 } def the_value = 2 module M = { def the_value = 4 } open M def main = the_value futhark-0.25.27/tests/modules/open4.fut000066400000000000000000000005131475065116200177200ustar00rootroot00000000000000-- Opening a module defining a type is not the same as defining a -- type. Confused the type checker at one point. -- == -- input { 2 } output { 2 } module pm (P: {type t}) (X: {}) = { open P def id_t (x: t): t = x } module p_is_i32 = { type t = i32 } module pm_i32 = pm p_is_i32 {} def main (x: i32): i32 = pm_i32.id_t x futhark-0.25.27/tests/modules/open5.fut000066400000000000000000000004151475065116200177220ustar00rootroot00000000000000module type ModuleType = { val someVal: i32 } module moduleinst: ModuleType = { def someVal = 0i32 } module ModuleTypeOps (x: ModuleType) = { def mySomeVal = x.someVal } open ModuleTypeOps moduleinst open ModuleTypeOps moduleinst entry main = mySomeVal futhark-0.25.27/tests/modules/open6.fut000066400000000000000000000006621475065116200177270ustar00rootroot00000000000000module type dummy = {} module dummyinst: dummy = {} module type ModuleType = { val someVal: i32 } module moduleinst: ModuleType = { def someVal = 0i32 } module ModuleTypeOps (x: ModuleType) = { def valGetter = x.someVal } module HigherModule (unused: dummy) = { open ModuleTypeOps moduleinst def myGet = valGetter } open ModuleTypeOps moduleinst module test = HigherModule dummyinst entry main = test.myGet futhark-0.25.27/tests/modules/polymorphic-error0.fut000066400000000000000000000002631475065116200224510ustar00rootroot00000000000000-- Insufficient polymorphism. -- == -- error: pair module type has_pair = { val pair 'a 'b: a -> b -> (a,b) } module with_pair: has_pair = { def pair 'a (x: a) (y: a) = (x,y) } futhark-0.25.27/tests/modules/polymorphic0.fut000066400000000000000000000004341475065116200213220ustar00rootroot00000000000000-- A simple polymorphic function in a module type. -- == -- input { 1 false } output { false 1 } module type has_identity = { val id 't: t -> t } module with_identity: has_identity = { def id 't (x: t) = x } def main (x: i32) (y: bool) = (with_identity.id y, with_identity.id x) futhark-0.25.27/tests/modules/polymorphic1.fut000066400000000000000000000003701475065116200213220ustar00rootroot00000000000000-- Being more polymorphic is OK. -- == -- input { 1 2 } output { 1 2 } module type has_pair = { val pair 'a: a -> a -> (a,a) } module with_pair: has_pair = { def pair 'a 'b (x: a) (y: b) = (x,y) } def main (x: i32) (y: i32) = with_pair.pair x y futhark-0.25.27/tests/modules/polymorphic2.fut000066400000000000000000000010021475065116200213140ustar00rootroot00000000000000-- Polymorphic function in module parameter. -- == -- input { [1,2] [true,false] } -- output { [1,2,1,2] [true,false,true,false] [2,1] [false,true] } module pm (P: { val frob 'a [n]: [n]a -> []a }) = { def frob_two 'a 'b (xs: []a) (ys: []b) = (P.frob xs, P.frob ys) } module double = pm { def frob 'a (xs: []a) = concat xs xs } module reverse = pm { def frob 'a (xs: []a) = xs[::-1] } def main (xs: []i32) (ys: []bool) = let (a,b) = double.frob_two xs ys let (c,d) = reverse.frob_two xs ys in (a,b,c,d) futhark-0.25.27/tests/modules/polymorphic3.fut000066400000000000000000000010201475065116200213150ustar00rootroot00000000000000-- Polymorphic function using polymorphic type in parametric module. -- == -- input { 2 3 } output { [1i64,0i64] [2.0,1.0,0.0] } module pm (P: { type~ vector 't val reverse 't: vector t -> vector t }) = { def reverse_pair 'a 'b ((xs,ys): (P.vector a, P.vector b)) = (P.reverse xs, P.reverse ys) } module m = pm { type~ vector 't = ?[k].[k]t def reverse 't (xs: []t) = xs[::-1] } def main (x: i32) (y: i32) = m.reverse_pair (iota (i64.i32 x), map f64.i64 (iota (i64.i32 y))) futhark-0.25.27/tests/modules/polymorphic4.fut000066400000000000000000000007051475065116200213270ustar00rootroot00000000000000-- Array of tuples polymorphism. -- == -- input { 2i64 } output { [1i64,0i64] [1.0,0.0] [1i64,0i64] } module pm (P: { type vector [n] 't val reverse [n] 't: vector [n] t -> vector [n] t }) = { def reverse_triple [n] 'a 'b (xs: (P.vector [n] (a,b,a))) = P.reverse xs } module m = pm { type vector [n] 't = [n]t def reverse 't (xs: []t) = xs[::-1] } def main (x: i64) = unzip3 (m.reverse_triple (zip3 (iota x) (map f64.i64 (iota x)) (iota x))) futhark-0.25.27/tests/modules/polymorphic5.fut000066400000000000000000000002431475065116200213250ustar00rootroot00000000000000-- Removing polymorphism with an ascription. -- == -- input { 2 } output { 2 } module m: { val id : i32 -> i32 } = { def id x = x } def main (x: i32) = m.id x futhark-0.25.27/tests/modules/polymorphic6.fut000066400000000000000000000004201475065116200213230ustar00rootroot00000000000000-- Removing polymorphism with an ascription, but higher order! -- == -- input { 2 } output { 2 } module mk_m (P: { type t val f: t -> i32 }) = { def g (x: P.t) = P.f x } module m = mk_m { type t = (i32, i32) def f (x, _) = x } def main (x: i32) = m.g (x,x) futhark-0.25.27/tests/modules/polymorphic7.fut000066400000000000000000000004171475065116200213320ustar00rootroot00000000000000-- Being more polymorphic inside a tuple is OK. -- == -- input { 1 2 } output { 1 2 } module type has_pair = { val fs 'a: (a -> a, a -> a) } module with_pair: has_pair = { def fs = (\x -> x, \y -> y) } def main (x: i32) (y: i32) = (with_pair.fs.0 x, with_pair.fs.0 y) futhark-0.25.27/tests/modules/scope_behaviour.fut000066400000000000000000000007071475065116200220550ustar00rootroot00000000000000-- == -- input { -- 10 3.0 -- } -- output { -- 1 2 3 -- } type foo = (i32, f64) module M0 = { type foo = (f64, i32) type bar = foo } module M1 = { type foo = f64 type bar = M0.bar -- type is defined at line 13 module M0 = { type foo = M0.foo -- is defined at line 12 type bar = (i32, i32, i32) } type baz = M0.bar -- defined at line 24 } type baz = M1.baz -- is defined at line 27 def main (a: i32) (b: f64) = (1,2,3) : baz futhark-0.25.27/tests/modules/shadowing0.fut000066400000000000000000000005401475065116200207360ustar00rootroot00000000000000-- M1.foo() calls the most recent declaration of number, due to M0.number() -- being brought into scope of M1, overshadowing the top level definition of -- number() -- == -- input { -- } -- output { -- 2 -- } def number(): i32 = 1 module M0 = { def number(): i32 = 2 module M1 = { def foo(): i32 = number() } } def main: i32 = M0.M1.foo() futhark-0.25.27/tests/modules/shadowing1.fut000066400000000000000000000007521475065116200207440ustar00rootroot00000000000000-- M1.foo() calls the most recent declaration of number, due to M0.number() -- being brought into scope of M1, overshadowing the top level definition of -- number() -- == -- input { -- } -- output { -- 6.0 6 6 -- } type best_type = f64 def best_number(): best_type = 6.0 module M0 = { type best_type = i32 def best_number(): best_type = 6 module M1 = { def best_number(): best_type = 6 } } def main: (f64, i32, i32) = (best_number() , M0.best_number() , M0.M1.best_number()) futhark-0.25.27/tests/modules/shadowing2.fut000066400000000000000000000005541475065116200207450ustar00rootroot00000000000000-- M0.foo() changes meaning inside M1, after the previous declaration of M0 -- is overshadowed. -- -- == -- input { -- } -- output { -- 1 1 10 -- } module M0 = { def foo(): i32 = 1 } module M1 = { def bar(): i32 = M0.foo() module M0 = { def foo(): i32 = 10 } def baz(): i32 = M0.foo() } def main: (i32, i32, i32) = (M0.foo(), M1.bar(), M1.baz()) futhark-0.25.27/tests/modules/shadowing3.fut000066400000000000000000000003121475065116200207360ustar00rootroot00000000000000-- Shadowing of ordinary names should work as expected. -- == -- input { 3 } output { 5 } def plus2 x = x + 2 module m = { def plus2 = plus2 -- Should refer to the global one. } def main = m.plus2 futhark-0.25.27/tests/modules/shadowing4.fut000066400000000000000000000002131475065116200207370ustar00rootroot00000000000000-- Shadowing of types should work as expected. -- == -- input { } output { 2 } type t = i32 module m = { type t = t } def main: t = 2 futhark-0.25.27/tests/modules/sig-error0.fut000066400000000000000000000001121475065116200206570ustar00rootroot00000000000000-- == -- error: non-anonymous module type mt = { val f: []i32 -> i32 } futhark-0.25.27/tests/modules/sig-error1.fut000066400000000000000000000001261475065116200206650ustar00rootroot00000000000000-- == -- error: quux module type mt = { val f: (quux: i32) -> (quux: i32) -> i32 } futhark-0.25.27/tests/modules/sig1.fut000066400000000000000000000002011475065116200175300ustar00rootroot00000000000000-- Signature with abstract type. module type MONOID = { type t val neutral: t val op: t -> t -> t } def main(): i32 = 0 futhark-0.25.27/tests/modules/sig3.fut000066400000000000000000000004071475065116200175420ustar00rootroot00000000000000-- == -- input { 2 true } -- output { [true,true] } module type mt = { val replicate 't: (n: i64) -> t -> [n]t } module m: mt = { def replicate 't (n: i64) (x: t): [n]t = map (\_ -> x) (iota n) } def main (n: i32) (x: bool) = m.replicate (i64.i32 n) x futhark-0.25.27/tests/modules/simple_nested_test.fut000066400000000000000000000005471475065116200225740ustar00rootroot00000000000000-- == -- input { -- 10 21 -- } -- output { -- 6 -- } type t = i32 module NumLib = { def plus(a: t, b: t): t = a + b module BestNumbers = { def four(): t = 4 def seven(): t = 42 def six(): t = 41 } } def localplus(a: i32, b: i32): i32 = NumLib.plus (a,b) def main (a: i32) (b: i32): i32 = localplus(NumLib.BestNumbers.four() , 2) futhark-0.25.27/tests/modules/simple_number_module.fut000066400000000000000000000005471475065116200231100ustar00rootroot00000000000000-- == -- input { -- 10 21 -- } -- output { -- 6 -- } type t = i32 module NumLib = { def plus(a: t, b: t): t = a + b module BestNumbers = { def four(): t = 4 def seven(): t = 42 def six(): t = 41 } } def localplus(a: i32, b: i32): i32 = NumLib.plus (a,b) def main (a: i32) (b: i32): i32 = localplus(NumLib.BestNumbers.four() , 2) futhark-0.25.27/tests/modules/sizeparams-error0.fut000066400000000000000000000005761475065116200222710ustar00rootroot00000000000000-- A module with abstract types containing size parameters, instantiated incorrectly. -- == -- error: intvec module type MT = { type intvec[n] val singleton: i32 -> intvec [1] val first [n]: intvec [n] -> i32 } module M0: MT = { type intvec = [3]i32 def singleton (x: i32) = [x,x,x] def first (x: intvec) = x[0] } def main(x: i32): i32 = M0.first (M0.singleton x) futhark-0.25.27/tests/modules/sizeparams-error1.fut000066400000000000000000000004221475065116200222600ustar00rootroot00000000000000-- A dimension parameter using a name bound in the module type. -- == -- error: k_ints type ints [n] = [n]i32 module type MT = { val k: i64 type k_ints = ints [k] } module M_k2: MT = { def k = 2i64 type k_ints = ints [2] } def main(n: i32): M_k2.k_ints = iota n futhark-0.25.27/tests/modules/sizeparams-error2.fut000066400000000000000000000001501475065116200222570ustar00rootroot00000000000000-- Size parameters may not be duplicated. -- == -- error: n module type mt = { type matrix [n] [n] } futhark-0.25.27/tests/modules/sizeparams-error3.fut000066400000000000000000000004261475065116200222660ustar00rootroot00000000000000-- == -- error: type vector module pm (P: { type vector [n] 't val reverse [n] 't: vector [n] t -> vector [n] t }) = { def reverse_triple [n] 'a 'b (xs: (P.vector [n] (a,b,a))) = P.reverse xs } module m = pm { type vector 't = [2]t def reverse 't (xs: []t) = xs[::-1] } futhark-0.25.27/tests/modules/sizeparams0.fut000066400000000000000000000006211475065116200211310ustar00rootroot00000000000000-- A module with abstract types containing size parameters. -- == -- input { 3 } output { 3 } module type MT = { type intvec[n] val singleton: i32 -> intvec [1] val first [n]: intvec [n] -> i32 } module M0: MT = { type intvec [n] = [n]i32 def singleton (x: i32) = [x] def first [n] (x: intvec[n]) = x[0] } def main(x: i32): i32 = let y: M0.intvec[1] = M0.singleton x in M0.first y futhark-0.25.27/tests/modules/sizeparams1.fut000066400000000000000000000005061475065116200211340ustar00rootroot00000000000000-- A dimension parameter using a name bound in the module type. -- == -- input { 2i64 } output { [0i64,1i64] } -- input { 1i64 } error: type ints [n] = [n]i64 module type MT = { val k: i64 type k_ints = ints [k] } module M_k2: MT = { def k = 2i64 type k_ints = ints [k] } def main (n: i64) = iota n :> M_k2.k_ints futhark-0.25.27/tests/modules/sizeparams2.fut000066400000000000000000000005621475065116200211370ustar00rootroot00000000000000-- Size parameters in a parametric type. -- == -- input { 1 2 } output { [[0,0]] } module PM(P: { type vec [n] val mk_a: (n: i64) -> vec [n] }) = { def mk_b (m: i64) (n: i64): [m](P.vec [n]) = replicate m (P.mk_a n) } module intmat = PM { type vec [n] = [n]i32 def mk_a (n: i64) = replicate n 0 } def main (m: i32) (n: i32) = intmat.mk_b (i64.i32 m) (i64.i32 n) futhark-0.25.27/tests/modules/sizeparams3.fut000066400000000000000000000006001475065116200211310ustar00rootroot00000000000000-- More size parameters in a parametric type. -- == -- input { 1 1 } output { [0] } -- input { 1 2 } error: module PM(P: { type vec [n] val mk: (n: i64) -> vec [n] }) = { def can_be_bad (n: i64) (x: i64) = P.mk x :> P.vec [n] } module intmat = PM { type vec [n] = [n]i32 def mk (n: i64) = replicate n 0 } def main (n: i32) (x: i32) = intmat.can_be_bad (i64.i32 n) (i64.i32 x) futhark-0.25.27/tests/modules/sizeparams4.fut000066400000000000000000000004111475065116200211320ustar00rootroot00000000000000-- == -- input { 2 } output { 2 } module type mt = { type^ abs val mk : i32 -> abs val len : abs -> i32 } module m : mt = { type~ abs = []i64 def mk (n: i32) = iota (i64.i32 n) def len [n] (_: [n]i64) = i32.i64 n } def main (x: i32) = m.len (m.mk x) futhark-0.25.27/tests/modules/sizeparams5.fut000066400000000000000000000002071475065116200211360ustar00rootroot00000000000000module m : { type~ t '~a val mk '~a : () -> t a } = { type~ t '~a = () def mk () = () } def f '~a (b: bool) : m.t a = m.mk () futhark-0.25.27/tests/modules/sizeparams6.fut000066400000000000000000000002601475065116200211360ustar00rootroot00000000000000module type lys = { type~ state val event : i64 -> state -> state } module lys : lys = { type~ state = {arr: []i64} def event x (s: state) = s with arr = iota x } futhark-0.25.27/tests/modules/sizeparams7.fut000066400000000000000000000003551475065116200211440ustar00rootroot00000000000000-- Does it work to have a definition that only has a size parameter? -- == -- input { 3i64 } output { [0i64,1i64,2i64] } module m : { val iota [n] : [n]i64 } = { def iota [n] : [n]i64 = 0..1.. [b+a]t } = { def plus_comm [a][b]'t (xs: [a+b]t): [b+a]t = xs :> [b+a]t } futhark-0.25.27/tests/modules/sizes0.fut000066400000000000000000000002231475065116200201060ustar00rootroot00000000000000module type sized = { val len: i64 } module arr (S: sized) = { type t = [S.len]i32 } module mt (P: sized): { type t = [P.len]i32 } = arr P futhark-0.25.27/tests/modules/sizes1.fut000066400000000000000000000002011475065116200201030ustar00rootroot00000000000000module type withvec_mt = { val n : i64 val xs : [n]i64 } module withvec : withvec_mt = { def n = 3i64 def xs = iota n } futhark-0.25.27/tests/modules/sizes2.fut000066400000000000000000000002501475065116200201100ustar00rootroot00000000000000-- == -- error: Sizes "n" module type withvec_mt = { val n : i64 val xs : [n]i64 } module withvec : withvec_mt = { def n = 3i64 def xs : []i64 = iota (n+1) } futhark-0.25.27/tests/modules/sizes3.fut000066400000000000000000000002331475065116200201120ustar00rootroot00000000000000module type mod_b = { type t val n : i64 val f: [n]t -> t } module type c = { module B: mod_b } module make_c (B: mod_b): c = { module B = B } futhark-0.25.27/tests/modules/sizes4.fut000066400000000000000000000003011475065116200201070ustar00rootroot00000000000000type state_sized [n] = {arr: [n]i32} module m : { type~ state val step : bool -> state -> state } = { type~ state = ?[n].state_sized [n] def step b (s: state) = s with arr = [1,2,3] } futhark-0.25.27/tests/modules/sizes5.fut000066400000000000000000000001711475065116200201150ustar00rootroot00000000000000local module type sparse = { type~ mat = ?[nnz].[nnz]i64 } module sparse : sparse = { type~ mat = ?[nnz].[nnz]i64 } futhark-0.25.27/tests/modules/sizes6.fut000066400000000000000000000002071475065116200201160ustar00rootroot00000000000000local module type sparse = { type^ mat = (nnz: i64) -> [nnz]i64 } module sparse : sparse = { type^ mat = (nnz: i64) -> [nnz]i64 } futhark-0.25.27/tests/modules/sizes7.fut000066400000000000000000000006031475065116200201170ustar00rootroot00000000000000-- Based on #1992, #1997. -- == -- input { [1u8,2u8,3u8] } module make(G: { val size : i64 val const : u8 -> [size]u8 }) = { def const_map (str : []u8) : [][G.size]u8 = map G.const str } module thing = make ({ def size : i64 = 3 def const (_ : u8) : [size]u8 = sized size [1, 2, 3] } : { val size : i64 val const : u8 -> [size]u8}) entry main str = id (thing.const_map str) futhark-0.25.27/tests/modules/sizes8.fut000066400000000000000000000005321475065116200201210ustar00rootroot00000000000000module type thing = { val n : i64 val f : i64 -> [n / 2]i64 } module mk_m(G: thing) = { def lex [m] (i : [m]i64) : [m][G.n / 2]i64 = map G.f i } module m = mk_m { def n : i64 = 2 def f i = replicate (n / 2) i } def x = m.lex [1, 2, 3] -- == -- entry: test -- random input { } output { 6i64 } entry test = flatten x |> i64.sum futhark-0.25.27/tests/modules/struct-var.fut000066400000000000000000000002611475065116200210050ustar00rootroot00000000000000-- Defining a structure via the name of some other structure. -- == -- input { 2 } output { 3 } module M1 = { def x: i32 = 1 } module M2 = M1 def main(x: i32): i32 = x + M2.x futhark-0.25.27/tests/modules/triangles.fut000066400000000000000000000004321475065116200206630ustar00rootroot00000000000000-- == -- input { -- 10 21 21 -- 19 12 5 -- } -- output { -- 547 -- } import "Vec3" type vec3 = Vec3.Int.t def f(a: vec3, b: vec3): i32 = Vec3.Int.dot(a , b) def main (a1: i32) (a2: i32) (a3: i32) (b1: i32) (b2: i32) (b3: i32): i32 = Vec3.Int.dot((a1,a2,a2) , (b1,b2,b3)) futhark-0.25.27/tests/modules/tst1.fut000066400000000000000000000002711475065116200175670ustar00rootroot00000000000000module type T1 = { type t type s = t val a : s val f : s -> i32 } module X : T1 = { type t = i32 type s = i32 def a : s = 3 def f (x:s) : i32 = x } -- ok def main () : i32 = X.f X.a futhark-0.25.27/tests/modules/tst2.fut000066400000000000000000000003601475065116200175670ustar00rootroot00000000000000-- Incompatible types check -- == -- error: i32.*f32 module type T1 = { type t type s = t val a : s val f : s -> i32 } module X : T1 = { type t = f32 type s = i32 def a : s = 3 def f (x:s) : i32 = x } -- err def main () : i32 = X.f X.a futhark-0.25.27/tests/modules/tst3.fut000066400000000000000000000004031475065116200175660ustar00rootroot00000000000000-- A functor that swaps types -- == -- input {} -- output { 3 } module F (X : { type t type s }) : { type t = X.s type s = X.t } = { type t = X.s type s = X.t } module A = { type t = f32 type s = i32 } module B = F(A) module C = F(B) def main : i32 = 3 futhark-0.25.27/tests/modules/tst4.fut000066400000000000000000000004651475065116200175770ustar00rootroot00000000000000-- A functor with a closure -- == -- input {} -- output { 26 } module A = { type t = i32 def x : t = 3 } module F (X : { val b : i32 }) : { type t = i32 val c : t } = { type t = i32 def c = A.x + X.b } module C = { module A = { type t = f32 } module B = F( { def b = 23 } ) } def main : i32 = C.B.c futhark-0.25.27/tests/modules/typeparams-error0.fut000066400000000000000000000001461475065116200222710ustar00rootroot00000000000000-- == -- error: vector module type MT = { type vector 'a } module M: MT = { type vector = []i32 } futhark-0.25.27/tests/modules/typeparams-error1.fut000066400000000000000000000003701475065116200222710ustar00rootroot00000000000000-- Erroneous use of a parametric type inside a module type. -- == -- error: i32matrix module type MT = { type vector 'a type i32matrix = [](vector i32) } module M1: MT = { type vector 'a = [2]a type i32matrix = [][2]f32 } def main() = 2 futhark-0.25.27/tests/modules/typeparams-error2.fut000066400000000000000000000001601475065116200222670ustar00rootroot00000000000000-- Type parameters may not be duplicated. -- == -- error: previously module type mt = { type whatevs 't 't } futhark-0.25.27/tests/modules/typeparams0.fut000066400000000000000000000010261475065116200211400ustar00rootroot00000000000000-- Type-parametric type in module type. -- == -- input { 1 2 } output { [1,2] } module type MT = { type t 'a val pack [n]: [n]i32 -> t i32 val unpack: t i32 -> []i32 } module M0: MT = { type t 'a = (a,a) def pack (xs: []i32) = (xs[0], xs[1]) def unpack (x: i32, y: i32) = [x,y] } module M1: MT = { type t 'a = [2]a def pack (xs: []i32) = [xs[0], xs[1]] def unpack (xs: t i32) = xs } def main (x: i32) (y: i32): []i32 = let a: M0.t i32 = M0.pack [x,y] let b: M1.t i32 = M1.pack (M0.unpack a) in M1.unpack b futhark-0.25.27/tests/modules/typeparams1.fut000066400000000000000000000005341475065116200211440ustar00rootroot00000000000000-- Use of a parametric type inside a module type. -- == -- input {} output {2} module type MT = { type vector 'a type i32matrix = [2](vector i32) } module M0: MT = { type vector 'a = [2]a type i32matrix = [2](vector i32) } -- And now an inlined one. module M1: MT = { type vector 'a = [2]a type i32matrix = [2][2]i32 } def main = 2 futhark-0.25.27/tests/modules/undefined_structure_err0.fut000066400000000000000000000002741475065116200237100ustar00rootroot00000000000000-- We can not access a module before it has been defined. -- == -- error: .*Unknown.* def try_me(): i32 = M0.number() module M0 = { def number(): i32 = 42 } def main(): i32 = try_me() futhark-0.25.27/tests/modules/warn_type_shadow.fut000066400000000000000000000002551475065116200222530ustar00rootroot00000000000000-- == -- warning: Inclusion shadows type module type has_t = { type t } module type also_has_t = { type t } module type mt = { include has_t include also_has_t } futhark-0.25.27/tests/modules/with-error0.fut000066400000000000000000000004061475065116200210560ustar00rootroot00000000000000-- With of a type in an inner module. -- == -- error: module type requires module type has_t = { type t } module type has_inner = { module inner: has_t } module m: has_inner with inner.t = bool = { module inner = { type t = i32 } } def main(): m.inner.t = 2 futhark-0.25.27/tests/modules/with-error1.fut000066400000000000000000000002621475065116200210570ustar00rootroot00000000000000-- Refinement of a non-parametric with a parametric type is not OK. -- == -- error: Cannot refine a type module type has_t = { type t } module type has_t' = has_t with t 'a = a futhark-0.25.27/tests/modules/with-error2.fut000066400000000000000000000002641475065116200210620ustar00rootroot00000000000000-- Refinement of a parametric with a non-parametric type is not OK. -- == -- error: Cannot refine a type module type has_t = { type t 'a } module type has_t' = has_t with t = i32 futhark-0.25.27/tests/modules/with-error3.fut000066400000000000000000000003621475065116200210620ustar00rootroot00000000000000-- Refinement of a parametric type must not narrow it. This rule may -- be too restrictive and we may loosen it in the future. -- == -- error: Cannot refine a type module type has_t = { type t '^a } module type has_t' = has_t with t 'a = a futhark-0.25.27/tests/modules/with-error4.fut000066400000000000000000000002471475065116200210650ustar00rootroot00000000000000-- Cannot refine a type twice. -- == -- error: not an abstract type module type mt = {type t} module type mt' = mt with t = i64 module type mt'' = mt' with t = i64 futhark-0.25.27/tests/modules/with0.fut000066400000000000000000000003311475065116200177240ustar00rootroot00000000000000-- Does 'with' work? -- == -- input { 2 } output { 42 } module type constant = { type t val x: t } module intconstant: (constant with t = i32) = { type t = i32 def x = 40 } def main(y: i32) = intconstant.x + y futhark-0.25.27/tests/modules/with1.fut000066400000000000000000000006261475065116200177340ustar00rootroot00000000000000-- == -- input { 2 } output { 2 } module type has_cell = { type cell } module type same_cell_twice = { type cell include has_cell with cell = cell } module functor (V: {type cell}): same_cell_twice = { type cell = V.cell } module applied = functor { type cell = bool } -- We can't create a value of this type, but let's just refer to it. entry quux (x: applied.cell) = x def main(x: i32) = x futhark-0.25.27/tests/modules/with2.fut000066400000000000000000000003731475065116200177340ustar00rootroot00000000000000-- With of a type in an inner module. -- == -- input {} output {2} module type has_t = { type t } module type has_inner = { module inner: has_t } module m: has_inner with inner.t = i32 = { module inner = { type t = i32 } } def main: m.inner.t = 2 futhark-0.25.27/tests/modules/with3.fut000066400000000000000000000004701475065116200177330ustar00rootroot00000000000000-- Keeping track of when names have to be qualified. -- == -- input { 1 } output { 3 } module type has_t = { type t } module type has_x = { type t val x: t } module pm (R: has_t) (V: has_x with t = R.t) = { def x = V.x } module m = pm { type t = i32} { type t = i32 def x = 2 } def main (y: i32) = y + m.x futhark-0.25.27/tests/modules/with4.fut000066400000000000000000000003001475065116200177240ustar00rootroot00000000000000-- Refinement of a parametric type. -- == -- input { 2 } output { 2 } module type has_t = { type t 'a } module id : (has_t with t 'a = a) = { type t 'a = a } def main (x: i32): id.t i32 = x futhark-0.25.27/tests/modules/with5.fut000066400000000000000000000002001475065116200177240ustar00rootroot00000000000000-- Refinement of a parametric may expand it. -- == module type has_t = { type t 'a } module type has_t' = has_t with t '^a = a futhark-0.25.27/tests/mss.fut000066400000000000000000000006551475065116200160340ustar00rootroot00000000000000-- Parallel maximum segment sum -- == -- input { [1, -2, 3, 4, -1, 5, -6, 1] } -- output { 11 } def main(xs: []i32): i32 = let max = i32.max let redOp (mssx, misx, mcsx, tsx) (mssy, misy, mcsy, tsy) = ( max mssx (max mssy (mcsx + misy)) , max misx (tsx+misy) , max mcsy (mcsx+tsy) , tsx + tsy) let mapOp x = ( max x 0 , max x 0 , max x 0 , x) in (reduce redOp (0,0,0,0) (map mapOp xs)).0 futhark-0.25.27/tests/mul0.fut000066400000000000000000000001131475065116200160740ustar00rootroot00000000000000-- == -- input { f64.nan } output { f64.nan } entry main (x: f64) = 0 * x futhark-0.25.27/tests/multinv16.fut000066400000000000000000000012101475065116200170630ustar00rootroot00000000000000-- Multiplicative inverse on 16-bit numbers. Returned as a 32-bit -- number to print better (because we do not print unsigned types). -- At one point the compiler missimplified the convergence loop. -- -- == -- input { 2u16 } output { 32769u32 } -- input { 33799u16 } output { 28110u32 } def main(a: u16): u32 = let b = 0x10001u32 let u = 0i32 let v = 1i32 in let (_,_,u,_) = loop ((a,b,u,v)) while a > 0u16 do let q = b / u32.u16(a) let r = b % u32.u16(a) let b = u32.u16(a) let a = u16.u32(r) let t = v let v = u - i32.u32 (q) * v let u = t in (a,b,u,v) in u32.i32(if u < 0 then u + 0x10001 else u) futhark-0.25.27/tests/negate.fut000066400000000000000000000003661475065116200164740ustar00rootroot00000000000000-- Test that negation works for both integers and f64s. -- == -- input { -- [1,2,3] -- } -- output { -- [-1, -2, -3] -- [-1.000000, -2.000000, -3.000000] -- } def main(a: []i32): ([]i32,[]f64) = (map (0-) a, map (0.0-) (map f64.i32 a)) futhark-0.25.27/tests/nestedmain.fut000066400000000000000000000001571475065116200173560ustar00rootroot00000000000000-- == -- input { true } output { false } module m = { def main (x: i32) = x + 2 } def main b : bool = !b futhark-0.25.27/tests/noinline/000077500000000000000000000000001475065116200163175ustar00rootroot00000000000000futhark-0.25.27/tests/noinline/noinline0.fut000066400000000000000000000002221475065116200207260ustar00rootroot00000000000000-- == -- input { [1,2,3] } output { [3,4,5] } -- structure gpu { SegMap/Apply 1 } def f (x: i32) = x + 2 def main = map (\x -> #[noinline] f x) futhark-0.25.27/tests/noinline/noinline1.fut000066400000000000000000000002411475065116200207300ustar00rootroot00000000000000-- == -- input { 3 [1,2,3] } output { [4,5,6] } -- structure gpu { SegMap/Apply 1 } def f (x: i32) (y: i32) = x + y def main y = map (\x -> #[noinline] f x y) futhark-0.25.27/tests/noinline/noinline2.fut000066400000000000000000000003311475065116200207310ustar00rootroot00000000000000-- == -- input { 3 [1,2,3] } output { [0,0,1] } -- compiled input { 0 [1,2,3] } error: division by zero -- structure gpu { SegMap/Apply 1 } def f (x: i32) (y: i32) = x / y def main y = map (\x -> #[noinline] f x y) futhark-0.25.27/tests/noinline/noinline3.fut000066400000000000000000000003341475065116200207350ustar00rootroot00000000000000-- == -- tags { no_opencl no_cuda no_hip no_pyopencl } -- input { [1,2,3] [0,0,1] } output { [1,1,2] } -- structure gpu { SegMap/Apply 1 } def f (xs: []i32) i = xs[i] def main xs is = map (\i -> #[noinline] f xs i) is futhark-0.25.27/tests/noinline/noinline4.fut000066400000000000000000000004251475065116200207370ustar00rootroot00000000000000-- == -- input { 3 [1,2,3] } output { 1 } -- compiled input { 0 [1,2,3] } error: division by zero -- structure gpu { SegRed/Apply 1 /Apply 1 } def f (x: i32) (y: i32) = x / y def g (x: i32) (y: i32) = #[noinline] f x y def main y = map (\x -> #[noinline] g x y) >-> i32.sum futhark-0.25.27/tests/noinline/noinline5.fut000066400000000000000000000003071475065116200207370ustar00rootroot00000000000000-- From issue #1634. -- == -- input { [1i64,2i64] } -- output { [1i64, 2i64, 1i64, 2i64] } #[noinline] def double [n] (A: [n]i64) : *[]i64 = A ++ A def main [n] (A: [n]i64) : *[]i64 = double A futhark-0.25.27/tests/noinline/noinline6.fut000066400000000000000000000002741475065116200207430ustar00rootroot00000000000000-- == -- input { [1,2,3] [4,5,6] } output { [5, 6, 7] [7, 9, 11] } -- structure gpu { SegMap/Apply 2 } #[noinline] def f y x = x + y + 2i32 def main xs ys = (map (f 2) xs, map2 f xs ys) futhark-0.25.27/tests/normalizeTest1.fut000066400000000000000000000006171475065116200201510ustar00rootroot00000000000000-- == -- input { -- 1 -- 2.0 -- 3 -- 4 -- 5.0 -- 6 -- } -- output { -- 5 -- } def tupfun(x: (i32,(f64,i32)), y: (i32,(f64,i32)) ): i32 = let (x1, x2) = x let (y1, y2) = y in x1 + y1 --let (x0, (x1,x2)) = x in --let (y0, (y1,y2)) = y in --33 def main (x1: i32) (y1: f64) (z1: i32) (x2: i32) (y2: f64) (z2: i32): i32 = tupfun((x1,(y1,z1)),(x2,(y2,z2))) futhark-0.25.27/tests/opaque.fut000066400000000000000000000002111475065116200165100ustar00rootroot00000000000000-- Test that 'opaque' prevents constant-folding. -- == -- input {} output {4} -- structure { BinOp 1 Opaque 1 } def main = opaque 2 + 2 futhark-0.25.27/tests/operator/000077500000000000000000000000001475065116200163375ustar00rootroot00000000000000futhark-0.25.27/tests/operator/section0.fut000066400000000000000000000003271475065116200206050ustar00rootroot00000000000000-- We can use operator sections anywhere, just like other functions. -- == -- input { 5 3 } output { 2 2 -2 } def (-^) (x: i32) (y: i32) = x - y def main (x: i32) (y: i32) = ( (-^) x y , (x -^) y , (-^ x) y) futhark-0.25.27/tests/operator/section1.fut000066400000000000000000000002431475065116200206030ustar00rootroot00000000000000-- Operator section as argument to map2 library function. -- == -- input { [4,2,1] [5,6,3] } output { [9,8,4] } def main (xs: []i32) (ys: []i32) = map2 (+) xs ys futhark-0.25.27/tests/operator/section2.fut000066400000000000000000000003101475065116200205770ustar00rootroot00000000000000-- Backticked operator sections work. -- == -- input { 5 3 } output { 2 2 -2 } def minus (x: i32) (y: i32) = x - y def main (x: i32) (y: i32) = ( (`minus`) x y , (x `minus`) y , (`minus` x) y) futhark-0.25.27/tests/operator/section3.fut000066400000000000000000000003011475065116200206000ustar00rootroot00000000000000-- Operator section with interesting type to the left. -- == -- input { [-1,1] } output { [1] } def (<*>) 'a 'b (f: a -> b) (xs: a) = f xs def main (xs: []i32) = (id<*>) (filter (>0) xs) futhark-0.25.27/tests/operator/section4.fut000066400000000000000000000002761475065116200206140ustar00rootroot00000000000000-- Operator section with interesting type to the left. -- == -- input { [-1,1] } output { [1,3] } def (<*>) 'a 'b (f: a -> b) (xs: a) = f xs def main (xs: []i32) = (<*>map (2+) xs) id futhark-0.25.27/tests/operator/section5.fut000066400000000000000000000001401475065116200206030ustar00rootroot00000000000000-- Test that parameter names are not lost in sections. def main n (x: i32) = (`replicate` x) n futhark-0.25.27/tests/operator/section6.fut000066400000000000000000000002131475065116200206050ustar00rootroot00000000000000-- Test that parameter names are not lost in sections. def flipreplicate x n = replicate n x def main n (x: i32) = (x `flipreplicate`) n futhark-0.25.27/tests/operator/size-section0.fut000066400000000000000000000002341475065116200215520ustar00rootroot00000000000000-- Check that sizes are well calculated in left section def main [n][m][l] (xs : [n]i64) (ys: [m]i64) (mat: [l][m]i64) = [xs ++ ys] ++ map (xs ++) mat futhark-0.25.27/tests/operator/size-section1.fut000066400000000000000000000002351475065116200215540ustar00rootroot00000000000000-- Check that sizes are well calculated in right section def main [n][m][l] (xs : [n]i64) (ys: [m]i64) (mat: [l][n]i64) = [xs ++ ys] ++ map (++ ys) mat futhark-0.25.27/tests/operator/size-section2.fut000066400000000000000000000003431475065116200215550ustar00rootroot00000000000000-- Check that sizes are well calculated in left section, even with bounded existential sizes def main [n][m][l] (xs : [n]i64) (ys: [m]i64) (mat: [l][m]i64) = let xs' = filter (>0) xs in [xs' ++ ys] ++ map (xs' ++) mat futhark-0.25.27/tests/operator/size-section3.fut000066400000000000000000000003031475065116200215520ustar00rootroot00000000000000-- Check that sizes are well calculated in left section, with complex sizes -- == def main [n][m][l] (xs : [n]i64) (ys: [m]i64) (mat: [l][m]i64) = [(xs ++ xs) ++ ys] ++ map (xs ++ xs ++) mat futhark-0.25.27/tests/operator/size-section4.fut000066400000000000000000000003041475065116200215540ustar00rootroot00000000000000-- Check that sizes are well calculated in right section, with complex sizes -- == def main [n][m][l] (xs : [n]i64) (ys: [m]i64) (mat: [l][m]i64) = [ys ++ (xs ++ xs)] ++ map (++ xs ++ xs) mat futhark-0.25.27/tests/operator/userdef-error0.fut000066400000000000000000000001511475065116200217200ustar00rootroot00000000000000-- You can't override &&. -- == -- error: && def (x: bool) && (y: bool) = x def main(x: bool) = x && x futhark-0.25.27/tests/operator/userdef-error1.fut000066400000000000000000000001531475065116200217230ustar00rootroot00000000000000-- You can't override ||. -- == -- error: \|\| def (x: bool) || (y: bool) = x def main(x: bool) = x || x futhark-0.25.27/tests/operator/userdef0.fut000066400000000000000000000002331475065116200205720ustar00rootroot00000000000000-- Can we define a user-defined operator at all? -- == -- input { 2 3 } output { -1 } def (x: i32) + (y: i32) = x - y def main (x: i32) (y: i32) = x + y futhark-0.25.27/tests/operator/userdef1.fut000066400000000000000000000003261475065116200205760ustar00rootroot00000000000000-- Do user-defined operators have the right precedence? -- == -- input { 2 3 4 } output { 14 } def (x: i32) +* (y: i32) = x * y def (x: i32) *+ (y: i32) = x + y def main (x: i32) (y: i32) (z: i32) = x +* y *+ z futhark-0.25.27/tests/operator/userdef2.fut000066400000000000000000000003311475065116200205730ustar00rootroot00000000000000-- Do user-defined operators have the right associativity? -- == -- input { 1 2 3 } output { true } def (x: i32) &-& (y: i32) = x - y def main (x: i32) (y: i32) (z: i32): bool = x &-& y &-& z == (x &-& y) &-& z futhark-0.25.27/tests/operator/userdef3.fut000066400000000000000000000002711475065116200205770ustar00rootroot00000000000000-- Can we overload the minus operator (might conflict with prefix -- negation). -- == -- input { 2 3 } output { 5 } def (x: i32) - (y: i32) = x + y def main (x: i32) (y: i32) = x - y futhark-0.25.27/tests/operator/userdef4.fut000066400000000000000000000002531475065116200206000ustar00rootroot00000000000000-- Can we define a user-defined operator with prefix notation? -- == -- input { 2 3 } output { -1 } def (+) (x: i32) (y: i32) = x - y def main (x: i32) (y: i32) = x + y futhark-0.25.27/tests/operator/userdef5.fut000066400000000000000000000004161475065116200206020ustar00rootroot00000000000000-- Polymorphic infix operators ought to work. -- == -- input { [1,2,3] [4,5,6] [true] [false] } -- output { [1,2,3,4,5,6] [true,false] } def (++) 't (xs: []t) (ys: []t) = concat xs ys def main (xs: []i32) (ys: []i32) (as: []bool) (bs: []bool) = (xs ++ ys, as ++ bs) futhark-0.25.27/tests/operator/userdef6.fut000066400000000000000000000006131475065116200206020ustar00rootroot00000000000000-- Polymorphic infix operators ought to work, even in sections. -- == -- input { [[1],[2],[3]] [[4],[5],[6]] [[true]] [[false]] } -- output { [[1],[2],[3]] [[true]] -- [[1],[2],[3]] [[true]] } def (++) 't (xs: []t) (ys: []t) = xs def main (xss: [][]i32) (yss: [][]i32) (ass: [][]bool) (bss: [][]bool) = (map2 (++) xss yss, map2 (++) ass bss, map (++[1]) xss, map ([true]++) bss) futhark-0.25.27/tests/operators.fut000066400000000000000000000003131475065116200172370ustar00rootroot00000000000000-- Test that sophisticated operators (such as "greater than") work. -- == -- input { -- 2 -- 2 -- } -- output { -- false -- true -- } def main (x: i32) (y: i32): (bool,bool) = (x > y, x >= y) futhark-0.25.27/tests/overflowing/000077500000000000000000000000001475065116200170455ustar00rootroot00000000000000futhark-0.25.27/tests/overflowing/edgecases.fut000066400000000000000000000002131475065116200215040ustar00rootroot00000000000000-- Some edge cases of literals that don't overflow, but are close -- -- == entry main : (i8, i8, u16, f64) = (-128, 127, 65535, 1.79e308) futhark-0.25.27/tests/overflowing/f32high.fut000066400000000000000000000000731475065116200210170ustar00rootroot00000000000000-- == -- error: (out of bounds.*) def main : f32 = 9.7e42 futhark-0.25.27/tests/overflowing/f32low.fut000066400000000000000000000000721475065116200207000ustar00rootroot00000000000000-- == -- error: (out of bounds.*) def main : f32 = -1e40 futhark-0.25.27/tests/overflowing/f64low.fut000066400000000000000000000000741475065116200207070ustar00rootroot00000000000000-- == -- error: (out of bounds.*) def main : f64 = 1.8e308 futhark-0.25.27/tests/overflowing/i16low.fut000066400000000000000000000000761475065116200207110ustar00rootroot00000000000000-- == -- error: (out of bounds.*) def main : i16 = -10000000 futhark-0.25.27/tests/overflowing/i32high.fut000066400000000000000000000001021475065116200210130ustar00rootroot00000000000000-- == -- error: (out of bounds.*) def main : i32 = 1000000000000 futhark-0.25.27/tests/overflowing/i8high.fut000066400000000000000000000000671475065116200207500ustar00rootroot00000000000000-- == -- error: (out of bounds.*) def main : i8 = 128 futhark-0.25.27/tests/overflowing/i8low.fut000066400000000000000000000000701475065116200206240ustar00rootroot00000000000000-- == -- error: (out of bounds.*) def main : i8 = -129 futhark-0.25.27/tests/overflowing/u16high.fut000066400000000000000000000000731475065116200210400ustar00rootroot00000000000000-- == -- error: (out of bounds.*) def main : u16 = 100000 futhark-0.25.27/tests/overflowing/u32low.fut000066400000000000000000000000671475065116200207230ustar00rootroot00000000000000-- == -- error: (out of bounds.*) def main : u32 = -4 futhark-0.25.27/tests/overflowing/u8high.fut000066400000000000000000000000671475065116200207640ustar00rootroot00000000000000-- == -- error: (out of bounds.*) def main : u8 = 256 futhark-0.25.27/tests/overflowing/u8low.fut000066400000000000000000000000661475065116200206450ustar00rootroot00000000000000-- == -- error: (out of bounds.*) def main : u8 = -4 futhark-0.25.27/tests/parbits.fut000066400000000000000000000007301475065116200166700ustar00rootroot00000000000000-- Test some bit operations in parallel context (e.g. on GPU). It is -- assumed that the tests in primitive/ validate the sequential -- reference results. -- == -- compiled random input { [100]u8 [100]u16 [100]u32 [100]u64 } auto output def main (u8s: []u8) (u16s: []u16) (u32s: []u32) (u64s: []u64) = (map u8.popc u8s, map u16.popc u16s, map u32.popc u32s, map u64.popc u64s, map u8.clz u8s, map u16.clz u16s, map u32.clz u32s, map u64.clz u64s) futhark-0.25.27/tests/paths/000077500000000000000000000000001475065116200156235ustar00rootroot00000000000000futhark-0.25.27/tests/paths/has_abs.fut000066400000000000000000000002441475065116200177430ustar00rootroot00000000000000-- File containing an abstract type. module m = { type t = i32 def x = 0i32 def eq = (i32.==) } open (m : { type t val x: t val eq: t -> t -> bool }) futhark-0.25.27/tests/paths/subdir/000077500000000000000000000000001475065116200171135ustar00rootroot00000000000000futhark-0.25.27/tests/paths/subdir/has_abs.fut000066400000000000000000000000311475065116200212250ustar00rootroot00000000000000open import "../has_abs" futhark-0.25.27/tests/paths/use_abs.fut000066400000000000000000000004701475065116200177650ustar00rootroot00000000000000-- Test that indirectly including a file with '..' does not result in -- that file being imported twice (verified by using an abstract -- type). -- == -- input {} output { true } module has_abs = import "has_abs" module subdir_has_abs = import "subdir/has_abs" def main = has_abs.eq has_abs.x subdir_has_abs.x futhark-0.25.27/tests/perceptron.fut000066400000000000000000000052171475065116200174120ustar00rootroot00000000000000-- An implementation of the venerable Perceptron algorithm. -- -- For clarity, uses very few library functions; in practice the -- /futlib/linalg module would probably be useful here. -- -- == -- input { -- -- [1.0f32, 1.0f32, 1.0f32] -- -- [[1.0f32, 0.6492f32, 10.5492f32], [1.0f32, 5.0576f32, -1.9462f32], -- [1.0f32, -5.9590f32, 7.8897f32], [1.0f32, 2.9614f32, 1.3547f32], -- [1.0f32, 3.6815f32, 1.6019f32], [1.0f32, 5.3024f32, 3.9243f32], -- [1.0f32, 1.9835f32, 2.3669f32], [1.0f32, -3.4360f32, 8.0828f32], -- [1.0f32, 6.1168f32, 2.3159f32], [1.0f32, 6.2850f32, -0.4685f32], -- [1.0f32, 4.4086f32, 1.3710f32], [1.0f32, -3.7105f32, 8.4309f32], -- [1.0f32, -2.3741f32, 6.1648f32], [1.0f32, 0.4221f32, 8.5627f32], -- [1.0f32, -3.5980f32, 9.2361f32], [1.0f32, -4.5349f32, 9.6428f32], -- [1.0f32, 1.6828f32, 0.5335f32], [1.0f32, 5.3271f32, -1.5529f32], -- [1.0f32, 3.2860f32, 3.1965f32], [1.0f32, 5.2880f32, 1.2030f32], -- [1.0f32, -3.7126f32, 12.7188f32], [1.0f32, -2.5362f32, 6.8989f32], -- [1.0f32, -2.0253f32, 5.1877f32], [1.0f32, 6.7019f32, 3.8357f32], -- [1.0f32, -2.9775f32, 8.5460f32], [1.0f32, 2.4272f32, -0.4192f32], -- [1.0f32, 3.7186f32, 4.0874f32], [1.0f32, -4.3252f32, 6.1897f32], -- [1.0f32, -4.8112f32, 9.7657f32], [1.0f32, -3.4481f32, 10.0994f32]] -- -- [-1f32, 1f32, -1f32, 1f32, 1f32, 1f32, 1f32, -1f32, 1f32, 1f32, -- 1f32, -1f32, -1f32, -1f32, -1f32, -1f32, 1f32, 1f32, 1f32, 1f32, -- -1f32, -1f32, -1f32, 1f32, -1f32, 1f32, 1f32, -1f32, -1f32, -1f32] -- -- 100 -- -- 1f32 -- } -- output { -- 6i32 -- [2.000000f32, 8.614600f32, -4.270200f32] -- 1.000000f32 -- } def dotV [d] (x: [d]f32) (y: [d]f32): f32 = reduce (+) 0.0 (map2 (*) x y) def addV [d] (x: [d]f32) (y: [d]f32): [d]f32 = map2 (+) x y def scaleV [d] (x: [d]f32) (a: f32): [d]f32 = map (*a) x def checkClass [d] (w: [d]f32) (x: [d]f32): f32 = if dotV x w > 0.0 then 1.0 else -1.0 def checkList [d][m] (w: [d]f32) (xs: [m][d]f32) (ys: [m]f32): bool = reduce (&&) true (map2 (\x y -> checkClass w x * y != -1.0) xs ys) def accuracy [d][m] (w: [d]f32) (xs: [m][d]f32) (ys: [m]f32): f32 = reduce (+) 0.0 (map2 (\x y -> f32.bool (checkClass w x * y != -1.0)) xs ys) def train [d] (w: [d]f32) (x: [d]f32) (y: f32) (eta: f32): [d]f32 = if checkClass w x == y then w else addV w (scaleV (scaleV x eta) y) -- Returns: #iterations, final 'w', accuracy from 0-1. def main [d][m] (w: [d]f32) (xd: [m][d]f32) (yd: [m]f32) (limit: i32) (eta: f32): (i32, [d]f32, f32) = let (w,i) = loop (w, i) = (w, 0) while i < limit && !(checkList w xd yd) do -- Find data for this iteration. let x = xd[i%i32.i64 m] let y = yd[i%i32.i64 m] in (train w x y eta, i+1) in (i, w, accuracy w xd yd / f32.i64(m)) futhark-0.25.27/tests/phantomsizes.fut000066400000000000000000000004601475065116200177500ustar00rootroot00000000000000-- == -- input { [1,2,3] } output { [0,1,2] [1,2,3] } type size [n] = [n]() def size n = replicate n () def iota' [n] (_: size [n]) : [n]i32 = 0..1.. [n]i32 def length' [n] 'a (_: [n]a) : size [n] = size n def f xs = zip (iota' (length' xs)) xs def main (xs: []i32) = unzip (f xs) futhark-0.25.27/tests/pow.fut000066400000000000000000000005151475065116200160320ustar00rootroot00000000000000-- Integer power test program -- == -- input { -- 0 0 -- } -- output { -- 1 -- } -- input { -- 1 0 -- } -- output { -- 1 -- } -- input { -- 0 10 -- } -- output { -- 0 -- } -- input { -- 2 3 -- } -- output { -- 8 -- } -- input { -- 2 16 -- } -- output { -- 65536 -- } def main (x: i32) (y: i32): i32 = x ** y futhark-0.25.27/tests/powneg.fut000066400000000000000000000002261475065116200165230ustar00rootroot00000000000000-- Do not crash during constant folding if encountering a negative -- exponent. -- == -- input { true } error: def main b = if b then 2 ** -1 else 0 futhark-0.25.27/tests/prefix_error.fut000066400000000000000000000000731475065116200177320ustar00rootroot00000000000000-- == -- error: cannot be used as infix def x ! y = x + y futhark-0.25.27/tests/prefix_prec.fut000066400000000000000000000001711475065116200175310ustar00rootroot00000000000000-- Test that prefix operators have the right precedence. -- -- == -- input {1 2} output { 1 } def main x y = -x%y : i32 futhark-0.25.27/tests/primitive/000077500000000000000000000000001475065116200165145ustar00rootroot00000000000000futhark-0.25.27/tests/primitive/README.md000066400000000000000000000010211475065116200177650ustar00rootroot00000000000000This directory contains tests for the primtive types in Futhark, their operators, and some built-in functions. It is a good place to start looking when implementing a new code generator backend, because if you don't add integers correctly, you're unlikely to get anywhere quickly. We would get a combinatorial explosion if we had one test program for every combination of type and operator, so instead we try to test several things in the same program (via different return values, and perhaps operations predicated on input). futhark-0.25.27/tests/primitive/acos32.fut000066400000000000000000000002231475065116200203230ustar00rootroot00000000000000-- Does the acos32 function work? -- == -- input { [1f32, 0.5403023f32, -1f32] } -- output { [0f32, 1f32, 3.1415927f32] } def main = map f32.acos futhark-0.25.27/tests/primitive/acos64.fut000066400000000000000000000002231475065116200203300ustar00rootroot00000000000000-- Does the acos64 function work? -- == -- input { [1f64, 0.5403023f64, -1f64] } -- output { [0f64, 1f64, 3.1415927f64] } def main = map f64.acos futhark-0.25.27/tests/primitive/acosh16.fut000066400000000000000000000002461475065116200205020ustar00rootroot00000000000000-- Does the f16.acosh function work? -- == -- input { [1f16, 0.5403023f16, 3.14f16] } -- output { [0f16, f16.nan, 1.810991348900196f16 ] } def main = map f16.acosh futhark-0.25.27/tests/primitive/acosh32.fut000066400000000000000000000002461475065116200205000ustar00rootroot00000000000000-- Does the f32.acosh function work? -- == -- input { [1f32, 0.5403023f32, 3.14f32] } -- output { [0f32, f32.nan, 1.810991348900196f32 ] } def main = map f32.acosh futhark-0.25.27/tests/primitive/acosh64.fut000066400000000000000000000002461475065116200205050ustar00rootroot00000000000000-- Does the f64.acosh function work? -- == -- input { [1f64, 0.5403023f64, 3.14f64] } -- output { [0f64, f64.nan, 1.810991348900196f64 ] } def main = map f64.acosh futhark-0.25.27/tests/primitive/asin32.fut000066400000000000000000000003031475065116200203270ustar00rootroot00000000000000-- Does the asin32 function work? -- == -- input { [0f32, -0.84147096f32, -8.742278e-8f32, 8.742278e-8f32] } -- output { [0f32, -1f32, -8.742278e-8f32, 8.742278e-8f32] } def main = map f32.asin futhark-0.25.27/tests/primitive/asin64.fut000066400000000000000000000002601475065116200203360ustar00rootroot00000000000000-- Does the sin64 function work? -- == -- input { [0.0, -0.84147096, -8.742278e-8, 8.742278e-8] } -- output { [0.0, -1.0, -8.742278e-8, 8.742278e-8] } def main = map f64.asin futhark-0.25.27/tests/primitive/asinh16.fut000066400000000000000000000003541475065116200205070ustar00rootroot00000000000000-- Does the f16.asinh function work? -- == -- input { [0f16, -0.84147096f16, -8.742278e-8f16, 8.742278e-8f16] } -- output { [0f16, -0.7647251350294384f16, -8.742277999999989e-08f16, 8.742277999999989e-08f16] } def main = map f16.asinh futhark-0.25.27/tests/primitive/asinh32.fut000066400000000000000000000003541475065116200205050ustar00rootroot00000000000000-- Does the f32.asinh function work? -- == -- input { [0f32, -0.84147096f32, -8.742278e-8f32, 8.742278e-8f32] } -- output { [0f32, -0.7647251350294384f32, -8.742277999999989e-08f32, 8.742277999999989e-08f32] } def main = map f32.asinh futhark-0.25.27/tests/primitive/asinh64.fut000066400000000000000000000003541475065116200205120ustar00rootroot00000000000000-- Does the f64.asinh function work? -- == -- input { [0f64, -0.84147096f64, -8.742278e-8f64, 8.742278e-8f64] } -- output { [0f64, -0.7647251350294384f64, -8.742277999999989e-08f64, 8.742277999999989e-08f64] } def main = map f64.asinh futhark-0.25.27/tests/primitive/atan2_32.fut000066400000000000000000000004211475065116200205420ustar00rootroot00000000000000-- Does the atan2_32 function work? -- == -- input { [0f32, 1f32, 1f32, 0f32, -1f32, 1f32, -1f32] [0f32, 0f32, 1f32, 1f32, 1f32, -1f32, -1f32] } -- output { [0f32, 1.570796f32, 0.785398f32, 0.000000f32, -0.785398f32, 2.356194f32, -2.356194f32] } def main = map2 f32.atan2 futhark-0.25.27/tests/primitive/atan2_64.fut000066400000000000000000000004211475065116200205470ustar00rootroot00000000000000-- Does the atan2_64 function work? -- == -- input { [0f64, 1f64, 1f64, 0f64, -1f64, 1f64, -1f64] [0f64, 0f64, 1f64, 1f64, 1f64, -1f64, -1f64] } -- output { [0f64, 1.570796f64, 0.785398f64, 0.000000f64, -0.785398f64, 2.356194f64, -2.356194f64] } def main = map2 f64.atan2 futhark-0.25.27/tests/primitive/atan32.fut000066400000000000000000000002261475065116200203240ustar00rootroot00000000000000-- Does the atan32 function work? -- == -- input { [0f32, 1f32, -1f32] } -- output { [0f32, 0.78539819f32, -0.78539819f32] } def main = map f32.atan futhark-0.25.27/tests/primitive/atan64.fut000066400000000000000000000002261475065116200203310ustar00rootroot00000000000000-- Does the atan32 function work? -- == -- input { [0f64, 1f64, -1f64] } -- output { [0f64, 0.78539819f64, -0.78539819f64] } def main = map f64.atan futhark-0.25.27/tests/primitive/atanh16.fut000066400000000000000000000002551475065116200205000ustar00rootroot00000000000000-- Does the f16.atanh function work? -- == -- input { [0f16, 0.5f16, 1f16, -1f16] } -- output { [0f16, 0.5493061443340548f16, f16.inf, -f16.inf] } def main = map f16.atanh futhark-0.25.27/tests/primitive/atanh32.fut000066400000000000000000000002551475065116200204760ustar00rootroot00000000000000-- Does the f32.atanh function work? -- == -- input { [0f32, 0.5f32, 1f32, -1f32] } -- output { [0f32, 0.5493061443340548f32, f32.inf, -f32.inf] } def main = map f32.atanh futhark-0.25.27/tests/primitive/atanh64.fut000066400000000000000000000002551475065116200205030ustar00rootroot00000000000000-- Does the f64.atanh function work? -- == -- input { [0f64, 0.5f64, 1f64, -1f64] } -- output { [0f64, 0.5493061443340548f64, f64.inf, -f64.inf] } def main = map f64.atanh futhark-0.25.27/tests/primitive/bool_cmpop.fut000066400000000000000000000016251475065116200213710ustar00rootroot00000000000000-- Test comparison of boolean values. -- == -- entry: lt -- input { [false, false, true, true ] [false, true, false, true] } -- output { [false, true, false, false] } -- == -- entry: gt -- input { [false, false, true, true ] [false, true, false, true] } -- output { [false, false, true, false] } -- == -- entry: eq -- input { [false, false, true, true ] [false, true, false, true] } -- output { [true, false, false, true] } -- == -- entry: lte -- input { [false, false, true, true ] [false, true, false, true] } -- output { [true, true, false, true] } -- == -- entry: gte -- input { [false, false, true, true ] [false, true, false, true] } -- output { [true, false, true, true] } entry lt (x:[]bool) (y:[]bool)= map2 (<) x y entry gt (x:[]bool) (y:[]bool)= map2 (>) x y entry eq (x:[]bool) (y:[]bool)= map2 (==) x y entry lte (x:[]bool) (y:[]bool)= map2 (<=) x y entry gte (x:[]bool) (y:[]bool)= map2 (>=) x y futhark-0.25.27/tests/primitive/bool_convop.fut000066400000000000000000000016011475065116200215510ustar00rootroot00000000000000-- Convert booleans to different types. -- == -- entry: castB -- input { [false, true] } -- output { [false, true] } -- == -- entry: castI8 -- input { [false, true] } -- output { [0i8, 1i8] } -- == -- entry: castI16 -- input { [false, true] } -- output { [0i16, 1i16] } -- == -- entry: castI32 -- input { [false, true] } -- output { [0i32, 1i32] } -- == -- entry: castI64 -- input { [false, true] } -- output { [0i64, 1i64] } -- == -- entry: castF16 -- input { [false, true] } -- output { [0f16, 1f16] } -- == -- entry: castF32 -- input { [false, true] } -- output { [0f32, 1f32] } -- == -- entry: castF64 -- input { [false, true] } -- output { [0f64, 1f64] } entry castB = map bool.bool entry castI8 = map i8.bool entry castI16 = map i16.bool entry castI32 = map i32.bool entry castI64 = map i64.bool entry castF16 = map f16.bool entry castF32 = map f32.bool entry castF64 = map f64.bool futhark-0.25.27/tests/primitive/cbrt.fut000066400000000000000000000004731475065116200201720ustar00rootroot00000000000000-- == -- entry: cbrt64 -- input { [0f64, 8f64] } -- output { [0f64, 2f64] } -- == -- entry: cbrt32 -- input { [0f32, 8f32] } -- output { [0f32, 2f32] } -- == -- entry: cbrt16 -- input { [0f16, 8f16] } -- output { [0f16, 2f16] } entry cbrt64 = map f64.cbrt entry cbrt32 = map f32.cbrt entry cbrt16 = map f16.cbrt futhark-0.25.27/tests/primitive/ceil32.fut000066400000000000000000000005321475065116200203150ustar00rootroot00000000000000-- Rounding up floats to whole numbers. -- == -- input { [1.0000001192092896f32, -0.9999999403953552f32, -0f32, 0.49999999999999994f32, 0.5f32, 0.5000000000000001f32, -- 1.18e-38f32, -f32.inf, f32.inf, f32.nan, -0f32] } -- output { [2f32, 0f32, -0f32, 1f32, 1f32, 1f32, 1f32, -f32.inf, f32.inf, f32.nan, -0f32] } def main = map f32.ceil futhark-0.25.27/tests/primitive/ceil64.fut000066400000000000000000000010231475065116200203160ustar00rootroot00000000000000-- Rounding floats to whole numbers. -- == -- input { [1.0000000000000002f64, -0.9999999999999999f64, -0.5000000000000001f64, -0f64, 0.49999999999999994f64, 0.5f64, 0.5000000000000001f64, -- 1.390671161567e-309f64, 2.2517998136852485e+15f64, 4.503599627370497e+15f64, -- -f64.inf, f64.inf, f64.nan, -0f64] } -- output { [2f64, 0f64, -0f64, -0f64, 1f64, 1f64, 1f64, -- 1f64, 2.251799813685249e+15f64, 4.503599627370497e+15f64, -- -f64.inf, f64.inf, f64.nan, -0f64] } def main = map f64.ceil futhark-0.25.27/tests/primitive/clz.fut000066400000000000000000000030011475065116200200160ustar00rootroot00000000000000-- == -- entry: clzi8 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } output { [8i32, 0i32, 0i32, 0i32, 0i32] } -- == -- entry: clzi16 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } output { [16i32, 8i32, 0i32, 0i32, 0i32] } -- == -- entry: clzi32 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } output { [32i32, 24i32, 16i32, 0i32, 0i32] } -- == -- entry: clzi64 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } output { [64i32, 56i32, 48i32, 32i32, 0i32] } -- == -- entry: clzu8 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } output { [8i32, 0i32, 0i32, 0i32, 0i32] } -- == -- entry: clzu16 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } output { [16i32, 8i32, 0i32, 0i32, 0i32] } -- == -- entry: clzu32 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } output { [32i32, 24i32, 16i32, 0i32, 0i32] } -- == -- entry: clzu64 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } output { [64i32, 56i32, 48i32, 32i32, 0i32] } entry clzi8 = map (\x -> i8.clz (i8.u64 x)) entry clzi16 = map (\x -> i16.clz (i16.u64 x)) entry clzi32 = map (\x -> i32.clz (i32.u64 x)) entry clzi64 = map (\x -> i64.clz (i64.u64 x)) entry clzu8 = map (\x -> u8.clz (u8.u64 x)) entry clzu16 = map (\x -> u16.clz (u16.u64 x)) entry clzu32 = map (\x -> u32.clz (u32.u64 x)) entry clzu64 = map (\x -> u64.clz (u64.u64 x)) futhark-0.25.27/tests/primitive/copysign.fut000066400000000000000000000011301475065116200210620ustar00rootroot00000000000000-- == -- entry: test_f16 -- input { [1f16, 1f16, f16.nan, f16.inf] -- [2f16, -2f16, -1f16, -1f16] } -- output { [1f16, -1f16, f16.nan, -f16.inf] } -- == -- entry: test_f32 -- input { [1f32, 1f32, f32.nan, f32.inf] -- [2f32, -2f32, -1f32, -1f32] } -- output { [1f32, -1f32, f32.nan, -f32.inf] } -- == -- entry: test_f64 -- input { [1f64, 1f64, f64.nan, f64.inf] -- [2f64, -2f64, -1f64, -1f64] } -- output { [1f64, -1f64, f64.nan, -f64.inf] } entry test_f16 = map2 f16.copysign entry test_f32 = map2 f32.copysign entry test_f64 = map2 f64.copysign futhark-0.25.27/tests/primitive/cos32.fut000066400000000000000000000002501475065116200201620ustar00rootroot00000000000000-- Does the cos32 function work? -- == -- input { [0f32, -1f32, 3.1415927f32, -3.1415927f32] } -- output { [1f32, 0.5403023f32, -1f32, -1f32] } def main = map f32.cos futhark-0.25.27/tests/primitive/cos64.fut000066400000000000000000000002321475065116200201670ustar00rootroot00000000000000-- Does the cos64 function work? -- == -- input { [0.0, -1.0, 3.1415927, -3.1415927] } -- output { [1.0, 0.5403023, -1.0, -1.0] } def main = map f64.cos futhark-0.25.27/tests/primitive/cosh16.fut000066400000000000000000000003301475065116200203330ustar00rootroot00000000000000-- Does the f16.cosh function work? -- == -- input { [0f16, -1f16, 3.1415927f16, -3.1415927f16] } -- output { [1.0f16, 1.5430806348152437f16, 11.591951675521519f16, 11.591951675521519f16] } def main = map f16.cosh futhark-0.25.27/tests/primitive/cosh32.fut000066400000000000000000000003301475065116200203310ustar00rootroot00000000000000-- Does the f32.cosh function work? -- == -- input { [0f32, -1f32, 3.1415927f32, -3.1415927f32] } -- output { [1.0f32, 1.5430806348152437f32, 11.591953275521519f32, 11.591953275521519f32] } def main = map f32.cosh futhark-0.25.27/tests/primitive/cosh64.fut000066400000000000000000000003301475065116200203360ustar00rootroot00000000000000-- Does the f64.cosh function work? -- == -- input { [0f64, -1f64, 3.1415927f64, -3.1415927f64] } -- output { [1.0f64, 1.5430806348152437f64, 11.591953275521519f64, 11.591953275521519f64] } def main = map f64.cosh futhark-0.25.27/tests/primitive/ctz.fut000066400000000000000000000007431475065116200200400ustar00rootroot00000000000000-- == -- entry: ctzi8 -- input { [0i8, 255i8, 128i8] } output { [8, 0, 7] } -- == -- entry: ctzi16 -- input { [0i16, 65535i16, 32768i16] } output { [16, 0, 15] } -- == -- entry: ctzi32 -- input { [0i32, 4294967295i32, 2147483648i32] } output { [32, 0, 31] } -- == -- entry: ctzi64 -- input { [0i64, 18446744073709551615i64, 9223372036854775808i64] } output { [64, 0, 63] } entry ctzi8 = map i8.ctz entry ctzi16 = map i16.ctz entry ctzi32 = map i32.ctz entry ctzi64 = map i64.ctz futhark-0.25.27/tests/primitive/erf.fut000066400000000000000000000005451475065116200200140ustar00rootroot00000000000000-- == -- entry: erf64 -- input { [0f64, 1f64] } -- output { [0f64, 0.8427007929497149f64] } -- == -- entry: erf32 -- input { [0f32, 1f32] } -- output { [0f32, 0.8427007929497149f32] } -- == -- entry: erf16 -- input { [0f16, 1f16] } -- output { [0f16, 0.8427007929497149f16] } entry erf64 = map f64.erf entry erf32 = map f32.erf entry erf16 = map f16.erf futhark-0.25.27/tests/primitive/erfc.fut000066400000000000000000000005611475065116200201550ustar00rootroot00000000000000-- == -- entry: erfc64 -- input { [0f64, 1f64] } -- output { [1f64, 0.15729920705028513f64] } -- == -- entry: erfc32 -- input { [0f32, 1f32] } -- output { [1f32, 0.15729920705028513f32] } -- == -- entry: erfc16 -- input { [0f16, 1f16] } -- output { [1f16, 0.15729920705028513f16] } entry erfc64 = map f64.erfc entry erfc32 = map f32.erfc entry erfc16 = map f16.erfc futhark-0.25.27/tests/primitive/f16.fut000066400000000000000000000020351475065116200176300ustar00rootroot00000000000000-- Test ad-hoc properties and utility functions for f16. -- == -- entry: testInf -- input { [1f16, -1f16, -1f16] [0f16, 0f16, 1f16] } -- output { [true, true, false] } -- == -- entry: testNaN -- input { [1f16, -1f16, -1f16] } -- output { [false, true, true] } -- == -- entry: testToBits -- input { [1f16, -1f16, -1f16] } -- output { [0x3c00u16, 0xbc00u16, 0xbc00u16] } -- == -- entry: testFromBits -- input { [1f16, -1f16, -1f16] } -- output { [1f16, -1f16, -1f16] } -- == -- entry: testNeg -- input { [1f16, f16.inf, -f16.inf, f16.nan] } -- output { [-1f16, -f16.inf, f16.inf, f16.nan] } -- == -- entry: testNegBits -- input { [0u16, 0x8000u16] } -- output { [0x8000u16, 0u16] } entry testInf (xs: []f16) (ys: []f16) = map2 (\x y -> f16.isinf(x/y)) xs ys entry testNaN (xs: []f16) = map (\x -> f16.isnan(f16.sqrt(x))) xs entry testToBits (xs: []f16) = map f16.to_bits xs entry testFromBits (xs: []f16) = map (\x -> f16.from_bits(f16.to_bits(x))) xs entry testNeg = map f16.neg entry testNegBits = map (f16.from_bits >-> f16.neg >-> f16.to_bits) futhark-0.25.27/tests/primitive/f16_binop.fut000066400000000000000000000023621475065116200210220ustar00rootroot00000000000000entry add = map2 (f16.+) entry sub = map2 (f16.-) entry mul = map2 (f16.*) entry div = map2 (f16./) entry mod = map2 (f16.%) entry pow = map2 (f16.**) -- == -- entry: add -- input { [0.0f16, 1.0f16, 1.0f16, -1.0f16, 3.402823e38f16, 0f16, 0f16] -- [0.0f16, 0.0f16, 0.0f16, 0.0f16, 10f16, f16.nan, f16.inf] } -- output { [0.0f16, 1.0f16, 1.0f16, -1.0f16, f16.inf, f16.nan, f16.inf] } -- == -- entry: sub -- input { [0.0f16, 0.0f16, 0.0f16, -3.402823e38f16] -- [0.0f16, 1.0f16, -1.0f16, 10f16] } -- output { [0.0f16, -1.0f16, 1.0f16, -f16.inf] } -- == -- entry: mul -- input { [0.0f16, 0.0f16, 0.0f16, 1.0f16, 2.0f16] -- [0.0f16, 1.0f16, -1.0f16, -1.0f16, 1.5f16] } -- output { [0.0f16, 0.0f16, 0.0f16, -1.0f16, 3.0f16] } -- == -- entry: div -- input { [0.0f16, 0.0f16, 1.0f16, 2.0f16] -- [1.0f16, -1.0f16, -1.0f16, 1.5f16] } -- output { [0.0f16, 0.0f16, -1.0f16, 1.3333333333333f16] } -- == -- entry: mod -- input { [0.0f16, 0.0f16, 1.0f16, 2.0f16] -- [1.0f16, -1.0f16, -1.0f16, 1.5f16] } -- == -- entry: pow -- input { [0.0f16, 1.0f16, 2.0f16, 2.0f16] -- [1.0f16, -1.0f16, 1.5f16, 0f16] } -- output { [0.0f16, 1.0f16, 2.8284271247461903f16, 1f16] } futhark-0.25.27/tests/primitive/f16_minmax.fut000066400000000000000000000011131475065116200211750ustar00rootroot00000000000000-- == -- entry: testMax -- input { [0f16, 1f16, -1f16, 1f16, f16.nan, -1f16, f16.nan, -1f16, -1f16 ] -- [1f16, 1f16, 1f16, -1f16, -1f16, f16.nan, f16.nan, f16.inf, -f16.inf] } -- output { [1f16, 1f16, 1f16, 1f16, -1f16, -1f16, f16.nan, f16.inf, -1f16] } -- == -- entry: testMin -- input { [0f16, 1f16, -1f16, 1f16, f16.nan, -1f16, f16.nan, -1f16, -1f16 ] -- [1f16, 1f16, 1f16, -1f16, -1f16, f16.nan, f16.nan, f16.inf, -f16.inf] } -- output { [0f16, 1f16, -1f16, -1f16, -1f16, -1f16, f16.nan, -1f16, -f16.inf] } entry testMax = map2 f16.max entry testMin = map2 f16.min futhark-0.25.27/tests/primitive/f32.fut000066400000000000000000000020611475065116200176250ustar00rootroot00000000000000-- Test ad-hoc properties and utility functions for f32. -- == -- entry: testInf -- input { [1f32, -1f32, -1f32] [0f32, 0f32, 1f32] } -- output { [true, true, false] } -- == -- entry: testNaN -- input { [1f32, -1f32, -1f32] } -- output { [false, true, true] } -- == -- entry: testToBits -- input { [1f32, -1f32, -1f32] } -- output { [0x3f800000u32, 0xbf800000u32, 0xbf800000u32] } -- == -- entry: testFromBits -- input { [1f32, -1f32, -1f32] } -- output { [1f32, -1f32, -1f32] } -- == -- entry: testNeg -- input { [1f32, f32.inf, -f32.inf, f32.nan] } -- output { [-1f32, -f32.inf, f32.inf, f32.nan] } -- == -- entry: testNegBits -- input { [0u32, 0x80000000u32] } -- output { [0x80000000u32, 0u32] } entry testInf (xs: []f32) (ys: []f32) = map2 (\x y -> f32.isinf(x/y)) xs ys entry testNaN (xs: []f32) = map (\x -> f32.isnan(f32.sqrt(x))) xs entry testToBits (xs: []f32) = map f32.to_bits xs entry testFromBits (xs: []f32) = map (\x -> f32.from_bits(f32.to_bits(x))) xs entry testNeg = map f32.neg entry testNegBits = map (f32.from_bits >-> f32.neg >-> f32.to_bits) futhark-0.25.27/tests/primitive/f32_binop.fut000066400000000000000000000027541475065116200210250ustar00rootroot00000000000000-- f32 test. Does not test for infinity/NaN as we have no way of writing -- that in Futhark yet. Does test for overflow. entry add = map2 (f32.+) entry sub = map2 (f32.-) entry mul = map2 (f32.*) entry div = map2 (f32./) entry mod = map2 (f32.%) entry pow = map2 (f32.**) -- == -- entry: add -- input { [0.0f32, 1.0f32, 1.0f32, -1.0f32, 3.402823e38f32, 0f32, 0f32] -- [0.0f32, 0.0f32, 0.0f32, 0.0f32, 10f32, f32.nan, f32.inf] } -- output { [0.0f32, 1.0f32, 1.0f32, -1.0f32, 340282306073709652508363335590014353408.000000f32, f32.nan, f32.inf] } -- == -- entry: sub -- input { [0.0f32, 0.0f32, 0.0f32, -3.402823e38f32] -- [0.0f32, 1.0f32, -1.0f32, 10f32] } -- output { [0.0f32, -1.0f32, 1.0f32, -340282306073709652508363335590014353408.000000f32] } -- == -- entry: mul -- input { [0.0f32, 0.0f32, 0.0f32, 1.0f32, 2.0f32] -- [0.0f32, 1.0f32, -1.0f32, -1.0f32, 1.5f32] } -- output { [0.0f32, 0.0f32, 0.0f32, -1.0f32, 3.0f32] } -- == -- entry: div -- input { [0.0f32, 0.0f32, 1.0f32, 2.0f32] -- [1.0f32, -1.0f32, -1.0f32, 1.5f32] } -- output { [0.0f32, 0.0f32, -1.0f32, 1.3333333333333f32] } -- == -- entry: mod -- input { [0.0f32, 0.0f32, 1.0f32, 2.0f32] -- [1.0f32, -1.0f32, -1.0f32, 1.5f32] } -- output { [0.0f32, -0.0f32, -0.0f32, 0.5f32] } -- == -- entry: pow -- input { [0.0f32, 1.0f32, 2.0f32, 2.0f32] -- [1.0f32, -1.0f32, 1.5f32, 0f32] } -- output { [0.0f32, 1.0f32, 2.8284271247461903f32, 1f32] } futhark-0.25.27/tests/primitive/f32_minmax.fut000066400000000000000000000011121475065116200211720ustar00rootroot00000000000000-- == -- entry: testMax -- input { [0f32, 1f32, -1f32, 1f32, f32.nan, -1f32, f32.nan, -1f32, -1f32 ] -- [1f32, 1f32, 1f32, -1f32, -1f32, f32.nan, f32.nan, f32.inf, -f32.inf] } -- output { [1f32, 1f32, 1f32, 1f32, -1f32, -1f32, f32.nan, f32.inf, -1f32] } -- == -- entry: testMin -- input { [0f32, 1f32, -1f32, 1f32, f32.nan, -1f32, f32.nan, -1f32, -1f32 ] -- [1f32, 1f32, 1f32, -1f32, -1f32, f32.nan, f32.nan, f32.inf, -f32.inf] } -- output { [0f32, 1f32, -1f32, -1f32, -1f32, -1f32, f32.nan, -1f32, -f32.inf] } entry testMax = map2 f32.max entry testMin = map2 f32.minfuthark-0.25.27/tests/primitive/f64.fut000066400000000000000000000021321475065116200176310ustar00rootroot00000000000000-- Test ad-hoc properties and utility functions for f64. -- == -- entry: testInf -- input { [1f64, -1f64, -1f64] [0f64, 0f64, 1f64] } -- output { [true, true, false] } -- == -- entry: testNaN -- input { [1f64, -1f64, -1f64] } -- output { [false, true, true] } -- == -- entry: testToBits -- input { [1f64, -1f64, -1f64] } -- output { [0x3ff0000000000000u64, 0xbff0000000000000u64, 0xbff0000000000000u64] } -- == -- entry: testFromBits -- input { [1f64, -1f64, -1f64] } -- output { [1f64, -1f64, -1f64] } -- == -- entry: testNeg -- input { [1f64, f64.inf, -f64.inf, f64.nan] } -- output { [-1f64, -f64.inf, f64.inf, f64.nan] } -- == -- entry: testNegBits -- input { [0u64, 0x8000000000000000u64] } -- output { [0x8000000000000000u64, 0u64] } entry testInf (xs: []f64) (ys: []f64) = map2 (\x y -> f64.isinf(x/y)) xs ys entry testNaN (xs: []f64) = map (\x -> f64.isnan(f64.sqrt(x))) xs entry testToBits (xs: []f64) = map f64.to_bits xs entry testFromBits (xs: []f64) = map (\x -> f64.from_bits(f64.to_bits(x))) xs entry testNeg = map f64.neg entry testNegBits = map (f64.from_bits >-> f64.neg >-> f64.to_bits) futhark-0.25.27/tests/primitive/f64_binop.fut000066400000000000000000000025461475065116200210310ustar00rootroot00000000000000-- f64 test. Does not test for infinity/NaN as we have no way of writing -- that in Futhark yet. Does test for overflow. entry add = map2 (f64.+) entry sub = map2 (f64.-) entry mul = map2 (f64.*) entry div = map2 (f64./) entry mod = map2 (f64.%) entry pow = map2 (f64.**) -- == -- entry: add -- input { [0.0f64, 1.0f64, 1.0f64, -1.0f64, 1.79769e308, 0f64, 0f64] -- [0.0f64, 0.0f64, 0.0f64, 0.0f64, 10f64, f64.nan, f64.inf] } -- output { [0.0f64, 1.0f64, 1.0f64, -1.0f64, 1.79769e308, f64.nan, f64.inf] } -- == -- entry: sub -- input { [0.0f64, 0.0f64, 0.0f64, -1.79769e308] -- [0.0f64, 1.0f64, -1.0f64, 10f64] } -- output { [0.0f64, -1.0f64, 1.0f64, -1.79769e308] } -- == -- entry: mul -- input { [0.0f64, 0.0f64, 0.0f64, 1.0f64, 2.0f64] -- [0.0f64, 1.0f64, -1.0f64, -1.0f64, 1.5f64] } -- output { [0.0f64, 0.0f64, 0.0f64, -1.0f64, 3.0f64] } -- == -- entry: div -- input { [0.0f64, 0.0f64, 1.0f64, 2.0f64] -- [1.0f64, -1.0f64, -1.0f64, 1.5f64] } -- output { [0.0f64, 0.0f64, -1.0f64, 1.3333333333333f64] } -- == -- entry: mod -- input { [0.0f64, 0.0f64, 1.0f64, 2.0f64] -- [1.0f64, -1.0f64, -1.0f64, 1.5f64] } -- == -- entry: pow -- input { [0.0f64, 1.0f64, 2.0f64, 2.0f64] -- [1.0f64, -1.0f64, 1.5f64, 0f64] } -- output { [0.0f64, 1.0f64, 2.8284271247461903f64, 1f64] } futhark-0.25.27/tests/primitive/f64_minmax.fut000066400000000000000000000011121475065116200211770ustar00rootroot00000000000000-- == -- entry: testMax -- input { [0f64, 1f64, -1f64, 1f64, f64.nan, -1f64, f64.nan, -1f64, -1f64 ] -- [1f64, 1f64, 1f64, -1f64, -1f64, f64.nan, f64.nan, f64.inf, -f64.inf] } -- output { [1f64, 1f64, 1f64, 1f64, -1f64, -1f64, f64.nan, f64.inf, -1f64] } -- == -- entry: testMin -- input { [0f64, 1f64, -1f64, 1f64, f64.nan, -1f64, f64.nan, -1f64, -1f64 ] -- [1f64, 1f64, 1f64, -1f64, -1f64, f64.nan, f64.nan, f64.inf, -f64.inf] } -- output { [0f64, 1f64, -1f64, -1f64, -1f64, -1f64, f64.nan, -1f64, -f64.inf] } entry testMax = map2 f64.max entry testMin = map2 f64.minfuthark-0.25.27/tests/primitive/fabs.fut000066400000000000000000000006661475065116200201570ustar00rootroot00000000000000-- == -- entry: test_f16 -- input { [1f16, -1f16, -f16.inf, f16.nan] } -- output { [1f16, 1f16, f16.inf, f16.nan] } -- == -- entry: test_f32 -- input { [1f32, -1f32, -f32.inf, f32.nan] } -- output { [1f32, 1f32, f32.inf, f32.nan] } -- == -- entry: test_f64 -- input { [1f64, -1f64, -f64.inf, f64.nan] } -- output { [1f64, 1f64, f64.inf, f64.nan] } entry test_f16 = map f16.abs entry test_f32 = map f32.abs entry test_f64 = map f64.abs futhark-0.25.27/tests/primitive/float_convop.fut000066400000000000000000000060521475065116200217300ustar00rootroot00000000000000-- Converting back and forth between different float- and integer types. entry f16_to_f32 = map f32.f16 entry f16_to_f64 = map f64.f16 entry f32_to_f16 = map f16.f32 entry f32_to_f64 = map f64.f32 entry f64_to_f16 = map f16.f64 entry f64_to_f32 = map f32.f64 entry u16_to_f16 = map f16.from_bits entry u32_to_f32 = map f32.from_bits entry u64_to_f64 = map f64.from_bits entry f32_to_i32 = map i32.f32 entry f32_to_u32 = map u32.f32 entry f64_to_i32 = map i32.f64 entry f64_to_u32 = map u32.f64 entry f16_to_bool = map bool.f16 entry f32_to_bool = map bool.f32 entry f64_to_bool = map bool.f64 -- == -- entry: f16_to_f32 -- input { [f16.inf, -f16.inf, f16.nan, -1f16, 1f16] } -- output { [f32.inf, -f32.inf, f32.nan, -1.0f32, 1.0f32] } -- == -- entry: f16_to_f64 -- input { [f16.inf, -f16.inf, f16.nan, -1f16, 1f16] } -- output { [f64.inf, -f64.inf, f64.nan, -1.0f64, 1.0f64] } -- == -- entry: f32_to_f16 -- input { [f32.inf, -f32.inf, -1f32, 1f32, 100000f32, -100000f32] } -- output { [f16.inf, -f16.inf, -1.0f16, 1.0f16, f16.inf, -f16.inf] } -- == -- entry: f32_to_f64 -- input { [f32.inf, -f32.inf, -1f32, 1f32, 100000f32, -100000f32] } -- output { [f64.inf, -f64.inf, -1.0f64, 1.0f64, 100000.0f64, -100000.0f64] } -- == -- entry: f64_to_f16 -- input { [f64.inf, -f64.inf, -1f64, 1f64, 3.5028234664e38f64, -3.5028234664e38f64] } -- output { [f16.inf, -f16.inf, -1.0f16, 1.0f16, f16.inf, -f16.inf] } -- == -- entry: f64_to_f32 -- input { [f64.inf, -f64.inf, -1f64, 1f64, 3.5028234664e38f64, -3.5028234664e38f64] } -- output { [f32.inf, -f32.inf, -1.0f32, 1.0f32, f32.inf, -f32.inf] } -- == -- entry: u16_to_f16 -- input { [31744u16, 65024u16, 48128u16, 15360u16] } -- output { [f16.inf, f16.nan, -1.0f16, 1.0f16] } -- == -- entry: u32_to_f32 -- input { [2139095040u32, 4290772992u32, 3212836864u32, 1065353216u32] } -- output { [f32.inf, f32.nan, -1f32, 1f32] } -- == -- entry: u64_to_f64 -- input { [9218868437227405312u64, 18444492273895866368u64, 13830554455654793216u64, 4607182418800017408u64] } -- output { [f64.inf, f64.nan, -1.0f64, 1.0f64] } -- == -- entry: f32_to_i32 -- input { [f32.nan, f32.inf, -f32.inf, -1f32, 1f32, 3.5f32, -3.5f32] } -- output { [0, 0, 0, -1, 1, 3, -3] } -- == -- entry: f32_to_u32 -- input { [f32.nan, f32.inf, -f32.inf, -1f32, 1f32, 3.5f32, -3.5f32] } -- output { [0u32, 0u32, 0u32, 4294967295u32, 1u32, 3u32, 4294967293u32] } -- == -- entry: f64_to_i32 -- input { [f64.nan, f64.inf, -f64.inf, -1f64, 1f64, 3.5f64, -3.5f64] } -- output { [0, 0, 0, -1, 1, 3, -3] } -- == -- entry: f64_to_u32 -- input { [f64.nan, f64.inf, -f64.inf, -1f64, 1f64, 3.5f64, -3.5f64] } -- output { [0u32, 0u32, 0u32, 4294967295u32, 1u32, 3u32, 4294967293u32] } -- == -- entry: f16_to_bool -- input { [f16.nan, f16.inf, -f16.inf, -1f16, 1f16, 0f16] } -- output { [true, true, true, true, true, false] } -- == -- entry: f32_to_bool -- input { [f32.nan, f32.inf, -f32.inf, -1f32, 1f32, 0f32] } -- output { [true, true, true, true, true, false] } -- == -- entry: f64_to_bool -- input { [f64.nan, f64.inf, -f64.inf, -1f64, 1f64, 0f64] } -- output { [true, true, true, true, true, false] } futhark-0.25.27/tests/primitive/floor32.fut000066400000000000000000000005741475065116200205300ustar00rootroot00000000000000-- Rounding down floats to whole numbers. -- == -- input { [1.0000001192092896f32, -0.9999999403953552f32, 0f32, -0f32, 0.49999999999999994f32, 0.5f32, 0.5000000000000001f32] } -- output { [1f32, -1f32, 0f32, -0f32, 0f32, 0f32, 0f32] } -- input { [1.18e-38f32, -f32.inf, f32.inf, f32.nan, -0f32] } -- output { [0f32, -f32.inf, f32.inf, f32.nan, -0f32] } def main = map f32.floor futhark-0.25.27/tests/primitive/floor64.fut000066400000000000000000000007471475065116200205370ustar00rootroot00000000000000-- Rounding down to whole numbers. -- == -- input { [1.0000000000000002f64, -0.9999999999999999f64, -0.5000000000000001f64, 0f64, -0f64, 0.49999999999999994f64, 0.5f64, -- 1.390671161567e-309f64, 2.2517998136852485e+15f64, 4.503599627370497e+15f64, -f64.inf, f64.inf, f64.nan, -0f64] } -- output { [1f64, -1f64, -1f64, 0f64, -0f64, 0f64, 0f64, 0f64, 2.251799813685249e+15f64, 4.503599627370497e+15f64, -- -f64.inf, f64.inf, f64.nan, -0f64] } def main = map f64.floor futhark-0.25.27/tests/primitive/fma_mad16.fut000066400000000000000000000006351475065116200207730ustar00rootroot00000000000000-- Test f16.(fma,mad). The test values here are very crude and do not -- actually test the numerical properties we hope to get. -- == -- input { [1f16, 2f16, 3f16 ] -- [3f16, 2f16, 1f16 ] -- [2f16, 3f16, 1f16 ] -- } -- output { [5f16, 7f16, 4f16] -- [5f16, 7f16, 4f16] -- } def main (as: []f16) (bs: []f16) (cs: []f16) = (map3 f16.fma as bs cs, map3 f16.mad as bs cs) futhark-0.25.27/tests/primitive/fma_mad32.fut000066400000000000000000000006351475065116200207710ustar00rootroot00000000000000-- Test f32.(fma,mad). The test values here are very crude and do not -- actually test the numerical properties we hope to get. -- == -- input { [1f32, 2f32, 3f32 ] -- [3f32, 2f32, 1f32 ] -- [2f32, 3f32, 1f32 ] -- } -- output { [5f32, 7f32, 4f32] -- [5f32, 7f32, 4f32] -- } def main (as: []f32) (bs: []f32) (cs: []f32) = (map3 f32.fma as bs cs, map3 f32.mad as bs cs) futhark-0.25.27/tests/primitive/fma_mad64.fut000066400000000000000000000006351475065116200207760ustar00rootroot00000000000000-- Test f64.(fma,mad). The test values here are very crude and do not -- actually test the numerical properties we hope to get. -- == -- input { [1f64, 2f64, 3f64 ] -- [3f64, 2f64, 1f64 ] -- [2f64, 3f64, 1f64 ] -- } -- output { [5f64, 7f64, 4f64] -- [5f64, 7f64, 4f64] -- } def main (as: []f64) (bs: []f64) (cs: []f64) = (map3 f64.fma as bs cs, map3 f64.mad as bs cs) futhark-0.25.27/tests/primitive/fsignum.fut000066400000000000000000000005351475065116200207070ustar00rootroot00000000000000-- == -- entry: test_f32 -- input { [0f32, 10f32, -10f32, f32.inf, -f32.inf, f32.nan] } -- output { [0f32, 1f32, -1f32, 1f32, -1f32, f32.nan] } entry test_f32 = map f32.sgn -- == -- entry: test_f64 -- input { [0f64, 10f64, -10f64, f64.inf, -f64.inf, f64.nan] } -- output { [0f64, 1f64, -1f64, 1f64, -1f64, f64.nan] } entry test_f64 = map f64.sgn futhark-0.25.27/tests/primitive/gamma.fut000066400000000000000000000003261475065116200203170ustar00rootroot00000000000000-- == -- entry: gamma64 -- input { [1.0, 4.0] } -- output { [1f64, 6f64] } -- == -- entry: gamma32 -- input { [1f32, 4f32] } -- output { [1f32, 6f32] } entry gamma64 = map f64.gamma entry gamma32 = map f32.gamma futhark-0.25.27/tests/primitive/hypot.fut000066400000000000000000000007651475065116200204070ustar00rootroot00000000000000-- == -- entry: hypotf16 -- input { [0f16, 3f16, 1.8446744e19f16] [0f16, 4f16, 0f16] } output { [0f16, 5f16, 1.8446744e19f16] } -- == -- entry: hypotf32 -- input { [0f32, 3f32, 1.8446744e19f32] [0f32, 4f32, 0f32] } output { [0f32, 5f32, 1.8446744e19f32] } -- == -- entry: hypotf64 -- input { [0f64, 3f64, 4.149515568880993e180f64] [0f64, 4f64, 0f64] } output { [0f64, 5f64, 4.149515568880993e180f64] } entry hypotf16 = map2 f16.hypot entry hypotf32 = map2 f32.hypot entry hypotf64 = map2 f64.hypot futhark-0.25.27/tests/primitive/i16_binop.fut000066400000000000000000000014741475065116200210300ustar00rootroot00000000000000-- i16 test. entry add = map2 (i16.+) entry sub = map2 (i16.-) entry mul = map2 (i16.*) entry pow = map2 (i16.**) -- == -- entry: add -- input { [0i16, 2i16, 32767i16, 32767i16] -- [0i16, 2i16, 32767i16, -2i16] } -- output { [0i16, 4i16, -2i16, 32765i16] } -- == -- entry: sub -- input { [2i16, 0i16, 32767i16] -- [2i16, 32767i16, -2i16] } -- output { [0i16, -32767i16, -32767i16] } -- == -- entry: mul -- input { [2i16, 2i16, -2i16, -2i16, 128i16] -- [3i16, -3i16, 3i16, -3i16, 512i16] } -- output { [6i16, -6i16, -6i16, 6i16, 0i16] } -- == -- entry: pow -- input { [2i16, 2i16, 11i16, 11i16, 11i16, 11i16, 11i16, 11i16, 11i16] -- [3i16, 0i16, 1i16, 2i16, 3i16, 4i16, 5i16, 6i16, 7i16] } -- output { [8i16, 1i16, 11i16, 121i16, 1331i16, 14641i16, 29979i16, 2089i16, 22979i16] }futhark-0.25.27/tests/primitive/i16_bitop.fut000066400000000000000000000024001475065116200210240ustar00rootroot00000000000000-- Bitwise operations on i16 values. -- -- == -- entry: land -- input { [0i16, 0i16, 0i16, 1i16, 1i16, 1i16, -1i16, -1i16, -1i16] -- [0i16, 1i16, -1i16, 0i16, 1i16, -1i16, 0i16, 1i16, -1i16] } -- output { [0i16, 0i16, 0i16, 0i16, 1i16, 1i16, 0i16, 1i16, -1i16] } -- == -- entry: lor -- input { [0i16, 0i16, 0i16, 1i16, 1i16, 1i16, -1i16, -1i16, -1i16, 64i16] -- [0i16, 1i16, -1i16, 0i16, 1i16, -1i16, 0i16, 1i16, -1i16, 32i16]} -- output { [0i16, 1i16, -1i16, 1i16, 1i16, -1i16, -1i16, -1i16, -1i16, 96i16] } -- == -- entry: lxor -- input { [0i16, 0i16, 0i16, 1i16, 1i16, 1i16, -1i16, -1i16, -1i16, 64i16] -- [0i16, 1i16, -1i16, 0i16, 1i16, -1i16, 0i16, 1i16, -1i16, 32i16]} -- output { [0i16, 1i16, -1i16, 1i16, 0i16, -2i16, -1i16, -2i16, 0i16, 96i16]} -- == -- entry: left -- input { [0i16, 0i16, 1i16, 1i16, -1i16, -1i16] -- [0i16, 1i16, 0i16, 1i16, 0i16, 1i16] } -- output { [0i16, 0i16, 1i16, 2i16, -1i16, -2i16] } -- == -- entry: right -- input { [0i16, 0i16, 1i16, 1i16, 2i16, -1i16, -1i16] -- [0i16, 1i16, 0i16, 1i16, 1i16, 0i16, 1i16] } -- output { [0i16, 0i16, 1i16, 0i16, 1i16, -1i16, -1i16]} entry land = map2 (i16.&) entry lor = map2 (i16.|) entry lxor = map2 (i16.^) entry left = map2 (i16.<<) entry right = map2 (i16.>>)futhark-0.25.27/tests/primitive/i16_cmpop.fut000066400000000000000000000012061475065116200210300ustar00rootroot00000000000000-- Test comparison of i16 values. -- -- == -- entry: lt -- input { [0i16, 1i16, -1i16, 1i16, -2i16 ] -- [0i16, 2i16, 1i16, -1i16, -1i16] } -- output { [false, true, true, false, true] } -- == -- entry: eq -- input { [0i16, 1i16, -1i16, 1i16, -2i16 ] -- [0i16, 2i16, 1i16, -1i16, -1i16] } -- output { [true, false, false, false, false] } -- == -- entry: lte -- input { [0i16, 1i16, -1i16, 1i16, -2i16 ] -- [0i16, 2i16, 1i16, -1i16, -1i16] } -- output { [true, true, true, false, true] } entry lt (x:[]i16) (y:[]i16)= map2 (<) x y entry eq (x:[]i16) (y:[]i16)= map2 (==) x y entry lte (x:[]i16) (y:[]i16)= map2 (<=) x yfuthark-0.25.27/tests/primitive/i16_convop.fut000066400000000000000000000026041475065116200212210ustar00rootroot00000000000000-- Convert back and forth between different integer types. -- -- == -- entry: i16tobool -- input { [0i16, 64i16, 32767i16, -32768i16] } -- output { [false, true, true, true] } -- == -- entry: i16toi8 -- input { [0i16, 64i16, 32767i16, -32768i16] } -- output { [0i8, 64i8, -1i8, 0i8] } -- == -- entry: i16toi16 -- input { [0i16, 64i16, 32767i16, -32768i16] } -- output { [0i16, 64i16, 32767i16, -32768i16] } -- == -- entry: i16toi32 -- input { [0i16, 64i16, 32767i16, -32768i16] } -- output { [0i32, 64i32, 32767i32, -32768i32] } -- == -- entry: i16toi64 -- input { [0i16, 64i16, 32767i16, -32768i16] } -- output { [0i64, 64i64, 32767i64, -32768i64] } -- == -- entry: i16tou8 -- input { [0i16, 64i16, 32767i16, -32768i16] } -- output { [0u8, 64u8, 255u8, 0u8] } -- == -- entry: i16tou16 -- input { [0i16, 64i16, 32767i16, -32768i16] } -- output { [0u16, 64u16, 32767u16, 32768u16] } -- == -- entry: i16tou32 -- input { [0i16, 64i16, 32767i16, -32768i16] } -- output { [0u32, 64u32, 32767u32, 32768u32] } -- == -- entry: i16tou64 -- input { [0i16, 64i16, 32767i16, -32768i16] } -- output { [0u64, 64u64, 32767u64, 32768u64] } entry i16tobool = map (bool.i16) entry i16toi8 = map (i8.i16) entry i16toi16 = map (i16.i16) entry i16toi32 = map (i32.i16) entry i16toi64 = map (i64.i16) entry i16tou8 = map (u8.i16) entry i16tou16 = map (u16.i16) entry i16tou32 = map (u32.i16) entry i16tou64 = map (u64.i16) futhark-0.25.27/tests/primitive/i16_division.fut000066400000000000000000000013211475065116200215340ustar00rootroot00000000000000-- Test of division-like operators for i16 values. -- == -- entry: divide -- input { [7i16, -7i16, 7i16, -7i16] -- [3i16, 3i16, -3i16, -3i16] } -- output { [2i16, -3i16, -3i16, 2i16] } -- == -- entry: mod -- input { [7i16, -7i16, 7i16, -7i16] -- [3i16, 3i16, -3i16, -3i16] } -- output { [1i16, 2i16, -2i16, -1i16] } -- == -- entry: quot -- input { [7i16, -7i16, 7i16, -7i16] -- [3i16, 3i16, -3i16, -3i16] } -- output { [2i16, -2i16, -2i16, 2i16] } -- == -- entry: rem -- input { [7i16, -7i16, 7i16, -7i16] -- [3i16, 3i16, -3i16, -3i16] } -- output { [1i16, -1i16, 1i16, -1i16] } entry divide = map2 (i16./) entry mod = map2 (i16.%) entry quot = map2 (i16.//) entry rem = map2 (i16.%%)futhark-0.25.27/tests/primitive/i16_minmax.fut000066400000000000000000000005201475065116200212010ustar00rootroot00000000000000-- == -- entry: testMax -- input { [0i16, 1i16, -1i16, 1i16] -- [1i16, 1i16, 1i16, -1i16]} -- output { [1i16, 1i16, 1i16, 1i16] } -- == -- entry: testMin -- input { [0i16, 1i16, -1i16, 1i16] -- [1i16, 1i16, 1i16, -1i16]} -- output { [0i16, 1i16, -1i16, -1i16] } entry testMax = map2 i16.max entry testMin = map2 i16.minfuthark-0.25.27/tests/primitive/i16_unop.fut000066400000000000000000000007431475065116200207000ustar00rootroot00000000000000-- Test unary operators for i16. -- == -- entry: negatei16 -- input { [0i16, 1i16, -1i16, 8i16, -8i16] } -- output { [0i16, -1i16, 1i16, -8i16, 8i16] } -- == -- entry: absi16 -- input { [0i16, 1i16, -1i16, 8i16, -8i16] } -- output { [0i16, 1i16, 1i16, 8i16, 8i16] } -- == -- entry: sgni16 -- input { [0i16, 1i16, -1i16, 8i16, -8i16] } -- output { [0i16, 1i16, -1i16, 1i16, -1i16] } entry negatei16 = map (\x : i16 -> -x) entry absi16 = map (i16.abs) entry sgni16 = map (i16.sgn) futhark-0.25.27/tests/primitive/i32_binop.fut000066400000000000000000000016061475065116200210230ustar00rootroot00000000000000-- i32 test. entry add = map2 (i32.+) entry sub = map2 (i32.-) entry mul = map2 (i32.*) entry pow = map2 (i32.**) -- == -- entry: add -- input { [0i32, 2i32, 2147483647i32, 2147483647i32] -- [0i32, 2i32, 2147483647i32, -2i32] } -- output { [0i32, 4i32, -2i32, 2147483645i32] } -- == -- entry: sub -- input { [2i32, 0i32, 2147483647i32] -- [2i32, 2147483647i32, -2i32] } -- output { [0i32, -2147483647i32, -2147483647i32] } -- == -- entry: mul -- input { [2i32, 2i32, -2i32, -2i32, 262144i32] -- [3i32, -3i32, 3i32, -3i32, 262144i32] } -- output { [6i32, -6i32, -6i32, 6i32, 0i32] } -- == -- entry: pow -- input { [2i32, 47i32, 47i32, 47i32, 47i32, 47i32, 47i32, 47i32, 47i32] -- [3i32, 0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32] } -- output { [8i32, 1i32, 47i32, 2209i32, 103823i32, 4879681i32, 229345007i32, -2105686559i32, -183020465i32] } futhark-0.25.27/tests/primitive/i32_bitop.fut000066400000000000000000000024001475065116200210220ustar00rootroot00000000000000-- Bitwise operations on i32 values. -- -- == -- entry: land -- input { [0i32, 0i32, 0i32, 1i32, 1i32, 1i32, -1i32, -1i32, -1i32] -- [0i32, 1i32, -1i32, 0i32, 1i32, -1i32, 0i32, 1i32, -1i32] } -- output { [0i32, 0i32, 0i32, 0i32, 1i32, 1i32, 0i32, 1i32, -1i32] } -- == -- entry: lor -- input { [0i32, 0i32, 0i32, 1i32, 1i32, 1i32, -1i32, -1i32, -1i32, 64i32] -- [0i32, 1i32, -1i32, 0i32, 1i32, -1i32, 0i32, 1i32, -1i32, 32i32]} -- output { [0i32, 1i32, -1i32, 1i32, 1i32, -1i32, -1i32, -1i32, -1i32, 96i32] } -- == -- entry: lxor -- input { [0i32, 0i32, 0i32, 1i32, 1i32, 1i32, -1i32, -1i32, -1i32, 64i32] -- [0i32, 1i32, -1i32, 0i32, 1i32, -1i32, 0i32, 1i32, -1i32, 32i32]} -- output { [0i32, 1i32, -1i32, 1i32, 0i32, -2i32, -1i32, -2i32, 0i32, 96i32]} -- == -- entry: left -- input { [0i32, 0i32, 1i32, 1i32, -1i32, -1i32] -- [0i32, 1i32, 0i32, 1i32, 0i32, 1i32] } -- output { [0i32, 0i32, 1i32, 2i32, -1i32, -2i32] } -- == -- entry: right -- input { [0i32, 0i32, 1i32, 1i32, 2i32, -1i32, -1i32] -- [0i32, 1i32, 0i32, 1i32, 1i32, 0i32, 1i32] } -- output { [0i32, 0i32, 1i32, 0i32, 1i32, -1i32, -1i32]} entry land = map2 (i32.&) entry lor = map2 (i32.|) entry lxor = map2 (i32.^) entry left = map2 (i32.<<) entry right = map2 (i32.>>)futhark-0.25.27/tests/primitive/i32_cmpop.fut000066400000000000000000000012061475065116200210260ustar00rootroot00000000000000-- Test comparison of i32 values. -- -- == -- entry: lt -- input { [0i32, 1i32, -1i32, 1i32, -2i32 ] -- [0i32, 2i32, 1i32, -1i32, -1i32] } -- output { [false, true, true, false, true] } -- == -- entry: eq -- input { [0i32, 1i32, -1i32, 1i32, -2i32 ] -- [0i32, 2i32, 1i32, -1i32, -1i32] } -- output { [true, false, false, false, false] } -- == -- entry: lte -- input { [0i32, 1i32, -1i32, 1i32, -2i32 ] -- [0i32, 2i32, 1i32, -1i32, -1i32] } -- output { [true, true, true, false, true] } entry lt (x:[]i32) (y:[]i32)= map2 (<) x y entry eq (x:[]i32) (y:[]i32)= map2 (==) x y entry lte (x:[]i32) (y:[]i32)= map2 (<=) x yfuthark-0.25.27/tests/primitive/i32_convop.fut000066400000000000000000000027711475065116200212240ustar00rootroot00000000000000-- Convert back and forth between different integer types. -- == -- entry: i32tobool -- input { [0i32, 64i32, 2147483647i32, -2147483648i32] } -- output { [false, true, true, true] } -- == -- entry: i32toi8 -- input { [0i32, 64i32, 2147483647i32, -2147483648i32] } -- output { [0i8, 64i8, -1i8, -0i8] } -- == -- entry: i32toi16 -- input { [0i32, 64i32, 2147483647i32, -2147483648i32] } -- output { [0i16, 64i16, -1i16, 0i16] } -- == -- entry: i32toi32 -- input { [0i32, 64i32, 2147483647i32, -2147483648i32] } -- output { [0i32, 64i32, 2147483647i32, -2147483648i32] } -- == -- entry: i32toi64 -- input { [0i32, 64i32, 2147483647i32, -2147483648i32] } -- output { [0i64, 64i64, 2147483647i64, -2147483648i64] } -- == -- entry: i32tou8 -- input { [0i32, 64i32, 2147483647i32, -2147483648i32] } -- output { [0u8, 64u8, 255u8, 0u8] } -- == -- entry: i32tou16 -- input { [0i32, 64i32, 2147483647i32, -2147483648i32] } -- output { [0u16, 64u16, 65535u16, 0u16] } -- == -- entry: i32tou32 -- input { [0i32, 64i32, 2147483647i32, -2147483648i32] } -- output { [0u32, 64u32, 2147483647u32, 2147483648u32] } -- == -- entry: i32tou64 -- input { [0i32, 64i32, 2147483647i32, -2147483648i32] } -- output { [0u64, 64u64, 2147483647u64, 2147483648u64] } entry i32tobool = map (bool.i32) entry i32toi8 = map (i8.i32) entry i32toi16 = map (i16.i32) entry i32toi32 = map (i32.i32) entry i32toi64 = map (i64.i32) entry i32tou8 = map (u8.i32) entry i32tou16 = map (u16.i32) entry i32tou32 = map (u32.i32) entry i32tou64 = map (u64.i32) futhark-0.25.27/tests/primitive/i32_division.fut000066400000000000000000000013221475065116200215330ustar00rootroot00000000000000-- Test of division-like operators for i32 values. -- == -- entry: divide -- input { [7i32, -7i32, 7i32, -7i32] -- [3i32, 3i32, -3i32, -3i32] } -- output { [2i32, -3i32, -3i32, 2i32] } -- == -- entry: mod -- input { [7i32, -7i32, 7i32, -7i32] -- [3i32, 3i32, -3i32, -3i32] } -- output { [1i32, 2i32, -2i32, -1i32] } -- == -- entry: quot -- input { [7i32, -7i32, 7i32, -7i32] -- [3i32, 3i32, -3i32, -3i32] } -- output { [2i32, -2i32, -2i32, 2i32] } -- == -- entry: rem -- input { [7i32, -7i32, 7i32, -7i32] -- [3i32, 3i32, -3i32, -3i32] } -- output { [1i32, -1i32, 1i32, -1i32] } entry divide = map2 (i32./) entry mod = map2 (i32.%) entry quot = map2 (i32.//) entry rem = map2 (i32.%%) futhark-0.25.27/tests/primitive/i32_minmax.fut000066400000000000000000000005201475065116200211770ustar00rootroot00000000000000-- == -- entry: testMax -- input { [0i32, 1i32, -1i32, 1i32] -- [1i32, 1i32, 1i32, -1i32]} -- output { [1i32, 1i32, 1i32, 1i32] } -- == -- entry: testMin -- input { [0i32, 1i32, -1i32, 1i32] -- [1i32, 1i32, 1i32, -1i32]} -- output { [0i32, 1i32, -1i32, -1i32] } entry testMax = map2 i32.max entry testMin = map2 i32.minfuthark-0.25.27/tests/primitive/i32_unop.fut000066400000000000000000000007431475065116200206760ustar00rootroot00000000000000-- Test unary operators for i32. -- == -- entry: negatei32 -- input { [0i32, 1i32, -1i32, 8i32, -8i32] } -- output { [0i32, -1i32, 1i32, -8i32, 8i32] } -- == -- entry: absi32 -- input { [0i32, 1i32, -1i32, 8i32, -8i32] } -- output { [0i32, 1i32, 1i32, 8i32, 8i32] } -- == -- entry: sgni32 -- input { [0i32, 1i32, -1i32, 8i32, -8i32] } -- output { [0i32, 1i32, -1i32, 1i32, -1i32] } entry negatei32 = map (\x : i32 -> -x) entry absi32 = map (i32.abs) entry sgni32 = map (i32.sgn) futhark-0.25.27/tests/primitive/i64_binop.fut000066400000000000000000000020761475065116200210320ustar00rootroot00000000000000-- i64 test. entry add = map2 (i64.+) entry sub = map2 (i64.-) entry mul = map2 (i64.*) entry pow = map2 (i64.**) -- == -- entry: add -- input { [0i64, 2i64, 9223372036854775807i64, 9223372036854775807i64] -- [0i64, 2i64, 9223372036854775807i64, -2i64] } -- output { [0i64, 4i64, -2i64, 9223372036854775805i64] } -- == -- entry: sub -- input { [2i64, 0i64, 9223372036854775807i64] -- [2i64, 9223372036854775807i64, -2i64] } -- output { [0i64, -9223372036854775807i64, -9223372036854775807i64] } -- == -- entry: mul -- input { [2i64, 2i64, -2i64, -2i64, 6442450941i64] -- [3i64, -3i64, 3i64, -3i64, 2147483647i64] } -- output { [6i64, -6i64, -6i64, 6i64, -4611686031312289789i64] } -- == -- entry: pow -- input { [2i64, 4021i64, 4021i64, 4021i64, 4021i64, 4021i64, 4021i64, 4021i64, 4021i64] -- [3i64, 0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64] } -- output { [8i64, 1i64, 4021i64, 16168441i64, 65013301261i64, 261418484370481i64, -- 1051163725653704101i64, 2424947974056870057i64, -7611811309678305667i64] } futhark-0.25.27/tests/primitive/i64_bitop.fut000066400000000000000000000024001475065116200210270ustar00rootroot00000000000000-- Bitwise operations on i64 values. -- -- == -- entry: land -- input { [0i64, 0i64, 0i64, 1i64, 1i64, 1i64, -1i64, -1i64, -1i64] -- [0i64, 1i64, -1i64, 0i64, 1i64, -1i64, 0i64, 1i64, -1i64] } -- output { [0i64, 0i64, 0i64, 0i64, 1i64, 1i64, 0i64, 1i64, -1i64] } -- == -- entry: lor -- input { [0i64, 0i64, 0i64, 1i64, 1i64, 1i64, -1i64, -1i64, -1i64, 64i64] -- [0i64, 1i64, -1i64, 0i64, 1i64, -1i64, 0i64, 1i64, -1i64, 32i64]} -- output { [0i64, 1i64, -1i64, 1i64, 1i64, -1i64, -1i64, -1i64, -1i64, 96i64] } -- == -- entry: lxor -- input { [0i64, 0i64, 0i64, 1i64, 1i64, 1i64, -1i64, -1i64, -1i64, 64i64] -- [0i64, 1i64, -1i64, 0i64, 1i64, -1i64, 0i64, 1i64, -1i64, 32i64]} -- output { [0i64, 1i64, -1i64, 1i64, 0i64, -2i64, -1i64, -2i64, 0i64, 96i64]} -- == -- entry: left -- input { [0i64, 0i64, 1i64, 1i64, -1i64, -1i64] -- [0i64, 1i64, 0i64, 1i64, 0i64, 1i64] } -- output { [0i64, 0i64, 1i64, 2i64, -1i64, -2i64] } -- == -- entry: right -- input { [0i64, 0i64, 1i64, 1i64, 2i64, -1i64, -1i64] -- [0i64, 1i64, 0i64, 1i64, 1i64, 0i64, 1i64] } -- output { [0i64, 0i64, 1i64, 0i64, 1i64, -1i64, -1i64]} entry land = map2 (i64.&) entry lor = map2 (i64.|) entry lxor = map2 (i64.^) entry left = map2 (i64.<<) entry right = map2 (i64.>>)futhark-0.25.27/tests/primitive/i64_cmpop.fut000066400000000000000000000012061475065116200210330ustar00rootroot00000000000000-- Test comparison of i64 values. -- -- == -- entry: lt -- input { [0i64, 1i64, -1i64, 1i64, -2i64 ] -- [0i64, 2i64, 1i64, -1i64, -1i64] } -- output { [false, true, true, false, true] } -- == -- entry: eq -- input { [0i64, 1i64, -1i64, 1i64, -2i64 ] -- [0i64, 2i64, 1i64, -1i64, -1i64] } -- output { [true, false, false, false, false] } -- == -- entry: lte -- input { [0i64, 1i64, -1i64, 1i64, -2i64 ] -- [0i64, 2i64, 1i64, -1i64, -1i64] } -- output { [true, true, true, false, true] } entry lt (x:[]i64) (y:[]i64)= map2 (<) x y entry eq (x:[]i64) (y:[]i64)= map2 (==) x y entry lte (x:[]i64) (y:[]i64)= map2 (<=) x yfuthark-0.25.27/tests/primitive/i64_convop.fut000066400000000000000000000032431475065116200212240ustar00rootroot00000000000000-- Convert back and forth between different integer types. -- == -- entry: i64tobool -- input { [0i64, 64i64, 9223372036854775807i64, -9223372036854775808i64] } -- output { [false, true, true, true] } -- == -- entry: i64toi8 -- input { [0i64, 64i64, 9223372036854775807i64, -9223372036854775808i64] } -- output { [0i8, 64i8, -1i8, 0i8] } -- == -- entry: i64toi16 -- input { [0i64, 64i64, 9223372036854775807i64, -9223372036854775808i64] } -- output { [0i16, 64i16, -1i16, 0i16] } -- == -- entry: i64toi32 -- input { [0i64, 64i64, 9223372036854775807i64, -9223372036854775808i64] } -- output { [0i32, 64i32, -1i32, 0i32] } -- == -- entry: i64toi64 -- input { [0i64, 64i64, 9223372036854775807i64, -9223372036854775808i64] } -- output { [0i64, 64i64, 9223372036854775807i64, -9223372036854775808i64] } -- == -- entry: i64tou8 -- input { [0i64, 64i64, 9223372036854775807i64, -9223372036854775808i64] } -- output { [0u8, 64u8, 255u8, 0u8] } -- == -- entry: i64tou16 -- input { [0i64, 64i64, 9223372036854775807i64, -9223372036854775808i64] } -- output { [0u16, 64u16, 65535u16, 0u16] } -- == -- entry: i64tou32 -- input { [0i64, 64i64, 9223372036854775807i64, -9223372036854775808i64] } -- output { [0u32, 64u32, 4294967295u32, 0u32] } -- == -- entry: i64tou64 -- input { [0i64, 64i64, 9223372036854775807i64, -9223372036854775808i64] } -- output { [0u64, 64u64, 9223372036854775807u64, 9223372036854775808u64] } entry i64tobool = map (bool.i64) entry i64toi8 = map (i8.i64) entry i64toi16 = map (i16.i64) entry i64toi32 = map (i32.i64) entry i64toi64 = map (i64.i64) entry i64tou8 = map (u8.i64) entry i64tou16 = map (u16.i64) entry i64tou32 = map (u32.i64) entry i64tou64 = map (u64.i64) futhark-0.25.27/tests/primitive/i64_division.fut000066400000000000000000000013221475065116200215400ustar00rootroot00000000000000-- Test of division-like operators for i64 values. -- == -- entry: divide -- input { [7i64, -7i64, 7i64, -7i64] -- [3i64, 3i64, -3i64, -3i64] } -- output { [2i64, -3i64, -3i64, 2i64] } -- == -- entry: mod -- input { [7i64, -7i64, 7i64, -7i64] -- [3i64, 3i64, -3i64, -3i64] } -- output { [1i64, 2i64, -2i64, -1i64] } -- == -- entry: quot -- input { [7i64, -7i64, 7i64, -7i64] -- [3i64, 3i64, -3i64, -3i64] } -- output { [2i64, -2i64, -2i64, 2i64] } -- == -- entry: rem -- input { [7i64, -7i64, 7i64, -7i64] -- [3i64, 3i64, -3i64, -3i64] } -- output { [1i64, -1i64, 1i64, -1i64] } entry divide = map2 (i64./) entry mod = map2 (i64.%) entry quot = map2 (i64.//) entry rem = map2 (i64.%%) futhark-0.25.27/tests/primitive/i64_minmax.fut000066400000000000000000000005201475065116200212040ustar00rootroot00000000000000-- == -- entry: testMax -- input { [0i64, 1i64, -1i64, 1i64] -- [1i64, 1i64, 1i64, -1i64]} -- output { [1i64, 1i64, 1i64, 1i64] } -- == -- entry: testMin -- input { [0i64, 1i64, -1i64, 1i64] -- [1i64, 1i64, 1i64, -1i64]} -- output { [0i64, 1i64, -1i64, -1i64] } entry testMax = map2 i64.max entry testMin = map2 i64.minfuthark-0.25.27/tests/primitive/i64_unop.fut000066400000000000000000000007541475065116200207050ustar00rootroot00000000000000-- Test unary operators for i64. -- -- == -- entry: neg -- input { [0i64, 1i64, -1i64, 8i64, -8i64]} -- output { [0i64,-1i64,1i64,-8i64,8i64] } -- == -- entry: abs -- input { [0i64, 1i64, -1i64, 8i64, -8i64, 5000000000i64, -5000000000i64] } -- output { [0i64,1i64,1i64,8i64,8i64, 5000000000i64, 5000000000i64] } -- == -- entry: sgn -- input { [0i64,1i64,-1i64,8i64,-8i64] } -- output { [0i64,1i64,-1i64,1i64,-1i64] } entry neg = map i64.neg entry abs = map i64.abs entry sgn = map i64.sgn futhark-0.25.27/tests/primitive/i8_binop.fut000066400000000000000000000013211475065116200207400ustar00rootroot00000000000000-- i8 test. entry add = map2 (i8.+) entry sub = map2 (i8.-) entry mul = map2 (i8.*) entry pow = map2 (i8.**) -- == -- entry: add -- input { [0i8, 2i8, 127i8, 127i8] -- [0i8, 2i8, 127i8, -2i8] } -- output { [0i8, 4i8, -2i8, 125i8] } -- == -- entry: sub -- input { [2i8, 0i8, 127i8] -- [2i8, 127i8, -2i8] } -- output { [0i8, -127i8, -127i8] } -- == -- entry: mul -- input { [2i8, 2i8, -2i8, -2i8, 4i8] -- [3i8, -3i8, 3i8, -3i8, 64i8] } -- output { [6i8, -6i8, -6i8, 6i8, 0i8] } -- == -- entry: pow -- input { [2i8, 2i8, 3i8, 3i8, 3i8, 3i8, 3i8, 3i8, 3i8] -- [3i8, 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 7i8] } -- output { [8i8, 1i8, 3i8, 9i8, 27i8, 81i8, -13i8, -39i8, -117i8] }futhark-0.25.27/tests/primitive/i8_bitop.fut000066400000000000000000000021741475065116200207550ustar00rootroot00000000000000-- Bitwise operations on i8 values. -- -- == -- entry: land -- input { [0i8, 0i8, 0i8, 1i8, 1i8, 1i8, -1i8, -1i8, -1i8] -- [0i8, 1i8, -1i8, 0i8, 1i8, -1i8, 0i8, 1i8, -1i8] } -- output { [0i8, 0i8, 0i8, 0i8, 1i8, 1i8, 0i8, 1i8, -1i8] } -- == -- entry: lor -- input { [0i8, 0i8, 0i8, 1i8, 1i8, 1i8, -1i8, -1i8, -1i8, 64i8] -- [0i8, 1i8, -1i8, 0i8, 1i8, -1i8, 0i8, 1i8, -1i8, 32i8]} -- output { [0i8, 1i8, -1i8, 1i8, 1i8, -1i8, -1i8, -1i8, -1i8, 96i8] } -- == -- entry: lxor -- input { [0i8, 0i8, 0i8, 1i8, 1i8, 1i8, -1i8, -1i8, -1i8, 64i8] -- [0i8, 1i8, -1i8, 0i8, 1i8, -1i8, 0i8, 1i8, -1i8, 32i8]} -- output { [0i8, 1i8, -1i8, 1i8, 0i8, -2i8, -1i8, -2i8, 0i8, 96i8]} -- == -- entry: left -- input { [0i8, 0i8, 1i8, 1i8, -1i8, -1i8] -- [0i8, 1i8, 0i8, 1i8, 0i8, 1i8] } -- output { [0i8, 0i8, 1i8, 2i8, -1i8, -2i8] } -- == -- entry: right -- input { [0i8, 0i8, 1i8, 1i8, 2i8, -1i8, -1i8] -- [0i8, 1i8, 0i8, 1i8, 1i8, 0i8, 1i8] } -- output { [0i8, 0i8, 1i8, 0i8, 1i8, -1i8, -1i8]} entry land = map2 (i8.&) entry lor = map2 (i8.|) entry lxor = map2 (i8.^) entry left = map2 (i8.<<) entry right = map2 (i8.>>)futhark-0.25.27/tests/primitive/i8_cmpop.fut000066400000000000000000000011411475065116200207470ustar00rootroot00000000000000-- Test comparison of i8 values. -- -- == -- entry: lt -- input { [0i8, 1i8, -1i8, 1i8, -2i8 ] -- [0i8, 2i8, 1i8, -1i8, -1i8] } -- output { [false, true, true, false, true] } -- == -- entry: eq -- input { [0i8, 1i8, -1i8, 1i8, -2i8 ] -- [0i8, 2i8, 1i8, -1i8, -1i8] } -- output { [true, false, false, false, false] } -- == -- entry: lte -- input { [0i8, 1i8, -1i8, 1i8, -2i8 ] -- [0i8, 2i8, 1i8, -1i8, -1i8] } -- output { [true, true, true, false, true] } entry lt (x:[]i8) (y:[]i8)= map2 (<) x y entry eq (x:[]i8) (y:[]i8)= map2 (==) x y entry lte (x:[]i8) (y:[]i8)= map2 (<=) x yfuthark-0.25.27/tests/primitive/i8_convop.fut000066400000000000000000000024141475065116200211410ustar00rootroot00000000000000-- Convert back and forth between different integer types. -- == -- entry: i8tobool -- input { [0i8, 64i8, 127i8, -128i8] } -- output { [false, true, true, true] } -- == -- entry: i8toi8 -- input { [0i8, 64i8, 127i8, -128i8] } -- output { [0i8, 64i8, 127i8, -128i8] } -- == -- entry: i8toi16 -- input { [0i8, 64i8, 127i8, -128i8] } -- output { [0i16, 64i16, 127i16, -128i16] } -- == -- entry: i8toi32 -- input { [0i8, 64i8, 127i8, -128i8] } -- output { [0i32, 64i32, 127i32, -128i32] } -- == -- entry: i8toi64 -- input { [0i8, 64i8, 127i8, -128i8] } -- output { [0i64, 64i64, 127i64, -128i64] } -- == -- entry: i8tou8 -- input { [0i8, 64i8, 127i8, -128i8] } -- output { [0u8, 64u8, 127u8, 128u8] } -- == -- entry: i8tou16 -- input { [0i8, 64i8, 127i8, -128i8] } -- output { [0u16, 64u16, 127u16, 128u16] } -- == -- entry: i8tou32 -- input { [0i8, 64i8, 127i8, -128i8] } -- output { [0u32, 64u32, 127u32, 128u32] } -- == -- entry: i8tou64 -- input { [0i8, 64i8, 127i8, -128i8] } -- output { [0u64, 64u64, 127u64, 128u64] } entry i8tobool = map (bool.i8) entry i8toi8 = map (i8.i8) entry i8toi16 = map (i16.i8) entry i8toi32 = map (i32.i8) entry i8toi64 = map (i64.i8) entry i8tou8 = map (u8.i8) entry i8tou16 = map (u16.i8) entry i8tou32 = map (u32.i8) entry i8tou64 = map (u64.i8) futhark-0.25.27/tests/primitive/i8_division.fut000066400000000000000000000012351475065116200214610ustar00rootroot00000000000000-- Test of division-like operators for i8 values. -- == -- entry: divide -- input { [7i8, -7i8, 7i8, -7i8] -- [3i8, 3i8, -3i8, -3i8] } -- output { [2i8, -3i8, -3i8, 2i8] } -- == -- entry: mod -- input { [7i8, -7i8, 7i8, -7i8] -- [3i8, 3i8, -3i8, -3i8] } -- output { [1i8, 2i8, -2i8, -1i8] } -- == -- entry: quot -- input { [7i8, -7i8, 7i8, -7i8] -- [3i8, 3i8, -3i8, -3i8] } -- output { [2i8, -2i8, -2i8, 2i8] } -- == -- entry: rem -- input { [7i8, -7i8, 7i8, -7i8] -- [3i8, 3i8, -3i8, -3i8] } -- output { [1i8, -1i8, 1i8, -1i8] } entry divide = map2 (i8./) entry mod = map2 (i8.%) entry quot = map2 (i8.//) entry rem = map2 (i8.%%) futhark-0.25.27/tests/primitive/i8_minmax.fut000066400000000000000000000004671475065116200211340ustar00rootroot00000000000000-- == -- entry: testMax -- input { [0i8, 1i8, -1i8, 1i8] -- [1i8, 1i8, 1i8, -1i8]} -- output { [1i8, 1i8, 1i8, 1i8] } -- == -- entry: testMin -- input { [0i8, 1i8, -1i8, 1i8] -- [1i8, 1i8, 1i8, -1i8]} -- output { [0i8, 1i8, -1i8, -1i8] } entry testMax = map2 i8.max entry testMin = map2 i8.min futhark-0.25.27/tests/primitive/i8_unop.fut000066400000000000000000000006731475065116200206230ustar00rootroot00000000000000-- Test unary operators for i8. -- == -- entry: negatei8 -- input { [0i8, 1i8, -1i8, 8i8, -8i8] } -- output { [0i8, -1i8, 1i8, -8i8, 8i8] } -- == -- entry: absi8 -- input { [0i8, 1i8, -1i8, 8i8, -8i8] } -- output { [0i8, 1i8, 1i8, 8i8, 8i8] } -- == -- entry: sgni8 -- input { [0i8, 1i8, -1i8, 8i8, -8i8] } -- output { [0i8, 1i8, -1i8, 1i8, -1i8] } entry negatei8 = map (\x : i8 -> -x) entry absi8 = map (i8.abs) entry sgni8 = map (i8.sgn) futhark-0.25.27/tests/primitive/ldexp.fut000066400000000000000000000010701475065116200203460ustar00rootroot00000000000000-- == -- entry: test_f16 -- input { [7f16, 7f16, -0f16, f16.inf, 1f16] [-4, 4, 10, -1, 1000] } -- output { [0.437500f16, 112f16, -0f16, f16.inf, f16.inf] } -- == -- entry: test_f32 -- input { [7f32, 7f32, -0f32, f32.inf, 1f32] [-4, 4, 10, -1, 1000] } -- output { [0.437500f32, 112f32, -0f32, f32.inf, f32.inf] } -- == -- entry: test_f64 -- input { [7f64, 7f64, -0f64, f64.inf, 1f64] [-4, 4, 10, -1, 10000] } -- output { [0.437500f64, 112f64, -0f64, f64.inf, f64.inf] } entry test_f16 = map2 f16.ldexp entry test_f32 = map2 f32.ldexp entry test_f64 = map2 f64.ldexp futhark-0.25.27/tests/primitive/lerp.fut000066400000000000000000000006721475065116200202030ustar00rootroot00000000000000-- == -- entry: lerpf64 -- input { [0.0, 0.0, 0.0, 0.0] -- [1.0, 10.0, 10.0, 10.0] -- [0.0, 0.25, 0.5, 0.75] } -- output { [0f64, 2.5f64, 5.0f64, 7.5f64] } -- == -- entry: lerpf32 -- input { [0.0f32, 0.0f32, 0.0f32, 0.0f32] -- [1.0f32, 10.0f32, 10.0f32, 10.0f32] -- [0.0f32, 0.25f32, 0.5f32, 0.75f32] } -- output { [0f32, 2.5f32, 5.0f32, 7.5f32] } entry lerpf64 = map3 f64.lerp entry lerpf32 = map3 f32.lerp futhark-0.25.27/tests/primitive/lgamma.fut000066400000000000000000000004071475065116200204730ustar00rootroot00000000000000-- == -- entry: lgammaf32 -- input { [1.0f32, 4.0f32] } -- output { [0f32, 1.7917594692280554f32] } -- == -- entry: lgammaf64 -- input { [1.0, 4.0] } -- output { [0f64, 1.7917594692280554f64] } entry lgammaf32 = map f32.lgamma entry lgammaf64 = map f64.lgamma futhark-0.25.27/tests/primitive/log32.fut000066400000000000000000000015301475065116200201610ustar00rootroot00000000000000-- == -- entry: logf32 -- input { [0.0f32, 2.718281828459045f32, 2f32, 10f32, f32.inf] } -- output { [-f32.inf, 1f32, 0.6931471805599453f32, 2.302585092994046f32, f32.inf] } -- == -- entry: log2f32 -- input { [0.0f32, 2.718281828459045f32, 2f32, 10f32, f32.inf] } -- output { [-f32.inf, 1.4426950408889634f32, 1f32, 3.321928094887362f32, f32.inf] } -- == -- entry: log10f32 -- input { [0.0f32, 2.718281828459045f32, 2f32, 10f32, f32.inf] } -- output { [-f32.inf, 0.4342944819032518f32, 0.3010299956639812f32, 1f32, f32.inf] } -- == -- entry: log1pf32 -- input { [-1.0f32, -1e-12f32, 0.0f32, 1e-23f32, 1.718281828459045f32, 1f32, f32.inf] } -- output { [-f32.inf, -1e-12f32, 0.0f32, 1e-23f32, 1.0f32, 0.6931471805599453f32, f32.inf] } entry logf32 = map f32.log entry log2f32 = map f32.log2 entry log10f32 = map f32.log10 entry log1pf32 = map f32.log1p futhark-0.25.27/tests/primitive/log64.fut000066400000000000000000000015341475065116200201720ustar00rootroot00000000000000-- == -- entry: logf64 -- input { [0.0f64, 2.718281828459045f64, 2f64, 10f64, f64.inf] } -- output { [-f64.inf, 1f64, 0.6931471805599453f64, 2.302585092994046f64, f64.inf] } -- == -- entry: log2f64 -- input { [0.0f64, 2.718281828459045f64, 2f64, 10f64, f64.inf] } -- output { [-f64.inf, 1.4426950408889634f64, 1f64, 3.321928094887362f64, f64.inf] } -- == -- entry: log10f64 -- input { [0.0f64, 2.718281828459045f64, 2f64, 10f64, f64.inf] } -- output { [-f64.inf, 0.4342944819032518f64, 0.3010299956639812f64, 1f64, f64.inf] } -- == -- entry: log1pf64 -- input { [-1.0f64, -1e-123f64, 0.0f64, 1e-234f64, 1.718281828459045f64, 1f64, f64.inf] } -- output { [-f64.inf, -1e-123f64, 0.0f64, 1e-234f64, 1.0f64, 0.6931471805599453f64, f64.inf] } entry logf64 = map f64.log entry log2f64 = map f64.log2 entry log10f64 = map f64.log10 entry log1pf64 = map f64.log1p futhark-0.25.27/tests/primitive/mad_hi.fut000066400000000000000000000103251475065116200204560ustar00rootroot00000000000000-- Test u8.mad_hi -- == -- entry: test_u8_mad_hi -- input { [10u8, 20u8, 2u8, 1u8, 2u8, 3u8, 2u8, 8u8 ] -- [10u8, 20u8, 127u8, 255u8, 255u8, 127u8, 128u8, 128u8] -- [0u8, 1u8, 2u8, 3u8, 4u8, 5u8, 7u8, 255u8] } -- output { [0u8, 2u8, 2u8, 3u8, 5u8, 6u8, 8u8, 3u8] } entry test_u8_mad_hi = map3 u8.mad_hi -- Test i8.mad_hi -- == -- entry: test_i8_mad_hi -- input { [10i8, 20i8, 2i8, 1i8, 2i8, 3i8, 2i8, 13i8] -- [10i8, 20i8, 127i8, -1i8, -1i8, 127i8, 128i8, 128i8] -- [ 0i8, 1i8, 2i8, 3i8, 4i8, 5i8, 6i8, 255i8] } -- output { [0i8, 2i8, 2i8, 2i8, 3i8, 6i8, 5i8, -8i8] } entry test_i8_mad_hi = map3 i8.mad_hi -- Test u16.mad_hi -- == -- entry: test_u16_mad_hi -- input { [10u16, 20u16, 2u16, 3u16, 2u16, 1u16, 2u16, 2u16, 2u16, 3u16, 65535u16] -- [10u16, 20u16, 127u16, 127u16, 128u16, 255u16, 255u16, 32768u16, 65535u16, 65535u16, 65535u16] -- [1u16, 2u16, 3u16, 4u16, 5u16, 6u16, 7u16, 8u16, 9u16, 10u16, 11u16] } -- output { [1u16, 2u16, 3u16, 4u16, 5u16, 6u16, 7u16, 9u16, 10u16, 12u16, 9u16] } entry test_u16_mad_hi = map3 u16.mad_hi -- Test i16.mad_hi -- == -- entry: test_i16_mad_hi -- input { [ 10i16, 20i16, 2i16, 3i16, 2i16, 1i16, 2i16, 2i16, 2i16, 3i16, -1i16] -- [ 10i16, 20i16, 127i16, 127i16, 128i16, 255i16, 255i16, 32768i16, -1i16, -1i16, -1i16] -- [250i16, 251i16, 252i16, 253i16, 254i16, 255i16, 256i16, 257i16, 258i16, 259i16, 260i16] } -- output { [250i16, 251i16, 252i16, 253i16, 254i16, 255i16, 256i16, 256i16, 257i16, 258i16, 260i16] } entry test_i16_mad_hi = map3 i16.mad_hi -- Test u32.mad_hi -- == -- entry: test_u32_mad_hi -- input { [10u32, 20u32, 2u32, 3u32, 2u32, 1u32,2u32, 2u32, 2u32, 3u32, 65535u32, 1u32, 2u32, 5u32, 4294967295u32] -- [10u32, 20u32, 127u32, 127u32, 128u32, 255u32, 255u32, 32768u32, 65535u32, 65535u32, 65535u32, 4294967295u32, 4294967295u32, 4294967295u32, 4294967295u32] -- [1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32, 13u32, 14u32, 15u32] } -- output { [1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32, 14u32, 18u32, 13u32] } entry test_u32_mad_hi = map3 u32.mad_hi -- Test i32.mad_hi -- == -- entry: test_i32_mad_hi -- input { [10i32, 20i32, 2i32, 3i32, 2i32, 1i32, 2i32, 2i32, 2i32, 3i32, 65535i32, 1i32, 2i32, 5i32, -1i32] -- [10i32, 20i32, 127i32, 127i32, 128i32, 255i32, 255i32, 32768i32, 65535i32, 65535i32, 65535i32, -1i32, -1i32, -1i32, -1i32] -- [0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32] } -- output { [0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, -1i32, -1i32,-1i32, 0i32] } entry test_i32_mad_hi = map3 i32.mad_hi -- Test u64.mad_hi -- == -- entry: test_u64_mad_hi -- input { [10u64, 20u64, 2u64, 3u64, 2u64, 1u64, 2u64, 2u64, 2u64, 3u64, 65535u64, 1u64, 2u64, 5u64, 4294967295u64, 1u64, 2u64, 18446744073709551615u64] -- [10u64, 20u64, 127u64, 127u64, 128u64, 255u64, 255u64, 32768u64, 65535u64, 65535u64, 65535u64, 4294967295u64, 4294967295u64, 4294967295u64, 4294967295u64, 18446744073709551615u64,18446744073709551615u64, 18446744073709551615u64] -- [1u64, 2u64, 3u64, 4u64, 5u64, 6u64, 7u64, 8u64, 9u64, 10u64, 11u64, 12u64, 13u64, 14u64, 15u64, 16u64, 17u64, 18u64] } -- output { [1u64, 2u64, 3u64, 4u64, 5u64, 6u64, 7u64, 8u64, 9u64, 10u64, 11u64, 12u64, 13u64, 14u64, 15u64, 16u64, 18u64, 16u64] } entry test_u64_mad_hi = map3 u64.mad_hi -- Test i64.mad_hi -- == -- entry: test_i64_mad_hi -- input { [10i64, 20i64, 2i64, 3i64, 2i64, 1i64, 2i64, 2i64, 2i64, 3i64, 65535i64, 1i64, 2i64, 5i64, 4294967295i64, 1i64, 2i64, -1i64] -- [10i64, 20i64, 127i64, 127i64, 128i64, 255i64, 255i64, 32768i64, 65535i64, 65535i64, 65535i64, 4294967295i64, 4294967295i64, 4294967295i64, 4294967295i64, -1i64, -1i64, -1i64] -- [1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64, 8i64, 9i64, 10i64, 11i64, 12i64, 13i64, 14i64, 15i64, 16i64, 17i64, 18i64] } -- output { [1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64, 8i64, 9i64, 10i64, 11i64, 12i64, 13i64, 14i64, 15i64, 15i64, 16i64, 18i64] } entry test_i64_mad_hi = map3 i64.mad_hi futhark-0.25.27/tests/primitive/mul_hi.fut000066400000000000000000000073071475065116200205200ustar00rootroot00000000000000-- Test u8.mul_hi -- == -- entry: test_u8_mul_hi -- input { [10u8, 20u8, 2u8, 1u8, 2u8, 3u8, 2u8] -- [10u8, 20u8, 127u8, 255u8, 255u8, 127u8, 128u8] } -- output { [0u8, 1u8, 0u8, 0u8, 1u8, 1u8, 1u8] } entry test_u8_mul_hi = map2 u8.mul_hi -- Test i8.mul_hi -- == -- entry: test_i8_mul_hi -- input { [10i8, 20i8, 2i8, 1i8, 2i8, 3i8, 2i8] -- [10i8, 20i8, 127i8, -1i8, -1i8, 127i8, 128i8] } -- output { [0i8, 1i8, 0i8, -1i8, -1i8, 1i8, -1i8] } entry test_i8_mul_hi = map2 i8.mul_hi -- Test u16.mul_hi -- == -- entry: test_u16_mul_hi -- input { [10u16, 20u16, 2u16, 3u16, 2u16, 1u16, 2u16, 2u16, 2u16, 3u16, 65535u16] -- [10u16, 20u16, 127u16, 127u16, 128u16, 255u16, 255u16, 32768u16, 65535u16, 65535u16, 65535u16] } -- output { [0u16, 0u16, 0u16, 0u16, 0u16, 0u16, 0u16, 1u16, 1u16, 2u16, 65534u16] } entry test_u16_mul_hi = map2 u16.mul_hi -- Test i16.mul_hi -- == -- entry: test_i16_mul_hi -- input { [10i16, 20i16, 2i16, 3i16, 2i16, 1i16, 2i16, 2i16, 2i16, 3i16, -1i16] -- [10i16, 20i16, 127i16, 127i16, 128i16, 255i16, 255i16, 32768i16, -1i16, -1i16, -1i16] } -- output { [0i16, 0i16, 0i16, 0i16, 0i16, 0i16, 0i16, -1i16, -1i16, -1i16, 0i16] } entry test_i16_mul_hi = map2 i16.mul_hi -- Test u32.mul_hi -- == -- entry: test_u32_mul_hi -- input { [10u32, 20u32, 2u32, 3u32, 2u32, 1u32, 2u32, 2u32, 2u32, 3u32, 65535u32, 1u32, 2u32, 5u32, 4294967295u32] -- [10u32, 20u32, 127u32, 127u32, 128u32, 255u32, 255u32, 32768u32, 65535u32, 65535u32, 65535u32, 4294967295u32, 4294967295u32, 4294967295u32, 4294967295u32] } -- output { [0u32, 0u32, 0u32, 0u32, 0u32, 0u32, 0u32, 0u32, 0u32, 0u32, 0u32, 0u32, 1u32, 4u32, 4294967294u32] } entry test_u32_mul_hi = map2 u32.mul_hi -- Test i32.mul_hi -- == -- entry: test_i32_mul_hi -- input { [10i32, 20i32, 2i32, 3i32, 2i32, 1i32, 2i32, 2i32, 2i32, 3i32, 65535i32, 1i32, 2i32, 5i32, -1i32] -- [10i32, 20i32, 127i32, 127i32, 128i32, 255i32, 255i32, 32768i32, 65535i32, 65535i32, 65535i32, -1i32, -1i32, -1i32, -1i32] } -- output { [0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, -1i32, -1i32, -1i32, 0i32] } entry test_i32_mul_hi = map2 i32.mul_hi -- Test u64.mul_hi -- == -- entry: test_u64_mul_hi -- input { [10u64, 20u64, 2u64, 3u64, 2u64, 1u64, 2u64, 2u64, 2u64, 3u64, 65535u64, 1u64, 2u64 5u64, 4294967295u64, 1u64, 2u64, 18446744073709551615u64] -- [10u64, 20u64, 127u64, 127u64, 128u64, 255u64, 255u64, 32768u64, 65535u64, 65535u64, 65535u64, 4294967295u64, 4294967295u64, 4294967295u64, 4294967295u64, 18446744073709551615u64, 18446744073709551615u64, 18446744073709551615u64] } -- output { [0u64, 0u64, 0u64, 0u64, 0u64, 0u64, 0u64, 0u64, 0u64, 0u64, 0u64, 0u64, 0u64, 0u64, 0u64, 0u64, 1u64, 18446744073709551614u64] } entry test_u64_mul_hi = map2 u64.mul_hi -- Test i64.mul_hi -- == -- entry: test_i64_mul_hi -- input { [10i64, 20i64, 2i64, 3i64, 2i64, 1i64, 2i64, 2i64, 2i64, 3i64, 65535i64, 1i64, 2i64, 5i64, 4294967295i64, 1i64, 2i64, -1i64] -- [10i64, 20i64, 127i64, 127i64, 128i64, 255i64, 255i64, 32768i64, 65535i64, 65535i64, 65535i64, 4294967295i64, 4294967295i64, 4294967295i64, 4294967295i64, -1i64, -1i64, -1i64] } -- output { [0i64, 0i64, 0i64, 0i64, 0i64, 0i64, 0i64, 0i64, 0i64, 0i64, 0i64, 0i64, 0i64, 0i64, 0i64, -1i64, -1i64, 0i64] } entry test_i64_mul_hi = map2 i64.mul_hi futhark-0.25.27/tests/primitive/naninf32.fut000066400000000000000000000034111475065116200206510ustar00rootroot00000000000000-- NaN and inf must work. -- == -- entry: eqNaN -- input { [2f32, f32.nan, f32.inf, -f32.inf] } -- output { [false, false, false, false] } -- == -- entry: ltNaN -- input { [2f32, f32.nan, f32.inf, -f32.inf] } -- output { [false, false, false, false] } -- == -- entry: lteNaN -- input { [2f32, f32.nan, f32.inf, -f32.inf] } -- output { [false, false, false, false] } -- == -- entry: ltInf -- input { [2f32, f32.nan, f32.inf, -f32.inf] } -- output { [true, false, false, true] } -- == -- entry: lteInf -- input { [2f32, f32.nan, f32.inf, -f32.inf] } -- output { [true, false, true, true] } -- == -- entry: diffInf -- input { [2f32, f32.nan, f32.inf, -f32.inf] } -- output { [true, false, false, false] } -- == -- entry: sumNaN -- input { [2f32, f32.nan, f32.inf, -f32.inf] } -- output { [true, true, true, true] } -- == -- entry: sumInf -- input { [2f32, f32.nan, f32.inf, -f32.inf] } -- output { [true, false, true, false] } -- == -- entry: log2 -- input { [2f32, f32.nan, f32.inf, -f32.inf] } -- output { [false, true, false, true] } -- == -- entry: log10 -- input { [10f32, f32.nan, f32.inf, -f32.inf] } -- output { [false, true, false, true] } -- == -- entry: log1p -- input { [-2f32, -1f32, 2f32, f32.nan, f32.inf, -f32.inf] } -- output { [true, false, false, true, false, true] } entry eqNaN = map (\x -> x == f32.nan) entry ltNaN = map (\x -> x < f32.nan) entry lteNaN = map (\x -> x <= f32.nan) entry ltInf = map (\x -> x < f32.inf) entry lteInf = map (\x -> x <= f32.inf) entry diffInf = map (\x -> x - f32.inf < x + f32.inf) entry sumNaN = map (\x -> f32.isnan (x + f32.nan)) entry sumInf = map (\x -> f32.isinf (x + f32.inf)) entry log2 = map (\x -> f32.isnan (f32.log2 (x))) entry log10 = map (\x -> f32.isnan (f32.log10 (x))) entry log1p = map (\x -> f32.isnan (f32.log1p (x))) futhark-0.25.27/tests/primitive/naninf64.fut000066400000000000000000000034111475065116200206560ustar00rootroot00000000000000-- NaN and inf must work. -- == -- entry: eqNaN -- input { [2f64, f64.nan, f64.inf, -f64.inf] } -- output { [false, false, false, false] } -- == -- entry: ltNaN -- input { [2f64, f64.nan, f64.inf, -f64.inf] } -- output { [false, false, false, false] } -- == -- entry: lteNaN -- input { [2f64, f64.nan, f64.inf, -f64.inf] } -- output { [false, false, false, false] } -- == -- entry: ltInf -- input { [2f64, f64.nan, f64.inf, -f64.inf] } -- output { [true, false, false, true] } -- == -- entry: lteInf -- input { [2f64, f64.nan, f64.inf, -f64.inf] } -- output { [true, false, true, true] } -- == -- entry: diffInf -- input { [2f64, f64.nan, f64.inf, -f64.inf] } -- output { [true, false, false, false] } -- == -- entry: sumNaN -- input { [2f64, f64.nan, f64.inf, -f64.inf] } -- output { [true, true, true, true] } -- == -- entry: sumInf -- input { [2f64, f64.nan, f64.inf, -f64.inf] } -- output { [true, false, true, false] } -- == -- entry: log2 -- input { [2f64, f64.nan, f64.inf, -f64.inf] } -- output { [false, true, false, true] } -- == -- entry: log10 -- input { [10f64, f64.nan, f64.inf, -f64.inf] } -- output { [false, true, false, true] } -- == -- entry: log1p -- input { [-2f64, -1f64, 2f64, f64.nan, f64.inf, -f64.inf] } -- output { [true, false, false, true, false, true] } entry eqNaN = map (\x -> x == f64.nan) entry ltNaN = map (\x -> x < f64.nan) entry lteNaN = map (\x -> x <= f64.nan) entry ltInf = map (\x -> x < f64.inf) entry lteInf = map (\x -> x <= f64.inf) entry diffInf = map (\x -> x - f64.inf < x + f64.inf) entry sumNaN = map (\x -> f64.isnan (x + f64.nan)) entry sumInf = map (\x -> f64.isinf (x + f64.inf)) entry log2 = map (\x -> f64.isnan (f64.log2 (x))) entry log10 = map (\x -> f64.isnan (f64.log10 (x))) entry log1p = map (\x -> f64.isnan (f64.log1p (x))) futhark-0.25.27/tests/primitive/nextafter.fut000066400000000000000000000010661475065116200212370ustar00rootroot00000000000000-- == -- entry: test_f16 -- input { [0f16, 0f16, -0f16, f16.nan] [1f16, -1f16, 0f16, f16.inf] } -- output { [0.0f16, -0.0f16, 0.0f16, f16.nan] } entry test_f16 = map2 f16.nextafter -- == -- entry: test_f32 -- input { [0f32, 0f32, -0f32, f32.nan] [1f32, -1f32, 0f32, f32.inf] } -- output { [1.0e-45f32, -1.0e-45f32, 0.0f32, f32.nan] } entry test_f32 = map2 f32.nextafter -- == -- entry: test_f64 -- input { [0f64, 0f64, -0f64, f64.nan] [1f64, -1f64, 0f64, f64.inf] } -- output { [5.0e-324f64, -5.0e-324f64, 0.0f64, f64.nan] } entry test_f64 = map2 f64.nextafter futhark-0.25.27/tests/primitive/popc.fut000066400000000000000000000030631475065116200201770ustar00rootroot00000000000000-- == -- entry: popci8 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } -- output { [0i32, 8i32, 8i32, 8i32, 8i32] } -- == -- entry: popci16 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } -- output { [0i32, 8i32, 16i32, 16i32, 16i32] } -- == -- entry: popci32 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } -- output { [0i32, 8i32, 16i32, 32i32, 32i32] } -- == -- entry: popci64 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } -- output { [0i32, 8i32, 16i32, 32i32, 64i32] } -- == -- entry: popcu8 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } -- output { [0i32, 8i32, 8i32, 8i32, 8i32] } -- == -- entry: popcu16 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } -- output { [0i32, 8i32, 16i32, 16i32, 16i32] } -- == -- entry: popcu32 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } -- output { [0i32, 8i32, 16i32, 32i32, 32i32] } -- == -- entry: popcu64 -- input { [0u64, 255u64, 65535u64, 4294967295u64, 18446744073709551615u64] } -- output { [0i32, 8i32, 16i32, 32i32, 64i32] } entry popci8 = map (\x -> i8.popc (i8.u64 x)) entry popci16 = map (\x -> i16.popc (i16.u64 x)) entry popci32 = map (\x -> i32.popc (i32.u64 x)) entry popci64 = map (\x -> i64.popc (i64.u64 x)) entry popcu8 = map (\x -> u8.popc (u8.u64 x)) entry popcu16 = map (\x -> u16.popc (u16.u64 x)) entry popcu32 = map (\x -> u32.popc (u32.u64 x)) entry popcu64 = map (\x -> u64.popc (u64.u64 x)) futhark-0.25.27/tests/primitive/round32.fut000066400000000000000000000010301475065116200205220ustar00rootroot00000000000000-- Rounding floats to whole numbers. -- == -- input { [-0.4999999701976776123046875f32, -0.5f32, -0.500000059604644775390625f32, 0f32, 0.4999999701976776123046875f32, -- 0.5f32, 0.500000059604644775390625f32, 1.390671161567e-309f32, 2.2517998136852485e+15f32, 4.503599627370497e+15f32, -- -f32.inf, f32.inf, f32.nan, -0f32] } -- output { [-0f32, -0f32, -1f32, 0f32, 0f32, 0f32, 1f32, 0f32, 2.251799813685249e+15f32, 4.503599627370497e+15f32, -- -f32.inf, f32.inf, f32.nan, -0f32] } def main = map f32.round futhark-0.25.27/tests/primitive/round64.fut000066400000000000000000000007701475065116200205410ustar00rootroot00000000000000-- Rounding floats to whole numbers. -- == -- input { [-0.49999999999999994f64, -0.5f64, -0.5000000000000001f64, 0f64, 0.49999999999999994f64, -- 0.5f64, 0.5000000000000001f64, 1.390671161567e-309f64, 2.2517998136852485e+15f64, 4.503599627370497e+15f64, -- -f64.inf, f64.inf, f64.nan, -0f64] } -- output { [-0f64, -0f64, -1f64, 0f64, 0f64, 0f64, 1f64, 0f64, 2.251799813685249e+15f64, 4.503599627370497e+15f64, -- -f64.inf, f64.inf, f64.nan, -0f64] } def main = map f64.round futhark-0.25.27/tests/primitive/signed_get_set_bit.fut000066400000000000000000000062411475065116200230600ustar00rootroot00000000000000-- Test the set_bit and get_bit functions for signed integers. -- == -- entry: test_i8_get -- input { [8i8, 8i8, 24i8, 0b010101i8, 0b11111111i8] -- [3, 2, 3, 3, 7] } -- output { [1, 0, 1, 0, 1] } -- == -- entry: test_i8_set0 -- input { [8i8, 8i8, 24i8, 0b010101i8, 0b11111111i8] -- [3, 2, 3, 3, 7] } -- output { [0i8, 8i8, 16i8, 0b010101i8, 0b01111111i8] } -- == -- entry: test_i8_set1 -- input { [8i8, 8i8, 24i8, 0b010101i8, 0b11111111i8] -- [3, 2, 3, 3, 7] } -- output { [8i8, 12i8, 24i8, 0b011101i8, 0b11111111i8] } entry test_i8_get = map2 (\a bit -> i8.get_bit bit a) entry test_i8_set0 = map2 (\a bit -> i8.set_bit bit a 0) entry test_i8_set1 = map2 (\a bit -> i8.set_bit bit a 1) -- == -- entry: test_i16_get -- input { [8i16, 8i16, 24i16, 0b0011001001010101i16, 0b1011011010010010i16] -- [3, 2, 3, 11, 13] } -- output { [1, 0, 1, 0, 1] } -- == -- entry: test_i16_set0 -- input { [8i16, 8i16, 24i16, 0b0011001001010101i16, 0b1011011010010010i16] -- [3, 2, 3, 11, 13] } -- output { [0i16, 8i16, 16i16, 0b0011001001010101i16, 0b1001011010010010i16] } -- == -- entry: test_i16_set1 -- input { [8i16, 8i16, 24i16, 0b0011001001010101i16, 0b1011011010010010i16] -- [3, 2, 3, 11, 13] } -- output { [8i16, 12i16, 24i16, 0b0011101001010101i16, 0b1011011010010010i16] } entry test_i16_get = map2 (\a bit -> i16.get_bit bit a) entry test_i16_set0 = map2 (\a bit -> i16.set_bit bit a 0) entry test_i16_set1 = map2 (\a bit -> i16.set_bit bit a 1) -- == -- entry: test_i32_get -- input { [8i32, 8i32, 24i32, 214783648i32, 214783648i32] -- [3, 2, 3, 5, 11] } -- output { [1, 0, 1, 1, 0] } -- == -- entry: test_i32_set0 -- input { [8i32, 8i32, 24i32, 214783648i32, 214783648i32] -- [3, 2, 3, 5, 11] } -- output { [0i32, 8i32, 16i32, 214783616i32, 214783648i32] } -- == -- entry: test_i32_set1 -- input { [8i32, 8i32, 24i32, 214783648i32, 214783648i32] -- [3, 2, 3, 5, 11] } -- output { [8i32, 12i32, 24i32, 214783648i32, 214785696i32] } entry test_i32_get = map2 (\a bit -> i32.get_bit bit a) entry test_i32_set0 = map2 (\a bit -> i32.set_bit bit a 0) entry test_i32_set1 = map2 (\a bit -> i32.set_bit bit a 1) -- == -- entry: test_i64_get -- input { [8i64, 8i64, 24i64, 4294967295i64, 4294967295i64] -- [3, 2, 3, 31, 30] } -- output { [1, 0, 1, 1, 1] } -- == -- entry: test_i64_set0 -- input { [8i64, 8i64, 24i64, 4294967295i64, 4294967295i64] -- [3, 2, 3, 31, 30] } -- output { [0i64, 8i64, 16i64, 2147483647i64, 3221225471i64] } -- == -- entry: test_i64_set1 -- input { [8i64, 8i64, 24i64, 4294967295i64, 4294967295i64] -- [3, 2, 3, 31, 30] } -- output { [8i64, 12i64, 24i64, 4294967295i64, 4294967295i64] } entry test_i64_get = map2 (\a bit -> i64.get_bit bit a) entry test_i64_set0 = map2 (\a bit -> i64.set_bit bit a 0) entry test_i64_set1 = map2 (\a bit -> i64.set_bit bit a 1) futhark-0.25.27/tests/primitive/sin32.fut000066400000000000000000000002761475065116200201770ustar00rootroot00000000000000-- Does the sin32 function work? -- == -- input { [0f32, -1f32, 3.1415927f32, -3.1415927f32] } -- output { [0f32, -0.84147096f32, -8.742278e-8f32, 8.742278e-8f32] } def main = map f32.sin futhark-0.25.27/tests/primitive/sin64.fut000066400000000000000000000002761475065116200202040ustar00rootroot00000000000000-- Does the sin64 function work? -- == -- input { [0f64, -1f64, 3.1415927f64, -3.1415927f64] } -- output { [0f64, -0.84147096f64, -8.742278e-8f64, 8.742278e-8f64] } def main = map f64.sin futhark-0.25.27/tests/primitive/sinh16.fut000066400000000000000000000003301475065116200203400ustar00rootroot00000000000000-- Does the f16.sinh function work? -- == -- input { [0f16, -1f16, 3.1415927f16, -3.1415927f16] } -- output { [0f16, -1.1752011936438014f16, 11.548739357257748f16, -11.548739357257748f16] } def main = map f16.sinh futhark-0.25.27/tests/primitive/sinh32.fut000066400000000000000000000003301475065116200203360ustar00rootroot00000000000000-- Does the f32.sinh function work? -- == -- input { [0f32, -1f32, 3.1415927f32, -3.1415927f32] } -- output { [0f32, -1.1752011936438014f32, 11.548739357257748f32, -11.548739357257748f32] } def main = map f32.sinh futhark-0.25.27/tests/primitive/sinh64.fut000066400000000000000000000003301475065116200203430ustar00rootroot00000000000000-- Does the f64.sinh function work? -- == -- input { [0f64, -1f64, 3.1415927f64, -3.1415927f64] } -- output { [0f64, -1.1752011936438014f64, 11.548739357257748f64, -11.548739357257748f64] } def main = map f64.sinh futhark-0.25.27/tests/primitive/tan32.fut000066400000000000000000000002241475065116200201610ustar00rootroot00000000000000-- Does the tan32 function work? -- == -- input { [0f32, 0.78539819f32, -0.78539819f32] } -- output { [0f32, 1f32, -1f32] } def main = map f32.tan futhark-0.25.27/tests/primitive/tan64.fut000066400000000000000000000002241475065116200201660ustar00rootroot00000000000000-- Does the tan64 function work? -- == -- input { [0f64, 0.78539819f64, -0.78539819f64] } -- output { [0f64, 1f64, -1f64] } def main = map f64.tan futhark-0.25.27/tests/primitive/tanh16.fut000066400000000000000000000002731475065116200203370ustar00rootroot00000000000000-- Does the f16.tanh function work? -- == -- input { [0f16, 0.78539819f16, -0.78539819f16] } -- output { [0f16, 0.6557942177943699f16, -0.6557942177943699f16] } def main = map f16.tanh futhark-0.25.27/tests/primitive/tanh32.fut000066400000000000000000000002731475065116200203350ustar00rootroot00000000000000-- Does the f32.tanh function work? -- == -- input { [0f32, 0.78539819f32, -0.78539819f32] } -- output { [0f32, 0.6557942177943699f32, -0.6557942177943699f32] } def main = map f32.tanh futhark-0.25.27/tests/primitive/tanh64.fut000066400000000000000000000002731475065116200203420ustar00rootroot00000000000000-- Does the f64.tanh function work? -- == -- input { [0f64, 0.78539819f64, -0.78539819f64] } -- output { [0f64, 0.6557942177943699f64, -0.6557942177943699f64] } def main = map f64.tanh futhark-0.25.27/tests/primitive/u16_binop.fut000066400000000000000000000014331475065116200210370ustar00rootroot00000000000000-- u16 test. entry add = map2 (u16.+) entry sub = map2 (u16.-) entry mul = map2 (u16.*) entry pow = map2 (u16.**) -- == -- entry: add -- input { [0u16, 2u16, 32767u16, 65535u16] -- [0u16, 2u16, 32767u16, 1u16] } -- output { [0u16, 4u16, 65534u16, 0u16] } -- == -- entry: sub -- input { [2u16, 0u16, 32767u16] -- [2u16, 127u16, 65534u16] } -- output { [0u16, 65409u16, 32769u16] } -- == -- entry: mul -- input { [2u16, 2u16, 256u16, 257u16] -- [3u16, 0u16, 256u16, 256u16] } -- output { [6u16, 0u16, 0u16, 256u16] } -- == -- entry: pow -- input { [2u16, 7u16, 7u16, 7u16, 7u16, 7u16, 7u16, 7u16, 7u16] -- [3u16, 0u16, 1u16, 2u16, 3u16, 4u16, 5u16, 6u16, 7u16] } -- output { [8u16, 1u16, 7u16, 49u16, 343u16, 2401u16, 16807u16, 52113u16, 37111u16] } futhark-0.25.27/tests/primitive/u16_cmpop.fut000066400000000000000000000011271475065116200210460ustar00rootroot00000000000000-- Test comparison of u16 values. -- == -- entry: lt -- input { [0u16, 1u16, 65535u16, 1u16] -- [0u16, 2u16, 1u16, 65535u16] } -- output { [false, true, false, true] } -- == -- entry: eq -- input { [0u16, 1u16, 65535u16, 1u16] -- [0u16, 2u16, 1u16, 65535u16] } -- output { [true, false, false, false] } -- == -- entry: lte -- input { [0u16, 1u16, 65535u16, 1u16] -- [0u16, 2u16, 1u16, 65535u16] } -- output { [true, true, false, true] } entry lt (x:[]u16) (y:[]u16)= map2 (<) x y entry eq (x:[]u16) (y:[]u16)= map2 (==) x y entry lte (x:[]u16) (y:[]u16)= map2 (<=) x y futhark-0.25.27/tests/primitive/u16_convop.fut000066400000000000000000000025721475065116200212410ustar00rootroot00000000000000-- Convert back and forth between different integer types. -- == -- entry: u16tobool -- input { [0u16, 64u16, 32767u16, 65535u16] } -- output { [false, true, true, true] } -- == -- entry: u16toi8 -- input { [0u16, 64u16, 32767u16, 65535u16] } -- output { [0i8, 64i8, 32767i8, -1i8] } -- == -- entry: u16toi16 -- input { [0u16, 64u16, 32767u16, 65535u16] } -- output { [0i16, 64i16, 32767i16, -1i16] } -- == -- entry: u16toi32 -- input { [0u16, 64u16, 32767u16, 65535u16] } -- output { [0i32, 64i32, 32767i32, 65535i32] } -- == -- entry: u16toi64 -- input { [0u16, 64u16, 32767u16, 65535u16] } -- output { [0i64, 64i64, 32767i64, 65535i64] } -- == -- entry: u16tou8 -- input { [0u16, 64u16, 32767u16, 65535u16] } -- output { [0u8, 64u8, 32767u8, 255u8] } -- == -- entry: u16tou16 -- input { [0u16, 64u16, 32767u16, 65535u16] } -- output { [0u16, 64u16, 32767u16, 65535u16] } -- == -- entry: u16tou32 -- input { [0u16, 64u16, 32767u16, 65535u16] } -- output { [0u32, 64u32, 32767u32, 65535u32] } -- == -- entry: u16tou64 -- input { [0u16, 64u16, 32767u16, 65535u16] } -- output { [0u64, 64u64, 32767u64, 65535u64] } entry u16tobool = map (bool.u16) entry u16toi8 = map (i8.u16) entry u16toi16 = map (i16.u16) entry u16toi32 = map (i32.u16) entry u16toi64 = map (i64.u16) entry u16tou8 = map (u8.u16) entry u16tou16 = map (u16.u16) entry u16tou32 = map (u32.u16) entry u16tou64 = map (u64.u16) futhark-0.25.27/tests/primitive/u16_division.fut000066400000000000000000000011001475065116200215430ustar00rootroot00000000000000-- Test of division-like operators for u16 values. -- == -- entry: divide -- input { [7u16, 32768u16] -- [3u16, 9u16] } -- output { [2u16, 3640u16] } -- == -- entry: mod -- input { [7u16, 32768u16] -- [3u16, 9u16] } -- output { [1u16, 8u16] } -- == -- entry: quot -- input { [7u16, 32768u16] -- [3u16, 9u16] } -- output { [2u16, 3640u16] } -- == -- entry: rem -- input { [7u16, 32768u16] -- [3u16, 9u16] } -- output { [1u16, 8u16] } entry divide = map2 (u16./) entry mod = map2 (u16.%) entry quot = map2 (u16.//) entry rem = map2 (u16.%%) futhark-0.25.27/tests/primitive/u16_minmax.fut000066400000000000000000000005451475065116200212240ustar00rootroot00000000000000-- == -- entry: testMax -- input { [0u16, 1u16, 65535u16, 1u16] -- [1u16, 1u16, 1u16, 65535u16]} -- output { [1u16, 1u16, 65535u16, 65535u16] } -- == -- entry: testMin -- input { [0u16, 1u16, 65535u16, 1u16] -- [1u16, 1u16, 1u16, 65535u16]} -- output { [0u16, 1u16, 1u16, 1u16] } entry testMax = map2 u16.max entry testMin = map2 u16.min futhark-0.25.27/tests/primitive/u16_unop.fut000066400000000000000000000010011475065116200207000ustar00rootroot00000000000000-- Test unary operators for u16. -- == -- entry: negateu16 -- input { [0u16, 1u16, 65535u16, 8u16, 65528u16] } -- output { [0u16, 65535u16, 1u16, 65528u16, 8u16] } -- == -- entry: absu16 -- input { [0u16, 1u16, 65535u16, 8u16, 65528u16] } -- output { [0u16, 1u16, 65535u16, 8u16, 65528u16] } -- == -- entry: sgnu16 -- input { [0u16, 1u16, 65535u16, 8u16, 65528u16] } -- output { [0u16, 1u16, 1u16, 1u16, 1u16] } entry negateu16 = map (\x : u16 -> -x) entry absu16 = map (u16.abs) entry sgnu16 = map (u16.sgn) futhark-0.25.27/tests/primitive/u32_binop.fut000066400000000000000000000015611475065116200210370ustar00rootroot00000000000000-- u32 test. entry add = map2 (u32.+) entry sub = map2 (u32.-) entry mul = map2 (u32.*) entry pow = map2 (u32.**) -- == -- entry: add -- input { [0u32, 2u32, 2147483647u32, 4294967295u32] -- [0u32, 2u32, 2147483647u32, 1u32] } -- output { [0u32, 4u32, 4294967294u32, 0u32] } -- == -- entry: sub -- input { [2u32, 0u32, 2147483647u32] -- [2u32, 127u32, 4294967295u32] } -- output { [0u32, 4294967169u32, 2147483648u32] } -- == -- entry: mul -- input { [2u32, 2u32, 262144u32, 262145u32] -- [3u32, 0u32, 262144u32, 262144u32] } -- output { [6u32, 0u32, 0u32, 262144u32] } -- == -- entry: pow -- input { [2u32, 47u32, 47u32, 47u32, 47u32, 47u32, 47u32, 47u32, 47u32] -- [3u32, 0u32, 1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32] } -- output { [8u32, 1u32, 47u32, 2209u32, 103823u32, 4879681u32, 229345007u32, 2189280737u32, 4111946831u32] } futhark-0.25.27/tests/primitive/u32_cmpop.fut000066400000000000000000000011651475065116200210460ustar00rootroot00000000000000-- Test comparison of u32 values. -- == -- entry: lt -- input { [0u32, 1u32, 4294967295u32, 1u32] -- [0u32, 2u32, 1u32, 4294967295u32] } -- output { [false, true, false, true] } -- == -- entry: eq -- input { [0u32, 1u32, 4294967295u32, 1u32] -- [0u32, 2u32, 1u32, 4294967295u32] } -- output { [true, false, false, false] } -- == -- entry: lte -- input { [0u32, 1u32, 4294967295u32, 1u32] -- [0u32, 2u32, 1u32, 4294967295u32] } -- output { [true, true, false, true] } entry lt (x:[]u32) (y:[]u32)= map2 (<) x y entry eq (x:[]u32) (y:[]u32)= map2 (==) x y entry lte (x:[]u32) (y:[]u32)= map2 (<=) x y futhark-0.25.27/tests/primitive/u32_convop.fut000066400000000000000000000027551475065116200212420ustar00rootroot00000000000000-- Convert back and forth between different integer types. -- == -- entry: u32tobool -- input { [0u32, 64u32, 2147483647u32, 4294967295u32] } -- output { [false, true, true, true] } -- == -- entry: u32toi8 -- input { [0u32, 64u32, 2147483647u32, 4294967295u32] } -- output { [0i8, 64i8, -1i8, -1i8] } -- == -- entry: u32toi16 -- input { [0u32, 64u32, 2147483647u32, 4294967295u32] } -- output { [0i16, 64i16, -1i16, -1i16] } -- == -- entry: u32toi32 -- input { [0u32, 64u32, 2147483647u32, 4294967295u32] } -- output { [0i32, 64i32, 2147483647i32, -1i32] } -- == -- entry: u32toi64 -- input { [0u32, 64u32, 2147483647u32, 4294967295u32] } -- output { [0i64, 64i64, 2147483647i64, 4294967295i64] } -- == -- entry: u32tou8 -- input { [0u32, 64u32, 2147483647u32, 4294967295u32] } -- output { [0u8, 64u8, 255u8, 255u8] } -- == -- entry: u32tou16 -- input { [0u32, 64u32, 2147483647u32, 4294967295u32] } -- output { [0u16, 64u16, 65535u16, 65535u16] } -- == -- entry: u32tou32 -- input { [0u32, 64u32, 2147483647u32, 4294967295u32] } -- output { [0u32, 64u32, 2147483647u32, 4294967295u32] } -- == -- entry: u32tou64 -- input { [0u32, 64u32, 2147483647u32, 4294967295u32] } -- output { [0u64, 64u64, 2147483647u64, 4294967295u64] } entry u32tobool = map (bool.u32) entry u32toi8 = map (i8.u32) entry u32toi16 = map (i16.u32) entry u32toi32 = map (i32.u32) entry u32toi64 = map (i64.u32) entry u32tou8 = map (u8.u32) entry u32tou16 = map (u16.u32) entry u32tou32 = map (u32.u32) entry u32tou64 = map (u64.u32) futhark-0.25.27/tests/primitive/u32_division.fut000066400000000000000000000011371475065116200215530ustar00rootroot00000000000000-- Test of division-like operators for u32 values. -- == -- entry: divide -- input { [7u32, 2147483648u32] -- [3u32, 9u32] } -- output { [2u32, 238609294u32] } -- == -- entry: mod -- input { [7u32, 2147483648u32] -- [3u32, 9u32] } -- output { [1u32, 2u32] } -- == -- entry: quot -- input { [7u32, 2147483648u32] -- [3u32, 9u32] } -- output { [2u32, 238609294u32] } -- == -- entry: rem -- input { [7u32, 2147483648u32] -- [3u32, 9u32] } -- output { [1u32, 2u32] } entry divide = map2 (u32./) entry mod = map2 (u32.%) entry quot = map2 (u32.//) entry rem = map2 (u32.%%) futhark-0.25.27/tests/primitive/u32_minmax.fut000066400000000000000000000006041475065116200212160ustar00rootroot00000000000000-- == -- entry: testMax -- input { [0u32, 1u32, 4294967295u32, 1u32] -- [1u32, 1u32, 1u32, 4294967295u32]} -- output { [1u32, 1u32, 4294967295u32, 4294967295u32] } -- == -- entry: testMin -- input { [0u32, 1u32, 4294967295u32, 1u32] -- [1u32, 1u32, 1u32, 4294967295u32]} -- output { [0u32, 1u32, 1u32, 1u32] } entry testMax = map2 u32.max entry testMin = map2 u32.min futhark-0.25.27/tests/primitive/u32_unop.fut000066400000000000000000000010641475065116200207070ustar00rootroot00000000000000-- Test unary operators for u32. -- == -- entry: negateu32 -- input { [0u32, 1u32, 4294967295u32, 8u32, 4294967288u32] } -- output { [0u32, 4294967295u32, 1u32, 4294967288u32, 8u32] } -- == -- entry: absu32 -- input { [0u32, 1u32, 4294967295u32, 8u32, 4294967288u32] } -- output { [0u32, 1u32, 4294967295u32, 8u32, 4294967288u32] } -- == -- entry: sgnu32 -- input { [0u32, 1u32, 4294967295u32, 8u32, 4294967288u32] } -- output { [0u32, 1u32, 1u32, 1u32, 1u32] } entry negateu32 = map (\x : u32 -> -x) entry absu32 = map (u32.abs) entry sgnu32 = map (u32.sgn) futhark-0.25.27/tests/primitive/u64_binop.fut000066400000000000000000000020031475065116200210340ustar00rootroot00000000000000-- u64 test. entry add = map2 (u64.+) entry sub = map2 (u64.-) entry mul = map2 (u64.*) entry pow = map2 (u64.**) -- == -- entry: add -- input { [0u64, 2u64, 9223372036854775807u64, 18446744073709551615u64] -- [0u64, 2u64, 9223372036854775807u64, 1u64] } -- output { [0u64, 4u64, 18446744073709551614u64, 0u64] } -- == -- entry: sub -- input { [2u64, 0u64, 9223372036854775808u64] -- [2u64, 127u64, 18446744073709551615u64] } -- output { [0u64, 18446744073709551489u64, 9223372036854775809u64] } -- == -- entry: mul -- input { [2u64, 2u64, 6442450941u64] -- [3u64, 0u64, 2147483647u64] } -- output { [6u64, 0u64, 13835058042397261827u64] } -- == -- entry: pow -- input { [2u64, 4021u64, 4021u64, 4021u64, 4021u64, 4021u64, 4021u64, 4021u64, 4021u64] -- [3u64, 0u64, 1u64, 2u64, 3u64, 4u64, 5u64, 6u64, 7u64] } -- output { [8u64, 1u64, 4021u64, 16168441u64, 65013301261u64, 261418484370481u64, -- 1051163725653704101u64, 2424947974056870057u64, 10834932764031245949u64] } futhark-0.25.27/tests/primitive/u64_cmpop.fut000066400000000000000000000012611475065116200210500ustar00rootroot00000000000000-- Test comparison of u64 values. -- == -- entry: lt -- input { [0u64, 1u64, 18446744073709551615u64, 1u64] -- [0u64, 2u64, 1u64, 18446744073709551615u64] } -- output { [false, true, false, true] } -- == -- entry: eq -- input { [0u64, 1u64, 18446744073709551615u64, 1u64] -- [0u64, 2u64, 1u64, 18446744073709551615u64] } -- output { [true, false, false, false] } -- == -- entry: lte -- input { [0u64, 1u64, 18446744073709551615u64, 1u64] -- [0u64, 2u64, 1u64, 18446744073709551615u64] } -- output { [true, true, false, true] } entry lt (x:[]u64) (y:[]u64)= map2 (<) x y entry eq (x:[]u64) (y:[]u64)= map2 (==) x y entry lte (x:[]u64) (y:[]u64)= map2 (<=) x y futhark-0.25.27/tests/primitive/u64_convop.fut000066400000000000000000000032451475065116200212420ustar00rootroot00000000000000-- Convert back and forth between different integer types. -- == -- entry: u64tobool -- input { [0u64, 64u64, 9223372036854775807u64, 18446744073709551615u64] } -- output { [false, true, true, true] } -- == -- entry: u64toi8 -- input { [0u64, 64u64, 9223372036854775807u64, 18446744073709551615u64] } -- output { [0i8, 64i8, -1i8, -1i8] } -- == -- entry: u64toi16 -- input { [0u64, 64u64, 9223372036854775807u64, 18446744073709551615u64] } -- output { [0i16, 64i16, -1i16, -1i16] } -- == -- entry: u64toi32 -- input { [0u64, 64u64, 9223372036854775807u64, 18446744073709551615u64] } -- output { [0i32, 64i32, -1i32, -1i32] } -- == -- entry: u64toi64 -- input { [0u64, 64u64, 9223372036854775807u64, 18446744073709551615u64] } -- output { [0i64, 64i64, 9223372036854775807i64, -1i64] } -- == -- entry: u64tou8 -- input { [0u64, 64u64, 9223372036854775807u64, 18446744073709551615u64] } -- output { [0u8, 64u8, 255u8, 255u8] } -- == -- entry: u64tou16 -- input { [0u64, 64u64, 9223372036854775807u64, 18446744073709551615u64] } -- output { [0u16, 64u16, 65535u16, 65535u16] } -- == -- entry: u64tou32 -- input { [0u64, 64u64, 9223372036854775807u64, 18446744073709551615u64] } -- output { [0u32, 64u32, 4294967295u32, 4294967295u32] } -- == -- entry: u64tou64 -- input { [0u64, 64u64, 9223372036854775807u64, 18446744073709551615u64] } -- output { [0u64, 64u64, 9223372036854775807u64, 18446744073709551615u64] } entry u64tobool = map (bool.u64) entry u64toi8 = map (i8.u64) entry u64toi16 = map (i16.u64) entry u64toi32 = map (i32.u64) entry u64toi64 = map (i64.u64) entry u64tou8 = map (u8.u64) entry u64tou16 = map (u16.u64) entry u64tou32 = map (u32.u64) entry u64tou64 = map (u64.u64) futhark-0.25.27/tests/primitive/u64_division.fut000066400000000000000000000012301475065116200215520ustar00rootroot00000000000000-- Test of division-like operators for u64 values. -- == -- entry: divide -- input { [7u64, 9223372036854775808u64] -- [3u64, 9u64] } -- output { [2u64, 1024819115206086200u64] } -- == -- entry: mod -- input { [7u64, 9223372036854775808u64] -- [3u64, 9u64] } -- output { [1u64, 8u64] } -- == -- entry: quot -- input { [7u64, 9223372036854775808u64] -- [3u64, 9u64] } -- output { [2u64, 1024819115206086200u64] } -- == -- entry: rem -- input { [7u64, 9223372036854775808u64] -- [3u64, 9u64] } -- output { [1u64, 8u64] } entry divide = map2 (u64./) entry mod = map2 (u64.%) entry quot = map2 (u64.//) entry rem = map2 (u64.%%) futhark-0.25.27/tests/primitive/u64_minmax.fut000066400000000000000000000007001475065116200212200ustar00rootroot00000000000000-- == -- entry: testMax -- input { [0u64, 1u64, 18446744073709551615u64, 1u64] -- [1u64, 1u64, 1u64, 18446744073709551615u64]} -- output { [1u64, 1u64, 18446744073709551615u64, 18446744073709551615u64] } -- == -- entry: testMin -- input { [0u64, 1u64, 18446744073709551615u64, 1u64] -- [1u64, 1u64, 1u64, 18446744073709551615u64]} -- output { [0u64, 1u64, 1u64, 1u64] } entry testMax = map2 u64.max entry testMin = map2 u64.min futhark-0.25.27/tests/primitive/u64_unop.fut000066400000000000000000000012311475065116200207100ustar00rootroot00000000000000-- Test unary operators for u64. -- == -- entry: negateu64 -- input { [0u64, 1u64, 18446744073709551615u64, 8u64, 18446744073709551608u64] } -- output { [0u64, 18446744073709551615u64, 1u64, 18446744073709551608u64, 8u64] } -- == -- entry: absu64 -- input { [0u64, 1u64, 18446744073709551615u64, 8u64, 18446744073709551608u64] } -- output { [0u64, 1u64, 18446744073709551615u64, 8u64, 18446744073709551608u64] } -- == -- entry: sgnu64 -- input { [0u64, 1u64, 18446744073709551615u64, 8u64, 18446744073709551608u64] } -- output { [0u64, 1u64, 1u64, 1u64, 1u64] } entry negateu64 = map (\x : u64 -> -x) entry absu64 = map (u64.abs) entry sgnu64 = map (u64.sgn) futhark-0.25.27/tests/primitive/u8_binop.fut000066400000000000000000000012711475065116200207600ustar00rootroot00000000000000-- u8 test. entry add = map2 (u8.+) entry sub = map2 (u8.-) entry mul = map2 (u8.*) entry pow = map2 (u8.**) -- == -- entry: add -- input { [0u8, 2u8, 127u8, 255u8] -- [0u8, 2u8, 127u8, 1u8] } -- output { [0u8, 4u8, 254u8, 0u8] } -- == -- entry: sub -- input { [2u8, 0u8, 127u8] -- [2u8, 127u8, 254u8] } -- output { [0u8, 129u8, 129u8] } -- == -- entry: mul -- input { [2u8, 2u8, 4u8, 5u8] -- [3u8, 0u8, 64u8, 64u8] } -- output { [6u8, 0u8, 0u8, 64u8] } -- == -- entry: pow -- input { [2u8, 2u8, 3u8, 3u8, 3u8, 3u8, 3u8, 3u8, 3u8] -- [3u8, 0u8, 1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8] } -- output { [8u8, 1u8, 3u8, 9u8, 27u8, 81u8, 243u8, 217u8, 139u8] } futhark-0.25.27/tests/primitive/u8_cmpop.fut000066400000000000000000000010541475065116200207660ustar00rootroot00000000000000-- Test comparison of u8 values. -- == -- entry: lt -- input { [0u8, 1u8, 255u8, 1u8] -- [0u8, 2u8, 1u8, 255u8] } -- output { [false, true, false, true] } -- == -- entry: eq -- input { [0u8, 1u8, 255u8, 1u8] -- [0u8, 2u8, 1u8, 255u8] } -- output { [true, false, false, false] } -- == -- entry: lte -- input { [0u8, 1u8, 255u8, 1u8] -- [0u8, 2u8, 1u8, 255u8] } -- output { [true, true, false, true] } entry lt (x:[]u8) (y:[]u8)= map2 (<) x y entry eq (x:[]u8) (y:[]u8)= map2 (==) x y entry lte (x:[]u8) (y:[]u8)= map2 (<=) x y futhark-0.25.27/tests/primitive/u8_convop.fut000066400000000000000000000023751475065116200211630ustar00rootroot00000000000000-- Convert back and forth between different integer types. -- == -- entry: u8tobool -- input { [0u8, 64u8, 127u8, 255u8] } -- output { [false, true, true, true] } -- == -- entry: u8toi8 -- input { [0u8, 64u8, 127u8, 255u8] } -- output { [0i8, 64i8, 127i8, -1i8] } -- == -- entry: u8toi16 -- input { [0u8, 64u8, 127u8, 255u8] } -- output { [0i16, 64i16, 127i16, 255i16] } -- == -- entry: u8toi32 -- input { [0u8, 64u8, 127u8, 255u8] } -- output { [0i32, 64i32, 127i32, 255i32] } -- == -- entry: u8toi64 -- input { [0u8, 64u8, 127u8, 255u8] } -- output { [0i64, 64i64, 127i64, 255i64] } -- == -- entry: u8tou8 -- input { [0u8, 64u8, 127u8, 255u8] } -- output { [0u8, 64u8, 127u8, 255u8] } -- == -- entry: u8tou16 -- input { [0u8, 64u8, 127u8, 255u8] } -- output { [0u16, 64u16, 127u16, 255u16] } -- == -- entry: u8tou32 -- input { [0u8, 64u8, 127u8, 255u8] } -- output { [0u32, 64u32, 127u32, 255u32] } -- == -- entry: u8tou64 -- input { [0u8, 64u8, 127u8, 255u8] } -- output { [0u64, 64u64, 127u64, 255u64] } entry u8tobool = map (bool.u8) entry u8toi8 = map (i8.u8) entry u8toi16 = map (i16.u8) entry u8toi32 = map (i32.u8) entry u8toi64 = map (i64.u8) entry u8tou8 = map (u8.u8) entry u8tou16 = map (u16.u8) entry u8tou32 = map (u32.u8) entry u8tou64 = map (u64.u8) futhark-0.25.27/tests/primitive/u8_division.fut000066400000000000000000000010271475065116200214740ustar00rootroot00000000000000-- Test of division-like operators for u8 values. -- == -- entry: divide -- input { [7u8, 128u8] -- [3u8, 9u8] } -- output { [2u8, 14u8] } -- == -- entry: mod -- input { [7u8, 128u8] -- [3u8, 9u8] } -- output { [1u8, 2u8] } -- == -- entry: quot -- input { [7u8, 128u8] -- [3u8, 9u8] } -- output { [2u8, 14u8] } -- == -- entry: rem -- input { [7u8, 128u8] -- [3u8, 9u8] } -- output { [1u8, 2u8] } entry divide = map2 (u8./) entry mod = map2 (u8.%) entry quot = map2 (u8.//) entry rem = map2 (u8.%%) futhark-0.25.27/tests/primitive/u8_minmax.fut000066400000000000000000000004761475065116200211500ustar00rootroot00000000000000-- == -- entry: testMax -- input { [0u8, 1u8, 255u8, 1u8] -- [1u8, 1u8, 1u8, 255u8]} -- output { [1u8, 1u8, 255u8, 255u8] } -- == -- entry: testMin -- input { [0u8, 1u8, 255u8, 1u8] -- [1u8, 1u8, 1u8, 255u8]} -- output { [0u8, 1u8, 1u8, 1u8] } entry testMax = map2 u8.max entry testMin = map2 u8.min futhark-0.25.27/tests/primitive/u8_unop.fut000066400000000000000000000007051475065116200206330ustar00rootroot00000000000000-- Test unary operators for u8. -- == -- entry: negateu8 -- input { [0u8, 1u8, 255u8, 8u8, 248u8] } -- output { [0u8, 255u8, 1u8, 248u8, 8u8] } -- == -- entry: absu8 -- input { [0u8, 1u8, 255u8, 8u8, 248u8] } -- output { [0u8, 1u8, 255u8, 8u8, 248u8] } -- == -- entry: sgnu8 -- input { [0u8, 1u8, 255u8, 8u8, 248u8] } -- output { [0u8, 1u8, 1u8, 1u8, 1u8] } entry negateu8 = map (\x : u8 -> -x) entry absu8 = map (u8.abs) entry sgnu8 = map (u8.sgn) futhark-0.25.27/tests/primitive/unsigned_get_set_bit.fut000066400000000000000000000071131475065116200234220ustar00rootroot00000000000000-- Test the set_bit and get_bit functions for unsigned integers. -- == -- entry: test_u8_get -- input { [8u8, 8u8, 24u8, 0b010101u8, 0b11111111u8] -- [3, 2, 3, 3, 7] } -- output { [1, 0, 1, 0, 1] } -- == -- entry: test_u8_set0 -- input { [8u8, 8u8, 24u8, 0b010101u8, 0b11111111u8] -- [3, 2, 3, 3, 7] } -- output { [0u8, 8u8, 16u8, 0b010101u8, 0b01111111u8] } -- == -- entry: test_u8_set1 -- input { [8u8, 8u8, 24u8, 0b010101u8, 0b11111111u8] -- [3, 2, 3, 3, 7] } -- output { [8u8, 12u8, 24u8, 0b011101u8, 0b11111111u8] } entry test_u8_get = map2 (\a bit -> u8.get_bit bit a) entry test_u8_set0 = map2 (\a bit -> u8.set_bit bit a 0) entry test_u8_set1 = map2 (\a bit -> u8.set_bit bit a 1) -- == -- entry: test_u16_get -- input { [8u16, 8u16, 24u16, 0b0011001001010101u16, 0b1011011010010010u16] -- [3, 2, 3, 11, 13] } -- output { [1, 0, 1, 0, 1] } -- == -- entry: test_u16_set0 -- input { [8u16, 8u16, 24u16, 0b0011001001010101u16, 0b1011011010010010u16] -- [3, 2, 3, 11, 13] } -- output { [0u16, 8u16, 16u16, 0b0011001001010101u16, 0b1001011010010010u16] } -- == -- entry: test_u16_set1 -- input { [8u16, 8u16, 24u16, 0b0011001001010101u16, 0b1011011010010010u16] -- [3, 2, 3, 11, 13] } -- output { [8u16, 12u16, 24u16, 0b0011101001010101u16, 0b1011011010010010u16] } entry test_u16_get = map2 (\a bit -> u16.get_bit bit a) entry test_u16_set0 = map2 (\a bit -> u16.set_bit bit a 0) entry test_u16_set1 = map2 (\a bit -> u16.set_bit bit a 1) -- == -- entry: test_u32_get -- input { [8u32, 8u32, 24u32, 0b0011001001010101u32, 0b11111111u32] -- [3, 2, 3, 11, 7] } -- output { [1, 0, 1, 0, 1] } -- == -- entry: test_u32_set0 -- input { [8u32, 8u32, 24u32, 0b0011001001010101u32, 0b11111111u32] -- [3, 2, 3, 11, 7] } -- output { [0u32, 8u32, 16u32, 0b0011001001010101u32, 0b01111111u32] } -- == -- entry: test_u32_set1 -- input { [8u32, 8u32, 24u32, 0b0011001001010101u32, 0b11111111u32] -- [3, 2, 3, 11, 7] } -- output { [8u32, 12u32, 24u32, 0b0011101001010101u32, 0b11111111u32] } entry test_u32_get = map2 (\a bit -> u32.get_bit bit a) entry test_u32_set0 = map2 (\a bit -> u32.set_bit bit a 0) entry test_u32_set1 = map2 (\a bit -> u32.set_bit bit a 1) -- == -- entry: test_u64_get -- input { [8u64, 8u64, 24u64, 0b0011001001010101u64, 0b11111111u64, 4294967295u64, 4294967295u64] -- [3, 2, 3, 11, 7, 31, 30] } -- output { [1, 0, 1, 0, 1, 1, 1] } -- == -- entry: test_u64_set0 -- input { [8u64, 8u64, 24u64, 0b0011001001010101u64, 0b11111111u64, 4294967295u64, 4294967295u64] -- [3, 2, 3, 11, 7, 31, 30] } -- output { [0u64, 8u64, 16u64, 0b0011001001010101u64, 0b01111111u64, 2147483647u64, 3221225471u64] } -- == -- entry: test_u64_set1 -- input { [8u64, 8u64, 24u64, 0b0011001001010101u64, 0b11111111u64, 4294967295u64, 4294967295u64] -- [3, 2, 3, 11, 7, 31, 30] } -- output { [8u64, 12u64, 24u64, 0b0011101001010101u64, 0b11111111u64, 4294967295u64, 4294967295u64] } entry test_u64_get = map2 (\a bit -> u64.get_bit bit a) entry test_u64_set0 = map2 (\a bit -> u64.set_bit bit a 0) entry test_u64_set1 = map2 (\a bit -> u64.set_bit bit a 1)futhark-0.25.27/tests/proj-error0.fut000066400000000000000000000001261475065116200174040ustar00rootroot00000000000000-- -- == -- error: field def main(x: (i32,i8,i16)): (i8,i16,i32) = (x.0, x.2, x.0) futhark-0.25.27/tests/proj-error1.fut000066400000000000000000000001011475065116200173760ustar00rootroot00000000000000-- Ambiguous projection. -- == -- error: ambiguous def f = (.a) futhark-0.25.27/tests/proj0.fut000066400000000000000000000003221475065116200162530ustar00rootroot00000000000000-- Does simple tuple projection work? -- -- == -- compiled input { 1i32 2i8 3i16 } -- output { 2i8 3i16 1i32 } def main (x0: i32) (x1: i8) (x2: i16): (i8,i16,i32) = let x = (x0, x1, x2) in (x.1, x.2, x.0) futhark-0.25.27/tests/proj1.fut000066400000000000000000000002271475065116200162600ustar00rootroot00000000000000-- Can we map a tuple projection? -- == -- input { [1,2] [3,4] } -- output { [1,2] } def main (xs: []i32) (ys: []i32): []i32 = map (.0) (zip xs ys) futhark-0.25.27/tests/proj2.fut000066400000000000000000000002521475065116200162570ustar00rootroot00000000000000-- Can we map a record projection? -- == -- input { [1,2] [3,4] } -- output { [1,2] } def main (xs: []i32) (ys: []i32): []i32 = map (.x) (map2 (\x y -> {x, y}) xs ys) futhark-0.25.27/tests/proj3.fut000066400000000000000000000002711475065116200162610ustar00rootroot00000000000000-- Can we map a deeper record projection? -- == -- input { [1,2] [3,4] } -- output { [1,2] } def main (xs: []i32) (ys: []i32): []i32 = map (.x.a) (map2 (\x y -> {x={a=x}, y}) xs ys) futhark-0.25.27/tests/proj4.fut000066400000000000000000000002661475065116200162660ustar00rootroot00000000000000-- Can we map a deeper tuple projection? -- == -- input { [1,2] [3,4] } -- output { [1,2] } def main (xs: []i32) (ys: []i32): []i32 = map (.0.1) (map2 (\x y -> ((x,x), y)) xs ys) futhark-0.25.27/tests/psums.fut000066400000000000000000000005151475065116200163740ustar00rootroot00000000000000-- == -- input { [[ 0, 1, 2, 3], [ 4, 5, 6, 7]] } -- output { [[0i32, 1i32, 5i32, 15i32], [4i32, 17i32, 45i32, 95i32]] } -- random input { [100][10]i32 } -- auto output -- structure gpu-mem { Alloc 3 } def psum = scan (+) 0 def main (xss: [][]i32) = #[incremental_flattening(only_intra)] map (psum >-> psum >-> psum) xss futhark-0.25.27/tests/quickmedian.fut000066400000000000000000000020131475065116200175120ustar00rootroot00000000000000-- Computing a median in Futhark using a parallel adaptation of -- Hoare's quickmedian algorithm. The pivot selection is naive, which -- may lead to O(n**2) behaviour. -- -- The implementation is not particularly fast; mostly due to the -- final iterations that operate on very small arrays, at which point -- the overhead of parallel execution becomes dominant. An -- improvement would be to switch to a sorting-based approach once the -- input size drops beneath some threshold. -- -- Oh, and it cannot handle empty inputs. -- -- == -- tags { no_csharp } -- input { [1] } -- output { 1 } -- input { [4, -8, 2, 2, 0, 0, 5, 9, -6, 2] } -- output { 0 } def quickmedian [n] (xs: [n]i32): i32 = let (_, ys) = loop (i, ys : []i32) = (0, xs) while length ys > 1 do let pivot = ys[length ys/2] let (lt, gte) = partition ( n/2 then (i, lt) else (i + length lt, gte) in ys[0] def main [n] (xs: [n]i32): i32 = quickmedian xs futhark-0.25.27/tests/quickselect.fut000066400000000000000000000014621475065116200175430ustar00rootroot00000000000000-- A port of the quick-select NESL implementation. As quick-sort, this -- algorithm uses a divide-and-conquerer approach, but it needs only -- recurse on one of the partitions as it will know in which one the -- looked-for value resides. -- -- Oh, and it cannot handle non-meaningful inputs. -- -- == -- tags { no_csharp } -- input { [1] 0i64 } output { 1 } -- input { [4, -8, 2, 2, 0, 0, 5, 9, -6, 2] 7i64 } output { 4 } def quickselect [n] (s: [n]i32) (k:i64): i32 = let (_, s) = loop (k, s) while length s > 1 do let pivot = s[length s/2] let (lt, gt, _) = partition2 (pivot) s in if k < length lt then (k, lt) else if k >= length s - length gt then (k - (length s - length gt), gt) else (0,[pivot]) in s[0] def main (s:[]i32) (k:i64) : i32 = quickselect s k futhark-0.25.27/tests/rand0.fut000066400000000000000000000021221475065116200162250ustar00rootroot00000000000000-- This test program demonstrates how to simply generate pseudorandom -- numbers in Futhark. This is useful for deterministically -- generating large amounts of test data for a program without -- actually passing in large arrays from outside. Note that the -- quality of the random numbers is very poor, but it's fast to -- execute and the code is simple. -- -- == -- input { 1i64 -50 50 } -- output { [26] } -- -- input { 10i64 -50 50 } -- output { [10, 38, 31, 12, 12, 0, 0, 23, -15, 37] } -- -- input { 10i64 0 1 } -- output { [0, 0, 0, 0, 1, 1, 0, 1, 0, 0] } -- From http://stackoverflow.com/a/12996028 def hash(x: i32): i32 = let x = ((x >> 16) ^ x) * 0x45d9f3b let x = ((x >> 16) ^ x) * 0x45d9f3b let x = ((x >> 16) ^ x) in x def rand_array (n: i64) (lower: i32) (upper: i32): [n]i32 = map (\(i: i64): i32 -> -- We hash i+n to ensure that a random length-n array is not a -- prefix of a random length-(n+m) array. hash(i32.i64 (i + n)) % (upper-lower+1) + lower) ( iota(n)) def main (x: i64) (lower: i32) (upper: i32): []i32 = rand_array x lower upper futhark-0.25.27/tests/random_test.fut000066400000000000000000000003371475065116200175460ustar00rootroot00000000000000-- Just a quick test whether futhark-test can generate random data. -- == -- random input { [100]i32 [100]i32 } auto output -- random input { [1000]i32 [1000]i32 } auto output def main xs ys = i32.product (map2 (*) xs ys) futhark-0.25.27/tests/random_test_bool.fut000066400000000000000000000002751475065116200205620ustar00rootroot00000000000000-- Can we read random boolean data? (Bools have strange -- representations for some backends.) -- == -- random input { [100]bool } auto output def main (bs: []bool) = reduce (&&) true bs futhark-0.25.27/tests/random_test_float.fut000066400000000000000000000001701475065116200207260ustar00rootroot00000000000000-- Testing random float constants. -- == -- random input { [1]f32 f32 1.5f32 } def main arr x y : f32 = arr[0] + x + y futhark-0.25.27/tests/range0.fut000066400000000000000000000034251475065116200164040ustar00rootroot00000000000000-- Basic ranges. -- == -- entry: test0 -- input { 0 5 } output { [0i32, 1i32, 2i32, 3i32, 4i32, 5i32] } -- input { 1 5 } output { [1i32, 2i32, 3i32, 4i32, 5i32] } -- input { 5 1 } error: 5...1 -- input { 5 0 } error: 5...0 -- == -- entry: test1 -- input { 0 5 } output { [0i32, 1i32, 2i32, 3i32, 4i32] } -- input { 1 5 } output { [1i32, 2i32, 3i32, 4i32] } -- input { 5 1 } error: 5..<1 -- input { 5 0 } error: 5..<0 -- == -- entry: test2 -- input { 0 5 } error: 0..>5 -- input { 1 5 } error: 1..>5 -- input { 5 1 } output { [5i32, 4i32, 3i32, 2i32] } -- input { 5 0 } output { [5i32, 4i32, 3i32, 2i32, 1i32] } -- == -- entry: test3 -- input { 0 1 5 } output { [0i32, 1i32, 2i32, 3i32, 4i32, 5i32] } -- input { 1 0 5 } error: 1..0...5 -- input { 5 4 1 } output { [5i32, 4i32, 3i32, 2i32, 1i32] } -- input { 5 0 0 } output { [5i32, 0i32] } -- input { 0 2 5 } output { [0i32, 2i32, 4i32] } -- input { 1 1 5 } error: 1..1...5 -- input { 1 0 1 } output { [1i32] } -- input { 1 2 1 } output { [1i32] } -- == -- entry: test4 -- input { 0 1 5 } output { [0i32, 1i32, 2i32, 3i32, 4i32] } -- input { 1 0 5 } error: 1..0..<5 -- input { 5 4 1 } error: 5..4..<1 -- input { 5 0 0 } error: 5..0..<0 -- input { 0 2 5 } output { [0i32, 2i32, 4i32] } -- == -- entry: test5 -- input { 0 1 5 } error: 0..1..>5 -- input { 1 0 5 } error: 1..0..>5 -- input { 5 4 1 } output { [5i32, 4i32, 3i32, 2i32] } -- input { 5 0 0 } output { [5i32] } -- input { 0 2 5 } error: 0..2..>5 entry test0 (start: i32) (end: i32) = start...end entry test1 (start: i32) (end: i32) = start..end entry test3 (start: i32) (step: i32) (end: i32) = start..step...end entry test4 (start: i32) (step: i32) (end: i32) = start..step..end futhark-0.25.27/tests/range1.fut000066400000000000000000000016161475065116200164050ustar00rootroot00000000000000-- Range exclusion -- == -- entry: test0 -- input { 10 20 } error: 10..10...20 -- input { 1 2 } error: 1..1...2 -- input { 20 10 } error: 20..20...10 -- input { 20 -10 } error: 20..20...-10 -- input { 5 0 } error: 5..5...0 -- input { 5 -1 } error: 5..5...-1 -- == -- entry: test1 -- input { 10 20 } error: 10..10..<20 -- input { 1 2 } error: 1..1..<2 -- input { 20 10 } error: 20..20..<10 -- input { 20 -10 } error: 20..20..<-10 -- input { 5 0 } error: 5..5..<0 -- input { 5 -1 } error: 5..5..<-1 -- == -- entry: test2 -- input { 10 20 } error: 10..10..>20 -- input { 1 2 } error: 1..1..>2 -- input { 20 10 } error: 20..20..>10 -- input { 20 -10 } error: 20..20..>-10 -- input { 5 0 } error: 5..5..>0 -- input { 5 -1 } error: 5..5..>-1 entry test0 (start: i32) (end: i32) = start..start...end entry test1 (start: i32) (end: i32) = start..start..end futhark-0.25.27/tests/rearrange0.fut000066400000000000000000000033331475065116200172540ustar00rootroot00000000000000-- == -- input { -- [[[0i32, 1i32], [2i32, 3i32], [4i32, 5i32], [6i32, 7i32]], -- [[8i32, 9i32], [10i32, 11i32], [12i32, 13i32], [14i32, 15i32]]] -- } -- output { -- [[[0, 2, 4, 6], [8, 10, 12, 14]], [[1, 3, 5, 7], [9, 11, 13, 15]]] -- } -- input { -- [[[7i32, 10i32, 2i32], -- [4i32, 3i32, 1i32], -- [8i32, 4i32, 4i32], -- [0i32, 9i32, 9i32], -- [0i32, 1i32, 3i32], -- [2i32, 5i32, 10i32], -- [0i32, 5i32, 0i32]], -- [[5i32, 7i32, 6i32], -- [2i32, 2i32, 3i32], -- [4i32, 4i32, 7i32], -- [6i32, 10i32, 10i32], -- [5i32, 6i32, 10i32], -- [5i32, 1i32, 6i32], -- [1i32, 3i32, 9i32]], -- [[3i32, 2i32, 9i32], -- [4i32, 0i32, 7i32], -- [4i32, 6i32, 5i32], -- [5i32, 8i32, 5i32], -- [10i32, 8i32, 7i32], -- [5i32, 8i32, 7i32], -- [3i32, 6i32, 8i32]]] -- } -- output { -- [[[7i32, 4i32, 8i32, 0i32, 0i32, 2i32, 0i32], [5i32, 2i32, -- 4i32, 6i32, 5i32, 5i32, 1i32], [3i32, 4i32, 4i32, 5i32, 10i32, -- 5i32, 3i32]], [[10i32, 3i32, 4i32, 9i32, 1i32, 5i32, 5i32], -- [7i32, 2i32, 4i32, 10i32, 6i32, 1i32, 3i32], [2i32, 0i32, 6i32, -- 8i32, 8i32, 8i32, 6i32]], [[2i32, 1i32, 4i32, 9i32, 3i32, 10i32, -- 0i32], [6i32, 3i32, 7i32, 10i32, 10i32, 6i32, 9i32], [9i32, 7i32, -- 5i32, 5i32, 7i32, 7i32, 8i32]]] -- } -- compiled random input { [2][10][10]i32 } auto output -- compiled random input { [2][64][4]i32 } auto output -- compiled random input { [2][4][64]i32 } auto output -- compiled random input { [2][64][64]i32 } auto output -- compiled random input { [64][2][64]i32 } auto output -- compiled random input { [64][64][2]i32 } auto output -- compiled random input { [128][128][128]i32 } auto output def main xss: [][][]i32 = xss |> map transpose |> transpose futhark-0.25.27/tests/rearrange1.fut000066400000000000000000000006001475065116200172470ustar00rootroot00000000000000-- This program tests that a transposition is carried out correctly -- even when the destination has an offset. The OpenCL code generator -- once messed this up. -- -- == -- input { [[1,2,3]] [[4,7],[5,8],[6,9]] [[10,11,12]] } -- output { [[1,2,3], [4,5,6],[7,8,9], [10,11,12]] } def main [n] (a: [][n]i32) (b: [n][]i32) (c: [][n]i32): [][]i32 = concat a (concat (transpose b) c) futhark-0.25.27/tests/rearrange2.fut000066400000000000000000000006421475065116200172560ustar00rootroot00000000000000-- == -- input { -- [[[1,10,100],[2,20,200],[3,30,300]],[[4,40,400],[5,50,500],[6,60,600]]] -- } -- output { -- [[[1, 2, 3], [10, 20, 30], [100, 200, 300]], [[4, 5, 6], [40, 50, 60], [400, -- 500, -- 600]]] -- } def main(a: [][][]i32): [][][]i32 = map transpose a futhark-0.25.27/tests/record-update0.fut000066400000000000000000000002451475065116200200430ustar00rootroot00000000000000-- Basic record update. -- == -- input { 0 0 } output { 2 0 } def main (x: i32) (y: i32): (i32, i32) = let r0 = {x, y} let r1 = r0 with x = 2 in (r1.x, r1.y) futhark-0.25.27/tests/record-update1.fut000066400000000000000000000002451475065116200200440ustar00rootroot00000000000000-- Type-changing record update. -- == -- error: i32.*bool def main (x: i32) (y: i32): (bool, i32) = let r0 = {x, y} let r1 = r0 with x = true in (r1.x, r1.y) futhark-0.25.27/tests/record-update2.fut000066400000000000000000000003261475065116200200450ustar00rootroot00000000000000-- Nested record update. -- == -- input { 0 0 } output { 1 0 0 0 } def main (x: i32) (y: i32): (i32, i32, i32, i32) = let r0 = {a={x,y}, b={x,y}} let r1 = r0 with a.x = 1 in (r1.a.x, r1.a.y, r1.b.x, r1.b.y) futhark-0.25.27/tests/record-update3.fut000066400000000000000000000003221475065116200200420ustar00rootroot00000000000000-- Record update of functional value. -- == -- input { 1 2 } output { 3 2 } def main (x: i32) (y: i32): (i32, i32) = let r0 = {a=(+x), b=(+y)} let r1 = r0 with a = (\v -> r0.a v + y) in (r1.a 0, r1.b 0) futhark-0.25.27/tests/record-update4.fut000066400000000000000000000003731475065116200200510ustar00rootroot00000000000000-- Chained record update in local function. -- == -- input { 0 1 } output { 0 1 } type record = {x: i32, y: i32} def main (x: i32) (y: i32) = let f b (r: record) = if b then r else r with x = 1 with y = 2 let r' = f true {x,y} in (r'.x, r'.y) futhark-0.25.27/tests/record-update5.fut000066400000000000000000000003111475065116200200420ustar00rootroot00000000000000-- Record updates of array fields is allowed to change the size. -- == def main [n][m] (xs: [n]i32) (ys: [m]i32): ([m]i32, [m]i32) = let r0 = {xs, ys} let r1 = r0 with xs = ys in (r1.xs, r1.ys) futhark-0.25.27/tests/record-update6.fut000066400000000000000000000003241475065116200200470ustar00rootroot00000000000000-- Inference of record in lambda. -- == -- error: Full type of type octnode = {body: i32} def f (octree: []octnode) (i: i32) = map (\n -> if n.body != i then n else n with body = 0) octree futhark-0.25.27/tests/records-error0.fut000066400000000000000000000001471475065116200200760ustar00rootroot00000000000000-- Duplicate fields in a record type is an error. -- == -- error: Duplicate type t = {x: i32, x: i32} futhark-0.25.27/tests/records-error2.fut000066400000000000000000000001671475065116200201020ustar00rootroot00000000000000-- Error if we try to access a non-existent field. -- == -- error: field.*c def main() = let r = {a=1,b=2} in r.c futhark-0.25.27/tests/records-error3.fut000066400000000000000000000002271475065116200201000ustar00rootroot00000000000000-- A record value must have at least the fields of its corresponding -- type. -- == -- error: match def main() = let r:{a:i32,b:i32} = {a=0} in 0 futhark-0.25.27/tests/records-error4.fut000066400000000000000000000002231475065116200200750ustar00rootroot00000000000000-- A record value must not have more fields than its corresponding -- type. -- == -- error: match def main() = let r:{a:i32} = {a=0,b=0} in 0 futhark-0.25.27/tests/records-error5.fut000066400000000000000000000002251475065116200201000ustar00rootroot00000000000000-- It is not OK to define the same field twice. -- == -- error: previously defined def main(x: i32) = let r = {a=x, b=x+1, a=x+2} in (r.a, r.b) futhark-0.25.27/tests/records-error6.fut000066400000000000000000000002041475065116200200760ustar00rootroot00000000000000-- It is not OK to match the same field twice. -- == -- error: Duplicate fields def main(x: i32) = let {x=a, x=b} = {x} in a+b futhark-0.25.27/tests/records-error7.fut000066400000000000000000000002411475065116200201000ustar00rootroot00000000000000-- Specific error message on record field mismatches. -- == -- error: Unshared fields: d, c. def f (v: {a: i32, b: i32, c: i32}) : {a: i32, b: i32, d: i32} = v futhark-0.25.27/tests/records-error8.fut000066400000000000000000000003671475065116200201120ustar00rootroot00000000000000-- Unification of variables with incompletely known and distinct fields. -- == -- error: must be a record with fields def sameconst '^a (_: a) (y: a) = y def main (x: i64) = let elems = sameconst (\s -> s.x < 0) (\s -> s.y > 0) in elems {x} futhark-0.25.27/tests/records-error9.fut000066400000000000000000000003771475065116200201140ustar00rootroot00000000000000-- Unification of incomplete record variable with non-record. -- == -- error: Cannot unify a record type with a non-record type def sameconst '^a (_: a) (y: a) = y def main (x: i64) = let elems = sameconst (\s -> s.x < 0) (\s -> s > 0) in elems {x} futhark-0.25.27/tests/records0.fut000066400000000000000000000002201475065116200167370ustar00rootroot00000000000000-- Do records work at all? -- == -- input { 2 } output { 1 3 } def f(x: i32) = {y=x+1,x=x-1} def main(x: i32) = let r = f x in (r.x, r.y) futhark-0.25.27/tests/records1.fut000066400000000000000000000002251475065116200167450ustar00rootroot00000000000000-- Tuples can be used like records. -- == -- input { 2 } output { 3 1 } def f(x: i32) = (x+1,x-1) def main(x: i32) = let r = f x in (r.0, r.1) futhark-0.25.27/tests/records10.fut000066400000000000000000000002621475065116200170260ustar00rootroot00000000000000def main (a: [4]u64) (b: [4]u64) (c: [4]u64) = (\p -> let x = length(p.0) let y = length(p.1.0) let z = length(p.1.1) in x < y && y < z) (a,(b,c)) futhark-0.25.27/tests/records11.fut000066400000000000000000000002251475065116200170260ustar00rootroot00000000000000-- We can index the result of a projection. -- == -- input { [1,2,3] 4 } -- output { 2 } def main (x: []i32) (y: i32) = let t = (x,y) in t.0[1] futhark-0.25.27/tests/records2.fut000066400000000000000000000001771475065116200167540ustar00rootroot00000000000000-- Records can be used like tuples. -- == -- input { 2 } output { 3 1 } def f(x: i32) = {0=x+1,1=x-1} def main(x: i32) = f x futhark-0.25.27/tests/records3.fut000066400000000000000000000002171475065116200167500ustar00rootroot00000000000000-- Test tuple patterns. -- == -- input { 2 } output { 3 1 } def f(x: i32) = {a=x+1,b=x-1} def main(x: i32) = let {a, b=c} = f x in (a,c) futhark-0.25.27/tests/records4.fut000066400000000000000000000003661475065116200167560ustar00rootroot00000000000000-- Record pattern in function parameter. -- == -- input { 2 } output { 1 3 } def f {a:i32, b=c:i32} = (a,c) -- And with a little fancier ascription. type t = {c: i32, d: i32} def g ({c,d}:t) = f {a=c,b=d} def main(x: i32) = g {c=x-1,d=x+1} futhark-0.25.27/tests/records5.fut000066400000000000000000000002031475065116200167450ustar00rootroot00000000000000-- Implicit field expressions. -- == -- input { 1 2 } output { 1 2 } def main (x: i32) (y: i32) = let r = {y,x} in (r.x, r.y) futhark-0.25.27/tests/records6.fut000066400000000000000000000002751475065116200167570ustar00rootroot00000000000000-- Unification of variables with incompletely known fields. def sameconst '^a (_: a) (y: a) = y def main (x: i64) = let elems = sameconst (\s -> s.x < 0) (\s -> s.x > 0) in elems {x} futhark-0.25.27/tests/records7.fut000066400000000000000000000004111475065116200167500ustar00rootroot00000000000000-- Even a large tuple works as a record. -- -- == -- input { 0 } output { 5 } def main(x: i32) = let (a,b,c,d,e,f,g,h,i,j,k,l,m,n) = {0=x+1, 1=x+2, 2=x+3, 3=x+4, 4=x+5, 5=x+6, 6=x+7, 7=x+8, 8=x+9, 9=x+10, 10=x+11, 11=x+12, 12=x+13, 13=x+14} in e futhark-0.25.27/tests/records8.fut000066400000000000000000000002271475065116200167560ustar00rootroot00000000000000-- Record field access can be nested. -- == -- input { 2 } output { 3 } def main(x: i32) = let r = {a=x,b=x+1,c={b=x,a={a={b=x+1}}}} in r.c.a.a.b futhark-0.25.27/tests/records9.fut000066400000000000000000000002151475065116200167540ustar00rootroot00000000000000-- Access a record field inside a module. -- == -- input { 1 } output { 3 } module m = { def r = { x = 2 } } def main (x: i32) = m.r.x + x futhark-0.25.27/tests/redomapNew.fut000066400000000000000000000015431475065116200173300ustar00rootroot00000000000000-- == -- input { -- [1,2,3,4,5] -- } -- output { -- [0, 30, 60] -- [[[0, 0, 0, 0, 0], -- [2, 2, 2, 2, 2], -- [4, 4, 4, 4, 4]], -- [[0, 0, 0, 0, 0], -- [4, 4, 4, 4, 4], -- [8, 8, 8, 8, 8]], -- [[0, 0, 0, 0, 0], -- [6, 6, 6, 6, 6], -- [12, 12, 12, 12, 12]], -- [[0, 0, 0, 0, 0], -- [8, 8, 8, 8, 8], -- [16, 16, 16, 16, 16]], -- [[0, 0, 0, 0, 0], -- [10, 10, 10, 10, 10], -- [20, 20, 20, 20, 20]]] -- } def main(arr: []i32): ([]i32,[][][]i32) = let vs = map (\(a: i32) -> map (\x: i32 -> 2*i32.i64 x*a ) (iota(3) ) ) arr in (reduce (\a b -> map2 (+) a b) ( replicate 3 0) vs, map (\(r: []i32) -> transpose (replicate 5 r)) vs) def main0(arr: []i32): i32 = reduce (+) 0 (map (2*) arr) futhark-0.25.27/tests/redundant-merge-variable.fut000066400000000000000000000004261475065116200220720ustar00rootroot00000000000000-- Test that we can remove an unused loop result as well as the -- computation that creates it. -- -- == -- structure { Loop/Negate 0 } def main(a: *[]i32, b: *[]i32, n: i32): []i32 = (loop ((a,b)) for i < n do let a[i] = a[i] + 1 let b[i] = -b[i] in (a,b)).0 futhark-0.25.27/tests/reg-tiling/000077500000000000000000000000001475065116200165455ustar00rootroot00000000000000futhark-0.25.27/tests/reg-tiling/batch-mm-lud.fut000066400000000000000000000015201475065116200215350ustar00rootroot00000000000000-- Batched matrix multiplication as it appears in Rodinia's LUD -- == -- no_python compiled random input { [128][16][16]f32 [128][16][16]f32 [128][128][16][16]f32 } auto output def main [m][b] (Bs: [m][b][b]f32) (As: [m][b][b]f32) (Css: [m][m][b][b]f32): *[m][m][b][b]f32 = let Btrs = map transpose Bs in map2(\(Cs: [m][b][b]f32) (A: [b][b]f32) : [m][b][b]f32 -> map2(\ (C: [b][b]f32) (Btr: [b][b]f32) : [b][b]f32 -> map2(\ (C_row: [b]f32) (A_row: [b]f32) : [b]f32 -> map2 (\ c B_col -> let prods = map2 (*) A_row B_col let sum = f32.sum prods in c - sum ) C_row Btr ) C A ) Cs Btrs ) Css As futhark-0.25.27/tests/reg-tiling/batch-mm-lud.fut.tuning000066400000000000000000000000361475065116200230410ustar00rootroot00000000000000main.suff_outer_par_0=4194304 futhark-0.25.27/tests/reg-tiling/mmm-tr.fut000066400000000000000000000007651475065116200205060ustar00rootroot00000000000000-- == -- -- compiled random input {[2001][4037]f32 [2021][4037]f32} auto output -- compiled random input {[1024][1024]f32 [1024][1024]f32} auto output -- compiled random input {[2048][4096]f32 [2048][4096]f32} auto output -- compiled random input {[2011][4011]f32 [1011][4011]f32} auto output -- let dotproduct [n] (x: [n]f32) (y: [n]f32) = map2 (*) x y |> reduce (+) 0 let main [m][n][q] (A: [m][q]f32) (B: [n][q]f32) : [m][n]f32 = map (\ Arow -> map (\Brow -> dotproduct Arow Brow) B) A futhark-0.25.27/tests/reg-tiling/mmm.fut000066400000000000000000000013111475065116200200470ustar00rootroot00000000000000-- == -- compiled random input {[2011][4011]f32 [4011][1011]f32} auto output -- compiled random input {[128][1024]f32 [1024][128]f32} auto output -- compiled random input {[128][4096]f32 [4096][128]f32} auto output -- -- input { -- [ [1.0f32, 2.0f32, 3.0f32], [3.0f32, 4.0f32, 5.0f32] ] -- [ [1.0f32, 2.0f32], [3.0f32, 4.0f32], [5.0f32, 6.0f32] ] -- } -- output { -- [ [22.0f32, 28.0f32], [40.0f32, 52.0f32] ] -- } -- -- compiled random input {[2048][4096]f32 [4096][2048]f32} auto output let dotproduct [n] (x: [n]f32) (y: [n]f32) = map2 (*) x y |> reduce (+) 0 let main [m][n][q] (A: [m][q]f32) (B: [q][n]f32) : [m][n]f32 = map (\ Arow -> map (\Bcol -> dotproduct Arow Bcol) (transpose B)) A futhark-0.25.27/tests/reg-tiling/reg3d-test1.fut000066400000000000000000000020441475065116200213270ustar00rootroot00000000000000-- Test register tiling when all input arrays are invariant to one parallel dimension -- This is a simple case in which there is no code after stream. -- -- == -- input { -- [ [1.0f32, 3.0f32], [2.0f32, 4.0f32] ] -- [ [5.0f32, 8.0f32], [6.0f32, 7.0f32] ] -- [ [1.0f32, 1.0f32], [9.0f32, 1.0f32] ] -- } -- output { -- [ [ [23.0f32, 29.0f32], [34.0f32, 44.0f32] ] -- , [ [18.0f32, 21.0f32], [24.0f32, 28.0f32] ] -- ] -- } -- no_python no_wasm compiled random input { [16][512]f32 [512][16]f32 [65536][512]f32 } auto output def pred (x : f32) : bool = x < 9.0 def dotprod_filt [n] (vct: [n]f32) (xs: [n]f32) (ys: [n]f32) : f32 = f32.sum (map3 (\v x y -> let z = x*y in let f = f32.bool (pred v) in z*f) vct xs ys) def matmul_filt [n][p][m] (xss: [n][p]f32) (yss: [p][m]f32) (vct: [p]f32) : [n][m]f32 = map (\xs -> map (dotprod_filt vct xs) (transpose yss)) xss def main [m][n][u] (ass: [n][u]f32) (bss: [u][n]f32) (fss: [m][u]f32) : [m][n][n]f32 = map (matmul_filt ass bss) fss futhark-0.25.27/tests/reg-tiling/reg3d-test1.fut.tuning000066400000000000000000000000321475065116200226250ustar00rootroot00000000000000main.suff_outer_par_0=200 futhark-0.25.27/tests/reg-tiling/reg3d-test2.fut000066400000000000000000000023341475065116200213320ustar00rootroot00000000000000-- Test register tiling when all input arrays are invariant to one parallel dimension. -- This is a more complex code, in which the code after the stream has both variant -- and invariant parts. -- -- == -- input { -- [ [1.0f32, 3.0f32], [2.0f32, 4.0f32] ] -- [ [5.0f32, 8.0f32], [6.0f32, 7.0f32] ] -- [ [1.0f32, 1.0f32], [9.0f32, 1.0f32] ] -- } -- output { -- [ [ [28.0f32, 40.0f32], [42.0f32, 58.0f32] ] -- , [ [39.0f32, 32.0f32], [48.0f32, 42.0f32] ] -- ] -- } -- -- no_python no_wasm compiled random input { [16][512]f32 [512][16]f32 [65536][512]f32 } auto output def pred (x : f32) : bool = x < 9.0 def dotprod_filt [n] (vct: [n]f32) (xs: [n]f32) (ys: [n]f32) (k : i64) : f32 = let s = f32.sum (map3 (\v x y -> let z = x*y in let f = f32.bool (pred v) in z*f) vct xs ys) let var_term = 2.0 * #[unsafe] vct[k] let inv_term = 3.0 * #[unsafe] xs[k] in s + inv_term + var_term def matmul_filt [n][p][m] (xss: [n][p]f32) (yss: [p][m]f32) (vct: [p]f32) : [n][m]f32 = map (\xs -> map2 (dotprod_filt vct xs) (transpose yss) (iota m)) xss def main [m][n][u] (ass: [n][u]f32) (bss: [u][n]f32) (fss: [m][u]f32) : [m][n][n]f32 = map (matmul_filt ass bss) fss futhark-0.25.27/tests/reg-tiling/reg3d-test2.fut.tuning000066400000000000000000000000321475065116200226260ustar00rootroot00000000000000main.suff_outer_par_0=200 futhark-0.25.27/tests/reg-tiling/reg3d-test3.fut000066400000000000000000000027721475065116200213410ustar00rootroot00000000000000-- Test register tiling when all input arrays are invariant to one parallel dimension. -- This is a more complex code, in which the code after the stream has both variant -- and invariant parts, and double return. -- -- == -- input { -- [ [1.0f32, 3.0f32], [2.0f32, 4.0f32] ] -- [ [5.0f32, 8.0f32], [6.0f32, 7.0f32] ] -- [ [1.0f32, 1.0f32], [9.0f32, 1.0f32] ] -- } -- output { -- [[[18.0f32, 18.0f32], [26.0f32, 30.0f32]], [[-3.0f32, 10.0f32], [ 0.0f32, 14.0f32]]] -- [[[28.0f32, 40.0f32], [42.0f32, 58.0f32]], [[39.0f32, 32.0f32], [48.0f32, 42.0f32]]] -- } -- -- no_python no_wasm compiled random input { [16][512]f32 [512][16]f32 [65536][512]f32 } auto output def pred (x : f32) : bool = x < 9.0 def dotprod_filt [n] (vct: [n]f32) (xs: [n]f32) (ys: [n]f32) (k : i64) : (f32,f32) = -- let s = f32.sum (map3 (\v x y -> let z = x*y in let f = f32.bool (pred v) in z*f) vct xs ys) let s = f32.sum (map3 (\y x v -> let z = x*y in let f = f32.bool (pred v) in z*f) ys xs vct) let var_term = 2.0 * #[unsafe] vct[k] let inv_term = 3.0 * #[unsafe] xs[k] let term = var_term + inv_term in (s - term, s + term) def matmul_filt [n][p][m] (xss: [n][p]f32) (yss: [p][m]f32) (vct: [p]f32) : [n][m](f32,f32) = map (\xs -> map2 (dotprod_filt vct xs) (transpose yss) (iota m)) xss def main [m][n][u] (ass: [n][u]f32) (bss: [u][n]f32) (fss: [m][u]f32) : ([m][n][n]f32,[m][n][n]f32) = map (matmul_filt ass bss) fss |> map (map unzip) |> map unzip |> unzip futhark-0.25.27/tests/reg-tiling/reg3d-test3.fut.tuning000066400000000000000000000000321475065116200226270ustar00rootroot00000000000000main.suff_outer_par_0=100 futhark-0.25.27/tests/reg-tiling/sgemm.fut000066400000000000000000000026371475065116200204050ustar00rootroot00000000000000-- SGEMM performs the matrix-matrix operation: -- C := alpha * A * B + beta * C -- == -- entry: main1 main2 -- no_python compiled random input { [1024][1024]f32 [1024][1024]f32 [1024][1024]f32 0.5f32 0.75f32} auto output entry main1 [n][m][q] (A: [n][q]f32) (B: [q][m]f32) (C: [n][m]f32) (alpha: f32) (beta: f32) : [n][m]f32 = map2(\Arow Crow -> map2(\Bcol c -> let x = map2 (*) Arow Bcol |> f32.sum in alpha * x + beta * c ) (transpose B) Crow ) A C entry main2 [n][m][q] (A: [n][q]f32) (B: [q][m]f32) (C: [n][m]f32) (alpha: f32) (beta: f32) : [n][m]f32 = map2(\Arow i -> map2(\Bcol j -> let y = beta * #[unsafe] C[i/2, j/2+1] let x = map2 (*) Bcol Arow |> f32.sum in alpha * x + y ) (transpose B) (iota m) ) A (iota n) -- == -- entry: main3 -- compiled random input { [1024][1024]f32 [1024][1024]f32 [1024][1024]f32 [1024][1024]f32 0.5f32 0.75f32} auto output entry main3 [n][m][q] (A: [n][q]f32) (B: [q][m]f32) (C: [n][m]f32) (D: [n][q]f32) (alpha: f32) (beta: f32) : [n][m]f32 = map3(\Arow Drow Crow -> map2(\Bcol c -> let x = map3 (\a d b -> a * d * b) Arow Drow Bcol |> f32.sum in alpha * x + beta * c ) (transpose B) Crow ) A D C futhark-0.25.27/tests/reg-tiling/sgemm.fut.tuning000066400000000000000000000001351475065116200216770ustar00rootroot00000000000000main1.suff_outer_par_0=1048576 main2.suff_outer_par_0=1048576 main3.suff_outer_par_0=1048576 futhark-0.25.27/tests/renameTest.fut000066400000000000000000000003331475065116200173320ustar00rootroot00000000000000-- == -- input { -- } -- output { -- 8 -- 11 -- } def f(x: (i32,i32), y: i32, z: i32): (i32,i32) = x def main: (i32,i32) = let x = 1 + 2 let x = (x + 5, 4+7) let (x, (y,z)) = (x, x) in f(x,y,z) futhark-0.25.27/tests/replicate0.fut000066400000000000000000000011141475065116200172510ustar00rootroot00000000000000-- Simple test to see whether we can properly replicate arrays. -- == -- input { -- 10i64 -- } -- output { -- [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] -- } def main(n: i64): [][]i32 = let x = 0..1.. map i32.i64) futhark-0.25.27/tests/replicate3.fut000066400000000000000000000003131475065116200172540ustar00rootroot00000000000000-- Reshape/replicate simplification test. -- == -- structure { Reshape 1 } def main [n] (b: [n]i32, m: i64) = let x = n * m let c = b :> [x]i32 let d = replicate (2*5*(n*m)) c in unflatten_3d d futhark-0.25.27/tests/reshape1.fut000066400000000000000000000005701475065116200167360ustar00rootroot00000000000000-- == -- input { -- [1i64,2i64,3i64,4i64,5i64,6i64,7i64,8i64,9i64] -- } -- output { -- [[1i64, 2i64, 3i64], [4i64, 5i64, 6i64], [7i64, 8i64, 9i64]] -- } -- input { [1i64,2i64,3i64] } -- error: (3)*cannot match shape.*\[1\]i64 def intsqrt(x: i64): i64 = i64.f32(f32.sqrt(f32.i64(x))) def main [n] (a: [n]i64): [][]i64 = unflatten (a :> [intsqrt n*intsqrt n]i64) futhark-0.25.27/tests/reshape3.fut000066400000000000000000000002571475065116200167420ustar00rootroot00000000000000-- Reshape with a polymorphic type, where only the outer dimensions -- are reshaped. def main [n][m][k] (A: [n][m][k]f32): []i32 = let A' = flatten A in map (\_ -> 0) A' futhark-0.25.27/tests/reshape4.fut000066400000000000000000000003261475065116200167400ustar00rootroot00000000000000-- Reshaping a concatenation. -- == -- input { 3i64 [1] [2,3] } output { [1, 2, 3] } -- input { 4i64 [1] [2,3] } error: cannot match shape -- structure { Reshape 0 } def main n xs ys: []i32 = xs ++ ys :> [n]i32 futhark-0.25.27/tests/returntype-error1.fut000066400000000000000000000003341475065116200206550ustar00rootroot00000000000000-- Test that the subtype property works for function return types. -- == -- error: def f(a: *[]i32): []i32 = a -- OK, because unique is a subtype of nonunique def g(a: []i32): *[]i32 = a -- Wrong! def main(): i32 = 0 futhark-0.25.27/tests/returntype-error2.fut000066400000000000000000000001371475065116200206570ustar00rootroot00000000000000-- Test basic detection of wrong function return types. -- == -- error: def main(): i32 = 2.0 futhark-0.25.27/tests/returntype-error3.fut000066400000000000000000000006171475065116200206630ustar00rootroot00000000000000-- This test demonstrates a limitation caused by the conservativity of -- the aliasing analyser. -- == -- error: -- The two arrays must not alias each other, because they are unique. def main(): (*[]i32, *[]i32) = let n = 10 let a = iota(n) in if 1 == 2 then (a, iota(n)) else (iota(n), a) -- The type checker decides that both components of the tuple may -- alias a, so we get an error. futhark-0.25.27/tests/returntype-error4.fut000066400000000000000000000002711475065116200206600ustar00rootroot00000000000000-- == -- error: Cannot generalise def foo n = let (m,_) = (n+1,true) in (iota ((m+1)+1), zip (iota (m+1)), zip (iota m)) def main n = let (xs, _, _) = foo n in xs futhark-0.25.27/tests/revdims.fut000066400000000000000000000007401475065116200166760ustar00rootroot00000000000000-- Reverse the dimensions of a 3D array. At one point the code -- generator thought this was a transposition. -- == -- input { [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, -- 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] } -- -- output { [[[0, 12], [4, 16], [8, 20]], [[1, 13], [5, 17], [9, 21]], -- [[2, 14], [6, 18], [10, 22]], [[3, 15], [7, 19], [11, 23]]] } def main [a][b][c] (A: [a][b][c]i32): [c][b][a]i32 = A |> map transpose |> transpose |> map transpose futhark-0.25.27/tests/reverse0.fut000066400000000000000000000003171475065116200167600ustar00rootroot00000000000000-- Reverse an array using indexing. -- -- == -- input { [1,2,3,4] } output { [4,3,2,1] } -- input { empty([0]i32) } output { empty([0]i32) } -- structure { Assert 0 } def main(as: []i32): []i32 = as[::-1] futhark-0.25.27/tests/reverse1.fut000066400000000000000000000002251475065116200167570ustar00rootroot00000000000000-- Reverse an inner array using indexing. -- -- == -- input { [[1,2],[3,4]] } output { [[2,1],[4,3]] } def main(as: [][]i32): [][]i32 = as[:,::-1] futhark-0.25.27/tests/rosettacode/000077500000000000000000000000001475065116200170205ustar00rootroot00000000000000futhark-0.25.27/tests/rosettacode/100doors.fut000066400000000000000000000013761475065116200211160ustar00rootroot00000000000000-- http://rosettacode.org/wiki/100_doors -- -- This is the "unoptimised" version, because the optimised one is too -- boring. Parametrised on number of doors. One optimisation done is -- to use write instead of a naive map. This allows us to only touch -- the doors we care about, while still remaining parallel. 0-indexes the doors. -- -- == -- input { 10i64 } -- output { [false, true, false, false, true, false, false, false, false, true] } def main(n: i64): [n]bool = loop is_open = replicate n false for i < n do let js = map (*i+1) (iota n) let flips = map (\j -> if j < n then !is_open[j] else true -- Doesn't matter. ) js in scatter is_open js flips futhark-0.25.27/tests/rosettacode/README.md000066400000000000000000000011241475065116200202750ustar00rootroot00000000000000Rosetta Code Implementations ============================ This directory contains Futhark implementations of tasks from [Rosetta Code](http://rosettacode.org). We try to keep this Git repository and the [list of implementations](http://rosettacode.org/wiki/Category:Futhark) in synch, but this is a manual process. Not all of these implementations are particularly good. Furthermore, there are many problems that Futhark is not able to solve well or at all, e.g. anything involving IO or most things about string processing. The latter is possible, but painful and not really in our domain. futhark-0.25.27/tests/rosettacode/agm.fut000066400000000000000000000004751475065116200203120ustar00rootroot00000000000000-- http://rosettacode.org/wiki/Arithmetic-geometric_mean -- -- == -- input { 1.0f64 2.0f64 } -- output { 1.456791f64 } def agm(a: f64, g: f64): f64 = let eps = 1.0E-16 let (a,_) = loop (a, g) while f64.abs(a-g) > eps do ((a+g) / 2.0, f64.sqrt (a*g)) in a def main (x: f64) (y: f64): f64 = agm(x,y) futhark-0.25.27/tests/rosettacode/almostprime.fut000066400000000000000000000013301475065116200220710ustar00rootroot00000000000000-- http://rosettacode.org/wiki/Almost_prime -- -- == -- input { 2 } -- output { [[2i32, 3i32, 5i32, 7i32, 11i32, 13i32, 17i32, 19i32, 23i32, 29i32], -- [4i32, 6i32, 9i32, 10i32, 14i32, 15i32, 21i32, 22i32, 25i32, 26i32]] } def kprime(n: i32, k: i32): bool = let (p,f) = (2, 0) let (n,_,f) = loop (n, p, f) while f < k && p*p <= n do let (n,f) = loop (n, f) while 0 == n % p do (n/p, f+1) in (n, p+1, f) in f + (if n > 1 then 1 else 0) == k def main(m: i32): [][]i32 = let f k = let ps = replicate 10 0 let (_,_,ps) = loop (i,c,ps) = (2,0,ps) while c < 10 do if kprime(i,k) then let ps[c] = i in (i+1, c+1, ps) else (i+1, c, ps) in ps in map f (1...m) futhark-0.25.27/tests/rosettacode/amicablepairs.fut000066400000000000000000000015571475065116200223440ustar00rootroot00000000000000-- http://rosettacode.org/wiki/Amicable_pairs -- -- This program is way too parallel and manifests all the pairs, which -- requires a giant amount of memory. Oh well. -- -- == -- tags { no_ispc } -- compiled input { 300i64 } -- output { [[220i32, 284i32]] } def divisors(n: i32): []i32 = filter (\x -> n%x == 0) (1...n/2+1) def amicable((n: i32, nd: i32), (m: i32, md: i32)): bool = n < m && nd == m && md == n def getPair [upper] (divs: [upper](i32, i32)) (flat_i: i64): ((i32,i32), (i32,i32)) = let i = flat_i / upper let j = flat_i % upper in (divs[i], divs[j]) def main(upper: i64): [][2]i32 = let range = map (1+) (iota upper) let divs = zip (map i32.i64 range) (map (\n -> reduce (+) 0 (divisors (i32.i64 n))) range) let amicable = filter amicable (map (getPair divs) (iota (upper*upper))) in map (\((x,_),(y,_)) -> [x, y]) amicable futhark-0.25.27/tests/rosettacode/arithmetic_means.fut000066400000000000000000000003611475065116200230540ustar00rootroot00000000000000-- http://rosettacode.org/wiki/Averages/Arithmetic_mean -- -- == -- input { [1.0,2.0,3.0,1.0] } -- output { 1.75f64 } -- Divide first to improve numerical behaviour. def main [n] (as: [n]f64): f64 = reduce (+) 0f64 (map (/f64.i64(n)) as) futhark-0.25.27/tests/rosettacode/binarydigits.fut000066400000000000000000000007731475065116200222370ustar00rootroot00000000000000-- https://rosettacode.org/wiki/Binary_digits -- -- We produce the binary number as a 64-bit integer whose digits are -- all 0s and 1s - this is because Futhark does not have any way to -- print, nor strings for that matter. -- -- == -- input { 5 } -- output { 101i64 } -- input { 50 } -- output { 110010i64 } -- input { 9000 } -- output { 10001100101000i64 } def main(x: i32): i64 = loop out = 0i64 for i < 32 do let digit = (x >> (31-i)) & 1 let out = (out * 10i64) + i64.i32(digit) in out futhark-0.25.27/tests/rosettacode/binarysearch.fut000066400000000000000000000011531475065116200222120ustar00rootroot00000000000000-- https://rosettacode.org/wiki/Binary_search -- -- This is a straightforward translation of the imperative iterative solution. -- -- == -- input { [1,2,3,4,5,6,8,9] 2 } -- output { 1i64 } def main [n] (as: [n]i32) (value: i32): i64 = let low = 0 let high = n-1 let (low, _) = loop ((low,high)) while low <= high do -- invariants: value > as[i] for all i < low -- value < as[i] for all i > high let mid = (low+high) / 2 in if as[mid] > value then (low, mid - 1) else if as[mid] < value then (mid + 1, high) else (mid, mid-1) -- Force termination. in low futhark-0.25.27/tests/rosettacode/complex.fut000066400000000000000000000021551475065116200212120ustar00rootroot00000000000000-- http://rosettacode.org/wiki/Arithmetic/Complex -- -- We implement a complex number as a pair of floats. This would be -- nicer with operator overloading. -- -- input { 0 1.0 1.0 3.14159 1.2 } -- output { 3.14159 2.2 } -- input { 1 1.0 1.0 3.14159 1.2 } -- output { 1.94159f64 4.34159f64 } -- input { 2 1.0 1.0 3.14159 1.2 } -- output { 0.5f64 -0.5f64 } -- input { 3 1.0 1.0 3.14159 1.2 } -- output { -1.0f64 -1.0f64 } -- input { 4 1.0 1.0 3.14159 1.2 } -- output { 1.0f64 -1.0f64 } type complex = (f64,f64) def complexAdd((a,b): complex) ((c,d): complex): complex = (a + c, b + d) def complexMult((a,b): complex) ((c,d): complex): complex = (a*c - b * d, a*d + b * c) def complexInv((r,i): complex): complex = let denom = r*r + i * i in (r / denom, -i / denom) def complexNeg((r,i): complex): complex = (-r, -i) def complexConj((r,i): complex): complex = (r, -i) def main (o: i32) (a: complex) (b: complex): complex = if o == 0 then complexAdd a b else if o == 1 then complexMult a b else if o == 2 then complexInv a else if o == 3 then complexNeg a else complexConj a futhark-0.25.27/tests/rosettacode/count_in_octal.fut000066400000000000000000000011301475065116200225330ustar00rootroot00000000000000-- https://rosettacode.org/wiki/Count_in_octal -- -- Futhark cannot print. Instead we produce an array of integers that -- look like octal numbers when printed in decimal. -- -- == -- input { 20i64 } -- output { [0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32, 10i32, 11i32, -- 12i32, 13i32, 14i32, 15i32, 16i32, 17i32, 20i32, 21i32, 22i32, 23i32] } def octal(x: i64): i32 = let (out,_,_) = loop (out,mult,x) = (0,1,i32.i64 x) while x > 0 do let digit = x % 8 let out = out + digit * mult in (out, mult * 10, x / 8) in out def main(n: i64): [n]i32 = map octal (iota n) futhark-0.25.27/tests/rosettacode/eulermethod.fut000066400000000000000000000054241475065116200220620ustar00rootroot00000000000000-- https://rosettacode.org/wiki/Euler_method#Common_Lisp -- -- Specialised to the cooling function. We produce an array of the -- temperature at each step subtracted from the analytically -- determined temperature (so we are computing the error). -- -- == -- input { 100.0 0.0 100.0 2.0 } -- output { -- [0.0f64, 0.7486588319044785f64, 1.2946993164580505f64, 1.679265585204547f64, -- 1.9360723079051994f64, 2.0926628953127633f64, 2.1714630634463745f64, -- 2.1906621307551717f64, 2.1649494013163135f64, 2.106128735048543f64, -- 2.0236308042860642f64, 1.924939486661728f64, 1.8159462635058006f64, -- 1.7012443098140402f64, 1.5843721158726787f64, 1.4680149205211634f64, -- 1.354170918183268f64, 1.2442880891724784f64, 1.139376563939308f64, -- 1.0401006401924064f64, 0.9468539045506148f64, 0.859820348366565f64, -- 0.7790238943202752f64, 0.7043683525135087f64, 0.6356694904077109f64, -- 0.572680620138101f64, 0.5151128711182906f64, 0.46265111832983763f64, -- 0.4149663712668854f64, 0.37172529010454625f64, 0.3325973799923041f64, -- 0.29726031781953566f64, 0.2654037853016753f64, 0.23673211521330373f64, -- 0.21096600187360082f64, 0.1878434807373779f64, 0.16712034361770378f64, -- 0.14857012436523576f64, 0.13198376366226583f64, 0.11716904003828077f64, -- 0.10394983650827427f64, 9.216529772764304e-2f64, 8.166892070760667e-2f64, -- 7.232761248797459e-2f64, 6.402074034058458e-2f64, 5.663919376000237e-2f64, -- 5.008447242299141e-2f64, 4.426781024452353e-2f64, 3.9109342442220196e-2f64, -- 3.453731999022125e-2f64] -- -- } -- -- input { 100.0 0.0 100.0 5.0 } -- output { -- [0.0f64, 4.375047177497066f64, 5.926824303312763f64, 6.025019928892426f64, -- 5.447257115328519f64, 4.619590476035608f64, 3.763003010238556f64, -- 2.9817046074496396f64, 2.315646506892435f64, 1.7712171223319615f64, -- 1.338771206215167f64, 1.0023162611494705f64, 0.7446054205897603f64, -- 0.5495998794267152f64, 0.4034719393409745f64, 0.2948359128180762f64, -- 0.21460148389295242f64, 0.15566929274598706f64, 0.11258571548996343f64, -- 8.121463510024185e-2f64] -- } -- -- input { 100.0 0.0 100.0 10.0 } -- output { -- [0.0f64, 15.726824303312767f64, 12.52775711532852f64, 7.636514260238556f64, -- 4.216805010017435f64, 2.221390673785482f64, 1.1413261456382173f64, -- 0.5782306456739441f64, 0.290580297318634f64, 0.14532974216231054f64] -- } def analytic(t0: f64) (time: f64): f64 = 20.0 + (t0 - 20.0) * f64.exp(-0.07*time) def cooling(_time: f64) (temperature: f64): f64 = -0.07 * (temperature-20.0) def main(t0: f64) (a: f64) (b: f64) (h: f64): []f64 = let steps = i64.f64 ((b-a)/h) let temps = replicate steps 0.0 let (_,temps) = loop (t,temps)=(t0,temps) for i < steps do let x = a + f64.i64 i * h let temps[i] = f64.abs(t-analytic t0 x) in (t + h * cooling x t, temps) in temps futhark-0.25.27/tests/rosettacode/even_or_odd.fut000066400000000000000000000003551475065116200220260ustar00rootroot00000000000000-- https://rosettacode.org/wiki/Even_or_odd -- -- true if even. -- == -- input { 0 } output { true } -- input { 1 } output { false } -- input { 10 } output { true } -- input { 11 } output { false } def main(x: i32): bool = (x & 1) == 0 futhark-0.25.27/tests/rosettacode/fact_iterative.fut000066400000000000000000000002431475065116200225300ustar00rootroot00000000000000-- == -- input { -- 10 -- } -- output { -- 3628800 -- } def fact(n: i32): i32 = loop out = 1 for i < n do out * (i+1) def main(n: i32): i32 = fact(n) futhark-0.25.27/tests/rosettacode/fibonacci_iterative.fut000066400000000000000000000004501475065116200235300ustar00rootroot00000000000000-- https://rosettacode.org/wiki/Fibonacci_sequence -- == -- input { 0 } output { 0 } -- input { 1 } output { 1 } -- input { 2 } output { 1 } -- input { 3 } output { 2 } -- input { 40 } output { 102334155 } def main(n: i32): i32 = let (a,_) = loop (a,b) = (0,1) for _i < n do (b, a + b) in a futhark-0.25.27/tests/rosettacode/filter.fut000066400000000000000000000005031475065116200210230ustar00rootroot00000000000000-- https://rosettacode.org/wiki/Filter -- -- Selects all even numbers from an array. -- -- == -- input { [1, 2, 3, 4, 5, 6, 7, 8, 9] } -- output { [2, 4, 6, 8] } -- input { empty([0]i32) } -- output { empty([0]i32) } -- input { [1,3] } -- output { empty([0]i32) } def main(as: []i32): []i32 = filter (\x -> x%2 == 0) as futhark-0.25.27/tests/rosettacode/for.fut000066400000000000000000000005461475065116200203330ustar00rootroot00000000000000-- https://rosettacode.org/wiki/Loops/For -- -- Futhark does not have I/O, so this program simply counts in the -- inner loop. -- == -- input { 10i64 } -- output { [0i64, 1i64, 3i64, 6i64, 10i64, 15i64, 21i64, 28i64, 36i64, 45i64] } def main(n: i64): [n]i64 = loop a = replicate n 0 for i < n do (let a[i] = loop s = 0 for j < i+1 do s + j in a) futhark-0.25.27/tests/rosettacode/greatest_element_of_list.fut000066400000000000000000000000641475065116200246060ustar00rootroot00000000000000def main (xs: []f64) = reduce f64.max (-f64.inf) xs futhark-0.25.27/tests/rosettacode/hailstone.fut000066400000000000000000000037551475065116200215400ustar00rootroot00000000000000-- https://rosettacode.org/wiki/Hailstone_sequence -- -- Does not use memoization, but is instead parallel for task #3. -- -- == -- compiled input { 27 100000 } -- output { -- [27i32, 82i32, 41i32, 124i32, 62i32, 31i32, 94i32, 47i32, 142i32, -- 71i32, 214i32, 107i32, 322i32, 161i32, 484i32, 242i32, 121i32, -- 364i32, 182i32, 91i32, 274i32, 137i32, 412i32, 206i32, 103i32, -- 310i32, 155i32, 466i32, 233i32, 700i32, 350i32, 175i32, 526i32, -- 263i32, 790i32, 395i32, 1186i32, 593i32, 1780i32, 890i32, -- 445i32, 1336i32, 668i32, 334i32, 167i32, 502i32, 251i32, 754i32, -- 377i32, 1132i32, 566i32, 283i32, 850i32, 425i32, 1276i32, -- 638i32, 319i32, 958i32, 479i32, 1438i32, 719i32, 2158i32, -- 1079i32, 3238i32, 1619i32, 4858i32, 2429i32, 7288i32, 3644i32, -- 1822i32, 911i32, 2734i32, 1367i32, 4102i32, 2051i32, 6154i32, -- 3077i32, 9232i32, 4616i32, 2308i32, 1154i32, 577i32, 1732i32, -- 866i32, 433i32, 1300i32, 650i32, 325i32, 976i32, 488i32, 244i32, -- 122i32, 61i32, 184i32, 92i32, 46i32, 23i32, 70i32, 35i32, -- 106i32, 53i32, 160i32, 80i32, 40i32, 20i32, 10i32, 5i32, 16i32, -- 8i32, 4i32, 2i32, 1i32] -- -- 351i32 -- } def hailstone_step(x: i32): i32 = if (x % 2) == 0 then x/2 else (3*x) + 1 def hailstone_seq(x: i32): []i32 = let capacity = 100 let i = 1 let steps = replicate capacity (-1) let steps[0] = x let (_,i,steps,_) = loop ((capacity,i,steps,x)) while x != 1 do let (steps, capacity) = if i == capacity then (concat steps (replicate capacity (-1)), capacity * 2) else (steps, capacity) let x = hailstone_step x let steps[i] = x in (capacity, i+1, steps, x) in take i steps def hailstone_len(x: i32): i32 = (loop (i,x)=(1,x) while x != 1 do (i+1, hailstone_step x)).0 def max (x: i32) (y: i32): i32 = if x < y then y else x def main (x: i32) (n: i32): ([]i32, i32) = (hailstone_seq x, reduce max 0 (map hailstone_len (map (1+) (map i32.i64 (iota (i64.i32 n-1)))))) futhark-0.25.27/tests/rosettacode/integer_sequence.fut000066400000000000000000000004651475065116200230720ustar00rootroot00000000000000-- https://rosettacode.org/wiki/Integer_sequence -- -- Infinite loops cannot produce results in Futhark, so this program -- accepts an input indicating how many integers to generate. -- -- == -- input { 10i64 } output { [0i64,1i64,2i64,3i64,4i64,5i64,6i64,7i64,8i64,9i64] } def main(n: i64): [n]i64 = iota n futhark-0.25.27/tests/rosettacode/life.fut000066400000000000000000000046501475065116200204640ustar00rootroot00000000000000-- Simple game of life implementation with a donut world. Tested with -- a glider running for four iterations. -- -- http://rosettacode.org/wiki/Conway's_Game_of_Life -- -- == -- input { -- [[0, 0, 0, 0, 0], -- [0, 0, 1, 0, 0], -- [0, 0, 0, 1, 0], -- [0, 1, 1, 1, 0], -- [0, 0, 0, 0, 0]] -- 4 -- } -- output { -- [[0, 0, 0, 0, 0], -- [0, 0, 0, 0, 0], -- [0, 0, 0, 1, 0], -- [0, 0, 0, 0, 1], -- [0, 0, 1, 1, 1]] -- } -- input { -- [[0, 0, 0, 0, 0], -- [0, 0, 1, 0, 0], -- [0, 0, 0, 1, 0], -- [0, 1, 1, 1, 0], -- [0, 0, 0, 0, 0]] -- 8 -- } -- output { -- [[1, 0, 0, 1, 1], -- [0, 0, 0, 0, 0], -- [0, 0, 0, 0, 0], -- [0, 0, 0, 0, 1], -- [1, 0, 0, 0, 0]] -- } def bint: bool -> i32 = i32.bool def intb : i32 -> bool = bool.i32 def to_bool_board(board: [][]i32): [][]bool = map (\r -> map intb r) board def to_int_board(board: [][]bool): [][]i32 = map (\r -> map bint r) board def all_neighbours [n][m] (world: [n][m]bool): [n][m]i32 = let ns = map (rotate (-1)) world let ss = map (rotate 1) world let ws = rotate (-1) world let es = rotate 1 world let nws = map (rotate (-1)) ws let nes = map (rotate (-1)) es let sws = map (rotate 1) ws let ses = map (rotate 1) es in map3 (\(nws_r, ns_r, nes_r) (ws_r, world_r, es_r) (sws_r, ss_r, ses_r) -> map3 (\(nw,n,ne) (w,_,e) (sw,s,se) -> bint nw + bint n + bint ne + bint w + bint e + bint sw + bint s + bint se) (zip3 nws_r ns_r nes_r) (zip3 ws_r world_r es_r) (zip3 sws_r ss_r ses_r)) (zip3 nws ns nes) (zip3 ws world es) (zip3 sws ss ses) def iteration [n][m] (board: [n][m]bool): [n][m]bool = let lives = all_neighbours(board) in map2 (\(lives_r: []i32) (board_r: []bool) -> map2 (\(neighbors: i32) (alive: bool): bool -> if neighbors < 2 then false else if neighbors == 3 then true else if alive && neighbors < 4 then true else false) lives_r board_r) lives board def main (int_board: [][]i32) (iterations: i32): [][]i32 = -- We accept the board as integers for convenience, and then we -- convert to booleans here. let board = to_bool_board int_board in to_int_board (loop board for _i < iterations do iteration board) futhark-0.25.27/tests/rosettacode/mandelbrot.fut000066400000000000000000000035101475065116200216660ustar00rootroot00000000000000-- Computes escapes for each pixel, but not the colour. -- == -- compiled input { 10i64 10i64 100 0.0f32 0.0f32 1.0f32 1.0f32 } -- output { -- [[100i32, 100i32, 100i32, 100i32, 100i32, 100i32, 100i32, 12i32, 17i32, 7i32], -- [100i32, 100i32, 100i32, 100i32, 100i32, 100i32, 100i32, 8i32, 5i32, 4i32], -- [100i32, 100i32, 100i32, 100i32, 100i32, 100i32, 11i32, 5i32, 4i32, 3i32], -- [11i32, 100i32, 100i32, 100i32, 100i32, 100i32, 14i32, 5i32, 4i32, 3i32], -- [6i32, 7i32, 30i32, 14i32, 8i32, 6i32, 14i32, 4i32, 3i32, 2i32], -- [4i32, 4i32, 4i32, 5i32, 4i32, 4i32, 3i32, 3i32, 2i32, 2i32], -- [3i32, 3i32, 3i32, 3i32, 3i32, 2i32, 2i32, 2i32, 2i32, 2i32], -- [2i32, 2i32, 2i32, 2i32, 2i32, 2i32, 2i32, 2i32, 2i32, 1i32], -- [2i32, 2i32, 2i32, 2i32, 2i32, 2i32, 2i32, 1i32, 1i32, 1i32], -- [2i32, 2i32, 2i32, 2i32, 2i32, 1i32, 1i32, 1i32, 1i32, 1i32]] -- } type complex = (f32, f32) def dot(c: complex): f32 = let (r, i) = c in r * r + i * i def multComplex(x: complex, y: complex): complex = let (a, b) = x let (c, d) = y in (a*c - b * d, a*d + b * c) def addComplex(x: complex, y: complex): complex = let (a, b) = x let (c, d) = y in (a + c, b + d) def divergence(depth: i32, c0: complex): i32 = (loop (c, i) = (c0, 0) while i < depth && dot(c) < 4.0 do (addComplex(c0, multComplex(c, c)), i + 1)).1 def main (screenX: i64) (screenY: i64) (depth: i32) (xmin: f32) (ymin: f32) (xmax: f32) (ymax: f32): [screenX][screenY]i32 = let sizex = xmax - xmin let sizey = ymax - ymin in map (\x: [screenY]i32 -> map (\y: i32 -> let c0 = (xmin + (f32.i64(x) * sizex) / f32.i64(screenX), ymin + (f32.i64(y) * sizey) / f32.i64(screenY)) in divergence(depth, c0)) (iota screenY)) (iota screenX) futhark-0.25.27/tests/rosettacode/matrixmultiplication.fut000066400000000000000000000007041475065116200240230ustar00rootroot00000000000000-- https://rosettacode.org/wiki/Matrix_multiplication -- -- Matrix multiplication written in a functional style. -- -- == -- input { -- [ [1,2], [3,4] ] -- [ [5,6], [7,8] ] -- } -- output { -- [ [ 19 , 22 ] , [ 43 , 50 ] ] -- } -- structure { /Screma 1 /Screma/Screma 1 /Screma/Screma/Screma 1 } def main [n][m][p] (x: [n][m]i32) (y: [m][p]i32): [n][p]i32 = map (\xr -> map (\yc -> reduce (+) 0 (map2 (*) xr yc)) (transpose y)) x futhark-0.25.27/tests/rosettacode/md5.fut000066400000000000000000000072431475065116200202330ustar00rootroot00000000000000-- MD5 implementation. Based on Accelerate and other sources. -- -- Easy to get wrong if you forget that MD5 is little-endian. -- -- Ignored on ISPC backend since the ISPC compiler miscompiles the program -- when using 64-bit addressing mode. Memory reads from valid addresses -- cause a segfault if there are inactive lanes with invalid addresses. -- == -- no_ispc input { empty([0]u8) } -- output { [0xd4u8,0x1du8,0x8cu8,0xd9u8,0x8fu8,0x00u8,0xb2u8,0x04u8,0xe9u8,0x80u8,0x09u8,0x98u8,0xecu8,0xf8u8,0x42u8,0x7eu8] } -- no_ispc input { [0u8] } -- output { [0x93u8, 0xb8u8, 0x85u8, 0xadu8, 0xfeu8, 0x0du8, 0xa0u8, 0x89u8, 0xcdu8, 0xf6u8, 0x34u8, 0x90u8, 0x4fu8, 0xd5u8, 0x9fu8, 0x71u8] } type md5 = (u32, u32, u32, u32) def us32 (x: i32) = u32.i32 x def rs: [64]u32 = map us32 [ 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21 ] def ks: [64]u32 = [ 0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee , 0xf57c0faf, 0x4787c62a, 0xa8304613, 0xfd469501 , 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be , 0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821 , 0xf61e2562, 0xc040b340, 0x265e5a51, 0xe9b6c7aa , 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8 , 0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed , 0xa9e3e905, 0xfcefa3f8, 0x676f02d9, 0x8d2a4c8a , 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c , 0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70 , 0x289b7ec6, 0xeaa127fa, 0xd4ef3085, 0x04881d05 , 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665 , 0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039 , 0x655b59c3, 0x8f0ccc92, 0xffeff47d, 0x85845dd1 , 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1 , 0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391 ] def rotate_left(x: u32, c: u32): u32 = (x << c) | (x >> (32u32 - c)) def bytes(x: u32): [4]u8 = [u8.u32(x), u8.u32(x/0x100u32), u8.u32(x/0x10000u32), u8.u32(x/0x1000000u32)] def unbytes(bs: [4]u8): u32 = u32.u8(bs[0]) + u32.u8(bs[1]) * 0x100u32 + u32.u8(bs[2]) * 0x10000u32 + u32.u8(bs[3]) * 0x1000000u32 def unbytes_block(block: [16*4]u8): [16]u32 = map unbytes (unflatten block) -- Process 512 bits of the input. def md5_chunk ((a0,b0,c0,d0): md5) (m: [16]u32): md5 = loop (a,b,c,d) = (a0,b0,c0,d0) for i < 64 do let (f,g) = if i < 16 then ((b & c) | (!b & d), u32.i32 i) else if i < 32 then ((d & b) | (!d & c), (5u32*u32.i32 i + 1u32) % 16u32) else if i < 48 then (b ^ c ^ d, (3u32*u32.i32 i + 5u32) % 16u32) else (c ^ (b | !d), (7u32*u32.i32 i) % 16u32) in (d, b + rotate_left(a + f + ks[i] + m[i32.u32 g], rs[i]), b, c) def md5 [n] (ms: [n][16]u32): md5 = let a0 = 0x67452301_u32 let b0 = 0xefcdab89_u32 let c0 = 0x98badcfe_u32 let d0 = 0x10325476_u32 in loop ((a0,b0,c0,d0)) for i < n do let (a,b,c,d) = md5_chunk (a0,b0,c0,d0) ms[i] in (a0+a, b0+b, c0+c, d0+d) def main [n] (ms: [n]u8): [16]u8 = let padding = 64 - (n % 64) let n_padded = n + padding let num_blocks = n_padded / 64 let ms_padded = ms ++ bytes 0x80u32 ++ replicate (padding-12) 0x0u8 ++ bytes (u32.i64(n*8)) ++ [0u8,0u8,0u8,0u8] :> [num_blocks*(16*4)]u8 let (a,b,c,d) = md5 (map unbytes_block (unflatten ms_padded)) in flatten (map bytes [a,b,c,d]) :> [16]u8 futhark-0.25.27/tests/rosettacode/monte_carlo_methods.fut000066400000000000000000000045161475065116200235730ustar00rootroot00000000000000-- https://rosettacode.org/wiki/Monte_Carlo_methods -- -- Using Sobol sequences for random numbers. -- == -- tags { no_python } -- input { 1 } output { 4f32 } -- input { 10 } output { 3.2f32 } -- compiled input { 100 } output { 3.16f32 } -- compiled input { 1000 } output { 3.144f32 } -- compiled input { 10000 } output { 3.142f32 } -- compiled input { 100000 } output { 3.14184f32 } -- compiled input { 1000000 } output { 3.141696f32 } -- compiled input { 10000000 } output { 3.141595f32 } def dirvcts(): [2][30]i32 = [ [ 536870912, 268435456, 134217728, 67108864, 33554432, 16777216, 8388608, 4194304, 2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1 ], [ 536870912, 805306368, 671088640, 1006632960, 570425344, 855638016, 713031680, 1069547520, 538968064, 808452096, 673710080, 1010565120, 572653568, 858980352, 715816960, 1073725440, 536879104, 805318656, 671098880, 1006648320, 570434048, 855651072, 713042560, 1069563840, 538976288, 808464432, 673720360, 1010580540, 572662306, 858993459 ] ] def grayCode(x: i32): i32 = (x >> 1) ^ x ---------------------------------------- --- Sobol Generator ---------------------------------------- def testBit(n: i32, ind: i32): bool = let t = (1 << ind) in (n & t) == t def xorInds [num_bits] (n: i32) (dir_vs: [num_bits]i32): i32 = let reldv_vals = map2 (\ dv i -> if testBit(grayCode n,i32.i64 i) then dv else 0) dir_vs (iota num_bits) in reduce (^) 0 reldv_vals def sobolIndI [m] [num_bits] (dir_vs: [m][num_bits]i32, n: i64): [m]i32 = map (xorInds (i32.i64 n)) dir_vs def sobolIndR [m] [num_bits] (dir_vs: [m][num_bits]i32) (n: i64): [m]f32 = let divisor = 2.0 ** f32.i64(num_bits) let arri = sobolIndI( dir_vs, n ) in map (\x -> f32.i32 x / divisor) arri def main(n: i32): f32 = let rand_nums = map (sobolIndR (dirvcts())) (iota (i64.i32 n)) let dists = map (\xy -> let (x,y) = (xy[0],xy[1]) in f32.sqrt(x*x + y*y)) rand_nums let bs = map (\d -> if d <= 1.0f32 then 1 else 0) dists let inside = reduce (+) 0 bs in 4.0f32*f32.i64(inside)/f32.i32(n) futhark-0.25.27/tests/rosettacode/pythagorean_means.fut000066400000000000000000000010531475065116200232430ustar00rootroot00000000000000-- http://rosettacode.org/wiki/Averages/Pythagorean_means -- -- == -- input { [1.0,2.0,3.0,1.0] } -- output { 1.75f64 1.565f64 1.412f64 } -- Divide first to improve numerical behaviour. def arithmetic_mean [n] (as: [n]f64): f64 = reduce (+) 0.0 (map (/f64.i64(n)) as) def geometric_mean [n] (as: [n]f64): f64 = reduce (*) 1.0 (map (**(1.0/f64.i64(n))) as) def harmonic_mean [n] (as: [n]f64): f64 = f64.i64(n) / reduce (+) 0.0 (map (1.0/) as) def main(as: []f64): (f64,f64,f64) = (arithmetic_mean as, geometric_mean as, harmonic_mean as) futhark-0.25.27/tests/rosettacode/reverse_a_string.fut000066400000000000000000000003101475065116200230730ustar00rootroot00000000000000-- Futhark has no real strings beyond a little bit of syntactic sugar, -- so this is the same as reversing an array. -- == -- input { [1,2,3,4] } -- output { [4,3,2,1] } def main(s: []i32) = s[::-1] futhark-0.25.27/tests/rosettacode/rms.fut000066400000000000000000000003171475065116200203420ustar00rootroot00000000000000-- http://rosettacode.org/wiki/Averages/Root_mean_square -- -- input { [1.0,2.0,3.0,1.0] } -- output { 1.936f64 } def main [n] (as: [n]f64): f64 = f64.sqrt ((reduce (+) 0.0 (map (**2.0) as)) / f64.i64 n) futhark-0.25.27/tests/safety/000077500000000000000000000000001475065116200157775ustar00rootroot00000000000000futhark-0.25.27/tests/safety/div0.fut000066400000000000000000000001721475065116200173610ustar00rootroot00000000000000-- Division by zero, and in a parallel context at that! -- == -- input { [0] } error: def main (xs: []i32) = map (2/) xs futhark-0.25.27/tests/safety/map0.fut000066400000000000000000000002151475065116200173520ustar00rootroot00000000000000-- == -- random input { [20000]f32 } error: Index \[-1\] def main [n] (xs: [n]f32) = map (\i -> xs[if i == 1000 then -1 else i]) (iota n) futhark-0.25.27/tests/safety/powneg.fut000066400000000000000000000002201475065116200200100ustar00rootroot00000000000000-- Negative integer exponent, and in a parallel context at that! -- == -- input { 2 [-1] } error: def main (b: i32) (xs: []i32) = map (b**) xs futhark-0.25.27/tests/safety/reduce0.fut000066400000000000000000000002271475065116200200470ustar00rootroot00000000000000-- == -- random input { [20000]f32 } error: Index \[-1\] def main [n] (xs: [n]f32) = f32.sum (map (\i -> xs[if i == 1000 then -1 else i]) (iota n)) futhark-0.25.27/tests/safety/reduce1.fut000066400000000000000000000003511475065116200200460ustar00rootroot00000000000000-- == -- compiled random input { [20000]f32 } error: l != 1337 def main [n] (xs: *[n]f32) = let xs[1337] = f32.lowest let op i j = let l = if xs[i] < xs[j] then i else j in assert (l != 1337) l in reduce op 0 (iota n) futhark-0.25.27/tests/safety/reduce_by_index0.fut000066400000000000000000000002761475065116200217340ustar00rootroot00000000000000-- == -- random input { [20000]f32 } error: Index \[-1\] def main [n] (xs: [n]f32) = hist (+) 0 3 (map (%3) (iota n)) (map (\i -> xs[if i == 1000 then -1 else i]) (iota n)) futhark-0.25.27/tests/safety/reduce_by_index1.fut000066400000000000000000000003751475065116200217350ustar00rootroot00000000000000-- == -- compiled random input { [20000]f32 } error: l != 1337 def main [n] (xs: *[n]f32) = let xs[1337] = f32.lowest let op i j = let l = if xs[i] < xs[j] then i else j in assert (l != 1337) l in hist op 0 3 (map (%3) (iota n)) (iota n) futhark-0.25.27/tests/safety/scan0.fut000066400000000000000000000002321475065116200175200ustar00rootroot00000000000000-- == -- random input { [20000]f32 } error: Index \[-1\] def main [n] (xs: [n]f32) = scan (+) 0 (map (\i -> xs[if i == 1000 then -1 else i]) (iota n)) futhark-0.25.27/tests/safety/scan1.fut000066400000000000000000000003471475065116200175300ustar00rootroot00000000000000-- == -- compiled random input { [20000]f32 } error: l != 1337 def main [n] (xs: *[n]f32) = let xs[1337] = f32.lowest let op i j = let l = if xs[i] < xs[j] then i else j in assert (l != 1337) l in scan op 0 (iota n) futhark-0.25.27/tests/safety/scatter0.fut000066400000000000000000000002431475065116200202430ustar00rootroot00000000000000-- == -- random input { [20000]f32 } error: Index \[-1\] def main [n] (xs: [n]f32) = spread 5 3 (iota n) (map (\i -> xs[if i == 1000 then -1 else i]) (iota n)) futhark-0.25.27/tests/saxpy.fut000066400000000000000000000003711475065116200163710ustar00rootroot00000000000000-- Single-Precision A·X Plus Y -- -- == -- input { -- 2.0f32 -- [1.0f32,2.0f32,3.0f32] -- [4.0f32,5.0f32,6.0f32] -- } -- output { -- [6.0f32, 9.0f32, 12.0f32] -- } def main (a: f32) (x: []f32) (y: []f32): []f32 = map2 (+) (map (a*) x) y futhark-0.25.27/tests/scatter/000077500000000000000000000000001475065116200161515ustar00rootroot00000000000000futhark-0.25.27/tests/scatter/data/000077500000000000000000000000001475065116200170625ustar00rootroot00000000000000futhark-0.25.27/tests/scatter/data/tiny.in000066400000000000000000000036671475065116200204110ustar00rootroot0000000000000010i32 [0i32, -10, -20, -30, -40, -50, -60, -70, -80, -90, -100, -110, -120, -130, -140, -150, -160, -10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -60, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -110, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -120, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -130, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -140, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -150, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -160, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [0i32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, -2, 0, -2, 0, -3, -3, 0, -3, 0, 1, -1, 8, 0, -3, 0, 0, 2, -4, -2, -4, -2, -1, 2, -2, 4, -2, -3, -4, -3, -2, -1, -2, 0, 2, -4, -2, -4, -2, -1, 2, -2, 4, -2, -3, -4, -3, -2, -1, -2, 0, -3, 0, 0, 0, 0, -3, -3, 0, -3, 0, 6, 1, 1, 0, -3, 0, 0, -3, -2, 0, -2, 2, -4, -3, 2, -3, 2, 0, 2, 0, 0, -4, 2, 0, -1, -3, -3, -3, -3, 9, -1, -3, -1, -3, -3, -3, -3, -3, 9, -3, 0, -4, 6, -2, 6, -2, -3, -4, -2, -4, -2, 0, -1, -2, -2, -3, -2, 0, -3, -2, 0, -2, 0, -3, -3, 0, -3, 0, 1, -1, 8, 0, -3, 0, 0, -3, -2, 5, -2, 1, -3, -3, 1, -2, 1, 0, -2, 0, 5, -3, 1, 0, -1, -3, -3, -3, -3, 9, -1, -3, -1, -3, -3, -3, -3, -3, 9, -3, 0, -3, -2, 0, -2, 2, -4, -3, 2, -3, 2, 0, 2, 0, 0, -4, 2, 0, -3, 0, 0, 0, 0, -3, -3, 0, -3, 0, 6, 1, 1, 0, -3, 0, 0, -3, 0, 0, 0, 0, -3, -3, 0, -3, 0, 6, 1, 1, 0, -3, 0, 0, -4, 6, -2, 6, -2, -3, -4, -2, -4, -2, 0, -1, -2, -2, -3, -2, 0, -3, -2, 0, -2, 2, -4, -3, 2, -3, 2, 0, 2, 0, 0, -4, 2, 0, -3, -2, 5, -2, 1, -3, -3, 1, -2, 1, 0, -2, 0, 5, -3, 1] futhark-0.25.27/tests/scatter/data/tiny.out000066400000000000000000000022231475065116200205750ustar00rootroot00000000000000b i32!~tj`|r~tjv`lfuthark-0.25.27/tests/scatter/elimination/000077500000000000000000000000001475065116200204615ustar00rootroot00000000000000futhark-0.25.27/tests/scatter/elimination/write-iota0.fut000066400000000000000000000004711475065116200233470ustar00rootroot00000000000000-- Test that an iota can be eliminated in a write. Contrived example. -- == -- input { -- [100, 200, 300] -- [5, 10, 15, 20, 25, 30] -- } -- output { -- [100, 200, 300, 20, 25, 30] -- } -- structure { Scatter 1 } def main [k][n] (values: [k]i32) (array: *[n]i32): [n]i32 = scatter array (iota k) values futhark-0.25.27/tests/scatter/elimination/write-iota1.fut000066400000000000000000000004651475065116200233530ustar00rootroot00000000000000-- Test that multiple iotas can be eliminated in a write. -- == -- input { -- 4i64 -- [5i64, 10i64, 15i64, 20i64, 25i64, 30i64] -- } -- output { -- [0i64, 1i64, 2i64, 3i64, 25i64, 30i64] -- } -- structure { Scatter 1 } def main [n] (k: i64) (array: *[n]i64): [n]i64 = scatter array (iota k) (iota k) futhark-0.25.27/tests/scatter/elimination/write-iota2.fut000066400000000000000000000005161475065116200233510ustar00rootroot00000000000000-- Test that multiple iotas with different start values can be eliminated in a -- write. -- == -- input { -- 5i64 -- [5, 10, 15, 20, 25, 30] -- } -- output { -- [-9, -8, -7, -6, -5, 30] -- } -- structure { Scatter 1 } def main [n] (k: i64) (array: *[n]i32): [n]i32 = scatter array (iota k) (map (\x -> i32.i64 x-9) (iota k)) futhark-0.25.27/tests/scatter/elimination/write-replicate0.fut000066400000000000000000000004461475065116200243650ustar00rootroot00000000000000-- Test that a replicate can be eliminated in a write. -- == -- input { -- [0i64, 3i64, 1i64] -- [9, 8, -3, 90, 41] -- } -- output { -- [5, 5, -3, 5, 41] -- } -- structure { Scatter 1 } def main [k][n] (indexes: [k]i64) (array: *[n]i32): [n]i32 = scatter array indexes (replicate k 5) futhark-0.25.27/tests/scatter/fusion/000077500000000000000000000000001475065116200174545ustar00rootroot00000000000000futhark-0.25.27/tests/scatter/fusion/concat-scatter-fusion0.fut000066400000000000000000000007251475065116200244730ustar00rootroot00000000000000-- If both the indexes and values come from a concatenation of arrays -- of the same size, that concatenation should be fused away. -- == -- input { [0,0,0,0,0,0,0,0] [0, 2, 6] } -- output { [1, 2, 1, 2, 0, 0, 1, 2] } -- structure { Concat 0 Scatter 1 } def main [k][n] (arr: *[k]i32) (xs: [n]i32) = let (is0, vs0, is1, vs1) = unzip4 (map (\x -> (i64.i32 x,1,i64.i32 x+1,2)) xs) let m = n + n in scatter arr (concat is0 is1 :> [m]i64) (concat vs0 vs1 :> [m]i32) futhark-0.25.27/tests/scatter/fusion/concat-scatter-fusion1.fut000066400000000000000000000007151475065116200244730ustar00rootroot00000000000000-- Concat-scatter fusion in a more complicated session. -- == -- input { [0,5] } -- output { [3, 5, 1, 1, 1, 3, 5, 1, 1, 1] -- [4, 6, 2, 2, 2, 4, 6, 2, 2, 2] } -- structure { Concat 0 Scatter 1 } def main [n] (xs: [n]i32) = let dest = replicate 10 (1,2) let (is0, vs0, is1, vs1) = unzip4 (map (\x -> (i64.i32 x,(3,4),i64.i32 x+1,(5,6))) xs) let m = n + n in unzip (scatter dest (concat is0 is1 :> [m]i64) (concat vs0 vs1 :> [m](i32,i32))) futhark-0.25.27/tests/scatter/fusion/concat-scatter-fusion2.fut000066400000000000000000000011711475065116200244710ustar00rootroot00000000000000-- If both the indexes and values come from a concatenation of arrays -- of the same size, that concatenation should be fused away. -- == -- input { [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0] [0, 5, 10] } -- output { [1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 0] } -- structure { Concat 0 Scatter 1 } def main [k][n] (arr: *[k]i32) (xs: [n]i32) = let (a, b) = unzip (map (\x -> ((i64.i32 x,1,i64.i32 x+1,2),(i64.i32 x+2,i64.i32 x+3,3,4))) xs) let m = n + n + n + n let ((is0, vs0, is1, vs1), (is2, is3, vs2, vs3)) = (unzip4 a, unzip4 b) in scatter arr (is0 ++ is1 ++ is2 ++ is3 :> [m]i64) (vs0 ++ vs1 ++ vs2 ++ vs3 :> [m]i32) futhark-0.25.27/tests/scatter/fusion/map-write-fusion-not-possible0.fut000066400000000000000000000007211475065116200260760ustar00rootroot00000000000000-- Test that map-write fusion is *not* applied when not all of the map outputs -- are used in the write. -- == -- structure { Screma 1 Scatter 1 } def main [k][n] (indexes: [k]i64, values: [k]i32, array: *[n]i32): ([n]i32, [k]i32) = let (indexes', baggage) = unzip(map (\(i, v) -> (i + 1, v + 1)) (zip indexes values)) let array' = scatter array indexes' values in (array', baggage) futhark-0.25.27/tests/scatter/fusion/map-write-fusion-not-possible1.fut000066400000000000000000000010561475065116200261010ustar00rootroot00000000000000-- Test that map-scatter fusion is *not* applied when one of the I/O arrays (which -- are not part of the fusable indexes and values arrays) are used in the map -- *and* in the scatter. If this was fused into a single scatter, the I/O array -- would ultimately be written and read in the same kernel. -- == -- structure { Screma 1 Scatter 1 } def main [k][n] (indexes: [k]i64, values: [k]i64, array: *[n]i64): [n]i64 = let indexes' = map (\i -> array[i]) indexes let array' = scatter array indexes' values in array' futhark-0.25.27/tests/scatter/fusion/map-write-fusion0.fut000066400000000000000000000006771475065116200234740ustar00rootroot00000000000000-- Test that map-write fusion works in a simple case. -- == -- input { -- [2i64, 0i64] -- [100i64, 200i64] -- [0i64, 2i64, 4i64, 6i64, 9i64] -- } -- output { -- [0i64, 200i64, 4i64, 100i64, 9i64] -- } -- structure { Screma 0 Scatter 1 } def main [k][n] (indexes: [k]i64) (values: [k]i64) (array: *[n]i64): [n]i64 = let indexes' = map (+1) indexes let array' = scatter array indexes' values in array' futhark-0.25.27/tests/scatter/fusion/map-write-fusion1.fut000066400000000000000000000007251475065116200234670ustar00rootroot00000000000000-- Test that map-scatter fusion works in a slightly less simple case. -- == -- input { -- [2i64, 0i64] -- [100i64, 200i64] -- [0i64, 2i64, 4i64, 6i64, 9i64] -- } -- output { -- [200i64, 2i64, 102i64, 6i64, 9i64] -- } -- structure { Screma 0 Scatter 1 } def main [k][n] (indexes: [k]i64) (values: [k]i64) (array: *[n]i64): [n]i64 = let values' = map2 (+) indexes values let array' = scatter array indexes values' in array' futhark-0.25.27/tests/scatter/fusion/scatter-scatter-not-possible.fut000066400000000000000000000001431475065116200257160ustar00rootroot00000000000000def main (xs: *[]i32) (ys: *[]i32) = let xs' = scatter xs [0] [1] in (xs', scatter ys [0] xs') futhark-0.25.27/tests/scatter/fusion/write-fusion-mix0.fut000066400000000000000000000012741475065116200235060ustar00rootroot00000000000000-- Test that map-scatter fusion and scatter-scatter fusion work together. -- == -- input { -- [2i64, 0i64] -- [1i64, 0i64] -- [100, 80] -- [0, 2, 4, 6, 9] -- [10, 12, 14, 16, 19] -- } -- output { -- [84i32, 2i32, 104i32, 6i32, 9i32] -- [240i32, 300i32, 14i32, 16i32, 19i32] -- } -- structure { Scatter 1 } def main [k][n] (indexes0: [k]i64) (indexes1: [k]i64) (values: [k]i32) (array0: *[n]i32) (array1: *[n]i32): ([n]i32, [n]i32) = let values0' = map (+4) values let values1' = map (*3) values let array0' = scatter array0 indexes0 values0' let array1' = scatter array1 indexes1 values1' in (array0', array1') futhark-0.25.27/tests/scatter/fusion/write-fusion-mix1.fut000066400000000000000000000013211475065116200235000ustar00rootroot00000000000000-- Test that map-scatter fusion and scatter-scatter fusion work together. -- == -- input { -- [0i64, 1i64, 3i64] -- [3i64, 2i64, 4i64, 6i64, 9i64, 14i64] -- [13i64, 12i64, 14i64, 16i64, 19i64, 114i64] -- } -- output { -- [3i64, 3i64, 4i64, 6i64, 6i64, 14i64] -- [13i64, 12i64, 4i64, 5i64, 19i64, 7i64] -- } -- structure { Scatter 1 } def main [k][n] (numbers: [k]i64) (array0: *[n]i64) (array1: *[n]i64): ([n]i64, [n]i64) = let indexes0 = map (+1) numbers let indexes1 = map (+2) numbers let values0 = map (+3) numbers let values1 = map (+4) numbers let array0' = scatter array0 indexes0 values0 let array1' = scatter array1 indexes1 values1 in (array0', array1') futhark-0.25.27/tests/scatter/fusion/write-write-fusion-not-possible0.fut000066400000000000000000000006051475065116200264540ustar00rootroot00000000000000-- Test that write-write fusion is *not* applied when one write uses the output -- of another write. -- == -- structure { Scatter 2 } def main [k] [n] (indexes: [k]i64, values1: [k]i32, values2: [k]i32, array: *[n]i32): [n]i32 = let array' = scatter array indexes values1 let array'' = scatter array' indexes values2 in array'' futhark-0.25.27/tests/scatter/fusion/write-write-fusion-not-possible1.fut000066400000000000000000000006271475065116200264610ustar00rootroot00000000000000-- Test that write-write fusion is *not* applied when one write uses the output -- of another write. -- == -- structure { Scatter 2 } def main [k] (indexes: [k]i64, values1: [k]i64, values2: [k]i64, array1: *[k]i64, array2: *[k]i64): [k]i64 = let array1' = scatter array1 indexes values1 let array2' = scatter array2 array1' values2 in array2' futhark-0.25.27/tests/scatter/fusion/write-write-fusion0.fut000066400000000000000000000011051475065116200240340ustar00rootroot00000000000000-- Test that write-write fusion works in a simple case. -- == -- input { -- [1i64, 0i64] -- [8, 2] -- [5, 3] -- [10, 20, 30, 40, 50] -- [100, 200, 300, 400, 500] -- } -- output { -- [2, 8, 30, 40, 50] -- [3, 5, 300, 400, 500] -- } -- structure { Scatter 1 } def main [n][k] (indexes: [k]i64) (values1: [k]i32) (values2: [k]i32) (array1: *[n]i32) (array2: *[n]i32): ([n]i32, [n]i32) = let array1' = scatter array1 indexes values1 let array2' = scatter array2 indexes values2 in (array1', array2') futhark-0.25.27/tests/scatter/fusion/write-write-fusion1.fut000066400000000000000000000013051475065116200240370ustar00rootroot00000000000000-- Test that scatter-scatter fusion works with more than two arrays. -- == -- input { -- [0i64] -- [99] -- [10, 20, 30, 40, 50] -- [100, 200, 300, 400, 500] -- [1000, 2000, 3000, 4000, 5000] -- } -- output { -- [99, 20, 30, 40, 50] -- [99, 200, 300, 400, 500] -- [99, 2000, 3000, 4000, 5000] -- } -- structure { Scatter 1 } def main [k][n] (indexes: [k]i64) (values: [k]i32) (array1: *[n]i32) (array2: *[n]i32) (array3: *[n]i32): ([n]i32, [n]i32, [n]i32) = let array1' = scatter array1 indexes values let array2' = scatter array2 indexes values let array3' = scatter array3 indexes values in (array1', array2', array3') futhark-0.25.27/tests/scatter/mapscatter.fut000066400000000000000000000004051475065116200210330ustar00rootroot00000000000000-- You can map a scatter (sort of). -- == -- input { [[1,2,3],[4,5,6]] [[1,-1,-1],[-1,0,1]] [[0,0,0],[0,0,0]] } -- output { [[1,0,3],[0,0,6]] } def main (as: [][]i32) (is: [][]i32) (vs: [][]i32) = map3 (\x y z -> scatter (copy x) (map i64.i32 y) z) as is vs futhark-0.25.27/tests/scatter/nw.fut000066400000000000000000000071541475065116200173240ustar00rootroot00000000000000-- Stripped version of nw from futhark-benchmarks -- == -- no_wasm compiled input @ data/tiny.in -- output @ data/tiny.out def B0: i64 = 64 def fInd (B: i64) (y:i32) (x:i32): i32 = y*(i32.i64 B+1) + x def max3 (x:i32, y:i32, z:i32) = if x < y then if y < z then z else y else if x < z then z else x def mkVal [l2][l] (y:i32) (x:i32) (pen:i32) (inp_l:[l2][l2]i32) (ref_l:[l][l]i32) : i32 = #[unsafe] max3( ( (inp_l[y - 1, x - 1])) + ( ref_l[y-1, x-1]) , ( (inp_l[y, x - 1])) - pen , ( (inp_l[y - 1, x])) - pen ) def intraBlockPar [len] (B: i64) (penalty: i32) (inputsets: [len*len]i32) (reference2: [len][len]i32) (b_y: i64) (b_x: i64) : [B][B]i32 = let ref_l = reference2[b_y * B + 1: b_y * B + 1 + B, b_x * B + 1: b_x * B + 1 + B] :> [B][B]i32 let inputsets' = unflatten inputsets let Bp1 = B + 1 -- Initialize inp_l with the already processed the column to the left of this -- block let inp_l = copy(inputsets'[b_y * B : b_y * B + B + 1, b_x * B : b_x * B + B + 1]) :> *[Bp1][Bp1]i32 -- Process the first half (anti-diagonally) of the block let inp_l = loop inp_l for m < B do let (inds, vals) = unzip ( -- tabulate over the m'th anti-diagonal before the middle tabulate B (\tx -> ( if tx > m then ((-1, -1), 0) else let ind_x = i32.i64 (tx + 1) let ind_y = i32.i64 (m - tx + 1) let v = mkVal ind_y ind_x penalty inp_l ref_l in ((i64.i32 ind_y, i64.i32 ind_x), v)))) in scatter_2d inp_l inds vals in inp_l[1:B+1,1:B+1] :> [B][B]i32 def updateBlocks [q] (B: i64) (len: i64) (blk: i64) (mk_b_y: (i32 -> i32)) (mk_b_x: (i32 -> i32)) (block_inp: [q][B][B]i32) (inputsets: *[len*len]i32) = let (inds, vals) = unzip ( tabulate (blk*B*B) (\gid -> let B2 = i32.i64 (B*B) let gid = i32.i64 gid let (bx, lid2) = (gid / B2, gid % B2) let (ty, tx) = (lid2 / i32.i64 B, lid2 % i32.i64 B) let b_y = mk_b_y bx let b_x = mk_b_x bx let v = #[unsafe] block_inp[bx, ty, tx] let ind = (i32.i64 B*b_y + 1 + ty) * i32.i64 len + (i32.i64 B*b_x + tx + 1) in (i64.i32 ind, v))) in scatter inputsets inds vals def main [lensq] (penalty : i32) (inputsets : *[lensq]i32) (reference : *[lensq]i32) : *[lensq]i32 = let len = i64.f32 (f32.sqrt (f32.i64 lensq)) let worksize = len - 1 let B = i64.min worksize B0 -- worksize should be a multiple of B0 let B = assert (worksize % B == 0) B let block_width = trace <| worksize / B let reference2 = unflatten (reference :> [len*len]i32) let inputsets = (inputsets :> [len*len]i32) -- First anti-diagonal half of the entire input matrix let inputsets = loop inputsets for blk < block_width do let blk = blk + 1 let block_inp = -- Process an anti-diagonal of independent blocks tabulate blk (\b_x -> let b_y = blk-1-b_x in intraBlockPar B penalty inputsets reference2 b_y b_x ) let mkBY bx = i32.i64 (blk - 1) - bx let mkBX bx = bx in updateBlocks B len blk mkBY mkBX block_inp inputsets in inputsets :> [lensq]i32 futhark-0.25.27/tests/scatter/scatter2d-2.fut000066400000000000000000000017141475065116200207260ustar00rootroot00000000000000-- Test that write works for arrays of tuples. -- == -- -- input { -- [1] -- [1] -- [1337] -- [[1,2,3],[4,5,6],[7,8,9]] -- [[10,20,30],[40,50,60],[70,80,90]] -- } -- output { -- [[1,2,3],[4,1337,6],[7,8,9]] -- [[10,20,30],[40,1337,60],[70,80,90]] -- } -- let scatter_blab 't [n] [h] [w] (dest: *[h][w]t) (is: [n](i64, i64)) (vs: [n]t): *[h][w]t = -- let len = h * w -- let is' = map (\(i, j) -> if i < 0 || i >= h || j < 0 || j >= w -- then -1 -- else i * w + j) -- is -- let flattened = flatten_to len dest -- in scatter flattened is' vs |> unflatten h w def main [k][n][m] (indexes1: [k]i32) (indexes2: [k]i32) (values: [k]i32) (array1: *[n][m]i32) (array2: *[n][m]i32): ([n][m]i32, [n][m]i32) = unzip (map unzip (scatter_2d (copy (map2 zip array1 array2)) (zip (map i64.i32 indexes1) (map i64.i32 indexes2)) (zip values values))) futhark-0.25.27/tests/scatter/scatter2d.fut000066400000000000000000000016701475065116200205700ustar00rootroot00000000000000-- == -- input { [[1,2,3],[4,5,6],[7,8,9]] [1i64, 1i64] [1i64, -1i64] [42, 1337] } -- output { [[1,2,3],[4,42,6],[7,8,9]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [-1i64] [-1i64] [1337] } -- output { [[1,2,3],[4,5,6],[7,8,9]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [3i64] [0i64] [1337] } -- output { [[1,2,3],[4,5,6],[7,8,9]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [0i64] [3i64] [1337] } -- output { [[1,2,3],[4,5,6],[7,8,9]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [-1i64] [0i64] [1337] } -- output { [[1,2,3],[4,5,6],[7,8,9]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [0i64] [-1i64] [1337] } -- output { [[1,2,3],[4,5,6],[7,8,9]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [0i64] [0i64] [1337] } -- output { [[1337,2,3],[4,5,6],[7,8,9]] } -- input { [[1,2,3],[4,5,6],[7,8,9]] [3i64] [3i64] [1337] } -- output { [[1,2,3],[4,5,6],[7,8,9]] } def main [n][m][l] (xss: *[n][m]i32) (is: [l]i64) (js: [l]i64) (vs: [l]i32): [n][m]i32 = scatter_2d xss (zip is js) vs futhark-0.25.27/tests/scatter/scatter3d.fut000066400000000000000000000003751475065116200205720ustar00rootroot00000000000000-- == -- input { [[[1,2,3],[4,5,6],[7,8,9]],[[10,20,30],[40,50,60],[70,80,90]]] } -- output { [[[1,2,3],[4,5,6],[7,8,9]],[[10,20,30],[40,1337,60],[70,80,90]]] } def main [n][m][o] (xss: *[n][m][o]i32) = scatter_3d xss [(1, 1, 1), (1,-1, 1)] [1337, 0] futhark-0.25.27/tests/scatter/singleton.fut000066400000000000000000000002351475065116200206730ustar00rootroot00000000000000-- #1793 -- == -- input { [false] 0i64 } output { [true] } -- input { [false] 1i64 } output { [false] } def main (xs: *[1]bool) i = scatter xs [i] [true] futhark-0.25.27/tests/scatter/write-error0.fut000066400000000000000000000001761475065116200212360ustar00rootroot00000000000000-- Fail if the argument to write is not unique. -- -- == -- error: Consuming def main(a: []i32): []i32 = scatter a [0] [1] futhark-0.25.27/tests/scatter/write0.fut000066400000000000000000000011541475065116200201040ustar00rootroot00000000000000-- Test that write works in its simplest uses. -- == -- -- input { -- [0] -- [9] -- [3] -- } -- output { -- [9] -- } -- -- input { -- [-1] -- [0] -- [5] -- } -- output { -- [5] -- } -- -- input { -- [0, 1] -- [5, 6] -- [3, 4] -- } -- output { -- [5, 6] -- } -- -- input { -- [0, 2, -1] -- [9, 7, 0] -- [3, 4, 5] -- } -- output { -- [9, 4, 7] -- } -- -- input { -- [4, -1] -- [77, 0] -- [8, -4, 9, 1, 2, 100] -- } -- output { -- [8, -4, 9, 1, 77, 100] -- } def main [k][n] (indexes: [k]i32) (values: [k]i32) (array: *[n]i32): [n]i32 = scatter array (map i64.i32 indexes) values futhark-0.25.27/tests/scatter/write1.fut000066400000000000000000000006241475065116200201060ustar00rootroot00000000000000-- Test that write works in non-trivial cases. -- == -- -- input { -- [1, -1] -- [[5.0f32, 4.3f32], [0.0f32, 0.0f32]] -- [[1.0f32, 1.2f32], [2.3f32, -11.6f32], [4.0f32, 44.2f32]] -- } -- output { -- [[1.0f32, 1.2f32], [5.0f32, 4.3f32], [4.0f32, 44.2f32]] -- } def main [k][m][n] (indexes: [k]i32) (values: [k][m]f32) (array: *[n][m]f32): [n][m]f32 = scatter array (map i64.i32 indexes) values futhark-0.25.27/tests/scatter/write2.fut000066400000000000000000000011301475065116200201000ustar00rootroot00000000000000-- Test that write works in even more non-trivial cases. -- == -- -- input { -- [2, -1, 0] -- [[[0, 0, 1], [5, 6, 7]], -- [[0, 0, 0], [0, 0, 0]], -- [[1, 1, 14], [15, 16, 17]]] -- [[[1, 2, 3], [10, 20, 30]], -- [[1, 2, 3], [10, 20, 30]], -- [[14, 24, 34], [11, 21, 31]]] -- } -- output { -- [[[1, 1, 14], [15, 16, 17]], -- [[1, 2, 3], [10, 20, 30]], -- [[0, 0, 1], [5, 6, 7]]] -- } def main [k][t][m][n] (indexes: [k]i32) (values: [k][t][m]i32) (array: *[n][t][m]i32): [n][t][m]i32 = scatter array (map i64.i32 indexes) values futhark-0.25.27/tests/scatter/write3.fut000066400000000000000000000005141475065116200201060ustar00rootroot00000000000000-- Test that write works for large indexes and values arrays. -- == -- -- input { -- 9337i64 -- } -- output { -- true -- } def main(n: i64): bool = let indexes = iota(n) let values = map (+2) indexes let array = map (+5) indexes let array' = scatter array indexes values in reduce (&&) true (map2 (==) array' values) futhark-0.25.27/tests/scatter/write4.fut000066400000000000000000000006021475065116200201050ustar00rootroot00000000000000-- Test that write works for arrays of tuples. -- == -- -- input { -- [0] -- [9] -- [1,2,3] -- [4,5,6] -- } -- output { -- [9,2,3] -- [9,5,6] -- } def main [k][n] (indexes: [k]i32) (values: [k]i32) (array1: *[n]i32) (array2: *[n]i32): ([n]i32, [n]i32) = unzip (scatter (copy (zip array1 array2)) (map i64.i32 indexes) (zip values values)) futhark-0.25.27/tests/scatter/write5.fut000066400000000000000000000002141475065116200201050ustar00rootroot00000000000000 def main [n] ((a: [n]f32, ja: []i32)): ([]f32, []i32) = let res = zip a ja let idxs = iota n in unzip (scatter (copy res) idxs res) futhark-0.25.27/tests/script/000077500000000000000000000000001475065116200160105ustar00rootroot00000000000000futhark-0.25.27/tests/script/data/000077500000000000000000000000001475065116200167215ustar00rootroot00000000000000futhark-0.25.27/tests/script/data/input.in000066400000000000000000000000331475065116200204040ustar00rootroot00000000000000b f32?@@@futhark-0.25.27/tests/script/opaques.fut000066400000000000000000000003111475065116200202000ustar00rootroot00000000000000-- == -- entry: bar -- script input { foo 10 } type bools [n] = #foo [n]bool entry foo (n: i64) : {x:[n]bool,y:bool} = {x=replicate n true,y=false} entry bar [m] (b: {x:[m]bool,y:bool}) : bool = b.y futhark-0.25.27/tests/script/script0.fut000066400000000000000000000003651475065116200201200ustar00rootroot00000000000000-- == -- entry: doeswork -- script input { mkdata 100i64 } output { 5050.0f32 } -- script input { mkdata 10000i64 } -- script input { mkdata 1000000i64 } entry mkdata n = (n,map f32.i64 (iota n)) entry doeswork n arr = f32.sum arr + f32.i64 n futhark-0.25.27/tests/script/script1.fut000066400000000000000000000005041475065116200201140ustar00rootroot00000000000000-- == -- entry: doeswork_named -- script input { mkdata_named 100i64 } output { 5050.0f32 } -- script input { mkdata_named 10000i64 } -- script input { mkdata_named 1000000i64 } type~ data = (i64, []f32) entry mkdata_named n : data = (n,map f32.i64 (iota n)) entry doeswork_named ((n,arr): data) = f32.sum arr + f32.i64 n futhark-0.25.27/tests/script/script2.fut000066400000000000000000000001121475065116200201100ustar00rootroot00000000000000-- == -- script input { 1i32 } output { 3i32 } def main (x: i32) = x + 2 futhark-0.25.27/tests/script/script3.fut000066400000000000000000000002021475065116200201110ustar00rootroot00000000000000-- == -- entry: dotprod -- script input @ script3.futharkscript entry dotprod (xs: []f64) (ys: []f64) = f64.sum (map2 (*) xs ys) futhark-0.25.27/tests/script/script3.futharkscript000066400000000000000000000000401475065116200222040ustar00rootroot00000000000000let A = [1.0,2.0,3.0] in (A, A) futhark-0.25.27/tests/script/script4.fut000066400000000000000000000002321475065116200201150ustar00rootroot00000000000000-- == -- entry: doeswork -- script input { let x = 2 in let y = 3 in mkdata x y } output { 10 } entry mkdata x y = x + y : i32 entry doeswork z = z * 2 futhark-0.25.27/tests/script/script5.fut000066400000000000000000000003461475065116200201240ustar00rootroot00000000000000-- == -- script input { (2f32, $loaddata "data/input.in") } -- output { [3f32,4f32,5f32] } -- "the other one" script input { (3f32, $loaddata "data/input.in") } -- output { [4f32,5f32,6f32] } def main (x: f32) arr = map (+x) arr futhark-0.25.27/tests/script/script6.fut000066400000000000000000000002131475065116200201160ustar00rootroot00000000000000-- Can we read our own source code? -- == -- script input { $loadbytes "script6.fut" } -- output { 139i64 } def main (s: []u8) = length s futhark-0.25.27/tests/segredomap/000077500000000000000000000000001475065116200166325ustar00rootroot00000000000000futhark-0.25.27/tests/segredomap/add-then-reduce.fut000066400000000000000000000005451475065116200223070ustar00rootroot00000000000000-- Add y to all elements of all inner arrays in xss, then we sum all inner -- arrays in xss. -- -- This is an interesting example only because the two expressions will be fused -- into a single segmented-redomap def main [m][n] (xss : [m][n]f32, y : f32): [m]f32 = let xss' = map (\xs -> map (y+) xs) xss in map (\xs -> reduce_comm (+) 0.0f32 xs) xss' futhark-0.25.27/tests/segredomap/ex1-comm.fut000066400000000000000000000010701475065116200207760ustar00rootroot00000000000000-- A simple example of a redomap within a map, that uses different types for the -- different parts of the redomap -- == -- input { -- [[1.0f32, 2.0f32, 3.0f32], [4.0f32, 5.0f32, 6.0f32]] -- } -- output { -- [6i64, 15i64] -- [[-1.000000f64, -2.000000f64, -3.000000f64], [-4.000000f64, -5.000000f64, -6.000000f64]] -- } def main [m][n] (xss : [m][n]f32): ([m]i64, [m][n]f64) = unzip (map( \(xs : [n]f32) : (i64, [n]f64) -> let (xs_int, xs_neg) = unzip (map(\x -> (i64.f32 x, f64.f32(-x))) xs) in (reduce_comm (+) 0 xs_int, xs_neg) ) xss) futhark-0.25.27/tests/segredomap/ex1-nocomm.fut000066400000000000000000000013641475065116200213410ustar00rootroot00000000000000-- A simple example of a redomap within a map, that uses different types for the -- different parts of the redomap -- == -- input { -- true -- [[1.0f32, 2.0f32, 3.0f32], [4.0f32, 5.0f32, 6.0f32]] -- } -- output { -- [6i64, 15i64] -- [[-1.000000f64, -2.000000f64, -3.000000f64], [-4.000000f64, -5.000000f64, -6.000000f64]] -- } -- Add a data-driven branch to prevent the compiler from noticing that -- this is commutative. def add (b: bool) (x : i64) (y : i64): i64 = if b then x + y else x - y def main [m][n] (b: bool) (xss : [m][n]f32): ([m]i64, [m][n]f64) = unzip (map( \(xs : [n]f32) : (i64, [n]f64) -> let (xs_int, xs_neg) = unzip (map(\x -> (i64.f32 x, f64.f32(-x))) xs) in (reduce (add b) 0 xs_int, xs_neg) ) xss) futhark-0.25.27/tests/segredomap/ex2.fut000066400000000000000000000015361475065116200200550ustar00rootroot00000000000000-- A simple example of a redomap within a map, that uses different types for the -- different parts of the redomap -- == -- input { -- [ [ [1.0f32, 2.0f32, 3.0f32], [4.0f32, 5.0f32, 6.0f32] ] -- , [ [1.0f32, 2.0f32, 3.0f32], [4.0f32, 5.0f32, 6.0f32] ] -- ] -- } -- output { -- [ [6i64, 15i64], [6i64, 15i64] ] -- [ [ [-1.000000f64, -2.000000f64, -3.000000f64], [-4.000000f64, -5.000000f64, -6.000000f64] ] -- , [ [-1.000000f64, -2.000000f64, -3.000000f64], [-4.000000f64, -5.000000f64, -6.000000f64] ] -- ] -- } def main [l][m][n] (xsss : [l][m][n]f32): ([l][m]i64, [l][m][n]f64) = unzip (map (\xss -> unzip (map(\(xs : [n]f32) : (i64, [n]f64) -> let (xs_int, xs_neg) = unzip (map(\x -> (i64.f32 x, f64.f32(-x))) xs) in (reduce (+) 0 xs_int, xs_neg) ) xss) ) xsss) futhark-0.25.27/tests/segredomap/ex3.fut000066400000000000000000000007341475065116200200550ustar00rootroot00000000000000-- A redomap, where the "map" part turns a 1D list into a 2D list -- == -- input { -- [[1, 2, 3], [4, 5, 6]] -- [-5, 10] -- } -- output { -- [6, 15] -- [ [ [-4, 11], [-3, 12], [-2, 13] ] -- , [ [-1, 14], [ 0, 15], [ 1, 16] ] -- ] -- } -- def main [m][n][l] (xss : [m][n]i32) (ys : [l]i32): ([m]i32, [m][n][l]i32) = unzip (map(\(xs : [n]i32) : (i32, [n][l]i32) -> let zs = map (\x -> map (\y -> x+y) ys) xs in (reduce (+) 0 xs, zs) ) xss) futhark-0.25.27/tests/segredomap/ex4.fut000066400000000000000000000004361475065116200200550ustar00rootroot00000000000000-- A segmented-redomap using the same array in both maps -- == -- input { -- [1i32, 2i32, 3i32, 4i32] -- } -- output { -- [14i32, 18i32, 22i32, 26i32] -- } entry main [n] (xs : [n]i32) : [n]i32 = map (\y -> let zs = map (\x -> x + y) xs in reduce (+) 0 zs ) xs futhark-0.25.27/tests/segredomap/ex6.fut000066400000000000000000000005241475065116200200550ustar00rootroot00000000000000-- Will compute the sum of the inner arrays, and sum these together. We do -- this for a 3D dimensional array, so end up with an array of results def main [l][m][n] (xsss : [l][m][n]i32): [l]i32 = map (\(xss : [m][n]i32): i32 -> let xss_sums : []i32 = map (\xs -> reduce (+) 0 xs) xss in reduce (+) 0 xss_sums ) xsss futhark-0.25.27/tests/segredomap/invariant-in-reduce.fut000066400000000000000000000006401475065116200232160ustar00rootroot00000000000000-- This example explores what happens when we use an invariant variable in the -- reduction operator of a segmented reduction. -- -- Currently this is not turned into a segmented reduction def add_if_smaller (const : i32) (acc : i32) (x : i32) : i32 = if x < const then acc + x else acc def main [m][n] (xss : [m][n]i32, consts : [m]i32): [m]i32 = map2 (\c xs -> reduce (add_if_smaller c) 0 xs) consts xss futhark-0.25.27/tests/segredomap/map-add-then-reduce.fut000066400000000000000000000011161475065116200230550ustar00rootroot00000000000000-- a 3D version of 'add-then-reduce.fut' -- -- As before, we will get a `map (map (redomap) )` expression, which can be -- turned into a single segmented redomap kernel. -- -- This example is interesting, because the value `y` is bound in the outermost -- map, and special care must be taken to handle such a map-invariant variable. def add_then_reduce [m][n] (xss : [m][n]f32) (y : f32): [m]f32 = let xss' = map (\xs -> map (y+) xs) xss in map (\xs -> reduce_comm (+) 0.0f32 xs) xss' def main [l][m][n] (xsss : [l][m][n]f32, ys : [l]f32): [l][m]f32 = map2 add_then_reduce xsss ys futhark-0.25.27/tests/segredomap/map-segsum-comm.fut000066400000000000000000000004001475065116200223530ustar00rootroot00000000000000-- a 3D version of 'segsum-comm' - we compute a sum over the inner arrays of a -- 3D-array def segsum [m][n] (xss : [m][n]f32): [m]f32 = map (\xs -> reduce_comm (+) 0.0f32 xs) xss def main [l][m][n] (xsss : [l][m][n]f32): [l][m]f32 = map segsum xsss futhark-0.25.27/tests/segredomap/real-nocomm.fut000066400000000000000000000007201475065116200215620ustar00rootroot00000000000000-- A redomap using a real non-commutative reduction -- -- Some of the other examples uses `+`, and make sure the compiler doesn't -- realize it is actually commutative def foo (x1:i32, x2:i32) (y1:i32, y2:i32) : (i32,i32) = if x1 > 0 then (x1, x2) else (y1, x2+y2) def main [m][n] (xss : [m][n]i32): ([m]i32, [m]i32) = unzip (map (\xs -> let ys = map (\x -> (x,x)) xs in reduce_comm foo (0,0) ys ) xss ) futhark-0.25.27/tests/segredomap/reduction-on-list.fut000066400000000000000000000004321475065116200227300ustar00rootroot00000000000000-- The reduction operator works on lists def vec_add [k] (xs : [k]i32) (ys : [k]i32) : [k]i32 = map2 (\x y -> x + y) xs ys def main [l][m][n] (xsss : [l][m][n]i32): [l][n]i32 = let zeros = replicate n 0 in map (\(xss : [m][n]i32) : [n]i32 -> reduce vec_add zeros xss) xsss futhark-0.25.27/tests/segredomap/segsum-comm.fut000066400000000000000000000003421475065116200216050ustar00rootroot00000000000000-- Segmented sum -- == -- input { -- [[1.0f32, 2.0f32, 3.0f32], [4.0f32, 5.0f32, 6.0f32]] -- } -- output { -- [6.0f32, 15.0f32] -- } def main [m][n] (xss : [m][n]f32): [m]f32 = map (\xs -> reduce_comm (+) 0.0f32 xs) xss futhark-0.25.27/tests/segredomap/segsum-nocomm.fut000066400000000000000000000006571475065116200221530ustar00rootroot00000000000000-- Segmented sum, non commutative -- == -- input { -- true -- [[1.0f32, 2.0f32, 3.0f32], [4.0f32, 5.0f32, 6.0f32]] -- } -- output { -- [6.0f32, 15.0f32] -- } -- Add a data-driven branch to prevent the compiler from noticing that -- this is commutative. def add (b: bool) (x : f32) (y : f32): f32 = if b then x + y else x - y def main [m][n] (b: bool) (xss : [m][n]f32): [m]f32 = map (\xs -> reduce (add b) 0.0f32 xs) xss futhark-0.25.27/tests/shapes/000077500000000000000000000000001475065116200157675ustar00rootroot00000000000000futhark-0.25.27/tests/shapes/apply2.fut000066400000000000000000000004031475065116200177130ustar00rootroot00000000000000-- Careful not to require that the two applications return the same size. -- == -- input { 1i64 2i64 } -- output { [0i64, 1i64] [0i64] } def apply2 '^a '^b (f: a -> b) (x: a) (y: a) = let a = f x let b = f y in (b, a) def main n m = apply2 iota n m futhark-0.25.27/tests/shapes/argdims0.fut000066400000000000000000000005471475065116200202230ustar00rootroot00000000000000-- If a size is produced by similar arguments in different places in -- the program, those should be considered distint. -- == -- input { true [1,2,3] } output { [0i64,1i64,2i64] } -- input { false [1,2,3] } output { [0i64,1i64,2i64] } def main (b: bool) (xs: []i32) = if b then let arr = iota (length xs) in arr else let arr = iota (length xs) in arr futhark-0.25.27/tests/shapes/argdims1.fut000066400000000000000000000002171475065116200202160ustar00rootroot00000000000000-- == -- input { 2i64 } -- output { [0i64] [-1] } def main (n: i64) = let foo = iota (n-1) let bar = replicate (n-1) (-1) in (foo, bar) futhark-0.25.27/tests/shapes/ascript-existential.fut000066400000000000000000000001741475065116200225050ustar00rootroot00000000000000-- == -- input { 0i64 } output { 1i64 } -- input { 1i64 } output { 2i64 } def main (n: i64) = length (iota (n+1): []i64) futhark-0.25.27/tests/shapes/assert0.fut000066400000000000000000000000551475065116200200700ustar00rootroot00000000000000def main n : [n]i64 = iota (assert (n>10) n) futhark-0.25.27/tests/shapes/attr0.fut000066400000000000000000000000511475065116200175350ustar00rootroot00000000000000def main n : [n]i64 = iota (#[unsafe] n) futhark-0.25.27/tests/shapes/cached.fut000066400000000000000000000004241475065116200177160ustar00rootroot00000000000000-- Can we recognise an invented size when it occurs later? -- == -- input { 2i64 } -- output { [[0i64, 0i64, 0i64, 0i64], [1i64, 1i64, 1i64, 1i64], [2i64, 2i64, 2i64, 2i64], [3i64, 3i64, 3i64, 3i64]] } def main n = let is = iota (n+n) in map (\x -> replicate (n+n) x) is futhark-0.25.27/tests/shapes/coerce0.fut000066400000000000000000000004601475065116200200270ustar00rootroot00000000000000type~ sized_state [n] = { xs: [n][n]i64, ys: []i32 } type~ state = sized_state [] def state v : state = {xs = [[v,2],[3,4]], ys = [1,2,3]} def size [n] (_: sized_state [n]) = n def f v (arg: state) = size (arg :> sized_state [v]) -- == -- input { 2i64 } output { 2i64 } def main v = f v (state v) futhark-0.25.27/tests/shapes/coerce1.fut000066400000000000000000000002401475065116200200240ustar00rootroot00000000000000-- == -- input { [[1,2],[4,5],[7,8]] } -- output { [[1i32, 4i32, 7i32], [2i32, 5i32, 8i32]] } def main [n] [m] (xss: [n][m]i32) = transpose xss :> [2][3]i32 futhark-0.25.27/tests/shapes/compose.fut000066400000000000000000000004501475065116200201530ustar00rootroot00000000000000-- An argument for why we should not permit composition of functions -- with anonymous return sizes. -- == -- error: "compose".*"iota" def compose '^a '^b '^c (f: a -> b) (g: b -> c) (x: a) (y: a): (c, c) = (g (f x), g (f y)) def main = let foo = compose iota in foo (\x -> length x) 1 2 futhark-0.25.27/tests/shapes/concatmap.fut000066400000000000000000000003411475065116200204520ustar00rootroot00000000000000-- == -- input { [1i64,2i64,3i64] } output { [0i64,0i64,1i64,0i64,1i64,2i64] } def concatmap [n] 'a 'b (f: a -> []b) (as: [n]a) : []b = loop acc = [] for a in as do acc ++ f a def main (xs: []i64) = concatmap iota xs futhark-0.25.27/tests/shapes/curry-shapes.fut000066400000000000000000000007151475065116200211370ustar00rootroot00000000000000-- Test that shape declarations are taken into account even when the -- function is curried. -- -- At the time this test was written, the only way to determine the -- success of this is to inspect the result of internalisation by the -- compiler. -- == -- input { -- [[6,5,2,1], -- [4,5,9,-1]] -- } -- output { -- [[7,6,3,2], -- [5,6,10,0]] -- } def oneToEach [n] (r: [n]i32): [n]i32 = map (+1) r def main(a: [][]i32): [][]i32 = map oneToEach a futhark-0.25.27/tests/shapes/duplicate-shapes-error.fut000066400000000000000000000002631475065116200230720ustar00rootroot00000000000000-- Test that a variable shape annotation in a binding position may not -- be the same as another parameter. -- == -- error: "n" def main(n: f64, a: [n]i32): []i32 = map (+2) a futhark-0.25.27/tests/shapes/emptydim0.fut000066400000000000000000000004431475065116200204200ustar00rootroot00000000000000-- == -- input { empty([0][1]i32) [[1]] } output { empty([0][1]i32) [[1]] } -- input { [[1]] empty([0][1]i32) } output { [[1]] empty([0][1]i32) } -- compiled input { [[1]] [[1,2]] } error: . -- compiled input { [[1,2]] [[1]] } error: . def main [n] (xs: [][n]i32) (ys: [][n]i32) = (xs, ys) futhark-0.25.27/tests/shapes/emptydim1.fut000066400000000000000000000005351475065116200204230ustar00rootroot00000000000000-- == -- input { empty([0][1]i32) [[1]] } output { empty([0][1]i32) [[1]] } -- input { [[1]] empty([0][1]i32) } output { [[1]] empty([0][1]i32) } -- compiled input { [[1]] [[1,2]] } error: . -- input { [[1]] [[2]] } output { [[1]] [[2]] } def f [n] (xs: [][n]i32) = \(ys: [][n]i32) -> (xs, ys) def main [n] (xs: [][n]i32) (ys: [][n]i32) = f xs ys futhark-0.25.27/tests/shapes/emptydim2.fut000066400000000000000000000003521475065116200204210ustar00rootroot00000000000000-- == -- input { 1i64 empty([0]i32) } output { empty([1][0]i32) } -- input { 0i64 [1] } output { empty([0][1]i32) } -- input { 0i64 empty([0]i32) } output { empty([0][0]i32) } def main (n: i64) (xs: []i32) = replicate n xs futhark-0.25.27/tests/shapes/emptydim3.fut000066400000000000000000000002111475065116200204140ustar00rootroot00000000000000-- == -- input { 2i64 } output { 2i64 empty([0][2]i32) } def empty 'a (x: i64) = (x, [] : [0]a) def main x : (i64, [][x]i32) = empty x futhark-0.25.27/tests/shapes/entry-constants-sum.fut000066400000000000000000000001051475065116200224600ustar00rootroot00000000000000type t = #foo ([5]i32) | #bar ([5]i32) def main (x: t) (y: i32) = 0 futhark-0.25.27/tests/shapes/entry-constants.fut000066400000000000000000000005201475065116200216570ustar00rootroot00000000000000-- Dimension declarations on entry points can refer to constants. -- == -- input { [1i64,2i64,3i64] } output { [0i64,1i64] } -- compiled input { [1i64,2i64] } error: invalid|match -- compiled input { [1i64,3i64,2i64] } error: invalid|match def three: i64 = 3 def two: i64 = 2 def main(a: [three]i64): [two]i64 = iota a[1] :> [two]i64 futhark-0.25.27/tests/shapes/entry-lifted.fut000066400000000000000000000002401475065116200211110ustar00rootroot00000000000000module m : { type~ t val mk : i64 -> *t } = { type~ t = #foo ([]i32) | #bar ([]i32) def mk (n: i64) : *t = #foo (replicate n 0) } def main n = m.mk n futhark-0.25.27/tests/shapes/error0.fut000066400000000000000000000002031475065116200177130ustar00rootroot00000000000000-- Actually check against the function return type, too. -- == -- error: 10 def main (xs: []i32) (ys: []i32) : [10]i32 = xs ++ ys futhark-0.25.27/tests/shapes/error1.fut000066400000000000000000000001461475065116200177220ustar00rootroot00000000000000-- Cannot magically change sizes. -- == -- error: \[10\]i32 def main [n] (xs: [n]i32) : [10]i32 = xs futhark-0.25.27/tests/shapes/error10.fut000066400000000000000000000002061475065116200177770ustar00rootroot00000000000000-- == -- error: Causality check def main (b: bool) (xs: []i32) = let a = [] : [][]i32 let b = [filter (>0) xs] in a[0] == b[0] futhark-0.25.27/tests/shapes/error12.fut000066400000000000000000000005711475065116200200060ustar00rootroot00000000000000-- No hiding sizes inside sum types! -- == -- error: Cannot apply "genarr" type sometype 't = #someval t def geni64 (maxsize : i64) : sometype i64 = #someval maxsize def genarr 'elm (genelm: i64 -> sometype elm) (ownsize : i64) : sometype ([ownsize](sometype elm)) = #someval (tabulate ownsize genelm) def main = genarr (genarr geni64) 1 futhark-0.25.27/tests/shapes/error13.fut000066400000000000000000000001331475065116200200010ustar00rootroot00000000000000-- No hiding a size inside a function! -- == -- error: anonymous type^ t = []bool -> bool futhark-0.25.27/tests/shapes/error2.fut000066400000000000000000000002521475065116200177210ustar00rootroot00000000000000-- An implied/unknown size can still only be inferred to have one concrete size. -- == -- error: (\[1\]i32, \[2\]i32) def main (xs: []i32) : ([1]i32, [2]i32) = (xs, xs) futhark-0.25.27/tests/shapes/error3.fut000066400000000000000000000003311475065116200177200ustar00rootroot00000000000000-- The sizes of a lambda parameter can percolate out to a let-binding. -- == -- error: "n" and "m" do not match def f [n] (xs: [n]i32) = \(ys: [n]i32) -> (xs, ys) def main [n][m] (xs: [n]i32) (ys: [m]i32) = f xs ys futhark-0.25.27/tests/shapes/error4.fut000066400000000000000000000003251475065116200177240ustar00rootroot00000000000000-- We cannot just ignore constraints imposed by a higher-order function. -- == -- error: Sizes.*"n".*do not match def f (g: (n: i64) -> [n]i32) (l: i64): i32 = (g l)[0] def main = f (\n : []i64 -> iota (n+1)) futhark-0.25.27/tests/shapes/error5.fut000066400000000000000000000002541475065116200177260ustar00rootroot00000000000000-- A function 'a -> a' must be size-preserving. -- == -- error: Occurs check def ap 'a (f: a -> a) (x: a) = f x def main [n] (arr: [n]i32) = ap (\xs -> xs ++ xs) arr futhark-0.25.27/tests/shapes/error6.fut000066400000000000000000000002371475065116200177300ustar00rootroot00000000000000-- Respect sizes based on named parameters. -- == -- error: "n" def ap (f: (n: i64) -> [n]i32) (k: i64) : [k]i32 = f k def main = ap (\n -> iota (n+1)) 10 futhark-0.25.27/tests/shapes/error7.fut000066400000000000000000000002761475065116200177340ustar00rootroot00000000000000-- Ambiguous size of sum type. -- == -- error: Ambiguous.*anonymous size in type expression type~ sum = #foo ([]i32) | #bar ([]i32) def main (xs: *[]i32) = let v : sum = #foo xs in xs futhark-0.25.27/tests/shapes/error8.fut000066400000000000000000000001531475065116200177270ustar00rootroot00000000000000-- == -- error: Entry point def empty 'a (x: i32) = (x, [] : [0]a) def main x : (i32, [][]i32) = empty x futhark-0.25.27/tests/shapes/error9.fut000066400000000000000000000004651475065116200177360ustar00rootroot00000000000000-- A function with constraints based on named parameters cannot be -- passed to a higher-order function that does not obey those -- constraints. -- == -- error: do not match def ap (f: i64 -> []i32 -> i32) (k: i32) : i32 = f 0 [k] def g (n: i64) (xs: [n]i32) : i32 = xs[n-1] def main (k: i32) = ap g k futhark-0.25.27/tests/shapes/eta-expand0.fut000066400000000000000000000003341475065116200206150ustar00rootroot00000000000000-- == -- input { 2 2 [1,2,3,4] } -- output { 2i64 2i64 [[1,2],[3,4]] } def unpack_2d 't n m : [i64.i32 n*i64.i32 m]t -> [][]t = unflatten def main n m xs = let [x][y] (ys: [x][y]i32) = unpack_2d n m xs in (x,y,ys) futhark-0.25.27/tests/shapes/eta-expand1.fut000066400000000000000000000002731475065116200206200ustar00rootroot00000000000000-- == -- input { 2 2 [1,2,3,4] } -- output { [[1,3],[2,4]] } def unpack_2d 't n m : [i64.i32 n*i64.i32 m]t -> [][]t = unflatten def main n m xs: [][]i32 = unpack_2d n m xs |> transpose futhark-0.25.27/tests/shapes/existential-apply-error.fut000066400000000000000000000003011475065116200233040ustar00rootroot00000000000000-- An existential size in an apply function returning an unlifted type is not fine. -- == -- error: existential def apply 'a 'b (f: a -> b) (x: a): b = f x def main (n: i32) = apply iota n futhark-0.25.27/tests/shapes/existential-apply.fut000066400000000000000000000003161475065116200221630ustar00rootroot00000000000000-- An existential size in an apply function returning a lifted type is fine. -- == -- input { 2i64 } output { [0i64,1i64] } def apply 'a '^b (f: a -> b) (x: a): b = f x def main (n: i64) = apply iota n futhark-0.25.27/tests/shapes/existential-argument.fut000066400000000000000000000003771475065116200226670ustar00rootroot00000000000000-- Sizes obtained with existantialy bounded sizes should not be calculated -- == -- input { 2i64 } output { [0i64, 1i64, 0i64, 1i64] } def double_eval 't (f : () -> []t) : []t = f () ++ f () def main (n:i64) : []i64 = double_eval (\_ -> iota n) futhark-0.25.27/tests/shapes/existential-hof.fut000066400000000000000000000002551475065116200216140ustar00rootroot00000000000000-- An existential produced through a higher-order function. -- == -- input { [0, 1, 2] } output { 2i64 } def main (xs: []i32) = let ys = xs |> filter (>0) in length ys futhark-0.25.27/tests/shapes/existential-ret.fut000066400000000000000000000003251475065116200216300ustar00rootroot00000000000000-- Two distinct applications of the function should not interfere with -- each other. -- == -- input { [1,2,3] [4,-2,1] } -- output { 3i64 2i64 } def f xs = length (filter (>0) xs) def main xs ys = (f xs, f ys) futhark-0.25.27/tests/shapes/explicit-shapes-error1.fut000066400000000000000000000001561475065116200230230ustar00rootroot00000000000000-- An explicitly quantified size must be used where it is bound. -- == -- error: n def main [n] (x: i32) = n futhark-0.25.27/tests/shapes/explicit-shapes0.fut000066400000000000000000000001471475065116200216730ustar00rootroot00000000000000-- Explicit shape quantification. -- = -- input { [1,2,3] } output { 3 } def main [n] (x: [n]i32) = n futhark-0.25.27/tests/shapes/extlet0.fut000066400000000000000000000002251475065116200200730ustar00rootroot00000000000000-- A type becomes existential because a name goes out of scope. -- == -- input { 1i64 } output { 1i64 } def main n = length (let m = n in iota m) futhark-0.25.27/tests/shapes/extlet1.fut000066400000000000000000000002441475065116200200750ustar00rootroot00000000000000-- A type becomes existential because a name goes out of scope, -- trickier. -- == -- input { 1i64 } output { 2i64 } def main n = length (let m = n+1 in iota m) futhark-0.25.27/tests/shapes/field-in-size.fut000066400000000000000000000001611475065116200211440ustar00rootroot00000000000000-- Allow to access argument field as size for return type -- == def f (p: {a:i64,b:bool}) : [p.a]i64 = iota p.a futhark-0.25.27/tests/shapes/flatmap.fut000066400000000000000000000002631475065116200201340ustar00rootroot00000000000000-- == -- input { [1,2] } output { [1,2,2,3] } def flatmap [n] [m] 'a 'b (f: a -> [m]b) (as: [n]a) : []b = flatten (map f as) def main (xs: []i32) = flatmap (\x -> [x,x+1]) xs futhark-0.25.27/tests/shapes/funshape0.fut000066400000000000000000000002461475065116200204020ustar00rootroot00000000000000-- == -- input { [1,-2,3] } output { 3i64 } def f [n] (_: [n]i32 -> i32) : [n]i32 -> i64 = let m = n + 1 in \_ -> m def main xs = filter (>0) xs |> f (\_ -> 0) futhark-0.25.27/tests/shapes/funshape1.fut000066400000000000000000000002321475065116200203760ustar00rootroot00000000000000-- == -- error: Causality check def f [n] (_: [n]i32 -> i32) : [n]i32 -> i64 = let m = n + 1 in \_ -> m def main xs = f (\_ -> 0) <| filter (>0) xs futhark-0.25.27/tests/shapes/funshape2.fut000066400000000000000000000001261475065116200204010ustar00rootroot00000000000000-- == -- error: scope violation def main xs = (\f' -> f' (filter (>0) xs)) (\_ -> 0) futhark-0.25.27/tests/shapes/funshape3.fut000066400000000000000000000002261475065116200204030ustar00rootroot00000000000000-- == -- input { 5i64 } output { 7i64 } def f [n] (_: [n]i64) (_: [n]i64 -> i32, _: [n]i64) = n def main x = f (iota (x+2)) (\_ -> 0, iota (x+2)) futhark-0.25.27/tests/shapes/funshape4.fut000066400000000000000000000003361475065116200204060ustar00rootroot00000000000000-- Left-side operands should be evaluated before before right-hand -- operands. -- == -- input { 2i64 } output { [[2i64,2i64,2i64]] } def f (x: i64) : [][]i64 = [replicate (x+1) 0] def main x = f x |> map (map (+2)) futhark-0.25.27/tests/shapes/funshape5.fut000066400000000000000000000002311475065116200204010ustar00rootroot00000000000000-- == -- error: Entry point functions may not be polymorphic def main indices (cs: *[](i32,i32)) j = map (\k -> (indices[j],k)) <| drop (j+1) indices futhark-0.25.27/tests/shapes/funshape6.fut000066400000000000000000000002101475065116200203770ustar00rootroot00000000000000-- Based on issue 1351. -- == -- input { [[1.0,2.0,3.0],[4.0,5.0,6.0]] 0i64 4i64 } def main (xs: [][]f64) i j = (.[i:j]) <| iota (i+j) futhark-0.25.27/tests/shapes/funshape7.fut000066400000000000000000000001211475065116200204010ustar00rootroot00000000000000-- == -- error: Causality check entry main xs mat = map (filter (>0) xs ++) mat futhark-0.25.27/tests/shapes/hof0.fut000066400000000000000000000005151475065116200173440ustar00rootroot00000000000000-- A dubious test - what we want to ensure is an absence of too many -- dynamic casts just after internalisation. -- == -- structure internalised { Assert 2 } def f [k] 'a (dest: [k]a) (f: [k]a -> [k]a) : [k]a = f dest def operation [n][m] (b: [m]i32) (as: [n][m]i32) = copy as with [0] = b def main xs b = f xs (operation b) futhark-0.25.27/tests/shapes/hof1.fut000066400000000000000000000011141475065116200173410ustar00rootroot00000000000000type^ network 'p = { zero: p , sum: (k:i64) -> [k]p -> p } def chain 'p1 'p2 (a:network p1) (b:network p2) : network (p1, p2) = let zero = (a.zero, b.zero) let sum k ps = let (as, bs) = unzip ps in (a.sum k as, b.sum k bs) in {zero, sum} def linear [m][n] (weights:[m][n]f32) : network ([m][n]f32) = let zero = replicate m (replicate n 0) let sum k (ps:[k][m][n]f32) = map (map (reduce (+) 0)) (map transpose (transpose ps)) in {zero, sum} def main [m][n] (ws1: [n][m]f32) (ws2: [n][n]f32) = let inet = chain (linear ws1) (linear ws2) in inet.zero futhark-0.25.27/tests/shapes/if0.fut000066400000000000000000000003141475065116200171630ustar00rootroot00000000000000-- Inferring an invariant size for a branch. -- == -- input { [1,2,3] [2,3,4] } def main [n] (xs: [n]i32) (ys: *[n]i32) : [n]i32 = if true then xs else loop ys for i < n do ys with [i] = ys[i] + 1 futhark-0.25.27/tests/shapes/if1.fut000066400000000000000000000001751475065116200171710ustar00rootroot00000000000000def main [n][m] (world: [n][m]i32): [n][m]i32 = map2 (\(c: [m]i32) i -> if i == n-1 then world[n-1] else c) world (iota n) futhark-0.25.27/tests/shapes/if2.fut000066400000000000000000000003131475065116200171640ustar00rootroot00000000000000-- Looking at the size of an existential branch. -- == -- input { true 1i64 2i64 } output { 1i64 } -- input { false 1i64 2i64 } output { 2i64 } def main b n m = length (if b then iota n else iota m) futhark-0.25.27/tests/shapes/if3.fut000066400000000000000000000002401475065116200171640ustar00rootroot00000000000000-- Size-variant branches don't have just any size. -- == -- error: \[n\].*\[m\] def main (b: bool) (n: i64) (m: i64) : [2]i64 = if b then iota n else iota m futhark-0.25.27/tests/shapes/if4.fut000066400000000000000000000003721475065116200171730ustar00rootroot00000000000000-- Differing sizes, but the same across a single branch. -- == -- input { false } output { [1,2] [3,4] } def main (b: bool) = let (xs, ys) = if b then ([1,2,3], [4,5,6]) else ([1,2], [3,4]) in unzip (zip xs ys) futhark-0.25.27/tests/shapes/implicit-shape-use.fut000066400000000000000000000042661475065116200222210ustar00rootroot00000000000000-- Extraction from generic pricer. Uses shape declarations in ways -- that were at one point problematic. -- -- == -- input { -- 3i64 -- [[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1]] -- [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0] -- } -- output { -- [[109.0, 140.0, 171.0], -- [-109.0, -140.0, -171.0], -- [0.0, 0.0, 0.0], -- [0.0, 0.0, 0.0], -- [0.0, 0.0, 0.0]] -- } def brownianBridgeDates [num_dates] (bb_inds: [3][num_dates]i32) (bb_data: [3][num_dates]f64) (gauss: [num_dates]f64): [num_dates]f64 = let bi = bb_inds[0] let li = bb_inds[1] let ri = bb_inds[2] let sd = bb_data[0] let lw = bb_data[1] let rw = bb_data[2] let bbrow = replicate num_dates 0.0 let bbrow[ bi[0]-1 ] = sd[0] * gauss[0] in let bbrow = loop (bbrow) for i < num_dates-1 do -- use i+1 since i in 1 .. num_dates-1 let j = li[i+1] - 1 let k = ri[i+1] - 1 let l = bi[i+1] - 1 let wk = bbrow[k] let zi = gauss[i+1] let tmp= rw[i+1] * wk + sd[i+1] * zi let bbrow[ l ] = if( j == -1) then tmp else tmp + lw[i+1] * bbrow[j] in bbrow -- This can be written as map-reduce, but it -- needs delayed arrays to be mapped nicely! in loop (bbrow) for ii < num_dates-1 do let i = num_dates - (ii+1) let bbrow[i] = bbrow[i] - bbrow[i-1] in bbrow def brownianBridge [num_dates] (bb_inds: [3][num_dates]i32) (bb_data: [3][num_dates]f64) (gaussian_arr: []f64) = let gauss2d = unflatten gaussian_arr let gauss2dT = transpose gauss2d in transpose ( map (brownianBridgeDates bb_inds bb_data) gauss2dT ) def main [num_dates] (num_und: i64) (bb_inds: [3][num_dates]i32) (arr: [num_dates*num_und]f64): [][]f64 = let bb_data= map (\(row: []i32) -> map f64.i32 row ) (bb_inds ) let bb_mat = brownianBridge bb_inds bb_data arr in bb_mat futhark-0.25.27/tests/shapes/implicit-shape-use2.fut000066400000000000000000000012441475065116200222740ustar00rootroot00000000000000-- == -- input { -- [[1.0,1.0,1.0,1.0,1.0],[1.0,1.0,1.0,1.0,1.0],[1.0,1.0,1.0,1.0,1.0]] -- [[2.0,2.0,2.0,2.0,2.0],[2.0,2.0,2.0,2.0,2.0],[2.0,2.0,2.0,2.0,2.0]] -- } -- output { -- [[2.0,2.0,2.0,2.0,2.0],[2.0,2.0,2.0,2.0,2.0],[2.0,2.0,2.0,2.0,2.0]] -- } def combineVs [num_und] (n_row: [num_und]f64) (vol_row: [num_und]f64): [num_und]f64 = map2 (*) n_row vol_row def mkPrices [num_dates][num_und] (md_vols: [num_dates][num_und]f64, noises: [num_dates][num_und]f64 ): [num_dates][num_und]f64 = map2 combineVs noises md_vols def main (vol: [][]f64) (noises: [][]f64): [][]f64 = mkPrices(vol,noises) futhark-0.25.27/tests/shapes/inference6.fut000066400000000000000000000003411475065116200205310ustar00rootroot00000000000000-- It is OK to infer stricter constraints than what is -- provided by explicit size parameters, if present. -- == -- input { [1,2] } output { [1,2] } -- compiled input { [1,2,3] } error: def main (xs: []i32) : [2]i32 = xs futhark-0.25.27/tests/shapes/inference7.fut000066400000000000000000000003431475065116200205340ustar00rootroot00000000000000-- Just because a top-level binding tries to hide its size, that does -- not mean it gets to have a blank size. -- == -- input { 2i64 } output { [0i64,1i64] } def arr : []i64 = iota 10 def main (n: i64) = copy (take n arr) futhark-0.25.27/tests/shapes/inference8.fut000066400000000000000000000003761475065116200205430ustar00rootroot00000000000000-- Just because a top-level binding tries to hide its size (which is -- existential), that does not mean it gets to have a blank size. -- == -- input { 2i64 } output { [0i64,1i64] } def arr : []i64 = iota (10+2) def main (n: i64) = copy (take n arr) futhark-0.25.27/tests/shapes/inference9.fut000066400000000000000000000004011475065116200205310ustar00rootroot00000000000000-- Inferring a functional parameter type that refers to a size that is -- not constructively provided until a later parameter. -- == -- input { [1,2,3,4,5,6,7,8] } def f [n] sorter (xs: [n]i32) : [n]i32 = sorter xs def main [n] (xs: [n]i32) = f id xs futhark-0.25.27/tests/shapes/irregular0.fut000066400000000000000000000002071475065116200205620ustar00rootroot00000000000000-- Irregularity must be detected! -- == -- input {0} error: def main (x: i32) = ([([1], [2,3]), ([2,3], [1]) :> ([1]i32, [2]i32)])[1] futhark-0.25.27/tests/shapes/known-shape.fut000066400000000000000000000010561475065116200207430ustar00rootroot00000000000000-- An existing variable can be used as a shape declaration. -- == -- input { -- 5i64 -- 4i64 -- 8i64 -- } -- output { -- [[6, 7, 8, 9, 10, 11, 12, 13], -- [7, 8, 9, 10, 11, 12, 13, 14], -- [8, 9, 10, 11, 12, 13, 14, 15], -- [9, 10, 11, 12, 13, 14, 15, 16], -- [10, 11, 12, 13, 14, 15, 16, 17]] -- } def main (n: i64) (m: i64) (k: i64): [n][k]i32 = let a = replicate n (iota m) in map2 (\(i: i64) (r: [m]i64): [k]i32 -> let x = reduce (+) 0 r in map i32.i64 (map (+i) (map (+x) (iota(k))))) (iota n) a futhark-0.25.27/tests/shapes/lambda-return.fut000066400000000000000000000011341475065116200212430ustar00rootroot00000000000000-- Shape annotation in lambda return type. -- -- This is intended to avoid shape slices. -- == -- tags { no_opencl no_cuda no_hip no_pyopencl } -- input { -- [[1,2,3], -- [4,5,6], -- [7,8,9]] -- 3i64 -- } -- output { -- [[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3], -- [4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6], -- [7, 8, 9, 7, 8, 9, 7, 8, 9, 7, 8, 9]] -- } def multiply (a: []i32) (x: i64) (n: i64): [n]i32 = (loop (a) for i < x-1 do concat a a) :> [n]i32 def main [m] (a: [m][]i32) (x: i64): [][]i32 = let n = m * (2 ** (x-1)) in map (\(r: []i32): [n]i32 -> multiply r x n) a futhark-0.25.27/tests/shapes/lambda-return2.fut000066400000000000000000000005541475065116200213320ustar00rootroot00000000000000-- This failed at one point during type-checking because the k was not -- visible in the map return type. def main [n][m][k] (a: [n][m][k]i32): [n][k]i32 = let acc_expanded = replicate n (replicate k 0) in loop (acc_expanded) for i < m do map2 (\(acc: [k]i32) (a_r: [m][k]i32): [k]i32 -> map2 (+) acc (a_r[i]) ) (acc_expanded) a futhark-0.25.27/tests/shapes/letshape0.fut000066400000000000000000000001641475065116200203750ustar00rootroot00000000000000-- == -- input { [1,2,-3] } output { 2i64 } def main (xs: []i32) = let [n] (xs': [n]i32) = filter (>0) xs in n futhark-0.25.27/tests/shapes/letshape1.fut000066400000000000000000000001401475065116200203700ustar00rootroot00000000000000-- == -- error: Size \[n\] unused def main (xs: []i32) = let [n] xs' = filter (>0) xs in n futhark-0.25.27/tests/shapes/letshape10.fut000066400000000000000000000003331475065116200204540ustar00rootroot00000000000000-- == -- input {true} output { 1i64 } module type mt = { type arr [n] val mk : bool -> arr [] } module m : mt = { type arr [n] = [n]bool def mk b = [b] } def main b = let [n] (_: m.arr [n]) = m.mk b in n futhark-0.25.27/tests/shapes/letshape11.fut000066400000000000000000000001671475065116200204620ustar00rootroot00000000000000-- == -- input { [1,2,-3] } output { [1,2] } def main (xs: []i32) = let [n] (xs': [n]i32) = filter (>0) xs in xs' futhark-0.25.27/tests/shapes/letshape12.fut000066400000000000000000000002111475065116200204510ustar00rootroot00000000000000-- #2210 -- == -- error: Unknown name "n" def f : [42]f32 = let [n] turtle: f32 -> [n]f32 = \(x: f32) -> replicate n x in turtle 42 futhark-0.25.27/tests/shapes/letshape2.fut000066400000000000000000000001621475065116200203750ustar00rootroot00000000000000-- == -- error: "n" and "m" do not match def main n m = let [k] (xss: [k][k]i64) = replicate n (iota m) in k futhark-0.25.27/tests/shapes/letshape3.fut000066400000000000000000000001261475065116200203760ustar00rootroot00000000000000-- == -- input { } output { 0i64 } def main = let [k] (xss: [k][k]i64) = [] in k futhark-0.25.27/tests/shapes/letshape4.fut000066400000000000000000000002341475065116200203770ustar00rootroot00000000000000-- The monomorphiser forgot to keep around the 'n' in this program at -- one point. def n = 1i64 def vec 't arr = arr : [n]t def main (xs: []i32) = vec xs futhark-0.25.27/tests/shapes/letshape5.fut000066400000000000000000000002201475065116200203730ustar00rootroot00000000000000-- A size goes out of scope. -- == -- input { 2i64 } -- output { [0i64,1i64,2i64] } def main (n: i64) : [n+1]i64 = let m = n + 1 in iota m futhark-0.25.27/tests/shapes/letshape6.fut000066400000000000000000000003611475065116200204020ustar00rootroot00000000000000-- A size goes out of scope and the defunctionaliser doesn't mess it -- up. Inspired by issue #848. -- == -- input { [1,2,3] } output { [2,3] } def main [n] (xs: [n]i32) = let res = let m = n - 1 in map (+1) (take m xs) in res futhark-0.25.27/tests/shapes/letshape7.fut000066400000000000000000000001371475065116200204040ustar00rootroot00000000000000-- == -- error: Ambiguous size "m". def main = let [n][m] (xss: [n][m]i64) = [] in (n, m) futhark-0.25.27/tests/shapes/letshape8.fut000066400000000000000000000001611475065116200204020ustar00rootroot00000000000000-- == -- error: Cannot bind \[n\] def main = let [n] (f: [n]bool -> [n]bool) = (\(xs: [10]bool) -> xs) in n futhark-0.25.27/tests/shapes/letshape9.fut000066400000000000000000000003611475065116200204050ustar00rootroot00000000000000-- == -- input {} output { 1i64 2i64 } def main = let [n] (_: [n]i32, f: [n]bool -> [n]bool) = (replicate 1 0, \(xs: [1]bool) -> xs) let [m] (f: [m]bool -> [m]bool, _: [m]i32) = (\(xs: [2]bool) -> xs, replicate 2 0) in (n, m) futhark-0.25.27/tests/shapes/local0.fut000066400000000000000000000004441475065116200176630ustar00rootroot00000000000000-- A location function with some shape stuff. Crashed the -- defunctionaliser once. def getneighbors (_: i32): []f64 = [] def main (x: i32) = let objxy = getneighbors x let flikelihood (_: i32) : []i64 = let ind = map i64.f64 objxy in ind let res = flikelihood x in res futhark-0.25.27/tests/shapes/loop0.fut000066400000000000000000000002611475065116200175370ustar00rootroot00000000000000-- Inferring an invariant size for a loop. -- == -- input { [1,2,3] } -- input { [2,3,4] } def main [n] (xs: *[n]i32) : [n]i32 = loop xs for i < n do xs with [i] = xs[i] + 1 futhark-0.25.27/tests/shapes/loop1.fut000066400000000000000000000002101475065116200175320ustar00rootroot00000000000000-- Loops impose sizes. -- == -- error: \[10\]i32 def main [n] (xs: *[n]i32) : [10]i32 = loop xs for i < n do xs with [i] = xs[i] + 1 futhark-0.25.27/tests/shapes/loop10.fut000066400000000000000000000002121475065116200176140ustar00rootroot00000000000000-- This crashed lambda lifting at one point. def main n = (\n -> let res = loop xs = replicate n true for i < 10 do xs ++ xs in res) n futhark-0.25.27/tests/shapes/loop11.fut000066400000000000000000000002631475065116200176230ustar00rootroot00000000000000-- == -- error: Loop body does not have expected type.*"n" and "3" do not match def main n: [n]bool = loop (arr: [n]bool) = replicate n true for i < 10 do replicate 3 true futhark-0.25.27/tests/shapes/loop12.fut000066400000000000000000000005211475065116200176210ustar00rootroot00000000000000-- Type-checking the body may induce extra constraints on the -- parameters that makes the return type invalid wrt. the parameter. -- -- Based on issue 1565. -- == -- error: Loop body does not have expected type let main [n] (xs: [n]i32) (ys: [n]i32) = loop (xs,ys) for i < 10 do let zs = filter (>0) (map2 (+) xs ys) in (xs, zs) futhark-0.25.27/tests/shapes/loop13.fut000066400000000000000000000002041475065116200176200ustar00rootroot00000000000000-- Nonsense, but ought to type-check. def main (xs : [10]i64) = loop xs for i < 10 do let n = 10 + i in (xs :> [n]i64) ++ [i] futhark-0.25.27/tests/shapes/loop14.fut000066400000000000000000000001101475065116200176150ustar00rootroot00000000000000def main (n: i64) = loop (xs: []i64) = iota n for i < n do (xs ++ xs) futhark-0.25.27/tests/shapes/loop15.fut000066400000000000000000000005621475065116200176310ustar00rootroot00000000000000-- Based on #2144. -- == -- error: inside the loop body def hide [m] (aoa_shp: [m]i32) : ?[EXT].[EXT]i32 = filter (!=0) aoa_shp entry main [n] (iv : [n]bool) = loop (wrk, wrk_shp) = (iv, [(i32.i64 n)]) for _i < 2 do let flags = hide wrk_shp let min = zip flags wrk -- There is no way 'wrk' can have the right shape here. in (wrk ++ wrk, wrk_shp ++ wrk_shp) futhark-0.25.27/tests/shapes/loop2.fut000066400000000000000000000003101475065116200175340ustar00rootroot00000000000000-- The initial values for merge parameters must have the right size. -- == -- error: \[n\]i32 def main [m] (xs: [m]i32) (n: i64) = loop (ys: [n]i32) = xs for _i < 3i32 do replicate n (ys[0]+1) futhark-0.25.27/tests/shapes/loop3.fut000066400000000000000000000004011475065116200175360ustar00rootroot00000000000000-- Just because a loop has an unpredictable size, that does not mean -- it has *any* size! And in particular, not the size of the initial value. -- == -- error: \[1\]i32 def main (n: i32) : [1]i32 = loop xs = replicate 1 0i32 for _i < n do xs ++ xs futhark-0.25.27/tests/shapes/loop4.fut000066400000000000000000000004201475065116200175400ustar00rootroot00000000000000-- It is legitimate for a loop to have an undefined size. -- == -- input { 4 } -- output { [0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32, 0i32] } def main (n: i32) : []i32 = loop xs = replicate 1 0 for _i < n do xs ++ xs futhark-0.25.27/tests/shapes/loop5.fut000066400000000000000000000003121475065116200175410ustar00rootroot00000000000000-- More complex loop. We must infer that the size here is unchanging. -- == -- input { [1,2,3] } output { [1,2,3] } def main [n] (s: [n]i32) : [n]i32 = loop s for _i < 10 do tabulate n (\i -> s[i]) futhark-0.25.27/tests/shapes/loop6.fut000066400000000000000000000004121475065116200175430ustar00rootroot00000000000000-- Make sure loop parameters are not existential while checking the -- loop. def main [n][m] (A: [n][m]f32): [n][m]f32 = loop A for i < n do let irow = A[0] -- Keep these let v1 = irow[i] -- separate. in map (\k -> map (\j -> v1) (iota m)) (iota n) futhark-0.25.27/tests/shapes/loop7.fut000066400000000000000000000005201475065116200175440ustar00rootroot00000000000000-- Infer correctly that the loop parameter 'ys' has a variant size. -- == -- input { [0i64,1i64] } output { 2i64 [0i64] } def first_nonempty (f: i64 -> []i64) xs = loop (i, ys) = (0, [] : []i64) while null ys && i < length xs do let i' = i+1 let ys' = f xs[i] in (i', ys') def main [n] (xs: [n]i64) = first_nonempty iota xs futhark-0.25.27/tests/shapes/loop8.fut000066400000000000000000000001641475065116200175510ustar00rootroot00000000000000-- == -- error: Causality def main : []i32 = (.0) <| loop (xs: []i32, j) = ([], 0) for i < 10 do (xs ++ xs, j+1) futhark-0.25.27/tests/shapes/loop9.fut000066400000000000000000000001521475065116200175470ustar00rootroot00000000000000-- == -- error: Causality def main : []i32 = ([1]++) <| loop (xs: []i32) = [] for i < 10 do (xs ++ xs) futhark-0.25.27/tests/shapes/match0.fut000066400000000000000000000004771475065116200176730ustar00rootroot00000000000000-- Looking at the size of an existential match. -- == -- input { 0 } output { 1i64 } -- input { 1 } output { 2i64 } -- input { 2 } output { 3i64 } -- input { 3 } output { 9i64 } def main i = length (match i case 0 -> iota 1 case 1 -> iota 2 case 2 -> iota 3 case _ -> iota 9) futhark-0.25.27/tests/shapes/match1.fut000066400000000000000000000004171475065116200176660ustar00rootroot00000000000000-- Looking at the size of an existential pattern match. -- == -- input { true 1i64 2i64 } output { 1i64 } -- input { false 1i64 2i64 } output { 2i64 } def main b n m = let arr = match b case true -> iota n case false -> iota m in length arr futhark-0.25.27/tests/shapes/match2.fut000066400000000000000000000002271475065116200176660ustar00rootroot00000000000000-- Size hidden by match. -- == -- input { 2i64 } output { 2i64 } def main (n: i64) = let arr = match n case m -> iota m in length arr futhark-0.25.27/tests/shapes/modules0.fut000066400000000000000000000003261475065116200202400ustar00rootroot00000000000000-- Abstract types may not have anonymous sizes. -- == -- error: info module edge_handling (mapper: {type info}) = { def handle (i: i32) (info: mapper.info) = info } module m = edge_handling {type info = []f32} futhark-0.25.27/tests/shapes/modules1.fut000066400000000000000000000006001475065116200202340ustar00rootroot00000000000000-- It is not allowed to create an opaque type whose size parameters -- are not used in array dimensions. -- == -- error: is not used constructively module m = { type^ t [n] = [n]i32 -> i64 def f [n] (_: t [n]) = 0 def mk (n: i64) : t [n] = \(xs: [n]i32) -> n } : { type^ t [n] val f [n] : (x: t [n]) -> i32 val mk : (n: i64) -> t [n] } def main x = (x+2) |> m.mk |> m.f futhark-0.25.27/tests/shapes/modules2.fut000066400000000000000000000001461475065116200202420ustar00rootroot00000000000000module type mt = { type^ t [n] = [n]i32 -> i32 } module m : mt = { type^ t [n] = [n]i32 -> i32 } futhark-0.25.27/tests/shapes/negative-position-shape0.fut000066400000000000000000000003331475065116200233300ustar00rootroot00000000000000-- It should be allowed to have a shape parameter that is only used in -- negative position in the parameter types. -- == -- input {} output { 3i64 } def f [n] (_g: i32 -> [n]i32) : i64 = n def main = f (replicate 3) futhark-0.25.27/tests/shapes/negative-position-shape1.fut000066400000000000000000000004511475065116200233320ustar00rootroot00000000000000-- It should not be allowed to have a shape parameter that is only -- used in negative position in the parameter types, but only if that -- size is unambiguous. -- == -- error: Ambiguous size.*instantiated size parameter of "f" def f [n] (g: [n]i64 -> i64) : i64 = n def main = f (\xs -> xs[0]) futhark-0.25.27/tests/shapes/negative-position-shape2.fut000066400000000000000000000004531475065116200233350ustar00rootroot00000000000000-- A shape parameter may be used before it has been in positive -- position at least once! -- == -- input { [1,2,3] } output { [3i64,3i64,3i64] 3i64 } def f [n] (g: i64 -> [n]i64) (xs: [n]i32) = let g' (x: i64) = g x : [n]i64 in (g' (length xs), n) def main xs = f (\x -> map (const x) xs) xs futhark-0.25.27/tests/shapes/negative-position-shape3.fut000066400000000000000000000001751475065116200233370ustar00rootroot00000000000000-- Entry points may not be return-polymorphic. -- == -- error: Entry point entry main [n] (x: i32) : [n]i32 = replicate n x futhark-0.25.27/tests/shapes/negative-position-shape4.fut000066400000000000000000000001771475065116200233420ustar00rootroot00000000000000-- == -- input { 2i64 } output { [2i64, 2i64] } def f [n] (x: i64) : [n]i64 = replicate n x def main (x: i64) : [x]i64 = f x futhark-0.25.27/tests/shapes/negative-position-shape5.fut000066400000000000000000000005731475065116200233430ustar00rootroot00000000000000-- Do not invent size variables for things that occur only in negative -- position. -- -- (This program is somewhat contrived to trigger unfortunate -- behaviour in the type checker.) def split_rng n rng = replicate n rng def shuffle' rngs xs = (rngs, xs) def main (rng: i32) (xs: []i32) = let rngs = split_rng (length xs) rng let (rngs', xs') = shuffle' rngs xs in xs' futhark-0.25.27/tests/shapes/nonint-shape-error.fut000066400000000000000000000002671475065116200222460ustar00rootroot00000000000000-- A shape declaration referring to a non-integer value should be an -- error. -- -- == -- error: bool def main(as: []i32, b: bool): [][]i32 = map (\i: [b]i32 -> replicate 3 i) as futhark-0.25.27/tests/shapes/opaque0.fut000066400000000000000000000005241475065116200200620ustar00rootroot00000000000000-- == -- error: do not match module num: { type t[n] val mk : (x: i64) -> t[x] val un [n] : t[n] -> i64 val comb [n] : t[n] -> t[n] -> i64 } = { type t[n] = [n]() def mk x = replicate x () def un x = length x def comb x y = length (zip x y) } def f x = let y = x + 1 in num.mk y def main a b = num.comb (f a) (f b) futhark-0.25.27/tests/shapes/paramsize0.fut000066400000000000000000000002511475065116200205600ustar00rootroot00000000000000-- == -- input { [1,2,3] } -- output { 3i64 } type^ f = (k: i64) -> [k]i32 -> i64 def f : f = \n (xs: [n]i32) -> length xs def main [K] (input: [K]i32) = f K input futhark-0.25.27/tests/shapes/paramsize1.fut000066400000000000000000000002051475065116200205600ustar00rootroot00000000000000-- == -- error: "k" type^ f = (k: i64) -> [k]i32 -> i64 def f : f = \_ xs -> length xs def main [K] (input: [K]i32) = f K input futhark-0.25.27/tests/shapes/partial-apply.fut000066400000000000000000000001631475065116200212660ustar00rootroot00000000000000-- Size for something that is partially applied. def f [n] (x: [n]f32): [n]f32 = x def main : []f32 -> []f32 = f futhark-0.25.27/tests/shapes/pointfree0.fut000066400000000000000000000002571475065116200205660ustar00rootroot00000000000000-- Defunctionaliser didn't handle this point-free definition properly. def f32id [n] (xs: [n]f32) : [n]f32 = xs def median = f32id >-> id entry median_entry xs = median xs futhark-0.25.27/tests/shapes/pointfree1.fut000066400000000000000000000002661475065116200205670ustar00rootroot00000000000000-- Defunctionaliser didn't handle this point-free definition properly. entry median_entry xs = let f32id [n] (xs: [n]f32) : [n]f32 = xs let median = f32id >-> id in median xs futhark-0.25.27/tests/shapes/polymorphic1.fut000066400000000000000000000004011475065116200211300ustar00rootroot00000000000000-- Arrays passed for polymorphic parameters of the same type must have -- the same size. -- == -- input { [1] [2] } output { [1] [2] } -- compiled input { [1] [2,3] } error: def pair 't (x: t) (y: t) = (x, y) def main (xs: []i32) (ys: []i32) = pair xs ys futhark-0.25.27/tests/shapes/polymorphic2.fut000066400000000000000000000002351475065116200211360ustar00rootroot00000000000000-- == -- input { 2 } output { 2 empty([0][1]i32) } def empty (d: i64) (x: i32) : (i32, [0][d]i32) = (x, []) def main (x: i32): (i32, [][1]i32) = empty 1 x futhark-0.25.27/tests/shapes/polymorphic3.fut000066400000000000000000000002321475065116200211340ustar00rootroot00000000000000-- We must be able to infer size-preserving function types. def set i v arr = copy arr with [i] = v def main [n] (xs: [n]i32) : [n]i32 = set 0 0 xs futhark-0.25.27/tests/shapes/polymorphic4.fut000066400000000000000000000002251475065116200211370ustar00rootroot00000000000000-- No hiding sizes behind type inference. -- == -- error: do not match def foo f x : [1]i32 = let r = if true then f x : []i32 else [1i32] in r futhark-0.25.27/tests/shapes/range0.fut000066400000000000000000000002101475065116200176540ustar00rootroot00000000000000-- Some ranges have known sizes. def main (n: i64) : ([n]i64, [n]i64, [n]i64, [n + 1 - 1]i64) = (0.. [m]i64 futhark-0.25.27/tests/shapes/shape-annot-is-param.fut000066400000000000000000000002571475065116200224370ustar00rootroot00000000000000-- == -- input { 2i64 [1,2] } -- output { [1,2] } -- compiled input { 1i64 [1,2] } -- error: def f (n: i64) (xs: [n]i32): [n]i32 = xs def main (n: i64) (xs: []i32) = f n xs futhark-0.25.27/tests/shapes/shape-inside-tuple.fut000066400000000000000000000001761475065116200222130ustar00rootroot00000000000000-- Issue #125 test program. -- -- == -- input { [[1,2],[3,4],[5,6]] } output { 3i64 } def main [n][m] (arg: [n][m]i32) = n futhark-0.25.27/tests/shapes/shape_duplicate.fut000066400000000000000000000003721475065116200216430ustar00rootroot00000000000000-- It is an error to impose two different names on the same dimension -- in a function parameter. -- -- == -- error: do not match def f [n][m] ((_, elems: [n]i32): (i32,[m]i32)): i32 = n + m + elems[0] def main (x: i32, y: []i32): i32 = f (x, y) futhark-0.25.27/tests/shapes/shape_in_ascription.fut000066400000000000000000000003241475065116200225270ustar00rootroot00000000000000-- Make sure ascribed names are available. -- -- == -- input { 2 [1i64,2i64,3i64] } -- output { 4i64 } def f [n] ((_, elems: []i64): (i32,[n]i64)) = n + elems[0] def main [n] (x: i32) (y: [n]i64) = f (x,y) futhark-0.25.27/tests/shapes/shape_in_tuple.fut000066400000000000000000000004301475065116200215030ustar00rootroot00000000000000-- Make sure inner shape names are available, even if they are -- "shadowed" by an outer type ascription. -- -- == -- input { 2 [1i64,2i64,3i64] } -- output { 4i64 } def f [n] ((_, elems: [n]i64): (i32,[]i64)): i64 = n + elems[0] def main (x: i32) (y: []i64): i64 = f (x,y) futhark-0.25.27/tests/shapes/size-inference0.fut000066400000000000000000000002211475065116200214700ustar00rootroot00000000000000-- Inference of return size. def get_at xs indices = map (\(i: i64) -> xs[i]) indices def main [l] (xs: [l]i32): [l]i32 = get_at xs (iota l) futhark-0.25.27/tests/shapes/size-inference1.fut000066400000000000000000000003351475065116200214770ustar00rootroot00000000000000-- Inference of return size (which then causes a type error). -- == -- error: "10" and "l" do not match def get_at xs indices = map (\(i: i64) -> xs[i]) indices def main [l] (xs: [l]i32): [10]i32 = get_at xs (iota l) futhark-0.25.27/tests/shapes/size-inference2.fut000066400000000000000000000002251475065116200214760ustar00rootroot00000000000000-- Sometimes we should infer that a size cannot be typed. -- == -- error: Sizes.*do not match def main [n] (xs: [n]i32) : [n]i32 = iota (length xs) futhark-0.25.27/tests/shapes/size-inference3.fut000066400000000000000000000002631475065116200215010ustar00rootroot00000000000000-- We cannot constrain the inferred size of an array parameter to a -- size that will not be visible in the function signature. -- == -- error: def main xs = zip xs (iota xs[0]) futhark-0.25.27/tests/shapes/size-inference4.fut000066400000000000000000000004451475065116200215040ustar00rootroot00000000000000-- It is an error if the size of an array parameter depends on a later -- parameter. Written in a convoluted way to ensure this is checked -- even for lambdas that are never let-generalised. -- == -- error: scope violation def f : i32 = const 2 ((\xs n -> (zip xs (iota n) : [](i64, i64)))) futhark-0.25.27/tests/shapes/size-inference5.fut000066400000000000000000000001631475065116200215020ustar00rootroot00000000000000-- Like size-inference4.fut, but with a let-binding. -- == -- error: scope violation def f xs n = zip xs (iota n) futhark-0.25.27/tests/shapes/size-inference6.fut000066400000000000000000000003711475065116200215040ustar00rootroot00000000000000-- Permit inference of a type with non-constructive size parameters. -- == -- input { 0i64 2i64 } output { empty([0]i64) [1i64,0i64] } def r = let f = reverse let g = reverse in {f, g} def main x y = (\p -> (p.f (iota x), p.g (iota y))) r futhark-0.25.27/tests/shapes/size-inference7.fut000066400000000000000000000004071475065116200215050ustar00rootroot00000000000000-- Does it work to have a definition that only has a size parameter? -- == -- input { 3i64 } -- output { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] -- [2, 2, 2, 2, 2, 2] } def iiota [n] : [n]i64 = iota n def main x = unzip (zip iiota (replicate (2*x) 2i32)) futhark-0.25.27/tests/shapes/slice0.fut000066400000000000000000000003461475065116200176710ustar00rootroot00000000000000-- Multiple slices with the same operands produce things that have the -- same size. def f (x: i64) = x + 2 def g (x: i64) = x * 2 def main [n] (xs: [n]i32) (ys: [n]i32) (i: i64) (j: i64) = zip xs[(f i):(g j)] ys[(f i):(g j)] futhark-0.25.27/tests/shapes/slice1.fut000066400000000000000000000001371475065116200176700ustar00rootroot00000000000000-- == -- error: do not match def main [n] [m] (xs: [n]i32) (ys: [m]i32) = zip xs[1:] ys[1:] futhark-0.25.27/tests/shapes/toplevel0.fut000066400000000000000000000001271475065116200204210ustar00rootroot00000000000000-- Important that size is existential. def values = [1] ++ [2] def main = copy values futhark-0.25.27/tests/shapes/toplevel1.fut000066400000000000000000000002561475065116200204250ustar00rootroot00000000000000-- Using a top level size. -- When this program failed, the problem was actually in the array literal. def n: i64 = 20 def main (xs: []i32) = let ys = take n xs in [ys] futhark-0.25.27/tests/shapes/toplevel2.fut000066400000000000000000000002621475065116200204230ustar00rootroot00000000000000-- #1993, do not allow size-polymorphic non-function bindings with -- unknown sizes. -- == -- error: size-polymorphic value binding def foo [n] = (iota n, filter (>5) (iota n)) futhark-0.25.27/tests/shapes/unknown-param.fut000066400000000000000000000003261475065116200213050ustar00rootroot00000000000000-- Existential sizes must not be (exclusively) used as a parameter -- type. -- == -- error: Unknown size.*in parameter def f (x: bool) = let (n,_) = if x then (10,true) else (20,true) in \(_: [n]bool) -> true futhark-0.25.27/tests/shapes/use-shapes.fut000066400000000000000000000003621475065116200205650ustar00rootroot00000000000000-- Test that a variable shape annotation is actually bound. -- == -- input { -- [42i64,1337i64,5i64,4i64,3i64,2i64,1i64] -- } -- output { -- [49i64,1344i64,12i64,11i64,10i64,9i64,8i64] -- } def main [n] (a: [n]i64): []i64 = map (+n) a futhark-0.25.27/tests/shortcircuit-and.fut000066400000000000000000000003651475065116200205120ustar00rootroot00000000000000-- && must be short-circuiting. -- -- == -- input { 0i64 [true, true] } output { true } -- input { 1i64 [true, true] } output { true } -- input { 2i64 [true, true] } output { false } def main [n] (i: i64) (bs: [n]bool): bool = i < n && bs[i] futhark-0.25.27/tests/shortcircuit-or.fut000066400000000000000000000003751475065116200203710ustar00rootroot00000000000000-- || must be short-circuiting. -- -- == -- input { 0i64 [false, false] } output { false } -- input { 1i64 [false, false] } output { false } -- input { 2i64 [false, false] } output { true } def main [n] (i: i64) (bs: [n]bool): bool = i >= n || bs[i] futhark-0.25.27/tests/simplify_primexp.fut000066400000000000000000000004001475065116200206160ustar00rootroot00000000000000-- The map should be simplified away entirely, even though it is a -- call to a built-in function. -- == -- structure gpu { SegMap 1 } def main (n: i64) (accs: []i64) = let ys = map (2**) (iota n) in map (\acc -> loop acc for y in ys do acc * y) accs futhark-0.25.27/tests/sinking0.fut000066400000000000000000000002271475065116200167470ustar00rootroot00000000000000-- == -- structure gpu { /Index 1 } def main (arr: [](i32, i32, i32, i32, i32)) = let (a,b,c,d,e) = arr[0] in if a == 0 then 0 else b + c + d + e futhark-0.25.27/tests/sinking1.fut000066400000000000000000000002751475065116200167530ustar00rootroot00000000000000-- == -- structure gpu { /SegMap/Index 1 } def main (as: []i32) (bs: []i32) (cs: []i32) (ds: []i32) (es: []i32) = map5 (\a b c d e -> if a == 0 then 0 else b + c + d + e) as bs cs ds es futhark-0.25.27/tests/sinking2.fut000066400000000000000000000005711475065116200167530ustar00rootroot00000000000000-- Sinking can be safe even in the presence of in-place updates. -- == -- structure gpu { /SegMap/Index 1 } def main (n: i64) (as: []i32) (bs: []i32) (cs: []i32) (ds: []i32) (es: []i32) = map5 (\a b c d e -> let arr = loop arr = replicate n 0 for i < n do arr with [i] = a in if a != 1337 then arr else replicate n (b + c + d + e)) as bs cs ds es futhark-0.25.27/tests/sinking4.fut000066400000000000000000000003551475065116200167550ustar00rootroot00000000000000-- Avoid sinking when the value is actually used afterwards in a -- result. See issue #858. -- == def main (arr: [](i32, i32, i32, i32, i32)) = let (a,b,c,d,e) = arr[0] let x = if a == 0 then 0 else b + c + d + e in (x,a,b,c,d,e) futhark-0.25.27/tests/sinking5.fut000066400000000000000000000006141475065116200167540ustar00rootroot00000000000000-- Sinking should be as deep as possible. -- == -- structure gpu { -- /GPUBody/Index 1 -- /GPUBody/If/False/If/True/Index 1 -- /GPUBody/If/False/If/False/If/True/Index 1 -- /GPUBody/If/False/If/False/If/False/Index 2 -- } def main (arr: [](i32, i32, i32, i32, i32)) = let (a,b,c,d,e) = arr[0] in if a == 0 then 0 else if a == 1 then b else if a == 2 then c else d + e futhark-0.25.27/tests/sinking6.fut000066400000000000000000000006041475065116200167540ustar00rootroot00000000000000-- At one point this did an incorrect sinking due to not looking -- properly at WithAccs. def pointToRoller (model: [2]f32) (p: f32) : f32 = let radius = model[0] let alpha = model[1] let p' = map (p *) [f32.sin alpha, 0.0] in (p'[0] - radius) ** 2 entry pointsToRollerGrad [n] (model: [2]f32) (pcd: [n]f32) : [2]f32 = vjp (\m -> map (pointToRoller m) pcd |> f32.sum) model 1 futhark-0.25.27/tests/size-expr-for-in.fut000066400000000000000000000002271475065116200203430ustar00rootroot00000000000000-- ForIn pattern must transform its size expressions -- == def main (n:i64) = loop x = 0 for t in replicate n (iota (n+1)) do x + reduce (+) 0 t futhark-0.25.27/tests/size-from-division.fut000066400000000000000000000004151475065116200207610ustar00rootroot00000000000000-- The array size is the result of a division. -- -- This was a problem with futhark-py and futhark-pyopencl due to the magic '/' -- Python 3 division operator. -- == -- input { 5i64 2i64 } -- output { [0i64, 1i64] } def main (x: i64) (y: i64): []i64 = iota (x / y) futhark-0.25.27/tests/slice-lmads/000077500000000000000000000000001475065116200167015ustar00rootroot00000000000000futhark-0.25.27/tests/slice-lmads/bounds.fut000066400000000000000000000032241475065116200207140ustar00rootroot00000000000000-- == -- entry: index -- compiled random input { [100]i32 0i64 10i64 -20i64 10i64 20i64 10i64 1i64 } -- error: out of bounds -- compiled random input { [27]i32 0i64 10i64 1i64 10i64 1i64 10i64 1i64 } -- error: out of bounds -- random input { [30]i32 0i64 10i64 1i64 10i64 1i64 10i64 1i64 } -- compiled random input { [27]i32 2i64 10i64 1i64 10i64 1i64 10i64 1i64 } -- error: out of bounds -- compiled random input { [1001]i32 1i64 10i64 100i64 10i64 10i64 10i64 1i64 } -- compiled random input { [1000]i32 1i64 10i64 100i64 10i64 10i64 10i64 1i64 } -- error: out of bounds -- == -- entry: index_2d -- compiled random input { [91]i32 0i64 5i64 20i64 11i64 1i64 } -- == -- entry: update -- compiled random input { [100]i32 0i64 -20i64 20i64 1i64 [10][10][10]i32 } -- error: out of bounds -- compiled random input { [27]i32 0i64 1i64 1i64 1i64 [10][10][10]i32 } -- error: out of bounds -- random input { [30]i32 0i64 1i64 1i64 1i64 [1][1][1]i32 } -- compiled random input { [27]i32 2i64 1i64 1i64 1i64 [10][10][10]i32 } -- error: out of bounds -- compiled random input { [1001]i32 1i64 100i64 10i64 1i64 [10][10][10]i32 } -- compiled random input { [1000]i32 1i64 100i64 10i64 1i64 [10][10][10]i32 } -- error: out of bounds import "intrinsics" entry index [n] (xs: [n]i32) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64): [][][]i32 = flat_index_3d xs offset n1 s1 n2 s2 n3 s3 entry index_2d [n] (xs: [n]i32) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64): [][]i32 = flat_index_2d xs offset n1 s1 n2 s2 entry update [n] (xs: *[n]i32) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (vs: [][][]i32): [n]i32 = flat_update_3d xs offset s1 s2 s3 vs futhark-0.25.27/tests/slice-lmads/flat.fut000066400000000000000000000032721475065116200203530ustar00rootroot00000000000000-- == -- entry: update_antidiag -- script input { my_iota 100i64 } -- output { [ 0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64, 8i64, 9i64, -- 10i64, 11i64, 12i64, 13i64, 14i64, 15i64, 16i64, 0i64, 1i64, 2i64, -- 20i64, 21i64, 22i64, 23i64, 24i64, 25i64, 26i64, 3i64, 4i64, 5i64, -- 30i64, 31i64, 32i64, 33i64, 34i64, 35i64, 36i64, 6i64, 7i64, 8i64, -- 40i64, 41i64, 42i64, 43i64, 9i64, 10i64, 11i64, 47i64, 48i64, 49i64, -- 50i64, 51i64, 52i64, 53i64, 12i64, 13i64, 14i64, 57i64, 58i64, 59i64, -- 60i64, 61i64, 62i64, 63i64, 15i64, 16i64, 17i64, 67i64, 68i64, 69i64, -- 70i64, 18i64, 19i64, 20i64, 74i64, 75i64, 76i64, 77i64, 78i64, 79i64, -- 80i64, 21i64, 22i64, 23i64, 84i64, 85i64, 86i64, 87i64, 88i64, 89i64, -- 90i64, 24i64, 25i64, 26i64, 94i64, 95i64, 96i64, 97i64, 98i64, 99i64] } -- == -- entry: index_antidiag -- script input { my_iota 100i64 } -- output { [ [ [ 17i64, 18i64, 19i64 ], -- [ 27i64, 28i64, 29i64 ], -- [ 37i64, 38i64, 39i64 ] ], -- [ [ 44i64, 45i64, 46i64 ], -- [ 54i64, 55i64, 56i64 ], -- [ 64i64, 65i64, 66i64 ] ], -- [ [ 71i64, 72i64, 73i64 ], -- [ 81i64, 82i64, 83i64 ], -- [ 91i64, 92i64, 93i64 ] ] ] } import "intrinsics" entry my_iota (n: i64) = iota n entry update_antidiag [n] (xs: *[n]i64): [n]i64 = let vs = iota (3*3*3) |> unflatten |> unflatten let zs = flat_update_3d xs 17 27 10 1 vs in zs entry index_antidiag [n] (xs: [n]i64): [][][]i64 = flat_index_3d xs 17 3 27 3 10 3 1 -- We need to test weird inner strides as well -- And I guess negative strides? futhark-0.25.27/tests/slice-lmads/intrinsics.fut000066400000000000000000000022351475065116200216100ustar00rootroot00000000000000#[inline] def flat_index_2d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) : [n1][n2]a = intrinsics.flat_index_2d as offset n1 s1 n2 s2 :> [n1][n2]a #[inline] def flat_update_2d [n][k][l] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (asss: [k][l]a) : *[n]a = intrinsics.flat_update_2d as offset s1 s2 asss #[inline] def flat_index_3d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) : [n1][n2][n3]a = intrinsics.flat_index_3d as offset n1 s1 n2 s2 n3 s3 :> [n1][n2][n3]a #[inline] def flat_update_3d [n][k][l][p] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (asss: [k][l][p]a) : *[n]a = intrinsics.flat_update_3d as offset s1 s2 s3 asss #[inline] def flat_index_4d [n] 'a (as: [n]a) (offset: i64) (n1: i64) (s1: i64) (n2: i64) (s2: i64) (n3: i64) (s3: i64) (n4: i64) (s4: i64) : [n1][n2][n3][n4]a = intrinsics.flat_index_4d as offset n1 s1 n2 s2 n3 s3 n4 s4 :> [n1][n2][n3][n4]a #[inline] def flat_update_4d [n][k][l][p][q] 'a (as: *[n]a) (offset: i64) (s1: i64) (s2: i64) (s3: i64) (s4: i64) (asss: [k][l][p][q]a) : *[n]a = intrinsics.flat_update_4d as offset s1 s2 s3 s4 asss futhark-0.25.27/tests/slice-lmads/lud.fut000066400000000000000000000135041475065116200202100ustar00rootroot00000000000000-- Parallel blocked LU-decomposition. -- -- == -- entry: lud -- random input { 32i64 [1024]f32 } -- compiled random input { 32i64 [16384]f32 } -- compiled random input { 32i64 [4194304]f32 } import "intrinsics" def dotprod [n] (a: [n]f32) (b: [n]f32): f32 = map2 (*) a b |> reduce (+) 0 def lud_diagonal [b] (a: [b][b]f32): *[b][b]f32 = map1 (\mat -> let mat = copy mat in loop (mat: *[b][b]f32) for i < b-1 do let col = map (\j -> if j > i then #[unsafe] (mat[j,i] - (dotprod mat[j,:i] mat[:i,i])) / mat[i,i] else mat[j,i]) (iota b) let mat[:,i] = col let row = map (\j -> if j > i then mat[i+1, j] - (dotprod mat[:i+1, j] mat[i+1, :i+1]) else mat[i+1, j]) (iota b) let mat[i+1] = row in mat ) (unflatten (a :> [opaque 1*b][b]f32)) |> head def lud_perimeter_upper [m][b] (diag: [b][b]f32) (a0s: [m][b][b]f32): *[m][b][b]f32 = let a1s = map (\ (x: [b][b]f32): [b][b]f32 -> transpose(x)) a0s in let a2s = map (\a1: [b][b]f32 -> map (\row0: [b]f32 -> -- Upper loop row = copy row0 for i < b do let sum = (loop sum=0.0f32 for k < i do sum + diag[i,k] * row[k]) let row[i] = row[i] - sum in row ) a1 ) a1s in map (\x: [b][b]f32 -> transpose(x)) a2s def lud_perimeter_lower [b][m] (diag: [b][b]f32) (mat: [m][b][b]f32): *[m][b][b]f32 = map (\blk: [b][b]f32 -> map (\ (row0: [b]f32): *[b]f32 -> -- Lower loop row = copy row0 for j < b do let sum = loop sum=0.0f32 for k < j do sum + diag[k,j] * row[k] let row[j] = (row[j] - sum) / diag[j,j] in row ) blk ) mat def lud_internal [m][b] (top_per: [m][b][b]f32) (lft_per: [m][b][b]f32) (mat_slice: [m][m][b][b]f32): *[m][m][b][b]f32 = let top_slice = map transpose top_per in map (\(mat_arr: [m][b][b]f32, lft: [b][b]f32): [m][b][b]f32 -> map (\ (mat_blk: [b][b]f32, top: [b][b]f32): [b][b]f32 -> map (\ (mat_row: [b]f32, lft_row: [b]f32): [b]f32 -> map (\(mat_el, top_row) -> let prods = map2 (*) lft_row top_row let sum = f32.sum prods in mat_el - sum ) (zip (mat_row) top) ) (zip (mat_blk) lft ) ) (zip (mat_arr) (top_slice) ) ) (zip (mat_slice) (lft_per) ) entry lud [n] (block_size: i64) (mat: *[n]f32): [n]f32 = let row_length = i64.f64 <| f64.sqrt <| f64.i64 n let row_length = assert (row_length ** 2 == n) row_length let block_size = i64.min block_size row_length let num_blocks = assert (row_length % block_size == 0) (row_length / block_size) let mat = loop mat for step < num_blocks - 1 do -- 1. compute the current diagonal block let diag = lud_diagonal (flat_index_2d mat (row_length * block_size * step + block_size * step) block_size row_length block_size 1) let mat = flat_update_2d mat (row_length * block_size * step + block_size * step) row_length 1 diag -- 2. compute the top perimeter let top_per_irreg = lud_perimeter_upper diag (flat_index_3d mat (row_length * block_size * step + block_size * (step + 1)) (num_blocks - step - 1) block_size block_size row_length block_size 1) let mat = flat_update_3d mat (row_length * block_size * step + block_size * (step + 1)) block_size row_length 1 top_per_irreg -- 3. compute the left perimeter let lft_per_irreg = lud_perimeter_lower diag (flat_index_3d mat (row_length * block_size * (step + 1) + block_size * step) (num_blocks - step - 1) (row_length * block_size) block_size row_length block_size 1) let mat = flat_update_3d mat (row_length * block_size * (step + 1) + block_size * step) (row_length * block_size) row_length 1 lft_per_irreg -- 4. compute the internal blocks let internal = lud_internal top_per_irreg lft_per_irreg (flat_index_4d mat (row_length * block_size * (step + 1) + block_size * (step + 1)) (num_blocks - step - 1) (row_length * block_size) (num_blocks - step - 1) block_size block_size row_length block_size 1) let mat = flat_update_4d mat (row_length * block_size * (step + 1) + block_size * (step + 1)) (row_length * block_size) block_size row_length 1 internal in mat let last_step = num_blocks - 1 let last_offset = last_step * block_size * row_length + last_step * block_size let v = lud_diagonal (flat_index_2d mat last_offset block_size row_length block_size 1) let mat = flat_update_2d mat last_offset row_length 1 v in mat entry lud_2d [m] (mat: *[m][m]f32): [m][m]f32 = let mat = flatten mat let mat = lud 32 mat in unflatten mat futhark-0.25.27/tests/slice-lmads/nw-cosmin.fut000066400000000000000000000163361475065116200213440ustar00rootroot00000000000000-- Code and comments based on -- https://github.com/kkushagra/rodinia/blob/master/openmp/nw -- -- == -- entry: nw_flat -- compiled random input { 16i64 10i32 [2362369]i32 [2362369]i32 } auto output -- compiled random input { 32i64 10i32 [2362369]i32 [2362369]i32 } auto output -- compiled random input { 64i64 10i32 [2362369]i32 [2362369]i32 } auto output -- compiled input { 3i64 -- 10i32 -- [4i32, 2i32, 4i32, 9i32, 2i32, 1i32, 7i32, 1i32, 9i32, 8i32, -- 7i32, 6i32, 7i32, 5i32, 0i32, 9i32, 0i32, 0i32, 2i32, 5i32, -- 6i32, 9i32, 3i32, 3i32, 7i32, 6i32, 6i32, 5i32, 4i32, 7i32, -- 1i32, 9i32, 5i32, 4i32, 4i32, 5i32, 9i32, 6i32, 7i32, 2i32, -- 2i32, 9i32, 6i32, 6i32, 8i32, 4i32, 4i32, 8i32, 0i32, 4i32, -- 5i32, 5i32, 5i32, 1i32, 3i32, 1i32, 1i32, 7i32, 2i32, 8i32, -- 5i32, 3i32, 9i32, 4i32, 2i32, 8i32, 1i32, 1i32, 0i32, 5i32, -- 8i32, 7i32, 0i32, 7i32, 5i32, 9i32, 1i32, 5i32, 5i32, 1i32, -- 6i32, 1i32, 8i32, 9i32, 3i32, 4i32, 6i32, 0i32, 2i32, 5i32, -- 4i32, 8i32, 7i32, 7i32, 2i32, 0i32, 5i32, 0i32, 1i32, 3i32] -- [4i32, 2i32, 4i32, 9i32, 2i32, 1i32, 7i32, 1i32, 9i32, 8i32, -- 7i32, 6i32, 7i32, 5i32, 0i32, 9i32, 0i32, 0i32, 2i32, 5i32, -- 6i32, 9i32, 3i32, 3i32, 7i32, 6i32, 6i32, 5i32, 4i32, 7i32, -- 1i32, 9i32, 5i32, 4i32, 4i32, 5i32, 9i32, 6i32, 7i32, 2i32, -- 2i32, 9i32, 6i32, 6i32, 8i32, 4i32, 4i32, 8i32, 0i32, 4i32, -- 5i32, 5i32, 5i32, 1i32, 3i32, 1i32, 1i32, 7i32, 2i32, 8i32, -- 5i32, 3i32, 9i32, 4i32, 2i32, 8i32, 1i32, 1i32, 0i32, 5i32, -- 8i32, 7i32, 0i32, 7i32, 5i32, 9i32, 1i32, 5i32, 5i32, 1i32, -- 6i32, 1i32, 8i32, 9i32, 3i32, 4i32, 6i32, 0i32, 2i32, 5i32, -- 4i32, 8i32, 7i32, 7i32, 2i32, 0i32, 5i32, 0i32, 1i32, 3i32] } -- output { [ 4i32, 2i32, 4i32, 9i32, 2i32, 1i32, 7i32, 1i32, 9i32, 8i32, -- 7i32, 10i32, 9i32, 9i32, 9i32, 11i32, 1i32, 7i32, 3i32, 14i32, -- 6i32, 16i32, 13i32, 12i32, 16i32, 15i32, 17i32, 7i32, 11i32, 10i32, -- 1i32, 15i32, 21i32, 17i32, 16i32, 21i32, 24i32, 23i32, 14i32, 13i32, -- 2i32, 10i32, 21i32, 27i32, 25i32, 20i32, 25i32, 32i32, 23i32, 18i32, -- 5i32, 7i32, 15i32, 22i32, 30i32, 26i32, 21i32, 32i32, 34i32, 31i32, -- 5i32, 8i32, 16i32, 19i32, 24i32, 38i32, 28i32, 22i32, 32i32, 39i32, -- 8i32, 12i32, 8i32, 23i32, 24i32, 33i32, 39i32, 33i32, 27i32, 33i32, -- 6i32, 9i32, 20i32, 17i32, 26i32, 28i32, 39i32, 39i32, 35i32, 32i32, -- 4i32, 14i32, 16i32, 27i32, 19i32, 26i32, 33i32, 39i32, 40i32, 38i32] } -- structure gpu-mem { Alloc 6 } import "intrinsics" let mkVal [bp1][b] (y:i32) (x:i32) (pen:i32) (block:[bp1][bp1]i32) (ref:[b][b]i32) : i32 = #[unsafe] i32.max (block[y, x - 1] - pen) (block[y - 1, x] - pen) |> i32.max (block[y - 1, x - 1] + ref[y - 1, x - 1]) let process_block [b][bp1] (penalty: i32) (above: [bp1]i32) (left: [b]i32) (ref: [b][b]i32): *[b][b]i32 = let block = assert (b + 1 == bp1) (tabulate_2d bp1 bp1 (\_ _ -> 0)) let block[0, 0:] = above let block[1:, 0] = left -- Process the first half (anti-diagonally) of the block let block = loop block for m < b do let inds = tabulate b (\tx -> if tx > m then (-1, -1) else let ind_x = i32.i64 (tx + 1) let ind_y = i32.i64 (m - tx + 1) in (i64.i32 ind_y, i64.i32 ind_x)) let vals = -- tabulate over the m'th anti-diagonal before the middle tabulate b (\tx -> if tx > m then 0 else let ind_x = i32.i64 (tx + 1) let ind_y = i32.i64 (m - tx + 1) let v = mkVal ind_y ind_x penalty block ref in v) in scatter_2d block inds vals -- Process the second half (anti-diagonally) of the block let block = loop block for m < b-1 do let m = b - 2 - m let inds = tabulate b (\tx -> ( if tx > m then (-1, -1) else let ind_x = i32.i64 (tx + b - m) let ind_y = i32.i64 (b - tx) in ((i64.i32 ind_y, i64.i32 ind_x)) ) ) let vals = -- tabulate over the m'th anti-diagonal after the middle tabulate b (\tx -> ( if tx > m then (0) else let ind_x = i32.i64 (tx + b - m) let ind_y = i32.i64 (b - tx) let v = mkVal ind_y ind_x penalty block ref in v )) in scatter_2d block inds vals in block[1:, 1:] :> *[b][b]i32 entry nw_flat [n] (block_size: i64) (penalty: i32) (input: *[n]i32) (refs: [n]i32) : *[n]i32 = let row_length = i64.f64 <| f64.sqrt <| f64.i64 n let num_blocks = assert ((row_length - 1) % block_size == 0) ((row_length - 1) / block_size) let bp1 = assert (row_length > 3) (assert (2 * block_size < row_length) (block_size + 1)) let input = loop input for i < num_blocks do let ip1 = i + 1 let v = #[incremental_flattening(only_intra)] map3 (process_block penalty) (flat_index_2d input (i * block_size) ip1 (row_length * block_size - block_size) bp1 1) (flat_index_2d input (row_length + i * block_size) ip1 (row_length * block_size - block_size) block_size row_length) (flat_index_3d refs (row_length + 1 + i * block_size) ip1 (row_length * block_size - block_size) block_size row_length block_size 1i64) in flat_update_3d input (row_length + 1 + i * block_size) (row_length * block_size - block_size) (row_length) 1 v let input = loop input for i < num_blocks - 1 do let v = #[incremental_flattening(only_intra)] map3 (process_block penalty) (flat_index_2d input (((i + 1) * block_size + 1) * row_length - block_size - 1) (num_blocks - i - 1) (row_length * block_size - block_size) bp1 1i64) (flat_index_2d input (((i + 1) * block_size + 1) * row_length - block_size - 1 + row_length) (num_blocks - i - 1) (row_length * block_size - block_size) block_size row_length) (flat_index_3d refs (((i + 1) * block_size + 2) * row_length - block_size) (num_blocks - i - 1) (row_length * block_size - block_size) block_size row_length block_size 1i64) in flat_update_3d input (((i + 1) * block_size + 2) * row_length - block_size) (row_length * block_size - block_size) (row_length) 1 v in input futhark-0.25.27/tests/slice-lmads/nw.fut000066400000000000000000000154611475065116200200540ustar00rootroot00000000000000-- Code and comments based on -- https://github.com/kkushagra/rodinia/blob/master/openmp/nw -- -- == -- entry: nw_flat -- compiled random input { 16i64 10i32 [2362369]i32 [2362369]i32 } auto output -- compiled random input { 32i64 10i32 [2362369]i32 [2362369]i32 } auto output -- compiled random input { 64i64 10i32 [2362369]i32 [2362369]i32 } auto output -- compiled input { 3i64 -- 10i32 -- [4i32, 2i32, 4i32, 9i32, 2i32, 1i32, 7i32, 1i32, 9i32, 8i32, -- 7i32, 6i32, 7i32, 5i32, 0i32, 9i32, 0i32, 0i32, 2i32, 5i32, -- 6i32, 9i32, 3i32, 3i32, 7i32, 6i32, 6i32, 5i32, 4i32, 7i32, -- 1i32, 9i32, 5i32, 4i32, 4i32, 5i32, 9i32, 6i32, 7i32, 2i32, -- 2i32, 9i32, 6i32, 6i32, 8i32, 4i32, 4i32, 8i32, 0i32, 4i32, -- 5i32, 5i32, 5i32, 1i32, 3i32, 1i32, 1i32, 7i32, 2i32, 8i32, -- 5i32, 3i32, 9i32, 4i32, 2i32, 8i32, 1i32, 1i32, 0i32, 5i32, -- 8i32, 7i32, 0i32, 7i32, 5i32, 9i32, 1i32, 5i32, 5i32, 1i32, -- 6i32, 1i32, 8i32, 9i32, 3i32, 4i32, 6i32, 0i32, 2i32, 5i32, -- 4i32, 8i32, 7i32, 7i32, 2i32, 0i32, 5i32, 0i32, 1i32, 3i32] -- [4i32, 2i32, 4i32, 9i32, 2i32, 1i32, 7i32, 1i32, 9i32, 8i32, -- 7i32, 6i32, 7i32, 5i32, 0i32, 9i32, 0i32, 0i32, 2i32, 5i32, -- 6i32, 9i32, 3i32, 3i32, 7i32, 6i32, 6i32, 5i32, 4i32, 7i32, -- 1i32, 9i32, 5i32, 4i32, 4i32, 5i32, 9i32, 6i32, 7i32, 2i32, -- 2i32, 9i32, 6i32, 6i32, 8i32, 4i32, 4i32, 8i32, 0i32, 4i32, -- 5i32, 5i32, 5i32, 1i32, 3i32, 1i32, 1i32, 7i32, 2i32, 8i32, -- 5i32, 3i32, 9i32, 4i32, 2i32, 8i32, 1i32, 1i32, 0i32, 5i32, -- 8i32, 7i32, 0i32, 7i32, 5i32, 9i32, 1i32, 5i32, 5i32, 1i32, -- 6i32, 1i32, 8i32, 9i32, 3i32, 4i32, 6i32, 0i32, 2i32, 5i32, -- 4i32, 8i32, 7i32, 7i32, 2i32, 0i32, 5i32, 0i32, 1i32, 3i32] } -- output { [ 4i32, 2i32, 4i32, 9i32, 2i32, 1i32, 7i32, 1i32, 9i32, 8i32, -- 7i32, 10i32, 9i32, 9i32, 9i32, 11i32, 1i32, 7i32, 3i32, 14i32, -- 6i32, 16i32, 13i32, 12i32, 16i32, 15i32, 17i32, 7i32, 11i32, 10i32, -- 1i32, 15i32, 21i32, 17i32, 16i32, 21i32, 24i32, 23i32, 14i32, 13i32, -- 2i32, 10i32, 21i32, 27i32, 25i32, 20i32, 25i32, 32i32, 23i32, 18i32, -- 5i32, 7i32, 15i32, 22i32, 30i32, 26i32, 21i32, 32i32, 34i32, 31i32, -- 5i32, 8i32, 16i32, 19i32, 24i32, 38i32, 28i32, 22i32, 32i32, 39i32, -- 8i32, 12i32, 8i32, 23i32, 24i32, 33i32, 39i32, 33i32, 27i32, 33i32, -- 6i32, 9i32, 20i32, 17i32, 26i32, 28i32, 39i32, 39i32, 35i32, 32i32, -- 4i32, 14i32, 16i32, 27i32, 19i32, 26i32, 33i32, 39i32, 40i32, 38i32] } -- structure gpu-mem { Alloc 6 } -- structure seq-mem { Alloc 8 } import "intrinsics" def mkVal [bp1][b] (y:i32) (x:i32) (pen:i32) (block:[bp1][bp1]i32) (ref:[b][b]i32) : i32 = #[unsafe] i32.max (block[y, x - 1] - pen) (block[y - 1, x] - pen) |> i32.max (block[y - 1, x - 1] + ref[y - 1, x - 1]) def process_block [b][bp1] (penalty: i32) (block: [bp1][bp1]i32) (ref: [b][b]i32): *[b][b]i32 = -- let bp1 = assert (bp1 = b + 1) bp1 -- Process the first half (anti-diagonally) of the block let block = loop block = copy block for m < b do let inds = tabulate b (\tx -> if tx > m then (-1, -1) else let ind_x = i32.i64 (tx + 1) let ind_y = i32.i64 (m - tx + 1) in (i64.i32 ind_y, i64.i32 ind_x)) let vals = -- tabulate over the m'th anti-diagonal before the middle tabulate b (\tx -> if tx > m then 0 else let ind_x = i32.i64 (tx + 1) let ind_y = i32.i64 (m - tx + 1) let v = mkVal ind_y ind_x penalty block ref in v) in scatter_2d block inds vals -- Process the second half (anti-diagonally) of the block let block = loop block for m < b-1 do let m = b - 2 - m let inds = tabulate b (\tx -> ( if tx > m then (-1, -1) else let ind_x = i32.i64 (tx + b - m) let ind_y = i32.i64 (b - tx) in ((i64.i32 ind_y, i64.i32 ind_x)) ) ) let vals = -- tabulate over the m'th anti-diagonal after the middle tabulate b (\tx -> ( if tx > m then (0) else let ind_x = i32.i64 (tx + b - m) let ind_y = i32.i64 (b - tx) let v = mkVal ind_y ind_x penalty block ref in v )) in scatter_2d block inds vals in block[1:, 1:] :> *[b][b]i32 entry nw_flat [n] (block_size: i64) (penalty: i32) (input: *[n]i32) (refs: [n]i32) : *[n]i32 = let row_length = i64.f64 <| f64.sqrt <| f64.i64 n let num_blocks = -- assert ((row_length - 1) % b == 0) <| (row_length - 1) / block_size let bp1 = block_size + 1 let input = loop input for i < num_blocks do let ip1 = i + 1 let v = #[incremental_flattening(only_intra)] map2 (process_block penalty) (flat_index_3d input (i * block_size) ip1 (row_length * block_size - block_size) bp1 row_length bp1 1i64) (flat_index_3d refs (row_length + 1 + i * block_size) ip1 (row_length * block_size - block_size) block_size row_length block_size 1i64) in flat_update_3d input (row_length + 1 + i * block_size) (row_length * block_size - block_size) (row_length) 1 v let input = loop input for i < num_blocks - 1 do let v = #[incremental_flattening(only_intra)] map2 (process_block penalty) (flat_index_3d input (((i + 1) * block_size + 1) * row_length - block_size - 1) (num_blocks - i - 1) (row_length * block_size - block_size) bp1 row_length bp1 1i64) (flat_index_3d refs (((i + 1) * block_size + 2) * row_length - block_size) (num_blocks - i - 1) (row_length * block_size - block_size) block_size row_length block_size 1i64) in flat_update_3d input (((i + 1) * block_size + 2) * row_length - block_size) (row_length * block_size - block_size) (row_length) 1 v in input futhark-0.25.27/tests/slice-lmads/small.fut000066400000000000000000000020111475065116200205230ustar00rootroot00000000000000-- == -- entry: index_antidiag -- input { [ 0i64, 1i64, 2i64, 3i64, -- 4i64, 5i64, 6i64, 7i64, -- 8i64, 9i64, 10i64, 11i64, -- 12i64, 13i64, 14i64, 15i64 ] } -- output { [[[2i64, 3i64], [6i64, 7i64]], -- [[5i64, 6i64], [9i64, 10i64]], -- [[8i64, 9i64], [12i64, 13i64]]] } -- input { [ 0i64 ] } -- error: out of bounds -- == -- entry: update_antidiag -- input { [ 0i64, 1i64, 2i64, 3i64, -- 4i64, 5i64, 6i64, 7i64, -- 8i64, 9i64, 10i64, 11i64, -- 12i64, 13i64, 14i64, 15i64 ] } -- output { [ 0i64, 0i64, 1i64, 3i64, -- 4i64, 2i64, 3i64, 7i64, -- 8i64, 9i64, 4i64, 5i64, -- 12i64, 13i64, 6i64, 7i64 ] } -- input { [ 0i64 ] } -- error: out of bounds import "intrinsics" entry index_antidiag [n] (xs: [n]i64): [][][]i64 = flat_index_3d xs 2 3 3 2 4 2 1 entry update_antidiag [n] (xs: *[n]i64): *[n]i64 = let vs = iota (2*2*2) |> unflatten |> unflatten in flat_update_3d xs 1 9 4 1 vs futhark-0.25.27/tests/slice-lmads/small_2d.fut000066400000000000000000000015211475065116200211150ustar00rootroot00000000000000-- == -- entry: index_antidiag -- input { [ 0i64, 1i64, 2i64, 3i64, -- 4i64, 5i64, 6i64, 7i64, -- 8i64, 9i64, 10i64, 11i64, -- 12i64, 13i64, 14i64, 15i64] } -- output { [[ 5i64, 6i64], -- [ 14i64, 15i64]] } -- == -- entry: update_antidiag -- input { [ 0i64, 1i64, 2i64, 3i64, -- 4i64, 5i64, 6i64, 7i64, -- 8i64, 9i64, 10i64, 11i64, -- 12i64, 13i64, 14i64, 15i64] } -- output { [ 0i64, 1i64, 2i64, 0i64, -- 4i64, 5i64, 6i64, 1i64, -- 8i64, 2i64, 10i64, 11i64, -- 12i64, 3i64, 14i64, 15i64] } import "intrinsics" entry index_antidiag [n] (xs: [n]i64): [][]i64 = flat_index_2d xs 5 2 9 2 1 entry update_antidiag [n] (xs: *[n]i64): *[n]i64 = let vs = iota (2 * 2) |> unflatten in flat_update_2d xs 3 6 4 vs futhark-0.25.27/tests/slice-lmads/small_4d.fut000066400000000000000000000021651475065116200211240ustar00rootroot00000000000000-- == -- entry: index_antidiag -- input { [ 0i64, 1i64, 2i64, 3i64, -- 4i64, 5i64, 6i64, 7i64, -- 8i64, 9i64, 10i64, 11i64, -- 12i64, 13i64, 14i64, 15i64, -- 16i64, 17i64, 18i64, 19i64] } -- output { [[[[2i64, 3i64], [6i64, 7i64]], -- [[4i64, 5i64], [8i64, 9i64]]], -- [[[10i64, 11i64], [14i64, 15i64]], -- [[12i64, 13i64], [16i64, 17i64]]]] } -- == -- entry: update_antidiag -- input { [ 0i64, 1i64, 2i64, 3i64, -- 4i64, 5i64, 6i64, 7i64, -- 8i64, 9i64, 10i64, 11i64, -- 12i64, 13i64, 14i64, 15i64, -- 16i64, 17i64, 18i64, 19i64] } -- output { [ 0i64, 1i64, 0i64, 1i64, -- 4i64, 5i64, 2i64, 3i64, -- 6i64, 7i64, 8i64, 9i64, -- 12i64, 13i64, 10i64, 11i64, -- 14i64, 15i64, 18i64, 19i64] } import "intrinsics" entry index_antidiag [n] (xs: [n]i64): [][][][]i64 = flat_index_4d xs 2 2 8 2 2 2 4 2 1 entry update_antidiag [n] (xs: *[n]i64): *[n]i64 = let vs = iota (2*2*2*2) |> unflatten |> unflatten |> unflatten in flat_update_4d xs 2 8 2 4 1 vs futhark-0.25.27/tests/slice0.fut000066400000000000000000000007111475065116200164020ustar00rootroot00000000000000-- Test of basic slicing. -- -- == -- input { [1,2,3,4,5] 1 3 } -- output { [2,3] } -- input { [1,2,3,4,5] 0 5 } -- output { [1,2,3,4,5] } -- input { [1,2,3,4,5] 1 1 } -- output { empty([0]i32) } -- input { [1,2,3,4,5] 1 0 } -- error: Index \[1:0\] out of bounds for array of shape \[5\] -- input { empty([0]i32) 0 1 } -- error: Index \[0:1\] out of bounds for array of shape \[0\] def main (as: []i32) (i: i32) (j: i32): []i32 = as[i64.i32 i:i64.i32 j] futhark-0.25.27/tests/slice1.fut000066400000000000000000000007271475065116200164120ustar00rootroot00000000000000-- Slicing a multidimensional array across the outer dimension. -- -- == -- input { [[1,2,3],[4,5,6]] 1 3 } -- output { [[2,3],[5,6]] } -- input { [[1,2,3],[4,5,6]] 0 3 } -- output { [[1,2,3],[4,5,6]] } -- input { [[1,2,3],[4,5,6]] 1 1 } -- output { empty([2][0]i32) } -- input { [[1,2,3],[4,5,6]] 1 0 } -- error: Index \[0:2, 1:0\] out of bounds for array of shape \[2\]\[3\]. def main [n][m] (as: [n][m]i32) (i: i32) (j: i32): [n][]i32 = as[0:n,i64.i32 i:i64.i32 j] futhark-0.25.27/tests/slice2.fut000066400000000000000000000036641475065116200164160ustar00rootroot00000000000000-- An index inside of a map should be turned into a slice, rather than -- a kernel by itself. -- == -- -- input { [[[[0.74242944f32, 0.1323092f32], [1.3599575e-2f32, -- 0.42590684f32]], [[0.28189754f32, 0.71788645f32], [0.120514154f32, -- 0.3523355f32]], [[0.97101444f32, 0.8475803f32], [0.88611674f32, -- 0.9148224f32]], [[0.94415265f32, 0.14399022f32], [0.5325674f32, -- 0.659268f32]]], [[[0.7296194f32, 0.6609876f32], [6.526101e-2f32, -- 6.5751016e-2f32]], [[ 0.95010173f32, 0.14800721f32], -- [0.94630295f32, 0.53180677f32]], [[0.50352955f32, 0.8683887f32], -- [0.52372944f32, 0.56981534f32]], [[0.89906573f32, 0.28717548f32], -- [0.33396137f32, 0.1774621f32]]], [[[0.38886482f32, 0.9896543f32], -- [0.46158296f32, 0.3661f32]], [[0.3473122f32, 0.3432145f32], -- [0.8394218f32, 0.99296236f32]], [[0.121897876f32, 9.7216845e-2f32], -- [0.9392534f32, 0.21994972f32]], [[0.48229688f32, 0.655326f32], -- [0.7612596f32, 0.87178886f32]]], [[[0.6982439f32, 0.3648432f32], -- [0.2956829f32, 0.64948434f32]], [[0.9514074f32, 0.5657658f32], -- [0.96731836f32, 0.2870463f32]], [[0.24546045f32, 0.5121502f32], -- [2.8573096e-2f32, 0.8905163f32]], [[0.11413413f32, 0.758343f32], [ -- 6.598133e-2f32, 0.34899563f32]]]] } -- output { -- [[[-0.12878528f32, -0.4338454f32], -- [-0.4932002f32, -0.28704658f32]], -- [[-0.13519031f32, -0.16950619f32], -- [-0.4673695f32, -0.4671245f32]], -- [[-0.3055676f32, -5.1728487e-3f32], -- [-0.26920852f32, -0.31695f32]], -- [[-0.15087804f32, -0.3175784f32], -- [-0.35215855f32, -0.17525783f32]]] -- } -- structure gpu { SegMap 1 } def main [m][b] (mat: [m][m][b][b]f32): [m][b][b]f32 = let mat_rows = map (\(mat_row: [m][b][b]f32): [b][b]f32 -> mat_row[0]) mat in map (\(blk: [b][b]f32): [b][b]f32 -> map (\(row0: [b]f32): [b]f32 -> loop row=copy row0 for j < b do let row[j] = (row[j] - 1.0f32) / 2.0f32 in row) blk) ( mat_rows) futhark-0.25.27/tests/slice3.fut000066400000000000000000000002371475065116200164100ustar00rootroot00000000000000-- Slicing produces a size that we can obtain. -- == -- input { [1,2,3] 0i64 1i64 } output { 1i64 } def main (xs: []i32) (i: i64) (j: i64) = length xs[i:j] futhark-0.25.27/tests/slice4.fut000066400000000000000000000002161475065116200164060ustar00rootroot00000000000000-- Zero strides are detected. -- == -- input { [1,2,3,4,5] 0i64 1i64 0i64 } -- error: out of bounds def main (xs: []i32) a b c = xs[a:b:c] futhark-0.25.27/tests/slice5.fut000066400000000000000000000003001475065116200164010ustar00rootroot00000000000000-- This might look like a transposition to the naive, but it is not! -- == -- input { [[1, 2], [3, 4]] } output { [[1, 2], [3, 4]] } entry main [n][m] (A: [n][m]i32) : [m][n]i32 = A[0:m,0:n] futhark-0.25.27/tests/slice6.fut000066400000000000000000000004141475065116200164100ustar00rootroot00000000000000-- Really careful when copying these slices! -- == -- input { [1,2,3,4,5,6,7,8,9,10,11,12,13] [13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] } -- output { [1, 12, 3, 10, 5, 8, 7, 6, 9, 4, 11, 12, 13] } def main (xs: *[]i32) (ys: []i32) = xs with [1:10:2] = ys[1:10:2] futhark-0.25.27/tests/soacs/000077500000000000000000000000001475065116200156145ustar00rootroot00000000000000futhark-0.25.27/tests/soacs/filter-error0.fut000066400000000000000000000002651475065116200210330ustar00rootroot00000000000000-- Test that the filter function must take nonunique arguments. -- == -- error: def main(a: *[][]i32): [][]i32 = let _ = filter (\(r: *[]i32): bool -> true) a in empty([]i32) futhark-0.25.27/tests/soacs/filter1.fut000066400000000000000000000002741475065116200177050ustar00rootroot00000000000000-- == -- input { -- [1,0,2,-5,3,-1] -- } -- output { -- [1, 2, 3] -- } -- input { -- empty([0]i32) -- } -- output { -- empty([0]i32) -- } def main(a: []i32): []i32 = filter (0<) a futhark-0.25.27/tests/soacs/filter2.fut000066400000000000000000000003701475065116200177030ustar00rootroot00000000000000-- == def main(a0: []f64, a1: []i32, oks: []bool): []f64 = let (b, _) = unzip(filter (\(x: (f64,i32)): bool -> let (_,i) = x in oks[i]) ( zip a0 a1)) in b futhark-0.25.27/tests/soacs/filter3.fut000066400000000000000000000006441475065116200177100ustar00rootroot00000000000000-- == -- input { -- [0, 1, -2, 5, 42] -- [false, true, true, false, true] -- } -- output { -- [true, true, true] -- [1, -2, 42] -- } def main (xs1: []i32) (xs2: []bool): ([]bool,[]i32) = let tmp = filter (\(x: (i32,bool)): bool -> let (i,b) = x in b ) (zip xs1 xs2) in unzip(map (\(x: (i32,bool)): (bool,i32) -> let (i,b) = x in (b,i) ) tmp) futhark-0.25.27/tests/soacs/filter4.fut000066400000000000000000000002061475065116200177030ustar00rootroot00000000000000-- The array produced by filter should be unique. def main (xs: *[]i32) = let xs' = filter (\x -> x>0) xs let xs[0] = 0 in xs' futhark-0.25.27/tests/soacs/map1-error.fut000066400000000000000000000002751475065116200203250ustar00rootroot00000000000000-- == -- error: def main(a: *[][]i32): *[]i32 = -- Should be an error, because all of 'a' is consumed at the point -- the map is invoked. map (\(r: *[]i32): i32 -> a[0,0]) a futhark-0.25.27/tests/soacs/map1.fut000066400000000000000000000004321475065116200171710ustar00rootroot00000000000000-- == -- input { -- [1,2,3,4,5,6,7,8,9] -- [1,2,3,4,5,6,7,8,9] -- } -- output { -- [1, 2, 3, 4, 5, 6, 7, 8, 9] -- [1, 2, 3, 4, 5, 6, 7, 8, 9] -- } def main (a: []i32) (b: []i32): ([]i32,[]i32) = let arr = zip a b in unzip(map (\(x: (i32,i32)): (i32,i32) -> x) arr) futhark-0.25.27/tests/soacs/map10.fut000066400000000000000000000002261475065116200172520ustar00rootroot00000000000000-- Test that a simple consuming map produces an error. -- == -- error: def main(a: *[][]f64): [][]f64 = map (\(r: *[]f64): *[]f64 -> r) a futhark-0.25.27/tests/soacs/map14.fut000066400000000000000000000003521475065116200172560ustar00rootroot00000000000000-- This program broke the simplifier at one point. def main(x: i32, y: i32, a: []i32, b: []i32): []i32 = let c = map (\(av: i32): (i32,i32) -> let v = x + y in (v, 2*av)) a in map (\(x,y)->x+y) c futhark-0.25.27/tests/soacs/map15.fut000066400000000000000000000003741475065116200172630ustar00rootroot00000000000000-- Test that a map not using its parameters can be turned into a -- replicate. -- -- == -- input { 2 [1,2,3] } -- output { [4, 4, 4] } -- structure { Map 0 Replicate 1 } def main (x: i32) (a: []i32): []i32 = map (\(y: i32): i32 -> x + 2) a futhark-0.25.27/tests/soacs/map16.fut000066400000000000000000000003641475065116200172630ustar00rootroot00000000000000-- Map returning an array predicated on the index variable. -- -- == -- input { 2i64 } -- output { [[0], [1]] } def main(chunk: i64): [][]i32 = map (\(k: i32): [1]i32 -> if k==0 then [0] else [1] ) (map i32.i64 (iota(chunk))) futhark-0.25.27/tests/soacs/map18.fut000066400000000000000000000002511475065116200172600ustar00rootroot00000000000000-- Single-iteration maps should be simplified away. -- -- == -- input { 2 } output { [4] } -- structure { Map 0 } def main(x: i32): [1]i32 = map (+x) (replicate 1 x) futhark-0.25.27/tests/soacs/map19.fut000066400000000000000000000004001475065116200172550ustar00rootroot00000000000000-- The interesting thing here is that the compiler should simplify -- away the copy. -- == -- input { [[1,2,3],[4,5,6]] } -- output { [[0, 2, 3], [0, 5, 6]] } -- structure { Replicate 0 } def main (xss: *[][]i32) = map (\xs -> copy xs with [0] = 0) xss futhark-0.25.27/tests/soacs/map2.fut000066400000000000000000000002041475065116200171670ustar00rootroot00000000000000-- == -- input { -- [1,2,3,4,5,6,7,8] -- } -- output { -- [3, 4, 5, 6, 7, 8, 9, 10] -- } def main(a: []i32): []i32 = map (+2) a futhark-0.25.27/tests/soacs/map20.fut000066400000000000000000000004541475065116200172560ustar00rootroot00000000000000-- The interesting thing here is that the compiler should simplify -- away the copy. -- == -- input { [[1,2,3],[4,5,6]] } -- output { [[0.0f32, 2.0f32, 3.0f32], [0.0f32, 5.0f32, 6.0f32]] } -- structure { Replicate 0 } def main (xss: *[][]i32) = map (\xs -> map f32.i32 (copy xs with [0] = 0)) xss futhark-0.25.27/tests/soacs/map3.fut000066400000000000000000000002701475065116200171730ustar00rootroot00000000000000-- Test a lambda with a free variable. -- == -- input { -- [1,2,3] -- 1 -- } -- output { -- [2, 3, 4] -- } def main (a: []i32) (y: i32): []i32 = map (\(x: i32): i32 -> (x+y)) a futhark-0.25.27/tests/soacs/map4.fut000066400000000000000000000007371475065116200172040ustar00rootroot00000000000000-- Test a tricky case involving rewriting lambda arguments in the -- tuple transformer. -- == -- input { -- [[1,5],[8,9],[2,4]] -- [[5,1],[9,2],[4,8]] -- } -- output { -- [6, 17, 6] -- } def inner(a: [][](i32,i32)): []i32 = map (\(r: [](i32,i32)): i32 -> let (x,y) = r[0] in x+y) a def main (a1: [][]i32) (a2: [][]i32): []i32 = let a = map (\(p: ([]i32,[]i32)) -> let (p1,p2) = p in zip p1 p2) ( zip a1 a2) in inner(a) futhark-0.25.27/tests/soacs/map5.fut000066400000000000000000000007571475065116200172070ustar00rootroot00000000000000-- Test a tricky case involving rewriting lambda arguments in the -- tuple transformer. -- == -- input { -- [[1,5],[8,9],[2,4]] -- [[5,1],[9,2],[4,8]] -- } -- output { -- [6, 17, 6] -- } def inner(a: [][](i32,i32)): []i32 = map (\(r1: [](i32,i32)): i32 -> let r2 = r1 let (x,y) = r2[0] in x+y) a def main(a1: [][]i32) (a2: [][]i32): []i32 = inner(map (\(r: ([]i32,[]i32)) -> let (r1,r2) = r in zip r1 r2) ( zip a1 a2)) futhark-0.25.27/tests/soacs/map6.fut000066400000000000000000000010561475065116200172010ustar00rootroot00000000000000-- == -- input { -- [[1,2,3],[4,5,6]] -- [[6,5,4],[3,2,1]] -- } -- output { -- [[7, 7, 7], [7, 7, 7]] -- [[-5, -3, -1], [1, 3, 5]] -- } def inner(a: [][](i32,i32)): [][](i32,i32) = map (\(row: [](i32,i32)) -> map (\(x: i32, y: i32) -> (x+y,x-y)) row) a def main(a1: [][]i32) (a2: [][]i32): ([][]i32, [][]i32) = let a = map (\(p: ([]i32,[]i32)) -> let (p1,p2) = p in zip p1 p2) ( zip a1 a2) in unzip(map (\(r: [](i32,i32)) -> unzip(r)) ( inner(a))) futhark-0.25.27/tests/soacs/map7.fut000066400000000000000000000004571475065116200172060ustar00rootroot00000000000000-- == -- input { -- [[1,2,3], [4,5,6]] -- [[2,1,3], [4,6,5]] -- } -- output { -- [[3, 3, 6], [8, 11, 11]] -- } def main (a1: [][]i32) (a2: [][]i32): [][]i32 = let b = map (\(row: ([]i32,[]i32)) -> let (x,y) = row in map2 (+) x y) (zip a1 a2) in b futhark-0.25.27/tests/soacs/map9.fut000066400000000000000000000011031475065116200171750ustar00rootroot00000000000000-- This test fails if the ISWIM transformation messes up the size -- annotations. -- == -- input { -- [1,1,1] -- [[1,2,3],[4,5,6],[7,8,9],[0,1,2],[3,4,5]] -- } -- output { -- [[3i32, 6i32, 11i32], [54i32, 162i32, 418i32], [2754i32, 10692i32, 34694i32], [5508i32, 32076i32, 208164i32], [60588i32, 577368i32, 5620428i32]] -- } def combineVs [n] (n_row: [n]i32): [n]i32 = map2 (*) n_row n_row def main [n][m] (md_starts: [m]i32) (md_vols: [n][m]i32): [][]i32 = let e_rows = map (\x -> map (+2) x) (map combineVs md_vols) in scan (\x y -> map2 (*) x y) md_starts e_rows futhark-0.25.27/tests/soacs/mapreduce.fut000066400000000000000000000005641475065116200203060ustar00rootroot00000000000000-- Mapping with a reduction. -- -- == -- tags { no_python } -- compiled input { 10i64 10i64 } -- output { [45i64, 145i64, 245i64, 345i64, 445i64, 545i64, 645i64, 745i64, 845i64, 945i64] } -- compiled input { 5i64 50i64 } auto output -- structure gpu { SegRed 1 } def main (n: i64) (m: i64): [n]i64 = let a = unflatten (iota (n*m)) in map (\a_r -> reduce (+) 0 a_r) a futhark-0.25.27/tests/soacs/mapscan.fut000066400000000000000000000006661475065116200177660ustar00rootroot00000000000000-- == -- tags { no_python no_wasm } -- input { 100i64 1000i64 } output { 870104 } -- compiled input { 400i64 1000i64} output { 985824 } -- compiled input { 100000i64 100i64} output { 15799424 } -- def main (n: i64) (m: i64): i32 = let a = map (\i -> map i32.i64 (map (+i) (iota(m)))) (iota(n)) let b = map (\(a_r: [m]i32): [m]i32 -> scan (+) 0 (a_r)) a in reduce (^) 0 (flatten b) futhark-0.25.27/tests/soacs/partition1.fut000066400000000000000000000005341475065116200204300ustar00rootroot00000000000000-- Simple test of the partition SOAC. -- == -- input { -- [0,1,2,3,4,5,6,7,8,9] -- } -- output { -- [0, 2, 4, 6, 8] -- [3, 9] -- [1, 5, 7] -- } def divisible_by_two(x: i32): bool = x % 2 == 0 def divisible_by_three(x: i32): bool = x % 3 == 0 def main(a: []i32): ([]i32, []i32, []i32) = partition2 divisible_by_two divisible_by_three a futhark-0.25.27/tests/soacs/partition2.fut000066400000000000000000000007701475065116200204330ustar00rootroot00000000000000-- Slightly more complicated test involving arrays of tuples. -- == -- input { -- [2,7,1,2,8,9,1,5,0,2] -- [6,1,0,9,8,7,9,4,2,1] -- } -- output { -- [2, 2, 1, 0] -- [6, 9, 9, 2] -- [8] -- [8] -- [7, 1, 9, 5, 2] -- [1, 0, 7, 4, 1] -- } def main (xs: []i32) (ys: []i32): ([]i32, []i32, []i32, []i32, []i32, []i32) = let (x,y,z) = partition2 (\(x,y)->xx==y) (zip xs ys) let (x1,x2) = unzip(x) let (y1,y2) = unzip(y) let (z1,z2) = unzip(z) in (x1, x2, y1, y2, z1, z2) futhark-0.25.27/tests/soacs/redomap0.fut000066400000000000000000000010051475065116200200370ustar00rootroot00000000000000-- This program does not contain a redomap on its own, but fusion will -- give rise to one. def grayCode(x: i32): i32 = (x >> 1) ^ x def testBit(n: i32, ind: i32): bool = let t = (1 << ind) in (n & t) == t def main [num_bits] (n: i64, dir_vs: [num_bits]i32): i32 = let reldv_vals = map (\(dv,i): i32 -> if testBit(grayCode(i32.i64 n),i) then dv else 0 ) (zip (dir_vs) (map i32.i64 (iota(num_bits))) ) in reduce (^) 0 (reldv_vals ) futhark-0.25.27/tests/soacs/redomap1.fut000066400000000000000000000012131475065116200200410ustar00rootroot00000000000000-- Test a redomap with map-out where each element is also an array. -- -- == -- input { 5i64 2i64 } -- output { [[0i32, 1i32], -- [2i32, 3i32], -- [4i32, 5i32], -- [6i32, 7i32], -- [8i32, 9i32]] -- false -- } -- input { 0i64 1i64 } -- output { empty([0][1]i32) true } def main (n: i64) (m: i64): ([][]i32, bool) = let ass = map (\l: [m]i32 -> map i32.i64 (map (+l*m) (iota(m)))) (iota(n)) let ps = map2 (\(as: []i32) (i: i32): bool -> as[i] % 2 == 0) ass (map i32.i64 (map (%m) (iota(n)))) in (ass, reduce (&&) true ps) futhark-0.25.27/tests/soacs/reduce0.fut000066400000000000000000000005001475065116200176560ustar00rootroot00000000000000-- How quickly can we reduce arrays? -- -- == -- tags { no_python } -- input { 0 } -- output { 0 } -- input { 100 } -- output { 4950 } -- compiled input { 100000 } -- output { 704982704 } -- compiled input { 100000000 } -- output { 887459712 } -- structure gpu { Iota 0 } def main(n: i32): i32 = reduce (+) 0 (0.. (accx && x, y)) (false,0) ( zip (replicate n true) (replicate n 1)) in (a,b) futhark-0.25.27/tests/soacs/reduce4.fut000066400000000000000000000003031475065116200176630ustar00rootroot00000000000000-- Reduction where the accumulator is an array. -- == -- input { [[1,2],[3,4],[5,6]] } -- output { [9, 12] } def main [n][m] (as: [n][m]i32): []i32 = reduce_comm (map2 (+)) (replicate m 0) as futhark-0.25.27/tests/soacs/reduce5.fut000066400000000000000000000006341475065116200176730ustar00rootroot00000000000000-- Reduction with an array accumulator, where the compiler (probably) -- cannot do a interchange. -- -- == -- input { [[1,2,3], [4,5,6], [6,7,8]] } -- output { [11i32, 14i32, 17i32] } def main [n][m] (xss: [n][m]i32): []i32 = reduce_comm(\(xs: []i32) ys -> loop zs = replicate m 0 for i < m do let zs[i] = xs[i] + ys[i] in zs) (replicate m 0) xss futhark-0.25.27/tests/soacs/reduce6.fut000066400000000000000000000004041475065116200176670ustar00rootroot00000000000000-- A reduction whose operator could cause existential memory. -- == -- random input { [40][4]i32 } auto output def main [n] (xs: [n][4]i32) = let op (x: [4]i32) (y: [4]i32) : [4]i32 = if x[0] < y[0] then x else y in reduce op [i32.lowest, 0, 0, 0] xs futhark-0.25.27/tests/soacs/reduce7.fut000066400000000000000000000003421475065116200176710ustar00rootroot00000000000000-- A funky reduce with vectorised operator (and not interchangeable). -- == -- compiled random input { [100][100]i32 } auto output def main [n][m] (xss: [n][m]i32) = reduce (map2 (+)) (replicate m 0) (map (scan (+) 0) xss) futhark-0.25.27/tests/soacs/reduce8.fut000066400000000000000000000002551475065116200176750ustar00rootroot00000000000000-- #1800 -- == -- input { 100i64 } auto output -- compiled input { 100000i64 } auto output def main n = i64.sum (map (\i -> loop i = i+1 while i < 1000 do i * 3) (iota n)) futhark-0.25.27/tests/soacs/reduce9.fut000066400000000000000000000005441475065116200176770ustar00rootroot00000000000000-- == -- compiled random input { [512][32]i32 } auto output -- compiled random input { [1024][1024]i32 } auto output def main [n][m] (xss: [n][m]i32): []i32 = reduce (\(xs: []i32) ys -> loop zs = replicate m 0 for i < m do let zs[i] = xs[i] + ys[i] in zs ) (replicate m 0) xss futhark-0.25.27/tests/soacs/scan-with-map.fut000066400000000000000000000011251475065116200210030ustar00rootroot00000000000000-- This one is tricky to get to run without too many memory copies. -- When it was added, in-place-lowering couldn't get it to work right. -- -- Now it also functions as a test for whether scans with higher-order -- operators work. Note that it is possible the scan is interchanged -- with the inner map during fusion or kernel extraction. -- -- == -- tags { no_python } -- compiled input { [0,0,0] [1,2,3] 100001i64 } output { 233120i32 } def main [n] (a: [n]i32) (b: [n]i32) (m: i64): i32 = let contribs = replicate m b let res = scan (map2 (+)) a contribs in reduce (^) 0 (flatten res) futhark-0.25.27/tests/soacs/scan0.fut000066400000000000000000000005241475065116200173410ustar00rootroot00000000000000-- Big prefix sum on iota. -- -- For result simplicity, we only return the last element. -- -- == -- tags { no_python } -- input { 100i64 } output { 4950 } -- compiled input { 1000000i64 } output { 1783293664i32 } -- structure gpu { SegScan 1 Iota 0 } def main(n: i64): i32 = let a = scan (+) 0 (map i32.i64 (iota(n))) in a[n-1] futhark-0.25.27/tests/soacs/scan1.fut000066400000000000000000000003061475065116200173400ustar00rootroot00000000000000-- == -- input { -- [1,2,3,4,5,6,7,8,9] -- } -- output { -- [1, 3, 6, 10, 15, 21, 28, 36, 45] -- } -- compiled random input { [1000000]i32 } auto output def main(a: []i32): []i32 = scan (+) 0 a futhark-0.25.27/tests/soacs/scan2.fut000066400000000000000000000011121475065116200173350ustar00rootroot00000000000000-- A segmented scan of a two-dimensional array. -- -- == -- random input { [20]bool [20][2]i32 } auto output -- compiled random input { [2000]bool [2000][10]i32 } auto output def segmented_scan [n] 't (op: t -> t -> t) (ne: t) (flags: [n]bool) (as: [n]t): [n]t = (unzip (scan (\(x_flag,x) (y_flag,y) -> (x_flag || y_flag, if y_flag then y else x `op` y)) (false, ne) (zip flags as))).1 def main [n][m] (flags: [n]bool) (xss: [n][m]i32): [n][m]i32 = segmented_scan (map2 (+)) (replicate m 0) flags xss futhark-0.25.27/tests/soacs/scan3.fut000066400000000000000000000017421475065116200173470ustar00rootroot00000000000000-- Test scanomap fusion and scan kernel generation with map-outs. -- -- == -- input { -- 10 -- [0i32, 1i32, 2i32, 3i32, 4i32, 5i32, 6i32, 7i32, 8i32, 9i32, 10i32, 11i32, -- 12i32, 13i32, 14i32, 15i32, 16i32, 17i32, 18i32, 19i32, 20i32, 21i32, 22i32, -- 23i32, 24i32, 25i32, 26i32, 27i32, 28i32, 29i32, 30i32, 31i32, 32i32, 33i32, -- 34i32, 35i32, 36i32, 37i32, 38i32, 39i32, 40i32, 41i32, 42i32, 43i32, 44i32, -- 45i32, 46i32, 47i32, 48i32, 49i32, 50i32, 51i32, 52i32, 53i32, 54i32, 55i32, -- 56i32, 57i32, 58i32, 59i32, 60i32, 61i32, 62i32, 63i32, 64i32, 65i32, 66i32, -- 67i32, 68i32, 69i32, 70i32, 71i32, 72i32, 73i32, 74i32, 75i32, 76i32, 77i32, -- 78i32, 79i32, 80i32, 81i32, 82i32, 83i32, 84i32, 85i32, 86i32, 87i32, 88i32, -- 89i32, 90i32, 91i32, 92i32, 93i32, 94i32, 95i32, 96i32, 97i32, 98i32, 99i32] -- } -- output { 12 55 } -- structure { Screma 1 } def main(i: i32) (a: []i32): (i32, i32) = let b = map (+2) a let c = scan (+) 0 a in (b[i], c[i]) futhark-0.25.27/tests/soacs/scan4.fut000066400000000000000000000007251475065116200173500ustar00rootroot00000000000000-- More complicated scanomap example. Distilled from radix sort. -- == -- -- input { -- [83, 1, 4, 99, 33, 0, 6, 5] -- } -- output { -- [4339, 4586, 4929, 5654, 6120, 6554, 7046, 7535] -- } def step [n] (xs: [n]i32): [n]i32 = let bits = map (+1) xs let ps1 = scan (+) 0 bits let bits_sum = reduce (+) 0 bits let ps1' = map (+bits_sum) ps1 let xs' = map2 (+) (ps1') xs in xs' def main [n] (xs: [n]i32): [n]i32 = loop (xs) for i < 2 do step(xs) futhark-0.25.27/tests/soacs/scan5.fut000066400000000000000000000005101475065116200173410ustar00rootroot00000000000000-- A funky scan with vectorised operator (and not interchangeable). -- == -- compiled random input { [1][100]i32 } auto output -- compiled random input { [100][1]i32 } auto output -- compiled random input { [100][100]i32 } auto output def main [n][m] (xss: [n][m]i32) = scan (map2 (+)) (replicate m 0) (map (scan (+) 0) xss) futhark-0.25.27/tests/soacs/scan6.fut000066400000000000000000000010221475065116200173410ustar00rootroot00000000000000-- Even more funky scan with vectorised operator on arrays (and not interchangeable). -- == -- compiled random input { [1][10][100]i32 } auto output -- compiled random input { [100][10][1]i32 } auto output -- compiled random input { [100][10][100]i32 } auto output def vecadd [n] (xs: [n]i32) (ys: [n]i32) : [n]i32 = loop res = replicate n 0 for i < n do res with [i] = xs[i] + ys[i] def main [n][m][k] (xsss: [n][m][k]i32) = scan (map2 vecadd) (replicate m (replicate k 0)) (map (scan (map2 (+)) (replicate k 0)) xsss) futhark-0.25.27/tests/soacs/scan7.fut000066400000000000000000000007351475065116200173540ustar00rootroot00000000000000-- Segmented scan with array operator (and not interchangeable). -- == -- random input { [10][1][10]i32 } auto output -- random input { [10][10][1]i32 } auto output -- random input { [10][10][10]i32 } auto output -- structure gpu { /SegScan 1 /SegScan/SegBinOp/Loop 1 } def vecadd [m] (xs: [m]i32) (ys: [m]i32): [m]i32 = loop xs = copy xs for i < m do let xs[i] = xs[i] + ys[i] in xs def main [n][m][k] (xss: [n][m][k]i32) = map (scan vecadd (replicate k 0)) xss futhark-0.25.27/tests/soacs/scan8.fut000066400000000000000000000005231475065116200173500ustar00rootroot00000000000000-- Segmented scan with array operator (interchangeable). -- == -- random input { [10][1][10]i32 } auto output -- random input { [10][10][1]i32 } auto output -- random input { [10][10][10]i32 } auto output -- structure gpu { /SegScan 1 /SegScan/Loop 0 } def main [n][m][k] (xss: [n][m][k]i32) = map (scan (map2 (+)) (replicate k 0)) xss futhark-0.25.27/tests/soacs/segreduce-iota.fut000066400000000000000000000006471475065116200212430ustar00rootroot00000000000000-- == -- random input { 2i64 10i64 } output { [0,10] } -- random input { 2i64 1000i64 } output { [0,1000] } -- random input { 0i64 2i64 } output { empty([0]i32) } -- random input { 0i64 1000i64 } output { empty([0]i32) } -- random input { 1000i64 2i64 } auto output -- random input { 1000i64 0i64 } auto output def array n m = map (\i -> replicate m (i32.i64 i)) (iota n) entry main n m: []i32 = array n m |> map i32.sum futhark-0.25.27/tests/stacktrace.fut000066400000000000000000000005341475065116200173520ustar00rootroot00000000000000-- We can produce useful stack traces on errors, at least for -- non-recursive functions. -- == -- input { [1] 1 } -- error: stacktrace.fut:7.*stacktrace.fut:9.*stacktrace.fut:11.*stacktrace.fut:13 def f (xs: []i32) (i: i32) = xs[i] def g (xs: []i32) (i: i32) = f xs i def h (xs: []i32) (i: i32) = g xs i def main (xs: []i32) (i: i32) = h xs i futhark-0.25.27/tests/stencil-1.fut000066400000000000000000000014431475065116200170250ustar00rootroot00000000000000-- Simple rank-1 one-dimensional stencil computation. Eventually -- smooths out all differences. -- -- == -- input { 1i64 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] } -- output { [1.3333333333333333, 2.0, 3.0, 3.9999999999999996, 5.0, 5.666666666666666] } -- input { 2i64 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] } -- output { -- [1.5555555555555554, -- 2.111111111111111, -- 2.9999999999999996, -- 3.9999999999999996, -- 4.888888888888888, -- 5.444444444444444] } def main [n] (num_iterations: i64) (a: [n]f64): []f64 = loop (a) for i < num_iterations do map (\(i: i64): f64 -> let x = if i == 0 then a[i] else a[i-1] let y = a[i] let z = if i == n-1 then a[i] else a[i+1] let factor = 1.0/3.0 in factor*x + factor*y + factor*z ) (iota(n)) futhark-0.25.27/tests/stencil-2.fut000066400000000000000000000025341475065116200170300ustar00rootroot00000000000000-- Simple rank-1 two-dimensional stencil computation. Eventually -- smooths out all differences. -- == -- input { -- 0 -- [[1.0,2.0,3.0], -- [4.0,5.0,6.0], -- [7.0,8.0,9.0]] -- } -- output { -- [[1.0,2.0,3.0], -- [4.0,5.0,6.0], -- [7.0,8.0,9.0]] -- } -- input { -- 1 -- [[1.0,2.0,3.0], -- [4.0,5.0,6.0], -- [7.0,8.0,9.0]] -- } -- output { -- [[1.8, 2.6000000000000005, 3.4000000000000004], -- [4.2, 5.0, 5.800000000000001], -- [6.6000000000000005, 7.4, 8.2]] -- } -- input { -- 2 -- [[1.0,2.0,3.0], -- [4.0,5.0,6.0], -- [7.0,8.0,9.0]] -- } -- output { -- [[2.44, 3.0800000000000005, 3.7200000000000006], -- [4.36, 5.0, 5.640000000000001], -- [6.280000000000001, 6.920000000000001, 7.56]] -- } def main [n][m] (num_iterations: i32) (a: [n][m]f64): [][]f64 = loop (a) for i < num_iterations do map (\i -> map (\j -> let center = a[i,j] let north = if i == 0 then center else a[i-1,j] let east = if j == m-1 then center else a[i,j+1] let south = if i == n-1 then center else a[i+1,j] let west = if j == 0 then center else a[i,j-1] let factor = 1.0/5.0 in factor*center + factor*north + factor*east + factor*south + factor*west ) (iota(m)) ) (iota(n)) futhark-0.25.27/tests/sumtypes/000077500000000000000000000000001475065116200163755ustar00rootroot00000000000000futhark-0.25.27/tests/sumtypes/coerce0.fut000066400000000000000000000005051475065116200204350ustar00rootroot00000000000000-- == -- input { true 3i64 } -- output { [0i64,1i64,2i64] } -- input { false 3i64 } -- output { [0i64] } type opt 't = #some t | #none def f b (x: i64) = if b then #some (iota x) else (#none : opt ([1]i64)) :> opt ([]i64) def main b x = let v = f b x in match v case #some arr -> arr case #none -> [0] futhark-0.25.27/tests/sumtypes/coerce1.fut000066400000000000000000000002331475065116200204340ustar00rootroot00000000000000-- == -- error: Ambiguous size.*anonymous size type opt 't = #some t | #none def f b (x: i64) = if b then #some (iota x) else #none :> opt ([]i64) futhark-0.25.27/tests/sumtypes/coerce2.fut000066400000000000000000000002261475065116200204370ustar00rootroot00000000000000-- == -- error: Ambiguous.*size coercion type opt 't = #some t | #none def f b (x: i64) = if b then #some (iota x) else #none :> opt ([2]i64) futhark-0.25.27/tests/sumtypes/existential-match.fut000066400000000000000000000003031475065116200225340ustar00rootroot00000000000000type~ sumT = #foo []i64 | #bar i64 def thing xs : sumT = #foo (filter (>0) xs) def main (xs: []i64) : []i64 = match thing xs case #foo xs' -> xs ++ xs' case #bar i -> xs ++ [i] futhark-0.25.27/tests/sumtypes/sumtype0.fut000066400000000000000000000003251475065116200207030ustar00rootroot00000000000000-- Basic sum type. -- == -- input { } -- output { 5 } type foobar = #foo i32 | #bar i16 def main : i32 = match (#foo 5) : foobar case #bar 5 -> 1 case #foo 4 -> 2 case (#foo x) -> x case _ -> 3 futhark-0.25.27/tests/sumtypes/sumtype1.fut000066400000000000000000000006321475065116200207050ustar00rootroot00000000000000-- Matches on nested tuples within sumtypes. -- == -- input { } -- output { 4 } type foobar = #foo ((i32, (i32, i32)), i32) | #bar i32 def main : i32 = match (#foo ((3,(1,10)), 2)) : foobar case (#foo (_, 3)) -> 1 case (#foo ((4,_), 2)) -> 2 case (#bar 5) -> 3 case (#foo ((3,(_, 10)), _)) -> 4 case (#bar _) -> 5 case (#foo _) -> 6 futhark-0.25.27/tests/sumtypes/sumtype10.fut000066400000000000000000000002571475065116200207700ustar00rootroot00000000000000-- Inexhaustive sumtype pattern match. -- == -- error: Unmatched cases type foobar = #foo i32 | #bar i32 def main : i32 = match ((#bar 12) : foobar) case (#foo _) -> 1 futhark-0.25.27/tests/sumtypes/sumtype11.fut000066400000000000000000000002271475065116200207660ustar00rootroot00000000000000-- Fail if lacking a type annotation. -- == -- error: Type is ambiguous def main : i32 = match (#bar 12) case (#foo _) -> 1 case (#bar _) -> 2 futhark-0.25.27/tests/sumtypes/sumtype12.fut000066400000000000000000000006241475065116200207700ustar00rootroot00000000000000-- Sumtypes with function payloads. -- == -- input { } -- output { 7 } type^ mooboo = #moo (i32 -> i32) | #boo i32 type^ foobar = #foo (mooboo -> i32 -> i32) | #bar def main : i32 = let f (mb : mooboo) x = match mb case (#moo g) -> g x case (#boo _) -> 0 in match (#foo f, #moo (+5)) : (foobar, mooboo) case ((#foo h), mb) -> h mb 2 case _ -> 1 futhark-0.25.27/tests/sumtypes/sumtype13.fut000066400000000000000000000010151475065116200207640ustar00rootroot00000000000000-- Sumtype in-place updates. -- == -- input { } -- output { [-1, -2, -3, -4, 2, 4, 6, 8] } type mooboo = #moo i32 | #boo i32 def swap_inplace (ns : []i32) : *[]mooboo = let x = map (\n -> #moo n) ns ++ map (\n -> #boo n) ns in loop x for i < 2*(length ns) do x with [i] = match x[i] case (#moo x) -> #boo (-x) case (#boo x) -> #moo (2 * x) def f (x : mooboo) : i32 = match x case (#moo x) -> x case (#boo x) -> x def main : []i32 = map f (swap_inplace [1,2,3,4]) futhark-0.25.27/tests/sumtypes/sumtype14.fut000066400000000000000000000006631475065116200207750ustar00rootroot00000000000000-- Sumtypes in module types. -- == -- input { } -- output { 15 } module type foobar_mod = { type^ foobar val f : foobar -> i32 -> i32 val bar : foobar } module sum_module : foobar_mod = { type^ foobar = #foo i32 | #bar (i32 -> i32) def f (fb : foobar) (x : i32) : i32 = match fb case (#foo y) -> x + y case (#bar f) -> f x def bar = (#bar (+5)) : foobar } def main : i32 = sum_module.f sum_module.bar 10 futhark-0.25.27/tests/sumtypes/sumtype15.fut000066400000000000000000000004341475065116200207720ustar00rootroot00000000000000-- Lists as payloads -- == -- input { } -- output { [1, 6] } -- Note: this test currently fails. type maybe 'a = #none | #some a def f (x : maybe ([]i32)) : i32 = match x case #none -> 1 case (#some xs) -> foldl (+) 0 xs def main : []i32 = map f [#none, #some [1,2,3]] futhark-0.25.27/tests/sumtypes/sumtype16.fut000066400000000000000000000002671475065116200207770ustar00rootroot00000000000000-- Type abbreviations -- == -- input { } -- output { 1 } type foobar 't = #foo t | #bar i32 def main : i32 = match (#foo 1) : foobar i32 case (#foo x) -> x case _ -> 0 futhark-0.25.27/tests/sumtypes/sumtype17.fut000066400000000000000000000003201475065116200207660ustar00rootroot00000000000000-- Lifted type parameters -- == -- input { } -- output { 1 } type^ foobar 't '^s = #foo t | #bar s def main : i32 = match (#bar (+1)) : foobar i32 (i32 -> i32) case (#foo x) -> x case (#bar f) -> f 0 futhark-0.25.27/tests/sumtypes/sumtype18.fut000066400000000000000000000010241475065116200207710ustar00rootroot00000000000000-- Missing pattern error. -- -- Note that there are multiple ways this error can be reported, so if -- you are fiddling with the match checker, feel free to change the -- expected error here. -- == -- error: Unmatched cases.*#foo \(#moo _ #none\) type some 't = #none | #some t type^ mooboo '^t = #moo t (some i32) | #boo type^ foobar = #foo (mooboo (i32 -> i32)) | #bar (i32 -> i32) def main : i32 = match (#foo (#moo (+1) #none)) : foobar case #bar f -> 0 case #foo #boo -> 2 case #foo (#moo f (#some 5)) -> 2 futhark-0.25.27/tests/sumtypes/sumtype19.fut000066400000000000000000000006011475065116200207720ustar00rootroot00000000000000-- Missing pattern error 2. -- == -- error: Unmatched cases type foobar = #foo (i32, (i32, i32), i32) | #bar type moo = #moo i32 foobar i32 type boo = #boo (i32, moo) type blah = #blah i32 def main : i32 = match (#boo (7, #moo 5 (#foo (1, (2, 3), 4)) 6)) : boo case (#boo (_, (#moo _ (#foo (_, (_, 3), _)) 6))) -> 2 case (#boo (_, (#moo _ (#bar ) _))) -> 2 futhark-0.25.27/tests/sumtypes/sumtype2.fut000066400000000000000000000003361475065116200207070ustar00rootroot00000000000000-- Sumtypes as function arguments. -- == -- input { } -- output { 2 } type foobar = #foo i32 | #bar i32 def f (x : foobar) : i32 = match x case (#foo _) -> 1 case (#bar _) -> 2 def main : i32 = f (#bar 12) futhark-0.25.27/tests/sumtypes/sumtype20.fut000066400000000000000000000004601475065116200207650ustar00rootroot00000000000000-- Missing pattern warning error. -- == -- error: Unmatched cases type some = #none | #some i32 i32 type foobar = #foo i32 some i32 def main : i32 = -- match (#foo 1 (#some 2) 3) : foobar -- case (#foo 1 (#some 2) 3) -> 1 match (#some 1 2) : some case (#none) -> 1 case (#some 1 2) -> 2 futhark-0.25.27/tests/sumtypes/sumtype21.fut000066400000000000000000000003001475065116200207570ustar00rootroot00000000000000-- Sumtype aliasing. -- == -- error: "xs".*consumed type sum [n] = #foo ([n]i32) | #bar ([n]i32) def main [n] (xs: *[n]i32) = let v : sum [n] = #foo xs let xs[0] = 0 let v' = v in 0 futhark-0.25.27/tests/sumtypes/sumtype22.fut000066400000000000000000000002171475065116200207670ustar00rootroot00000000000000-- Sumtype aliasing. -- == type sum = #foo ([3]i32) | #bar ([2]i32) def main (xs: *[3]i32) = let v : sum = #foo xs let xs[0] = 0 in xs futhark-0.25.27/tests/sumtypes/sumtype23.fut000066400000000000000000000002361475065116200207710ustar00rootroot00000000000000-- Sumtype consumption. -- == type^ sum = #foo ([]i32) | #bar ([]i32) def main (v: *sum) : *[]i32 = match v case #foo arr -> arr case #bar arr -> arr futhark-0.25.27/tests/sumtypes/sumtype24.fut000066400000000000000000000003231475065116200207670ustar00rootroot00000000000000-- Sumtype consumption. -- == -- error: "v".*consumed type^ sum = #foo (*[]i32) | #bar (*[]i32) def main (v: *sum) = let x = match v case #foo v -> v with [0] = 0 case #bar v -> v in v futhark-0.25.27/tests/sumtypes/sumtype25.fut000066400000000000000000000001441475065116200207710ustar00rootroot00000000000000-- Issue 785 type mbpd = #Just {pos:i32} def main (pd: mbpd) = match pd case #Just x -> x.pos futhark-0.25.27/tests/sumtypes/sumtype26.fut000066400000000000000000000001461475065116200207740ustar00rootroot00000000000000-- Issue 785 type mbpd = #Just {pos:i32} def main (pd: mbpd) = match pd case #Just {pos} -> pos futhark-0.25.27/tests/sumtypes/sumtype27.fut000066400000000000000000000006571475065116200210040ustar00rootroot00000000000000-- == -- input { [0,-1,3,4,-2] } -- output { 1 } type opt 'a = #some a | #none def opt 'a 'b (b: b) (f: a -> b) (x: opt a) : b = match x case #some x' -> f x' case #none -> b def cat_opt [n] (xs: [n](opt i32)) = let either (x : opt i32) (y : opt i32) = opt x (\y' -> #some y') y in match reduce either #none xs case #none -> 0 case #some _ -> 1 def main (xs: []i32) = cat_opt (map (\x -> #some x) xs) futhark-0.25.27/tests/sumtypes/sumtype28.fut000066400000000000000000000001771475065116200210020ustar00rootroot00000000000000-- Arrays literals of sum types. -- == -- error: bool type t = #c i32 def main = let ts = [#c 1, #c false] : []t in 0i32 futhark-0.25.27/tests/sumtypes/sumtype29.fut000066400000000000000000000005041475065116200207750ustar00rootroot00000000000000-- Optimise representation where the components must be flattened out -- first. type twod = {x: f32, y: f32} type threed = {x: f32, y: f32, z: f32} type point = #twod twod | #threed threed def main (p: point) : point = match p case #twod ds -> #threed {x=ds.x, y=ds.y, z=0} case #threed {x, y, z=_} -> #twod {x, y} futhark-0.25.27/tests/sumtypes/sumtype3.fut000066400000000000000000000003431475065116200207060ustar00rootroot00000000000000-- Sumtype as a type parameter. -- == -- input { } -- output { 2 } def id 'a (x : a) : a = x def f (x : #foo i32 | #bar i32) : i32 = match x case (#foo y) -> y case (#bar y) -> y def main : i32 = f (id (#bar 2)) futhark-0.25.27/tests/sumtypes/sumtype30.fut000066400000000000000000000004221475065116200207640ustar00rootroot00000000000000-- Deduplication for nested sum types. type either 'a 'b = #left a | #right b type t = either bool (either (either i32 i32) i32) def main (x: i32) = match (#right (#left (#left x))) : t case #right (#right x) -> x-1 case #right (#left (#left x)) -> x case _ -> 0 futhark-0.25.27/tests/sumtypes/sumtype31.fut000066400000000000000000000001631475065116200207670ustar00rootroot00000000000000-- Ordering is not defined for sum types. -- == -- error: sum type def main (x: f32) (y: f32) = #foo x > #foo y futhark-0.25.27/tests/sumtypes/sumtype32.fut000066400000000000000000000002501475065116200207650ustar00rootroot00000000000000-- Specific error message on constructor mismatches. -- == -- error: Unshared constructors: #d, #c. def f (v: #a i32 | #b i32 | #c i32) : #a i32 | #b i32 | #d i32 = v futhark-0.25.27/tests/sumtypes/sumtype33.fut000066400000000000000000000003331475065116200207700ustar00rootroot00000000000000-- Proper inference of size in sum type in negative position, even -- when not all the constructors of the sum type are known yet. type sometype 'a = #someval a def error : i32 -> sometype ([]i32) = \_ -> #someval [] futhark-0.25.27/tests/sumtypes/sumtype34.fut000066400000000000000000000001121475065116200207640ustar00rootroot00000000000000type sometype 't = #someval t def main : sometype (*[]i32) = #someval [1] futhark-0.25.27/tests/sumtypes/sumtype35.fut000066400000000000000000000001251475065116200207710ustar00rootroot00000000000000type t = #foo ([2]i32) | #bar ([2]i32) ([2]i32) def main (x: t) (y: i32) = 2 futhark-0.25.27/tests/sumtypes/sumtype36.fut000066400000000000000000000001141475065116200207700ustar00rootroot00000000000000type t = #foo ([2]i32) | #bar ([2]i32) ([2]i32) def main (x: t) = x futhark-0.25.27/tests/sumtypes/sumtype37.fut000066400000000000000000000001341475065116200207730ustar00rootroot00000000000000type t [n] = #foo ([n]i32) | #bar ([n]i32) ([n]i32) def main [n] (x: t [n]) = x futhark-0.25.27/tests/sumtypes/sumtype38.fut000066400000000000000000000002231475065116200207730ustar00rootroot00000000000000-- == -- error: Unmatched type r = {f0: bool, f1: bool} def f (x: r) = match x case {f0=false, f1=false} -> 0 case {f0=true, f1=true} -> 0 futhark-0.25.27/tests/sumtypes/sumtype39.fut000066400000000000000000000001631475065116200207770ustar00rootroot00000000000000-- == -- error: Unmatched def f (x: (bool, bool)) = match x case (false, false) -> 0 case (true, true) -> 0 futhark-0.25.27/tests/sumtypes/sumtype4.fut000066400000000000000000000002331475065116200207050ustar00rootroot00000000000000-- Constructors with different fields should be different. -- == -- error: #foo i16.*#foo i32 def g (x : #foo i32) : #foo i16 = match x case y -> y futhark-0.25.27/tests/sumtypes/sumtype40.fut000066400000000000000000000001431475065116200207650ustar00rootroot00000000000000-- == -- error: Unmatched def f (x: (i32, i32)) = match x case (0, _) -> 0 case (_, 1) -> 0 futhark-0.25.27/tests/sumtypes/sumtype41.fut000066400000000000000000000002231475065116200207650ustar00rootroot00000000000000-- == -- error: Unmatched type tuple = #tuple bool bool def f (x: tuple) = match x case #tuple false false -> 0 case #tuple true true -> 0 futhark-0.25.27/tests/sumtypes/sumtype42.fut000066400000000000000000000002261475065116200207710ustar00rootroot00000000000000-- == -- input { -1i32 } output { true } -- input { 1i32 } output { false } def main (x: i32) = match x case -1 -> true case _ -> false futhark-0.25.27/tests/sumtypes/sumtype43.fut000066400000000000000000000002311475065116200207660ustar00rootroot00000000000000-- == -- input { -1f32 } output { true } -- input { 1f32 } output { false } def main (x: f32) = match x case -1f32 -> true case _ -> false futhark-0.25.27/tests/sumtypes/sumtype44.fut000066400000000000000000000003201475065116200207660ustar00rootroot00000000000000-- == -- error: Unmatched type inst = #foo | #bar | #baz def exec (inst: inst) = #[unsafe] let x = 0i32 in match inst case #foo -> x + 1 case #baz -> x - 1 futhark-0.25.27/tests/sumtypes/sumtype45.fut000066400000000000000000000004211475065116200207710ustar00rootroot00000000000000-- From #1748 def main (a: bool) (b: bool) (c: bool): () = match (a, b, c) case (false, false, false) -> () case (false, false, true) -> () case (false, true, false) -> () case (_, true, true) -> () case (true, _, false) -> () case (true, false, true) -> () futhark-0.25.27/tests/sumtypes/sumtype46.fut000066400000000000000000000001471475065116200207770ustar00rootroot00000000000000-- == -- error: cannot match type t = #foo f64 def main (surf: t) = match surf case #foo -> true futhark-0.25.27/tests/sumtypes/sumtype47.fut000066400000000000000000000001501475065116200207720ustar00rootroot00000000000000-- == -- error: cannot match type t = #foo f64 def main (surf: t) = match surf case #foo x y -> x futhark-0.25.27/tests/sumtypes/sumtype48.fut000066400000000000000000000001451475065116200207770ustar00rootroot00000000000000-- == -- error: Unshared constructors type t = #foo | #bar let f b : t = if b then #foo else #baar futhark-0.25.27/tests/sumtypes/sumtype49.fut000066400000000000000000000003501475065116200207760ustar00rootroot00000000000000-- Based on #1917 -- == -- input {} output { 3 2 } type^ t = #foo (i32 -> i32) | #bar (f32 -> i32) def use_p (p: t) = match p case #foo f -> f 2 case #bar f -> f 2 def main = (use_p (#foo (+1)), use_p (#bar i32.f32)) futhark-0.25.27/tests/sumtypes/sumtype5.fut000066400000000000000000000004031475065116200207050ustar00rootroot00000000000000-- Arrays of sumtypes. -- == -- input { } -- output { [1, -2, 3, -4] } type foobar = #foo i32 | #bar i32 def f (x : foobar) : i32 = match x case (#foo y) -> y case (#bar y) -> -y def main : []i32 = map f ([#foo 1, #bar 2, #foo 3, #bar 4] : []foobar) futhark-0.25.27/tests/sumtypes/sumtype50.fut000066400000000000000000000002451475065116200207710ustar00rootroot00000000000000type^ t = #foo (i32 -> i32) | #bar i32 def use_p (p: t) = match p case #foo f -> f 2 case #bar x -> x def main = (use_p (#foo (+1)), use_p (#bar 1)) futhark-0.25.27/tests/sumtypes/sumtype51.fut000066400000000000000000000003471475065116200207750ustar00rootroot00000000000000-- Based on #1950. -- == -- error: Causality check type option 'a = #None | #Some a def gen () : ?[n].[n]i32 = let (n,_) = (0,true) in replicate n 0i32 entry main b: option ([]i32) = if b then #None else #Some(gen ()) futhark-0.25.27/tests/sumtypes/sumtype52.fut000066400000000000000000000004241475065116200207720ustar00rootroot00000000000000-- Based on #1950. -- == -- error: Causality check type option 'a = #None | #Some a def gen () : ?[n].[n]i32 = let (n,_) = (0,true) in replicate n 0i32 def ite b t f = if b then t() else f() entry main b: option ([]i32) = ite b (\() -> #None) (\() -> #Some(gen ())) futhark-0.25.27/tests/sumtypes/sumtype6.fut000066400000000000000000000004271475065116200207140ustar00rootroot00000000000000-- Sumtype equality. -- == -- input { } -- output { 2 } type foobar = #foo i32 | #bar i32 def main : i32 = if ((#foo 5) : foobar) == #foo 4 then 1 else if ((#bar 1) : foobar) == #bar 1 then 2 else 3 futhark-0.25.27/tests/sumtypes/sumtype7.fut000066400000000000000000000007201475065116200207110ustar00rootroot00000000000000-- N-ary sumtypes. -- == -- input { } -- output { 26 } type^ foobar = #foo i32 | #bar i32 type^ boomoo = #boo foobar i32 {field1: i32, field2: []i32} | #moo i32 foobar def main : i32 = match ((#boo (#bar 5) 10 {field1 = 1, field2 = [1,2,3,4]}) : boomoo) case (#boo (#bar 5) 10 {field1 = 2, field2 = _}) -> 1 case (#boo (#bar 1) 10 {field1 = 1, field2 = _}) -> 2 case (#boo (#bar x) y {field1 = w, field2 = v}) -> x + y + w + foldl (+) 0 v case _ -> 3 futhark-0.25.27/tests/sumtypes/sumtype8.fut000066400000000000000000000002351475065116200207130ustar00rootroot00000000000000-- Constructor order shouldn't matter. -- == type foobar = #foo i32 | #bar i32 type barfoo = #bar i32 | #foo i32 def main (x : foobar) = (#bar 5) : barfoo futhark-0.25.27/tests/sumtypes/sumtype9.fut000066400000000000000000000003051475065116200207120ustar00rootroot00000000000000-- Sumtype matches on wildcards. -- == -- input { } -- output { 2 } type foobar = #foo i32 | #bar i16 def main : i32 = match ((#bar 1) : foobar) case (#foo _) -> 1 case (#bar _) -> 2 futhark-0.25.27/tests/three_way_partition.fut000066400000000000000000000021431475065116200213040ustar00rootroot00000000000000-- A manually implemented partitioning, hardcoded for 3 equivalence classes. -- -- == -- input { [1f32, 2f32, 3f32, 4f32, 5f32, 6f32, 7f32, 8f32, 9f32] -- [0i64, 1i64, 2i64, 3i64, 0i64, 1i64, 2i64, 3i64, 0i64] } -- output { 3i64 2i64 2i64 [1f32, 5f32, 9f32, 2f32, 6f32, 3f32, 7f32] } def main [n] (vs: [n]f32) (classes: [n]i64): (i64, i64, i64, []f32) = let flags = map (\c -> if c == 0 then (1, 0, 0) else if c == 1 then (0, 1, 0) else if c == 2 then (0, 0, 1) else (0, 0, 0)) classes let is0 = scan (\(a0,b0,c0) (a1,b1,c1) -> (a0+a1,b0+b1,c0+c1)) (0,0,0) flags let (size_0, size_1, size_2) = is0[n-1] let filter_size = size_0 + size_1 + size_2 let is1 = map2 (\(ai,bi,ci) c -> if c == 0 then ai - 1 else if c == 1 then size_0 + bi - 1 else if c == 2 then size_0 + size_1 + ci - 1 else -1) is0 classes in (size_0, size_1, size_2, spread filter_size 0 is1 vs) futhark-0.25.27/tests/tiling/000077500000000000000000000000001475065116200157725ustar00rootroot00000000000000futhark-0.25.27/tests/tiling/issue1933.fut000066400000000000000000000151411475065116200201640ustar00rootroot00000000000000-- Heavily mangled from original test case. This one was quite -- hellish to shrink. -- -- We really should not tile here at all, due to how much else is -- going on in the kernel. Feel free to remove this program if one -- day we improve tiling to detect that it is not profitable. local let hash(x: i32): i32 = let x = u32.i32 x let x = ((x >> 16) ^ x) * 0x45d9f3b let x = ((x >> 16) ^ x) * 0x45d9f3b let x = ((x >> 16) ^ x) in i32.u32 x type rng = {state: u64, inc: u64} let rand ({state, inc}: rng) = let oldstate = state let state = oldstate * 6364136223846793005u64 + (inc|1u64) let xorshifted = u32.u64 (((oldstate >> 18u64) ^ oldstate) >> 27u64) let rot = u32.u64 (oldstate >> 59u64) in ({state, inc}, (xorshifted >> rot) | (xorshifted << ((-rot) & 31u32))) type^ distribution [n] 'rng 'a 'b = { sample: rng -> (rng, [n]a), transform: [n]a -> [n]a, log_prob: [n]a -> b, log_prob_base: [n]a -> b } module normal_diag = { let log_sqrt_2pi = f32.log <| f32.sqrt <| 2 * f32.pi def sample n rng: (rng, [n]f32) = let rngs = replicate n rng let (rngs, xs) = map rand rngs |> unzip in (rngs[0], map f32.u32 xs) def transform n means stds (xs: [n]f32) = map3 (\mean std x -> mean + x * std) means stds xs def log_prob n means stds (xs: [n]f32) = map3 (\mean std x -> (-1) * ((x-mean)**2)/(2*std**2) - f32.log std - log_sqrt_2pi ) means stds xs |> f32.sum def log_prob_base n (xs: [n]f32) = map (\x -> (-1) * (x**2)/2 - log_sqrt_2pi) xs |> f32.sum def mk_dist [n] (means: [n]f32) stds: distribution [n] rng f32 f32 = { sample = sample n, transform = transform n means stds, log_prob = log_prob n means stds, log_prob_base = log_prob_base n } } def mean [n] (xs: [n]f32) = f32.sum xs / f32.i64 n def MAX_NUM_RV = 3i64 type addr = i64 type state [n] [n2] = { rng: rng, log_like: f32, trace: [MAX_NUM_RV](bool, [n]f32), conditioned: [MAX_NUM_RV](bool, [n]f32), log_probs: [MAX_NUM_RV](bool, f32), fix_unknown_size_n2: [n2]f32 } def mk_empty_state n n2 rng_init = { rng=rng_init, log_like = 0f32, trace = replicate MAX_NUM_RV (false, replicate n 0f32), conditioned = replicate MAX_NUM_RV (false, replicate n 0f32), log_probs = replicate MAX_NUM_RV (false, 0f32), fix_unknown_size_n2 = replicate n2 0f32 } -- Handlers. type^ message [n] [n2] 'c 't2 = #sample (state [n] [n2]) addr (distribution [n] rng c f32) | #observe (state [n] [n2]) addr (distribution [n2] rng t2 f32) ([n]c) [n2]t2 | #return (state [n] [n2]) [n]c type^ handler [n] [n2] 'c 't2 = message [n] [n2] c t2 -> (state [n] [n2], [n]c) def default_handler [n] [n2] 'c 't2 (req: message [n] [n2] c t2) = match req case #sample s _a d -> let (rng', c) = d.sample s.rng in (s with rng = rng', c) case #observe s _a _d c _obs -> (s, c) case #return s c -> (s, c) def store_log_probs_and_transform [n] [n2] 'c 't2 (parent_handler: handler [n] [n2] c t2) (req: message [n] [n2] c t2) = match req case #sample _s a d -> let (s: state [n] [n2], c) = parent_handler req let log_probs' = (copy s.log_probs) with [a] = (true, d.log_prob_base c) in (s with log_probs = log_probs', d.transform c) case #observe _s a d _c obs -> let (s: state [n] [n2], c) = parent_handler req let log_probs' = (copy s.log_probs) with [a] = (true, d.log_prob obs) in (s with log_probs = log_probs', c) case _ -> parent_handler req def store_trace [n] [n2] 'c 't2 parent_handler (req: message [n] [n2] c t2) = match req case #sample _s a _d -> let (s: state [n] [n2], c) = parent_handler req let trace' = (copy s.trace) with [a] = (true, c) in (s with trace = trace', c) case _ -> parent_handler req def reuse_conditioned [n] [n2] 'c 't2 parent_handler (req: message [n] [n2] c t2) = match req case #sample s a _d -> let (conditioned, x) = s.conditioned[a] let (s, c) = if conditioned then (s, x) else parent_handler req let trace' = copy s.trace with [a] = (true, c) in (s with trace = trace', c) case _ -> parent_handler req def sample [n] [n2] 'c 't2 (h: handler [n] [n2] c t2) s a d = h (#sample s a d) def observe [n] [n2] 'c 't2 (h: handler [n] [n2] c t2) s a d c obs = h (#observe s a d c obs) def log_weight [n] [n2] (state_p: state [n] [n2]) (state_q: state [n] [n2]) = let log_ps = state_p.log_probs let log_qs = state_q.log_probs let log_p_div_q = map2 (\(in_p, log_p) (in_q, log_q) -> if in_p && in_q then log_p - log_q else 0 ) log_ps log_qs |> reduce (+) 0f32 in state_p.log_like + log_p_div_q def importance_sampling latent_dim n2 rng y p theta q phi = let sempty = mk_empty_state latent_dim n2 rng let handler = store_trace (store_log_probs_and_transform default_handler) let sq = q handler sempty y phi let state = (sempty with conditioned = sq.trace) with rng = sq.rng let handler = reuse_conditioned handler let sp = p handler state y theta let log_w = log_weight sp sq let log_p = (reduce (+) 0 (map (.1) sp.log_probs)) in (sp.rng, log_w, log_p) def grad 'a (f: a -> f32) (primal: a) = vjp f primal (f32.i32 1) def normal (means, stddevs) = normal_diag.mk_dist means stddevs def dot (xs: []f32) (ys: []f32) = f32.sum (map2 (*) xs ys) def matvecmul [n][m] (xss: [n][m]f32) (ys: [m]f32): *[n]f32 = map (dot ys) xss def dense (in_dim: i64) (out_dim: i64) = let init (initfn: (n: i64) -> rng -> (rng, [n]f32)) rng = let (rng, weights) = initfn (in_dim * out_dim) rng let (_, bias) = initfn out_dim rng in (unflatten weights, bias) let apply (params: ([out_dim][in_dim]f32, [out_dim]f32)) (xs: [in_dim]f32): [out_dim]f32 = let (weightsT, bias) = params in map2 (+) (matvecmul weightsT xs) bias in apply def LATENT_DIM = 2i64 def INPUT_DIM = 8i64 def model decoder handler s y theta = let (mu, sigma) = (replicate LATENT_DIM 0f32, replicate LATENT_DIM 1f32) let (s, z) = sample handler s 0 (normal (mu, sigma)) let y_probs = (decoder theta z) let (s, _) = observe handler s 1 (normal (y_probs, y_probs)) z y in s entry main [batch_sz] rng theta phi (ys: [batch_sz][INPUT_DIM]f32) = let (f_apply) = dense LATENT_DIM INPUT_DIM let fake_init n rng = (rng, replicate n 1f32) let (model, guide) = (model f_apply, model f_apply) let elbo_loss rngs ys (theta, phi) = let elbos = map2 (\rng y -> let (_rng, log_w, log_p) = importance_sampling LATENT_DIM INPUT_DIM rng y model theta guide phi in log_w * log_p ) rngs ys in mean elbos in grad (elbo_loss rng ys) (theta, phi) futhark-0.25.27/tests/tiling/issue1940.fut000066400000000000000000000011511475065116200201560ustar00rootroot00000000000000-- The #[unsafe] is just to avoid noisy assertions. The bug here is -- the handling of the loop. -- == -- structure gpu { /SegMap/Loop/SegMap 2 } def dotprod [n] (a: [n]f64) (b: [n]f64): f64 = map2 (*) a b |> reduce (+) 0 def newton_equality [n][m] (_: [m][n]f64) hessian_val = #[unsafe] let KKT = replicate (n + m) (replicate (n + m) 0) let KKT[:n, :n] = hessian_val in loop y=replicate n 0 for i in 0.. loop acc = 0 for i < x do i32.sum (map (+acc) xs)) xs futhark-0.25.27/tests/tiling/seqloop_1d.fut000066400000000000000000000006261475065116200205640ustar00rootroot00000000000000-- Two dimensions in the kernel, but we are only tiling along the -- innermost one. -- == -- input { [1,2,3] [[1,2,3],[4,5,6],[7,8,9]] [1,2,3] } auto output -- structure gpu { SegMap/Loop/Loop/SegMap 2 } def main (ns: []i32) (xs: [][]i32) (ys: []i32) = map (\n -> map (\y -> loop y for i < n do #[sequential] i32.sum (map (+y) (#[unsafe] xs[i]))) ys) ns futhark-0.25.27/tests/tiling/seqloop_1d_postlude.fut000066400000000000000000000005741475065116200225050ustar00rootroot00000000000000-- A prelude value is used both within the tiled loop and the -- postlude. -- == -- random input { 3 [10][20]i32 [10]i32 } auto output -- structure gpu { SegMap/Loop/Loop/SegMap 2 } def main (n: i32) (xs: [][]i32) (ys: []i32) = map (\y -> let y' = loop y for i < n do #[sequential] i32.sum (map (+y) (#[unsafe] xs[i])) in y + y') ys futhark-0.25.27/tests/tiling/seqloop_2d.fut000066400000000000000000000010061475065116200205560ustar00rootroot00000000000000-- 2D tiling where the loop to tile is inside a sequential loop, that -- is variant to an outermost dimension. -- == -- input { [1,2,3] [[1,2,3],[4,5,6]] [[1,2,3],[4,5,6]] } auto output -- structure gpu { SegMap/Loop/Loop/SegMap 2 } def main [k] (ns: []i32) (xss: [][k]i32) (yss: [][k]i32) = map (\n -> map (\xs' -> map (\ys' -> loop z = 0 for _p < n do #[sequential] i32.sum (map (+z) (map2 (*) xs' ys'))) yss) xss) ns futhark-0.25.27/tests/tiling/tiling2.fut000066400000000000000000000003631475065116200200640ustar00rootroot00000000000000-- Simple 2D tiling -- == -- structure gpu { SegMap/SegMap 4 -- SegMap/Loop/SegMap 3 -- SegMap/SegMap 4 -- SegMap/Loop/Loop/SegMap/Loop 0 } def main (xs: [][]i32) (ys: [][]i32) = map (\xs' -> map (\ys' -> i32.sum (map2 (*) xs' ys')) ys) xs futhark-0.25.27/tests/tiling/tiling_1d.fut000066400000000000000000000003031475065116200203600ustar00rootroot00000000000000-- Simple 1D tiling -- == -- compiled random input { [100]i32 } auto output -- structure gpu { SegMap/Loop/SegMap 2 } def main (xs: []i32) = map (\x -> #[sequential] i32.sum (map (+x) xs)) xs futhark-0.25.27/tests/tiling/tiling_1d_complex.fut000066400000000000000000000017041475065116200221150ustar00rootroot00000000000000-- More stuff that can go wrong with a larger tiling prelude, but -- still just 1D tiling. -- == -- no_ispc compiled random input { [2000]f32 [2000]f32 } auto output -- structure gpu { SegMap/Loop/SegMap 2 } type point = (f32, f32) def add_points ((x1, y1): point) ((x2, y2): point) : point = (x1 + x2, y1 + y2) def euclid_dist_2 ((x1, y1): point) ((x2, y2): point) : f32 = (x2 - x1) ** 2.0f32 + (y2 - y1) ** 2.0f32 def closest_point (p1: (i32, f32)) (p2: (i32, f32)) : (i32, f32) = if p1.1 < p2.1 then p1 else p2 def find_nearest_point [k] (pts: [k]point) (pt: point) : i32 = let (i, _) = reduce_comm closest_point (0, euclid_dist_2 pt pts[0]) (zip (map i32.i64 (iota k)) (map (euclid_dist_2 pt) pts)) in i def main [n] (xs: [n]f32) (ys: [n]f32) = let points = zip xs ys let cluster_centres = take 10 points in #[sequential_inner] map (find_nearest_point cluster_centres) points futhark-0.25.27/tests/tiling/tiling_1d_partial.fut000066400000000000000000000010621475065116200220770ustar00rootroot00000000000000-- Tiling a redomap when not all of the arrays are invariant. -- == -- compiled random input { [256][256]f32 [256]f32 } auto output -- compiled random input { [256][10]f32 [256]f32 } auto output -- structure gpu { SegMap/Loop/SegMap 2 } def dotprod [n] (xs: [n]f32) (ys: [n]f32): f32 = reduce (+) 0.0 (map2 (*) xs ys) def main [n][m] (xss: [m][n]f32) (ys: [m]f32) = -- The transpose is to avoid a coalescing optimisation that would -- otherwise make manifestation of a transpose the bottleneck. map (\xs -> #[sequential] dotprod ys xs) (transpose xss) futhark-0.25.27/tests/tiling/tiling_2d.fut000066400000000000000000000004011475065116200203600ustar00rootroot00000000000000-- Simple 2D tiling -- == -- structure gpu { SegMap/SegMap 4 -- SegMap/Loop/SegMap 3 -- SegMap/SegMap 4 -- SegMap/Loop/Loop/SegMap/Loop 0 } def main (xs: [][]i32) (ys: [][]i32) = map (\xs' -> map (\ys' -> #[sequential] i32.sum (map2 (*) xs' ys')) ys) xs futhark-0.25.27/tests/tiling/tiling_2d_indirect.fut000066400000000000000000000006031475065116200222450ustar00rootroot00000000000000-- 2D tiling, but where the arrays are variant to an outermost third dimension. -- == -- structure gpu { SegMap/SegMap 4 -- SegMap/Loop/SegMap 3 -- SegMap/SegMap 4 -- SegMap/Loop/Loop/SegMap/Loop 0 } def main (is: []i32) (js: []i32) (xss: [][][]i32) (yss: [][][]i32) = map2 (\i j -> map (\xs' -> map (\ys' -> i32.sum (map2 (*) xs' ys')) (#[unsafe] yss[j])) (#[unsafe] xss[i])) is js futhark-0.25.27/tests/tiling/tiling_2d_inner.fut000066400000000000000000000006001475065116200215540ustar00rootroot00000000000000-- 2D tiling with extra dimensions on top. -- == -- compiled random input { [2][40][40]i32 [2][40][40]i32 } auto output -- structure gpu { SegMap/SegMap 4 -- SegMap/Loop/SegMap 3 -- SegMap/SegMap 4 -- SegMap/Loop/Loop/SegMap/Loop 0 } def main [a][b][c] (xss: [a][b][c]i32) (yss: [a][b][c]i32) = map2 (\xs ys -> map (\xs' -> map (\ys' -> i32.sum (map2 (*) xs' ys')) ys) xs) xss yss futhark-0.25.27/tests/tiling/tiling_2d_partial.fut000066400000000000000000000013221475065116200220770ustar00rootroot00000000000000-- 2D tiling when not all of the arrays are invariant. -- == -- compiled random input { -- [256][256]i32 -- [256][256]i32 -- [256][256][256]i32 -- [256]i32 -- } auto output -- compiled random input { -- [256][10]i32 -- [10][256]i32 -- [256][10][256]i32 -- [10]i32 -- } auto output -- structure gpu { SegMap/Loop/SegMap 2 } def main [a][b][c] (xs: [a][c]i32) (ys: [c][b]i32) (zsss: [b][c][a]i32) (vs: [c]i32) = map2 (\xs' zss -> map2 (\ys' zs -> #[sequential] i32.sum (map2 (+) vs (map2 (*) zs (map2 (*) xs' ys')))) (transpose ys) zss) xs (transpose (map transpose zsss)) futhark-0.25.27/tests/tiling/tricky_prelude0.fut000066400000000000000000000012311475065116200216140ustar00rootroot00000000000000-- A case of tiling with a complex dependency. -- == -- compiled random input { [100]i32 } auto output -- structure gpu { SegMap/Loop/SegMap 2 } def main (xs: []i32) = map (\x -> let y = x + 2 -- Used in postlude. let z = loop z=0 while z < 1337 do z * 2 + y -- Uses 'y', but -- cannot itself -- be inlined in -- the postlude. let loopres = #[sequential] i32.sum (map (+z) xs) -- Uses z. in loopres + i32.clz y -- 'y' must be made available here. ) xs futhark-0.25.27/tests/tiling/tricky_prelude1.fut000066400000000000000000000006731475065116200216260ustar00rootroot00000000000000-- Make sure not to share the in-place updated array too much in the -- prelude. -- == -- random input { [10]f64 [100]f64 } auto output def seqmap [n] 'a 'b (f: a -> b) (xs: [n]a) (dummy: b) : [n]b = loop out = replicate n dummy for i < n do out with [i] = f xs[i] def main [n] (xs: [n]f64) is = #[sequential_inner] map (\i -> let arr = seqmap (+i) (map f64.i64 (iota n)) 0 in f64.sum (map2 (+) xs arr)) is futhark-0.25.27/tests/tokens.fut000066400000000000000000000021171475065116200165300ustar00rootroot00000000000000-- == -- input { [102u8, 111u8, 111u8, 32u8, 98u8, 97u8, 114u8, 32u8, 98u8, 97u8, 122u8] 1 } -- output { [98u8, 97u8, 114u8] } def segmented_scan 't [n] (g:t->t->t) (ne: t) (flags: [n]bool) (vals: [n]t): [n]t = let pairs = scan ( \ (v1,f1) (v2,f2) -> let f = f1 || f2 let v = if f2 then v2 else g v1 v2 in (v,f) ) (ne,false) (zip vals flags) let (res,_) = unzip pairs in res def is_space (x: u8) = x == ' ' def isnt_space x = !(is_space x) def f &&& g = \x -> (f x, g x) module tokens : { type token [w] val tokens [n] : [n]u8 -> ?[k][w].(token [w] -> ?[m].[m]u8, [k](token [w])) } = { type token [w] = ([w](), i64, i64) def tokens [n] (s: [n]u8) = let rep = replicate 0 () in (\(_, i, k) -> #[unsafe] s[i:i+k], segmented_scan (+) 0 (map is_space s) (map (isnt_space >-> i64.bool) s) |> (id &&& rotate 1) |> uncurry zip |> zip (indices s) |> filter (\(_,(x,y)) -> x > y) |> map (\(i,(x,_)) -> (rep,i-x+1,x))) } def main xs i = let (f, ts) = tokens.tokens xs in f ts[i] futhark-0.25.27/tests/trailing-commas.fut000066400000000000000000000003361475065116200203140ustar00rootroot00000000000000def tuple : (i32,bool,) = (1,true,) def record : {x:i32,y:bool,} = {x=0,y=true,} def array = [1,2,3,] def index = (unflatten (iota (2*3)))[0,1,] def attr = #[trail(a,b,)] true def tpat (x,y,) = true def rpat {x,y,} = true futhark-0.25.27/tests/transpose1.fut000066400000000000000000000006751475065116200173330ustar00rootroot00000000000000-- == -- input { -- [[1,2,3],[4,5,6]] -- } -- output { -- [[1, 4], [2, 5], [3, 6]] -- } -- input { -- empty([0][1]i32) -- } -- output { -- empty([1][0]i32) -- } -- compiled random input { [10][10]i32 } auto output -- compiled random input { [1024][4]i32 } auto output -- compiled random input { [4][1024]i32 } auto output -- compiled random input { [1024][1024]i32 } auto output def main [n][m] (a: [n][m]i32): [m][n]i32 = transpose a futhark-0.25.27/tests/transpose2.fut000066400000000000000000000002541475065116200173250ustar00rootroot00000000000000-- Verify that we can perform an in-place update on the result of a -- transposition. -- == def main (xss: *[][]i32) = let xss' = transpose xss in xss' with [0,1] = 2 futhark-0.25.27/tests/tridag.fut000066400000000000000000000033601475065116200165000ustar00rootroot00000000000000-------------------------------------------- --subroutine tridag(a,b,c,d,nn) -- -- dimension a(nn),b(nn),c(nn),d(nn) -- -- if(nn .eq. 1) then -- d(1)=d(1)/b(1) -- return -- end if -- -- do 10 k = 2,nn -- xm = a(k)/b(k-1) -- b(k) = b(k) - xm*c(k-1) -- d(k) = d(k) - xm*d(k-1) --10 continue -- -- d(nn) = d(nn)/b(nn) -- k = nn -- do 20 i = 2,nn -- k = nn + 1 - i -- d(k) = (d(k) - c(k)*d(k+1))/b(k) --20 continue -- return --end --------------------------------------------/ -- == -- input { -- } -- output { -- [1.000000, 0.335000, -13.003881, 4.696531, 2.284400, -1.201104, 23.773315, -- 6.997077, 5.064199, 3.832115] -- [-3.039058, 6.578116, -1.103211, -6.002021, 8.084014, -3.267883, -0.332640, -- 2.726331, -1.618253, 1.472878] -- } def tridag(nn: i32, b: *[]f64, d: *[]f64, a: []f64, c: []f64 ): ([]f64,[]f64) = if (nn == 1) --then ( b, map(\f64 (f64 x, f64 y) -> x / y, d, b) ) then (b, [d[0]/b[0]]) else let (b,d) = loop((b, d)) for i < (nn-1) do let xm = a[i+1] / b[i] let b[i+1] = b[i+1] - xm*c[i] let d[i+1] = d[i+1] - xm*d[i] in (b, d) let d[nn-1] = d[nn-1] / b[nn-1] in let d = loop(d) for i < (nn-1) do let k = nn - 2 - i let d[k] = ( d[k] - c[k]*d[k+1] ) / b[k] in d in (b, d) def main: ([]f64,[]f64) = let nn = reduce (+) 0 ([1,2,3,4]) let a = replicate nn 3.33 let b = map (\x -> f64.i64(x) + 1.0) (iota(nn)) let c = map (\x -> 1.11*f64.i64(x) + 0.5) (iota(nn)) let d = map (\x -> 1.01*f64.i64(x) + 0.25) (iota(nn)) in tridag(i32.i64 nn, b, d, a, c) futhark-0.25.27/tests/tupleTest.fut000066400000000000000000000004431475065116200172160ustar00rootroot00000000000000-- Test various abuse of tuples - specifically, the flattening done by -- internalisation. -- == -- input { -- } -- output { -- 8 -- 11 -- } def f(x: (i32,i32)): (i32,i32) = x def main: (i32,i32) = let x = 1 + 2 let y = (x + 5, 4+7) let (z, (t,q)) = (x, y) in f(y) futhark-0.25.27/tests/tupleindex.fut000066400000000000000000000003311475065116200174020ustar00rootroot00000000000000-- == -- input { -- } -- output { -- 0 -- 1 -- } def main: (i32, i32) = let arr = [(0,1), (2,3), (4,5)] let n = length arr let outarr = replicate n (0,0) let i = 0 let outarr[i] = arr[i] in outarr[0] futhark-0.25.27/tests/types/000077500000000000000000000000001475065116200156505ustar00rootroot00000000000000futhark-0.25.27/tests/types/README.md000066400000000000000000000001101475065116200171170ustar00rootroot00000000000000Type system tests. The programs here should be simple and not do much. futhark-0.25.27/tests/types/alias-error0.fut000066400000000000000000000001311475065116200206630ustar00rootroot00000000000000-- No circular types! -- -- == -- error: Unknown type type t = t def main(x: t): t = x futhark-0.25.27/tests/types/alias-error2.fut000066400000000000000000000001471475065116200206740ustar00rootroot00000000000000-- Warn on unique non-arrays -- -- == -- warning: has no effect* type t = i32 def main(x: *t): t = x futhark-0.25.27/tests/types/alias-error3.fut000066400000000000000000000002161475065116200206720ustar00rootroot00000000000000-- You may not define the same alias twice. -- -- == -- error: Duplicate.*mydup type mydup = i32 type mydup = f32 def main(x: i32): i32 = x futhark-0.25.27/tests/types/alias-error4.fut000066400000000000000000000001421475065116200206710ustar00rootroot00000000000000-- No undefined types! -- -- == -- error: Unknown type type foo = bar def main(x: foo): foo = x futhark-0.25.27/tests/types/alias-error5.fut000066400000000000000000000001721475065116200206750ustar00rootroot00000000000000-- A type abbreviation must use all of its size parameters. -- == -- error: Size parameter "\[n\]" type matrix [n] = i32 futhark-0.25.27/tests/types/alias0.fut000066400000000000000000000000661475065116200175430ustar00rootroot00000000000000type best_type = i32 def main(x: i32): best_type = x futhark-0.25.27/tests/types/alias1.fut000066400000000000000000000001221475065116200175350ustar00rootroot00000000000000type t = i32 type ts [n] = [n]t def main(xs: ts [], x: t): ts [] = map (+x) xs futhark-0.25.27/tests/types/alias2.fut000066400000000000000000000002331475065116200175410ustar00rootroot00000000000000-- Can we put type aliases in lambdas too? type t = i32 type ts [n] = [n]t def main(xs: ts []): [](ts []) = map (\(x: t): [10]t -> replicate 10 x) xs futhark-0.25.27/tests/types/alias3.fut000066400000000000000000000002031475065116200175370ustar00rootroot00000000000000-- Nest array type aliases type t = i32 type ts [n] = [n]t type tss [n][m] = [n](ts [m]) def main(xss: tss [][]): tss [][] = xss futhark-0.25.27/tests/types/alias4.fut000066400000000000000000000002061475065116200175430ustar00rootroot00000000000000-- An array type alias can be unique. type matrix [n][m] = [n][m]i32 def main(m: *matrix [][]): matrix [][] = let m[0,0] = 0 in m futhark-0.25.27/tests/types/alias5.fut000066400000000000000000000001511475065116200175430ustar00rootroot00000000000000-- Uniqueness goes outside-in. type uniqlist [n] = *[n]i32 def main(p: [][]i32): [](uniqlist []) = p futhark-0.25.27/tests/types/badsquare-lam.fut000066400000000000000000000002701475065116200211050ustar00rootroot00000000000000-- The error here could be better. -- == -- error: scope violation type square [n] 't = [n][n]t def ext_square : i64 -> square [] i64 = \n -> tabulate_2d (n+1) (n+2) (\i j -> i + j) futhark-0.25.27/tests/types/badsquare.fut000066400000000000000000000002161475065116200203360ustar00rootroot00000000000000-- == -- error: Sizes.*do not match type square [n] 't = [n][n]t def ext_square n : square [] i64 = tabulate_2d (n+1) (n+2) (\i j -> i + j) futhark-0.25.27/tests/types/error0.fut000066400000000000000000000003731475065116200176040ustar00rootroot00000000000000-- Based on https://stackoverflow.com/questions/56376512/why-do-i-get-cannot-unify-t%e2%82%82-with-type-f32-when-compiling-and-how-do-i-sol -- == -- error: "\^" def hit_register (x : f32) (y : f32) : bool = ((x - 1.0)^2.0 + (y - 1.0)^2.0) <= 1.0 futhark-0.25.27/tests/types/ext0.fut000066400000000000000000000002221475065116200172440ustar00rootroot00000000000000-- == -- input { 0i64 } -- output { [[true,true],[true,true]] } def main x : ?[n].[n][n]bool = let n = x+2 in replicate n (replicate n true) futhark-0.25.27/tests/types/ext1.fut000066400000000000000000000004041475065116200172470ustar00rootroot00000000000000-- == -- input { 1i64 } -- output { [true, true] [false, false] } type^ t = ?[n].([n]bool, bool -> [n]bool) def v x : t = let x' = x + 1 in (replicate x' true, \b -> replicate x' b) def main x = let (arr, f) = v x in unzip (zip arr (f false)) futhark-0.25.27/tests/types/ext2.fut000066400000000000000000000002451475065116200172530ustar00rootroot00000000000000-- == -- error: Sizes .* do not match type^ t = ?[n].([n]bool, bool -> [n]bool) def v x : t = let x' = x + 1 in (replicate x' true, \b -> replicate x b) futhark-0.25.27/tests/types/ext3.fut000066400000000000000000000001421475065116200172500ustar00rootroot00000000000000-- == -- input { [true,false] } -- output { 2i64 } def main (xs: ?[n].[n]bool) : i64 = length xs futhark-0.25.27/tests/types/ext4.fut000066400000000000000000000001641475065116200172550ustar00rootroot00000000000000-- == -- input { [true] [2] } -- output { 1i64 } def main [n] (xs: ?[m].[m]bool) (ys: [n]i32) = length (zip xs ys) futhark-0.25.27/tests/types/ext5.fut000066400000000000000000000002611475065116200172540ustar00rootroot00000000000000-- == -- error: Existential size "n" def f : (i64, i64) -> ?[n][m].(b: bool) -> *[n][m]bool = \(n,m) (b: bool) -> replicate n (replicate m b) def main x = map id (f x true) futhark-0.25.27/tests/types/ext6.fut000066400000000000000000000002641475065116200172600ustar00rootroot00000000000000-- == -- input { 1i64 2i64 } -- output { [[true, true]] } def f (n: i64) (m: i64) (b: bool) = replicate n (replicate m b) def g = uncurry f def main a b = map id (g (a,b) true) futhark-0.25.27/tests/types/ext7.fut000066400000000000000000000002111475065116200172510ustar00rootroot00000000000000-- == -- error: Existential size would appear def f (n: i64) (m: i64) (b: [n][m]bool) = b[0,0] def g = uncurry f def main x y = g x y futhark-0.25.27/tests/types/ext8.fut000066400000000000000000000003661475065116200172650ustar00rootroot00000000000000-- == -- input { 2i64 3i64 } -- output { true } def f (n: i64) (m: i64) : ([n][m](), [n][m]bool -> bool) = (replicate n (replicate m ()), \b -> b[0,0]) def g = uncurry f def main x y = let (a,f) = g (x,y) in f (map (map (const true)) a) futhark-0.25.27/tests/types/ext9.fut000066400000000000000000000002501475065116200172560ustar00rootroot00000000000000-- == -- input { [1,2,3] } -- output { 2i64 3i64 } def f [n] [m] 't (_: [n+m]t) = (n, m) def main [n] (xs: [n]i32) = f (let xs' = filter (<3) xs in xs' ++ xs) futhark-0.25.27/tests/types/function-error0.fut000066400000000000000000000002061475065116200214220ustar00rootroot00000000000000-- Polymorphic function called incorrectly. -- == -- error: Cannot apply "f" def f 't (x: t) (y: t) = (x,y) def main () = f 1 false futhark-0.25.27/tests/types/function-error2.fut000066400000000000000000000002471475065116200214310ustar00rootroot00000000000000-- Anonymous array element type misused. -- == -- error: Cannot apply "reverse" to "x" def reverse [n] [m] 't (a: [m][n]t) = a[::-1] def main (x: []i32) = reverse x futhark-0.25.27/tests/types/function-error3.fut000066400000000000000000000001341475065116200214250ustar00rootroot00000000000000-- Entry points may not be polymorphic. -- == -- error: polymorphic def main 't (x: t) = x futhark-0.25.27/tests/types/function0.fut000066400000000000000000000001661475065116200203000ustar00rootroot00000000000000-- Simplest polymorphic function. -- == -- input { 1 } output { 1 } def id 't (x: t): t = x def main(x: i32) = id x futhark-0.25.27/tests/types/function1.fut000066400000000000000000000003001475065116200202670ustar00rootroot00000000000000-- Polymorphic function used with multiple different types. -- == -- input { 1 true 2 } output { 1 true 2 } def id 't (x: t): t = x def main (x: i32) (y: bool) (z: i32) = (id x, id y, id z) futhark-0.25.27/tests/types/function2.fut000066400000000000000000000002361475065116200203000ustar00rootroot00000000000000-- Anonymous array element type. -- == -- input { [1,2,3] } output { [3,2,1] } def reverse [n] 't (a: [n]t): [n]t = a[::-1] def main (x: []i32) = reverse x futhark-0.25.27/tests/types/function3.fut000066400000000000000000000003571475065116200203050ustar00rootroot00000000000000-- Anonymous array tuple element type. -- == -- input { [1,2,3] [false,true,true] } output { [3,2,1] [true,true,false] } def reverse [n] 'a 'b (a: [n](a,b)): [n](a,b) = a[::-1] def main (x: []i32) (y: []bool) = unzip (reverse (zip x y)) futhark-0.25.27/tests/types/function4.fut000066400000000000000000000003031475065116200202750ustar00rootroot00000000000000-- A parametric type can be instantiated with an array. -- == -- input { [[1],[2],[3]] } output { [[3],[2],[1]] } def reverse [n] 't (a: [n]t): [n]t = a[::-1] def main (x: [][]i32) = reverse x futhark-0.25.27/tests/types/function5.fut000066400000000000000000000003261475065116200203030ustar00rootroot00000000000000-- Locally defined polymorphic function used with multiple different types. -- == -- input { 1 true 2 } output { 1 true 2 } def main (x: i32) (y: bool) (z: i32) = let id 't (x: t): t = x in (id x, id y, id z) futhark-0.25.27/tests/types/function6.fut000066400000000000000000000002241475065116200203010ustar00rootroot00000000000000-- A polymorphic function can be used curried. -- == -- input { [1,2,3] } output { [1,2,3] } def id 't (x: t) = x def main(xs: []i32) = map id xs futhark-0.25.27/tests/types/function7.fut000066400000000000000000000001621475065116200203030ustar00rootroot00000000000000-- Array dimensions in function type may refer to previous named parameters. def f (g: (n: i64) -> [n]i32) = g 0 futhark-0.25.27/tests/types/inference-error1.fut000066400000000000000000000001721475065116200215360ustar00rootroot00000000000000-- No switcharoos. -- == -- error: Function body does not have expected type def id 'a 'b (x: a) (y: b): (a, b) = (y, x) futhark-0.25.27/tests/types/inference-error10.fut000066400000000000000000000002621475065116200216160ustar00rootroot00000000000000-- Lambda-binding freezes an otherwise general function. -- == -- error: Cannot apply "g" to "y" def main (x: i32) (y: bool) = let f x y = (y,x) in (\g -> (g x y, g y x)) f futhark-0.25.27/tests/types/inference-error12.fut000066400000000000000000000001771475065116200216250ustar00rootroot00000000000000-- A record turns out to be missing a field. -- == -- error: expected type def f r = let y = r.l2 in (r: {l1: i32}) futhark-0.25.27/tests/types/inference-error13.fut000066400000000000000000000001761475065116200216250ustar00rootroot00000000000000-- A record turns out to have the wrong type of field. -- == -- error: def f r = let y: f32 = r.l in (r: {l: i32}) futhark-0.25.27/tests/types/inference-error14.fut000066400000000000000000000001521475065116200216200ustar00rootroot00000000000000-- A record must have an unambiguous type (no row polymorphism). -- == -- error: ambiguous def f x = x.l futhark-0.25.27/tests/types/inference-error2.fut000066400000000000000000000002031475065116200215320ustar00rootroot00000000000000-- If something is put in an array, it cannot later be inferred as a -- function. -- == -- error: functional def f x = ([x], x 2) futhark-0.25.27/tests/types/inference-error3.fut000066400000000000000000000001531475065116200215370ustar00rootroot00000000000000-- If something is applied, it cannot later be put in an array. -- == -- error: -> b def f x = (x 2, [x]) futhark-0.25.27/tests/types/inference-error4.fut000066400000000000000000000002311475065116200215350ustar00rootroot00000000000000-- If something is used in a loop, it cannot later be inferred as a -- function. -- == -- error: functional def f x = (loop x = x for i < 10 do x, x 2) futhark-0.25.27/tests/types/inference-error5.fut000066400000000000000000000002011475065116200215330ustar00rootroot00000000000000-- A type parameter cannot be inferred as a specific type. -- == -- error: does not have expected type def f 't (x: i32): t = x futhark-0.25.27/tests/types/inference-error7.fut000066400000000000000000000001141475065116200215400ustar00rootroot00000000000000-- Ambiguous equality type. -- == -- error: ambiguous def add x y = x == y futhark-0.25.27/tests/types/inference-error8.fut000066400000000000000000000002061475065116200215430ustar00rootroot00000000000000-- If something is used for equality, it cannot later be inferred as a -- function. -- == -- error: equality def f x = (x == x, x 2) futhark-0.25.27/tests/types/inference-error9.fut000066400000000000000000000001431475065116200215440ustar00rootroot00000000000000-- Equality not defined for type parameters. -- == -- error: equality def main 't (x: t) = x == x futhark-0.25.27/tests/types/inference0.fut000066400000000000000000000001661475065116200204110ustar00rootroot00000000000000-- The simplest conceivable type inference. -- == -- input { 2 } output { 2 } def id x = x def main (x: i32) = id x futhark-0.25.27/tests/types/inference1.fut000066400000000000000000000002501475065116200204040ustar00rootroot00000000000000-- Inferred-polymorphic function instantiated twice. -- == -- input { 2 true } output { true 2 } def id x = x def main (x: i32) (y: bool): (bool, i32) = (id y, id x) futhark-0.25.27/tests/types/inference10.fut000066400000000000000000000001421475065116200204640ustar00rootroot00000000000000-- Disambiguating with array elements. -- == -- input { 0 } output { [0,1] } def main x = [x, 1] futhark-0.25.27/tests/types/inference11.fut000066400000000000000000000000711475065116200204660ustar00rootroot00000000000000-- An empty array literal is fine! def main: []i32 = [] futhark-0.25.27/tests/types/inference12.fut000066400000000000000000000001741475065116200204730ustar00rootroot00000000000000-- Type inference based on SOAC usage. -- == -- input { [true, false] } output { false } def main xs = reduce (&&) true xs futhark-0.25.27/tests/types/inference13.fut000066400000000000000000000001311475065116200204650ustar00rootroot00000000000000-- Inference from indexing. -- == def f xsss = xsss[0,1,2] def main xsss: i32 = f xsss futhark-0.25.27/tests/types/inference14.fut000066400000000000000000000001411475065116200204670ustar00rootroot00000000000000-- Inference when overloading is involved. -- == -- input { 1 } output { 3 } def main x = x + 2 futhark-0.25.27/tests/types/inference15.fut000066400000000000000000000001601475065116200204710ustar00rootroot00000000000000-- Inference with both overloading and lambdas. -- == -- input { 1 } output { 3 } def main x = (\y -> x + y) 2 futhark-0.25.27/tests/types/inference16.fut000066400000000000000000000001231475065116200204710ustar00rootroot00000000000000-- Equality overloading! -- == -- input { 2 } output { true } def main x = x == 2 futhark-0.25.27/tests/types/inference17.fut000066400000000000000000000001701475065116200204740ustar00rootroot00000000000000-- Both arithmetic and equality overloading at once! -- == -- input { 2 3 } output { false } def main x y = y == x + 2 futhark-0.25.27/tests/types/inference18.fut000066400000000000000000000003041475065116200204740ustar00rootroot00000000000000-- Currying overloaded operators. def eq1 = (==1) def eq2 = (==) : (i32 -> i32 -> bool) def add1 = (+1) def add2 = (+) : (i32 -> i32 -> i32) def main (x: i32) = eq1 x && eq2 (add1 x) (add2 x x) futhark-0.25.27/tests/types/inference19.fut000066400000000000000000000001151475065116200204750ustar00rootroot00000000000000-- Loop bound inference def main x = (loop y = 2 for i < x do y*2, x + 1i8) futhark-0.25.27/tests/types/inference2.fut000066400000000000000000000001721475065116200204100ustar00rootroot00000000000000-- Higher-order inference. -- == -- input { 2 } output { 4 } def apply f x = f x def main x = apply (apply (i32.+) x) x futhark-0.25.27/tests/types/inference20.fut000066400000000000000000000002041475065116200204640ustar00rootroot00000000000000-- For-in loop variable inference. -- == -- input { [1,2,3] } output { 9 } def main xs = (loop y = 1 for x in xs do y * 2) + xs[0] futhark-0.25.27/tests/types/inference21.fut000066400000000000000000000002561475065116200204740ustar00rootroot00000000000000-- Zip-inference. Note that unzip-inference is not supported (and -- does not make much sense). def f xs ys = zip xs ys def main (xs: []i32) (ys: []i32) = unzip (f xs ys) futhark-0.25.27/tests/types/inference22.fut000066400000000000000000000002161475065116200204710ustar00rootroot00000000000000-- A let-bound function can be instantiated for different types. -- == def main (x: i32) (y: bool) = let f x y = (y,x) in (f x y, f y x) futhark-0.25.27/tests/types/inference23.fut000066400000000000000000000003001475065116200204640ustar00rootroot00000000000000-- Inferring a function parameter into a tuple in an interesting way. -- == -- input { 1 2 } output { 1 2 } def curry f x y = f (x, y) def id x = x def main (x: i32) (y: i32) = curry id x y futhark-0.25.27/tests/types/inference24.fut000066400000000000000000000002231475065116200204710ustar00rootroot00000000000000-- Use of a variable free in a loop can affect inference. -- == -- input { 3 [1,2,3] } output { 3 } def main m xs = loop y = 0 for i < m do xs[i] futhark-0.25.27/tests/types/inference25.fut000066400000000000000000000001211475065116200204670ustar00rootroot00000000000000-- Tuple inference via let binding. def main x = let (a,b) = x in a + 1 + b futhark-0.25.27/tests/types/inference26.fut000066400000000000000000000001221475065116200204710ustar00rootroot00000000000000-- Record inference via let binding. def main x = let {a,b} = x in a + 1 + b futhark-0.25.27/tests/types/inference27.fut000066400000000000000000000002271475065116200205000ustar00rootroot00000000000000-- Inference of record projection. -- == -- input { 2 } output { 2 } def f r = let _y = r.l in (r: {l: i32}) def main (l: i32) = (f {l}).l futhark-0.25.27/tests/types/inference28.fut000066400000000000000000000002771475065116200205060ustar00rootroot00000000000000-- An inferred record suddenly has its fields determined. -- == -- input { 2 } output { 2 } def f r = let y: i32 = r.l let _: {l:i32} = r in y def main (l: i32) = f {l} futhark-0.25.27/tests/types/inference29.fut000066400000000000000000000002001475065116200204710ustar00rootroot00000000000000-- Unification of inferred type with ascription in pattern. -- == -- input { 2 } output { 2 } def main x = let y: i32 = x in y futhark-0.25.27/tests/types/inference3.fut000066400000000000000000000001751475065116200204140ustar00rootroot00000000000000-- An inferred parameter can be put in an array. -- == -- input { 2 } output { [2] } def f x = [x] def main (x: i32) = f x futhark-0.25.27/tests/types/inference30.fut000066400000000000000000000001751475065116200204740ustar00rootroot00000000000000-- Field projection inference for a lambda. -- == -- input { 1 } output { [1] } def main (x: i32) = map (\r -> r.l) [{l=x}] futhark-0.25.27/tests/types/inference31.fut000066400000000000000000000001601475065116200204670ustar00rootroot00000000000000type^ img 'a = f32 -> a def f r: (img bool -> img bool) = \reg -> let g d x = reg (d+x:f32) in g r futhark-0.25.27/tests/types/inference32.fut000066400000000000000000000003071475065116200204730ustar00rootroot00000000000000-- Inferring a unique type is never allowed - it must always be put -- there explicitly! -- == -- error: Consuming.*"xs" def consume (xs: *[]i32) = xs def main (xs: []i32) = (\xs -> consume xs) xs futhark-0.25.27/tests/types/inference33.fut000066400000000000000000000003651475065116200205000ustar00rootroot00000000000000-- Local functions should not have their (overloaded) type fixed -- immediately, but rather wait until the top-level function has to be -- assigned a type. -- == -- input { 2u8 } output { 3u8 } def main (x: u8) = let inc y = y + 1 in inc x futhark-0.25.27/tests/types/inference34.fut000066400000000000000000000002671475065116200205020ustar00rootroot00000000000000-- Local functions should not have their (overloaded) record type fixed -- immediately. -- == -- input { 2 3 } output { 2 } def main (x: i32) (y: i32) = let f v = v.0 in f (x,y) futhark-0.25.27/tests/types/inference35.fut000066400000000000000000000004071475065116200204770ustar00rootroot00000000000000-- Type inference for infix operators. -- == -- input { 1f32 0f32 } output { 1f32 false true } -- input { 1f32 2f32 } output { 3f32 true true } def f op x y = x `op` y def g (+) x y = x + y def main (x: f32) (y: f32) = (f (+) x y, g (<) x y, g (<) false true) futhark-0.25.27/tests/types/inference36.fut000066400000000000000000000001771475065116200205040ustar00rootroot00000000000000-- Inference through a record/tuple type. def f b x (y: (i32, i32)) = if x.0 == y.0 then (x.0, 0) else if b then y else x futhark-0.25.27/tests/types/inference37.fut000066400000000000000000000002351475065116200205000ustar00rootroot00000000000000def I_mult (n: i64) (x: i64) (a: i64) : [n][n]i64 = let elem i j = i64.bool(i == j) * (if i == x then a else 1) in tabulate_2d n n elem futhark-0.25.27/tests/types/inference4.fut000066400000000000000000000002241475065116200204100ustar00rootroot00000000000000-- An inferred parameter can be returned from a branch. -- == -- input { 2 } output { 2 } def f x = if true then x else x def main (x: i32) = f x futhark-0.25.27/tests/types/inference5.fut000066400000000000000000000002101475065116200204040ustar00rootroot00000000000000-- Inference for a local function. -- == -- input { 2 } output { 4 } def main x = let apply f x = f x in apply (apply (i32.+) x) x futhark-0.25.27/tests/types/inference6.fut000066400000000000000000000001411475065116200204100ustar00rootroot00000000000000-- Disambiguating with type ascription. -- == -- input { 1 } output { 1 } def main x = (x: i32) futhark-0.25.27/tests/types/inference7.fut000066400000000000000000000001331475065116200204120ustar00rootroot00000000000000-- Disambiguating with return type. -- == -- input { 1 } output { 1 } def main x: i32 = x futhark-0.25.27/tests/types/inference8.fut000066400000000000000000000001521475065116200204140ustar00rootroot00000000000000-- Disambiguating with branch type. -- == -- input { 1 } output { 1 } def main x = if true then x else 0 futhark-0.25.27/tests/types/inference9.fut000066400000000000000000000001711475065116200204160ustar00rootroot00000000000000-- Disambiguating with conditional type. -- == -- input { true } output { true } def main x = if x then true else false futhark-0.25.27/tests/types/level0.fut000066400000000000000000000002351475065116200175570ustar00rootroot00000000000000-- Cannot unify a type parameter with a type bound in an outer scope. -- == -- error: "b".*scope def f x = let g 'b (y: b) = if true then y else x in g futhark-0.25.27/tests/types/level1.fut000066400000000000000000000002651475065116200175630ustar00rootroot00000000000000-- Cannot unify a type parameter with a type bound in an outer scope. -- == -- error: "b".*scope def f x = let g 'b (y: b) = if true then y else x.0 let (_, _: i32) = x in g futhark-0.25.27/tests/types/level2.fut000066400000000000000000000003101475065116200175530ustar00rootroot00000000000000-- A size restriction imposed by a lambda parameter may not affect -- anything free in the lambda. -- == -- error: "n".*scope violation def main (ys: []i32) = (\(n: i64) (xs: [n]i32) -> zip xs ys) futhark-0.25.27/tests/types/level3.fut000066400000000000000000000003331475065116200175610ustar00rootroot00000000000000-- A size restriction imposed by a local function parameter may not affect -- anything free in the function. -- == -- error: "n".*scope violation def main (ys: []i32) = let f (n: i64) (xs: [n]i32) = zip xs ys in f futhark-0.25.27/tests/types/level4.fut000066400000000000000000000005511475065116200175640ustar00rootroot00000000000000-- A size restriction imposed by a local function parameter may not affect -- a constructor of anything free in the function. -- == -- error: "n".*scope violation def main x = let f (n: i64) (xs: [n]i32) = zip xs (match x case #ys (ys: [n]i32) -> ys case _ -> xs) let x' = (x : (#ys ([]i32) | #null)) in f futhark-0.25.27/tests/types/level5.fut000066400000000000000000000002421475065116200175620ustar00rootroot00000000000000-- A local binding may not affect the type of an outer parameter. -- == -- error: "n".*scope violation def main (xs: []i32) = let n = 2+3 in zip (iota n) xs futhark-0.25.27/tests/types/level6.fut000066400000000000000000000003221475065116200175620ustar00rootroot00000000000000-- A size restriction imposed by a local size parameter may not affect -- anything free in the function. -- == -- error: "n".*scope violation def main (ys: []i32) = let f [n] (xs: [n]i32) = zip xs ys in f futhark-0.25.27/tests/types/lifted-abbrev.fut000066400000000000000000000001551475065116200210770ustar00rootroot00000000000000-- == -- error: "arr" type^ arr = [2]i32 type bad = [3]arr -- Bad, because we declared 'arr' to be lifted. futhark-0.25.27/tests/types/metasizes.fut000066400000000000000000000027301475065116200203760ustar00rootroot00000000000000-- A tricky test of type-level programming. -- == -- input { [1,2,3] [4,5,6] [7,8,9] } -- output { [1, 2, 3, 4, 5, 6, 7, 8, 9] -- [4, 5, 6, 1, 2, 3, 7, 8, 9] -- } module meta: { type eq[n][m] val coerce [n][m]'t : eq[n][m] -> [n]t -> [m]t val coerce_inner [n][m]'t [k] : eq[n][m] -> [k][n]t -> [k][m]t val refl [n] : eq[n][n] val comm [n][m] : eq[n][m] -> eq[m][n] val trans [n][m][k] : eq[n][m] -> eq[m][k] -> eq[n][k] val plus_comm [a][b] : eq[a+b][b+a] val plus_assoc [a][b][c] : eq[(a+b)+c][a+(b+c)] val plus_lhs [a][b][c] : eq[a][b] -> eq[a+c][b+c] val plus_rhs [a][b][c] : eq[c][b] -> eq[a+c][a+b] val mult_comm [a][b] : eq[a*b][b*a] val mult_assoc [a][b][c] : eq[(a*b)*c][a*(b*c)] val mult_lhs [a][b][c] : eq[a][b] -> eq[a+c][b+c] val mult_rhs [a][b][c] : eq[c][b] -> eq[a+c][a+b] } = { type eq[n][m] = [0][n][m]() def coerce [n][m]'t (_: eq[n][m]) (a: [n]t) = a :> [m]t def coerce_inner [n][m]'t [k] (_: eq[n][m]) (a: [k][n]t) = a :> [k][m]t def refl = [] def comm _ = [] def trans _ _ = [] def plus_comm = [] def plus_assoc = [] def plus_lhs _ = [] def plus_rhs _ = [] def mult_comm = [] def mult_assoc = [] def mult_lhs _ = [] def mult_rhs _ = [] } def main [n][m][l] (xs: [n]i32) (ys: [m]i32) (zs: [l]i32) = let proof : meta.eq[m+(n+l)][(n+m)+l] = meta.comm meta.plus_assoc `meta.trans` meta.plus_lhs meta.plus_comm in zip ((xs ++ ys) ++ zs) (meta.coerce proof (ys ++ (xs ++ zs))) |> unzip futhark-0.25.27/tests/types/overloaded0.fut000066400000000000000000000002261475065116200205740ustar00rootroot00000000000000-- Test of overloaded numeric types. -- == -- input { 10f32 } output { 0.01f32 } def main (x: f32) = let y = 0.001 let f (z: f32) = y*z in f x futhark-0.25.27/tests/types/overloaded1.fut000066400000000000000000000001451475065116200205750ustar00rootroot00000000000000-- Warn when top-level definitions are overloaded. -- == -- warning: ambiguous def plus x y = x + y futhark-0.25.27/tests/types/size-lifted-abbrev.fut000066400000000000000000000001551475065116200220470ustar00rootroot00000000000000-- == -- error: "arr" type~ arr = [2]i32 type bad = [3]arr -- Bad, because we declared 'arr' to be lifted. futhark-0.25.27/tests/types/size-lifted0.fut000066400000000000000000000001521475065116200206650ustar00rootroot00000000000000-- Size-lifted types can be returned from 'if'. def f '~a (b: bool) (x: a) (y: a) = if b then x else y futhark-0.25.27/tests/types/size-lifted1.fut000066400000000000000000000005221475065116200206670ustar00rootroot00000000000000-- Size-lifted types can be returned from loop. -- -- Pretty ridiculous that we have to write it this way! -- == -- input { 1 [1] } output { [1,1] } def iterate '~a (n: i32) (f: (() -> a) -> a) (x: () -> a) = loop x = x () for _i < n do f (\() -> x) def main n (xs: []i32) = iterate n (\(p : () -> []i32) -> p () ++ p ()) (\() -> xs) futhark-0.25.27/tests/types/size-lifted2.fut000066400000000000000000000004041475065116200206670ustar00rootroot00000000000000-- Size-lifted types must not unify with nonlifted types. -- == -- error: array element def packunpack 'a (f: i64 -> a) : a = (tabulate 10 f)[0] def apply 'a '~b (f: i64 -> a -> b) (x: a) = packunpack (\i -> f i x) def main (x: i32) = apply replicate x futhark-0.25.27/tests/types/sizeparams-error0.fut000066400000000000000000000001441475065116200217540ustar00rootroot00000000000000-- Too few arguments. -- == -- error: ints type ints [n] = [n]i32 def main(n: i32): ints = iota n futhark-0.25.27/tests/types/sizeparams-error1.fut000066400000000000000000000001541475065116200217560ustar00rootroot00000000000000-- Too many arguments. -- == -- error: ints type ints [n] = [n]i32 def main(n: i32): ints [1][2] = iota n futhark-0.25.27/tests/types/sizeparams-error2.fut000066400000000000000000000001351475065116200217560ustar00rootroot00000000000000-- Size parameters may not be duplicated. -- == -- error: n type matrix [n] [n] = [n][n]i32 futhark-0.25.27/tests/types/sizeparams0.fut000066400000000000000000000002761475065116200206330ustar00rootroot00000000000000-- Basic size-parameterised type. -- == -- input { 0i64 } output { empty([0]i64) } -- input { 3i64 } output { [0i64,1i64,2i64] } type ints [n] = [n]i64 def main(n: i64): ints [n] = iota n futhark-0.25.27/tests/types/sizeparams1.fut000066400000000000000000000002751475065116200206330ustar00rootroot00000000000000-- Size-parameterised type in parameter. -- == -- input { empty([0]i32) } output { 0i64 } -- input { [1,2,3] } output { 3i64 } type ints [n] = [n]i32 def main [n] (_: ints [n]) : i64 = n futhark-0.25.27/tests/types/sizeparams10.fut000066400000000000000000000005631475065116200207130ustar00rootroot00000000000000-- What about size parameters that are only known in complex expressions? -- == -- input { [1,2] [3,4] } -- output { [3,4,1,2] } type eq[n][m] = [n][m]() def coerce [n][m]'t (_: eq[n][m]) (a: [n]t) = a :> [m]t def plus_comm [a][b]'t : eq[a+b] [b+a] = tabulate_2d (a+b) (b+a) (\_ _ -> ()) def main [n][m] (xs: [n]i32) (ys: [m]i32) = copy (coerce plus_comm (ys ++ xs)) futhark-0.25.27/tests/types/sizeparams11.fut000066400000000000000000000005011475065116200207040ustar00rootroot00000000000000-- Another complicated case. -- == -- input { 1i64 2i64 } -- output { [[true, true, true], [true, true, true], [true, true, true]] } def plus a b : i64 = a + b def plus_comm [a][b]'t : [plus a b][plus b a]bool = tabulate_2d (plus a b) (plus b a) (\_ _ -> true) def main a b = copy plus_comm : [plus a b][plus b a]bool futhark-0.25.27/tests/types/sizeparams2.fut000066400000000000000000000005071475065116200206320ustar00rootroot00000000000000-- Multiple uses of same size in parameterised type. -- == -- input { empty([0]i32) empty([0]i32) } output { empty([0]i32) } -- input { [1,2,3] [1,2,3] } output { [1,2,3,1,2,3] } -- input { [1,2,3] [1,2,3,4] } error: type ints [n] = [n]i32 def main [n][m] (a: ints [n]) (b: ints [m]) = concat (a :> ints [n]) (b :> ints [n]) futhark-0.25.27/tests/types/sizeparams3.fut000066400000000000000000000006421475065116200206330ustar00rootroot00000000000000-- One size-parameterised type refers to another. -- == -- input { empty([0]i32) empty([0]i32) } output { empty([0]i32) empty([0]i32) } -- input { [1,2,3] [1,2,3] } output { [1,2,3] [1,2,3] } -- input { [1,2,3] [1,2,3,4] } output { [1,2,3] [1,2,3] } type ints [n] = [n]i32 type pairints [n] [m] = (ints [n], ints [m]) def main [n][m] (a: ints [n]) (b: ints [m]) = let b' = take n b in (a,b') : pairints [n] [n] futhark-0.25.27/tests/types/sizeparams4.fut000066400000000000000000000003131475065116200206270ustar00rootroot00000000000000-- Shadowing of size parameters. -- == -- input { 0i64 } output { empty([0]i64) } -- input { 3i64 } output { [0i64,1i64,2i64] } def n = 2i64 type ints [n] = [n]i64 def main(n: i64): ints [n] = iota n futhark-0.25.27/tests/types/sizeparams5.fut000066400000000000000000000002771475065116200206410ustar00rootroot00000000000000-- A size parameter can be a constant type. -- == -- input { 0i64 } error: Error -- input { 3i64 } output { [0i64,1i64,2i64] } type ints [n] = [n]i64 def main (n: i64) = iota n :> ints [3] futhark-0.25.27/tests/types/sizeparams6.fut000066400000000000000000000003201475065116200206270ustar00rootroot00000000000000-- Arrays of tuples work, too. -- == -- input { 2i64 3 } output { [3,3,3,3] } type pairvec [m] = [m](i32,i32) def main (n:i64) (e: i32): []i32 = let a: pairvec [] = replicate (2*n) (e,e) in (unzip a).0 futhark-0.25.27/tests/types/sizeparams7.fut000066400000000000000000000002321475065116200206320ustar00rootroot00000000000000-- No space is needed before the size argument. -- == -- input { 2i64 } output { [0i64,1i64] } type ints[n] = [n]i64 def main (n:i64): ints[n] = iota n futhark-0.25.27/tests/types/sizeparams8.fut000066400000000000000000000002141475065116200206330ustar00rootroot00000000000000-- If a name is used as a size, then it's probably an i32! -- == -- input { 3i64 [1,2,3] } output { [1,2,3] } def main n (xs: [n]i32) = xs futhark-0.25.27/tests/types/sizeparams9.fut000066400000000000000000000001671475065116200206430ustar00rootroot00000000000000-- From #1573. Make sure we can handle missing whitespace in the parser. type a [n] = [n]f32 let f (a:a[]) : a[] = a futhark-0.25.27/tests/types/square.fut000066400000000000000000000006151475065116200176720ustar00rootroot00000000000000-- Test that we can use type abbreviations to encode tricky existential situations. -- == -- input { 2i64 } -- output { [[0i64, 1i64, 2i64], [1i64, 2i64, 3i64], [2i64, 3i64, 4i64]] } type square [n] 't = [n][n]t def ext_square n : square [] i64 = tabulate_2d (n+1) (n+1) (\i j -> i + j) def tr_square [n] 't (s: square [n] t) : square [n] t = transpose s def main n = tr_square (ext_square n) futhark-0.25.27/tests/types/tricky.fut000066400000000000000000000003361475065116200176770ustar00rootroot00000000000000-- == -- input { 4 } -- output { [1i32, 1i32, 1i32] [4i32, 4i32, 4i32] } type^ t [n] = ([n]i32, i32 -> [n]i32) def v : t [] = let three = 3 in (replicate three 1, \i -> replicate three i) def main x = (copy v.0, v.1 x) futhark-0.25.27/tests/types/typeparams-error0.fut000066400000000000000000000001731475065116200217650ustar00rootroot00000000000000-- Missing parameter to a parametric type. -- == -- error: vector type vector 't = []t def main(n: i32): vector = iota n futhark-0.25.27/tests/types/typeparams-error1.fut000066400000000000000000000002271475065116200217660ustar00rootroot00000000000000-- Missing parameter to a multi-parameter parametric tpye. -- == -- error: pair type pair 'a 'b = (a,b) def main (x: i32) (y: f64): pair f64 = (y,x) futhark-0.25.27/tests/types/typeparams-error2.fut000066400000000000000000000001351475065116200217650ustar00rootroot00000000000000-- Type parameters may not be duplicated. -- == -- error: previously type whatevs 't 't = t futhark-0.25.27/tests/types/typeparams-error5.fut000066400000000000000000000002001475065116200217610ustar00rootroot00000000000000-- Cannot create an array containing elements of a lifted type parameter. -- == -- error: Cannot create array type t '^a = []a futhark-0.25.27/tests/types/typeparams0.fut000066400000000000000000000002241475065116200206330ustar00rootroot00000000000000-- A simple case of a parametric type. -- == -- input { 2i64 } output { [0i64,1i64] } type~ vector 't = []t def main(n: i64): vector i64 = iota n futhark-0.25.27/tests/types/typeparams1.fut000066400000000000000000000002321475065116200206330ustar00rootroot00000000000000-- Multi-parameter parametric type. -- == -- input { 1 2.0 } output { 2.0 1 } type pair 'a 'b = (a,b) def main (x: i32) (y: f64) = (y,x) : pair f64 i32 futhark-0.25.27/tests/types/typeparams2.fut000066400000000000000000000001551475065116200206400ustar00rootroot00000000000000-- Shadowing in type. -- == -- input { 2.0 } output { 2.0 } type t 'int = int def main (x: f64): t f64 = x futhark-0.25.27/tests/unflatten0.fut000066400000000000000000000002611475065116200173030ustar00rootroot00000000000000-- == -- input { 2i64 2i64 } -- output { [[0i64, 1i64], [2i64, 3i64]] } -- input { -2i64 -2i64 } -- error: Cannot unflatten.*\[-2\]\[-2\] def main n m = unflatten (iota (n*m)) futhark-0.25.27/tests/uniqueness/000077500000000000000000000000001475065116200167035ustar00rootroot00000000000000futhark-0.25.27/tests/uniqueness/uniqueness-error0.fut000066400000000000000000000002341475065116200230300ustar00rootroot00000000000000-- Type ascription should not hide aliases. -- == -- error: "a".*consumed def main(): i64 = let a = iota(10) let b:*[]i64 = a let b[0] = 1 in a[0] futhark-0.25.27/tests/uniqueness/uniqueness-error1.fut000066400000000000000000000004111475065116200230260ustar00rootroot00000000000000-- Test whether multiple references within the same sequence are -- detected. -- == -- error: "a".*consumed def main(): i32 = let n = 10 let a = iota(n) let b = iota(n) let (i,j) = (2,5) in (let a[i]=b[j] in 1) + (let b[j]=a[i] in 2) -- Error! futhark-0.25.27/tests/uniqueness/uniqueness-error10.fut000066400000000000000000000004451475065116200231150ustar00rootroot00000000000000-- Don't let occurences clash just because they're function arguments. -- == -- error: Cannot apply "f" def f(a: *[]i32): []i32 = a def main(): ([]i32, []i32) = let n = 10 let a = iota(n) let b = iota(n) let (i,j) = (2,5) in (f(let a[i]=b[j] in a),f(let b[j]=a[i] in b)) futhark-0.25.27/tests/uniqueness/uniqueness-error11.fut000066400000000000000000000004171475065116200231150ustar00rootroot00000000000000-- Make sure occurences are checked inside function parameters as well. -- == -- error: consumed def f(x: i32): i32 = x def main(): i32 = let n = 10 let a = iota(n) let b = iota(n) let (i,j) = (2,5) in f((let a[i]=b[j] in 1) + (let b[j]=a[i] in 2)) futhark-0.25.27/tests/uniqueness/uniqueness-error12.fut000066400000000000000000000004011475065116200231070ustar00rootroot00000000000000-- Don't let curried mapees consume more than once. -- == -- error: consumption def f(a: *[]i64) (i: i64): []i64 = let a[i] = 0 in a def main n = let a = iota(n) let b = iota(n) in map (f (a)) b -- Bad, because a may be consumed many times. futhark-0.25.27/tests/uniqueness/uniqueness-error13.fut000066400000000000000000000002131475065116200231110ustar00rootroot00000000000000-- No cheating uniqueness with tuple shenanigans. -- == -- error: aliased def main (x: (*[]i32, *[]i32)): (*[]i32, *[]i32) = (x.0, x.0) futhark-0.25.27/tests/uniqueness/uniqueness-error14.fut000066400000000000000000000006661475065116200231260ustar00rootroot00000000000000-- This program tests whether the compiler catches some of the more -- nasty side cases of aliasing in loops. -- == -- error: "arr" aliases "barr" def main(): i64 = let arr = copy(iota(10)) let barr = copy(iota(10)) in let arr = loop arr for i < 10 do let arr[i] = 0 in -- Consume arr and its aliases... barr -- Because of this, arr should be aliased to barr. in barr[0] -- Error, barr has been consumed! futhark-0.25.27/tests/uniqueness/uniqueness-error15.fut000066400000000000000000000003411475065116200231150ustar00rootroot00000000000000-- Test that shadowing does not break alias analysis. -- == -- error: consumed def main(): *[]i64 = let n = 10 let a = iota(n) let c = let a = a let a[0] = 42 in a in a -- Should be an error, because a was consumed. futhark-0.25.27/tests/uniqueness/uniqueness-error16.fut000066400000000000000000000003631475065116200231220ustar00rootroot00000000000000-- Test that complex shadowing does not break alias analysis. -- == -- error: consumed def main(): *[]i64 = let n = 10 let a = iota(n) let c = let (a, b) = (2, a) let b[0] = 42 in b in a -- Should be an error, because a was consumed. futhark-0.25.27/tests/uniqueness/uniqueness-error17.fut000066400000000000000000000004051475065116200231200ustar00rootroot00000000000000-- Test that aliasing is found, even if hidden inside a -- branch. -- == -- error: .*consumed.* def main n = let a = iota(n) let c = if 2==2 then iota(n) else a -- c aliases a. let c[0] = 4 in -- Consume c and a. a[0] -- Error, because a was consumed. futhark-0.25.27/tests/uniqueness/uniqueness-error18.fut000066400000000000000000000002511475065116200231200ustar00rootroot00000000000000-- Check that unique components of a return tuple do not alias each -- other. -- == -- error: unique def main(n: i64): (*[]i64, *[]i64) = let a = iota(n) in (a, a) futhark-0.25.27/tests/uniqueness/uniqueness-error19.fut000066400000000000000000000004031475065116200231200ustar00rootroot00000000000000-- Test that you cannot consume free variables in a loop. -- == -- error: not consumable def main = let n = 10 let a = iota(n) let b = iota(n) in loop b for i < n do let a[i] = i -- Error, because a is free and should not be consumed. in b futhark-0.25.27/tests/uniqueness/uniqueness-error2.fut000066400000000000000000000005101475065116200230270ustar00rootroot00000000000000-- Test that non-basic aliasing of an array results in an aliased -- array. -- == -- error: .*consumed.* def main(): []i64 = let n = 10 let a = replicate n (iota n) -- Note that a is 2-dimensional let b = a[0] -- Now b aliases a. let a[1] = replicate n 8 in -- Consume a, thus also consuming b. b -- Error! futhark-0.25.27/tests/uniqueness/uniqueness-error20.fut000066400000000000000000000003071475065116200231130ustar00rootroot00000000000000-- Test that you can't consume a free variable in a lambda. -- == -- error: unique def main(n: i64): i32 = let a = iota(n) let b = map (\(x: i32): i32 -> let a[x] = 4 in a[x]) (iota(n)) in 0 futhark-0.25.27/tests/uniqueness/uniqueness-error21.fut000066400000000000000000000005271475065116200231200ustar00rootroot00000000000000-- This benchmark tests whether aliasing is tracked even deep inside -- loops. -- == -- error: "DT".*consumed def floydSbsImp(N: i32, D: *[][]i32): [][]i32 = let DT = transpose(D) -- DT aliases D. in loop D for i < N do loop D for j < N do let D[i,j] = DT[j,i] in D -- Consume D and DT, but bad, because DT is used. futhark-0.25.27/tests/uniqueness/uniqueness-error22.fut000066400000000000000000000003341475065116200231150ustar00rootroot00000000000000-- Test that we cannot consume anything inside an anonymous function. -- == -- error: Consuming variable "a" def f(a: *[]i64) = a[0] def main(n: i64) = let a = iota(n) in foldl (\sum i -> sum + f(a)) 0 (iota(10)) futhark-0.25.27/tests/uniqueness/uniqueness-error23.fut000066400000000000000000000005041475065116200231150ustar00rootroot00000000000000-- == -- error: self-aliased def g(ar: *[]i64, a: *[][]i64): i64 = ar[0] def f(ar: *[]i64, a: *[][]i64): i64 = g(a[0], a) -- Should be a type error, as both are supposed to be -- unique yet they alias each other. def main(n: i64): i64 = let a = replicate n (iota n) let ar = copy(a[0]) in f(ar, a) futhark-0.25.27/tests/uniqueness/uniqueness-error24.fut000066400000000000000000000003671475065116200231250ustar00rootroot00000000000000-- Test that consumption checking is done even with no meaningful -- bindings. -- == -- error: consumed def consume(a: *[]i32): i32 = 0 -- OK. def main(a: *[]i32): []i32 = let _ = consume(a) in a -- Should fail, because a has been consumed! futhark-0.25.27/tests/uniqueness/uniqueness-error25.fut000066400000000000000000000004211475065116200231150ustar00rootroot00000000000000-- == -- error: "t", but this was consumed def f(t: ([]i32,*[]i32)): i32 = let (a,b) = t let b[0] = 1337 in a[0] def main(b: *[]i32): i32 = let a = b in -- Should fail, because 'a' and 'b' are aliased, yet the 'b' part of -- the tuple is consumed. f((a,b)) futhark-0.25.27/tests/uniqueness/uniqueness-error26.fut000066400000000000000000000004641475065116200231250ustar00rootroot00000000000000-- At one point, usage of SOAC array arguments when mapping with an -- operator was not registered properly. -- -- == -- error: "row".*consumed def main [w] (row : *[w]i32) : [w]u8 = let b = row -- b now aliases row let row[0] = 2 -- consume row in map u8.i32 b -- fail, because row has been consumed futhark-0.25.27/tests/uniqueness/uniqueness-error27.fut000066400000000000000000000003121475065116200231160ustar00rootroot00000000000000-- You may not consume a free variable inside of a lambda. -- -- == -- error: Consuming variable "a" def consume(a: *[]i32): []i32 = a def main(a: *[]i32): [][]i32 = map (\i -> consume a) (iota 10) futhark-0.25.27/tests/uniqueness/uniqueness-error28.fut000066400000000000000000000003151475065116200231220ustar00rootroot00000000000000-- Catch consumption even within curried expressions. -- == -- error: QUUX.*consumed def main () : []i32 = let QUUX = replicate 1 0 let y = scatter QUUX [0] [2] let xs = map (+ QUUX[0]) [1] in xs futhark-0.25.27/tests/uniqueness/uniqueness-error29.fut000066400000000000000000000002661475065116200231300ustar00rootroot00000000000000-- A local function whose free variable has been consumed. -- == -- error: QUUX.*consumed def main(y: i32, QUUX: *[]i32) = let f (x: i32) = x + QUUX[0] let QUUX[1] = 2 in f y futhark-0.25.27/tests/uniqueness/uniqueness-error3.fut000066400000000000000000000003241475065116200230330ustar00rootroot00000000000000-- == -- error: "a".*consumed def main n = let a = iota(n) let b = a -- b and a alias each other. let (i,j) = (2,5) in (let a[i]=b[j] in a[i]) -- Consume a, and thus also b. + b[j] -- Error! futhark-0.25.27/tests/uniqueness/uniqueness-error30.fut000066400000000000000000000002471475065116200231170ustar00rootroot00000000000000-- A local function may not consume anything free. -- == -- error: QUUX.*consumable def main(y: i32, QUUX: *[]i32) = let f (x: i32) = let QUUX[0] = x in x in f y futhark-0.25.27/tests/uniqueness/uniqueness-error31.fut000066400000000000000000000002231475065116200231120ustar00rootroot00000000000000-- In-place updates with 'with' can also have errors. -- == -- error: in-place def main [n] (a: *[][n]i32, i: i32): [][]i32 = a with [i] = a[0] futhark-0.25.27/tests/uniqueness/uniqueness-error33.fut000066400000000000000000000001671475065116200231230ustar00rootroot00000000000000-- Type ascriptions must respect uniqueness. -- == -- error: aliased to "x" def main (x: []i32) : *[]i32 = x : *[]i32 futhark-0.25.27/tests/uniqueness/uniqueness-error34.fut000066400000000000000000000002041475065116200231140ustar00rootroot00000000000000-- Pattern bindings must respect uniqueness. -- == -- error: aliased to "x" def main (x: []i32) : *[]i32 = let y : *[]i32 = x in y futhark-0.25.27/tests/uniqueness/uniqueness-error35.fut000066400000000000000000000002031475065116200231140ustar00rootroot00000000000000-- Loop parameters must respect uniqueness. -- == -- error: Consuming.*"x" def main (x: []i32) = loop (x: *[]i32) for i < 10 do x futhark-0.25.27/tests/uniqueness/uniqueness-error36.fut000066400000000000000000000003471475065116200231260ustar00rootroot00000000000000-- Ensure that we cannot cheat uniqueness typing with higher-order -- functions. -- == -- error: consuming def apply 'a 'b (f: a -> b) (x: a) = (f x, f x) def consume (xs: *[]i32) = 0 def main (arr: *[]i32) = apply consume arr futhark-0.25.27/tests/uniqueness/uniqueness-error37.fut000066400000000000000000000002371475065116200231250ustar00rootroot00000000000000-- No cheating uniqueness just by using a clever name for the array. -- == -- error: consumed def main ((++): *[]i32) = let _ = (++) with [0] = 0 in (++) futhark-0.25.27/tests/uniqueness/uniqueness-error38.fut000066400000000000000000000002551475065116200231260ustar00rootroot00000000000000-- A function must not return a global array. -- == -- error: aliases the free variable "global" def global: []i32 = [1,2,3] def main (b: bool) = if b then global else [] futhark-0.25.27/tests/uniqueness/uniqueness-error39.fut000066400000000000000000000003001475065116200231160ustar00rootroot00000000000000-- This is not OK, because it would imply consuming the original -- non-unique array. -- == -- error: Unique-typed return value def f (x: []i32): []i32 = x def main (a: []i32): *[]i32 = f a futhark-0.25.27/tests/uniqueness/uniqueness-error4.fut000066400000000000000000000004321475065116200230340ustar00rootroot00000000000000-- == -- error: "a".*consumed def f(a: *[]i64, i: i32, v: i64): i64 = let a[i]=v in a[i] def main(): i64 = let n = 10 let a = iota(n) let b = a -- a and b are aliases. let (i,j) = (2,5) in f(a,i,42) -- Consumes a (and b through the alias) + b[j] -- Error! futhark-0.25.27/tests/uniqueness/uniqueness-error40.fut000066400000000000000000000003211475065116200231110ustar00rootroot00000000000000-- This is not OK, because it would imply consuming the original -- non-unique array. -- == -- error: consumable def polyid 't (x: t) = x def main (xs: []i32) = let ys = polyid xs let ys[0] = 42 in ys futhark-0.25.27/tests/uniqueness/uniqueness-error41.fut000066400000000000000000000002111475065116200231100ustar00rootroot00000000000000-- Global variables may not be unique! -- == -- error: constant def global: *[]i32 = [1,2,3] def main (x: i32) = global with [0] = x futhark-0.25.27/tests/uniqueness/uniqueness-error42.fut000066400000000000000000000003621475065116200231200ustar00rootroot00000000000000-- When returning unique values from a loop, they must not alias each other. -- == -- error: aliases other consumed loop parameter def main (n: i64) = loop (xs: *[]i32, ys: *[]i32) = (replicate n 0, replicate n 0) for i < 10 do (xs, xs) futhark-0.25.27/tests/uniqueness/uniqueness-error43.fut000066400000000000000000000003761475065116200231260ustar00rootroot00000000000000-- When returning unique values from a loop, they must not alias each other. -- == -- error: aliases other consumed loop parameter def main (n: i64) = loop {xs: *[]i32, ys: *[]i32} = {xs=replicate n 0, ys=replicate n 0} for i < 10 do {xs=xs, ys=xs} futhark-0.25.27/tests/uniqueness/uniqueness-error44.fut000066400000000000000000000002701475065116200231200ustar00rootroot00000000000000-- A lambda function must not return a global array. -- == -- error: aliases the free variable "global" def global: []i32 = [1,2,3] def main = \(b: bool) -> if b then global else [] futhark-0.25.27/tests/uniqueness/uniqueness-error45.fut000066400000000000000000000003041475065116200231170ustar00rootroot00000000000000-- A local function must not return a global array. -- == -- error: aliases the free variable "global" def global: []i32 = [1,2,3] def main = let f (b: bool) = if b then global else [] in f futhark-0.25.27/tests/uniqueness/uniqueness-error46.fut000066400000000000000000000003161475065116200231230ustar00rootroot00000000000000-- The result of a functional argument may alias anything (unless it's -- unique). -- == -- error: "f".*which is not consumable def f (f: i32 -> []i32): i32 = let xs = f 1 let xs[0] = xs[0] + 2 in 2 futhark-0.25.27/tests/uniqueness/uniqueness-error47.fut000066400000000000000000000002521475065116200231230ustar00rootroot00000000000000-- Maintain aliases through record updates. -- == -- error: "ys".*consumed def main (xs: []i32) (ys: *[]i32) = let tup = (xs, ys) with 0 = xs let ys[0] = 0 in tup futhark-0.25.27/tests/uniqueness/uniqueness-error48.fut000066400000000000000000000004461475065116200231310ustar00rootroot00000000000000-- Record updates should respect uniqueness and aliases. -- == -- error: "s", which is not consumable type^ state = { size: i64, world: []i32 } def init (size: i64): state = {size, world = replicate size 0} def main (size: i64) (s: state) : *[]i32 = (init size with world = s.world).world futhark-0.25.27/tests/uniqueness/uniqueness-error49.fut000066400000000000000000000002361475065116200231270ustar00rootroot00000000000000-- Do not let ascription screw up uniqueness/aliasing. -- == -- error: Consuming.*"xs" def f 't (x: t) = id (x : t) def main (xs: []i32) = f xs with [0] = 0 futhark-0.25.27/tests/uniqueness/uniqueness-error5.fut000066400000000000000000000003421475065116200230350ustar00rootroot00000000000000-- == -- error: .*consumed.* def f(a: *[][]i64): i64 = a[0,0] def main n = let a = replicate n (iota n) let c = transpose a in -- Rearrange creates an alias. f(a) + c[0,0] -- f(a) consumes both a and c, so error. futhark-0.25.27/tests/uniqueness/uniqueness-error50.fut000066400000000000000000000003721475065116200231200ustar00rootroot00000000000000-- The result of a loop can have aliases. -- == -- error: "chunk" def vecadd [m] (xs: [m]i32) (ys: [m]i32): [m]i32 = ys def main [m] chunk_sz (chunk: [chunk_sz][m]i32): *[m]i32 = loop acc = replicate m 0 for i < chunk_sz do vecadd acc chunk[i] futhark-0.25.27/tests/uniqueness/uniqueness-error51.fut000066400000000000000000000002771475065116200231250ustar00rootroot00000000000000-- Type inference should not eliminate uniqueness checking. -- == -- error: Consuming.*"xs" def f {xs: []i32} : {xs: []i32} = {xs} def main xs = let {xs=ys} = f {xs} in ys with [0] = 0 futhark-0.25.27/tests/uniqueness/uniqueness-error52.fut000066400000000000000000000002441475065116200231200ustar00rootroot00000000000000-- Do not hide aliases with flatten. -- == -- error: Cannot apply def main [n] (xss: [n][n]i32): *[]i32 = let xs = flatten xss in scatter xs (iota n) (iota n) futhark-0.25.27/tests/uniqueness/uniqueness-error53.fut000066400000000000000000000003101475065116200231130ustar00rootroot00000000000000-- Do not hide global variables with flatten. -- == -- error: Cannot apply def xss : [][]i32 = [[1,2,3],[4,5,6]] def main (n: i32): *[]i32 = let xs = flatten xss in scatter xs (iota n) (iota n) futhark-0.25.27/tests/uniqueness/uniqueness-error54.fut000066400000000000000000000001771475065116200231270ustar00rootroot00000000000000-- == -- error: aliased to some other component. def dup x = (x,x) def main (xs: []i32) : (*[]i32, *[]i32) = dup (copy xs) futhark-0.25.27/tests/uniqueness/uniqueness-error55.fut000066400000000000000000000002111475065116200231150ustar00rootroot00000000000000-- == -- error: "y".*consumed def main n = let (a,b) = let y = iota n in (y,y) let a[0] = 0 let b[0] = 0 in (a,b) futhark-0.25.27/tests/uniqueness/uniqueness-error56.fut000066400000000000000000000003501475065116200231220ustar00rootroot00000000000000-- Properly check uniqueness of return values when calling -- higher-order functions. -- == -- error: def cons (f: () -> *[2]i32) : *[2]i32 = f () with [0] = 1 def main (x: [2]i32) : *[2]i32 = let f () : []i32 = x in cons f futhark-0.25.27/tests/uniqueness/uniqueness-error57.fut000066400000000000000000000004311475065116200231230ustar00rootroot00000000000000-- Track aliases properly in arrays in sum types. -- == -- error: "x".*consumed type t = #some ([2]i32) | #none def consume (x: *t): *[2]i32 = match x case #some arr -> arr with [0] = 1 case _ -> [0,0] def main (x: *t) = let a = consume x let b = consume x in (a,b) futhark-0.25.27/tests/uniqueness/uniqueness-error58.fut000066400000000000000000000003251475065116200231260ustar00rootroot00000000000000-- Derived from #1535. -- == -- error: aliased to "x" type sumType = #some ([0]i32) | #none entry main = (\(x: sumType): *[]i32 -> match x case (#some y) -> id y case _ -> []) (#none: sumType) futhark-0.25.27/tests/uniqueness/uniqueness-error59.fut000066400000000000000000000002351475065116200231270ustar00rootroot00000000000000-- == -- error: aliases the free variable "global" def global = ([1,2,3], 0) def return_global () = global def main i = (return_global ()).0 with [i] = 0 futhark-0.25.27/tests/uniqueness/uniqueness-error6.fut000066400000000000000000000003001475065116200230300ustar00rootroot00000000000000-- == -- error: "a".*consumed def f(t: (i32, *[]i64)): i32 = let (x, a) = t in x def main(): i64 = let n = 10 let a = iota(n) let t = (5, a) let c = f(t) in a[0] futhark-0.25.27/tests/uniqueness/uniqueness-error60.fut000066400000000000000000000002521475065116200231160ustar00rootroot00000000000000-- == -- error: result of applying "f".*consumed def f (n: i64) : ([]i64, []i64) = let a = iota n in (a,a) def main n = let (a,b) = f n let a[0] = 0 in (a,b) futhark-0.25.27/tests/uniqueness/uniqueness-error61.fut000066400000000000000000000003461475065116200231230ustar00rootroot00000000000000-- Based on #1842 -- -- The problem was that type ascription of a function type did not -- check that the uniqueness matched. -- == -- error: does not have expected type def f (xs: *[]f32) : f32 = 0f32 entry g : []f32 -> f32 = f futhark-0.25.27/tests/uniqueness/uniqueness-error62.fut000066400000000000000000000002541475065116200231220ustar00rootroot00000000000000-- Size expression should be non-consuming -- == -- error: "ns".*not consumable def consume (xs: *[]i64): i64 = xs[0] def f [n] (ns: *[n]i64) (xs: [consume ns]f32) = xs[0] futhark-0.25.27/tests/uniqueness/uniqueness-error63.fut000066400000000000000000000003241475065116200231210ustar00rootroot00000000000000-- Issue #1975 -- == -- error: aliased to some other component def main n : (*[]i64, *[]i64) = let (foo,bar) = loop _ = (iota 10,iota 10) for i < n do let arr = iota 10 in (arr,arr) in (foo,bar) futhark-0.25.27/tests/uniqueness/uniqueness-error64.fut000066400000000000000000000001351475065116200231220ustar00rootroot00000000000000-- From #2007 -- == -- error: Argument is consumed def main (xs: *[]i64) = scatter xs xs xs futhark-0.25.27/tests/uniqueness/uniqueness-error65.fut000066400000000000000000000002031475065116200231170ustar00rootroot00000000000000-- From #2067. -- == -- error: "xs".*not consumable def main (xs: *[]i32) = loop xs : []i32 for i < 10 do xs with [i] = i+1 futhark-0.25.27/tests/uniqueness/uniqueness-error7.fut000066400000000000000000000003141475065116200230360ustar00rootroot00000000000000-- == -- error: "a".*consumed def main n = let a = iota(n) let b = iota(n) let i = 0 in (let a[i]=b[i] in a[i]) + (let b=a in b[i]) -- Bad because of parallel consume-observe collision. futhark-0.25.27/tests/uniqueness/uniqueness-error8.fut000066400000000000000000000003321475065116200230370ustar00rootroot00000000000000-- == -- error: "a".*consumed def main(): i32 = let n = 10 let a = iota(n) let (i,j) = (2,5) let (c, a) = (let a[i] = 0 in 1, a[i]) in -- Error: consumes and observes a in same sequence. 5 -- Bad. futhark-0.25.27/tests/uniqueness/uniqueness-error9.fut000066400000000000000000000006041475065116200230420ustar00rootroot00000000000000-- This test tracks whether aliasing is propagated properly when -- tuples of differing dimensions is used as function parameters. -- == -- error: "a".*consumed def f(x: (i32, i32), t: (i32, i32, []i64)): []i64 = let (x, y, a) = t in a def main n = let a = iota(n) let t = (3, 4, a) let b = f((1,2), t) let a[0] = 2 in b -- Error, because b is aliased to t. futhark-0.25.27/tests/uniqueness/uniqueness0.fut000066400000000000000000000003041475065116200216770ustar00rootroot00000000000000-- Simplest possible in-place operation. -- == -- input { -- [1,2,3,4] -- 2 -- 10 -- } -- output { -- [1,2,10,4] -- } def main (a: *[]i32) (i: i32) (x: i32): []i32 = let a[i] = x in a futhark-0.25.27/tests/uniqueness/uniqueness1.fut000066400000000000000000000004551475065116200217070ustar00rootroot00000000000000-- Test that simple function argument consumption works. -- == -- input { -- } -- output { -- 0 -- } def f(a: *[]i64): i64 = a[0] def main: i32 = let n = 10 let b = iota(n) let a = b -- Alias a to b. let x = f(b) in -- Consumes both b and a because a is aliased to b. 0 -- OK! futhark-0.25.27/tests/uniqueness/uniqueness10.fut000066400000000000000000000004761475065116200217720ustar00rootroot00000000000000-- Test that complex shadowing does not break alias analysis. -- == -- input { 10i64 } -- output { -- [0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64, 8i64, 9i64] -- } def main n: []i64 = let a = iota(n) let c = let (a, b) = (iota(n), a) let a[0] = 42 in a in a -- OK, because the outer a was never consumed. futhark-0.25.27/tests/uniqueness/uniqueness11.fut000066400000000000000000000004361475065116200217670ustar00rootroot00000000000000-- Test that map does not introduce aliasing when the row type is a -- basic type. -- == -- input { -- } -- output { -- 0i64 -- } def f (x: i64) = x def g (x: i64) = x def main: i64 = let a = iota(10) let x = map f a let a[1] = 3 let y = map g x in y[0] futhark-0.25.27/tests/uniqueness/uniqueness12.fut000066400000000000000000000004321475065116200217640ustar00rootroot00000000000000-- == -- input { -- } -- output { -- 0 -- } def main: i32 = let n = 10 let a = iota(n) let b = iota(n) let c = (a,b) let (a_,unused_b) = (a,b) let a[0] = 0 -- Only a_ and a are consumed. let (unused_a,b_) = (a,b) let b[0] = 1 in -- Only b_ and b are consumed. 0 futhark-0.25.27/tests/uniqueness/uniqueness13.fut000066400000000000000000000003221475065116200217630ustar00rootroot00000000000000-- == -- input { -- 42i64 -- } -- output { -- [1.000000] -- [2.000000] -- } def f(b_1: *[]i64): ([]f64,[]f64) = ([1.0],[2.0]) def main(n: i64): ([]f64, []f64) = let a = iota(n) let x = f(a) in x futhark-0.25.27/tests/uniqueness/uniqueness14.fut000066400000000000000000000003651475065116200217730ustar00rootroot00000000000000-- == -- input { -- 42i64 -- } -- output { -- [0i64, 1i64, 2i64, 3i64, 4i64, 5i64, 6i64, 7i64, 8i64, 9i64] -- } def f(b_1: *[]i64): *[]i64 = iota(10) def main(n: i64): []i64 = let a = iota(n) let x = if n == 0 then a else f(a) in x futhark-0.25.27/tests/uniqueness/uniqueness17.fut000066400000000000000000000011531475065116200217720ustar00rootroot00000000000000-- CSE once messed this one up after kernel extraction, because the -- same input array is used in two reduce kernels, and reduce-kernels -- may consume some input arrays (specifically, the scratch input -- array). -- -- == -- input { [1f32,2f32,3f32] 4f32 5f32 } -- output { 24.0f32 } -- structure gpu { SegRed 1 } def max(a: f32) (b: f32): f32 = if(a < b) then b else a def exactYhat(xs: []f32, x: f32): f32 = let ups = map (+x) xs let lo = reduce max (0.0) ups in lo + ups[0] def main (xs: []f32) (mux: f32) (eps: f32): f32 = let g = exactYhat(xs, mux + eps) let h = exactYhat(xs, mux - eps) in g + h futhark-0.25.27/tests/uniqueness/uniqueness18.fut000066400000000000000000000002721475065116200217740ustar00rootroot00000000000000-- When the map is simplified away, it must turn into a copy, as the -- result is consumed. -- -- == -- structure { Map 0 Replicate 1 } def main(as: []i32): *[]i32 = map (\x -> x) as futhark-0.25.27/tests/uniqueness/uniqueness19.fut000066400000000000000000000003441475065116200217750ustar00rootroot00000000000000-- A local function whose free variable has been consumed, but the -- function is never called! -- == -- input { 2 [1,2,3] } output { 2} def main(y: i32) (QUUX: *[]i32) = let f (x: i32) = x + QUUX[0] let QUUX[1] = 2 in y futhark-0.25.27/tests/uniqueness/uniqueness2.fut000066400000000000000000000002621475065116200217040ustar00rootroot00000000000000-- == -- input { -- } -- output { -- 3 -- } def main: i32 = let n = 10 let (a, b) = (replicate n 0, replicate n 0) let a[0] = 1 let b[0] = 2 in a[0] + b[0] futhark-0.25.27/tests/uniqueness/uniqueness20.fut000066400000000000000000000004361475065116200217670ustar00rootroot00000000000000-- It is fine to do an in-place update on something returned by a -- function that has aliases. -- == -- input { [1, 2, 3] } output { [42, 2, 3] } def id (xs: []i32) = xs def polyid 't (x: t) = x def main (xs: *[]i32) = let ys = id xs let ys = polyid xs let ys[0] = 42 in ys futhark-0.25.27/tests/uniqueness/uniqueness21.fut000066400000000000000000000001561475065116200217670ustar00rootroot00000000000000-- The array magically becomes unique! -- == def f (x: []i32): []i32 = x def main (a: *[]i32): *[]i32 = f a futhark-0.25.27/tests/uniqueness/uniqueness22.fut000066400000000000000000000001741475065116200217700ustar00rootroot00000000000000-- Once failed after internalisation. def main (arr: *[](i32, i32)) = let arr' = rotate 1 arr in arr' with [0] = (0,0) futhark-0.25.27/tests/uniqueness/uniqueness23.fut000066400000000000000000000002311475065116200217630ustar00rootroot00000000000000-- Could fail after internalisation. def consume (arr: *[](i32, i32)) = arr def main (arr: *[](i32, i32)) = let arr' = rotate 1 arr in consume arr futhark-0.25.27/tests/uniqueness/uniqueness24.fut000066400000000000000000000001621475065116200217670ustar00rootroot00000000000000def main (arr: *[]i32) = let a = arr[0] let arr' = rotate 1 arr let arr'[0] = 0 let arr'[1] = a in arr' futhark-0.25.27/tests/uniqueness/uniqueness25.fut000066400000000000000000000002051475065116200217660ustar00rootroot00000000000000def main [n] (m: i32) (xs: [n]i32) : [n]i32 = let foo = loop xs = copy xs for _d < m do xs let foo[n-1] = 0 in foo futhark-0.25.27/tests/uniqueness/uniqueness26.fut000066400000000000000000000006711475065116200217760ustar00rootroot00000000000000-- Carefully handle the case where there is a lot of alias overlapping -- going on between the target of an in-place update and the source, -- including inside a section! It was actually the section that we -- mishandled in lambda-lifting. -- == -- input { 0i64 3i64 [3f32,4f32,5f32,6f32] } -- output { [1.0f32, 1.3333334f32, 1.6666666f32, 6.0f32] } def main [n] (i: i64) (j: i64) (A: *[n]f32): []f32 = A with [i:j] = map (/A[0]) A[i:j] futhark-0.25.27/tests/uniqueness/uniqueness27.fut000066400000000000000000000004111475065116200217670ustar00rootroot00000000000000def main b1 b2 (A: *[]i32) (B: *[]i32) (C: *[]i32) = let X = if b1 then A else B let Y = if b2 then B else C let X[0] = 0 -- This is OK because while X aliases {A,B} and Y aliases {B,C}, -- there is no way for the consumption of X to touch C. in (X, C) futhark-0.25.27/tests/uniqueness/uniqueness3.fut000066400000000000000000000003631475065116200217070ustar00rootroot00000000000000-- == -- input { -- } -- output { -- [1.000000, 0.000000] -- [2.000000, 0.000000] -- } def main: ([]f64,[]f64) = let n = 2 let (arr1, arr2) = (replicate n 0.0, replicate n 0.0) let arr1[0] = 1.0 let arr2[0] = 2.0 in (arr1, arr2) futhark-0.25.27/tests/uniqueness/uniqueness4.fut000066400000000000000000000010261475065116200217050ustar00rootroot00000000000000-- This test inspired by code often created by -- arrays-of-tuples-to-tuple-of-arrays transformation. -- == -- input { -- } -- output { -- [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000] -- [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000] -- } def main: ([]f64, []f64) = let n = 10 in loop (a, b) = (replicate n 0.0, replicate n 0.0) for i < n do let a[ i ] = 0.0 let b[ i ] = 0.0 in (a, b) futhark-0.25.27/tests/uniqueness/uniqueness51.fut000066400000000000000000000001671475065116200217740ustar00rootroot00000000000000-- Overloaded numbers should not track aliases. def main = let arr = [3,7] let a = arr[0] let arr[0] = 0 in a futhark-0.25.27/tests/uniqueness/uniqueness52.fut000066400000000000000000000002541475065116200217720ustar00rootroot00000000000000type constants 'p = {v: [10]p} type~ t = {i32s: []i32, constants: constants i32} def f ({i32s, constants}: *t) : *t = let i32s[0] = 0 in {i32s, constants} futhark-0.25.27/tests/uniqueness/uniqueness53.fut000066400000000000000000000001141475065116200217660ustar00rootroot00000000000000def main (xs: *[]i32) = let xs' = zip xs xs let xs'[0] = (0,0) in xs' futhark-0.25.27/tests/uniqueness/uniqueness54.fut000066400000000000000000000001551475065116200217740ustar00rootroot00000000000000def dupcopy (xs: []i32) : (*[]i32, *[]i32) = (copy xs, copy xs) def main xs : (*[]i32, *[]i32) = dupcopy xs futhark-0.25.27/tests/uniqueness/uniqueness55.fut000066400000000000000000000002031475065116200217670ustar00rootroot00000000000000-- Proper aliasing inference for function. let f xs = map (+1i32) xs def main (xss: *[][]i32) = let xss[0] = f xss[0] in xss futhark-0.25.27/tests/uniqueness/uniqueness56.fut000066400000000000000000000001251475065116200217730ustar00rootroot00000000000000-- #1687 type t = {a:[10]i32, b:bool} def recUpdate (rec: *t) = rec.a with [0] = 1 futhark-0.25.27/tests/uniqueness/uniqueness57.fut000066400000000000000000000001771475065116200220030ustar00rootroot00000000000000-- == -- error: "f".*consumed def main [n] (xs: *[n]i32) = let f i = xs[i] in loop ys = xs for i < n do ys with [i] = f i futhark-0.25.27/tests/uniqueness/uniqueness58.fut000066400000000000000000000002441475065116200217770ustar00rootroot00000000000000def main to from (counts: *[]i32) (state: *[][]u8) = let state[to,counts[to]] = state[from,counts[from]-1] let counts[to] = counts[to] + 1 in (counts, state) futhark-0.25.27/tests/uniqueness/uniqueness59.fut000066400000000000000000000002241475065116200217760ustar00rootroot00000000000000-- It is ok to consuming non-free variables -- == def consume (xs: *[]i64): i64 = xs[0] def f [n] (ns: [n]i64) (xs: [consume (iota 10)]f32) = xs[0] futhark-0.25.27/tests/uniqueness/uniqueness6.fut000066400000000000000000000003111475065116200217030ustar00rootroot00000000000000-- == -- input { -- [1,2,3] -- } -- output { -- [1,2,3] -- } def main [n] (arr: [n]i32): []i32 = let newarr = (let notused = arr in replicate n 0) let newarr[0] = 0 in arr futhark-0.25.27/tests/uniqueness/uniqueness60.fut000066400000000000000000000002231475065116200217650ustar00rootroot00000000000000-- If consumption is on bounded var, no problem -- == -- warning: ^$ def consume (xs: *[]i64): i64 = xs[0] def f (n:i64) = iota (consume (iota n)) futhark-0.25.27/tests/uniqueness/uniqueness7.fut000066400000000000000000000005541475065116200217150ustar00rootroot00000000000000-- == -- input { -- } -- output { -- 0i64 -- } def f(a: *[][]i64) = a[0,0] def main: i64 = let n = 10 let a = replicate n (iota n) let b = replicate n (iota n) in let a = loop (a) for i < n do let a[i] = b[i] in a -- Does not alias a to b, because let-with is in-place! let x = f(b) in -- Consumes only b. a[0,0] -- OK futhark-0.25.27/tests/uniqueness/uniqueness8.fut000066400000000000000000000004411475065116200217110ustar00rootroot00000000000000-- == -- input { -- } -- output { -- 0i64 -- } def f(a: *[]i64) = a[0] def g(a: []i64) = a[0] def main: i64 = let n = 10 let a = iota(n) let b = a in if 1 == 2 then let c = g(b) in f(a) + c else let c = g(a) in f(b) -- OK Because only one branch is taken. futhark-0.25.27/tests/uniqueness/uniqueness9.fut000066400000000000000000000005611475065116200217150ustar00rootroot00000000000000-- This test checks whether we can consume something in a loop, -- without causing an error just because it's aliased outside the loop. -- == -- input { -- } -- output { -- 0 -- } def main: i32 = let n = 10 let inarr = replicate n 0 in let _ = loop outarr = inarr for i < n do if i == 0 then outarr else let outarr[i] = i in outarr in 0 futhark-0.25.27/tests/unscoping.fut000066400000000000000000000003441475065116200172320ustar00rootroot00000000000000-- == -- error: Cannot apply "bar" to "xs" def foo n = let (m,_) = (n+1,true) in (iota ((m+1)+1), \_ -> iota (m+1), \_ -> iota m) def bar [n] (_:[n+1]i64) = n def main n = let (xs, _, _) = foo n in bar xs futhark-0.25.27/tests/unused-reduce.fut000066400000000000000000000003601475065116200177730ustar00rootroot00000000000000-- Can we remove the unused parts of a reduction? -- == -- structure { Screma/BinOp 1 } def main (xs: []i32) (ys: []i32) = let (x', _) = reduce (\(x1, y1) (x2, y2) -> (x1 + x2, y1 * y2)) (0, 1) (zip xs ys) in x' futhark-0.25.27/tests/vasicek/000077500000000000000000000000001475065116200161315ustar00rootroot00000000000000futhark-0.25.27/tests/vasicek/iobound-mc2.fut000066400000000000000000000026231475065116200207720ustar00rootroot00000000000000-- An I/O-bound mc2 implementation. -- -- Useful for verification with a "sequential" R implementation. -- -- == -- Some useful (for mc2) Futhark extensions. def sum(xs: []f32): f32 = reduce (+) (0.0) xs def mean [n] (xs: [n]f32): f32 = sum(map (/f32.i64(n)) xs) -- Vasicek model parameters. def r0(): f32 = 0.03 -- initial interest rate def thetaP(): f32 = 0.03 -- short-rate mean def thetaQ(): f32 = 0.045 -- long-rate mean def kappa(): f32 = 0.1 -- speed of mean reversion def sigma(): f32 = 0.01 -- interest rate volatility def nextrP(lastr: f32, wp: f32): f32 = lastr + kappa() * (thetaP() - lastr) + sigma() * wp def nextrQ(lastr: f32, wq: f32): f32 = lastr + kappa() * (thetaQ() - lastr) + sigma() * wq def seqRedSumP [n] (lastr: f32, ws: [n]f32): f32 = loop (lastr) for i < n do nextrP(lastr, ws[i]) def seqRedSumQ [n] (lastr: f32, ws: [n]f32): f32 = loop (lastr) for i < n do nextrQ(lastr, ws[i]) def mc1step(wps: []f32): f32 = seqRedSumP(r0(), wps) def mc1 [n] (wpss: [n][]f32): [n]f32 = map mc1step wpss def mc2step (wqs: []f32) (r1: f32): f32 = seqRedSumQ(r1, wqs) def mc2sim [tn][k] (arg: ([tn][k]f32, f32)): f32 = let ( wqss, r1 ) = arg let sum_r = map2 mc2step wqss (replicate tn r1) in mean(sum_r) def mc2(wqsss: [][][]f32, r1s: []f32): []f32 = map mc2sim (zip wqsss r1s) def main(wpss: [][]f32, wqsss: [][][]f32): []f32 = --mc1(wpss) mc2(wqsss, mc1(wpss)) futhark-0.25.27/tests/zip0.fut000066400000000000000000000003411475065116200161040ustar00rootroot00000000000000-- Basic test that zip doesn't totally mess up everything. -- == -- input { -- [1,2,3] -- [4,5,6] -- } -- output { -- [1, 2, 3] -- [4, 5, 6] -- } def main [n] (a: [n]i32) (b: [n]i32): ([]i32,[]i32) = unzip(zip a b) futhark-0.25.27/tests_adhoc/000077500000000000000000000000001475065116200156425ustar00rootroot00000000000000futhark-0.25.27/tests_adhoc/bad_input/000077500000000000000000000000001475065116200176075ustar00rootroot00000000000000futhark-0.25.27/tests_adhoc/bad_input/.gitignore000066400000000000000000000000351475065116200215750ustar00rootroot00000000000000stderr bad_input bad_input.c futhark-0.25.27/tests_adhoc/bad_input/bad_input.fut000066400000000000000000000000571475065116200222760ustar00rootroot00000000000000-- == -- input { 1i32 } let main (x: i64) = x futhark-0.25.27/tests_adhoc/bad_input/test.sh000077500000000000000000000003161475065116200211250ustar00rootroot00000000000000#!/bin/sh set -e ! futhark test bad_input.fut > stderr if !(fgrep -q 'Expected input of types: i64' stderr) || !(fgrep -q 'Provided input of types: i32' stderr); then cat stderr exit 1 fi futhark-0.25.27/tests_adhoc/eval/000077500000000000000000000000001475065116200165715ustar00rootroot00000000000000futhark-0.25.27/tests_adhoc/eval/test.fut000066400000000000000000000000561475065116200202710ustar00rootroot00000000000000local def plus2_iota x = map (+ 2) (iota x) futhark-0.25.27/tests_adhoc/eval/test.sh000077500000000000000000000001051475065116200201030ustar00rootroot00000000000000#!/bin/sh futhark eval -f test.fut '1 + 1' 'iota 4' 'plus2_iota 15' futhark-0.25.27/tests_adhoc/hash/000077500000000000000000000000001475065116200165655ustar00rootroot00000000000000futhark-0.25.27/tests_adhoc/hash/test.fut000066400000000000000000000000271475065116200202630ustar00rootroot00000000000000let main (x: bool) = x futhark-0.25.27/tests_adhoc/hash/test.sh000077500000000000000000000001131475065116200200760ustar00rootroot00000000000000#!/bin/sh [ $(futhark hash test.fut) = $(futhark hash ../hash/test.fut) ] futhark-0.25.27/tests_adhoc/literate_trace/000077500000000000000000000000001475065116200206315ustar00rootroot00000000000000futhark-0.25.27/tests_adhoc/literate_trace/test.sh000077500000000000000000000002001475065116200221370ustar00rootroot00000000000000#!/bin/sh set -e rm -rf trace-img futhark literate trace.fut -v | tee | fgrep 'trace.fut:1:22-32: 1.000000 2.000000 3.000000' futhark-0.25.27/tests_adhoc/literate_trace/trace.fut000066400000000000000000000000761475065116200224520ustar00rootroot00000000000000def foo x = f32.sum (#[trace] x) -- > foo [1f32, 2f32, 3f32] futhark-0.25.27/tests_adhoc/profile/000077500000000000000000000000001475065116200173025ustar00rootroot00000000000000futhark-0.25.27/tests_adhoc/profile/.gitignore000066400000000000000000000000401475065116200212640ustar00rootroot00000000000000prog prog.c prog.json prog.prof futhark-0.25.27/tests_adhoc/profile/prog.fut000066400000000000000000000001741475065116200207730ustar00rootroot00000000000000-- == -- entry: inc1 inc2 -- input { [1,2,3,4] } entry inc1 (xs: []i32) = map (+1) xs entry inc2 (xs: []i32) = map (+2) xs futhark-0.25.27/tests_adhoc/profile/test.sh000077500000000000000000000003521475065116200206200ustar00rootroot00000000000000#!/bin/sh set -e with_backend() { echo echo "Trying backend $1" futhark bench --backend=$1 --profile --json prog.json prog.fut rm -rf prog.prof futhark profile prog.json } with_backend c with_backend multicore futhark-0.25.27/tests_adhoc/script/000077500000000000000000000000001475065116200171465ustar00rootroot00000000000000futhark-0.25.27/tests_adhoc/script/.gitignore000066400000000000000000000000301475065116200211270ustar00rootroot00000000000000testout double double.c futhark-0.25.27/tests_adhoc/script/double.fut000066400000000000000000000000321475065116200211330ustar00rootroot00000000000000entry main = map (f32.*2) futhark-0.25.27/tests_adhoc/script/test.sh000077500000000000000000000002211475065116200204570ustar00rootroot00000000000000#!/bin/sh set -e futhark script ./double.fut '$store "testout" (main [1,2,3])' [ "$(futhark dataset x + 2) def r = { x = \x -> x + 2 } def h = map3 (\x y z -> x + y + z) (iota 10) (iota 10) (iota 10) futhark-0.25.27/tests_fmt/expected/lastarg.fut000066400000000000000000000001231475065116200213240ustar00rootroot00000000000000def f = tabulate_2d 2 3 \i j -> i + j def g = tabulate_2d 2 3 \i j -> i + j futhark-0.25.27/tests_fmt/expected/let_linebreak.fut000066400000000000000000000001421475065116200224700ustar00rootroot00000000000000-- If the binding is after a linebreak, then keep it there. def f x = let y = x + 2 in y futhark-0.25.27/tests_fmt/expected/local.fut000066400000000000000000000001061475065116200207620ustar00rootroot00000000000000-- | top level. local -- | I am a doc comment. module type mt = { } futhark-0.25.27/tests_fmt/expected/loop.fut000066400000000000000000000003021475065116200206370ustar00rootroot00000000000000def l1 x = loop x for i < 10 do x + i def l2 x = loop x for i < 10 do x + i def l3 x = loop x = x for i < 10 do x + i def l4 x = loop x = x for i < 10 do x + i futhark-0.25.27/tests_fmt/expected/modules.fut000066400000000000000000000011621475065116200213430ustar00rootroot00000000000000module type mt1 = { type a type b type c } module type mt2 = mt1 with a = i32 module type mt3 = mt1 with a = i32 with b = bool with c = f32 module m : mt1 with a = i32 with b = bool with c = f32 = { type a = i32 type b = bool type c = f32 def x = 123 def y = 321 } module type mt4 = { val f [n] : [n]i32 -> [n]i32 val g [n] : [n]i32 -> [n]i32 val g [n] : [n]i32 -> [n]i32 val block [m1] [m2] [n1] [n2] : (A: [m1][n1]i32) -> (B: [m1][n2]i32) -> (C: [m2][n1]i32) -> (D: [m2][n2]i32) -> [m1 + m2][n1 + n2]i32 } module pm1 (P: {}) : {} = { } futhark-0.25.27/tests_fmt/expected/nested_letIn.fut000066400000000000000000000001151475065116200223050ustar00rootroot00000000000000-- Extra 'in' should be removed def main = let n = 10 let m = 20 in 0 futhark-0.25.27/tests_fmt/expected/nested_letWith.fut000066400000000000000000000001471475065116200226570ustar00rootroot00000000000000-- extra 'in' should also be removed from with terms def main = let x = [1, 2] let x[0] = 2 in x futhark-0.25.27/tests_fmt/expected/numbers.fut000066400000000000000000000005351475065116200213510ustar00rootroot00000000000000-- Is the precise spelling of numbers maintained? -- -- Put in some Unicode for good measure: ᚴᚬᛏᛏᛅ ᚴᚬ ᚠᛅᛋᛏ def a = 1_2_3 def b = 0x123 def c = 0rXIV def d = 100f32 def e = 1_2i32 def h = [ (0xf.fp1f32, 0x11.ffp0f32, 0xf.fp-2f32, -0x11.ffp0f32, 0x0.f0p0f32) , (0xf.fp1, 0x11.ffp0, 0xf.fp-2, -0x11.ffp0, 0x0.f0p0) ] futhark-0.25.27/tests_fmt/expected/records.fut000066400000000000000000000005561475065116200213420ustar00rootroot00000000000000type a = { a: i32 , b: i32 } type b = {a: i32, b: i32} type c = { a: i32 , b: i32 , c: ( bool , -- comment here bool , bool ) } def main = let a = 0i32 let b = 0i32 let c = 0i32 let x = {a, b, c} let {a, b, c} = x let {a = a, b = b, c = c} = x let x = { a = a , b = b , c } in x futhark-0.25.27/tests_fmt/expected/singlelines.fut000066400000000000000000000000641475065116200222070ustar00rootroot00000000000000def vecadd = map2 (f64.+) def vecmul = map2 (f64.*) futhark-0.25.27/tests_fmt/expected/sumtype.fut000066400000000000000000000001031475065116200213730ustar00rootroot00000000000000type t1 = #foo | #bar | #baz type t2 = #foo | #bar | #baz futhark-0.25.27/tests_fmt/expected/trailingComments1.fut000066400000000000000000000004341475065116200232740ustar00rootroot00000000000000-- Here is one comment -- Now I'll add some code type test = (i32, i32) -- here we have a trailing comments -- lets add some more code def record = { a = 1 , -- trying trailing b = 2 , -- in multiline comment -- also a test comment here c = 3 } -- one last comment futhark-0.25.27/tests_fmt/expected/trailingComments2.fut000066400000000000000000000000741475065116200232750ustar00rootroot00000000000000def a = ( 0 , -- Test 0 1 ) -- Test 1 def b = 1 futhark-0.25.27/tests_fmt/expected/with.fut000066400000000000000000000001761475065116200206520ustar00rootroot00000000000000def record (x: (f32, f32)) = x with 0 = 42 with 1 = 1337 def array (x: *[]f32) = x with [0] = 42 with [1] = 1337 futhark-0.25.27/tests_fmt/header_comment.fut000066400000000000000000000000741475065116200210450ustar00rootroot00000000000000-- # This comment is the first and only thing in this file. futhark-0.25.27/tests_fmt/if.fut000066400000000000000000000003261475065116200164710ustar00rootroot00000000000000def a = if true then 1 else 2 def b = if true then 1 + 3 + 5 else 2 def c = if true then 1 else if true then 2 else 3 def d = if true then -- foo true else -- bar true futhark-0.25.27/tests_fmt/lambda.fut000066400000000000000000000002321475065116200173070ustar00rootroot00000000000000def f = (\x -> x + 2) def r = { x = \x -> x + 2 } def h = map3 (\x y z -> x + y + z) (iota 10) (iota 10) (iota 10) futhark-0.25.27/tests_fmt/lastarg.fut000066400000000000000000000001471475065116200175310ustar00rootroot00000000000000def f = tabulate_2d 2 3 \i j -> i + j def g = tabulate_2d 2 3 \i j -> i + j futhark-0.25.27/tests_fmt/let_linebreak.fut000066400000000000000000000001421475065116200206670ustar00rootroot00000000000000-- If the binding is after a linebreak, then keep it there. def f x = let y = x + 2 in y futhark-0.25.27/tests_fmt/local.fut000066400000000000000000000001061475065116200171610ustar00rootroot00000000000000-- | top level. local -- | I am a doc comment. module type mt = { } futhark-0.25.27/tests_fmt/loop.fut000066400000000000000000000002711475065116200170430ustar00rootroot00000000000000def l1 x = loop x for i < 10 do x + i def l2 x = loop x for i < 10 do x + i def l3 x = loop x = x for i < 10 do x + i def l4 x = loop x = x for i < 10 do x + i futhark-0.25.27/tests_fmt/modules.fut000066400000000000000000000012221475065116200175370ustar00rootroot00000000000000module type mt1 = { type a type b type c } module type mt2 = mt1 with a = i32 module type mt3 = mt1 with a = i32 with b = bool with c = f32 module m : mt1 with a = i32 with b = bool with c = f32 = { type a = i32 type b = bool type c = f32 def x = 123 def y = 321 } module type mt4 = { val f [n] : [n]i32 -> [n]i32 val g [n] : [n]i32 -> [n]i32 val g [n] : [n]i32 -> [n]i32 val block [m1] [m2] [n1] [n2] : (A: [m1][n1]i32) -> (B: [m1][n2]i32) -> (C: [m2][n1]i32) -> (D: [m2][n2]i32) -> [m1 + m2][n1 + n2]i32 } module pm1 (P: {}) : {} = { } futhark-0.25.27/tests_fmt/nested_letIn.fut000066400000000000000000000001201475065116200205000ustar00rootroot00000000000000-- Extra 'in' should be removed def main = let n = 10 in let m = 20 in 0 futhark-0.25.27/tests_fmt/nested_letWith.fut000066400000000000000000000001551475065116200210550ustar00rootroot00000000000000-- extra 'in' should also be removed from with terms def main = let x = [1, 2] let x[0] = 2 in x futhark-0.25.27/tests_fmt/numbers.fut000066400000000000000000000005351475065116200175500ustar00rootroot00000000000000-- Is the precise spelling of numbers maintained? -- -- Put in some Unicode for good measure: ᚴᚬᛏᛏᛅ ᚴᚬ ᚠᛅᛋᛏ def a = 1_2_3 def b = 0x123 def c = 0rXIV def d = 100f32 def e = 1_2i32 def h = [ (0xf.fp1f32, 0x11.ffp0f32, 0xf.fp-2f32, -0x11.ffp0f32, 0x0.f0p0f32) , (0xf.fp1, 0x11.ffp0, 0xf.fp-2, -0x11.ffp0, 0x0.f0p0) ] futhark-0.25.27/tests_fmt/records.fut000066400000000000000000000005231475065116200175330ustar00rootroot00000000000000type a = {a: i32 , b: i32} type b = {a: i32 , b: i32} type c = {a: i32 , b: i32, c: (bool, -- comment here bool, bool)} def main = let a = 0i32 let b = 0i32 let c = 0i32 let x = {a, b, c} let {a,b,c} = x let {a=a,b=b,c=c} = x let x = {a = a , b = b, c} in x futhark-0.25.27/tests_fmt/singlelines.fut000066400000000000000000000000641475065116200204060ustar00rootroot00000000000000def vecadd = map2 (f64.+) def vecmul = map2 (f64.*) futhark-0.25.27/tests_fmt/sumtype.fut000066400000000000000000000001031475065116200175720ustar00rootroot00000000000000type t1 = #foo | #bar | #baz type t2 = #foo | #bar | #baz futhark-0.25.27/tests_fmt/test.sh000077500000000000000000000007071475065116200166740ustar00rootroot00000000000000#!/bin/sh test_dir="TEMP" diff_error=0 rm -rf "$test_dir" && mkdir "$test_dir" for file in *.fut; do fmtFile=$test_dir/$(basename -s .fut $file).fmt.fut if futhark fmt --check expected/$file; then futhark fmt < $file > $fmtFile if ! cmp --silent expected/$file $fmtFile; then echo "$file didn't format as expected" diff_error=1 else rm $fmtFile fi fi done exit $diff_error futhark-0.25.27/tests_fmt/trailingComments1.fut000066400000000000000000000004211475065116200214670ustar00rootroot00000000000000-- Here is one comment -- Now I'll add some code type test = (i32, i32) -- here we have a trailing comments -- lets add some more code let record = { a=1, -- trying trailing b=2, -- in multiline comment -- also a test comment here c=3} -- one last commentfuthark-0.25.27/tests_fmt/trailingComments2.fut000066400000000000000000000000631475065116200214720ustar00rootroot00000000000000def a = (0 -- Test 0 ,1) -- Test 1 def b = 1futhark-0.25.27/tests_fmt/with.fut000066400000000000000000000001771475065116200170520ustar00rootroot00000000000000def record (x: (f32, f32)) = x with 0 = 42 with 1 = 1337 def array (x: *[]f32) = x with [0] = 42 with [1] = 1337 futhark-0.25.27/tests_lib/000077500000000000000000000000001475065116200153325ustar00rootroot00000000000000futhark-0.25.27/tests_lib/README.md000066400000000000000000000005641475065116200166160ustar00rootroot00000000000000# Tests of Futhark's library backends These tests are written in an ad hoc fashion, as the usual `futhark test` tool only handles executables. Since executables are to a large extent built using the exact same library code, we only need to test library-specific concerns. Specifically, the handling of opaque types (which don't work in executables anyway) is important. futhark-0.25.27/tests_lib/c/000077500000000000000000000000001475065116200155545ustar00rootroot00000000000000futhark-0.25.27/tests_lib/c/.gitignore000066400000000000000000000000231475065116200175370ustar00rootroot00000000000000* !*.fut !test_*.c futhark-0.25.27/tests_lib/c/Makefile000066400000000000000000000022671475065116200172230ustar00rootroot00000000000000FUTHARK_BACKEND ?= c ifeq ($(FUTHARK_BACKEND),opencl) CFLAGS=-O3 -std=c99 LDFLAGS=-lm -lOpenCL else ifeq ($(FUTHARK_BACKEND),multicore) CFLAGS=-O3 -std=c99 LDFLAGS=-lm else ifeq ($(FUTHARK_BACKEND),ispc) CFLAGS=-O3 -std=c99 LDFLAGS=-lm ISPCFLAGS=--woff --pic --addressing=64 else ifeq ($(FUTHARK_BACKEND),cuda) CFLAGS=-O3 -std=c99 LDFLAGS=-lm -lcuda -lcudart -lnvrtc else ifeq ($(FUTHARK_BACKEND),hip) CFLAGS=-O3 -std=c99 LDFLAGS=-lm -lamdhip64 -lhiprtc else CFLAGS=-O3 -std=c99 LDFLAGS=-lm endif .SECONDARY: .PHONY: test clean test: $(patsubst %.fut, do_test_%, $(wildcard *.fut)) do_test_%: test_% ./validatemanifest.py ../../docs/manifest.schema.json $*.json ./test_$* test_%: test_%.c %.o ifeq ($(FUTHARK_BACKEND),ispc) gcc -o $@ $^ $(patsubst test_%, %_ispc.o, $@) -Wall -Wextra -pedantic -std=c99 $(LDFLAGS) -lpthread else gcc -o $@ $^ -Wall -Wextra -pedantic -std=c99 $(LDFLAGS) -lpthread endif %.o: %.c ifeq ($(FUTHARK_BACKEND),ispc) ispc -o $*_ispc.o $*.kernels.ispc $(ISPCFLAGS) endif gcc $*.c -c $(CFLAGS) %.c: %.fut futhark $(FUTHARK_BACKEND) --library $^ clean: rm -rf $(patsubst %.c, %, $(wildcard test_*.c)) *.h *.o *.ispc $(patsubst %.fut, %.c, $(wildcard *.fut)) futhark-0.25.27/tests_lib/c/README.md000066400000000000000000000003141475065116200170310ustar00rootroot00000000000000# Tests of Futhark's C backends In practice, the different C backends are very alike in their implementation, so we hopefully don't need to test them all individually. To run the tests, execute `make`. futhark-0.25.27/tests_lib/c/a.fut000066400000000000000000000002521475065116200165130ustar00rootroot00000000000000type s = i32 entry a (x: s) : s = x + 2 type t1 = {x:[1]i32} type t2 = t1 entry b (x: i32) : t1 = {x=[x + 3]} entry c : t1 -> t2 = id entry d ({x}: t2) : i32 = x[0] futhark-0.25.27/tests_lib/c/b.fut000066400000000000000000000000751475065116200165170ustar00rootroot00000000000000type r = {x: bool} entry a (x: bool) : (i32, r) = (32, {x}) futhark-0.25.27/tests_lib/c/c.fut000066400000000000000000000000371475065116200165160ustar00rootroot00000000000000let main n = map (2+) (iota n) futhark-0.25.27/tests_lib/c/const_error.fut000066400000000000000000000002031475065116200206260ustar00rootroot00000000000000-- Errors in constants must be detectable. -- == -- input { 2 } -- error: false let bad = assert false 0i32 let main x = x + bad futhark-0.25.27/tests_lib/c/const_error_par.fut000066400000000000000000000003171475065116200214760ustar00rootroot00000000000000-- Errors in parallelism in constants. -- == -- input { 2i64 } -- error: out of bounds let n = 10i64 let arr = iota n let bad = map (\i -> arr[if i == 0 then -1 else i]) (iota n) let main x = map (+x) bad futhark-0.25.27/tests_lib/c/d.fut000066400000000000000000000004251475065116200165200ustar00rootroot00000000000000type triad 'a = (a, a, a) type v3 = triad f32 type m33 = triad v3 entry toM33 a0 a1 a2 b0 b1 b2 c0 c1 c2 : m33 = ( (a0, a1, a2) , (b0, b1, b2) , (c0, c1, c2) ) entry fromM33 (m:m33) = ( m.0.0, m.0.1, m.0.2 , m.1.0, m.1.1, m.1.2 , m.2.0, m.2.1, m.2.2 ) futhark-0.25.27/tests_lib/c/e.fut000066400000000000000000000002241475065116200165160ustar00rootroot00000000000000-- Test that error states are not sticky from one call of an entry -- point to the next. let main (xs: []f32) (is: []i32) = map (\i -> xs[i]) is futhark-0.25.27/tests_lib/c/f.fut000066400000000000000000000000431475065116200165160ustar00rootroot00000000000000let main (xs: []i32) = map (+2) xs futhark-0.25.27/tests_lib/c/g.fut000066400000000000000000000003311475065116200165170ustar00rootroot00000000000000-- Test that size constraints on opaque types are respected. type vec [n] = {vec: [n]i64} entry mk_vec (n: i64) : vec [n] = {vec=iota n} entry use_vec [n] (x: vec [n]) (y: vec [n]) = i64.sum (map2 (+) x.vec y.vec) futhark-0.25.27/tests_lib/c/index.fut000066400000000000000000000001051475065116200173770ustar00rootroot00000000000000entry fun1 n = iota n entry fun2 n = tabulate_2d n n (\i j -> i + j) futhark-0.25.27/tests_lib/c/opaque_array.fut000066400000000000000000000005251475065116200207660ustar00rootroot00000000000000module m : { type t val mk : i32 -> [2]f32 -> t val unmk : t -> (i32,[2]f32) } = { type t = #foo (i32, [2]f32) def mk x y : t = #foo (x,y) def unmk (x: t) = match x case #foo (x,y) -> (x,y) } type t = m.t entry mk (x: i32) (y: [2]f32) : t = m.mk x y entry unmk (x: t): (i32, [2]f32) = m.unmk x entry arr (x: t) : [2]t = [x, x] futhark-0.25.27/tests_lib/c/phantomsize.fut000066400000000000000000000002441475065116200206350ustar00rootroot00000000000000type size [n] = [0][n]() type~ state = size [] entry construct (n: i64) : state = [] : [][n]() entry destruct (s: state) : i64 = let [n] (_: size [n]) = s in n futhark-0.25.27/tests_lib/c/project.fut000066400000000000000000000002561475065116200177450ustar00rootroot00000000000000type sum = #foo i32 | #bar i32 type t0 [n] = ([n]u32,f16,sum) type t1 [n] = (t0 [n],[3]f32) entry main0 [n] (p: *t1 [n]): t1 [] = p entry main1 [n] (y: t0 [n]) : t0 [] = y futhark-0.25.27/tests_lib/c/raw.fut000066400000000000000000000000451475065116200170640ustar00rootroot00000000000000entry main (xs: []i32) = map (+2) xs futhark-0.25.27/tests_lib/c/record_array.fut000066400000000000000000000001531475065116200207470ustar00rootroot00000000000000entry main (xs: []i32) (ys: []f32) = zip xs (zip ys ys) entry main2 (xs: [][]i32) (ys: []f32) = zip xs ys futhark-0.25.27/tests_lib/c/restore.fut000066400000000000000000000004131475065116200177550ustar00rootroot00000000000000-- Test that we can store and restore opaques. type whatever [n] = {vec: [n](i64, bool), b:bool} entry mk (b: bool) : whatever [] = {vec = [(1,b),(2, !b)], b} entry unmk ({vec, b}: whatever []) : ([]i64, []bool, bool) = let (xs,ys) = unzip vec in (xs, ys, b) futhark-0.25.27/tests_lib/c/sum.fut000066400000000000000000000006451475065116200171050ustar00rootroot00000000000000-- Tests various complicated sum types and destructing/constructing -- them. type~ contrived = #foo ([]i32) bool | #bar bool ([]u32) | #baz ([]i32) ([]i32) entry next (c: contrived) : contrived = match c case #foo arr b -> #bar (!b) (map u32.i32 (map (+1) arr)) case #bar b arr -> #baz (map i32.u32 (map (*2) arr)) (map i32.u32 (map (+u32.bool b) arr)) case #baz x y -> #foo (map2 (+) x y) (i32.sum x % 2 == 0) futhark-0.25.27/tests_lib/c/test_a.c000066400000000000000000000013411475065116200171760ustar00rootroot00000000000000#include "a.h" #include int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); int err; int x, y; x = 3; err = futhark_entry_a(ctx, &y, x); assert(err == 0); assert(y==x+2); struct futhark_opaque_t1 *t1; err = futhark_entry_b(ctx, &t1, x); assert(err == 0); struct futhark_opaque_t2 *t2; err = futhark_entry_c(ctx, &t2, t1); assert(err == 0); err = futhark_free_opaque_t1(ctx, t1); assert(err == 0); err = futhark_entry_d(ctx, &y, t2); assert(err == 0); assert(y==x+3); err = futhark_free_opaque_t2(ctx, t2); assert(err == 0); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_b.c000066400000000000000000000006671475065116200172110ustar00rootroot00000000000000#include "b.h" #include int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); int err; struct futhark_opaque_r *r; int x; err = futhark_entry_a(ctx, &x, &r, true); assert(err == 0); assert(x == 32); err = futhark_free_opaque_r(ctx, r); assert(err == 0); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_c.c000066400000000000000000000013671475065116200172100ustar00rootroot00000000000000#include "c.h" #include #include // Test repeated creations and destructions of context. If the // context does not clean up properly after itself, then it is likely // that this test will fail to run. static const int runs = 100; static const int alloc_per_run = 1024*1024*1024; // 1GiB int main() { for (int i = 0; i < runs; i++) { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); int err; struct futhark_i64_1d *arr; err = futhark_entry_main(ctx, &arr, alloc_per_run/8); assert(err == 0); err = futhark_free_i64_1d(ctx, arr); assert(err == 0); futhark_context_free(ctx); futhark_context_config_free(cfg); } } futhark-0.25.27/tests_lib/c/test_const_error.c000066400000000000000000000006421475065116200213200ustar00rootroot00000000000000#include "const_error.h" #include #include #include int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); char *err = futhark_context_get_error(ctx); assert(err != NULL); assert(strstr(err, "false") != NULL); free(err); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_const_error_par.c000066400000000000000000000006561475065116200221670ustar00rootroot00000000000000#include "const_error_par.h" #include #include #include int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); char *err = futhark_context_get_error(ctx); assert(err != NULL); assert(strstr(err, "out of bounds") != NULL); free(err); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_d.c000066400000000000000000000016501475065116200172040ustar00rootroot00000000000000#include "d.h" #include #include int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); int err; float xs[9] = { 0, 1, 2, 3, 4, 5, 6, 7, 8 }; float ys[9]; struct futhark_opaque_m33 *m; err = futhark_entry_toM33(ctx, &m, xs[0], xs[1], xs[2], xs[3], xs[4], xs[5], xs[6], xs[7], xs[8]); assert(err == 0); err = futhark_entry_fromM33(ctx, &ys[0], &ys[1], &ys[2], &ys[3], &ys[4], &ys[5], &ys[6], &ys[7], &ys[8], m); assert(err == 0); assert(memcmp(xs, ys, sizeof(xs)) == 0); err = futhark_free_opaque_m33(ctx, m); assert(err == 0); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_e.c000066400000000000000000000034561475065116200172130ustar00rootroot00000000000000#include "e.h" #include #include #include int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); int err; float xs[] = { 42, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; int is0[] = { -1 }; int is1[] = { 0 }; struct futhark_f32_1d *xs_fut = futhark_new_f32_1d(ctx, xs, 10); struct futhark_i32_1d *is0_fut = futhark_new_i32_1d(ctx, is0, 1); struct futhark_i32_1d *is1_fut = futhark_new_i32_1d(ctx, is1, 1); assert(xs_fut != NULL); assert(is0_fut != NULL); assert(is1_fut != NULL); float out[1]; struct futhark_f32_1d *out_fut = NULL; err = futhark_entry_main(ctx, &out_fut, xs_fut, is0_fut); #if defined(FUTHARK_BACKEND_c) || defined(FUTHARK_BACKEND_multicore) || defined(FUTHARK_BACKEND_ispc) assert(err == FUTHARK_PROGRAM_ERROR); err = futhark_context_sync(ctx); assert(err == 0); #else assert(err == 0); err = futhark_context_sync(ctx); assert(err == FUTHARK_PROGRAM_ERROR); #endif char *error = futhark_context_get_error(ctx); assert(strstr(error, "Index [-1] out of bounds") != NULL); free(error); if (out_fut != NULL) { futhark_free_f32_1d(ctx, out_fut); } err = futhark_entry_main(ctx, &out_fut, xs_fut, is1_fut); assert(err == 0); err = futhark_context_sync(ctx); assert(err == 0); err = futhark_values_f32_1d(ctx, out_fut, out); assert(err == 0); err = futhark_context_sync(ctx); assert(err == 0); assert(out[0] == xs[is1[0]]); err = futhark_free_f32_1d(ctx, xs_fut); assert(err == 0); err = futhark_free_i32_1d(ctx, is0_fut); assert(err == 0); err = futhark_free_i32_1d(ctx, is1_fut); assert(err == 0); err = futhark_free_f32_1d(ctx, out_fut); assert(err == 0); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_f.c000066400000000000000000000025711475065116200172110ustar00rootroot00000000000000#include "f.h" #include #include #include #include #include int xs[] = { 42, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; const int n = sizeof(xs)/sizeof(int); struct thread_args { struct futhark_context *ctx; struct futhark_i32_1d *xs_fut; }; void* thread_fn(void* arg) { struct thread_args args = *(struct thread_args*)arg; struct futhark_context *ctx = args.ctx; struct futhark_i32_1d *out_fut = NULL; int err; err = futhark_entry_main(ctx, &out_fut, args.xs_fut); assert(err == 0); int out[n]; err = futhark_values_i32_1d(ctx, out_fut, out); assert(err == 0); err = futhark_context_sync(ctx); assert(err == 0); for (int i = 0; i < n; i++) { assert(out[i] == xs[i]+2); } err = futhark_free_i32_1d(ctx, out_fut); assert(err == 0); err = futhark_free_i32_1d(ctx, args.xs_fut); assert(err == 0); return NULL; } int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); int err; struct futhark_i32_1d *xs_fut = futhark_new_i32_1d(ctx, xs, n); struct thread_args args; args.ctx = ctx; args.xs_fut = xs_fut; pthread_t tid; err = pthread_create(&tid, NULL, thread_fn, &args); assert(err == 0); err = pthread_join(tid, NULL); assert(err == 0); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_g.c000066400000000000000000000016371475065116200172140ustar00rootroot00000000000000#include "g.h" #include #include #include int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); int err; int64_t n = 1; int64_t m = 1000; struct futhark_opaque_vec *vec_n, *vec_m; err = futhark_entry_mk_vec(ctx, &vec_n, n); assert(err == 0); err = futhark_entry_mk_vec(ctx, &vec_m, m); assert(err == 0); int64_t out = 42; err = futhark_entry_use_vec(ctx, &out, vec_n, vec_n); assert(err == 0); assert(out == 0); err = futhark_entry_use_vec(ctx, &out, vec_n, vec_m); assert(err != 0); char *err_s = futhark_context_get_error(ctx); assert(err_s != NULL); free(err_s); err = futhark_free_opaque_vec(ctx, vec_n); assert(err == 0); err = futhark_free_opaque_vec(ctx, vec_m); assert(err == 0); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_index.c000066400000000000000000000026201475065116200200660ustar00rootroot00000000000000#include "index.h" #include #include #include int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); assert(futhark_context_get_error(ctx) == NULL); char* err; struct futhark_i64_1d *a; assert(futhark_entry_fun1(ctx, &a, 10) == 0); // Error case: index too high. assert(futhark_index_i64_1d(ctx, NULL, a, 100) == 1); assert((err = futhark_context_get_error(ctx)) != NULL); free(err); // Error case: index negative. assert(futhark_index_i64_1d(ctx, NULL, a, -1) == 1); assert((err = futhark_context_get_error(ctx)) != NULL); free(err); // Correct indexing. int64_t x; assert(futhark_index_i64_1d(ctx, &x, a, 5) == 0); assert(futhark_context_sync(ctx) == 0); assert(x == 5); assert(futhark_free_i64_1d(ctx, a) == 0); struct futhark_i64_2d *b; assert(futhark_entry_fun2(ctx, &b, 10) == 0); // Error case: index too high along one dimension, but not the // other. assert(futhark_index_i64_2d(ctx, NULL, b, 0,10) == 1); assert((err = futhark_context_get_error(ctx)) != NULL); free(err); // Correct indexing. int64_t y; assert(futhark_index_i64_2d(ctx, &y, b, 3, 4) == 0); assert(futhark_context_sync(ctx) == 0); assert(y == 3+4); assert(futhark_free_i64_2d(ctx, b) == 0); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_opaque_array.c000066400000000000000000000035601475065116200214530ustar00rootroot00000000000000#include "opaque_array.h" #include #include #include #include void test1(struct futhark_context *ctx) { int32_t a = 42; float b[2] = {1,2}; char* err; struct futhark_f32_1d *b_fut = futhark_new_f32_1d(ctx, b, 2); assert(b_fut != NULL); struct futhark_opaque_tup2_i32_arr1d_f32* a_b_fut; assert(futhark_new_opaque_tup2_i32_arr1d_f32(ctx, &a_b_fut, a, b_fut) == 0); struct futhark_opaque_t* t; assert(futhark_entry_mk(ctx, &t, a, b_fut) == 0); struct futhark_opaque_arr1d_t* arr_t; assert(futhark_entry_arr(ctx, &arr_t, t) == 0); // Test shape. assert(futhark_shape_opaque_arr1d_t(ctx, arr_t)[0] == 2); // Test index out of bounds. assert(futhark_index_opaque_arr1d_t(ctx, NULL, arr_t, 2) != 0); assert((err = futhark_context_get_error(ctx)) != NULL); free(err); // Test correct indexing. { struct futhark_opaque_t* out; assert(futhark_index_opaque_arr1d_t(ctx, &out, arr_t, 1) == 0); int32_t out0; struct futhark_f32_1d *out1; assert(futhark_entry_unmk(ctx, &out0, &out1, out) == 0); assert(out0 == a); float out1_host[2]; assert(futhark_values_f32_1d(ctx, out1, out1_host) == 0); assert(futhark_context_sync(ctx) == 0); assert(memcmp(out1_host, b, sizeof(float)*2) == 0); assert(futhark_free_opaque_t(ctx, out) == 0); assert(futhark_free_f32_1d(ctx, out1) == 0); } assert(futhark_free_f32_1d(ctx, b_fut) == 0); assert(futhark_free_opaque_t(ctx, t) == 0); assert(futhark_free_opaque_arr1d_t(ctx, arr_t) == 0); assert(futhark_free_opaque_tup2_i32_arr1d_f32(ctx, a_b_fut) == 0); } int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); assert(futhark_context_get_error(ctx) == NULL); test1(ctx); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_phantomsize.c000066400000000000000000000017461475065116200213300ustar00rootroot00000000000000#include "phantomsize.h" #include #include int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); int64_t n = 1000000, m; struct futhark_opaque_state *s; assert(futhark_entry_construct(ctx, &s, n) == 0); assert(futhark_context_sync(ctx) == 0); void *bytes = NULL; size_t num_bytes = 0; assert(futhark_store_opaque_state(ctx, s, &bytes, &num_bytes) == 0); assert(futhark_free_opaque_state(ctx, s) == 0); s = NULL; // Point here is to check that we don't need to store a large array. assert(num_bytes < 100); s = futhark_restore_opaque_state(ctx, bytes); assert(s != NULL); assert(futhark_context_sync(ctx) == 0); assert(futhark_entry_destruct(ctx, &m, s) == 0); assert(futhark_context_sync(ctx) == 0); assert(n == m); assert(futhark_free_opaque_state(ctx, s) == 0); free(bytes); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_project.c000066400000000000000000000045701475065116200204330ustar00rootroot00000000000000#include "project.h" #include #include int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); int err; struct futhark_opaque_sum *sum; err = futhark_new_opaque_sum_foo(ctx, &sum, 42); assert(err == 0); uint32_t u32_data[3] = { 42, 1337, 420 }; struct futhark_u32_1d* u32_arr = futhark_new_u32_1d(ctx, u32_data, 3); assert(u32_arr != NULL); float f32_data[3] = { 42, 1337, 420 }; struct futhark_f32_1d* f32_arr = futhark_new_f32_1d(ctx, f32_data, 3); assert(f32_arr != NULL); struct futhark_opaque_t0 *t0; err = futhark_new_opaque_t0(ctx, &t0, u32_arr, 10, sum); assert(err == 0); err = futhark_free_u32_1d(ctx, u32_arr); assert(err == 0); for (int i = 0; i < 2; i++) { struct futhark_u32_1d* u32_arr_projected; err = futhark_project_opaque_t0_0(ctx, &u32_arr_projected, t0); assert(err == 0); uint32_t u32_data_projected[3]; err = futhark_values_u32_1d(ctx, u32_arr_projected, u32_data_projected); assert(err == 0); err = futhark_context_sync(ctx); assert(err == 0); assert(memcmp(u32_data, u32_data_projected, sizeof(uint32_t) * 3) == 0); err = futhark_free_u32_1d(ctx, u32_arr_projected); assert(err == 0); } struct futhark_opaque_t1 *t1; err = futhark_new_opaque_t1(ctx, &t1, t0, f32_arr); assert(err == 0); for (int i = 0; i < 2; i++) { struct futhark_opaque_t0* t0; err = futhark_project_opaque_t1_0(ctx, &t0, t1); assert(err == 0); struct futhark_u32_1d* u32_arr_projected; err = futhark_project_opaque_t0_0(ctx, &u32_arr_projected, t0); assert(err == 0); uint32_t u32_data_projected[3]; err = futhark_values_u32_1d(ctx, u32_arr_projected, u32_data_projected); assert(err == 0); err = futhark_context_sync(ctx); assert(err == 0); assert(memcmp(u32_data, u32_data_projected, sizeof(uint32_t) * 3) == 0); err = futhark_free_u32_1d(ctx, u32_arr_projected); assert(err == 0); err = futhark_free_opaque_t0(ctx, t0); assert(err == 0); } err = futhark_free_opaque_t1(ctx, t1); assert(err == 0); err = futhark_free_opaque_t0(ctx, t0); assert(err == 0); err = futhark_free_opaque_sum(ctx, sum); assert(err == 0); err = futhark_free_f32_1d(ctx, f32_arr); assert(err == 0); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_raw.c000066400000000000000000000015471475065116200175570ustar00rootroot00000000000000// Fiddling around with the raw arrays API. #include "raw.h" #include #include #include int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); #ifdef FUTHARK_BACKEND_c int32_t data[3] = {1,2,3}; struct futhark_i32_1d *in = futhark_new_raw_i32_1d(ctx, (unsigned char*)&data, 3); assert(in != NULL); struct futhark_i32_1d *out; assert(futhark_entry_main(ctx, &out, in) == FUTHARK_SUCCESS); int32_t *out_ptr = (int32_t*)futhark_values_raw_i32_1d(ctx, out); assert(futhark_context_sync(ctx) == FUTHARK_SUCCESS); assert(out_ptr[0] == 3); assert(out_ptr[1] == 4); assert(out_ptr[2] == 5); futhark_free_i32_1d(ctx, in); futhark_free_i32_1d(ctx, out); #endif futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_record_array.c000066400000000000000000000102611475065116200214330ustar00rootroot00000000000000#include "record_array.h" #include #include #include #include void test1(struct futhark_context *ctx) { char* err; int32_t a[] = {1,2,3}; float b[] = {1,2,3}; struct futhark_i32_1d *a_fut = futhark_new_i32_1d(ctx, a, 3); assert(a_fut != NULL); struct futhark_f32_1d *b_fut = futhark_new_f32_1d(ctx, b, 3); assert(b_fut != NULL); struct futhark_f32_1d *b_short_fut = futhark_new_f32_1d(ctx, b, 2); assert(b_short_fut != NULL); struct futhark_opaque_arr1d_tup2_f32_f32 *b_b_fut; // Error case for zip. assert(futhark_zip_opaque_arr1d_tup2_f32_f32(ctx, &b_b_fut, b_fut, b_short_fut) == 1); err = futhark_context_get_error(ctx); assert(err != NULL); free(err); // Correct use of zip. assert(futhark_zip_opaque_arr1d_tup2_f32_f32(ctx, &b_b_fut, b_fut, b_fut) == 0); struct futhark_opaque_arr1d_tup2_i32_tup2_f32_f32 *a_b_b_fut; assert(futhark_zip_opaque_arr1d_tup2_i32_tup2_f32_f32(ctx, &a_b_b_fut, a_fut, b_b_fut) == 0); assert(futhark_shape_opaque_arr1d_tup2_i32_tup2_f32_f32(ctx, a_b_b_fut)[0] == 3); // Test indexing: out of bounds. assert(futhark_index_opaque_arr1d_tup2_i32_tup2_f32_f32(ctx, NULL, a_b_b_fut, 3) != 0); err = futhark_context_get_error(ctx); assert(err != NULL); free(err); // Test indexing: in bounds. struct futhark_opaque_tup2_i32_tup2_f32_f32* trip_fut; assert(futhark_index_opaque_arr1d_tup2_i32_tup2_f32_f32(ctx, &trip_fut, a_b_b_fut, 1) == 0); assert(futhark_context_sync(ctx) == 0); // XXX, would be nice if this was not required. struct futhark_opaque_tup2_f32_f32* pair_fut; assert(futhark_project_opaque_tup2_i32_tup2_f32_f32_1(ctx, &pair_fut, trip_fut) == 0); { int x; assert(futhark_project_opaque_tup2_i32_tup2_f32_f32_0(ctx, &x, trip_fut) == 0); assert(futhark_context_sync(ctx) == 0); assert(x == a[1]); } { float x; assert(futhark_project_opaque_tup2_f32_f32_0(ctx, &x, pair_fut) == 0); assert(futhark_context_sync(ctx) == 0); assert(x == b[1]); } { float x; assert(futhark_project_opaque_tup2_f32_f32_1(ctx, &x, pair_fut) == 0); assert(futhark_context_sync(ctx) == 0); assert(x == b[1]); } assert(futhark_free_opaque_tup2_f32_f32(ctx, pair_fut) == 0); assert(futhark_free_opaque_tup2_i32_tup2_f32_f32(ctx, trip_fut) == 0); assert(futhark_free_opaque_arr1d_tup2_i32_tup2_f32_f32(ctx, a_b_b_fut) == 0); assert(futhark_free_f32_1d(ctx, b_fut) == 0); assert(futhark_free_f32_1d(ctx, b_short_fut) == 0); assert(futhark_free_i32_1d(ctx, a_fut) == 0); assert(futhark_free_opaque_arr1d_tup2_f32_f32(ctx, b_b_fut) == 0); } void test2(struct futhark_context *ctx) { struct futhark_opaque_arr1d_tup2_arr1d_i32_f32 *a_b_fut; int32_t a[] = {1,2,3,4}; float b[] = {5,6}; struct futhark_i32_2d *a_fut = futhark_new_i32_2d(ctx, a, 2, 2); assert(a_fut != NULL); struct futhark_f32_1d *b_fut = futhark_new_f32_1d(ctx, b, 2); assert(b_fut != NULL); assert(futhark_zip_opaque_arr1d_tup2_arr1d_i32_f32(ctx, &a_b_fut, a_fut, b_fut) == 0); assert(futhark_free_f32_1d(ctx, b_fut) == 0); assert(futhark_free_i32_2d(ctx, a_fut) == 0); // Valid indexing. struct futhark_opaque_tup2_arr1d_i32_f32* a_b_elem_fut; assert(futhark_index_opaque_arr1d_tup2_arr1d_i32_f32(ctx, &a_b_elem_fut, a_b_fut, 1) == 0); { struct futhark_i32_1d *out_fut; assert(futhark_project_opaque_tup2_arr1d_i32_f32_0(ctx, &out_fut, a_b_elem_fut) == 0); int32_t out[2]; assert(futhark_values_i32_1d(ctx, out_fut, out) == 0); assert(futhark_context_sync(ctx) == 0); assert(memcmp(out, &a[2], sizeof(int32_t)*2) == 0); assert(futhark_free_i32_1d(ctx, out_fut) == 0); } { float out; assert(futhark_project_opaque_tup2_arr1d_i32_f32_1(ctx, &out, a_b_elem_fut) == 0); assert(out == b[1]); } assert(futhark_free_opaque_arr1d_tup2_arr1d_i32_f32(ctx, a_b_fut) == 0); assert(futhark_free_opaque_tup2_arr1d_i32_f32(ctx, a_b_elem_fut) == 0); } int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); assert(futhark_context_get_error(ctx) == NULL); test1(ctx); test2(ctx); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_restore.c000066400000000000000000000030701475065116200204420ustar00rootroot00000000000000#include "restore.h" #include #include #include int main() { struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); int err; struct futhark_opaque_whatever *whatever; err = futhark_entry_mk(ctx, &whatever, 1); assert(err == 0); err = futhark_context_sync(ctx); assert(err == 0); void *bytes0 = NULL; void *bytes1; size_t n0, n1; err = futhark_store_opaque_whatever(ctx, whatever, &bytes0, &n0); assert(err == 0); err = futhark_store_opaque_whatever(ctx, whatever, NULL, &n1); assert(err == 0); assert(n0 == n1); bytes1 = malloc(n1); err = futhark_store_opaque_whatever(ctx, whatever, &bytes1, &n1); err = futhark_context_sync(ctx); assert(err == 0); assert(memcmp(bytes0, bytes1, n0) == 0); err = futhark_free_opaque_whatever(ctx, whatever); assert(err == 0); whatever = futhark_restore_opaque_whatever(ctx, bytes0); assert(whatever != NULL); struct futhark_i64_1d *out0; struct futhark_bool_1d *out1; bool out2; err = futhark_entry_unmk(ctx, &out0, &out1, &out2, whatever); assert(err == 0); assert(out2 == 1); free(bytes0); free(bytes1); err = futhark_free_i64_1d(ctx, out0); assert(err == 0); err = futhark_free_bool_1d(ctx, out1); assert(err == 0); // Test that passing in garbage fails predictably. bytes1 = calloc(n0, 1); whatever = futhark_restore_opaque_whatever(ctx, bytes1); assert(whatever == NULL); free(bytes1); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/test_sum.c000066400000000000000000000032771475065116200175740ustar00rootroot00000000000000#include "sum.h" #include #include #include int main() { int rounds = 10; struct futhark_context_config *cfg = futhark_context_config_new(); struct futhark_context *ctx = futhark_context_new(cfg); struct futhark_opaque_contrived* v; { int32_t data[] = { 1, 2, 3 }; struct futhark_i32_1d* arr = futhark_new_i32_1d(ctx, data, sizeof(data)/sizeof(int32_t)); assert(arr != NULL); assert(futhark_context_sync(ctx) == FUTHARK_SUCCESS); assert(futhark_new_opaque_contrived_foo(ctx, &v, arr, true) == FUTHARK_SUCCESS); futhark_free_i32_1d(ctx, arr); } for (int i = 0; i < rounds; i++) { switch (futhark_variant_opaque_contrived(ctx, v)) { case 0: { bool v0; struct futhark_u32_1d *v1; futhark_destruct_opaque_contrived_bar(ctx, &v0, &v1, v); futhark_free_u32_1d(ctx, v1); } break; case 1: { struct futhark_i32_1d *v0; struct futhark_i32_1d *v1; futhark_destruct_opaque_contrived_baz(ctx, &v0, &v1, v); futhark_free_i32_1d(ctx, v0); futhark_free_i32_1d(ctx, v1); } break; case 2: { struct futhark_i32_1d *v0; bool v1; futhark_destruct_opaque_contrived_foo(ctx, &v0, &v1, v); futhark_free_i32_1d(ctx, v0); } break; default: abort(); }; struct futhark_opaque_contrived* v_new; assert(futhark_entry_next(ctx, &v_new, v) == FUTHARK_SUCCESS); assert(futhark_free_opaque_contrived(ctx, v) == FUTHARK_SUCCESS); v = v_new; } assert(futhark_free_opaque_contrived(ctx, v) == FUTHARK_SUCCESS); futhark_context_free(ctx); futhark_context_config_free(cfg); } futhark-0.25.27/tests_lib/c/validatemanifest.py000077500000000000000000000003121475065116200214450ustar00rootroot00000000000000#!/usr/bin/env python3 from jsonschema import validate import json import sys schema = json.load(open(sys.argv[1])) manifest = json.load(open(sys.argv[2])) validate(instance=manifest, schema=schema) futhark-0.25.27/tests_lib/javascript/000077500000000000000000000000001475065116200175005ustar00rootroot00000000000000futhark-0.25.27/tests_lib/javascript/Makefile000066400000000000000000000003701475065116200211400ustar00rootroot00000000000000FUTHARK_BACKEND ?= wasm .PHONY: test clean test: do_test_a do_test_array do_test_%: test_%.js %.mjs node --experimental-wasm-simd test_$*.js %.mjs: %.fut futhark $(FUTHARK_BACKEND) --library $^ clean: rm -rf *.c *.h *.class.js *.wasm *.mjs futhark-0.25.27/tests_lib/javascript/a.fut000066400000000000000000000001141475065116200204340ustar00rootroot00000000000000-- Simple library with the increment function entry incr (x : i32) = x + 1 futhark-0.25.27/tests_lib/javascript/array.fut000066400000000000000000000006031475065116200213350ustar00rootroot00000000000000-- Library with entry points that accept and return -- array inputs and outputs entry sum1d (xs : []i32) : i32 = i32.sum xs entry sum2d (xss : [][]i64) : i64 = let row_sums = map i64.sum xss in i64.sum row_sums entry replicate1d (n : i64) (x : f32) : []f32 = replicate n x entry replicate2d (n: i64) (m : i64) (x : f32) : [][]f32 = let row = replicate m x in replicate n row futhark-0.25.27/tests_lib/javascript/package.json000066400000000000000000000000271475065116200217650ustar00rootroot00000000000000{ "type": "module" } futhark-0.25.27/tests_lib/javascript/test_a.js000066400000000000000000000012131475065116200213120ustar00rootroot00000000000000// Tests for a.fut, the increment function import assert from 'assert/strict'; // Hack for Running generated ES6 modules from Emscripten with Node // https://github.com/emscripten-core/emscripten/issues/11792#issuecomment-877120580 import {dirname} from "path"; import { createRequire } from 'module'; // substring removes file:// from the filepath globalThis.__dirname = dirname(import.meta.url).substring(7); globalThis.require = createRequire(import.meta.url); // Imports from the generated ES6 Module import { newFutharkContext } from './a.mjs'; var fc; newFutharkContext().then(x => { fc = x; var res = fc.incr(7) assert(res == 8); }); futhark-0.25.27/tests_lib/javascript/test_array.js000066400000000000000000000067101475065116200222170ustar00rootroot00000000000000// Tests for array.fut import assert from 'assert/strict'; // Hack for Running generated ES6 modules from Emscripten with Node // https://github.com/emscripten-core/emscripten/issues/11792#issuecomment-877120580 import {dirname} from "path"; import { createRequire } from 'module'; // substring removes file:// from the filepath globalThis.__dirname = dirname(import.meta.url).substring(7); globalThis.require = createRequire(import.meta.url); // Imports from the generated ES6 Module import { newFutharkContext, FutharkArray } from './array.mjs'; newFutharkContext().then(fc => { console.log(); console.log("Testing Entry Points..."); console.log(); console.log("Testing Entry Point : sum1d"); var arr_1d = new Int32Array([1, 2, 3, 4, 5, 6]); var fut_arr_1d = fc.new_i32_1d(arr_1d, arr_1d.length); var sum1d_res = fc.sum1d(fut_arr_1d); assert(sum1d_res === 21); console.log("Testing Entry Point : sum2d"); var arr_2d = new BigInt64Array([1n, 2n, 3n, 4n, 5n, 6n]); var fut_arr_2d = fc.new_i64_2d(arr_2d, 2, 3); var sum2d_res = fc.sum2d(fut_arr_2d); assert(sum2d_res === 21n); console.log("Testing Entry Point : replicate1d"); var n = 5n; var fut_res_arr_1d = fc.replicate1d(n, 1.1); var replicate1d_res_arr = fut_res_arr_1d.toArray(); for (var i = 0; i < Number(n); i++) { // check with consideration for floating point precision assert(Math.abs(replicate1d_res_arr[i] - 1.1) < .0001); } console.log("Testing Entry Point : replicate2d"); var x = 5n; var y = 2n; var fut_res_arr_2d = fc.replicate2d(x, y, 1.1); var replicate1d_res_arr = fut_res_arr_2d.toArray(); for (var i = 0; i < Number(x); i++) { for (var j = 0; j < Number(y); j++) { // check with consideration for floating point precision assert(Math.abs(replicate1d_res_arr[i][j] - 1.1) < .0001); } } console.log(); console.log("Array API Tests..."); console.log(); console.log("Testing array construction with numbers and bigints"); var test_arr = [1n, 2n, 3n, 4n, 5n, 6n]; var futhark_test_array_shape_ints= fc.new_i64_2d(test_arr, 2, 3); var futhark_test_array_shape_bigints = fc.new_i64_2d(test_arr, 2n, 3n); var shape_ints = futhark_test_array_shape_ints.shape(); var shape_bigints = futhark_test_array_shape_bigints.shape(); assert(shape_ints[0] === shape_bigints[0]); assert(shape_ints[1] === shape_bigints[1]); console.log("Testing toArray"); var futhark_test_array = fc.new_i64_2d(test_arr, 2, 3); var arr_toArray = futhark_test_array.toArray(); for (var i = 0; i < 2; i++) { for (var j = 0; j < 3; j++) { assert(arr_toArray[i][j] === test_arr[i * 3 + j]); } } console.log("Testing toTypedArray"); var arr_toTypedArray = futhark_test_array.toTypedArray(); for (var i = 0; i < 3 * 2; i++) { assert(arr_toTypedArray[i] === test_arr[i]); } console.log("Testing shape"); var expected_shape = [2n, 3n]; var actual_shape = futhark_test_array.shape(); assert(actual_shape[0] === actual_shape[0]); assert(expected_shape[1] === expected_shape[1]); console.log("Testing frees"); fut_arr_1d.free(); fut_arr_2d.free(); fut_res_arr_2d.free(); fut_res_arr_1d.free(); futhark_test_array.free(); console.log("Testing access after free") assert.throws(() => fut_test_array.toArray()); assert.throws(() => fut_test_array.toTypedArray()); assert.throws(() => fut_test_array.shape()); fc.free(); console.log(); console.log("Tests complete"); console.log(); }); futhark-0.25.27/tests_lib/python/000077500000000000000000000000001475065116200166535ustar00rootroot00000000000000futhark-0.25.27/tests_lib/python/.gitignore000066400000000000000000000000271475065116200206420ustar00rootroot00000000000000__pycache__ *.pyc *.py futhark-0.25.27/tests_lib/python/Makefile000066400000000000000000000003171475065116200203140ustar00rootroot00000000000000FUTHARK_BACKEND ?= python .PHONY: test clean test: do_test_a do_test_g do_test_%: test_% %.py ./test_$* test_%: %.py %.py: %.fut futhark $(FUTHARK_BACKEND) --library $^ clean: rm -rf test_? ?.c ?.h futhark-0.25.27/tests_lib/python/a.fut000066400000000000000000000002241475065116200176110ustar00rootroot00000000000000-- Test that error states are not sticky from one call of an entry -- point to the next. let main (xs: []f32) (is: []i32) = map (\i -> xs[i]) is futhark-0.25.27/tests_lib/python/g.fut000066400000000000000000000003311475065116200176160ustar00rootroot00000000000000-- Test that size constraints on opaque types are respected. type vec [n] = {vec: [n]i64} entry mk_vec (n: i64) : vec [n] = {vec=iota n} entry use_vec [n] (x: vec [n]) (y: vec [n]) = i64.sum (map2 (+) x.vec y.vec) futhark-0.25.27/tests_lib/python/test_a000077500000000000000000000005451475065116200200640ustar00rootroot00000000000000#!/usr/bin/env python3 import a import numpy as np obj = a.a() try: obj.main(np.array([1,2,3], dtype=np.float32), np.array([-1], dtype=np.int32)) except Exception as e: assert("Index [-1] out of bounds" in str(e)) res = obj.main(np.array([1,2,3], dtype=np.float32), np.array([0], dtype=np.int32)) assert(res[0] == 1) futhark-0.25.27/tests_lib/python/test_g000077500000000000000000000004421475065116200200660ustar00rootroot00000000000000#!/usr/bin/env python3 import g import numpy as np obj = g.g() n = 1 m = 1000 vec_n = obj.mk_vec(n) vec_m = obj.mk_vec(m) assert(obj.use_vec(vec_n, vec_n) == 0) try: obj.use_vec(vec_n, vec_m) # Should fail. assert(False) except Exception as e: assert('invalid' in str(e)) futhark-0.25.27/tests_literate/000077500000000000000000000000001475065116200163755ustar00rootroot00000000000000futhark-0.25.27/tests_literate/.gitignore000066400000000000000000000000661475065116200203670ustar00rootroot00000000000000/* /*/ !*.fut !/expected/ !.gitignore !test.sh !data/ futhark-0.25.27/tests_literate/audio.fut000066400000000000000000000030031475065116200202120ustar00rootroot00000000000000def output_hz = 44100i64 def standard_pitch = 440.0f32 def pitch (i: i64): f32 = standard_pitch * 2 ** (f32.i64 i/12) def num_samples (duration: f32): i64 = i64.f32 (f32.i64 output_hz * duration) def sample (p: f32) (i: i64): f32 = f32.sin (2 * f32.pi * f32.i64 i * p / f32.i64 output_hz) def note (i: i64) (duration: f32): []f32 = let p = pitch i let n = num_samples duration in tabulate n (sample p) def break (duration: f32): []f32 = replicate (num_samples duration) 0.0 entry left = let c = note 3 let d = note 5 let e = note 7 let f = note 8 let g = note 10 in c 0.3 ++ d 0.3 ++ e 0.3 ++ c 0.3 ++ c 0.3 ++ d 0.3 ++ e 0.3 ++ c 0.3 ++ e 0.3 ++ f 0.3 ++ g 0.6 ++ e 0.3 ++ f 0.3 ++ g 0.6 entry right = let c = note 3 let d = note 5 let e = note 7 let f = note 8 let g = note 10 in break (8 * 0.3) ++ c 0.3 ++ d 0.3 ++ e 0.3 ++ c 0.3 ++ c 0.3 ++ d 0.3 ++ e 0.3 ++ c 0.3 -- > :audio left; -- sampling_frequency: 44100 -- codec: ogg -- > :audio right; -- sampling_frequency: 44100 -- codec: ogg entry stereo = let [k] left: [k]f32 = left let right = right :> [k]f32 in [left, right] -- > :audio stereo entry surround = let [k] left: [k]f32 = left let right = right :> [k]f32 in [left, right, right, right, right, right] -- > :audio surround -- > $loadaudio "mono.wav" -- > $loadaudio "stereo.wav" futhark-0.25.27/tests_literate/basic.fut000066400000000000000000000002141475065116200201730ustar00rootroot00000000000000-- Let us see if this works. let main x = x + 2 -- > main 2 -- > main 2f32 --The lack of a final newline here is intentional let x = truefuthark-0.25.27/tests_literate/coerce.fut000066400000000000000000000001011475065116200203450ustar00rootroot00000000000000def f = i64.sum def g = f64.sum -- > f [1,2,3] -- > g [1,2,3] futhark-0.25.27/tests_literate/data/000077500000000000000000000000001475065116200173065ustar00rootroot00000000000000futhark-0.25.27/tests_literate/data/array.in000066400000000000000000000000151475065116200207500ustar00rootroot00000000000000[1i32, 2, 3] futhark-0.25.27/tests_literate/data/array_and_value.in000066400000000000000000000000231475065116200227650ustar00rootroot00000000000000[1i32, 2, 3] 10i32 futhark-0.25.27/tests_literate/expected/000077500000000000000000000000001475065116200201765ustar00rootroot00000000000000futhark-0.25.27/tests_literate/expected/audio.md000066400000000000000000000037211475065116200216240ustar00rootroot00000000000000```futhark def output_hz = 44100i64 def standard_pitch = 440.0f32 def pitch (i: i64): f32 = standard_pitch * 2 ** (f32.i64 i/12) def num_samples (duration: f32): i64 = i64.f32 (f32.i64 output_hz * duration) def sample (p: f32) (i: i64): f32 = f32.sin (2 * f32.pi * f32.i64 i * p / f32.i64 output_hz) def note (i: i64) (duration: f32): []f32 = let p = pitch i let n = num_samples duration in tabulate n (sample p) def break (duration: f32): []f32 = replicate (num_samples duration) 0.0 entry left = let c = note 3 let d = note 5 let e = note 7 let f = note 8 let g = note 10 in c 0.3 ++ d 0.3 ++ e 0.3 ++ c 0.3 ++ c 0.3 ++ d 0.3 ++ e 0.3 ++ c 0.3 ++ e 0.3 ++ f 0.3 ++ g 0.6 ++ e 0.3 ++ f 0.3 ++ g 0.6 entry right = let c = note 3 let d = note 5 let e = note 7 let f = note 8 let g = note 10 in break (8 * 0.3) ++ c 0.3 ++ d 0.3 ++ e 0.3 ++ c 0.3 ++ c 0.3 ++ d 0.3 ++ e 0.3 ++ c 0.3 ``` ``` > :audio left; sampling_frequency: 44100 codec: ogg ``` ![](audio-img/4f72a37a9b3cf4b8a68ce450f0f696d5-output.ogg) ``` > :audio right; sampling_frequency: 44100 codec: ogg ``` ![](audio-img/53655790dd8b1739a749c772c9f6bc13-output.ogg) ```futhark entry stereo = let [k] left: [k]f32 = left let right = right :> [k]f32 in [left, right] ``` ``` > :audio stereo ``` ![](audio-img/9e12f26cca7539fa04a923f7b652d2e8-output.wav) ```futhark entry surround = let [k] left: [k]f32 = left let right = right :> [k]f32 in [left, right, right, right, right, right] ``` ``` > :audio surround ``` ![](audio-img/7e0c96822c449db1dd5712a2b809ff40-output.wav) ``` > $loadaudio "mono.wav" ``` ``` [[-0.9921875f64, -0.984375f64, -0.9765625f64]] ``` ``` > $loadaudio "stereo.wav" ``` ``` [[-0.9921875f64, -0.984375f64, -0.9765625f64], [-0.96875f64, -0.9609375f64, -0.953125f64]] ``` futhark-0.25.27/tests_literate/expected/basic.md000066400000000000000000000004651475065116200216060ustar00rootroot00000000000000Let us see if this works. ```futhark let main x = x + 2 ``` ``` > main 2 ``` ``` 4i32 ``` ``` > main 2f32 ``` **FAILED** ``` Function "main" expects 1 argument(s) of types: i32 But applied to 1 argument(s) of types: f32 ``` The lack of a final newline here is intentional ```futhark let x = true ``` futhark-0.25.27/tests_literate/expected/coerce.md000066400000000000000000000001711475065116200217570ustar00rootroot00000000000000```futhark def f = i64.sum def g = f64.sum ``` ``` > f [1,2,3] ``` ``` 6i64 ``` ``` > g [1,2,3] ``` ``` 6.0f64 ``` futhark-0.25.27/tests_literate/expected/img.md000066400000000000000000000007761475065116200213060ustar00rootroot00000000000000```futhark let checkerboard_f32 = map (\i -> map (\j -> f32.i64 ((j + i % 2) % 2)) (iota 8)) (iota 8) ``` ``` > :img checkerboard_f32 ``` ![](img-img/26b07dbf1f772d987b40f08e0b0e0eab-img.png) ```futhark let checkerboard_bool = map (\i -> map (\j -> ((j + i % 2) % 2) == 0) (iota 8)) (iota 8) ``` ``` > :img checkerboard_bool ``` ![](img-img/6ffd56d07e2c485080c59b4cbd6682b0-img.png) ``` > :img checkerboard_bool; file: foo.bar ``` ![](img-img/foo.bar) futhark-0.25.27/tests_literate/expected/img_noentry.md000066400000000000000000000001551475065116200230530ustar00rootroot00000000000000``` > :img $loadimg "../assets/ohyes.png" ``` ![](img_noentry-img/e66b70cae5b1b3df4bae3afa03f7ee7e-img.png) futhark-0.25.27/tests_literate/expected/loaddata.md000066400000000000000000000003421475065116200222700ustar00rootroot00000000000000``` > $loaddata "data/array.in" ``` ``` [1i32, 2i32, 3i32] ``` ```futhark let add_scalar (y: i32) = map (+y) ``` ``` > let (xs, y) = $loaddata "data/array_and_value.in" in add_scalar y xs ``` ``` [11i32, 12i32, 13i32] ``` futhark-0.25.27/tests_literate/expected/loadimg.md000066400000000000000000000002771475065116200221420ustar00rootroot00000000000000```futhark let foo [n][m] (img: [n][m]u32): [n][m]u32 = map (map id) img ``` ``` > :img foo ($loadimg "../assets/ohyes.png") ``` ![](loadimg-img/fdccf88abb5fb0001143631ae1c452e6-img.png) futhark-0.25.27/tests_literate/expected/nonl.md000066400000000000000000000002451475065116200214670ustar00rootroot00000000000000Note: it is intentional that this file does not include a final newline character. Do not add it! ```futhark def main x = x + 1 ``` ``` > main 1 ``` ``` 2i32 ``` futhark-0.25.27/tests_literate/expected/opaque_tuple.md000066400000000000000000000005311475065116200232220ustar00rootroot00000000000000Based on #1750. ```futhark type vec2 = (i64, i64) def moving_point (x: i64, y: i64) : ([][]u32, vec2) = let canvas = replicate 128 (replicate 128 0) with [y, x] = 255 in (canvas, (x + 1, y+1)) ``` ``` > :video (moving_point, (10i64, 10i64), 50i64); fps: 1 format: gif ``` ![](opaque_tuple-img/da140d8304219232f384b6a2d5ea4be9-video.gif) futhark-0.25.27/tests_literate/expected/tuple.md000066400000000000000000000003141475065116200216470ustar00rootroot00000000000000```futhark entry f (x: i32, (y: i32, z: i32)) = x + y + z ``` ``` > f (1,(2,3)) ``` ``` 6i32 ``` ```futhark entry g {x: i32, y: (i32,i32)} = x - y.0 + y.1 ``` ``` > g {y=(7,2),x=10} ``` ``` 5i32 ``` futhark-0.25.27/tests_literate/expected/video.md000066400000000000000000000004671475065116200216350ustar00rootroot00000000000000```futhark def moving_point (x: i64, y: i64) : ([][]u32, (i64, i64)) = let canvas = replicate 128 (replicate 128 0) with [y, x] = 0x00FFFFFF in (canvas, (x + 1, y+1)) ``` ``` > :video (moving_point, (10i64, 10i64), 50i64,); fps: 1 format: gif ``` ![](video-img/f1a78d77be28ef1a5de76f71fe3a4b91-video.gif) futhark-0.25.27/tests_literate/img.fut000066400000000000000000000005171475065116200176740ustar00rootroot00000000000000let checkerboard_f32 = map (\i -> map (\j -> f32.i64 ((j + i % 2) % 2)) (iota 8)) (iota 8) -- > :img checkerboard_f32 let checkerboard_bool = map (\i -> map (\j -> ((j + i % 2) % 2) == 0) (iota 8)) (iota 8) -- > :img checkerboard_bool -- > :img checkerboard_bool; -- file: foo.bar futhark-0.25.27/tests_literate/img_noentry.fut000066400000000000000000000000511475065116200214430ustar00rootroot00000000000000-- > :img $loadimg "../assets/ohyes.png" futhark-0.25.27/tests_literate/loaddata.fut000066400000000000000000000002161475065116200206650ustar00rootroot00000000000000-- > $loaddata "data/array.in" let add_scalar (y: i32) = map (+y) -- > let (xs, y) = $loaddata "data/array_and_value.in" in add_scalar y xs futhark-0.25.27/tests_literate/loadimg.fut000066400000000000000000000001601475065116200205260ustar00rootroot00000000000000let foo [n][m] (img: [n][m]u32): [n][m]u32 = map (map id) img -- > :img foo ($loadimg "../assets/ohyes.png") futhark-0.25.27/tests_literate/mono.wav000066400000000000000000000001241475065116200200610ustar00rootroot00000000000000RIFFLWAVEfmt DXLISTINFOISFTLavf58.76.100datafuthark-0.25.27/tests_literate/nonl.fut000066400000000000000000000002101475065116200200540ustar00rootroot00000000000000-- Note: it is intentional that this file does not include a final -- newline character. Do not add it! def main x = x + 1 -- > main 1futhark-0.25.27/tests_literate/opaque_tuple.fut000066400000000000000000000004141475065116200216170ustar00rootroot00000000000000-- Based on #1750. type vec2 = (i64, i64) def moving_point (x: i64, y: i64) : ([][]u32, vec2) = let canvas = replicate 128 (replicate 128 0) with [y, x] = 255 in (canvas, (x + 1, y+1)) -- > :video (moving_point, (10i64, 10i64), 50i64); -- fps: 1 -- format: gif futhark-0.25.27/tests_literate/stereo.wav000066400000000000000000000001321475065116200204110ustar00rootroot00000000000000RIFFRWAVEfmt DLISTINFOISFTLavf58.76.100data futhark-0.25.27/tests_literate/test.sh000077500000000000000000000003071475065116200177130ustar00rootroot00000000000000#!/bin/sh for x in *.fut; do md=$(basename -s .fut $x).md echo echo "$x ($(futhark hash $x)):" futhark literate $x if ! diff -u expected/$md $md; then exit 1 fi done futhark-0.25.27/tests_literate/tuple.fut000066400000000000000000000002101475065116200202370ustar00rootroot00000000000000entry f (x: i32, (y: i32, z: i32)) = x + y + z -- > f (1,(2,3)) entry g {x: i32, y: (i32,i32)} = x - y.0 + y.1 -- > g {y=(7,2),x=10} futhark-0.25.27/tests_literate/video.fut000066400000000000000000000003561475065116200202270ustar00rootroot00000000000000def moving_point (x: i64, y: i64) : ([][]u32, (i64, i64)) = let canvas = replicate 128 (replicate 128 0) with [y, x] = 0x00FFFFFF in (canvas, (x + 1, y+1)) -- > :video (moving_point, (10i64, 10i64), 50i64,); -- fps: 1 -- format: gif futhark-0.25.27/tests_pkg/000077500000000000000000000000001475065116200153455ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/.gitignore000066400000000000000000000000201475065116200173250ustar00rootroot00000000000000lib futhark.pkg futhark-0.25.27/tests_pkg/README.md000066400000000000000000000010201475065116200166150ustar00rootroot00000000000000# futhark-pkg tests This directory contains a shell script (sorry) for testing futhark-pkg. This is done by serially performing package management operations, and after each operation comparing the resulting `lib` directory against an expected `lib` directory. It is distressingly awkward to write this using one of the normal Haskell test frameworks. The tests here are somewhat unstable, in that they depend on certain remote Git repositories to have specific contents. If they change, you have to change this test, too. futhark-0.25.27/tests_pkg/futhark.pkg.0000066400000000000000000000000621475065116200176500ustar00rootroot00000000000000package github.com/sturluson/testpkg require { } futhark-0.25.27/tests_pkg/futhark.pkg.1000066400000000000000000000001751475065116200176560ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-foo 0.1.0 #d285563c25c5152b1ae80fc64de64ff2775fa733 } futhark-0.25.27/tests_pkg/futhark.pkg.10000066400000000000000000000003121475065116200177270ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 github.com/athas/fut-foo@2 2.0.0 #9459117fa75aea8fffc677294d25f273e894d19f } futhark-0.25.27/tests_pkg/futhark.pkg.11000066400000000000000000000004441475065116200177360ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 github.com/athas/fut-foo@2 2.0.0 #9459117fa75aea8fffc677294d25f273e894d19f github.com/athas/fut-quux 0.0.0-20180801102532+b70028521e4dbcc286834b32ce82c1d2721a6209 } futhark-0.25.27/tests_pkg/futhark.pkg.12000066400000000000000000000004441475065116200177370ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 github.com/athas/fut-foo@2 2.0.0 #9459117fa75aea8fffc677294d25f273e894d19f github.com/athas/fut-quux 0.0.0-20180801102532+b70028521e4dbcc286834b32ce82c1d2721a6209 } futhark-0.25.27/tests_pkg/futhark.pkg.13000066400000000000000000000004441475065116200177400ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 github.com/athas/fut-foo@2 2.0.0 #9459117fa75aea8fffc677294d25f273e894d19f github.com/athas/fut-quux 0.0.0-20180801102533+dd5168df1b8a20cb0547a88afd4e4a6cc098e0f1 } futhark-0.25.27/tests_pkg/futhark.pkg.14000066400000000000000000000004441475065116200177410ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 github.com/athas/fut-foo@2 2.0.0 #9459117fa75aea8fffc677294d25f273e894d19f github.com/athas/fut-quux 0.0.0-20180801102533+dd5168df1b8a20cb0547a88afd4e4a6cc098e0f1 } futhark-0.25.27/tests_pkg/futhark.pkg.15000066400000000000000000000005621475065116200177430ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 github.com/athas/fut-foo@2 2.0.0 #9459117fa75aea8fffc677294d25f273e894d19f github.com/athas/fut-quux 0.0.0-20180801102533+dd5168df1b8a20cb0547a88afd4e4a6cc098e0f1 gitlab.com/athas/fut-gitlab 1.0.1 #631578b71d68381dd7461fc1d0c669cf84d0d5fe } futhark-0.25.27/tests_pkg/futhark.pkg.16000066400000000000000000000005621475065116200177440ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 github.com/athas/fut-foo@2 2.0.0 #9459117fa75aea8fffc677294d25f273e894d19f github.com/athas/fut-quux 0.0.0-20180801102533+dd5168df1b8a20cb0547a88afd4e4a6cc098e0f1 gitlab.com/athas/fut-gitlab 1.0.1 #631578b71d68381dd7461fc1d0c669cf84d0d5fe } futhark-0.25.27/tests_pkg/futhark.pkg.17000066400000000000000000000006001475065116200177360ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 github.com/athas/fut-foo@2 2.0.0 #9459117fa75aea8fffc677294d25f273e894d19f github.com/athas/fut-quux 0.0.0-20180801102533+dd5168df1b8a20cb0547a88afd4e4a6cc098e0f1 gitlab.com/athas/fut-gitlab 0.0.0-20180922095419+44bc83247e7b4995dd0c65acf3a72b70d6fe3efe } futhark-0.25.27/tests_pkg/futhark.pkg.18000066400000000000000000000006001475065116200177370ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 github.com/athas/fut-foo@2 2.0.0 #9459117fa75aea8fffc677294d25f273e894d19f github.com/athas/fut-quux 0.0.0-20180801102533+dd5168df1b8a20cb0547a88afd4e4a6cc098e0f1 gitlab.com/athas/fut-gitlab 0.0.0-20180922095419+44bc83247e7b4995dd0c65acf3a72b70d6fe3efe } futhark-0.25.27/tests_pkg/futhark.pkg.2000066400000000000000000000001751475065116200176570ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-foo 0.1.0 #d285563c25c5152b1ae80fc64de64ff2775fa733 } futhark-0.25.27/tests_pkg/futhark.pkg.3000066400000000000000000000003101475065116200176470ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-foo 0.1.0 #d285563c25c5152b1ae80fc64de64ff2775fa733 github.com/athas/fut-baz 0.1.0 #6d26e7059eb138f95d4d9747051fc0bb6a4eb85c } futhark-0.25.27/tests_pkg/futhark.pkg.4000066400000000000000000000003101475065116200176500ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-foo 0.1.0 #d285563c25c5152b1ae80fc64de64ff2775fa733 github.com/athas/fut-baz 0.1.0 #6d26e7059eb138f95d4d9747051fc0bb6a4eb85c } futhark-0.25.27/tests_pkg/futhark.pkg.5000066400000000000000000000003101475065116200176510ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-foo 0.2.0 #87d372c689131f33bef1b013ac2421fb5e75642b github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 } futhark-0.25.27/tests_pkg/futhark.pkg.6000066400000000000000000000003101475065116200176520ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-foo 0.2.0 #87d372c689131f33bef1b013ac2421fb5e75642b github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 } futhark-0.25.27/tests_pkg/futhark.pkg.7000066400000000000000000000001751475065116200176640ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 } futhark-0.25.27/tests_pkg/futhark.pkg.8000066400000000000000000000001751475065116200176650ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 } futhark-0.25.27/tests_pkg/futhark.pkg.9000066400000000000000000000003121475065116200176570ustar00rootroot00000000000000package github.com/sturluson/testpkg require { github.com/athas/fut-baz 0.2.0 #44da85224224d37803976c1d30cedb1d2cd20b74 github.com/athas/fut-foo@2 2.0.0 #9459117fa75aea8fffc677294d25f273e894d19f } futhark-0.25.27/tests_pkg/lib.10/000077500000000000000000000000001475065116200163325ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.10/github.com/000077500000000000000000000000001475065116200203715ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.10/github.com/athas/000077500000000000000000000000001475065116200214715ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.10/github.com/athas/fut-bar/000077500000000000000000000000001475065116200230315ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.10/github.com/athas/fut-bar/bar.fut000066400000000000000000000000241475065116200243110ustar00rootroot00000000000000let bar = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.10/github.com/athas/fut-baz/000077500000000000000000000000001475065116200230415ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.10/github.com/athas/fut-baz/baz.fut000066400000000000000000000002531475065116200243350ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! module bar_mod = import "../fut-bar/bar" -- Naughty! let baz = (0, 1, 1) let foo = foo_mod.foo let bar = bar_mod.bar futhark-0.25.27/tests_pkg/lib.10/github.com/athas/fut-foo/000077500000000000000000000000001475065116200230505ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.10/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200243470ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.10/github.com/athas/fut-foo@2/000077500000000000000000000000001475065116200232325ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.10/github.com/athas/fut-foo@2/foo.fut000066400000000000000000000001241475065116200245320ustar00rootroot00000000000000module fut_baz = import "../fut-baz/baz" let foo = (0, 2, 0) let baz = fut_baz.baz futhark-0.25.27/tests_pkg/lib.11/000077500000000000000000000000001475065116200163335ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.11/github.com/000077500000000000000000000000001475065116200203725ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.11/github.com/athas/000077500000000000000000000000001475065116200214725ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.11/github.com/athas/fut-bar/000077500000000000000000000000001475065116200230325ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.11/github.com/athas/fut-bar/bar.fut000066400000000000000000000000241475065116200243120ustar00rootroot00000000000000let bar = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.11/github.com/athas/fut-baz/000077500000000000000000000000001475065116200230425ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.11/github.com/athas/fut-baz/baz.fut000066400000000000000000000002531475065116200243360ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! module bar_mod = import "../fut-bar/bar" -- Naughty! let baz = (0, 1, 1) let foo = foo_mod.foo let bar = bar_mod.bar futhark-0.25.27/tests_pkg/lib.11/github.com/athas/fut-foo/000077500000000000000000000000001475065116200230515ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.11/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200243500ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.11/github.com/athas/fut-foo@2/000077500000000000000000000000001475065116200232335ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.11/github.com/athas/fut-foo@2/foo.fut000066400000000000000000000001241475065116200245330ustar00rootroot00000000000000module fut_baz = import "../fut-baz/baz" let foo = (0, 2, 0) let baz = fut_baz.baz futhark-0.25.27/tests_pkg/lib.12/000077500000000000000000000000001475065116200163345ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.12/github.com/000077500000000000000000000000001475065116200203735ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.12/github.com/athas/000077500000000000000000000000001475065116200214735ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.12/github.com/athas/fut-bar/000077500000000000000000000000001475065116200230335ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.12/github.com/athas/fut-bar/bar.fut000066400000000000000000000000241475065116200243130ustar00rootroot00000000000000let bar = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.12/github.com/athas/fut-baz/000077500000000000000000000000001475065116200230435ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.12/github.com/athas/fut-baz/baz.fut000066400000000000000000000002531475065116200243370ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! module bar_mod = import "../fut-bar/bar" -- Naughty! let baz = (0, 1, 1) let foo = foo_mod.foo let bar = bar_mod.bar futhark-0.25.27/tests_pkg/lib.12/github.com/athas/fut-foo/000077500000000000000000000000001475065116200230525ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.12/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200243510ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.12/github.com/athas/fut-foo@2/000077500000000000000000000000001475065116200232345ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.12/github.com/athas/fut-foo@2/foo.fut000066400000000000000000000001241475065116200245340ustar00rootroot00000000000000module fut_baz = import "../fut-baz/baz" let foo = (0, 2, 0) let baz = fut_baz.baz futhark-0.25.27/tests_pkg/lib.12/github.com/athas/fut-quux/000077500000000000000000000000001475065116200232715ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.12/github.com/athas/fut-quux/quux.fut000066400000000000000000000000211475065116200250040ustar00rootroot00000000000000let quux = "123" futhark-0.25.27/tests_pkg/lib.13/000077500000000000000000000000001475065116200163355ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.13/github.com/000077500000000000000000000000001475065116200203745ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.13/github.com/athas/000077500000000000000000000000001475065116200214745ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.13/github.com/athas/fut-bar/000077500000000000000000000000001475065116200230345ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.13/github.com/athas/fut-bar/bar.fut000066400000000000000000000000241475065116200243140ustar00rootroot00000000000000let bar = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.13/github.com/athas/fut-baz/000077500000000000000000000000001475065116200230445ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.13/github.com/athas/fut-baz/baz.fut000066400000000000000000000002531475065116200243400ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! module bar_mod = import "../fut-bar/bar" -- Naughty! let baz = (0, 1, 1) let foo = foo_mod.foo let bar = bar_mod.bar futhark-0.25.27/tests_pkg/lib.13/github.com/athas/fut-foo/000077500000000000000000000000001475065116200230535ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.13/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200243520ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.13/github.com/athas/fut-foo@2/000077500000000000000000000000001475065116200232355ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.13/github.com/athas/fut-foo@2/foo.fut000066400000000000000000000001241475065116200245350ustar00rootroot00000000000000module fut_baz = import "../fut-baz/baz" let foo = (0, 2, 0) let baz = fut_baz.baz futhark-0.25.27/tests_pkg/lib.13/github.com/athas/fut-quux/000077500000000000000000000000001475065116200232725ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.13/github.com/athas/fut-quux/quux.fut000066400000000000000000000000211475065116200250050ustar00rootroot00000000000000let quux = "123" futhark-0.25.27/tests_pkg/lib.14/000077500000000000000000000000001475065116200163365ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.14/github.com/000077500000000000000000000000001475065116200203755ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.14/github.com/athas/000077500000000000000000000000001475065116200214755ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.14/github.com/athas/fut-bar/000077500000000000000000000000001475065116200230355ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.14/github.com/athas/fut-bar/bar.fut000066400000000000000000000000241475065116200243150ustar00rootroot00000000000000let bar = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.14/github.com/athas/fut-baz/000077500000000000000000000000001475065116200230455ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.14/github.com/athas/fut-baz/baz.fut000066400000000000000000000002531475065116200243410ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! module bar_mod = import "../fut-bar/bar" -- Naughty! let baz = (0, 1, 1) let foo = foo_mod.foo let bar = bar_mod.bar futhark-0.25.27/tests_pkg/lib.14/github.com/athas/fut-foo/000077500000000000000000000000001475065116200230545ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.14/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200243530ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.14/github.com/athas/fut-foo@2/000077500000000000000000000000001475065116200232365ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.14/github.com/athas/fut-foo@2/foo.fut000066400000000000000000000001241475065116200245360ustar00rootroot00000000000000module fut_baz = import "../fut-baz/baz" let foo = (0, 2, 0) let baz = fut_baz.baz futhark-0.25.27/tests_pkg/lib.14/github.com/athas/fut-quux/000077500000000000000000000000001475065116200232735ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.14/github.com/athas/fut-quux/quux.fut000066400000000000000000000001361475065116200250150ustar00rootroot00000000000000-- This is not a version number, but this is also not a released -- version! let quux = "123" futhark-0.25.27/tests_pkg/lib.15/000077500000000000000000000000001475065116200163375ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.15/github.com/000077500000000000000000000000001475065116200203765ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.15/github.com/athas/000077500000000000000000000000001475065116200214765ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.15/github.com/athas/fut-bar/000077500000000000000000000000001475065116200230365ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.15/github.com/athas/fut-bar/bar.fut000066400000000000000000000000241475065116200243160ustar00rootroot00000000000000let bar = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.15/github.com/athas/fut-baz/000077500000000000000000000000001475065116200230465ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.15/github.com/athas/fut-baz/baz.fut000066400000000000000000000002531475065116200243420ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! module bar_mod = import "../fut-bar/bar" -- Naughty! let baz = (0, 1, 1) let foo = foo_mod.foo let bar = bar_mod.bar futhark-0.25.27/tests_pkg/lib.15/github.com/athas/fut-foo/000077500000000000000000000000001475065116200230555ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.15/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200243540ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.15/github.com/athas/fut-foo@2/000077500000000000000000000000001475065116200232375ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.15/github.com/athas/fut-foo@2/foo.fut000066400000000000000000000001241475065116200245370ustar00rootroot00000000000000module fut_baz = import "../fut-baz/baz" let foo = (0, 2, 0) let baz = fut_baz.baz futhark-0.25.27/tests_pkg/lib.15/github.com/athas/fut-quux/000077500000000000000000000000001475065116200232745ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.15/github.com/athas/fut-quux/quux.fut000066400000000000000000000001361475065116200250160ustar00rootroot00000000000000-- This is not a version number, but this is also not a released -- version! let quux = "123" futhark-0.25.27/tests_pkg/lib.16/000077500000000000000000000000001475065116200163405ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.16/github.com/000077500000000000000000000000001475065116200203775ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.16/github.com/athas/000077500000000000000000000000001475065116200214775ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.16/github.com/athas/fut-bar/000077500000000000000000000000001475065116200230375ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.16/github.com/athas/fut-bar/bar.fut000066400000000000000000000000241475065116200243170ustar00rootroot00000000000000let bar = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.16/github.com/athas/fut-baz/000077500000000000000000000000001475065116200230475ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.16/github.com/athas/fut-baz/baz.fut000066400000000000000000000002531475065116200243430ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! module bar_mod = import "../fut-bar/bar" -- Naughty! let baz = (0, 1, 1) let foo = foo_mod.foo let bar = bar_mod.bar futhark-0.25.27/tests_pkg/lib.16/github.com/athas/fut-foo/000077500000000000000000000000001475065116200230565ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.16/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200243550ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.16/github.com/athas/fut-foo@2/000077500000000000000000000000001475065116200232405ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.16/github.com/athas/fut-foo@2/foo.fut000066400000000000000000000001241475065116200245400ustar00rootroot00000000000000module fut_baz = import "../fut-baz/baz" let foo = (0, 2, 0) let baz = fut_baz.baz futhark-0.25.27/tests_pkg/lib.16/github.com/athas/fut-quux/000077500000000000000000000000001475065116200232755ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.16/github.com/athas/fut-quux/quux.fut000066400000000000000000000001361475065116200250170ustar00rootroot00000000000000-- This is not a version number, but this is also not a released -- version! let quux = "123" futhark-0.25.27/tests_pkg/lib.16/gitlab.com/000077500000000000000000000000001475065116200203575ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.16/gitlab.com/athas/000077500000000000000000000000001475065116200214575ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.16/gitlab.com/athas/fut-gitlab/000077500000000000000000000000001475065116200235155ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.16/gitlab.com/athas/fut-gitlab/gitlab.fut000066400000000000000000000000271475065116200254760ustar00rootroot00000000000000let gitlab = (1, 0, 1) futhark-0.25.27/tests_pkg/lib.17/000077500000000000000000000000001475065116200163415ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.17/github.com/000077500000000000000000000000001475065116200204005ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.17/github.com/athas/000077500000000000000000000000001475065116200215005ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.17/github.com/athas/fut-bar/000077500000000000000000000000001475065116200230405ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.17/github.com/athas/fut-bar/bar.fut000066400000000000000000000000241475065116200243200ustar00rootroot00000000000000let bar = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.17/github.com/athas/fut-baz/000077500000000000000000000000001475065116200230505ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.17/github.com/athas/fut-baz/baz.fut000066400000000000000000000002531475065116200243440ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! module bar_mod = import "../fut-bar/bar" -- Naughty! let baz = (0, 1, 1) let foo = foo_mod.foo let bar = bar_mod.bar futhark-0.25.27/tests_pkg/lib.17/github.com/athas/fut-foo/000077500000000000000000000000001475065116200230575ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.17/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200243560ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.17/github.com/athas/fut-foo@2/000077500000000000000000000000001475065116200232415ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.17/github.com/athas/fut-foo@2/foo.fut000066400000000000000000000001241475065116200245410ustar00rootroot00000000000000module fut_baz = import "../fut-baz/baz" let foo = (0, 2, 0) let baz = fut_baz.baz futhark-0.25.27/tests_pkg/lib.17/github.com/athas/fut-quux/000077500000000000000000000000001475065116200232765ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.17/github.com/athas/fut-quux/quux.fut000066400000000000000000000001361475065116200250200ustar00rootroot00000000000000-- This is not a version number, but this is also not a released -- version! let quux = "123" futhark-0.25.27/tests_pkg/lib.18/000077500000000000000000000000001475065116200163425ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.18/github.com/000077500000000000000000000000001475065116200204015ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.18/github.com/athas/000077500000000000000000000000001475065116200215015ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.18/github.com/athas/fut-bar/000077500000000000000000000000001475065116200230415ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.18/github.com/athas/fut-bar/bar.fut000066400000000000000000000000241475065116200243210ustar00rootroot00000000000000let bar = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.18/github.com/athas/fut-baz/000077500000000000000000000000001475065116200230515ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.18/github.com/athas/fut-baz/baz.fut000066400000000000000000000002531475065116200243450ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! module bar_mod = import "../fut-bar/bar" -- Naughty! let baz = (0, 1, 1) let foo = foo_mod.foo let bar = bar_mod.bar futhark-0.25.27/tests_pkg/lib.18/github.com/athas/fut-foo/000077500000000000000000000000001475065116200230605ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.18/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200243570ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.18/github.com/athas/fut-foo@2/000077500000000000000000000000001475065116200232425ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.18/github.com/athas/fut-foo@2/foo.fut000066400000000000000000000001241475065116200245420ustar00rootroot00000000000000module fut_baz = import "../fut-baz/baz" let foo = (0, 2, 0) let baz = fut_baz.baz futhark-0.25.27/tests_pkg/lib.18/github.com/athas/fut-quux/000077500000000000000000000000001475065116200232775ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.18/github.com/athas/fut-quux/quux.fut000066400000000000000000000001361475065116200250210ustar00rootroot00000000000000-- This is not a version number, but this is also not a released -- version! let quux = "123" futhark-0.25.27/tests_pkg/lib.18/gitlab.com/000077500000000000000000000000001475065116200203615ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.18/gitlab.com/athas/000077500000000000000000000000001475065116200214615ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.18/gitlab.com/athas/fut-gitlab/000077500000000000000000000000001475065116200235175ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.18/gitlab.com/athas/fut-gitlab/gitlab.fut000066400000000000000000000000561475065116200255020ustar00rootroot00000000000000-- Unreleased version. let gitlab = (1, 0, 1) futhark-0.25.27/tests_pkg/lib.2/000077500000000000000000000000001475065116200162535ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.2/github.com/000077500000000000000000000000001475065116200203125ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.2/github.com/athas/000077500000000000000000000000001475065116200214125ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.2/github.com/athas/fut-foo/000077500000000000000000000000001475065116200227715ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.2/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200242700ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.3/000077500000000000000000000000001475065116200162545ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.3/github.com/000077500000000000000000000000001475065116200203135ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.3/github.com/athas/000077500000000000000000000000001475065116200214135ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.3/github.com/athas/fut-foo/000077500000000000000000000000001475065116200227725ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.3/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200242710ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.4/000077500000000000000000000000001475065116200162555ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.4/github.com/000077500000000000000000000000001475065116200203145ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.4/github.com/athas/000077500000000000000000000000001475065116200214145ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.4/github.com/athas/fut-baz/000077500000000000000000000000001475065116200227645ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.4/github.com/athas/fut-baz/baz.fut000066400000000000000000000001401475065116200242530ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! let baz = (0, 1, 0) let foo = foo_mod.foo futhark-0.25.27/tests_pkg/lib.4/github.com/athas/fut-foo/000077500000000000000000000000001475065116200227735ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.4/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200242720ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.5/000077500000000000000000000000001475065116200162565ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.5/github.com/000077500000000000000000000000001475065116200203155ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.5/github.com/athas/000077500000000000000000000000001475065116200214155ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.5/github.com/athas/fut-baz/000077500000000000000000000000001475065116200227655ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.5/github.com/athas/fut-baz/baz.fut000066400000000000000000000001401475065116200242540ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! let baz = (0, 1, 0) let foo = foo_mod.foo futhark-0.25.27/tests_pkg/lib.5/github.com/athas/fut-foo/000077500000000000000000000000001475065116200227745ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.5/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200242730ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.6/000077500000000000000000000000001475065116200162575ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.6/github.com/000077500000000000000000000000001475065116200203165ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.6/github.com/athas/000077500000000000000000000000001475065116200214165ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.6/github.com/athas/fut-bar/000077500000000000000000000000001475065116200227565ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.6/github.com/athas/fut-bar/bar.fut000066400000000000000000000000241475065116200242360ustar00rootroot00000000000000let bar = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.6/github.com/athas/fut-baz/000077500000000000000000000000001475065116200227665ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.6/github.com/athas/fut-baz/baz.fut000066400000000000000000000002531475065116200242620ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! module bar_mod = import "../fut-bar/bar" -- Naughty! let baz = (0, 1, 1) let foo = foo_mod.foo let bar = bar_mod.bar futhark-0.25.27/tests_pkg/lib.6/github.com/athas/fut-foo/000077500000000000000000000000001475065116200227755ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.6/github.com/athas/fut-foo/foo.fut000066400000000000000000000001531475065116200242770ustar00rootroot00000000000000module fut_bar = import "../../../github.com/athas/fut-bar/bar" let foo = (0, 2, 0) let bar = fut_bar.bar futhark-0.25.27/tests_pkg/lib.7/000077500000000000000000000000001475065116200162605ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.7/github.com/000077500000000000000000000000001475065116200203175ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.7/github.com/athas/000077500000000000000000000000001475065116200214175ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.7/github.com/athas/fut-bar/000077500000000000000000000000001475065116200227575ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.7/github.com/athas/fut-bar/bar.fut000066400000000000000000000000241475065116200242370ustar00rootroot00000000000000let bar = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.7/github.com/athas/fut-baz/000077500000000000000000000000001475065116200227675ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.7/github.com/athas/fut-baz/baz.fut000066400000000000000000000002531475065116200242630ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! module bar_mod = import "../fut-bar/bar" -- Naughty! let baz = (0, 1, 1) let foo = foo_mod.foo let bar = bar_mod.bar futhark-0.25.27/tests_pkg/lib.7/github.com/athas/fut-foo/000077500000000000000000000000001475065116200227765ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.7/github.com/athas/fut-foo/foo.fut000066400000000000000000000001531475065116200243000ustar00rootroot00000000000000module fut_bar = import "../../../github.com/athas/fut-bar/bar" let foo = (0, 2, 0) let bar = fut_bar.bar futhark-0.25.27/tests_pkg/lib.8/000077500000000000000000000000001475065116200162615ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.8/github.com/000077500000000000000000000000001475065116200203205ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.8/github.com/athas/000077500000000000000000000000001475065116200214205ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.8/github.com/athas/fut-bar/000077500000000000000000000000001475065116200227605ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.8/github.com/athas/fut-bar/bar.fut000066400000000000000000000000241475065116200242400ustar00rootroot00000000000000let bar = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.8/github.com/athas/fut-baz/000077500000000000000000000000001475065116200227705ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.8/github.com/athas/fut-baz/baz.fut000066400000000000000000000002531475065116200242640ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! module bar_mod = import "../fut-bar/bar" -- Naughty! let baz = (0, 1, 1) let foo = foo_mod.foo let bar = bar_mod.bar futhark-0.25.27/tests_pkg/lib.8/github.com/athas/fut-foo/000077500000000000000000000000001475065116200227775ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.8/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200242760ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.9/000077500000000000000000000000001475065116200162625ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.9/github.com/000077500000000000000000000000001475065116200203215ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.9/github.com/athas/000077500000000000000000000000001475065116200214215ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.9/github.com/athas/fut-bar/000077500000000000000000000000001475065116200227615ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.9/github.com/athas/fut-bar/bar.fut000066400000000000000000000000241475065116200242410ustar00rootroot00000000000000let bar = (0, 1, 0) futhark-0.25.27/tests_pkg/lib.9/github.com/athas/fut-baz/000077500000000000000000000000001475065116200227715ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.9/github.com/athas/fut-baz/baz.fut000066400000000000000000000002531475065116200242650ustar00rootroot00000000000000module foo_mod = import "../fut-foo/foo" -- Naughty! module bar_mod = import "../fut-bar/bar" -- Naughty! let baz = (0, 1, 1) let foo = foo_mod.foo let bar = bar_mod.bar futhark-0.25.27/tests_pkg/lib.9/github.com/athas/fut-foo/000077500000000000000000000000001475065116200230005ustar00rootroot00000000000000futhark-0.25.27/tests_pkg/lib.9/github.com/athas/fut-foo/foo.fut000066400000000000000000000000241475065116200242770ustar00rootroot00000000000000let foo = (0, 1, 0) futhark-0.25.27/tests_pkg/test.sh000077500000000000000000000030771475065116200166720ustar00rootroot00000000000000#!/bin/sh # # You must be in the directory when running this script. It does not # try to be clever. set -e # Die on error. lastrun="" expects () { # Hack to correctly handle empty directories, which are otherwise # not committed to Git. if ! diff -urN $1 $2; then echo "After command '$lastrun', $1 does not match $2" exit 1 fi } i=0 succeed () { lastrun="$@" echo '$' "$@" if ! "$@"; then echo "Command '$lastrun' failed unexpectedly." exit 1 fi expects futhark.pkg futhark.pkg.$i expects lib lib.$i i=$(($i+1)) } fail () { lastrun="$@" echo '$' "$@" if "$@"; then echo "Command '$lastrun' succeeded unexpectedly." exit 1 fi } # Clean up after previous test runs. rm -rf futhark.pkg lib succeed futhark pkg init github.com/sturluson/testpkg succeed futhark pkg add github.com/athas/fut-foo 0.1.0 succeed futhark pkg sync succeed futhark pkg add github.com/athas/fut-baz 0.1.0 succeed futhark pkg sync succeed futhark pkg upgrade succeed futhark pkg sync succeed futhark pkg remove github.com/athas/fut-foo succeed futhark pkg sync succeed futhark pkg add github.com/athas/fut-foo@2 2.0.0 succeed futhark pkg sync succeed futhark pkg add github.com/athas/fut-quux 0.0.0-20180801102532+b70028521e4dbcc286834b32ce82c1d2721a6209 succeed futhark pkg sync succeed futhark pkg add github.com/athas/fut-quux 0.0.0-20180801102533+dd5168df1b8a20cb0547a88afd4e4a6cc098e0f1 succeed futhark pkg sync succeed futhark pkg add gitlab.com/athas/fut-gitlab succeed futhark pkg sync futhark-0.25.27/tests_repl/000077500000000000000000000000001475065116200155265ustar00rootroot00000000000000futhark-0.25.27/tests_repl/.gitignore000066400000000000000000000000111475065116200175060ustar00rootroot00000000000000*.actual futhark-0.25.27/tests_repl/issue1347.fut000066400000000000000000000002351475065116200177150ustar00rootroot00000000000000entry blockify [n] (b: i64) (xs: [n][n]i32) = (xs :> [(n/b)*b][(n/b)*b]i32) |> unflatten |> map transpose |> map unflatten |> map (map transpose) futhark-0.25.27/tests_repl/issue1347.in000066400000000000000000000001711475065116200175240ustar00rootroot00000000000000let blocked = blockify 2 (copy <| unflatten (map i32.i64 <| iota (8*8))) let pre_transpose = map (map transpose) blocked futhark-0.25.27/tests_repl/issue1347.out000066400000000000000000000000471475065116200177270ustar00rootroot00000000000000[0]> [1]> [2]> Leaving 'futhark repl'. futhark-0.25.27/tests_repl/local.fut000066400000000000000000000000341475065116200173350ustar00rootroot00000000000000local def secret : i32 = 42 futhark-0.25.27/tests_repl/local.in000066400000000000000000000000071475065116200171450ustar00rootroot00000000000000secret futhark-0.25.27/tests_repl/local.out000066400000000000000000000000451475065116200173500ustar00rootroot00000000000000[0]> 42 [1]> Leaving 'futhark repl'. futhark-0.25.27/tests_repl/test.sh000077500000000000000000000005041475065116200170430ustar00rootroot00000000000000#!/bin/sh set -e do_test() { prog=$1 in=$(basename -s .fut $prog).in out=$(basename -s .fut $prog).out actual=$(basename -s .fut $prog).actual echo $prog futhark repl $prog < $in > $actual if ! diff -u $actual $out; then exit 1 fi } for x in *.fut; do do_test $x || exit 1; done futhark-0.25.27/tools/000077500000000000000000000000001475065116200145025ustar00rootroot00000000000000futhark-0.25.27/tools/.gitignore000066400000000000000000000000131475065116200164640ustar00rootroot00000000000000__pycache__futhark-0.25.27/tools/README.md000066400000000000000000000025221475065116200157620ustar00rootroot00000000000000Futhark tools ============= This directory contains useful programs for working with and and on the futhark compiler. Below are the (non-comprehensive) descriptions of some of the tools in this directory. Style ------------- `style-check`. The futhark CI does a style check using the following style checker. It's convienant to run it locally before making a PR, as merging will be blocked if it doesn't pass. `style-check.sh` assumes that [hlint](https://github.com/ndmitchell/hlint) is installed. ```bash ./style-check.sh src/futhark.hs ``` Similarly `run-formatter.sh` is a convienent tool for automatic code formatting. Useful for catching some obvious issues that trigger the style checker. It requires [ormolu](https://github.com/tweag/ormolu) to be installed. It can be run on files or whole directories. ```bash ./run-formatter.sh src unittests ``` ``` GtkSourceView ------------- To install the Futhark syntax highlighter for GtkSourceView (e.g. used by Gedit), copy `futhark.lang` and place it in the following directory: ~/.local/share/gtksourceview-3.0/language-specs/ Restart Gedit and open a `*.fut`-file. Miscellaneous ------------- `generate-module-list.sh` generates a list of exposed modules for the .cabal file (excluding the parser files). This should be run from futhark's root directory. ``` ./tools/generate-module-list.sh ``` futhark-0.25.27/tools/auto-hlint.sh000077500000000000000000000005441475065116200171300ustar00rootroot00000000000000#!/bin/sh # # Automatically apply hlint rules to the given file(s) and directories. # # Note: you may need to run this tool to a fixed point manually, as # some rewrites only do partial improvements, and the resulting code # may still violate other rules. for d in "$@"; do find "$d" -name \*.hs | parallel hlint --refactor --refactor-options=-i done futhark-0.25.27/tools/bench-compilers.py000077500000000000000000000316271475065116200201420ustar00rootroot00000000000000#!/usr/bin/env python3 # # Construct PNG graphs of compile- and run-time performance for # historical versions of the Futhark compiler. The legend is almost # useless, so manual investigation is needed to get something usable # out of this. # # This program is quite naive in its construction of commands passed # to os.system, so don't run it from a directory with interesting # characters in its absolute path. Run it from an initially empty # directory, because it will create subdirectories for its own # purposes. # # This program has lots of hacks and special cases to handle the # various changes in Futhark tooling over the years. But all things # considered, it's not too bad. # # Regarding the special 'nightly' version, delete it from the # releases/ directory if you want this script to download a newer one. # # $ bench-compilers.py plot_runtime 0.1.0 0.2.0 0.3.0 0.4.0 0.4.1 0.5.1 0.6.1 0.6.2 0.6.3 0.7.1 0.7.2 0.7.3 0.7.4 0.8.1 0.1 0.1 0.10.2 0.11.1 0.11.2 0.12.1 0.12.2 0.12.3 0.13.1 0.13.2 0.14.1 0.15.1 0.15.2 0.15.3 0.15.4 0.15.5 0.15.6 0.15.7 0.15.8 0.16.1 0.16.2 0.16.3 0.16.4 0.16.5 0.17.2 0.17 .3 0.18.1 0.18.2 0.18.3 0.18.4 0.18.5 0.18.6 0.19.1 0.19.2 0.19.3 0.19.4 0.19.5 nightly # nopep8 # # The example above skips 0.9.1 because a few benchmarks do not # validate with it. Seems like there was a bug with certain # reductions. import sys import os.path import urllib.request import subprocess from pathlib import Path import json import time import matplotlib.pyplot as plt import matplotlib.ticker import numpy as np RELEASES_DIR = "releases" GITS_DIR = "gits" JSON_DIR = "json" FUTHARK_BACKEND = "opencl" FUTHARK_REPO = "https://github.com/diku-dk/futhark.git" FUTHARK_BENCHMARKS_REPO = "https://github.com/diku-dk/futhark-benchmarks.git" FUTHARK_DIR = os.path.join(GITS_DIR, "futhark") FUTHARK_BENCHMARKS_DIR = os.path.join(GITS_DIR, "futhark-benchmarks") BENCH_RUNS = 20 # For newer versions than this, the 'futhark-benchmarks' submodule # points to the right revision already. versions_to_benchmark_commits = { "0.3.1": "3b9c6cd06784fbf94c40a30f4302eeef119352b7", "0.3.0": "289000e5734705b45840ee6315bd683ab14b6ddb", "0.2.0": "33aabdd88bb3c935934032c5d49a7563f7078322", "0.1.0": "24cba07b5b70d3cbcf4195108a4caddbb649481f", } # Hand-curated set of benchmarks and corresponding data sets that have # existed since version 0.1.0 benchmarks = { "futhark-benchmarks/accelerate/canny/canny.fut": "data/lena512.in", "futhark-benchmarks/accelerate/crystal/crystal.fut": '#6 ("4000i32 30.0f32 50i32 1i32 1.0f32")', "futhark-benchmarks/accelerate/fluid/fluid.fut": "benchmarking/medium.in", "futhark-benchmarks/accelerate/mandelbrot/mandelbrot.fut": '#5 ("8000i32 8000i32 -0.7f32 0.0f32 3.067f32 100i32 16....")', "futhark-benchmarks/accelerate/nbody/nbody.fut": "data/100000-bodies.in", "futhark-benchmarks/accelerate/tunnel/tunnel.fut": '#5 ("10.0f32 8000i32 8000i32")', "futhark-benchmarks/finpar/LocVolCalib.fut": "LocVolCalib-data/large.in", "futhark-benchmarks/finpar/OptionPricing.fut": "OptionPricing-data/large.in", "futhark-benchmarks/jgf/crypt/crypt.fut": "crypt-data/medium.in", "futhark-benchmarks/jgf/series/series.fut": "data/1000000.in", "futhark-benchmarks/misc/heston/heston32.fut": "data/100000_quotes.in", "futhark-benchmarks/misc/heston/heston64.fut": "data/100000_quotes.in", "futhark-benchmarks/misc/radix_sort/radix_sort_large.fut": "data/radix_sort_1M.in", "futhark-benchmarks/parboil/mri-q/mri-q.fut": "data/large.in", "futhark-benchmarks/parboil/sgemm/sgemm.fut": "data/medium.in", "futhark-benchmarks/parboil/stencil/stencil.fut": "data/default.in", "futhark-benchmarks/parboil/tpacf/tpacf.fut": "data/large.in", "futhark-benchmarks/rodinia/backprop/backprop.fut": "data/medium.in", "futhark-benchmarks/rodinia/bfs/bfs_filt_padded_fused.fut": "data/graph1MW_6.in", "futhark-benchmarks/rodinia/cfd/cfd.fut": "data/fvcorr.domn.193K.toa", "futhark-benchmarks/rodinia/hotspot/hotspot.fut": "data/1024.in", "futhark-benchmarks/rodinia/kmeans/kmeans.fut": "data/kdd_cup.in", "futhark-benchmarks/rodinia/lavaMD/lavaMD.fut": "data/10_boxes.in", "futhark-benchmarks/rodinia/lud/lud.fut": "data/2048.in", "futhark-benchmarks/rodinia/myocyte/myocyte.fut": "data/medium.in", "futhark-benchmarks/rodinia/nn/nn.fut": "data/medium.in", "futhark-benchmarks/rodinia/pathfinder/pathfinder.fut": "data/medium.in", "futhark-benchmarks/rodinia/srad/srad.fut": "data/image.in", } # Some data sets changed names over time - renormalise here. def datamap(d): m = { '#6 ("4000i32 30.0f32 50i32 1i32 1.0f32")': '#5 ("4000i32 30.0f32 50i32 1i32 1.0f32")', '#5 ("8000i32 8000i32 -0.7f32 0.0f32 3.067f32 100i32 16....")': '#4 ("8000i32 8000i32 -0.7f32 0.0f32 3.067f32 100i32 16....")', '#5 ("10.0f32 8000i32 8000i32")': '#4 ("10.0f32 8000i32 8000i32")', } if d in m: return m[d] else: return d def shell(cmd): print("$ {}".format(cmd)) if os.system(cmd) != 0: raise Exception("shell command failed") def version_before(x, y): if x == "nightly": return False else: return tuple(map(int, x.split("."))) < tuple(map(int, y.split("."))) def runtime_json_for_version(version): return os.path.join(JSON_DIR, "{}.json".format(version)) def compile_json_for_version(version): return os.path.join(JSON_DIR, "{}-compile.json".format(version)) def release_for_version(version): return os.path.join( RELEASES_DIR, "futhark-{}-linux-x86_64".format(version) ) def tarball_for_version(version): return os.path.join( RELEASES_DIR, "futhark-{}-linux-x86_64.tar.xz".format(version) ) def futhark_compile_for_version(version): if version_before(version, "0.9.1"): return os.path.abspath( os.path.join( release_for_version(version), "bin", "futhark-{}".format(FUTHARK_BACKEND), ) ) else: return os.path.join( release_for_version(version), "bin/futhark {}".format(FUTHARK_BACKEND), ) def futhark_bench_for_version(version): bs = " ".join(list(benchmarks.keys())) if version_before(version, "0.9.1"): prog = os.path.abspath( os.path.join(release_for_version(version), "bin", "futhark-bench") ) compiler = os.path.abspath( os.path.join( release_for_version(version), "bin", "futhark-{}".format(FUTHARK_BACKEND), ) ) return "{} -r {} --compiler={} {}".format( prog, BENCH_RUNS, compiler, bs ) else: prog = os.path.abspath( os.path.join(release_for_version(version), "bin", "futhark") ) return "{} bench -r {} --backend={} --futhark={} {}".format( prog, BENCH_RUNS, FUTHARK_BACKEND, prog, bs ) def tarball_url(version): return "https://futhark-lang.org/releases/futhark-{}-linux-x86_64.tar.xz".format( version ) def ensure_tarball(version): f = tarball_for_version(version) if not os.path.exists(f): print("Downloading {}...".format(tarball_url(version))) urllib.request.urlretrieve(tarball_url(version), f) def ensure_release(version): if not os.path.exists(release_for_version(version)): ensure_tarball(version) print("Extracting {}...".format(tarball_for_version(version))) shell( "cd {} && tar xf ../{}".format( RELEASES_DIR, tarball_for_version(version) ) ) def ensure_repo(what, url): dir = os.path.join(GITS_DIR, what) if os.path.exists(dir): shell("cd {} && git checkout master && git pull".format(dir)) else: shell("git clone --recursive {} {}".format(url, dir)) return dir def ensure_futhark_repo(): return ensure_repo("futhark", FUTHARK_REPO) def ensure_benchmarks(version): if version in versions_to_benchmark_commits: ensure_repo("futhark-benchmarks", FUTHARK_BENCHMARKS_REPO) shell( "cd {} && git checkout -f {}".format( FUTHARK_BENCHMARKS_DIR, versions_to_benchmark_commits[version] ) ) return GITS_DIR else: d = ensure_futhark_repo() tag = "master" if version == "nightly" else "v{}".format(version) shell( "cd {} && git checkout -f {} && git submodule update".format( d, tag ) ) return os.path.join(d) def produce_runtime_json(version): f = runtime_json_for_version(version) ensure_release(version) ensure_repo("futhark", FUTHARK_REPO) d = ensure_benchmarks(version) json_f = os.path.abspath(runtime_json_for_version(version)) shell( "cd {} && {} --json={}".format( d, futhark_bench_for_version(version), json_f ) ) def ensure_runtime_json(version): f = runtime_json_for_version(version) if not os.path.exists(f): produce_runtime_json(version) return f def produce_compile_json(version): ensure_release(version) d = ensure_benchmarks(version) compiletimes = {} for b in benchmarks: bef = time.time() shell( "{} {}".format( futhark_compile_for_version(version), os.path.join(d, b) ) ) aft = time.time() compiletimes[b] = aft - bef with open(compile_json_for_version(version), "w") as f: json.dump(compiletimes, f) def ensure_compile_json(version): f = compile_json_for_version(version) if not os.path.exists(f): produce_compile_json(version) return f def ensure_json(version): ensure_runtime_json(version) ensure_compile_json(version) def load_json(f): return json.load(open(f, "r")) def plot_runtime(versions): fastest = {} results_per_bench = {} for v in versions: results = load_json(ensure_runtime_json(v)) for b, d in benchmarks.items(): try: x = np.median(results[b]["datasets"][d]["runtimes"]) except KeyError: x = np.median(results[b]["datasets"][datamap(d)]["runtimes"]) if (b, d) in fastest: fastest[b, d] = min(fastest[b, d], x) else: fastest[b, d] = x if (b, d) in results_per_bench: results_per_bench[b, d] += [x] else: results_per_bench[b, d] = [x] for b, d in benchmarks.items(): results_per_bench[b, d] = ( np.array(results_per_bench[b, d]) / fastest[b, d] ) fig, ax = plt.subplots() ax.set_yscale("log") ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter()) ax.set_ylabel("Runtime relative to fastest") plt.xticks(rotation=90) for b, d in benchmarks.items(): xs = versions + ["..."] ys = np.append(results_per_bench[b, d], [results_per_bench[b, d][-1]]) ax.plot(xs, ys, label=b[len("futhark-benchmarks/") :]) ax.grid() ax.legend( bbox_to_anchor=(0, -0.35), ncol=3, loc="center left", fontsize=5.5 ) fig.savefig("runtime.png", dpi=300, figsize=(12, 6), bbox_inches="tight") def plot_compiletime(versions): fastest = {} results_per_bench = {} for v in versions: results = load_json(ensure_compile_json(v)) for b in benchmarks: x = np.mean(results[b]) if b in fastest: fastest[b] = min(fastest[b], x) else: fastest[b] = x if b in results_per_bench: results_per_bench[b] += [x] else: results_per_bench[b] = [x] for b in benchmarks: results_per_bench[b] = np.array(results_per_bench[b]) / fastest[b] fig, ax = plt.subplots() ax.set_yscale("log") ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter()) ax.set_ylabel("Compilation time relative to fastest") plt.xticks(rotation=90) for b in benchmarks: xs = versions + ["..."] ys = np.append(results_per_bench[b], [results_per_bench[b][-1]]) ax.plot(xs, ys, label=b[len("futhark-benchmarks/") :]) ax.grid() ax.legend( bbox_to_anchor=(0, -0.35), ncol=3, loc="center left", fontsize=5.5 ) fig.savefig( "compiletime.png", dpi=300, figsize=(12, 6), bbox_inches="tight" ) if __name__ == "__main__": _, cmd, *args = sys.argv Path(RELEASES_DIR).mkdir(parents=True, exist_ok=True) Path(GITS_DIR).mkdir(parents=True, exist_ok=True) Path(JSON_DIR).mkdir(parents=True, exist_ok=True) if cmd == "bench": ensure_json(*args) elif cmd == "benchdir": print(os.path.join(ensure_benchmarks(*args), "futhark-benchmarks")) elif cmd == "plot_runtime": plot_runtime(args) elif cmd == "plot_compiletime": plot_compiletime(args) else: print("Invalid command: {}\n".format(cmd)) sys.exit(1) futhark-0.25.27/tools/changelog.sh000077500000000000000000000004241475065116200167700ustar00rootroot00000000000000#!/bin/sh # # Extract CHANGELOG.md entries for a given version number. # # Usage: # # $ sh changelog.sh x.y.z < CHANGELOG.md set -e if [ $# -ne 1 ]; then echo "Usage: $0 x.y.z" fi awk '/^## / { thisone = index($0,V) != 0; next } \ thisone { print }' \ "V=$1" futhark-0.25.27/tools/commit-impact.py000077500000000000000000000036711475065116200176310ustar00rootroot00000000000000#!/usr/bin/env python3 # # See the impact of a Futhark commit compared to the previous one we # have benchmarking for. import sys import subprocess from urllib.request import urlopen from urllib.error import HTTPError import json import tempfile import os import gzip def url_for(backend, system, commit): return ( "https://futhark-lang.org/benchmark-results/futhark-" "{}-{}-{}.json.gz" ).format(backend, system, commit) def results_for_commit(backend, system, commit): try: url = url_for(backend, system, commit) print("Fetching {}...".format(url)) return json.loads(gzip.decompress(urlopen(url).read())) except HTTPError: return None def first_commit_with_results(backend, system, commits): for commit in commits: res = results_for_commit(backend, system, commit) if res: return commit, res def find_commits(start): return ( subprocess.check_output(["git", "rev-list", start]) .decode("utf-8") .splitlines() ) if __name__ == "__main__": backend, system, commit = sys.argv[1:4] now = results_for_commit(backend, system, commit) if not now: print("No results found") sys.exit(1) if len(sys.argv) == 5: commits = find_commits(sys.argv[4]) else: commits = find_commits(commit)[1:] then_commit, then = first_commit_with_results(backend, system, commits[1:]) print("Comparing {}".format(commit)) print(" with {}".format(then_commit)) with tempfile.NamedTemporaryFile(prefix=commit, mode="w") as now_file: with tempfile.NamedTemporaryFile( prefix=then_commit, mode="w" ) as then_file: json.dump(now, now_file) json.dump(then, then_file) now_file.flush() then_file.flush() os.system( "futhark benchcmp {} {}".format(then_file.name, now_file.name) ) futhark-0.25.27/tools/compare-compiler-versions.sh000077500000000000000000000011011475065116200221360ustar00rootroot00000000000000#!/usr/bin/env bash # # Quick hack to compare compiler revisions by automatically pulling # statistics from futhark-lang.org/benchmark-results. # # Pass the reference version as the first argument. You must run this # while standing in the Futhark directory. if [ $# -ne 4 ]; then echo "Usage: $0 " >&2 exit 1 fi RESULTS_URL=https://futhark-lang.org/benchmark-results/ a_url=$RESULTS_URL/$1-$2.json b_url=$RESULTS_URL/$3-$4.json echo "Fetching $a_url" echo " and $b_url" futhark benchcmp <(curl "$a_url") <(curl "$b_url") futhark-0.25.27/tools/data2fut.py000077500000000000000000000014571475065116200166000ustar00rootroot00000000000000#!/usr/bin/env python3 # # Convert the contents of a file to a Futhark-compatible binary data # file. The file is interpreted as raw binary data corresponding to a # one-dimensional array. # # Takes a single option: the name of the element type. Only the # numeric types are supported (no floats). # # Usage example: # # $ ./data2fut.py f32 INPUT OUTPUT import futhark_data as fd import numpy as np import sys types = { "f16": np.float16, "f32": np.float32, "f64": np.float64, "i8": np.int8, "i16": np.int16, "i32": np.int32, "i64": np.int64, "u8": np.uint8, "u16": np.uint16, "u32": np.uint32, "u64": np.uint64, } _, type, input, output = sys.argv dtype = types[type] v = np.fromfile(open(input, mode="rb"), dtype) fd.dump(v, open(output, mode="wb"), binary=True) futhark-0.25.27/tools/data2png.py000077500000000000000000000063501475065116200165630ustar00rootroot00000000000000#!/usr/bin/env python3 # # Turn a Futhark value encoded in the binary data format in some file # into a PNG image in another file. Very little error checking is # done. The following input types are supported: # # [height][width]u8 # [height][width]i32 # [height][width]u32 # # [height][width]f32 # [height][width]f64 # # [height][width][3]i8 # [height][width][3]u8 # # Requires purepng and Numpy. # # Example: # # $ cat input | ./some-futhark-program | tools/data2png.py /dev/stdin out.png import sys import struct import numpy as np import png if "purepng" not in png.__file__: print("data2png works only with purepng, not pypng.", file=sys.stderr) sys.exit(1) def read_image(f): # Skip binary header. assert f.read(1) == b"b", "Invalid binary header" assert f.read(1) == b"\2", "Invalid binary format" rank = np.int8(struct.unpack("> 16 image[:, :, 1] = (array & 0xFF00) >> 8 image[:, :, 2] = array & 0xFF return (width, height, np.reshape(image, (height, width * 3))) elif rank == 3 and type in [b" i8", b" u8"]: height = np.int64(struct.unpack(" text/x-futhark;application/x-futhark;text/x-futhark-source;application/x-futhark-source *.fut --

{root}

{sections} """ def task(plot_jobs: Dict[str, PlotJob]) -> None: """Begins plotting, it is used""" global plots global PLOT_TYPES_USED global TRANSPARENT plot_types = [ plot_type() # type: ignore for key, plot_type in ALL_PLOT_TYPES.items() if key in PLOT_TYPES_USED ] plotter = Plotter(plot_types, dpi=200, transparent=TRANSPARENT) plotter.plot(plot_jobs) TRANSPARENT: bool def main() -> None: global PLOT_TYPES_USED global TRANSPARENT plt.rcParams.update( { "ytick.color": "black", "xtick.color": "black", "axes.labelcolor": "black", "axes.edgecolor": "black", "axes.axisbelow": True, "text.usetex": False, "axes.prop_cycle": cycler(color=["#5f021f"]), } ) args = get_args() filename = pathlib.Path(args.filename).stem data = json.load(open(args.filename, "r")) programs = format_arg_list(args.programs) plots_used = format_arg_list(args.plots) if plots_used is None: PLOT_TYPES_USED = list(sorted(ALL_PLOT_TYPES.keys())) else: PLOT_TYPES_USED = list(sorted(plots_used)) temp = list(ALL_PLOT_TYPES.keys()) for plot_type in PLOT_TYPES_USED: if plot_type not in temp: existing_plot_types = ", ".join(temp) raise Exception( ( '"{plot_type}" is not a plot type try ' f"{existing_plot_types}" ) ) filetype = args.filetype TRANSPARENT = args.transparent root = f"{filename}-plots" if os.path.exists(root): raise Exception( ( f'The folder "{root}" must be removed before the plots can be ' "made." ) ) if programs is None: programs = set(data.keys()) else: programs = set(programs) keys = set(data.keys()) if not programs.issubset(keys): diff = ", ".join(programs.difference(keys)) raise Exception(f'"{diff}" are not valid keys.') plot_jobs, folder_content = make_plot_jobs_and_directories( list(programs), data, filetype, PLOT_TYPES_USED, root=root ) with open(f"{filename}.html", "w") as fp: fp.write(make_html(folder_content, plot_jobs, root)) with Pool(16) as p: p.map(task, chunks(plot_jobs, max(len(plot_jobs) // 32, 1))) print(f"Open {filename}.html in a browser.") if __name__ == "__main__": main() futhark-0.25.27/tools/png2data.py000077500000000000000000000021551475065116200165620ustar00rootroot00000000000000#!/usr/bin/env python3 # # Turn a PNG image into a Futhark value of type [height][width]i32 encoded in # the binary data format. # Usage: ./png2data image.png output.data # # Absolutely no error checking is done. # # png2data currently supports 24-bit RGB PNGs. Transparency, grayscale, # multiple palettes and most other features of PNG is not supported. import sys import numpy as np import png if __name__ == "__main__": infile = sys.argv[1] outfile = sys.argv[2] r = png.Reader(infile) (width, height, img, _) = r.read() image_2d = np.vstack(list(map(np.uint32, img))) image_3d = np.reshape(image_2d, (height, width, 3)) array = np.empty((height, width), dtype=np.int32) array = array | (image_3d[:, :, 0] << 16) array = array | (image_3d[:, :, 1] << 8) array = array | (image_3d[:, :, 2]) with open(outfile, "wb") as f: f.write(b"b") f.write(np.int8(2)) # type: ignore f.write(np.int8(2)) # type: ignore f.write(b" i32") f.write(np.uint64(height)) # type: ignore f.write(np.uint64(width)) # type: ignore array.tofile(f) futhark-0.25.27/tools/release/000077500000000000000000000000001475065116200161225ustar00rootroot00000000000000futhark-0.25.27/tools/release/README.md000066400000000000000000000042601475065116200174030ustar00rootroot00000000000000Release Engineering =================== This repository contains hacks, scripts and tooling for building binary releases of Futhark. For a binary release on a POSIXy system, the directory `skeleton` is populated with compiled binaries and manpages and turned into a tarball. Release Procedure ----------------- When making a release, we not only make a binary release, we also make a source release based on a specific Git commit. A source release corresponds exactly to the state of the Git repository at some point. * Decide on a version number X.Y.Z. Remember that only unreleased versions have Z=0. * Find a commit that would make for a good release. Make sure it is at least minimally functional. * Verify that `CHANGELOG.md` is updated, and that the most recent entries refer to the correct version number. * Verify the version number in `futhark.cabal`. * Make sure you've committed any changes you may have made due to the previous two steps. * Run `git tag vX.Y.Z`. * Push the tag: `git push origin vX.Y.Z`. This counts as a release on GitHub. * Wait for GitHub Actions to create the new release. You're done! Congratulations! Increment the version number in `futhark.cabal` (such that Z=0), make room for a new release in `CHANGELOG.md`, and go hack some more on the compiler. The following steps are for making the release available elsewhere. Some of them are supposed to be automatic. * **This is done automatically by a CI job**: Run `tools/release/binary-tarball.sh . -X.Y.Z-linux-x86_64`. This produces `futhark-X.Y.Z-linux-x86_64.xz`. Put this tarball in some public location and make sure its permissions make it readable. * **This is done automatically by a CI job**: Run `tools/release/hackage.sh`. * **This is done automatically by a CI job**: Go to `https://github.com/diku-dk/futhark/releases` and copy release notes from `CHANGELOG.md`. * Update the Homebrew formula with `brew bump-formula-pr --url=https://github.com/diku-dk/futhark/archive/vX.Y.Z.tar.gz futhark --verbose`. This may take significant previous setup. In practice someone else is running a script that automatically bumps formulae every few hours. futhark-0.25.27/tools/release/binary-deb.sh000077500000000000000000000020761475065116200205020ustar00rootroot00000000000000#!/bin/sh set -e if [ $# != 2 ]; then echo "Use: $0 " exit 1 fi repodir=$1 version=$2 skeletondir=$repodir/tools/release/deb-skeleton tmpdir=$(mktemp -d) debdir=$tmpdir/futhark-$version deb=futhark-$version.deb if ! [ -d "$tmpdir" ]; then echo "Failed to create temporary directory." exit 1 fi inrepo() { (cd $repodir; "$@") } commit=$(inrepo git describe --dirty=-modified --always) if echo "$commit" | grep -q modified; then echo "Refusing to package a modified repository." exit 1 fi inrepo stack build inrepo make -C docs man binpath=$(inrepo stack path --local-install-root)/bin umask 000 # dpkg-deb is picky about permissions. cp -r $skeletondir $debdir mkdir -p $debdir/usr/bin install $binpath/* $debdir/usr/bin/ mkdir -p $debdir/usr/share/man/man1/ install -D -m 644 $repodir/docs/_build/man/*.1 $debdir/usr/share/man/man1/ sed s/VERSION/$version/ -i $debdir/DEBIAN/control echo "Building .deb in $tmpdir..." (cd $tmpdir && dpkg-deb --build futhark-$version) mv $tmpdir/$deb . rm -rf "$debdir" echo "Created $deb." futhark-0.25.27/tools/release/binary-tarball.sh000077500000000000000000000022721475065116200213670ustar00rootroot00000000000000#!/bin/sh set -e if [ $# -gt 2 ]; then echo "Usage: $0 [suffix]" exit 1 fi repodir=$1 suffix=$2 if [ $# -lt 2 ]; then suffix=-$(uname -s)-$(uname -m)-$(date +%Y-%m-%d) echo "Defaulting suffix to $suffix." fi skeletondir=$repodir/tools/release/skeleton tmpdir=$(mktemp -d) tarballdir=$tmpdir/futhark$suffix tarball=futhark$suffix.tar.xz if ! [ -d "$tmpdir" ]; then echo "Failed to create temporary directory." exit 1 fi inrepo() { (cd $repodir; "$@") } commit=$(inrepo git describe --dirty=-modified --always) if echo "$commit" | grep -q modified; then echo "Refusing to package a modified repository." exit 1 fi inrepo stack build inrepo make -C docs man bins=$(inrepo stack path --local-install-root)/bin/futhark* cp -r $skeletondir $tarballdir echo "$commit" > $tarballdir/commit-id mkdir -p $tarballdir/bin cp -r $bins $tarballdir/bin cp $repodir/LICENSE $tarballdir/LICENSE mkdir -p $tarballdir/share/man/man1/ cp -r $repodir/docs/_build/man/*.1 $tarballdir/share/man/man1/ echo "Building tarball in $tmpdir..." (cd $tmpdir; tar -Jcf $tarball futhark$suffix) mv $tmpdir/$tarball . rm -rf "$tarballdir" echo "Created $tarball." futhark-0.25.27/tools/release/deb-skeleton/000077500000000000000000000000001475065116200204765ustar00rootroot00000000000000futhark-0.25.27/tools/release/deb-skeleton/DEBIAN/000077500000000000000000000000001475065116200214205ustar00rootroot00000000000000futhark-0.25.27/tools/release/deb-skeleton/DEBIAN/control000077500000000000000000000012111475065116200230210ustar00rootroot00000000000000Package: futhark Version: VERSION Section: devel Maintainer: Troels Henriksen Architecture: amd64 Priority: optional Depends: libgmp10, zlib1g, libtinfo5 Homepage: https://futhark-lang.org Description: Purely functional parallel array language Futhark is a small programming language designed to be compiled to efficient parallel code. It is a statically typed, data-parallel, and purely functional array language in the ML family. The compiler generates GPU code via OpenCL, although the language itself is hardware-agnostic. . This package contains various Futhark compilers, as well as some Futhark development tools. futhark-0.25.27/tools/release/hackage.sh000077500000000000000000000012501475065116200200420ustar00rootroot00000000000000#!/bin/sh # # Generate sdists and documentation for Hackage, then upload them. set -e user=TroelsHenriksen pass=$HACKAGE_KEY dir=$(mktemp -d dist.XXXXXX) trap 'rm -rf "$dir"' EXIT echo "Generating sdist..." cabal sdist --builddir="$dir" echo "Uploading sdist..." cabal upload --publish --username=$user --password=$pass $dir/sdist/*.tar.gz # See https://github.com/haskell/cabal/issues/8104 for why we have --haddock-options=--quickjump echo "Generating Haddock..." cabal v2-haddock --builddir="$dir" --haddock-for-hackage --enable-doc --haddock-options=--quickjump echo "Uploading Haddock..." cabal upload --publish --username=$user --password=$pass -d $dir/*-docs.tar.gz futhark-0.25.27/tools/release/skeleton/000077500000000000000000000000001475065116200177465ustar00rootroot00000000000000futhark-0.25.27/tools/release/skeleton/Makefile000066400000000000000000000013011475065116200214010ustar00rootroot00000000000000# High-performance purely functional data-parallel array programming on the GPU # See LICENSE file for copyright and license details. include config.mk BINARIES=bin/* all: @echo "This is a precompiled binary distribution of Futhark - no building necessary." @echo "But you may want to check out config.mk to ensure that 'make install' installs to the right place." install: @echo \# Installing executable files to ${PREFIX}/bin @mkdir -p ${PREFIX}/bin/ install bin/* ${PREFIX}/bin/ @echo \# Installing manual page to ${MANPREFIX}/man1 @mkdir -p ${MANPREFIX}/man1/ @echo \# Installing manpages to ${MANPREFIX}/man1/ install -D -m 644 share/man/man1/* ${MANPREFIX}/man1/ .PHONY: all install futhark-0.25.27/tools/release/skeleton/README.md000066400000000000000000000010171475065116200212240ustar00rootroot00000000000000Futhark binary distribution =========================== This is a precompiled distribution of the Futhark compiler. For more information on Futhark itself, see our [website][1] or [GitHub repository][2]. Installation ============ You can install the Futhark compiler and its manual pages by invoking: make install If you wish to install to a specific location, you can set the `PREFIX` environment variable: PREFIX=$HOME/.local make install [1]: https://futhark-lang.org [2]: https://github.com/diku-dk/futhark futhark-0.25.27/tools/release/skeleton/config.mk000066400000000000000000000001431475065116200215420ustar00rootroot00000000000000# Customize below to fit your system # paths PREFIX ?= /usr/local MANPREFIX = ${PREFIX}/share/man futhark-0.25.27/tools/remove-from-bench-json.py000077500000000000000000000015471475065116200213500ustar00rootroot00000000000000#!/usr/bin/env python # # Remove some benchmark from a JSON file produced by futhark-bench's --json # option. # # This is useful if we accidentally added some pointless programs that # just obscure things. import json import sys remove = sys.argv[1] files = sys.argv[2:] removed = False i = 0 for fp in files: with open(fp, "r") as infile: try: file_json = json.load(infile) except Exception as e: print("Could not read {}: {}".format(fp, e)) continue for b in file_json.keys(): if remove in b: file_json.pop(b) removed = True with open(fp, "w") as outfile: json.dump(file_json, outfile) i += 1 if i % 100 == 0: print("{}/{} files processed...".format(i, len(files))) if not removed: print("Warning: no JSON file contained {}".format(remove)) futhark-0.25.27/tools/run-formatter.sh000077500000000000000000000003031475065116200176420ustar00rootroot00000000000000#!/bin/sh # # Run ormolu on the input directories # # Example command: # # ./run-formatter.sh src unittests find "$@" -name '*.hs' -print -exec ormolu --mode inplace --check-idempotence {} \; futhark-0.25.27/tools/style-check-file.sh000077500000000000000000000050371475065116200201760ustar00rootroot00000000000000#!/usr/bin/env bash # # Perform basic style checks on a single Futhark compiler source code # file. If a style violation is found, this script will exit with a # nonzero exit code. Checks performed: # # 0) if Haskell: hlint (with some rules ignored). # # 1) Trailing whitespace. # # 2) A file ending in blank lines. # # 3) Tabs, anywhere. # # 4) DOS line endings (CRLF). # # 5) If Python: black and mypy. # # This script can be called on directories (in which case it applies # to every file inside), or on files. cyan=$(printf '%b' "\033[0;36m") NC=$(printf '%b' "\033[0m") if [ "$#" -ne 1 ]; then echo "Usage: $0 " exit 1 fi exit=0 hlint_check() { # Some hlint-suggestions are terrible, so ignore them here. hlint -XNoCPP -i "Use import/export shortcut" -i "Use const" -i "Use tuple-section" -i "Too strict maybe" -i "Functor law" "$1" } no_trailing_blank_lines() { awk '/^$/ { sawblank=1; next } { sawblank=0 } END { if (sawblank==1) { exit 1 } }' "$1" } file="$1" if grep -E -n " +$" "$file"; then echo echo "${cyan}Trailing whitespace in $file:${NC}" echo "$output" exit=1 fi no_tabs() { if grep -E -n "$(printf '\t')" "$1"; then echo echo "${cyan}Tab characters found in $1:${NC}" echo "$output" exit=1 fi } if file "$file" | grep -q 'CRLF line terminators'; then echo echo "${cyan}CRLF line terminators in $file.${NC}" exit=1 fi if ! no_trailing_blank_lines "$file"; then echo echo "${cyan}$file ends in several blank lines.${NC}" exit=1 fi case "$file" in *.fut) no_tabs "$file" ;; *.hs) no_tabs "$file" if ! LC_ALL=C.UTF-8 ormolu --mode check "$file"; then echo echo "${cyan}$file:${NC} is not formatted correctly with Ormolu" echo "$output" exit=1 fi output=$(hlint_check "$file") if [ $? = 1 ]; then echo echo "${cyan}$file:${NC} hlint issues" echo "$output" exit=1 fi ;; *.py) no_tabs "$file" output=$(mypy --no-error-summary "$file") if [ $? != 0 ]; then echo echo "${cyan}$file:${NC} Mypy is upset" echo "$output" exit=1 fi output=$(black --check --diff --quiet "$file") if [ $? != 0 ]; then echo echo "${cyan}$file:${NC} is not formatted correctly with Black" echo "$output" exit=1 fi ;; esac exit $exit futhark-0.25.27/tools/style-check.sh000077500000000000000000000004741475065116200172610ustar00rootroot00000000000000#!/usr/bin/env bash # # Run style-check-file.sh on every file in a directory tree. set -e set -o pipefail check="$(dirname "$0")"/style-check-file.sh if [ $# -ne 0 ]; then # Running a style checker will not contribute to a scientific # publication. find "$@" -type f | parallel --will-cite "$check" fi futhark-0.25.27/tools/test-autotuner.sh000077500000000000000000000006151475065116200200460ustar00rootroot00000000000000#!/usr/bin/env bash set -o errexit set -o nounset set -o pipefail for bench in rodinia/lud/lud misc/bfast/bfast finpar/LocVolCalib accelerate/nbody/nbody do futhark autotune --backend=opencl --tuning=.tuning.test futhark-benchmarks/$bench.fut diff -u futhark-benchmarks/$bench.fut.tuning futhark-benchmarks/$bench.fut.tuning.test rm -f futhark-benchmarks/$bench.fut.tuning.test done futhark-0.25.27/tools/testfmt.sh000077500000000000000000000065641475065116200165420ustar00rootroot00000000000000#!/bin/sh THREADS=16 TEST_DIR="TEMP_FORMATTER_TEST" if [ "$TESTFMT_WORKER" ]; then shift testwith() { prog="$1" if [ -d "$prog" ]; then exit 0 fi if ! futhark check-syntax "$prog" 2> /dev/null > /dev/null; then rm "$prog" exit 0 fi name=${prog%.fut} futhark hash "$prog" 2> /dev/null > "$prog.expected" if [ ! $? -eq 0 ]; then rm "$prog" "$prog.expected" exit 0 fi futhark fmt < "$prog" 2> /dev/null > "$name.fmt.fut" futhark fmt < "$name.fmt.fut" 2> /dev/null > "$name.fmt.fmt.fut" futhark hash "$name.fmt.fut" 2> /dev/null > "$prog.actual" futhark tokens "$prog" 2> /dev/null | grep '^COMMENT' > "$name.comments" futhark tokens "$name.fmt.fut" 2> /dev/null | grep '^COMMENT' > "$name.fmt.comments" hash_result=1 idempotent_result=1 comments_result=1 if ! cmp --silent "$prog.expected" "$prog.actual" then hash_result=0 echo "Failed Hash Comparison Test" >> "$name.log" fi printf "$hash_result" >> "$name.result" if ! cmp --silent "$name.fmt.fut" "$name.fmt.fmt.fut" then idempotent_result=0 echo "Failed Idempotent Comparison Test" >> "$name.log" fi printf "$idempotent_result" >> "$name.result" if ! cmp --silent "$name.comments" "$name.fmt.comments" then comments_result=0 echo "Failed Order of Comments Test" >> "$name.log" fi printf "$comments_result" >> "$name.result" } for f in "$@"; do testwith "$f" done else if [ $# != 1 ]; then echo "Usage: $0 " exit 1 fi rm -rf "$TEST_DIR" && mkdir "$TEST_DIR" find "$1" -name '*.fut' -exec cp --parents \{\} "$TEST_DIR" \; export TESTFMT_WORKER=1 find "$TEST_DIR" -name '*.fut' -print0 | xargs -0 -n 1 -P $THREADS "$0" -rec idempotent_pass=0 idempotent_fail=0 hash_pass=0 hash_fail=0 comments_pass=0 comments_fail=0 for file in $(find "$TEST_DIR" -name '*.result'); do if ! [ -d $file ]; then content=$(cat "$file") hash_result=$(printf "$content" | cut -c1) idempotent_result=$(printf "$content" | cut -c2) comments_result=$(printf "$content" | cut -c3) hash_pass=$((hash_pass + hash_result)) hash_fail=$((hash_fail + 1 - hash_result)) idempotent_pass=$((idempotent_pass + idempotent_result)) idempotent_fail=$((idempotent_fail + 1 - idempotent_result)) comments_pass=$((comments_pass + comments_result)) comments_faul=$((comments_pass + 1 - comments_result)) if [ "$hash_result" -eq 1 ] && [ "$idempotent_result" -eq 1 ] && [ "$comments_result" -eq 1 ]; then rm "${file%.*}."* fi fi done find "$TEST_DIR" -type d -empty -delete echo "Hash Tests Passed: $hash_pass/$((hash_pass + hash_fail))" echo "Idempotent Tests Passed: $idempotent_pass/$((idempotent_pass + idempotent_fail))" echo "Order of Comments Tests Passed: $comments_pass/$((comments_pass + comments_fail))" if [ $hash_fail -eq 0 ] && [ $idempotent_fail -eq 0 ] && [ $comments_fail -eq 0 ]; then exit 0; else exit 1; fi fi futhark-0.25.27/tools/testparser.sh000077500000000000000000000022431475065116200172360ustar00rootroot00000000000000#!/bin/sh # # This script checks that the IR of various Futhark compiler pipelines # can be parsed back in. It does not do any correctness verification. # # To make the xargs hack work, you must run it as './testparser.sh' # (or with some other path), *not* by passing it to 'sh'. set -e THREADS=16 dir=testparser mkdir -p $dir if [ "$TESTPARSER_WORKER" ]; then shift testwith() { f=$1 suffix=$2 shift; shift out=$dir/${f}_${suffix} mkdir -p "$(dirname "$out")" if ! ( futhark dev -w "$@" "$f" > "$out" && futhark dev "$out" >/dev/null); then echo "^- $f $*" exit 1 fi } for f in "$@"; do if futhark check "$f" 2>/dev/null && ! (grep -F 'tags { disable }' -q "$f") ; then testwith "$f" soacs -s testwith "$f" mc -s --extract-multicore testwith "$f" gpu --gpu testwith "$f" mc_mem --mc-mem if ! grep -q no_opencl "$f"; then testwith "$f" gpu_mem --gpu-mem fi fi done else export TESTPARSER_WORKER=1 find "$@" -name \*.fut -print0 | xargs -0 -n 1 -P $THREADS "$0" -rec fi futhark-0.25.27/unittests/000077500000000000000000000000001475065116200154045ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/000077500000000000000000000000001475065116200170105ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/AD/000077500000000000000000000000001475065116200172745ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/AD/DerivativesTests.hs000066400000000000000000000027601475065116200231450ustar00rootroot00000000000000module Futhark.AD.DerivativesTests (tests) where import Data.Map qualified as M import Data.Text qualified as T import Futhark.AD.Derivatives import Futhark.Analysis.PrimExp import Futhark.IR.Syntax.Core (nameFromText) import Futhark.Util.Pretty (prettyString) import Test.Tasty import Test.Tasty.HUnit tests :: TestTree tests = testGroup "Futhark.AD.DerivativesTests" [ testGroup "Primitive functions" $ map primFunTest $ filter (not . (`elem` missing_primfuns) . fst) $ M.toList primFuns, testGroup "BinOps" $ map binOpTest allBinOps, testGroup "UnOps" $ map unOpTest allUnOps ] where blank = ValueExp . blankPrimValue primFunTest (f, (ts, ret, _)) = testCase (T.unpack f) $ case pdBuiltin (nameFromText f) (map blank ts) of Nothing -> assertFailure "pdBuiltin gives Nothing" Just v -> map primExpType v @?= replicate (length ts) ret -- We know we have no derivatives for these... and they are not -- coming any time soon. missing_primfuns = [ "gamma16", "gamma32", "gamma64", "lgamma16", "lgamma32", "lgamma64" ] binOpTest bop = testCase (prettyString bop) $ let t = binOpType bop (dx, dy) = pdBinOp bop (blank t) (blank t) in (primExpType dx, primExpType dy) @?= (t, t) unOpTest bop = testCase (prettyString bop) $ let t = unOpType bop in primExpType (pdUnOp bop $ blank t) @?= t futhark-0.25.27/unittests/Futhark/Analysis/000077500000000000000000000000001475065116200205735ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/Analysis/AlgSimplifyTests.hs000066400000000000000000000051261475065116200243760ustar00rootroot00000000000000{-# LANGUAGE FlexibleInstances #-} {-# OPTIONS_GHC -fno-warn-orphans #-} {-# OPTIONS_GHC -fno-warn-unused-imports #-} {-# OPTIONS_GHC -fno-warn-unused-matches #-} {-# OPTIONS_GHC -fno-warn-unused-top-binds #-} module Futhark.Analysis.AlgSimplifyTests ( tests, ) where import Control.Monad import Data.Function ((&)) import Data.List (subsequences) import Data.Map qualified as M import Data.Maybe (fromMaybe, mapMaybe) import Futhark.Analysis.AlgSimplify hiding (add, sub) import Futhark.Analysis.PrimExp import Futhark.IR.Syntax.Core import Test.Tasty import Test.Tasty.HUnit import Test.Tasty.QuickCheck tests :: TestTree tests = testGroup "AlgSimplifyTests" [ testProperty "simplify is idempotent" $ \(TestableExp e) -> simplify e == simplify (simplify e), testProperty "simplify doesn't change exp evalutation result" $ \(TestableExp e) -> evalPrimExp (\_ -> Nothing) e == evalPrimExp (\_ -> Nothing) (simplify e) ] eval :: TestableExp -> Int64 eval (TestableExp e) = evalExp e evalExp :: PrimExp VName -> Int64 evalExp (ValueExp (IntValue (Int64Value i))) = i evalExp (BinOpExp (Add Int64 OverflowUndef) e1 e2) = evalExp e1 + evalExp e2 evalExp (BinOpExp (Sub Int64 OverflowUndef) e1 e2) = evalExp e1 - evalExp e2 evalExp (BinOpExp (Mul Int64 OverflowUndef) e1 e2) = evalExp e1 * evalExp e2 evalExp _ = undefined add :: PrimExp VName -> PrimExp VName -> PrimExp VName add = BinOpExp (Add Int64 OverflowUndef) sub :: PrimExp VName -> PrimExp VName -> PrimExp VName sub = BinOpExp (Sub Int64 OverflowUndef) mul :: PrimExp VName -> PrimExp VName -> PrimExp VName mul = BinOpExp (Mul Int64 OverflowUndef) neg :: PrimExp VName -> PrimExp VName neg = BinOpExp (Sub Int64 OverflowUndef) (val 0) l :: Int -> PrimExp VName l i = LeafExp (VName (nameFromString $ show i) i) (IntType Int64) val :: Int64 -> PrimExp VName val = ValueExp . IntValue . Int64Value generateExp :: Gen (PrimExp VName) generateExp = do n <- getSize if n <= 1 then val <$> arbitrary else oneof [ scale (`div` 2) $ generateBinOp add, scale (`div` 2) $ generateBinOp sub, scale (`div` 2) $ generateBinOp mul, scale (`div` 2) generateNeg, val <$> arbitrary ] generateBinOp :: (PrimExp VName -> PrimExp VName -> PrimExp VName) -> Gen (PrimExp VName) generateBinOp op = do t1 <- generateExp op t1 <$> generateExp generateNeg :: Gen (PrimExp VName) generateNeg = do neg <$> generateExp newtype TestableExp = TestableExp (PrimExp VName) deriving (Show) instance Arbitrary TestableExp where arbitrary = TestableExp <$> generateExp futhark-0.25.27/unittests/Futhark/Analysis/PrimExp/000077500000000000000000000000001475065116200221575ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/Analysis/PrimExp/TableTests.hs000066400000000000000000000252731475065116200245760ustar00rootroot00000000000000module Futhark.Analysis.PrimExp.TableTests (tests) where import Control.Monad.State.Strict import Data.Map.Strict qualified as M import Futhark.Analysis.PrimExp import Futhark.Analysis.PrimExp.Table import Futhark.IR.GPU import Futhark.IR.GPUTests () import Futhark.IR.MC import Futhark.IR.MCTests () import Test.Tasty import Test.Tasty.HUnit tests :: TestTree tests = testGroup "AnalyzePrim" [stmToPrimExpsTests] stmToPrimExpsTests :: TestTree stmToPrimExpsTests = testGroup "stmToPrimExps" [stmToPrimExpsTestsGPU, stmToPrimExpsTestsMC] stmToPrimExpsTestsGPU :: TestTree stmToPrimExpsTestsGPU = testGroup "GPU" $ do let scope = M.fromList [ ("n_5142", FParamName "i64"), ("m_5143", FParamName "i64"), ("xss_5144", FParamName "[n_5142][m_5143]i64"), ("segmap_group_size_5201", LetName "i64"), ("segmap_usable_groups_5202", LetName "i64"), ("defunc_0_map_res_5203", LetName "[n_5142]i64"), ("defunc_0_f_res_5207", LetName "i64"), ("i_5208", IndexName Int64), ("acc_5209", FParamName "i64"), ("b_5210", LetName "i64"), ("defunc_0_f_res_5211", LetName "i64") ] [ testCase "BinOp" $ do let stm = "let {defunc_0_f_res_5211 : i64} = add64(acc_5209, b_5210)" let res = execState (stmToPrimExps scope stm) mempty let expected = M.fromList [ ( "defunc_0_f_res_5211", Just ( BinOpExp (Add Int64 OverflowWrap) (LeafExp "acc_5209" (IntType Int64)) (LeafExp "b_5210" (IntType Int64)) ) ) ] res @?= expected, testCase "Index" $ do let stm = "let {b_5210 : i64} = xss_5144[gtid_5204, i_5208]" let res = execState (stmToPrimExps scope stm) mempty let expected = M.fromList [("b_5210", Nothing)] res @?= expected, testCase "Loop" $ do let stm = "let {defunc_0_f_res_5207 : i64} = loop {acc_5209 : i64} = {0i64} for i_5208:i64 < m_5143 do { {defunc_0_f_res_5211} }" let res = execState (stmToPrimExps scope stm) mempty let expected = M.fromList [ ("defunc_0_f_res_5207", Nothing), ("i_5208", Just (LeafExp "i_5208" (IntType Int64))), ("acc_5209", Just (LeafExp "acc_5209" (IntType Int64))) ] res @?= expected, testCase "Loop body" $ do let stm = "let {defunc_0_f_res_5207 : i64} = loop {acc_5209 : i64} = {0i64} for i_5208:i64 < m_5143 do { let {b_5210 : i64} = xss_5144[gtid_5204, i_5208] let {defunc_0_f_res_5211 : i64} = add64(acc_5209, b_5210) in {defunc_0_f_res_5211} }" let res = execState (stmToPrimExps scope stm) mempty let expected = M.fromList [ ("defunc_0_f_res_5207", Nothing), ("i_5208", Just (LeafExp "i_5208" (IntType Int64))), ("acc_5209", Just (LeafExp "acc_5209" (IntType Int64))), ("b_5210", Nothing), ( "defunc_0_f_res_5211", Just ( BinOpExp (Add Int64 OverflowWrap) (LeafExp "acc_5209" (IntType Int64)) (LeafExp "b_5210" (IntType Int64)) ) ) ] res @?= expected, testCase "SegMap" $ do let stm = "let {defunc_0_map_res_5125 : [n_5142]i64} =\ \ segmap(thread; ; grid=segmap_usable_groups_5124; blocksize=segmap_group_size_5123)\ \ (gtid_5126 < n_5142) (~phys_tid_5127) : {i64} {\ \ return {returns lifted_lambda_res_5129} \ \}" let res = execState (stmToPrimExps scope stm) mempty let expected = M.fromList [ ("defunc_0_map_res_5125", Nothing), ("gtid_5126", Just (LeafExp "gtid_5126" (IntType Int64))) ] res @?= expected, testCase "SegMap body" $ do let stm :: Stm GPU stm = "let {defunc_0_map_res_5125 : [n_5142]i64} =\ \ segmap(thread; ; grid=segmap_usable_groups_5124; blocksize=segmap_group_size_5123)\ \ (gtid_5126 < n_5142) (~phys_tid_5127) : {i64} {\ \ let {eta_p_5128 : i64} =\ \ xs_5093[gtid_5126]\ \ let {lifted_lambda_res_5129 : i64} =\ \ add64(2i64, eta_p_5128)\ \ return {returns lifted_lambda_res_5129}\ \ }" let res = execState (stmToPrimExps scope stm) mempty let expected = M.fromList [ ("defunc_0_map_res_5125", Nothing), ("gtid_5126", Just (LeafExp "gtid_5126" (IntType Int64))), ("eta_p_5128", Nothing), ( "lifted_lambda_res_5129", Just ( BinOpExp (Add Int64 OverflowWrap) (ValueExp (IntValue (Int64Value 2))) (LeafExp "eta_p_5128" (IntType Int64)) ) ) ] res @?= expected ] stmToPrimExpsTestsMC :: TestTree stmToPrimExpsTestsMC = testGroup "MC" $ do let scope = M.fromList [ ("n_5142", FParamName "i64"), ("m_5143", FParamName "i64"), ("xss_5144", FParamName "[n_5142][5143]i64"), ("segmap_group_size_5201", LetName "i64"), ("segmap_usable_groups_5202", LetName "i64"), ("defunc_0_map_res_5203", LetName "[n_5142]i64"), ("defunc_0_f_res_5207", LetName "i64"), ("i_5208", IndexName Int64), ("acc_5209", FParamName "i64"), ("b_5210", LetName "i64"), ("defunc_0_f_res_5211", LetName "i64") ] [ testCase "BinOp" $ do let stm = "let {defunc_0_f_res_5211 : i64} = add64(acc_5209, b_5210)" let res = execState (stmToPrimExps scope stm) mempty let expected = M.fromList [ ( "defunc_0_f_res_5211", Just ( BinOpExp (Add Int64 OverflowWrap) (LeafExp "acc_5209" (IntType Int64)) (LeafExp "b_5210" (IntType Int64)) ) ) ] res @?= expected, testCase "Index" $ do let stm = "let {b_5210 : i64} = xss_5144[gtid_5204, i_5208]" let res = execState (stmToPrimExps scope stm) mempty let expected = M.fromList [("b_5210", Nothing)] res @?= expected, testCase "Loop" $ do let stm = "let {defunc_0_f_res_5207 : i64} = loop {acc_5209 : i64} = {0i64} for i_5208:i64 < m_5143 do { {defunc_0_f_res_5211} }" let res = execState (stmToPrimExps scope stm) mempty let expected = M.fromList [ ("defunc_0_f_res_5207", Nothing), ("i_5208", Just (LeafExp "i_5208" (IntType Int64))), ("acc_5209", Just (LeafExp "acc_5209" (IntType Int64))) ] res @?= expected, testCase "Loop body" $ do let stm = "\ \let {defunc_0_f_res_5207 : i64} =\ \ loop {acc_5209 : i64} = {0i64}\ \ for i_5208:i64 < m_5143 do {\ \ let {b_5210 : i64} =\ \ xss_5144[gtid_5204, i_5208]\ \ let {defunc_0_f_res_5211 : i64} =\ \ add64(acc_5209, b_5210)\ \ in {defunc_0_f_res_5211}\ \ }" let res = execState (stmToPrimExps scope stm) mempty let expected = M.fromList [ ("defunc_0_f_res_5207", Nothing), ("i_5208", Just (LeafExp "i_5208" (IntType Int64))), ("acc_5209", Just (LeafExp "acc_5209" (IntType Int64))), ("b_5210", Nothing), ( "defunc_0_f_res_5211", Just ( BinOpExp (Add Int64 OverflowWrap) (LeafExp "acc_5209" (IntType Int64)) (LeafExp "b_5210" (IntType Int64)) ) ) ] res @?= expected, testCase "SegMap" $ do let stm = "let {defunc_0_map_res_5125 : [n_5142]i64} =\ \ segmap()\ \ (gtid_5126 < n_5142) (~flat_tid_5112) : {i64} {\ \ return {returns lifted_lambda_res_5129}\ \ }" let res = execState (stmToPrimExps scope stm) mempty let expected = M.fromList [ ("defunc_0_map_res_5125", Nothing), ("gtid_5126", Just (LeafExp "gtid_5126" (IntType Int64))) ] res @?= expected, testCase "SegMap body" $ do let stm :: Stm MC stm = "let {defunc_0_map_res_5125 : [n_5142]i64} =\ \ segmap()\ \ (gtid_5126 < n_5142) (~flat_tid_5112) : {i64} {\ \ let {eta_p_5128 : i64} =\ \ xs_5093[gtid_5126]\ \ let {lifted_lambda_res_5129 : i64} =\ \ add64(2i64, eta_p_5128)\ \ return {returns lifted_lambda_res_5129}\ \ }" let res = execState (stmToPrimExps scope stm) mempty let expected = M.fromList [ ("defunc_0_map_res_5125", Nothing), ("gtid_5126", Just (LeafExp "gtid_5126" (IntType Int64))), ("eta_p_5128", Nothing), ( "lifted_lambda_res_5129", Just ( BinOpExp (Add Int64 OverflowWrap) (ValueExp (IntValue (Int64Value 2))) (LeafExp "eta_p_5128" (IntType Int64)) ) ) ] res @?= expected ] futhark-0.25.27/unittests/Futhark/BenchTests.hs000066400000000000000000000025521475065116200214120ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} module Futhark.BenchTests (tests) where import Data.Map qualified as M import Data.Text qualified as T import Futhark.Bench import Futhark.ProfileTests () import Test.Tasty import Test.Tasty.QuickCheck instance Arbitrary RunResult where arbitrary = RunResult . getPositive <$> arbitrary printable :: Gen String printable = getASCIIString <$> arbitrary instance Arbitrary DataResult where arbitrary = DataResult <$> (T.pack <$> printable) <*> oneof [ Left <$> arbText, Right <$> ( Result <$> arbitrary <*> arbMap <*> oneof [pure Nothing, Just <$> arbText] <*> arbitrary ) ] where arbText = T.pack <$> printable arbMap = M.fromList <$> listOf ((,) <$> arbText <*> arbitrary) -- XXX: we restrict this generator to single datasets to we don't have -- to worry about duplicates. instance Arbitrary BenchResult where arbitrary = BenchResult <$> printable <*> (pure <$> arbitrary) encodeDecodeJSON :: TestTree encodeDecodeJSON = testProperty "encoding and decoding are inverse" prop where prop :: BenchResult -> Bool prop brs = decodeBenchResults (encodeBenchResults [brs]) == Right [brs] tests :: TestTree tests = testGroup "Futhark.BenchTests" [encodeDecodeJSON] futhark-0.25.27/unittests/Futhark/IR/000077500000000000000000000000001475065116200173225ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/IR/GPUTests.hs000066400000000000000000000010251475065116200213320ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} module Futhark.IR.GPUTests () where import Data.String import Futhark.IR.GPU import Futhark.IR.Parse import Futhark.IR.SyntaxTests (parseString) -- There isn't anything to test in this module, but we define some -- convenience instances. instance IsString (Stm GPU) where fromString = parseString "Stm GPU" parseStmGPU instance IsString (Body GPU) where fromString = parseString "Body GPU" parseBodyGPU instance IsString (Prog GPU) where fromString = parseString "Prog GPU" parseGPU futhark-0.25.27/unittests/Futhark/IR/MCTests.hs000066400000000000000000000010121475065116200211720ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} module Futhark.IR.MCTests () where import Data.String import Futhark.IR.MC import Futhark.IR.Parse import Futhark.IR.SyntaxTests (parseString) -- There isn't anything to test in this module, but we define some -- convenience instances. instance IsString (Stm MC) where fromString = parseString "Stm MC" parseStmMC instance IsString (Body MC) where fromString = parseString "Body MC" parseBodyMC instance IsString (Prog MC) where fromString = parseString "Prog MC" parseMC futhark-0.25.27/unittests/Futhark/IR/Mem/000077500000000000000000000000001475065116200200405ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/IR/Mem/IntervalTests.hs000066400000000000000000000042071475065116200232060ustar00rootroot00000000000000module Futhark.IR.Mem.IntervalTests ( tests, ) where import Futhark.Analysis.AlgSimplify import Futhark.Analysis.PrimExp.Convert import Futhark.IR.Mem.Interval import Futhark.IR.Syntax import Futhark.IR.Syntax.Core () import Test.Tasty import Test.Tasty.HUnit -- Actual tests. tests :: TestTree tests = testGroup "IntervalTests" testDistributeOffset name :: String -> Int -> VName name s = VName (nameFromString s) testDistributeOffset :: [TestTree] testDistributeOffset = [ testCase "Stride is (nb-b)" $ do let n = TPrimExp $ LeafExp (name "n" 1) $ IntType Int64 b = TPrimExp $ LeafExp (name "b" 2) $ IntType Int64 res <- distributeOffset [Prod False [untyped (n * b - b :: TPrimExp Int64 VName)]] [ Interval 0 1 (n * b - b), Interval 0 b b, Interval 0 b 1 ] res == [Interval 1 1 (n * b - b), Interval 0 b b, Interval 0 b 1] @? "Failed", testCase "Stride is 1024r" $ do let r = TPrimExp $ LeafExp (name "r" 1) $ IntType Int64 res <- distributeOffset [Prod False [untyped (1024 :: TPrimExp Int64 VName), untyped r]] [ Interval 0 1 (1024 * r), Interval 0 32 32, Interval 0 32 1 ] res == [Interval 1 1 (1024 * r), Interval 0 32 32, Interval 0 32 1] @? "Failed. Got " <> show res, testCase "Stride is 32, offsets are multples of 32" $ do let n = TPrimExp $ LeafExp (name "n" 0) $ IntType Int64 let g1 = TPrimExp $ LeafExp (name "g" 1) $ IntType Int64 let g2 = TPrimExp $ LeafExp (name "g" 2) $ IntType Int64 res <- distributeOffset [ Prod False [untyped (1024 :: TPrimExp Int64 VName)], Prod False [untyped (1024 :: TPrimExp Int64 VName), untyped g1], Prod False [untyped (32 :: TPrimExp Int64 VName), untyped g2] ] [ Interval 0 1 (1024 * n), Interval 0 1 32, Interval 0 32 1 ] res == [ Interval 0 1 (1024 * n), Interval (32 + 32 * g1 + g2) 1 32, Interval 0 32 1 ] @? "Failed. Got " <> show res ] futhark-0.25.27/unittests/Futhark/IR/Mem/IxFun/000077500000000000000000000000001475065116200210715ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/IR/Mem/IxFun/Alg.hs000066400000000000000000000107151475065116200221340ustar00rootroot00000000000000-- | A simple index operation representation. Every operation corresponds to a -- constructor. module Futhark.IR.Mem.IxFun.Alg ( IxFun (..), iota, offsetIndex, permute, reshape, coerce, slice, flatSlice, expand, shape, index, disjoint, ) where import Data.List qualified as L import Data.Set qualified as S import Futhark.IR.Pretty () import Futhark.IR.Prop import Futhark.IR.Syntax ( DimIndex (..), FlatDimIndex (..), FlatSlice (..), Slice (..), flatSliceDims, sliceDims, unitSlice, ) import Futhark.Util.IntegralExp import Futhark.Util.Pretty import Prelude hiding (div, mod, span) type Shape num = [num] type Indices num = [num] type Permutation = [Int] data IxFun num = Direct (Shape num) | Permute (IxFun num) Permutation | Index (IxFun num) (Slice num) | FlatIndex (IxFun num) (FlatSlice num) | Reshape (IxFun num) (Shape num) | Coerce (IxFun num) (Shape num) | OffsetIndex (IxFun num) num | Expand num num (IxFun num) deriving (Eq, Show) instance (Pretty num) => Pretty (IxFun num) where pretty (Direct dims) = "Direct" <> parens (commasep $ map pretty dims) pretty (Permute fun perm) = pretty fun <> pretty perm pretty (Index fun is) = pretty fun <> pretty is pretty (FlatIndex fun is) = pretty fun <> pretty is pretty (Reshape fun oldshape) = pretty fun <> "->reshape" <> parens (pretty oldshape) pretty (Coerce fun oldshape) = pretty fun <> "->coerce" <> parens (pretty oldshape) pretty (OffsetIndex fun i) = pretty fun <> "->offset_index" <> parens (pretty i) pretty (Expand o p fun) = "expand(" <> pretty o <> "," <+> pretty p <> "," <+> pretty fun <> ")" iota :: Shape num -> IxFun num iota = Direct offsetIndex :: IxFun num -> num -> IxFun num offsetIndex = OffsetIndex permute :: IxFun num -> Permutation -> IxFun num permute = Permute slice :: IxFun num -> Slice num -> IxFun num slice = Index flatSlice :: IxFun num -> FlatSlice num -> IxFun num flatSlice = FlatIndex expand :: num -> num -> IxFun num -> IxFun num expand = Expand reshape :: IxFun num -> Shape num -> IxFun num reshape = Reshape coerce :: IxFun num -> Shape num -> IxFun num coerce = Reshape shape :: (IntegralExp num) => IxFun num -> Shape num shape (Direct dims) = dims shape (Permute ixfun perm) = rearrangeShape perm $ shape ixfun shape (Index _ how) = sliceDims how shape (FlatIndex ixfun how) = flatSliceDims how <> tail (shape ixfun) shape (Reshape _ dims) = dims shape (Coerce _ dims) = dims shape (OffsetIndex ixfun _) = shape ixfun shape (Expand _ _ ixfun) = shape ixfun index :: (Eq num, IntegralExp num) => IxFun num -> Indices num -> num index (Direct dims) is = sum $ zipWith (*) is slicesizes where slicesizes = drop 1 $ sliceSizes dims index (Permute fun perm) is_new = index fun is_old where is_old = rearrangeShape (rearrangeInverse perm) is_new index (Index fun (Slice js)) is = index fun (adjust js is) where adjust (DimFix j : js') is' = j : adjust js' is' adjust (DimSlice j _ s : js') (i : is') = j + i * s : adjust js' is' adjust _ _ = [] index (FlatIndex fun (FlatSlice offset js)) is = index fun $ sum (offset : zipWith f is js) : drop (length js) is where f i (FlatDimIndex _ s) = i * s index (Reshape fun newshape) is = let new_indices = reshapeIndex (shape fun) newshape is in index fun new_indices index (Coerce fun _) is = index fun is index (OffsetIndex fun i) is = case shape fun of d : ds -> index (Index fun (Slice (DimSlice i (d - i) 1 : map (unitSlice 0) ds))) is [] -> error "index: OffsetIndex: underlying index function has rank zero" index (Expand o p ixfun) is = o + p * index ixfun is allPoints :: (IntegralExp num, Enum num) => [num] -> [[num]] allPoints dims = let total = product dims strides = drop 1 $ L.reverse $ scanl (*) 1 $ L.reverse dims in map (unflatInd strides) [0 .. total - 1] where unflatInd strides x = fst $ foldl ( \(res, acc) span -> (res ++ [acc `div` span], acc `mod` span) ) ([], x) strides disjoint :: (IntegralExp num, Ord num, Enum num) => IxFun num -> IxFun num -> Bool disjoint ixf1 ixf2 = let shp1 = shape ixf1 points1 = S.fromList $ allPoints shp1 allIdxs1 = S.map (index ixf1) points1 shp2 = shape ixf2 points2 = S.fromList $ allPoints shp2 allIdxs2 = S.map (index ixf2) points2 in S.disjoint allIdxs1 allIdxs2 futhark-0.25.27/unittests/Futhark/IR/Mem/IxFunTests.hs000066400000000000000000000514561475065116200224630ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} module Futhark.IR.Mem.IxFunTests ( tests, ) where import Data.Bifunctor import Data.Function ((&)) import Data.List qualified as L import Data.Map qualified as M import Data.Text qualified as T import Futhark.Analysis.PrimExp.Convert import Futhark.IR.Mem.IxFun.Alg qualified as IxFunAlg import Futhark.IR.Mem.IxFunWrapper import Futhark.IR.Mem.IxFunWrapper qualified as IxFunWrap import Futhark.IR.Mem.LMAD qualified as IxFunLMAD import Futhark.IR.Prop import Futhark.IR.Syntax import Futhark.IR.Syntax.Core () import Futhark.Util.IntegralExp qualified as IE import Futhark.Util.Pretty import Test.Tasty import Test.Tasty.HUnit import Prelude hiding (span) import Prelude qualified as P instance IE.IntegralExp Int where quot = P.quot rem = P.rem div = P.div mod = P.mod pow = (P.^) sgn = Just . P.signum allPoints :: [Int] -> [[Int]] allPoints dims = let total = product dims strides = drop 1 $ L.reverse $ scanl (*) 1 $ L.reverse dims in map (unflatInd strides) [0 .. total - 1] where unflatInd :: [Int] -> Int -> [Int] unflatInd strides x = fst $ foldl ( \(res, acc) span -> (res ++ [acc `P.div` span], acc `P.mod` span) ) ([], x) strides compareIxFuns :: Maybe (IxFunLMAD.LMAD Int) -> IxFunAlg.IxFun Int -> Assertion compareIxFuns (Just ixfunLMAD) ixfunAlg = let lmadShape = IxFunLMAD.shape ixfunLMAD algShape = IxFunAlg.shape ixfunAlg points = allPoints lmadShape resLMAD = map (IxFunLMAD.index ixfunLMAD) points resAlg = map (IxFunAlg.index ixfunAlg) points errorMessage = T.unpack . docText $ "lmad ixfun: " <> pretty ixfunLMAD "alg ixfun: " <> pretty ixfunAlg "lmad shape: " <> pretty lmadShape "alg shape: " <> pretty algShape "lmad points length: " <> pretty (length resLMAD) "alg points length: " <> pretty (length resAlg) "lmad points: " <> pretty resLMAD "alg points: " <> pretty resAlg in (lmadShape == algShape && resLMAD == resAlg) @? errorMessage compareIxFuns Nothing ixfunAlg = assertFailure $ unlines [ "lmad ixfun: Nothing", "alg ixfun: " <> prettyString ixfunAlg ] compareOps :: IxFunWrap.IxFun Int -> Assertion compareOps (ixfunLMAD, ixfunAlg) = compareIxFuns ixfunLMAD ixfunAlg compareOpsFailure :: IxFunWrap.IxFun Int -> Assertion compareOpsFailure (Nothing, _) = pure () compareOpsFailure (Just ixfunLMAD, ixfunAlg) = assertFailure . T.unpack . docText $ "Not supposed to be representable as LMAD." "lmad ixfun: " <> pretty ixfunLMAD "alg ixfun: " <> pretty ixfunAlg -- XXX: Clean this up. n :: Int n = 19 slice3 :: Slice Int slice3 = Slice [ DimSlice 2 (n `P.div` 3) 3, DimFix (n `P.div` 2), DimSlice 1 (n `P.div` 2) 2 ] -- Actual tests. tests :: TestTree tests = testGroup "IxFunTests" $ concat [ test_iota, test_slice_iota, test_slice_reshape_iota1, test_permute_slice_iota, test_reshape_iota, test_reshape_permute_iota, test_slice_reshape_iota2, test_reshape_slice_iota3, test_flatten_strided, test_complex1, test_complex2, test_expand1, test_expand2, test_expand3, test_expand4, test_flatSlice_iota, test_slice_flatSlice_iota, test_flatSlice_flatSlice_iota, test_flatSlice_slice_iota, test_flatSlice_transpose_slice_iota -- TODO: Without z3, these tests fail. Ideally, our internal simplifier -- should be able to handle them: -- -- test_disjoint3 ] singleton :: TestTree -> [TestTree] singleton = (: []) test_iota :: [TestTree] test_iota = singleton . testCase "iota" . compareOps $ iota [n] test_slice_iota :: [TestTree] test_slice_iota = singleton . testCase "slice . iota" . compareOps $ slice (iota [n, n, n]) slice3 test_slice_reshape_iota1 :: [TestTree] test_slice_reshape_iota1 = singleton . testCase "slice . reshape . iota 1" . compareOps $ slice (reshape (iota [n, n, n]) [n `P.div` 2, n `P.div` 3, 1]) slice3 test_permute_slice_iota :: [TestTree] test_permute_slice_iota = singleton . testCase "permute . slice . iota" . compareOps $ permute (slice (iota [n, n, n]) slice3) [1, 0] test_reshape_iota :: [TestTree] test_reshape_iota = -- This tests a pattern that occurs with ScalarSpace. singleton . testCase "reshape . zeroslice . iota" . compareOps $ let s = Slice [DimSlice 0 n 0, DimSlice 0 n 1] in reshape (slice (iota [n, n]) s) [1, n, 1, n] test_reshape_permute_iota :: [TestTree] test_reshape_permute_iota = -- negative reshape test singleton . testCase "reshape . permute . iota" . compareOpsFailure $ let newdims = [n * n, n] in reshape (permute (iota [n, n, n]) [1, 2, 0]) newdims test_slice_reshape_iota2 :: [TestTree] test_slice_reshape_iota2 = singleton . testCase "slice . reshape . iota 2" . compareOps $ let newdims = [n * n, n] slc = Slice [ DimFix (n `P.div` 2), DimSlice 0 n 1 ] in slice (reshape (iota [n, n, n, n]) newdims) slc test_reshape_slice_iota3 :: [TestTree] test_reshape_slice_iota3 = -- negative reshape test singleton . testCase "reshape . slice . iota 3" . compareOpsFailure $ let newdims = [n * n, n] slc = Slice [ DimFix (n `P.div` 2), DimSlice 0 n 1, DimSlice 0 (n `P.div` 2) 1, DimSlice 0 n 1 ] in reshape (slice (iota [n, n, n, n]) slc) newdims -- Tests flattening something that is strided - this can occur after -- memory expansion. test_flatten_strided :: [TestTree] test_flatten_strided = singleton . testCase "reshape . fix . iota 3d" . compareOps $ let slc = Slice [DimSlice 0 n 1, DimSlice 0 2 1, DimFix 1] in reshape (slice (iota [n, 2, n * n]) slc) [2 * 10] test_complex1 :: [TestTree] test_complex1 = singleton . testCase "permute . slice . permute . slice . iota 1" . compareOps $ let slice33 = Slice [ DimSlice (n - 1) (n `P.div` 3) (-1), DimSlice (n - 1) n (-1), DimSlice (n - 1) n (-1), DimSlice 0 n 1 ] ixfun = permute (slice (iota [n, n, n, n, n]) slice33) [3, 1, 2, 0] m = n `P.div` 3 slice1 = Slice [ DimSlice 0 n 1, DimSlice (n - 1) n (-1), DimSlice (n - 1) n (-1), DimSlice 1 (m - 2) (-1) ] ixfun' = slice ixfun slice1 in ixfun' test_complex2 :: [TestTree] test_complex2 = singleton . testCase "permute . slice . permute . slice . iota 2" . compareOps $ let slc2 = Slice [ DimFix (n `P.div` 2), DimSlice (n - 1) (n `P.div` 3) (-1), DimSlice (n - 1) n (-1), DimSlice (n - 1) n (-1), DimSlice 0 n 1 ] ixfun = permute (slice (iota [n, n, n, n, n]) slc2) [3, 1, 2, 0] m = n `P.div` 3 slice1 = Slice [ DimSlice 0 n 1, DimSlice (n - 1) n (-1), DimSlice (n - 1) n (-1), DimSlice 1 (m - 2) (-1) ] ixfun' = slice ixfun slice1 in ixfun' -- Imitates a case from memory expansion. test_expand1 :: [TestTree] test_expand1 = [ testCase "expand . iota1d" . compareOps $ expand t nt (iota [n]) ] where t = 3 nt = 7 -- Imitates another case from memory expansion. test_expand2 :: [TestTree] test_expand2 = [ testCase "expand . iota2d" . compareOps $ expand t nt (iota [n, n]) ] where t = 3 nt = 7 test_expand3 :: [TestTree] test_expand3 = [ testCase "expand . permute . iota2d" . compareOps $ expand t nt (permute (iota [n, n `div` 2]) [1, 0]) ] where t = 3 nt = 7 test_expand4 :: [TestTree] test_expand4 = [ testCase "expand . slice . iota1d" . compareOps $ expand t nt (slice (iota [n]) (Slice [DimSlice (n `div` 2) (n `div` 2) 1])) ] where t = 3 nt = 7 test_flatSlice_iota :: [TestTree] test_flatSlice_iota = singleton . testCase "flatSlice . iota" . compareOps $ flatSlice (iota [n * n * n * n]) $ FlatSlice 2 [FlatDimIndex (n * 2) 4, FlatDimIndex n 3, FlatDimIndex 1 2] test_slice_flatSlice_iota :: [TestTree] test_slice_flatSlice_iota = singleton . testCase "slice . flatSlice . iota " . compareOps $ slice (flatSlice (iota [2 + n * n * n]) flat_slice) $ Slice [DimFix 2, DimSlice 0 n 1, DimFix 0] where flat_slice = FlatSlice 2 [FlatDimIndex (n * n) 1, FlatDimIndex n 1, FlatDimIndex 1 1] test_flatSlice_flatSlice_iota :: [TestTree] test_flatSlice_flatSlice_iota = singleton . testCase "flatSlice . flatSlice . iota " . compareOps $ flatSlice (flatSlice (iota [10 * 10]) flat_slice_1) flat_slice_2 where flat_slice_1 = FlatSlice 17 [FlatDimIndex 3 27, FlatDimIndex 3 10, FlatDimIndex 3 1] flat_slice_2 = FlatSlice 2 [FlatDimIndex 2 (-2)] test_flatSlice_slice_iota :: [TestTree] test_flatSlice_slice_iota = singleton . testCase "flatSlice . slice . iota " . compareOps $ flatSlice (slice (iota [210, 100]) $ Slice [DimSlice 10 100 2, DimFix 10]) flat_slice_1 where flat_slice_1 = FlatSlice 17 [FlatDimIndex 3 27, FlatDimIndex 3 10, FlatDimIndex 3 1] test_flatSlice_transpose_slice_iota :: [TestTree] test_flatSlice_transpose_slice_iota = singleton . testCase "flatSlice . transpose . slice . iota " . compareOps $ flatSlice (permute (slice (iota [20, 20]) $ Slice [DimSlice 1 5 2, DimSlice 0 5 2]) [1, 0]) flat_slice_1 where flat_slice_1 = FlatSlice 1 [FlatDimIndex 2 2] -- test_disjoint2 :: [TestTree] -- test_disjoint2 = -- let add_nw64 = (+) -- mul_nw64 = (*) -- sub64 = (-) -- vname s i = VName (nameFromString s) i -- in [ let gtid_8472 = TPrimExp $ LeafExp (vname "gtid" 8472) $ IntType Int64 -- gtid_8473 = TPrimExp $ LeafExp (vname "gtid" 8473) $ IntType Int64 -- gtid_8474 = TPrimExp $ LeafExp (vname "gtid" 8474) $ IntType Int64 -- num_blocks_8284 = TPrimExp $ LeafExp (vname "num_blocks" 8284) $ IntType Int64 -- nonnegs = freeIn [gtid_8472, gtid_8473, gtid_8474, num_blocks_8284] -- j_m_i_8287 :: TPrimExp Int64 VName -- j_m_i_8287 = num_blocks_8284 - 1 -- lessthans :: [(VName, PrimExp VName)] -- lessthans = -- [ (head $ namesToList $ freeIn gtid_8472, untyped j_m_i_8287), -- (head $ namesToList $ freeIn gtid_8473, untyped j_m_i_8287), -- (head $ namesToList $ freeIn gtid_8474, untyped (16 :: TPrimExp Int64 VName)) -- ] -- lm1 :: IxFunLMAD.LMAD (TPrimExp Int64 VName) -- lm1 = -- IxFunLMAD.LMAD -- 256 -- [ IxFunLMAD.LMADDim 256 0 (sub64 (num_blocks_8284) 1) 0 , -- IxFunLMAD.LMADDim 1 0 16 1 , -- IxFunLMAD.LMADDim 16 0 16 2 -- ] -- lm2 :: IxFunLMAD.LMAD (TPrimExp Int64 VName) -- lm2 = -- IxFunLMAD.LMAD -- (add_nw64 (add_nw64 (add_nw64 (add_nw64 (mul_nw64 (256) (num_blocks_8284)) (256)) (mul_nw64 (gtid_8472) (mul_nw64 (256) (num_blocks_8284)))) (mul_nw64 (gtid_8473) (256))) (mul_nw64 (gtid_8474) (16))) -- [IxFunLMAD.LMADDim 1 0 16 0 ] -- in testCase (pretty lm1 <> " and " <> pretty lm2) $ IxFunLMAD.disjoint2 lessthans nonnegs lm1 lm2 @? "Failed" -- ] -- test_lessThanish :: [TestTree] -- test_lessThanish = -- [testCase "0 < 1" $ IxFunLMAD.lessThanish mempty mempty 0 1 @? "Failed"] -- test_lessThanOrEqualish :: [TestTree] -- test_lessThanOrEqualish = -- [testCase "1 <= 1" $ IxFunLMAD.lessThanOrEqualish mempty mempty 1 1 @? "Failed"] _test_disjoint3 :: [TestTree] _test_disjoint3 = let foo s = VName (nameFromString s) add_nw64 = (+) add64 = (+) mul_nw64 = (*) mul64 = (*) sub64 = (-) sdiv64 = IE.div sub_nw64 = (-) disjointTester asserts lessthans lm1 lm2 = let nonnegs = map (`LeafExp` IntType Int64) $ namesToList $ freeIn lm1 <> freeIn lm2 scmap = M.fromList $ map (\x -> (x, Prim $ IntType Int64)) $ namesToList $ freeIn lm1 <> freeIn lm2 <> freeIn lessthans <> freeIn asserts in IxFunLMAD.disjoint3 scmap asserts lessthans nonnegs lm1 lm2 in [ testCase "lm1 and lm2" $ let lessthans = [ ( i_12214, sdiv64 (sub64 n_blab 1) block_size_12121 ), (gtid_12553, add64 1 i_12214) ] & map (\(v, p) -> (head $ namesToList $ freeIn v, untyped p)) asserts = [ untyped ((2 * block_size_12121 :: TPrimExp Int64 VName) .<. n_blab :: TPrimExp Bool VName), untyped ((3 :: TPrimExp Int64 VName) .<. n_blab :: TPrimExp Bool VName) ] block_size_12121 = TPrimExp $ LeafExp (foo "block_size" 12121) $ IntType Int64 i_12214 = TPrimExp $ LeafExp (foo "i" 12214) $ IntType Int64 n_blab = TPrimExp $ LeafExp (foo "n" 1337) $ IntType Int64 gtid_12553 = TPrimExp $ LeafExp (foo "gtid" 12553) $ IntType Int64 lm1 = IxFunLMAD.LMAD (add_nw64 (mul64 block_size_12121 i_12214) (mul_nw64 (add_nw64 gtid_12553 1) (sub64 (mul64 block_size_12121 n_blab) block_size_12121))) [ IxFunLMAD.LMADDim (add_nw64 (mul_nw64 block_size_12121 n_blab) (mul_nw64 (-1) block_size_12121)) (sub_nw64 (sub_nw64 (add64 1 i_12214) gtid_12553) 1), IxFunLMAD.LMADDim 1 (block_size_12121 + 1) ] lm2 = IxFunLMAD.LMAD (block_size_12121 * i_12214) [ IxFunLMAD.LMADDim (add_nw64 (mul_nw64 block_size_12121 n_blab) (mul_nw64 (-1) block_size_12121)) gtid_12553, IxFunLMAD.LMADDim 1 (1 + block_size_12121) ] lm_w = IxFunLMAD.LMAD (add_nw64 (add64 (add64 1 n_blab) (mul64 block_size_12121 i_12214)) (mul_nw64 gtid_12553 (sub64 (mul64 block_size_12121 n_blab) block_size_12121))) [ IxFunLMAD.LMADDim n_blab block_size_12121, IxFunLMAD.LMADDim 1 block_size_12121 ] lm_blocks = IxFunLMAD.LMAD (block_size_12121 * i_12214 + n_blab + 1) [ IxFunLMAD.LMADDim (add_nw64 (mul_nw64 block_size_12121 n_blab) (mul_nw64 (-1) block_size_12121)) (i_12214 + 1), IxFunLMAD.LMADDim n_blab block_size_12121, IxFunLMAD.LMADDim 1 block_size_12121 ] lm_lower_per = IxFunLMAD.LMAD (block_size_12121 * i_12214) [ IxFunLMAD.LMADDim (add_nw64 (mul_nw64 block_size_12121 n_blab) (mul_nw64 (-1) block_size_12121)) (i_12214 + 1), IxFunLMAD.LMADDim 1 (block_size_12121 + 1) ] res1 = disjointTester asserts lessthans lm1 lm_w res2 = disjointTester asserts lessthans lm2 lm_w res3 = disjointTester asserts lessthans lm_lower_per lm_blocks in res1 && res2 && res3 @? "Failed", testCase "nw second half" $ do let lessthans = [ ( i_12214, sdiv64 (sub64 n_blab 1) block_size_12121 ), (gtid_12553, add64 1 i_12214) ] & map (\(v, p) -> (head $ namesToList $ freeIn v, untyped p)) asserts = [ untyped ((2 * block_size_12121 :: TPrimExp Int64 VName) .<. n_blab :: TPrimExp Bool VName), untyped ((3 :: TPrimExp Int64 VName) .<. n_blab :: TPrimExp Bool VName) ] block_size_12121 = TPrimExp $ LeafExp (foo "block_size" 12121) $ IntType Int64 i_12214 = TPrimExp $ LeafExp (foo "i" 12214) $ IntType Int64 n_blab = TPrimExp $ LeafExp (foo "n" 1337) $ IntType Int64 gtid_12553 = TPrimExp $ LeafExp (foo "gtid" 12553) $ IntType Int64 lm1 = IxFunLMAD.LMAD (add_nw64 (add64 n_blab (sub64 (sub64 (mul64 n_blab (add64 1 (mul64 block_size_12121 (add64 1 i_12214)))) block_size_12121) 1)) (mul_nw64 (add_nw64 gtid_12553 1) (sub64 (mul64 block_size_12121 n_blab) block_size_12121))) [ IxFunLMAD.LMADDim (add_nw64 (mul_nw64 block_size_12121 n_blab) (mul_nw64 (-1) block_size_12121)) (sub_nw64 (sub_nw64 (sub64 (sub64 (sdiv64 (sub64 n_blab 1) block_size_12121) i_12214) 1) gtid_12553) 1), IxFunLMAD.LMADDim n_blab block_size_12121 ] lm2 = IxFunLMAD.LMAD (add_nw64 (sub64 (sub64 (mul64 n_blab (add64 1 (mul64 block_size_12121 (add64 1 i_12214)))) block_size_12121) 1) (mul_nw64 (add_nw64 gtid_12553 1) (sub64 (mul64 block_size_12121 n_blab) block_size_12121))) [ IxFunLMAD.LMADDim (add_nw64 (mul_nw64 block_size_12121 n_blab) (mul_nw64 (-1) block_size_12121)) (sub_nw64 (sub_nw64 (sub64 (sub64 (sdiv64 (sub64 n_blab 1) block_size_12121) i_12214) 1) gtid_12553) 1), IxFunLMAD.LMADDim 1 (1 + block_size_12121) ] lm3 = IxFunLMAD.LMAD (add64 n_blab (sub64 (sub64 (mul64 n_blab (add64 1 (mul64 block_size_12121 (add64 1 i_12214)))) block_size_12121) 1)) [ IxFunLMAD.LMADDim (add_nw64 (mul_nw64 block_size_12121 n_blab) (mul_nw64 (-1) block_size_12121)) gtid_12553, IxFunLMAD.LMADDim n_blab block_size_12121 ] lm4 = IxFunLMAD.LMAD (sub64 (sub64 (mul64 n_blab (add64 1 (mul64 block_size_12121 (add64 1 i_12214)))) block_size_12121) 1) [ IxFunLMAD.LMADDim (add_nw64 (mul_nw64 block_size_12121 n_blab) (mul_nw64 (-1) block_size_12121)) gtid_12553, IxFunLMAD.LMADDim 1 (1 + block_size_12121) ] lm_w = IxFunLMAD.LMAD (add_nw64 (sub64 (mul64 n_blab (add64 2 (mul64 block_size_12121 (add64 1 i_12214)))) block_size_12121) (mul_nw64 gtid_12553 (sub64 (mul64 block_size_12121 n_blab) block_size_12121))) [ IxFunLMAD.LMADDim n_blab block_size_12121, IxFunLMAD.LMADDim 1 block_size_12121 ] res1 = disjointTester asserts lessthans lm1 lm_w res2 = disjointTester asserts lessthans lm2 lm_w res3 = disjointTester asserts lessthans lm3 lm_w res4 = disjointTester asserts lessthans lm4 lm_w in res1 && res2 && res3 && res4 @? "Failed " <> show [res1, res2, res3, res4], testCase "lud long" $ let lessthans = [ bimap (head . namesToList . freeIn) untyped (step, num_blocks - 1 :: TPrimExp Int64 VName) ] step = TPrimExp $ LeafExp (foo "step" 1337) $ IntType Int64 num_blocks = TPrimExp $ LeafExp (foo "n" 1338) $ IntType Int64 lm1 = IxFunLMAD.LMAD (1024 * num_blocks * (1 + step) + 1024 * step) [ IxFunLMAD.LMADDim (1024 * num_blocks) (num_blocks - step - 1), IxFunLMAD.LMADDim 32 32, IxFunLMAD.LMADDim 1 32 ] lm_w1 = IxFunLMAD.LMAD (1024 * num_blocks * step + 1024 * step) [ IxFunLMAD.LMADDim 32 32, IxFunLMAD.LMADDim 1 32 ] lm_w2 = IxFunLMAD.LMAD ((1 + step) * 1024 * num_blocks + (1 + step) * 1024) [ IxFunLMAD.LMADDim (1024 * num_blocks) (num_blocks - step - 1), IxFunLMAD.LMADDim 1024 (num_blocks - step - 1), IxFunLMAD.LMADDim 1024 1, IxFunLMAD.LMADDim 32 1, IxFunLMAD.LMADDim 128 8, IxFunLMAD.LMADDim 4 8, IxFunLMAD.LMADDim 32 4, IxFunLMAD.LMADDim 1 4 ] asserts = [ untyped ((1 :: TPrimExp Int64 VName) .<. num_blocks :: TPrimExp Bool VName) ] res1 = disjointTester asserts lessthans lm1 lm_w1 res2 = disjointTester asserts lessthans lm1 lm_w2 in res1 && res2 @? "Failed" ] futhark-0.25.27/unittests/Futhark/IR/Mem/IxFunWrapper.hs000066400000000000000000000026611475065116200227730ustar00rootroot00000000000000-- | Perform index function operations in both algebraic and LMAD -- representations. module Futhark.IR.Mem.IxFunWrapper ( IxFun, iota, permute, reshape, coerce, slice, flatSlice, expand, ) where import Control.Monad (join) import Futhark.IR.Mem.IxFun.Alg qualified as IA import Futhark.IR.Mem.LMAD qualified as I import Futhark.IR.Syntax (FlatSlice, Slice) import Futhark.Util.IntegralExp type Shape num = [num] type Permutation = [Int] type IxFun num = (Maybe (I.LMAD num), IA.IxFun num) iota :: (IntegralExp num) => Shape num -> IxFun num iota x = (Just $ I.iota 0 x, IA.iota x) permute :: IxFun num -> Permutation -> IxFun num permute (l, a) x = (I.permute <$> l <*> pure x, IA.permute a x) reshape :: (Eq num, IntegralExp num) => IxFun num -> Shape num -> IxFun num reshape (l, a) x = (join (I.reshape <$> l <*> pure x), IA.reshape a x) coerce :: IxFun num -> Shape num -> IxFun num coerce (l, a) x = (I.coerce <$> l <*> pure x, IA.coerce a x) slice :: (Eq num, IntegralExp num) => IxFun num -> Slice num -> IxFun num slice (l, a) x = (I.slice <$> l <*> pure x, IA.slice a x) flatSlice :: (IntegralExp num) => IxFun num -> FlatSlice num -> IxFun num flatSlice (l, a) x = (I.flatSlice <$> l <*> pure x, IA.flatSlice a x) expand :: (IntegralExp num) => num -> num -> IxFun num -> IxFun num expand o p (lf, af) = (Just . I.expand o p =<< lf, IA.expand o p af) futhark-0.25.27/unittests/Futhark/IR/Prop/000077500000000000000000000000001475065116200202425ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/IR/Prop/RearrangeTests.hs000066400000000000000000000030061475065116200235260ustar00rootroot00000000000000module Futhark.IR.Prop.RearrangeTests (tests) where import Control.Applicative import Futhark.IR.Prop.Rearrange import Test.Tasty import Test.Tasty.HUnit import Test.Tasty.QuickCheck import Prelude tests :: TestTree tests = testGroup "RearrangeTests" $ isMapTransposeTests ++ [isMapTransposeProp] isMapTransposeTests :: [TestTree] isMapTransposeTests = [ testCase (unwords ["isMapTranspose", show perm, "==", show dres]) $ isMapTranspose perm @?= dres | (perm, dres) <- [ ([0, 1, 4, 5, 2, 3], Just (2, 2, 2)), ([1, 0, 4, 5, 2, 3], Nothing), ([1, 0], Just (0, 1, 1)), ([0, 2, 1], Just (1, 1, 1)), ([0, 1, 2], Nothing), ([1, 0, 2], Nothing) ] ] newtype Permutation = Permutation [Int] deriving (Eq, Ord, Show) instance Arbitrary Permutation where arbitrary = do Positive n <- arbitrary Permutation <$> shuffle [0 .. n - 1] isMapTransposeProp :: TestTree isMapTransposeProp = testProperty "isMapTranspose corresponds to a map of transpose" prop where prop :: Permutation -> Bool prop (Permutation perm) = case isMapTranspose perm of Nothing -> True Just (r1, r2, r3) -> and [ r1 >= 0, r2 > 0, r3 > 0, r1 + r2 + r3 == length perm, let (mapped, notmapped) = splitAt r1 perm (pretrans, posttrans) = splitAt r2 notmapped in mapped ++ posttrans ++ pretrans == [0 .. length perm - 1] ] futhark-0.25.27/unittests/Futhark/IR/Prop/ReshapeTests.hs000066400000000000000000000023131475065116200232070ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} module Futhark.IR.Prop.ReshapeTests ( tests, ) where import Futhark.IR.Prop.Constants import Futhark.IR.Prop.Reshape import Futhark.IR.Syntax import Test.Tasty import Test.Tasty.HUnit reshapeOuterTests :: [TestTree] reshapeOuterTests = [ testCase (unwords ["reshapeOuter", show sc, show n, show shape, "==", show sc_res]) $ reshapeOuter (intShape sc) n (intShape shape) @?= intShape sc_res | (sc, n, shape, sc_res) <- [ ([1], 1, [4, 3], [1, 3]), ([1], 2, [4, 3], [1]), ([2, 2], 1, [4, 3], [2, 2, 3]), ([2, 2], 2, [4, 3], [2, 2]) ] ] reshapeInnerTests :: [TestTree] reshapeInnerTests = [ testCase (unwords ["reshapeInner", show sc, show n, show shape, "==", show sc_res]) $ reshapeInner (intShape sc) n (intShape shape) @?= intShape sc_res | (sc, n, shape, sc_res) <- [ ([1], 1, [4, 3], [4, 1]), ([1], 0, [4, 3], [1]), ([2, 2], 1, [4, 3], [4, 2, 2]), ([2, 2], 0, [4, 3], [2, 2]) ] ] intShape :: [Int] -> Shape intShape = Shape . map (intConst Int32 . toInteger) tests :: TestTree tests = testGroup "ReshapeTests" $ reshapeOuterTests ++ reshapeInnerTests futhark-0.25.27/unittests/Futhark/IR/PropTests.hs000066400000000000000000000005331475065116200216220ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} module Futhark.IR.PropTests ( tests, ) where import Futhark.IR.Prop.RearrangeTests qualified import Futhark.IR.Prop.ReshapeTests qualified import Test.Tasty tests :: TestTree tests = testGroup "PropTests" [ Futhark.IR.Prop.ReshapeTests.tests, Futhark.IR.Prop.RearrangeTests.tests ] futhark-0.25.27/unittests/Futhark/IR/Syntax/000077500000000000000000000000001475065116200206105ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/IR/Syntax/CoreTests.hs000066400000000000000000000040331475065116200230570ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} module Futhark.IR.Syntax.CoreTests (tests) where import Control.Applicative import Futhark.IR.Pretty (prettyString) import Futhark.IR.Syntax.Core import Language.Futhark.CoreTests () import Language.Futhark.PrimitiveTests () import Test.QuickCheck import Test.Tasty import Test.Tasty.HUnit import Prelude tests :: TestTree tests = testGroup "Internal CoreTests" subShapeTests subShapeTests :: [TestTree] subShapeTests = [ shape [free 1, free 2] `isSubShapeOf` shape [free 1, free 2], shape [free 1, free 3] `isNotSubShapeOf` shape [free 1, free 2], shape [free 1] `isNotSubShapeOf` shape [free 1, free 2], shape [free 1, free 2] `isSubShapeOf` shape [free 1, Ext 3], shape [Ext 1, Ext 2] `isNotSubShapeOf` shape [Ext 1, Ext 1], shape [Ext 1, Ext 1] `isSubShapeOf` shape [Ext 1, Ext 2] ] where shape :: [ExtSize] -> ExtShape shape = Shape free :: Int -> ExtSize free = Free . Constant . IntValue . Int32Value . fromIntegral isSubShapeOf shape1 shape2 = subShapeTest shape1 shape2 True isNotSubShapeOf shape1 shape2 = subShapeTest shape1 shape2 False subShapeTest :: ExtShape -> ExtShape -> Bool -> TestTree subShapeTest shape1 shape2 expected = testCase ( "subshapeOf " ++ prettyString shape1 ++ " " ++ prettyString shape2 ++ " == " ++ show expected ) $ shape1 `subShapeOf` shape2 @?= expected instance Arbitrary NoUniqueness where arbitrary = pure NoUniqueness instance (Arbitrary shape, Arbitrary u) => Arbitrary (TypeBase shape u) where arbitrary = oneof [ Prim <$> arbitrary, Array <$> arbitrary <*> arbitrary <*> arbitrary ] instance Arbitrary Ident where arbitrary = Ident <$> arbitrary <*> arbitrary instance Arbitrary Rank where arbitrary = Rank <$> elements [1 .. 9] instance Arbitrary Shape where arbitrary = Shape . map intconst <$> listOf1 (elements [1 .. 9]) where intconst = Constant . IntValue . Int32Value futhark-0.25.27/unittests/Futhark/IR/SyntaxTests.hs000066400000000000000000000016711475065116200221740ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} module Futhark.IR.SyntaxTests (parseString) where import Data.String import Data.Text qualified as T import Futhark.IR.Parse import Futhark.IR.Syntax -- There isn't anything to test in this module, but we define some -- convenience instances. parseString :: String -> (FilePath -> T.Text -> Either T.Text a) -> String -> a parseString desc p = either (error . T.unpack) id . p ("IsString " <> desc) . T.pack instance IsString Type where fromString = parseString "Type" parseType instance IsString DeclExtType where fromString = parseString "DeclExtType" parseDeclExtType instance IsString DeclType where fromString = parseString "DeclType" parseDeclType instance IsString VName where fromString = parseString "VName" parseVName instance IsString SubExp where fromString = parseString "SubExp" parseSubExp instance IsString SubExpRes where fromString = parseString "SubExpRes" parseSubExpRes futhark-0.25.27/unittests/Futhark/Internalise/000077500000000000000000000000001475065116200212655ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/Internalise/TypesValuesTests.hs000066400000000000000000000120071475065116200251300ustar00rootroot00000000000000module Futhark.Internalise.TypesValuesTests (tests) where import Control.Monad.Free (Free (..)) import Data.Map qualified as M import Data.String (fromString) import Futhark.IR.Syntax hiding (Free) import Futhark.IR.SyntaxTests () import Futhark.Internalise.TypesValues import Language.Futhark.SyntaxTests () import Test.Tasty import Test.Tasty.HUnit internaliseTypeTests :: TestTree internaliseTypeTests = testGroup "internaliseType" [ mkTest "[0]()" [Free [Pure "[0i64]unit"]], mkTest "{a: [t_7447][n_7448](f32, f32), b: i64, c: i64}" [Free [Pure "[t_7447][n_7448]f32", Pure "[t_7447][n_7448]f32"], Pure "i64", Pure "i64"], mkTest "([0]i32, {a: f32, b: f32, c: f32, d: [0]((f32, f32), (f32, f32))})" [ Free [Pure "[0i64]i32"], Pure "f32", Pure "f32", Pure "f32", Free [Pure "[0i64]f32", Pure "[0i64]f32", Pure "[0i64]f32", Pure "[0i64]f32"] ], mkTest "[0]([1]i32, f32)" [Free [Free [Pure "[0i64][1i64]i32"], Pure "[0i64]f32"]] ] where mkTest x y = testCase (prettyString x) $ internaliseType x @?= y sumTypeTests :: TestTree sumTypeTests = testGroup "internaliseConstructors" [ testCase "Dedup of primitives" $ internaliseConstructors ( M.fromList [ ("foo", [Pure "i64"]), ("bar", [Pure "i64"]) ] ) @?= ( [Pure "i64"], [ ("bar", [0]), ("foo", [0]) ] ), testCase "Dedup of array" $ internaliseConstructors ( M.fromList [ ("foo", [Pure "[?0]i64"]), ("bar", [Pure "[?0]i64"]) ] ) @?= ( [Pure "[?0]i64", Pure "[?0]i64"], [ ("bar", [0]), ("foo", [1]) ] ), testCase "Dedup of array of tuple" $ internaliseConstructors ( M.fromList [ ("foo", [Free [Pure "[?0]i64", Pure "[?0]i64"]]), ("bar", [Pure "[?0]i64"]) ] ) @?= ( [Pure "[?0]i64", Free [Pure "[?0]i64", Pure "[?0]i64"]], [ ("bar", [0]), ("foo", [1, 2]) ] ) ] -- Be aware that some of these tests simply reinforce current -- behaviour - it may be that we want to restrict aliasing even -- further in the future; these tests would have to be updated in such -- cases. inferAliasesTests :: TestTree inferAliasesTests = testGroup "inferAliases" [ mkTest [Free [Pure "[0i64]i32"]] [Free [Pure "[?0]i32"]] [[("[?0]i32", RetAls [0] [0])]], mkTest [Free [Pure "[0i64]i32", Pure "[0i64]i32"]] [Free [Pure "[0i64]i32", Pure "[0i64]i32"]] [ [ ("[0i64]i32", RetAls [0] [0]), ("[0i64]i32", RetAls [1] [1]) ] ], -- Basically zip. mkTest [Free [Pure "[n_0]i32"], Free [Pure "[n_0]i32"]] [Free [Pure "[n_0]i32", Pure "[n_0]i32"]] [ [ ("[n_0]i32", RetAls [] [0]), ("[n_0]i32", RetAls [] [1]) ] ], mkTest [Free [Pure "[0i64]i32"], Free [Pure "[0i64]i32", Pure "[0i64]i32"]] [Free [Pure "[?0]i32", Pure "[?0]i32"]] [ [ ("[?0]i32", RetAls [1] [0]), ("[?0]i32", RetAls [2] [1]) ] ], mkTest [Free [Pure "[0i64][1i64]i32", Pure "[0i64][1i64]i32"]] [Free [Pure "[?0]i32", Pure "[?0]i32"]] [ [ ("[?0]i32", RetAls [0] [0]), ("[?0]i32", RetAls [1] [1]) ] ], -- Basically unzip. mkTest [Free [Pure "[n_0][n_1]i32", Pure "[n_0][n_1]i32"]] [Free [Pure "[?0]i32"], Free [Pure "[?0]i32"]] [ [("[?0]i32", RetAls [] [0, 1])], [("[?0]i32", RetAls [] [0, 1])] ], mkTest [ Free [Pure "*[n_0][n_1]i32"], Free [Pure "[n_2]i64"], Free [Pure "[n_3]i64"] ] [Free [Pure "*[n_0][n_1]i32"]] [[("*[n_0][n_1]i32", RetAls [] [])]], mkTest [Free [Pure "[n_0]i32", Free [Pure "[n_0][n_1]i32"]]] [Free [Pure "[n_0]i32"]] [[("[n_0]i32", RetAls [1] [0])]], mkTest [] [ Free [Pure "[n_0]i32", Free [Pure "[n_0][n_1]i32"]], Free [Pure "[n_0]i32"] ] [ [("[n_0]i32", RetAls [] [0]), ("[n_0][n_1]i32", RetAls [] [1])], [("[n_0]i32", RetAls [] [1, 2])] ], mkTest [Free [Pure "[n_0]i32"]] [Free [Pure "[m_1][m_1]i32"]] [ [("[m_1][m_1]i32", RetAls [0] [0])] ] ] where mkTest all_param_ts all_res_ts expected = testCase (show all_param_ts <> " " <> show all_res_ts) $ inferAliases (map (fmap fromString) all_param_ts) (map (fmap fromString) all_res_ts) @?= expected tests :: TestTree tests = testGroup "Futhark.Internalise.TypesValuesTests" [ internaliseTypeTests, sumTypeTests, inferAliasesTests ] futhark-0.25.27/unittests/Futhark/Optimise/000077500000000000000000000000001475065116200206015ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/Optimise/ArrayLayout/000077500000000000000000000000001475065116200230555ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/Optimise/ArrayLayout/AnalyseTests.hs000066400000000000000000000236551475065116200260430ustar00rootroot00000000000000module Futhark.Optimise.ArrayLayout.AnalyseTests (tests) where import Data.Map.Strict qualified as M import Futhark.Analysis.AccessPattern import Futhark.IR.GPU import Futhark.IR.GPUTests () import Futhark.IR.SyntaxTests () import Test.Tasty import Test.Tasty.HUnit tests :: TestTree tests = testGroup "Analyse" [analyseStmTests] analyseStmTests :: TestTree analyseStmTests = testGroup "analyseStm" [analyseIndexTests, analyseDimAccessesTests] analyseIndexTests :: TestTree analyseIndexTests = testGroup "analyseIndex" $ do let arr_name = "xss_5144" -- ============================= TestCase0 ============================= -- Most simple case where we want to manifest an array, hence, we record -- the Index in the IndexTable. let testCase0 = testCase "2D manifest" $ do let ctx = mempty { parents = [ SegOpName (SegmentedMap "defunc_0_map_res_5204"), LoopBodyName "defunc_0_f_res_5208" ], assignments = M.fromList [ ("gtid_5205", VariableInfo mempty 0 mempty ThreadID), ("i_5209", VariableInfo mempty 1 mempty LoopVar) ] } let patternNames = ["b_5211"] let dimFixes = [ DimFix (Var "gtid_5205"), DimFix (Var "i_5209") ] let indexTable = M.fromList [ ( SegmentedMap "defunc_0_map_res_5204", M.fromList [ ( (arr_name, [], [0 .. 1]), M.fromList [ ( "b_5211", [ DimAccess (M.fromList [("gtid_5205", Dependency 0 ThreadID)]) (Just "gtid_5205"), DimAccess (M.fromList [("i_5209", Dependency 1 LoopVar)]) (Just "i_5209") ] ) ] ) ] ) ] let (_, indexTable') = analyseIndex ctx patternNames arr_name dimFixes indexTable' @?= indexTable -- ============================= TestCase2 ============================= -- We don't want to manifest an array with only one dimension, so we don't -- record anything in the IndexTable. let testCase1 = testCase "1D manifest" $ do let ctx = mempty { parents = [ SegOpName (SegmentedMap "defunc_0_map_res_5204"), LoopBodyName "defunc_0_f_res_5208" ] } let patternNames = ["b_5211"] let dimFixes = [DimFix "i_5209"] let (_, indexTable') = analyseIndex ctx patternNames arr_name dimFixes indexTable' @?= mempty -- ============================= TestCase1 ============================= -- We don't want to record anything to the IndexTable when the array is -- not accessed inside a SegMap -- TODO: Create a similar one for MC with loops let testCase2 = testCase "Not inside SegMap" $ do let ctx = mempty let patternNames = ["b_5211"] let dimFixes = [ DimFix "gtid_5205", DimFix "i_5209" ] let (_, indexTable') = analyseIndex ctx patternNames arr_name dimFixes indexTable' @?= mempty -- ============================= TestCase3 ============================= -- If an array is allocated inside a loop or SegMap, we want to record that -- information in the ArrayName of the IndexTable. let testCase3 = testCase "Allocated inside SegMap" $ do let parents' = [ SegOpName (SegmentedMap "defunc_0_map_res_5204"), LoopBodyName "defunc_0_f_res_5208" ] let ctx = mempty { parents = parents', assignments = M.fromList [ ("gtid_5205", VariableInfo mempty 0 mempty ThreadID), ("i_5209", VariableInfo mempty 1 mempty LoopVar), (arr_name, VariableInfo mempty 0 parents' Variable) ] } let patternNames = ["b_5211"] let dimFixes = [ DimFix "gtid_5205", DimFix "i_5209" ] let indexTable = M.fromList [ ( SegmentedMap "defunc_0_map_res_5204", M.fromList [ ( (arr_name, parents', [0 .. 1]), M.fromList [ ( "b_5211", [ DimAccess (M.fromList [("gtid_5205", Dependency 0 ThreadID)]) (Just "gtid_5205"), DimAccess (M.fromList [("i_5209", Dependency 1 LoopVar)]) (Just "i_5209") ] ) ] ) ] ) ] let (_, indexTable') = analyseIndex ctx patternNames arr_name dimFixes indexTable' @?= indexTable -- ============================= TestCase4 ============================= -- If the vars in the index are temporaries, we want to reduce them to -- to the thread IDs and or loop counters they are functions of. let testCase4 = testCase "Reduce dependencies" $ do let ctx = mempty { parents = [ SegOpName (SegmentedMap "defunc_0_map_res_5204"), LoopBodyName "defunc_0_f_res_5208" ], assignments = M.fromList [ ("gtid_5205", VariableInfo mempty 0 mempty ThreadID), ("i_5209", VariableInfo mempty 1 mempty LoopVar), ("tmp0_5210", VariableInfo (namesFromList ["gtid_5205"]) 2 mempty Variable), ("tmp1_5211", VariableInfo (namesFromList ["i_5209"]) 3 mempty Variable), ("k_5212", VariableInfo mempty 1 mempty ConstType) ] } let patternNames = ["b_5211"] let dimFixes = [ DimFix "tmp0_5210", DimFix "tmp1_5211", DimFix "k_5212" ] let indexTable = M.fromList [ ( SegmentedMap "defunc_0_map_res_5204", M.fromList [ ( (arr_name, [], [0 .. 2]), M.fromList [ ( "b_5211", [ DimAccess (M.fromList [("gtid_5205", Dependency 0 ThreadID)]) (Just "tmp0_5210"), DimAccess (M.fromList [("i_5209", Dependency 1 LoopVar)]) (Just "tmp1_5211"), DimAccess mempty (Just "k_5212") ] ) ] ) ] ) ] let (_, indexTable') = analyseIndex ctx patternNames arr_name dimFixes indexTable' @?= indexTable [testCase0, testCase1, testCase2, testCase3, testCase4] analyseDimAccessesTests :: TestTree analyseDimAccessesTests = testGroup "analyseDimAccesses" $ do let testCase0 = testCase "Fold" $ do let indexTable = M.fromList [ ( SegmentedMap "defunc_0_map_res_5204", M.fromList [ ( ("xss_5144", [], [0, 1]), M.fromList [ ( "b_5211", [ DimAccess (M.fromList [("gtid_5205", Dependency 0 ThreadID)]) (Just "gtid_5205"), DimAccess (M.fromList [("i_5209", Dependency 1 LoopVar)]) (Just "i_5209") ] ) ] ) ] ) ] let indexTable' = (analyseDimAccesses @GPU) prog0 indexTable' @?= indexTable [testCase0] where prog0 :: Prog GPU prog0 = "\ \entry(\"main\",\ \ {xss: [][]i64},\ \ {[]i64})\ \ entry_main (n_5142 : i64,\ \ m_5143 : i64,\ \ xss_5144 : [n_5142][m_5143]i64)\ \ : {[n_5142]i64#([2], [0])} = {\ \ let {segmap_group_size_5202 : i64} =\ \ get_size(segmap_group_size_5190, thread_block_size)\ \ let {segmap_usable_groups_5203 : i64} =\ \ sdiv_up64(n_5142, segmap_group_size_5202)\ \ let {defunc_0_map_res_5204 : [n_5142]i64} =\ \ segmap(thread; ; grid=segmap_usable_groups_5203; blocksize=segmap_group_size_5202)\ \ (gtid_5205 < n_5142) (~phys_tid_5206) : {i64} {\ \ let {defunc_0_f_res_5208 : i64} =\ \ loop {acc_5210 : i64} = {0i64}\ \ for i_5209:i64 < m_5143 do {\ \ let {b_5211 : i64} =\ \ xss_5144[gtid_5205, i_5209]\ \ let {defunc_0_f_res_5212 : i64} =\ \ add64(acc_5210, b_5211)\ \ in {defunc_0_f_res_5212}\ \ }\ \ return {returns defunc_0_f_res_5208}\ \ }\ \ in {defunc_0_map_res_5204}\ \}" futhark-0.25.27/unittests/Futhark/Optimise/ArrayLayout/LayoutTests.hs000066400000000000000000000121641475065116200257150ustar00rootroot00000000000000module Futhark.Optimise.ArrayLayout.LayoutTests (tests) where import Data.Map.Strict qualified as M import Futhark.Analysis.AccessPattern import Futhark.Analysis.PrimExp import Futhark.FreshNames import Futhark.IR.GPU (GPU) import Futhark.IR.GPUTests () import Futhark.Optimise.ArrayLayout.Layout import Language.Futhark.Core import Test.Tasty import Test.Tasty.HUnit tests :: TestTree tests = testGroup "Layout" [commonPermutationEliminatorsTests] commonPermutationEliminatorsTests :: TestTree commonPermutationEliminatorsTests = testGroup "commonPermutationEliminators" [permutationTests, nestTests, dimAccessTests, constIndexElimTests] permutationTests :: TestTree permutationTests = testGroup "Permutations" $ do -- This isn't the way to test this, in reality we should provide realistic -- access patterns that might result in the given permutations. -- Luckily we only use the original access for one check atm. [ testCase (unwords [show perm, "->", show res]) $ commonPermutationEliminators perm [] @?= res | (perm, res) <- [ ([0], True), ([1, 0], False), ([0, 1], True), ([0, 0], True), ([1, 1], True), ([1, 2, 0], False), ([2, 0, 1], False), ([0, 1, 2], True), ([1, 0, 2], True), ([2, 1, 0], True), ([2, 2, 0], True), ([2, 1, 1], True), ([1, 0, 1], True), ([0, 0, 0], True), ([0, 1, 2, 3, 4], True), ([1, 0, 2, 3, 4], True), ([2, 3, 0, 1, 4], True), ([3, 4, 2, 0, 1], True), ([2, 3, 4, 0, 1], False), ([1, 2, 3, 4, 0], False), ([3, 4, 0, 1, 2], False) ] ] nestTests :: TestTree nestTests = testGroup "Nests" $ do let names = generateNames 2 [ testCase (unwords [args, "->", show res]) $ commonPermutationEliminators [1, 0] nest @?= res | (args, nest, res) <- [ ("[]", [], False), ("[CondBodyName]", [CondBodyName] <*> names, False), ("[SegOpName]", [SegOpName . SegmentedMap] <*> names, True), ("[LoopBodyName]", [LoopBodyName] <*> names, False), ("[SegOpName, CondBodyName]", [SegOpName . SegmentedMap, CondBodyName] <*> names, True), ("[CondBodyName, LoopBodyName]", [CondBodyName, LoopBodyName] <*> names, False) ] ] dimAccessTests :: TestTree dimAccessTests = testGroup "DimAccesses" [] -- TODO: Write tests for the part of commonPermutationEliminators that checks the complexity of the DimAccesses. constIndexElimTests :: TestTree constIndexElimTests = testGroup "constIndexElimTests" [ testCase "gpu eliminates indexes with constant in any dim" $ do let primExpTable = M.fromList [ ("gtid_4", Just (LeafExp "n_4" (IntType Int64))), ("i_5", Just (LeafExp "n_4" (IntType Int64))) ] layoutTableFromIndexTable primExpTable accessTableGPU @?= mempty, testCase "gpu ignores when not last" $ do let primExpTable = M.fromList [ ("gtid_4", Just (LeafExp "gtid_4" (IntType Int64))), ("gtid_5", Just (LeafExp "gtid_5" (IntType Int64))), ("i_6", Just (LeafExp "i_6" (IntType Int64))) ] layoutTableFromIndexTable primExpTable accessTableGPUrev @?= M.fromList [ ( SegmentedMap "mapres_1", M.fromList [ ( ("a_2", [], [0, 1, 2, 3]), M.fromList [("A_3", [2, 3, 0, 1])] ) ] ) ] ] where accessTableGPU :: IndexTable GPU accessTableGPU = singleAccess [ singleParAccess 0 "gtid_4", DimAccess mempty Nothing, singleSeqAccess 1 "i_5" ] accessTableGPUrev :: IndexTable GPU accessTableGPUrev = singleAccess [ singleParAccess 1 "gtid_4", singleParAccess 2 "gtid_5", singleSeqAccess 0 "i_5", singleSeqAccess 2 "gtid_4" ] singleAccess :: [DimAccess rep] -> IndexTable rep singleAccess dims = M.fromList [ ( sgOp, M.fromList [ ( ("A_2", [], [0, 1, 2, 3]), M.fromList [ ( "a_3", dims ) ] ) ] ) ] where sgOp = SegmentedMap {vnameFromSegOp = "mapres_1"} singleParAccess :: Int -> VName -> DimAccess rep singleParAccess level name = DimAccess (M.singleton name $ Dependency level ThreadID) (Just name) singleSeqAccess :: Int -> VName -> DimAccess rep singleSeqAccess level name = DimAccess (M.singleton name $ Dependency level LoopVar) (Just name) generateNames :: Int -> [VName] generateNames count = do let (name, source) = newName blankNameSource "i_0" fst $ foldl f ([name], source) [1 .. count - 1] where f (names, source) _ = do let (name, source') = newName source (last names) (names ++ [name], source') futhark-0.25.27/unittests/Futhark/Optimise/ArrayLayoutTests.hs000066400000000000000000000007351475065116200244410ustar00rootroot00000000000000module Futhark.Optimise.ArrayLayoutTests (tests) where import Futhark.Analysis.PrimExp.TableTests qualified import Futhark.Optimise.ArrayLayout.AnalyseTests qualified import Futhark.Optimise.ArrayLayout.LayoutTests qualified import Test.Tasty tests :: TestTree tests = testGroup "OptimizeArrayLayoutTests" [ Futhark.Optimise.ArrayLayout.AnalyseTests.tests, Futhark.Optimise.ArrayLayout.LayoutTests.tests, Futhark.Analysis.PrimExp.TableTests.tests ] futhark-0.25.27/unittests/Futhark/Optimise/MemoryBlockMerging/000077500000000000000000000000001475065116200243355ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/Optimise/MemoryBlockMerging/GreedyColoringTests.hs000066400000000000000000000044071475065116200306350ustar00rootroot00000000000000module Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests ( tests, ) where import Control.Arrow ((***)) import Data.Function ((&)) import Data.Map qualified as M import Data.Set qualified as S import Futhark.Optimise.MemoryBlockMerging.GreedyColoring qualified as GreedyColoring import Test.Tasty import Test.Tasty.HUnit tests :: TestTree tests = testGroup "GreedyColoringTests" [psumTest, allIntersect, emptyGraph, noIntersections, differentSpaces] psumTest :: TestTree psumTest = testCase "psumTest" $ assertEqual "Color simple 1-2-3 using two colors" ( [(0, "shared"), (1, "shared")] :: [(Int, String)], [(1 :: Int, 0), (2, 1), (3, 0)] ) $ (M.toList *** M.toList) $ GreedyColoring.colorGraph (M.fromList [(1, "shared"), (2, "shared"), (3, "shared")]) $ S.fromList [(1, 2), (2, 3)] allIntersect :: TestTree allIntersect = testCase "allIntersect" $ assertEqual "Color a graph where all values intersect" ( [(0, "shared"), (1, "shared"), (2, "shared")] :: [(Int, String)], [(1 :: Int, 2), (2, 1), (3, 0)] ) $ (M.toList *** M.toList) $ GreedyColoring.colorGraph (M.fromList [(1, "shared"), (2, "shared"), (3, "shared")]) $ S.fromList [(1, 2), (2, 3), (1, 3)] emptyGraph :: TestTree emptyGraph = testCase "emptyGraph" $ assertEqual "Color an empty graph" ([] :: [(Int, Char)], [] :: [(Int, Int)]) $ (M.toList *** M.toList) $ GreedyColoring.colorGraph M.empty $ S.fromList [] noIntersections :: TestTree noIntersections = GreedyColoring.colorGraph (M.fromList [(1, "shared"), (2, "shared"), (3, "shared")]) (S.fromList []) & M.toList *** M.toList & assertEqual "Color nodes with no intersections" ( [(0, "shared")] :: [(Int, String)], [(1, 0), (2, 0), (3, 0)] :: [(Int, Int)] ) & testCase "noIntersections" differentSpaces :: TestTree differentSpaces = GreedyColoring.colorGraph (M.fromList [(1, "a"), (2, "b"), (3, "c")]) (S.fromList []) & M.toList *** M.toList & assertEqual "Color nodes with no intersections but in different spaces" ( [(0, "c"), (1, "b"), (2, "a")] :: [(Int, String)], [(1, 2), (2, 1), (3, 0)] :: [(Int, Int)] ) & testCase "differentSpaces" futhark-0.25.27/unittests/Futhark/Pkg/000077500000000000000000000000001475065116200175315ustar00rootroot00000000000000futhark-0.25.27/unittests/Futhark/Pkg/SolveTests.hs000066400000000000000000000077631475065116200222150ustar00rootroot00000000000000module Futhark.Pkg.SolveTests (tests) where import Data.Map qualified as M import Data.Monoid import Data.Text qualified as T import Futhark.Pkg.Solve import Futhark.Pkg.Types import Test.Tasty import Test.Tasty.HUnit import Prelude semverE :: T.Text -> SemVer semverE s = case parseVersion s of Left err -> error $ T.unpack s <> " is not a valid version number: " <> errorBundlePretty err Right x -> x -- | A world of packages and interdependencies for testing the solver -- without touching the outside world. testEnv :: PkgRevDepInfo testEnv = M.fromList $ concatMap frob [ ( "athas", [ ( "foo", [ ("0.1.0", []), ("0.2.0", [("athas/bar", "1.0.0")]), ("0.3.0", []) ] ), ("foo@v2", [("2.0.0", [("athas/quux", "0.1.0")])]), ("bar", [("1.0.0", [])]), ("baz", [("0.1.0", [("athas/foo", "0.3.0")])]), ( "quux", [ ( "0.1.0", [ ("athas/foo", "0.2.0"), ("athas/baz", "0.1.0") ] ) ] ), ( "quux_perm", [ ( "0.1.0", [ ("athas/baz", "0.1.0"), ("athas/foo", "0.2.0") ] ) ] ), ("x_bar", [("1.0.0", [("athas/bar", "1.0.0")])]), ("x_foo", [("1.0.0", [("athas/foo", "0.3.0")])]), ( "tricky", [ ( "1.0.0", [ ("athas/foo", "0.2.0"), ("athas/x_foo", "1.0.0") ] ) ] ) ] ), -- Some mutually recursive packages. ( "nasty", [ ("foo", [("1.0.0", [("nasty/bar", "1.0.0")])]), ("bar", [("1.0.0", [("nasty/foo", "1.0.0")])]) ] ) ] where frob (user, repos) = do (repo, repo_revs) <- repos (rev, deps) <- repo_revs let rev' = semverE rev onDep (dp, dv) = (dp, (semverE dv, Nothing)) deps' = PkgRevDeps $ M.fromList $ map onDep deps pure ((user <> "/" <> repo, rev'), deps') newtype SolverRes = SolverRes BuildList deriving (Eq) instance Show SolverRes where show (SolverRes bl) = T.unpack $ prettyBuildList bl solverTest :: PkgPath -> T.Text -> Either T.Text [(PkgPath, T.Text)] -> TestTree solverTest p v expected = testCase (T.unpack $ p <> "-" <> prettySemVer v') $ fmap SolverRes (solveDepsPure testEnv target) @?= expected' where target = PkgRevDeps $ M.singleton p (v', Nothing) v' = semverE v expected' = SolverRes . BuildList . M.fromList . map onRes <$> expected onRes (dp, dv) = (dp, semverE dv) tests :: TestTree tests = testGroup "SolveTests" [ solverTest "athas/foo" "0.1.0" $ Right [("athas/foo", "0.1.0")], solverTest "athas/foo" "0.2.0" $ Right [ ("athas/foo", "0.2.0"), ("athas/bar", "1.0.0") ], solverTest "athas/quux" "0.1.0" $ Right [ ("athas/quux", "0.1.0"), ("athas/foo", "0.3.0"), ("athas/baz", "0.1.0") ], solverTest "athas/quux_perm" "0.1.0" $ Right [ ("athas/quux_perm", "0.1.0"), ("athas/foo", "0.3.0"), ("athas/baz", "0.1.0") ], solverTest "athas/foo@v2" "2.0.0" $ Right [ ("athas/foo@v2", "2.0.0"), ("athas/quux", "0.1.0"), ("athas/foo", "0.3.0"), ("athas/baz", "0.1.0") ], solverTest "athas/foo@v3" "3.0.0" $ Left "Unknown package/version: athas/foo@v3-3.0.0", solverTest "nasty/foo" "1.0.0" $ Right [ ("nasty/foo", "1.0.0"), ("nasty/bar", "1.0.0") ], solverTest "athas/tricky" "1.0.0" $ Right [ ("athas/tricky", "1.0.0"), ("athas/foo", "0.3.0"), ("athas/x_foo", "1.0.0") ] ] futhark-0.25.27/unittests/Futhark/ProfileTests.hs000066400000000000000000000011121475065116200217620ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} module Futhark.ProfileTests () where import Data.Map qualified as M import Data.Text qualified as T import Futhark.Profile import Test.Tasty.QuickCheck printable :: Gen String printable = getPrintableString <$> arbitrary arbText :: Gen T.Text arbText = T.pack <$> printable instance Arbitrary ProfilingEvent where arbitrary = ProfilingEvent <$> arbText <*> arbitrary <*> arbText instance Arbitrary ProfilingReport where arbitrary = ProfilingReport <$> arbitrary <*> (M.fromList <$> listOf ((,) <$> arbText <*> arbitrary)) futhark-0.25.27/unittests/Language/000077500000000000000000000000001475065116200171275ustar00rootroot00000000000000futhark-0.25.27/unittests/Language/Futhark/000077500000000000000000000000001475065116200205335ustar00rootroot00000000000000futhark-0.25.27/unittests/Language/Futhark/CoreTests.hs000066400000000000000000000005411475065116200230020ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} module Language.Futhark.CoreTests () where import Language.Futhark.Core import Language.Futhark.PrimitiveTests () import Test.QuickCheck instance Arbitrary Name where arbitrary = nameFromString <$> listOf1 (elements ['a' .. 'z']) instance Arbitrary VName where arbitrary = VName <$> arbitrary <*> arbitrary futhark-0.25.27/unittests/Language/Futhark/PrimitiveTests.hs000066400000000000000000000043231475065116200240640ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} module Language.Futhark.PrimitiveTests ( tests, arbitraryPrimValOfType, ) where import Control.Applicative import Futhark.Util (convFloat) import Language.Futhark.Primitive import Test.QuickCheck import Test.Tasty import Test.Tasty.HUnit import Prelude tests :: TestTree tests = testGroup "PrimitiveTests" [propPrimValuesHaveRightType] propPrimValuesHaveRightType :: TestTree propPrimValuesHaveRightType = testGroup "propPrimValuesHaveRightTypes" [ testCase (show t ++ " has blank of right type") $ primValueType (blankPrimValue t) @?= t | t <- [minBound .. maxBound] ] instance Arbitrary IntType where arbitrary = elements [minBound .. maxBound] instance Arbitrary FloatType where arbitrary = elements [minBound .. maxBound] instance Arbitrary PrimType where arbitrary = elements [minBound .. maxBound] instance Arbitrary IntValue where arbitrary = oneof [ Int8Value <$> arbitrary, Int16Value <$> arbitrary, Int32Value <$> arbitrary, Int64Value <$> arbitrary ] instance Arbitrary Half where arbitrary = (convFloat :: Float -> Half) <$> arbitrary instance Arbitrary FloatValue where arbitrary = oneof [ Float16Value <$> arbitrary, Float32Value <$> arbitrary, Float64Value <$> arbitrary ] instance Arbitrary PrimValue where arbitrary = oneof [ IntValue <$> arbitrary, FloatValue <$> arbitrary, BoolValue <$> arbitrary, pure UnitValue ] arbitraryPrimValOfType :: PrimType -> Gen PrimValue arbitraryPrimValOfType (IntType Int8) = IntValue . Int8Value <$> arbitrary arbitraryPrimValOfType (IntType Int16) = IntValue . Int16Value <$> arbitrary arbitraryPrimValOfType (IntType Int32) = IntValue . Int32Value <$> arbitrary arbitraryPrimValOfType (IntType Int64) = IntValue . Int64Value <$> arbitrary arbitraryPrimValOfType (FloatType Float16) = FloatValue . Float16Value <$> arbitrary arbitraryPrimValOfType (FloatType Float32) = FloatValue . Float32Value <$> arbitrary arbitraryPrimValOfType (FloatType Float64) = FloatValue . Float32Value <$> arbitrary arbitraryPrimValOfType Bool = BoolValue <$> arbitrary arbitraryPrimValOfType Unit = pure UnitValue futhark-0.25.27/unittests/Language/Futhark/SemanticTests.hs000066400000000000000000000014471475065116200236630ustar00rootroot00000000000000module Language.Futhark.SemanticTests (tests) where import Language.Futhark (ImportName (..)) import Language.Futhark.Semantic import Test.Tasty import Test.Tasty.HUnit tests :: TestTree tests = testGroup "Semantic objects" [ testCase "a" $ mkInitialImport "a" @?= ImportName "a", testCase "./a" $ mkInitialImport "./a" @?= ImportName "a", testCase "a/b -> ../c" $ mkImportFrom (mkInitialImport "a/b") "../c" @?= ImportName "c", testCase "a/b -> ../../c" $ mkImportFrom (mkInitialImport "a/b") "../../c" @?= ImportName "../c", testCase "../a -> b" $ mkImportFrom (mkInitialImport "../a") "b" @?= ImportName "../b", testCase "../a -> ../b" $ mkImportFrom (mkInitialImport "../a") "../b" @?= ImportName "../../b" ] futhark-0.25.27/unittests/Language/Futhark/SyntaxTests.hs000066400000000000000000000124271475065116200234060ustar00rootroot00000000000000{-# OPTIONS_GHC -fno-warn-orphans #-} module Language.Futhark.SyntaxTests (tests) where import Control.Applicative hiding (many, some) import Data.Bifunctor import Data.Char (isAlpha) import Data.Functor import Data.Map qualified as M import Data.String import Data.Text qualified as T import Data.Void import Language.Futhark import Language.Futhark.Parser import Language.Futhark.Primitive.Parse (constituent, keyword, lexeme) import Language.Futhark.PrimitiveTests () import Test.QuickCheck import Test.Tasty import Text.Megaparsec import Text.Megaparsec.Char.Lexer qualified as L import Prelude tests :: TestTree tests = testGroup "Source SyntaxTests" [] instance Arbitrary BinOp where arbitrary = elements [minBound .. maxBound] instance Arbitrary Uniqueness where arbitrary = elements [Unique, Nonunique] instance Arbitrary PrimType where arbitrary = oneof [ Signed <$> arbitrary, Unsigned <$> arbitrary, FloatType <$> arbitrary, pure Bool ] instance Arbitrary PrimValue where arbitrary = oneof [ SignedValue <$> arbitrary, UnsignedValue <$> arbitrary, FloatValue <$> arbitrary, BoolValue <$> arbitrary ] -- The following dirty instances make it slightly nicer to write unit tests. instance IsString VName where fromString s = let (s', '_' : tag) = span (/= '_') s in VName (fromString s') (read tag) instance (IsString v) => IsString (QualName v) where fromString = QualName [] . fromString instance IsString UncheckedTypeExp where fromString = either (error . T.unpack . syntaxErrorMsg) id . parseType "IsString UncheckedTypeExp" . fromString type Parser = Parsec Void T.Text braces, brackets, parens :: Parser a -> Parser a braces = between (lexeme "{") (lexeme "}") brackets = between (lexeme "[") (lexeme "]") parens = between (lexeme "(") (lexeme ")") pName :: Parser Name pName = lexeme . fmap nameFromString $ (:) <$> satisfy isAlpha <*> many (satisfy constituent) pVName :: Parser VName pVName = lexeme $ do (s, tag) <- satisfy constituent `manyTill_` try pTag "variable name" pure $ VName (nameFromString s) tag where pTag = "_" *> L.decimal <* notFollowedBy (satisfy constituent) pQualName :: Parser (QualName VName) pQualName = QualName [] <$> pVName pPrimType :: Parser PrimType pPrimType = choice $ map f [ Bool, Signed Int8, Signed Int16, Signed Int32, Signed Int64, Unsigned Int8, Unsigned Int16, Unsigned Int32, Unsigned Int64, FloatType Float32, FloatType Float64 ] where f t = keyword (prettyText t) $> t pUniqueness :: Parser Uniqueness pUniqueness = choice [lexeme "*" $> Unique, pure Nonunique] pSize :: Parser Size pSize = brackets $ choice [ flip sizeFromInteger mempty <$> lexeme L.decimal, flip sizeFromName mempty <$> pQualName ] pScalarNonFun :: Parser (ScalarTypeBase Size Uniqueness) pScalarNonFun = choice [ Prim <$> pPrimType, pTypeVar, tupleRecord <$> parens (pType `sepBy` lexeme ","), Record . M.fromList <$> braces (pField `sepBy1` lexeme ",") ] where pField = (,) <$> pName <* lexeme ":" <*> pType pTypeVar = TypeVar <$> pUniqueness <*> pQualName <*> many pTypeArg pTypeArg = choice [ TypeArgDim <$> pSize, TypeArgType . second (const NoUniqueness) <$> pTypeArgType ] pTypeArgType = choice [ Scalar . Prim <$> pPrimType, parens pType ] pArrayType :: Parser ResType pArrayType = Array <$> pUniqueness <*> (Shape <$> some pSize) <*> (second (const NoUniqueness) <$> pScalarNonFun) pNonFunType :: Parser ResType pNonFunType = choice [ try pArrayType, try $ parens pType, Scalar <$> pScalarNonFun ] pScalarType :: Parser (ScalarTypeBase Size Uniqueness) pScalarType = choice [try pFun, pScalarNonFun] where pFun = pParam <* lexeme "->" <*> pRetType pParam = choice [ try pNamedParam, do t <- pNonFunType pure $ Arrow Nonunique Unnamed (diet $ resToParam t) (toStruct t) ] pNamedParam = parens $ do v <- pVName <* lexeme ":" t <- pType pure $ Arrow Nonunique (Named v) (diet $ resToParam t) (toStruct t) pRetType :: Parser ResRetType pRetType = choice [ lexeme "?" *> (RetType <$> some (brackets pVName) <* lexeme "." <*> pType), RetType [] <$> pType ] pType :: Parser ResType pType = choice [try $ Scalar <$> pScalarType, pArrayType, parens pType] fromStringParse :: Parser a -> String -> String -> a fromStringParse p what s = either onError id $ parse (p <* eof) "" (T.pack s) where onError e = error $ "not a " <> what <> ": " <> s <> "\n" <> errorBundlePretty e instance IsString (ScalarTypeBase Size NoUniqueness) where fromString = fromStringParse (second (const NoUniqueness) <$> pScalarType) "ScalarType" instance IsString StructType where fromString = fromStringParse (second (const NoUniqueness) <$> pType) "StructType" instance IsString StructRetType where fromString = fromStringParse (second (pure NoUniqueness) <$> pRetType) "StructRetType" instance IsString ResRetType where fromString = fromStringParse pRetType "ResRetType" futhark-0.25.27/unittests/Language/Futhark/TypeChecker/000077500000000000000000000000001475065116200227415ustar00rootroot00000000000000futhark-0.25.27/unittests/Language/Futhark/TypeChecker/TypesTests.hs000066400000000000000000000144451475065116200254340ustar00rootroot00000000000000module Language.Futhark.TypeChecker.TypesTests (tests) where import Data.Bifunctor import Data.List (isInfixOf) import Data.Map qualified as M import Data.Text qualified as T import Futhark.FreshNames import Futhark.Util.Pretty (docText, prettyTextOneLine) import Language.Futhark import Language.Futhark.Semantic import Language.Futhark.SyntaxTests () import Language.Futhark.TypeChecker (initialEnv) import Language.Futhark.TypeChecker.Monad import Language.Futhark.TypeChecker.Names (resolveTypeExp) import Language.Futhark.TypeChecker.Terms import Language.Futhark.TypeChecker.Types import Test.Tasty import Test.Tasty.HUnit evalTest :: TypeExp (ExpBase NoInfo Name) Name -> Either String ([VName], ResRetType) -> TestTree evalTest te expected = testCase (prettyString te) $ case (fmap (extract . fst) (run (checkTypeExp checkSizeExp =<< resolveTypeExp te)), expected) of (Left got_e, Left expected_e) -> let got_e_s = T.unpack $ docText $ prettyTypeError got_e in (expected_e `isInfixOf` got_e_s) @? got_e_s (Left got_e, Right _) -> let got_e_s = T.unpack $ docText $ prettyTypeError got_e in assertFailure $ "Failed: " <> got_e_s (Right actual_t, Right expected_t) -> actual_t @?= expected_t (Right actual_t, Left _) -> assertFailure $ "Expected error, got: " <> show actual_t where extract (_, svars, t, _) = (svars, t) run = snd . runTypeM env mempty (mkInitialImport "") (newNameSource 100) -- We hack up an environment with some predefined type -- abbreviations for testing. This is all pretty sensitive to the -- specific unique names, so we have to be careful! env = initialEnv { envTypeTable = M.fromList [ ( "square_1000", TypeAbbr Unlifted [TypeParamDim "n_1001" mempty] "[n_1001][n_1001]i32" ), ( "fun_1100", TypeAbbr Lifted [ TypeParamType Lifted "a_1101" mempty, TypeParamType Lifted "b_1102" mempty ] "a_1101 -> b_1102" ), ( "pair_1200", TypeAbbr SizeLifted [] "?[n_1201][m_1202].([n_1201]i64, [m_1202]i64)" ) ] <> envTypeTable initialEnv, envNameMap = M.fromList [ ((Type, "square"), "square_1000"), ((Type, "fun"), "fun_1100"), ((Type, "pair"), "pair_1200") ] <> envNameMap initialEnv } evalTests :: TestTree evalTests = testGroup "Type expression elaboration" [ testGroup "Positive tests" (map mkPos pos), testGroup "Negative tests" (map mkNeg neg) ] where mkPos (x, y) = evalTest x (Right y) mkNeg (x, y) = evalTest x (Left y) pos = [ ( "[]i32", ([], "?[d_100].[d_100]i32") ), ( "[][]i32", ([], "?[d_100][d_101].[d_100][d_101]i32") ), ( "bool -> []i32", ([], "bool -> ?[d_100].[d_100]i32") ), ( "bool -> []f32 -> []i32", (["d_100"], "bool -> [d_100]f32 -> ?[d_101].[d_101]i32") ), ( "([]i32,[]i32)", ([], "?[d_100][d_101].([d_100]i32, [d_101]i32)") ), ( "{a:[]i32,b:[]i32}", ([], "?[d_100][d_101].{a:[d_100]i32, b:[d_101]i32}") ), ( "?[n].[n][n]bool", ([], "?[n_100].[n_100][n_100]bool") ), ( "([]i32 -> []i32) -> bool -> []i32", (["d_100"], "([d_100]i32 -> ?[d_101].[d_101]i32) -> bool -> ?[d_102].[d_102]i32") ), ( "((k: i64) -> [k]i32 -> [k]i32) -> []i32 -> bool", (["d_101"], "((k_100: i64) -> [k_100]i32 -> [k_100]i32) -> [d_101]i32 -> bool") ), ( "square [10]", ([], "[10][10]i32") ), ( "square []", ([], "?[d_100].[d_100][d_100]i32") ), ( "bool -> square []", ([], "bool -> ?[d_100].[d_100][d_100]i32") ), ( "(k: i64) -> square [k]", ([], "(k_100: i64) -> [k_100][k_100]i32") ), ( "fun i32 bool", ([], "i32 -> bool") ), ( "fun ([]i32) bool", ([], "?[d_100].[d_100]i32 -> bool") ), ( "fun bool ([]i32)", ([], "?[d_100].bool -> [d_100]i32") ), ( "bool -> fun ([]i32) bool", ([], "bool -> ?[d_100].[d_100]i32 -> bool") ), ( "bool -> fun bool ([]i32)", ([], "bool -> ?[d_100].bool -> [d_100]i32") ), ( "pair", ([], "?[n_100][m_101].([n_100]i64, [m_101]i64)") ), ( "(pair,pair)", ([], "?[n_100][m_101][n_102][m_103].(([n_100]i64, [m_101]i64), ([n_102]i64, [m_103]i64))") ) ] neg = [ ("?[n].bool", "Existential size \"n\""), ("?[n].bool -> [n]bool", "Existential size \"n\""), ("?[n].[n]bool -> [n]bool", "Existential size \"n\""), ("?[n].[n]bool -> bool", "Existential size \"n\"") ] substTest :: M.Map VName (Subst StructRetType) -> StructRetType -> StructRetType -> TestTree substTest m t expected = testCase (pretty_m <> ": " <> T.unpack (prettyTextOneLine t)) $ applySubst (`M.lookup` m) t @?= expected where pretty_m = T.unpack $ prettyText $ map (first toName) $ M.toList m -- Some of these tests may be a bit fragile, in that they depend on -- internal renumbering, which can be arbitrary. substTests :: TestTree substTests = testGroup "Type substitution" [ substTest m0 "t_0" "i64", substTest m0 "[1]t_0" "[1]i64", substTest m0 "?[n_10].[n_10]t_0" "?[n_10].[n_10]i64", -- substTest m1 "t_0" "?[n_1].[n_1]bool", substTest m1 "f32 -> t_0" "f32 -> ?[n_1].[n_1]bool", substTest m1 "f32 -> f64 -> t_0" "f32 -> f64 -> ?[n_1].[n_1]bool", substTest m1 "f32 -> t_0 -> bool" "?[n_1].f32 -> [n_1]bool -> bool", substTest m1 "f32 -> t_0 -> t_0" "?[n_1].f32 -> [n_1]bool -> ?[n_2].[n_2]bool" ] where m0 = M.fromList [("t_0", Subst [] "i64")] m1 = M.fromList [("t_0", Subst [] "?[n_1].[n_1]bool")] tests :: TestTree tests = testGroup "Basic type operations" [evalTests, substTests] futhark-0.25.27/unittests/Language/Futhark/TypeCheckerTests.hs000066400000000000000000000004041475065116200243160ustar00rootroot00000000000000module Language.Futhark.TypeCheckerTests (tests) where import Language.Futhark.TypeChecker.TypesTests qualified import Test.Tasty tests :: TestTree tests = testGroup "Source type checker tests" [ Language.Futhark.TypeChecker.TypesTests.tests ] futhark-0.25.27/unittests/futhark_tests.hs000066400000000000000000000027541475065116200206360ustar00rootroot00000000000000module Main (main) where import Futhark.AD.DerivativesTests qualified import Futhark.Analysis.AlgSimplifyTests qualified import Futhark.BenchTests qualified import Futhark.IR.Mem.IntervalTests qualified import Futhark.IR.Mem.IxFunTests qualified import Futhark.IR.PropTests qualified import Futhark.IR.Syntax.CoreTests qualified import Futhark.Internalise.TypesValuesTests qualified import Futhark.Optimise.ArrayLayoutTests qualified import Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests qualified import Futhark.Pkg.SolveTests qualified import Language.Futhark.PrimitiveTests qualified import Language.Futhark.SemanticTests qualified import Language.Futhark.SyntaxTests qualified import Language.Futhark.TypeCheckerTests qualified import Test.Tasty allTests :: TestTree allTests = testGroup "" [ Language.Futhark.SyntaxTests.tests, Futhark.AD.DerivativesTests.tests, Futhark.BenchTests.tests, Futhark.IR.PropTests.tests, Futhark.IR.Syntax.CoreTests.tests, Futhark.Pkg.SolveTests.tests, Futhark.Internalise.TypesValuesTests.tests, Futhark.IR.Mem.IntervalTests.tests, Futhark.IR.Mem.IxFunTests.tests, Language.Futhark.PrimitiveTests.tests, Futhark.Optimise.MemoryBlockMerging.GreedyColoringTests.tests, Futhark.Analysis.AlgSimplifyTests.tests, Language.Futhark.TypeCheckerTests.tests, Language.Futhark.SemanticTests.tests, Futhark.Optimise.ArrayLayoutTests.tests ] main :: IO () main = defaultMain allTests futhark-0.25.27/weeder.toml000066400000000000000000000015101475065116200155070ustar00rootroot00000000000000roots = [ # The entry points for the main CLI program and the unit tests. "^Main.main$" # Modules intended as externally visible for library code. , "^Language.Futhark.Query" , "^Language.Futhark.Parser" , "^Futhark.Bench.decodeBenchResults" # Generated code that we cannot do anything about. , "^Paths_futhark" # Code that might technically be dead right now, but is kept around # for consistency of the internal API. , "^Futhark.Analysis.PrimExp" , "^Futhark.Builder" , "^Futhark.Construct.eConvOp" , "^Futhark.Pass.ExtractKernels.Distribution.ppKernelNest" , "^Futhark.Representation.AST.Attributes.Types.int16" , "^Futhark.Representation.AST.Attributes.Types.float32" , "^Futhark.Representation.AST.Attributes.Types.float64" , "^Futhark.Representation.PrimitiveTests" ] type-class-roots = true